diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 000000000..1fcf1442d --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1 @@ +github: jackc diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 000000000..637113357 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,54 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: bug +assignees: '' + +--- + +**Describe the bug** +A clear and concise description of what the bug is. + +**To Reproduce** +Steps to reproduce the behavior: + +If possible, please provide runnable example such as: + +```go +package main + +import ( + "context" + "log" + "os" + + "github.com/jackc/pgx/v5" +) + +func main() { + conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) + if err != nil { + log.Fatal(err) + } + defer conn.Close(context.Background()) + + // Your code here... +} +``` + +Please run your example with the race detector enabled. For example, `go run -race main.go` or `go test -race`. + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Actual behavior** +A clear and concise description of what actually happened. + +**Version** + - Go: `$ go version` -> [e.g. go version go1.18.3 darwin/amd64] + - PostgreSQL: `$ psql --no-psqlrc --tuples-only -c 'select version()'` -> [e.g. PostgreSQL 14.4 on x86_64-apple-darwin21.5.0, compiled by Apple clang version 13.1.6 (clang-1316.0.21.2.5), 64-bit] + - pgx: `$ grep 'github.com/jackc/pgx/v[0-9]' go.mod` -> [e.g. v4.16.1] + +**Additional context** +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 000000000..bbcbbe7d6 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,20 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: '' +assignees: '' + +--- + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. diff --git a/.github/ISSUE_TEMPLATE/other-issues.md b/.github/ISSUE_TEMPLATE/other-issues.md new file mode 100644 index 000000000..27862a317 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/other-issues.md @@ -0,0 +1,10 @@ +--- +name: Other issues +about: Any issue that is not a bug or a feature request +title: '' +labels: '' +assignees: '' + +--- + +Please describe the issue in detail. If this is a question about how to use pgx please use discussions instead. diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 000000000..68fe8c6f3 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,156 @@ +name: CI + +on: + push: + branches: [master] + pull_request: + branches: [master] + +jobs: + test: + name: Test + runs-on: ubuntu-22.04 + + strategy: + matrix: + go-version: ["1.24", "1.25"] + pg-version: [13, 14, 15, 16, 17, cockroachdb] + include: + - pg-version: 13 + pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" + pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test" + pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" + pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test" + pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test" + pgx-ssl-password: certpw + pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test" + - pg-version: 14 + pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" + pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test" + pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" + pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test" + pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test" + pgx-ssl-password: certpw + pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test" + - pg-version: 15 + pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" + pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test" + pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" + pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test" + pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test" + pgx-ssl-password: certpw + pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test" + - pg-version: 16 + pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" + pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test" + pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" + pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test" + pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test" + pgx-ssl-password: certpw + pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test" + - pg-version: 17 + pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" + pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test" + pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" + pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test" + pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test" + pgx-ssl-password: certpw + pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test" + - pg-version: cockroachdb + pgx-test-database: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" + + steps: + - name: Check out code into the Go module directory + uses: actions/checkout@v4 + + - name: Set up Go ${{ matrix.go-version }} + uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go-version }} + + - name: Setup database server for testing + run: ci/setup_test.bash + env: + PGVERSION: ${{ matrix.pg-version }} + + # - name: Setup upterm session + # uses: lhotari/action-upterm@v1 + # with: + # ## limits ssh access and adds the ssh public key for the user which triggered the workflow + # limit-access-to-actor: true + # env: + # PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }} + # PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }} + # PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }} + # PGX_TEST_SCRAM_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-scram-password-conn-string }} + # PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }} + # PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }} + # PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }} + # PGX_SSL_PASSWORD: ${{ matrix.pgx-ssl-password }} + # PGX_TEST_TLS_CLIENT_CONN_STRING: ${{ matrix.pgx-test-tls-client-conn-string }} + + - name: Check formatting + run: | + gofmt -l -s -w . + git status + git diff --exit-code + + - name: Test + # parallel testing is disabled because somehow parallel testing causes Github Actions to kill the runner. + run: go test -parallel=1 -race ./... + env: + PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }} + PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }} + PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }} + PGX_TEST_SCRAM_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-scram-password-conn-string }} + PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }} + PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }} + # TestConnectTLS fails. However, it succeeds if I connect to the CI server with upterm and run it. Give up on that test for now. + # PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }} + PGX_SSL_PASSWORD: ${{ matrix.pgx-ssl-password }} + PGX_TEST_TLS_CLIENT_CONN_STRING: ${{ matrix.pgx-test-tls-client-conn-string }} + + test-windows: + name: Test Windows + runs-on: windows-latest + strategy: + matrix: + go-version: ["1.24", "1.25"] + + steps: + - name: Setup PostgreSQL + id: postgres + uses: ikalnytskyi/action-setup-postgres@v4 + with: + database: pgx_test + + - name: Check out code into the Go module directory + uses: actions/checkout@v4 + + - name: Set up Go ${{ matrix.go-version }} + uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go-version }} + + - name: Initialize test database + run: | + psql -f testsetup/postgresql_setup.sql pgx_test + env: + PGSERVICE: ${{ steps.postgres.outputs.service-name }} + shell: bash + + - name: Test + # parallel testing is disabled because somehow parallel testing causes Github Actions to kill the runner. + run: go test -parallel=1 -race ./... + env: + PGX_TEST_DATABASE: ${{ steps.postgres.outputs.connection-uri }} diff --git a/.gitignore b/.gitignore index 0ff008008..a2ebbe9c6 100644 --- a/.gitignore +++ b/.gitignore @@ -21,5 +21,7 @@ _testmain.go *.exe -conn_config_test.go .envrc +/.testdb + +.DS_Store diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 000000000..ca74c703a --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,21 @@ +# See for configurations: https://golangci-lint.run/usage/configuration/ +version: 2 + +# See: https://golangci-lint.run/usage/formatters/ +formatters: + default: none + enable: + - gofmt # https://pkg.go.dev/cmd/gofmt + - gofumpt # https://github.com/mvdan/gofumpt + + settings: + gofmt: + simplify: true # Simplify code: gofmt with `-s` option. + + gofumpt: + # Module path which contains the source code being formatted. + # Default: "" + module-path: github.com/jackc/pgx/v5 # Should match with module in go.mod + # Choose whether to use the extra rules. + # Default: false + extra-rules: true diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 6d4b3cd2f..000000000 --- a/.travis.yml +++ /dev/null @@ -1,33 +0,0 @@ -language: go - -go: - - 1.x - - tip - -# Derived from https://github.com/lib/pq/blob/master/.travis.yml -before_install: - - ./travis/before_install.bash - -env: - global: - - PGX_TEST_DATABASE=postgres://pgx_md5:secret@127.0.0.1/pgx_test - matrix: - - CRATEVERSION=2.1 - - PGVERSION=10 - - PGVERSION=9.6 - - PGVERSION=9.5 - - PGVERSION=9.4 - - PGVERSION=9.3 - -before_script: - - ./travis/before_script.bash - -install: - - ./travis/install.bash - -script: - - ./travis/script.bash - -matrix: - allow_failures: - - go: tip diff --git a/CHANGELOG.md b/CHANGELOG.md index 720ad4e60..6c9c99b5e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,310 +1,473 @@ -# Unreleased - -## Features - -* Support sslkey, sslcert, and sslrootcert URI params (Sean Chittenden) -* Allow any scheme in ParseURI (for convenience with cockroachdb) (Sean Chittenden) - -## Fixes - -* Fix Rows.Values returning same value for multiple columns of same complex type - -# 3.1.0 (January 15, 2018) - -## Features - -* Add QueryEx, QueryRowEx, ExecEx, and RollbackEx to Tx -* Add more ColumnType support (Timothée Peignier) -* Add UUIDArray type (Kelsey Francis) -* Add zap log adapter (Kelsey Francis) -* Add CreateReplicationSlotEx that consistent_point and snapshot_name (Mark Fletcher) -* Add BeginBatch to Tx (Gaspard Douady) -* Support CrateDB (Felix Geisendörfer) -* Allow use of logrus logger with fields configured (André Bierlein) -* Add array of enum support -* Add support for bit type -* Handle timeout parameters (Timothée Peignier) -* Allow overriding connection info (James Lawrence) -* Add support for bpchar type (Iurii Krasnoshchok) -* Add ConnConfig.PreferSimpleProtocol +# 5.7.6 (September 8, 2025) + +* Use ParseConfigError in pgx.ParseConfig and pgxpool.ParseConfig (Yurasov Ilia) +* Add PrepareConn hook to pgxpool (Jonathan Hall) +* Reduce allocations in QueryContext (Dominique Lefevre) +* Add MarshalJSON and UnmarshalJSON for pgtype.Uint32 (Panos Koutsovasilis) +* Configure ping behavior on pgxpool with ShouldPing (Christian Kiely) +* zeronull int types implement Int64Valuer and Int64Scanner (Li Zeghong) +* Fix panic when receiving terminate connection message during CopyFrom (Michal Drausowski) +* Fix statement cache not being invalidated on error during batch (Muhammadali Nazarov) + +# 5.7.5 (May 17, 2025) + +* Support sslnegotiation connection option (divyam234) +* Update golang.org/x/crypto to v0.37.0. This placates security scanners that were unable to see that pgx did not use the behavior affected by https://pkg.go.dev/vuln/GO-2025-3487. +* TraceLog now logs Acquire and Release at the debug level (dave sinclair) +* Add support for PGTZ environment variable +* Add support for PGOPTIONS environment variable +* Unpin memory used by Rows quicker +* Remove PlanScan memoization. This resolves a rare issue where scanning could be broken for one type by first scanning another. The problem was in the memoization system and benchmarking revealed that memoization was not providing any meaningful benefit. + +# 5.7.4 (March 24, 2025) + +* Fix / revert change to scanning JSON `null` (Felix Röhrich) + +# 5.7.3 (March 21, 2025) + +* Expose EmptyAcquireWaitTime in pgxpool.Stat (vamshiaruru32) +* Improve SQL sanitizer performance (ninedraft) +* Fix Scan confusion with json(b), sql.Scanner, and automatic dereferencing (moukoublen, felix-roehrich) +* Fix Values() for xml type always returning nil instead of []byte +* Add ability to send Flush message in pipeline mode (zenkovev) +* Fix pgtype.Timestamp's JSON behavior to match PostgreSQL (pconstantinou) +* Better error messages when scanning structs (logicbomb) +* Fix handling of error on batch write (bonnefoa) +* Match libpq's connection fallback behavior more closely (felix-roehrich) +* Add MinIdleConns to pgxpool (djahandarie) + +# 5.7.2 (December 21, 2024) + +* Fix prepared statement already exists on batch prepare failure +* Add commit query to tx options (Lucas Hild) +* Fix pgtype.Timestamp json unmarshal (Shean de Montigny-Desautels) +* Add message body size limits in frontend and backend (zene) +* Add xid8 type +* Ensure planning encodes and scans cannot infinitely recurse +* Implement pgtype.UUID.String() (Konstantin Grachev) +* Switch from ExecParams to Exec in ValidateConnectTargetSessionAttrs functions (Alexander Rumyantsev) +* Update golang.org/x/crypto +* Fix json(b) columns prefer sql.Scanner interface like database/sql (Ludovico Russo) + +# 5.7.1 (September 10, 2024) + +* Fix data race in tracelog.TraceLog +* Update puddle to v2.2.2. This removes the import of nanotime via linkname. +* Update golang.org/x/crypto and golang.org/x/text + +# 5.7.0 (September 7, 2024) + +* Add support for sslrootcert=system (Yann Soubeyrand) +* Add LoadTypes to load multiple types in a single SQL query (Nick Farrell) +* Add XMLCodec supports encoding + scanning XML column type like json (nickcruess-soda) +* Add MultiTrace (Stepan Rabotkin) +* Add TraceLogConfig with customizable TimeKey (stringintech) +* pgx.ErrNoRows wraps sql.ErrNoRows to aid in database/sql compatibility with native pgx functions (merlin) +* Support scanning binary formatted uint32 into string / TextScanner (jennifersp) +* Fix interval encoding to allow 0s and avoid extra spaces (Carlos Pérez-Aradros Herce) +* Update pgservicefile - fixes panic when parsing invalid file +* Better error message when reading past end of batch +* Don't print url when url.Parse returns an error (Kevin Biju) +* Fix snake case name normalization collision in RowToStructByName with db tag (nolandseigler) +* Fix: Scan and encode types with underlying types of arrays + +# 5.6.0 (May 25, 2024) + +* Add StrictNamedArgs (Tomas Zahradnicek) +* Add support for macaddr8 type (Carlos Pérez-Aradros Herce) +* Add SeverityUnlocalized field to PgError / Notice +* Performance optimization of RowToStructByPos/Name (Zach Olstein) +* Allow customizing context canceled behavior for pgconn +* Add ScanLocation to pgtype.Timestamp[tz]Codec +* Add custom data to pgconn.PgConn +* Fix ResultReader.Read() to handle nil values +* Do not encode interval microseconds when they are 0 (Carlos Pérez-Aradros Herce) +* pgconn.SafeToRetry checks for wrapped errors (tjasko) +* Failed connection attempts include all errors +* Optimize LargeObject.Read (Mitar) +* Add tracing for connection acquire and release from pool (ngavinsir) +* Fix encode driver.Valuer not called when nil +* Add support for custom JSON marshal and unmarshal (Mitar) +* Use Go default keepalive for TCP connections (Hans-Joachim Kliemeck) + +# 5.5.5 (March 9, 2024) + +Use spaces instead of parentheses for SQL sanitization. + +This still solves the problem of negative numbers creating a line comment, but this avoids breaking edge cases such as +`set foo to $1` where the substitution is taking place in a location where an arbitrary expression is not allowed. + +# 5.5.4 (March 4, 2024) + +Fix CVE-2024-27304 + +SQL injection can occur if an attacker can cause a single query or bind message to exceed 4 GB in size. An integer +overflow in the calculated message size can cause the one large message to be sent as multiple messages under the +attacker's control. + +Thanks to Paul Gerste for reporting this issue. + +* Fix behavior of CollectRows to return empty slice if Rows are empty (Felix) +* Fix simple protocol encoding of json.RawMessage +* Fix *Pipeline.getResults should close pipeline on error +* Fix panic in TryFindUnderlyingTypeScanPlan (David Kurman) +* Fix deallocation of invalidated cached statements in a transaction +* Handle invalid sslkey file +* Fix scan float4 into sql.Scanner +* Fix pgtype.Bits not making copy of data from read buffer. This would cause the data to be corrupted by future reads. + +# 5.5.3 (February 3, 2024) + +* Fix: prepared statement already exists +* Improve CopyFrom auto-conversion of text-ish values +* Add ltree type support (Florent Viel) +* Make some properties of Batch and QueuedQuery public (Pavlo Golub) +* Add AppendRows function (Edoardo Spadolini) +* Optimize convert UUID [16]byte to string (Kirill Malikov) +* Fix: LargeObject Read and Write of more than ~1GB at a time (Mitar) + +# 5.5.2 (January 13, 2024) + +* Allow NamedArgs to start with underscore +* pgproto3: Maximum message body length support (jeremy.spriet) +* Upgrade golang.org/x/crypto to v0.17.0 +* Add snake_case support to RowToStructByName (Tikhon Fedulov) +* Fix: update description cache after exec prepare (James Hartig) +* Fix: pipeline checks if it is closed (James Hartig and Ryan Fowler) +* Fix: normalize timeout / context errors during TLS startup (Samuel Stauffer) +* Add OnPgError for easier centralized error handling (James Hartig) + +# 5.5.1 (December 9, 2023) + +* Add CopyFromFunc helper function. (robford) +* Add PgConn.Deallocate method that uses PostgreSQL protocol Close message. +* pgx uses new PgConn.Deallocate method. This allows deallocating statements to work in a failed transaction. This fixes a case where the prepared statement map could become invalid. +* Fix: Prefer driver.Valuer over json.Marshaler for json fields. (Jacopo) +* Fix: simple protocol SQL sanitizer previously panicked if an invalid $0 placeholder was used. This now returns an error instead. (maksymnevajdev) +* Add pgtype.Numeric.ScanScientific (Eshton Robateau) + +# 5.5.0 (November 4, 2023) + +* Add CollectExactlyOneRow. (Julien GOTTELAND) +* Add OpenDBFromPool to create *database/sql.DB from *pgxpool.Pool. (Lev Zakharov) +* Prepare can automatically choose statement name based on sql. This makes it easier to explicitly manage prepared statements. +* Statement cache now uses deterministic, stable statement names. +* database/sql prepared statement names are deterministically generated. +* Fix: SendBatch wasn't respecting context cancellation. +* Fix: Timeout error from pipeline is now normalized. +* Fix: database/sql encoding json.RawMessage to []byte. +* CancelRequest: Wait for the cancel request to be acknowledged by the server. This should improve PgBouncer compatibility. (Anton Levakin) +* stdlib: Use Ping instead of CheckConn in ResetSession +* Add json.Marshaler and json.Unmarshaler for Float4, Float8 (Kirill Mironov) + +# 5.4.3 (August 5, 2023) + +* Fix: QCharArrayOID was defined with the wrong OID (Christoph Engelbert) +* Fix: connect_timeout for sslmode=allow|prefer (smaher-edb) +* Fix: pgxpool: background health check cannot overflow pool +* Fix: Check for nil in defer when sending batch (recover properly from panic) +* Fix: json scan of non-string pointer to pointer +* Fix: zeronull.Timestamptz should use pgtype.Timestamptz +* Fix: NewConnsCount was not correctly counting connections created by Acquire directly. (James Hartig) +* RowTo(AddrOf)StructByPos ignores fields with "-" db tag +* Optimization: improve text format numeric parsing (horpto) + +# 5.4.2 (July 11, 2023) + +* Fix: RowScanner errors are fatal to Rows +* Fix: Enable failover efforts when pg_hba.conf disallows non-ssl connections (Brandon Kauffman) +* Hstore text codec internal improvements (Evan Jones) +* Fix: Stop timers for background reader when not in use. Fixes memory leak when closing connections (Adrian-Stefan Mares) +* Fix: Stop background reader as soon as possible. +* Add PgConn.SyncConn(). This combined with the above fix makes it safe to directly use the underlying net.Conn. + +# 5.4.1 (June 18, 2023) + +* Fix: concurrency bug with pgtypeDefaultMap and simple protocol (Lev Zakharov) +* Add TxOptions.BeginQuery to allow overriding the default BEGIN query + +# 5.4.0 (June 14, 2023) + +* Replace platform specific syscalls for non-blocking IO with more traditional goroutines and deadlines. This returns to the v4 approach with some additional improvements and fixes. This restores the ability to use a pgx.Conn over an ssh.Conn as well as other non-TCP or Unix socket connections. In addition, it is a significantly simpler implementation that is less likely to have cross platform issues. +* Optimization: The default type registrations are now shared among all connections. This saves about 100KB of memory per connection. `pgtype.Type` and `pgtype.Codec` values are now required to be immutable after registration. This was already necessary in most cases but wasn't documented until now. (Lev Zakharov) +* Fix: Ensure pgxpool.Pool.QueryRow.Scan releases connection on panic +* CancelRequest: don't try to read the reply (Nicola Murino) +* Fix: correctly handle bool type aliases (Wichert Akkerman) +* Fix: pgconn.CancelRequest: Fix unix sockets: don't use RemoteAddr() +* Fix: pgx.Conn memory leak with prepared statement caching (Evan Jones) +* Add BeforeClose to pgxpool.Pool (Evan Cordell) +* Fix: various hstore fixes and optimizations (Evan Jones) +* Fix: RowToStructByPos with embedded unexported struct +* Support different bool string representations (Lev Zakharov) +* Fix: error when using BatchResults.Exec on a select that returns an error after some rows. +* Fix: pipelineBatchResults.Exec() not returning error from ResultReader +* Fix: pipeline batch results not closing pipeline when error occurs while reading directly from results instead of using + a callback. +* Fix: scanning a table type into a struct +* Fix: scan array of record to pointer to slice of struct +* Fix: handle null for json (Cemre Mengu) +* Batch Query callback is called even when there is an error +* Add RowTo(AddrOf)StructByNameLax (Audi P. Risa P) + +# 5.3.1 (February 27, 2023) + +* Fix: Support v4 and v5 stdlib in same program (Tomáš Procházka) +* Fix: sql.Scanner not being used in certain cases +* Add text format jsonpath support +* Fix: fake non-blocking read adaptive wait time + +# 5.3.0 (February 11, 2023) + +* Fix: json values work with sql.Scanner +* Fixed / improved error messages (Mark Chambers and Yevgeny Pats) +* Fix: support scan into single dimensional arrays +* Fix: MaxConnLifetimeJitter setting actually jitter (Ben Weintraub) +* Fix: driver.Value representation of bytea should be []byte not string +* Fix: better handling of unregistered OIDs +* CopyFrom can use query cache to avoid extra round trip to get OIDs (Alejandro Do Nascimento Mora) +* Fix: encode to json ignoring driver.Valuer +* Support sql.Scanner on renamed base type +* Fix: pgtype.Numeric text encoding of negative numbers (Mark Chambers) +* Fix: connect with multiple hostnames when one can't be resolved +* Upgrade puddle to remove dependency on uber/atomic and fix alignment issue on 32-bit platform +* Fix: scanning json column into **string +* Multiple reductions in memory allocations +* Fake non-blocking read adapts its max wait time +* Improve CopyFrom performance and reduce memory usage +* Fix: encode []any to array +* Fix: LoadType for composite with dropped attributes (Felix Röhrich) +* Support v4 and v5 stdlib in same program +* Fix: text format array decoding with string of "NULL" +* Prefer binary format for arrays + +# 5.2.0 (December 5, 2022) + +* `tracelog.TraceLog` implements the pgx.PrepareTracer interface. (Vitalii Solodilov) +* Optimize creating begin transaction SQL string (Petr Evdokimov and ksco) +* `Conn.LoadType` supports range and multirange types (Vitalii Solodilov) +* Fix scan `uint` and `uint64` `ScanNumeric`. This resolves a PostgreSQL `numeric` being incorrectly scanned into `uint` and `uint64`. -## Fixes +# 5.1.1 (November 17, 2022) -* Fix numeric EncodeBinary bug (Wei Congrui) -* Fix logrus updated package name (Damir Vandic) -* Fix some invalid one round trip execs failing to return non-nil error. (Kelsey Francis) -* Return ErrClosedPool when Acquire() with closed pool (Mike Graf) -* Fix decoding row with same type values -* Always return non-nil \*Rows from Query to fix QueryRow (Kelsey Francis) -* Fix pgtype types that can Set database/sql/driver.driver.Valuer -* Prefix types in namespaces other than pg_catalog or public (Kelsey Francis) -* Fix incomplete selects during batch (Gaspard Douady and Jack Christensen) -* Support nil pointers to value implementing driver.Valuer -* Fix time logging for QueryEx -* Fix ranges with text format where end is unbounded -* Detect erroneous JSON(B) encoding -* Fix missing interval mapping -* ConnPool begin should not retry if ctx is done (Gaspard Douady) -* Fix reading interrupted messages could break connection -* Return error on unknown oid while decoding record instead of panic (Iurii Krasnoshchok) +* Fix simple query sanitizer where query text contains a Unicode replacement character. +* Remove erroneous `name` argument from `DeallocateAll()`. Technically, this is a breaking change, but given that method was only added 5 days ago this change was accepted. (Bodo Kaiser) -## Changes +# 5.1.0 (November 12, 2022) -* Align sslmode "require" more closely to libpq (Johan Brandhorst) +* Update puddle to v2.1.2. This resolves a race condition and a deadlock in pgxpool. +* `QueryRewriter.RewriteQuery` now returns an error. Technically, this is a breaking change for any external implementers, but given the minimal likelihood that there are actually any external implementers this change was accepted. +* Expose `GetSSLPassword` support to pgx. +* Fix encode `ErrorResponse` unknown field handling. This would only affect pgproto3 being used directly as a proxy with a non-PostgreSQL server that included additional error fields. +* Fix date text format encoding with 5 digit years. +* Fix date values passed to a `sql.Scanner` as `string` instead of `time.Time`. +* DateCodec.DecodeValue can return `pgtype.InfinityModifier` instead of `string` for infinite values. This now matches the behavior of the timestamp types. +* Add domain type support to `Conn.LoadType()`. +* Add `RowToStructByName` and `RowToAddrOfStructByName`. (Pavlo Golub) +* Add `Conn.DeallocateAll()` to clear all prepared statements including the statement cache. (Bodo Kaiser) -# 3.0.1 (August 12, 2017) +# 5.0.4 (October 24, 2022) -## Fixes +* Fix: CollectOneRow prefers PostgreSQL error over pgx.ErrorNoRows +* Fix: some reflect Kind checks to first check for nil +* Bump golang.org/x/text dependency to placate snyk +* Fix: RowToStructByPos on structs with multiple anonymous sub-structs (Baptiste Fontaine) +* Fix: Exec checks if tx is closed -* Fix compilation on 32-bit platform -* Fix invalid MarshalJSON of types with status Undefined -* Fix pid logging +# 5.0.3 (October 14, 2022) -# 3.0.0 (July 24, 2017) +* Fix `driver.Valuer` handling edge cases that could cause infinite loop or crash -## Changes +# v5.0.2 (October 8, 2022) -* Pid to PID in accordance with Go naming conventions. -* Conn.Pid changed to accessor method Conn.PID() -* Conn.SecretKey removed -* Remove Conn.TxStatus -* Logger interface reduced to single Log method. -* Replace BeginIso with BeginEx. BeginEx adds support for read/write mode and deferrable mode. -* Transaction isolation level constants are now typed strings instead of bare strings. -* Conn.WaitForNotification now takes context.Context instead of time.Duration for cancellation support. -* Conn.WaitForNotification no longer automatically pings internally every 15 seconds. -* ReplicationConn.WaitForReplicationMessage now takes context.Context instead of time.Duration for cancellation support. -* Reject scanning binary format values into a string (e.g. binary encoded timestamptz to string). See https://github.com/jackc/pgx/issues/219 and https://github.com/jackc/pgx/issues/228 -* No longer can read raw bytes of any value into a []byte. Use pgtype.GenericBinary if this functionality is needed. -* Remove CopyTo (functionality is now in CopyFrom) -* OID constants moved from pgx to pgtype package -* Replaced Scanner, Encoder, and PgxScanner interfaces with pgtype system -* Removed ValueReader -* ConnPool.Close no longer waits for all acquired connections to be released. Instead, it immediately closes all available connections, and closes acquired connections when they are released in the same manner as ConnPool.Reset. -* Removed Rows.Fatal(error) -* Removed Rows.AfterClose() -* Removed Rows.Conn() -* Removed Tx.AfterClose() -* Removed Tx.Conn() -* Use Go casing convention for OID, UUID, JSON(B), ACLItem, CID, TID, XID, and CIDR -* Replaced stdlib.OpenFromConnPool with DriverConfig system +* Fix date encoding in text format to always use 2 digits for month and day +* Prefer driver.Valuer over wrap plans when encoding +* Fix scan to pointer to pointer to renamed type +* Allow scanning NULL even if PG and Go types are incompatible -## Features +# v5.0.1 (September 24, 2022) -* Entirely revamped pluggable type system that supports approximately 60 PostgreSQL types. -* Types support database/sql interfaces and therefore can be used with other drivers -* Added context methods supporting cancellation where appropriate -* Added simple query protocol support -* Added single round-trip query mode -* Added batch query operations -* Added OnNotice -* github.com/pkg/errors used where possible for errors -* Added stdlib.DriverConfig which directly allows full configuration of underlying pgx connections without needing to use a pgx.ConnPool -* Added AcquireConn and ReleaseConn to stdlib to allow acquiring a connection from a database/sql connection. +* Fix 32-bit atomic usage +* Add MarshalJSON for Float8 (yogipristiawan) +* Add `[` and `]` to text encoding of `Lseg` +* Fix sqlScannerWrapper NULL handling -# 2.11.0 (June 5, 2017) +# v5.0.0 (September 17, 2022) -## Fixes +## Merged Packages -* Fix race with concurrent execution of stdlib.OpenFromConnPool (Terin Stock) +`github.com/jackc/pgtype`, `github.com/jackc/pgconn`, and `github.com/jackc/pgproto3` are now included in the main +`github.com/jackc/pgx` repository. Previously there was confusion as to where issues should be reported, additional +release work due to releasing multiple packages, and less clear changelogs. -## Features +## pgconn -* .pgpass support (j7b) -* Add missing CopyFrom delegators to Tx and ConnPool (Jack Christensen) -* Add ParseConnectionString (James Lawrence) +`CommandTag` is now an opaque type instead of directly exposing an underlying `[]byte`. -## Performance +The return value `ResultReader.Values()` is no longer safe to retain a reference to after a subsequent call to `NextRow()` or `Close()`. -* Optimize HStore encoding (René Kroon) +`Trace()` method adds low level message tracing similar to the `PQtrace` function in `libpq`. -# 2.10.0 (March 17, 2017) +pgconn now uses non-blocking IO. This is a significant internal restructuring, but it should not cause any visible changes on its own. However, it is important in implementing other new features. -## Fixes +`CheckConn()` checks a connection's liveness by doing a non-blocking read. This can be used to detect database restarts or network interruptions without executing a query or a ping. -* database/sql driver created through stdlib.OpenFromConnPool closes connections when requested by database/sql rather than release to underlying connection pool. +pgconn now supports pipeline mode. -# 2.11.0 (June 5, 2017) +`*PgConn.ReceiveResults` removed. Use pipeline mode instead. -## Fixes +`Timeout()` no longer considers `context.Canceled` as a timeout error. `context.DeadlineExceeded` still is considered a timeout error. -* Fix race with concurrent execution of stdlib.OpenFromConnPool (Terin Stock) +## pgxpool -## Features +`Connect` and `ConnectConfig` have been renamed to `New` and `NewWithConfig` respectively. The `LazyConnect` option has been removed. Pools always lazily connect. -* .pgpass support (j7b) -* Add missing CopyFrom delegators to Tx and ConnPool (Jack Christensen) -* Add ParseConnectionString (James Lawrence) +## pgtype -## Performance +The `pgtype` package has been significantly changed. -* Optimize HStore encoding (René Kroon) +### NULL Representation -# 2.10.0 (March 17, 2017) +Previously, types had a `Status` field that could be `Undefined`, `Null`, or `Present`. This has been changed to a +`Valid` `bool` field to harmonize with how `database/sql` represents `NULL` and to make the zero value useable. -## Fixes +Previously, a type that implemented `driver.Valuer` would have the `Value` method called even on a nil pointer. All nils +whether typed or untyped now represent `NULL`. -* Oid underlying type changed to uint32, previously it was incorrectly int32 (Manni Wood) -* Explicitly close checked-in connections on ConnPool.Reset, previously they were closed by GC +### Codec and Value Split -## Features +Previously, the type system combined decoding and encoding values with the value types. e.g. Type `Int8` both handled +encoding and decoding the PostgreSQL representation and acted as a value object. This caused some difficulties when +there was not an exact 1 to 1 relationship between the Go types and the PostgreSQL types For example, scanning a +PostgreSQL binary `numeric` into a Go `float64` was awkward (see https://github.com/jackc/pgtype/issues/147). This +concepts have been separated. A `Codec` only has responsibility for encoding and decoding values. Value types are +generally defined by implementing an interface that a particular `Codec` understands (e.g. `PointScanner` and +`PointValuer` for the PostgreSQL `point` type). -* Add xid type support (Manni Wood) -* Add cid type support (Manni Wood) -* Add tid type support (Manni Wood) -* Add "char" type support (Manni Wood) -* Add NullOid type (Manni Wood) -* Add json/jsonb binary support to allow use with CopyTo -* Add named error ErrAcquireTimeout (Alexander Staubo) -* Add logical replication decoding (Kris Wehner) -* Add PgxScanner interface to allow types to simultaneously support database/sql and pgx (Jack Christensen) -* Add CopyFrom with schema support (Jack Christensen) +### Array Types -## Compatibility +All array types are now handled by `ArrayCodec` instead of using code generation for each new array type. This also +means that less common array types such as `point[]` are now supported. `Array[T]` supports PostgreSQL multi-dimensional +arrays. -* jsonb now defaults to binary format. This means passing a []byte to a jsonb column will no longer work. -* CopyTo is now deprecated but will continue to work. +### Composite Types -# 2.9.0 (August 26, 2016) +Composite types must be registered before use. `CompositeFields` may still be used to construct and destruct composite +values, but any type may now implement `CompositeIndexGetter` and `CompositeIndexScanner` to be used as a composite. -## Fixes +### Range Types -* Fix *ConnPool.Deallocate() not deleting prepared statement from map -* Fix stdlib not logging unprepared query SQL (Krzysztof Dryś) -* Fix Rows.Values() with varchar binary format -* Concurrent ConnPool.Acquire calls with Dialer timeouts now timeout in the expected amount of time (Konstantin Dzreev) +Range types are now handled with types `RangeCodec` and `Range[T]`. This allows additional user defined range types to +easily be handled. Multirange types are handled similarly with `MultirangeCodec` and `Multirange[T]`. -## Features +### pgxtype -* Add CopyTo -* Add PrepareEx -* Add basic record to []interface{} decoding -* Encode and decode between all Go and PostgreSQL integer types with bounds checking -* Decode inet/cidr to net.IP -* Encode/decode [][]byte to/from bytea[] -* Encode/decode named types whose underlying types are string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64 +`LoadDataType` moved to `*Conn` as `LoadType`. -## Performance +### Bytea -* Substantial reduction in memory allocations +The `Bytea` and `GenericBinary` types have been replaced. Use the following instead: -# 2.8.1 (March 24, 2016) +* `[]byte` - For normal usage directly use `[]byte`. +* `DriverBytes` - Uses driver memory only available until next database method call. Avoids a copy and an allocation. +* `PreallocBytes` - Uses preallocated byte slice to avoid an allocation. +* `UndecodedBytes` - Avoids any decoding. Allows working with raw bytes. -## Features +### Dropped lib/pq Support -* Scan accepts nil argument to ignore a column +`pgtype` previously supported and was tested against [lib/pq](https://github.com/lib/pq). While it will continue to work +in most cases this is no longer supported. -## Fixes +### database/sql Scan -* Fix compilation on 32-bit architecture -* Fix Tx.status not being set on error on Commit -* Fix Listen/Unlisten with special characters +Previously, most `Scan` implementations would convert `[]byte` to `string` automatically to decode a text value. Now +only `string` is handled. This is to allow the possibility of future binary support in `database/sql` mode by +considering `[]byte` to be binary format and `string` text format. This change should have no effect for any use with +`pgx`. The previous behavior was only necessary for `lib/pq` compatibility. -# 2.8.0 (March 18, 2016) +Added `*Map.SQLScanner` to create a `sql.Scanner` for types such as `[]int32` and `Range[T]` that do not implement +`sql.Scanner` directly. -## Fixes +### Number Type Fields Include Bit size -* Fix unrecognized commit failure -* Fix msgReader.rxMsg bug when msgReader already has error -* Go float64 can no longer be encoded to a PostgreSQL float4 -* Fix connection corruption when query with error is closed early +`Int2`, `Int4`, `Int8`, `Float4`, `Float8`, and `Uint32` fields now include bit size. e.g. `Int` is renamed to `Int64`. +This matches the convention set by `database/sql`. In addition, for comparable types like `pgtype.Int8` and +`sql.NullInt64` the structures are identical. This means they can be directly converted one to another. -## Features +### 3rd Party Type Integrations -This release adds multiple extension points helpful when wrapping pgx with -custom application behavior. pgx can now use custom types designed for the -standard database/sql package such as -[github.com/shopspring/decimal](https://github.com/shopspring/decimal). +* Extracted integrations with https://github.com/shopspring/decimal and https://github.com/gofrs/uuid to + https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid respectively. This trims + the pgx dependency tree. -* Add *Tx.AfterClose() hook -* Add *Tx.Conn() -* Add *Tx.Status() -* Add *Tx.Err() -* Add *Rows.AfterClose() hook -* Add *Rows.Conn() -* Add *Conn.SetLogger() to allow changing logger -* Add *Conn.SetLogLevel() to allow changing log level -* Add ConnPool.Reset method -* Add support for database/sql.Scanner and database/sql/driver.Valuer interfaces -* Rows.Scan errors now include which argument caused error -* Add Encode() to allow custom Encoders to reuse internal encoding functionality -* Add Decode() to allow customer Decoders to reuse internal decoding functionality -* Add ConnPool.Prepare method -* Add ConnPool.Deallocate method -* Add Scan to uint32 and uint64 (utrack) -* Add encode and decode to []uint16, []uint32, and []uint64 (Max Musatov) - -## Performance - -* []byte skips encoding/decoding - -# 2.7.1 (October 26, 2015) - -* Disable SSL renegotiation - -# 2.7.0 (October 16, 2015) - -* Add RuntimeParams to ConnConfig -* ParseURI extracts RuntimeParams -* ParseDSN extracts RuntimeParams -* ParseEnvLibpq extracts PGAPPNAME -* Prepare is now idempotent -* Rows.Values now supports oid type -* ConnPool.Release automatically unlistens connections (Joseph Glanville) -* Add trace log level -* Add more efficient log leveling -* Retry automatically on ConnPool.Begin (Joseph Glanville) -* Encode from net.IP to inet and cidr -* Generalize encoding pointer to string to any PostgreSQL type -* Add UUID encoding from pointer to string (Joseph Glanville) -* Add null mapping to pointer to pointer (Jonathan Rudenberg) -* Add JSON and JSONB type support (Joseph Glanville) - -# 2.6.0 (September 3, 2015) - -* Add inet and cidr type support -* Add binary decoding to TimestampOid in stdlib driver (Samuel Stauffer) -* Add support for specifying sslmode in connection strings (Rick Snyder) -* Allow ConnPool to have MaxConnections of 1 -* Add basic PGSSLMODE to support to ParseEnvLibpq -* Add fallback TLS config -* Expose specific error for TSL refused -* More error details exposed in PgError -* Support custom dialer (Lewis Marshall) - -# 2.5.0 (April 15, 2015) - -* Fix stdlib nil support (Blaž Hrastnik) -* Support custom Scanner not reading entire value -* Fix empty array scanning (Laurent Debacker) -* Add ParseDSN (deoxxa) -* Add timestamp support to NullTime -* Remove unused text format scanners -* Return error when too many parameters on Prepare -* Add Travis CI integration (Jonathan Rudenberg) -* Large object support (Jonathan Rudenberg) -* Fix reading null byte arrays (Karl Seguin) -* Add timestamptz[] support -* Add timestamp[] support (Karl Seguin) -* Add bool[] support (Karl Seguin) -* Allow writing []byte into text and varchar columns without type conversion (Hari Bhaskaran) -* Fix ConnPool Close panic -* Add Listen / notify example -* Reduce memory allocations (Karl Seguin) - -# 2.4.0 (October 3, 2014) - -* Add per connection oid to name map -* Add Hstore support (Andy Walker) -* Move introductory docs to godoc from readme -* Fix documentation references to TextEncoder and BinaryEncoder -* Add keep-alive to TCP connections (Andy Walker) -* Add support for EmptyQueryResponse / Allow no-op Exec (Andy Walker) -* Allow reading any type into []byte -* WaitForNotification detects lost connections quicker - -# 2.3.0 (September 16, 2014) - -* Truncate logged strings and byte slices -* Extract more error information from PostgreSQL -* Fix data race with Rows and ConnPool +### Other Changes + +* `Bit` and `Varbit` are both replaced by the `Bits` type. +* `CID`, `OID`, `OIDValue`, and `XID` are replaced by the `Uint32` type. +* `Hstore` is now defined as `map[string]*string`. +* `JSON` and `JSONB` types removed. Use `[]byte` or `string` directly. +* `QChar` type removed. Use `rune` or `byte` directly. +* `Inet` and `Cidr` types removed. Use `netip.Addr` and `netip.Prefix` directly. These types are more memory efficient than the previous `net.IPNet`. +* `Macaddr` type removed. Use `net.HardwareAddr` directly. +* Renamed `pgtype.ConnInfo` to `pgtype.Map`. +* Renamed `pgtype.DataType` to `pgtype.Type`. +* Renamed `pgtype.None` to `pgtype.Finite`. +* `RegisterType` now accepts a `*Type` instead of `Type`. +* Assorted array helper methods and types made private. + +## stdlib + +* Removed `AcquireConn` and `ReleaseConn` as that functionality has been built in since Go 1.13. + +## Reduced Memory Usage by Reusing Read Buffers + +Previously, the connection read buffer would allocate large chunks of memory and never reuse them. This allowed +transferring ownership to anything such as scanned values without incurring an additional allocation and memory copy. +However, this came at the cost of overall increased memory allocation size. But worse it was also possible to pin large +chunks of memory by retaining a reference to a small value that originally came directly from the read buffer. Now +ownership remains with the read buffer and anything needing to retain a value must make a copy. + +## Query Execution Modes + +Control over automatic prepared statement caching and simple protocol use are now combined into query execution mode. +See documentation for `QueryExecMode`. + +## QueryRewriter Interface and NamedArgs + +pgx now supports named arguments with the `NamedArgs` type. This is implemented via the new `QueryRewriter` interface which +allows arbitrary rewriting of query SQL and arguments. + +## RowScanner Interface + +The `RowScanner` interface allows a single argument to Rows.Scan to scan the entire row. + +## Rows Result Helpers + +* `CollectRows` and `RowTo*` functions simplify collecting results into a slice. +* `CollectOneRow` collects one row using `RowTo*` functions. +* `ForEachRow` simplifies scanning each row and executing code using the scanned values. `ForEachRow` replaces `QueryFunc`. + +## Tx Helpers + +Rather than every type that implemented `Begin` or `BeginTx` methods also needing to implement `BeginFunc` and +`BeginTxFunc` these methods have been converted to functions that take a db that implements `Begin` or `BeginTx`. + +## Improved Batch Query Ergonomics + +Previously, the code for building a batch went in one place before the call to `SendBatch`, and the code for reading the +results went in one place after the call to `SendBatch`. This could make it difficult to match up the query and the code +to handle the results. Now `Queue` returns a `QueuedQuery` which has methods `Query`, `QueryRow`, and `Exec` which can +be used to register a callback function that will handle the result. Callback functions are called automatically when +`BatchResults.Close` is called. + +## SendBatch Uses Pipeline Mode When Appropriate + +Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1 +for each prepare / describe and 1 for executing them all. Now pipeline mode is used to prepare / describe all statements +in a single network round trip. So it would only take 2 round trips. + +## Tracing and Logging + +Internal logging support has been replaced with tracing hooks. This allows custom tracing integration with tools like OpenTelemetry. Package tracelog provides an adapter for pgx v4 loggers to act as a tracer. + +All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency +tree. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..c975a9372 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,121 @@ +# Contributing + +## Discuss Significant Changes + +Before you invest a significant amount of time on a change, please create a discussion or issue describing your +proposal. This will help to ensure your proposed change has a reasonable chance of being merged. + +## Avoid Dependencies + +Adding a dependency is a big deal. While on occasion a new dependency may be accepted, the default answer to any change +that adds a dependency is no. + +## Development Environment Setup + +pgx tests naturally require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_DATABASE` +environment variable. The `PGX_TEST_DATABASE` environment variable can either be a URL or key-value pairs. In addition, +the standard `PG*` environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to +simplify environment variable handling. + +### Using an Existing PostgreSQL Cluster + +If you already have a PostgreSQL development server this is the quickest way to start and run the majority of the pgx +test suite. Some tests will be skipped that require server configuration changes (e.g. those testing different +authentication methods). + +Create and setup a test database: + +``` +export PGDATABASE=pgx_test +createdb +psql -c 'create extension hstore;' +psql -c 'create extension ltree;' +psql -c 'create domain uint64 as numeric(20,0);' +``` + +Ensure a `postgres` user exists. This happens by default in normal PostgreSQL installs, but some installation methods +such as Homebrew do not. + +``` +createuser -s postgres +``` + +Ensure your `PGX_TEST_DATABASE` environment variable points to the database you just created and run the tests. + +``` +export PGX_TEST_DATABASE="host=/private/tmp database=pgx_test" +go test ./... +``` + +This will run the vast majority of the tests, but some tests will be skipped (e.g. those testing different connection methods). + +### Creating a New PostgreSQL Cluster Exclusively for Testing + +The following environment variables need to be set both for initial setup and whenever the tests are run. (direnv is +highly recommended). Depending on your platform, you may need to change the host for `PGX_TEST_UNIX_SOCKET_CONN_STRING`. + +``` +export PGPORT=5015 +export PGUSER=postgres +export PGDATABASE=pgx_test +export POSTGRESQL_DATA_DIR=postgresql + +export PGX_TEST_DATABASE="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret" +export PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/private/tmp database=pgx_test" +export PGX_TEST_TCP_CONN_STRING="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret" +export PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_scram password=secret database=pgx_test" +export PGX_TEST_MD5_PASSWORD_CONN_STRING="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret" +export PGX_TEST_PLAIN_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_pw password=secret" +export PGX_TEST_TLS_CONN_STRING="host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem" +export PGX_SSL_PASSWORD=certpw +export PGX_TEST_TLS_CLIENT_CONN_STRING="host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem database=pgx_test sslcert=`pwd`/.testdb/pgx_sslcert.crt sslkey=`pwd`/.testdb/pgx_sslcert.key" +``` + +Create a new database cluster. + +``` +initdb --locale=en_US -E UTF-8 --username=postgres .testdb/$POSTGRESQL_DATA_DIR + +echo "listen_addresses = '127.0.0.1'" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf +echo "port = $PGPORT" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf +cat testsetup/postgresql_ssl.conf >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf +cp testsetup/pg_hba.conf .testdb/$POSTGRESQL_DATA_DIR/pg_hba.conf + +cd .testdb + +# Generate CA, server, and encrypted client certificates. +go run ../testsetup/generate_certs.go + +# Copy certificates to server directory and set permissions. +cp ca.pem $POSTGRESQL_DATA_DIR/root.crt +cp localhost.key $POSTGRESQL_DATA_DIR/server.key +chmod 600 $POSTGRESQL_DATA_DIR/server.key +cp localhost.crt $POSTGRESQL_DATA_DIR/server.crt + +cd .. +``` + + +Start the new cluster. This will be necessary whenever you are running pgx tests. + +``` +postgres -D .testdb/$POSTGRESQL_DATA_DIR +``` + +Setup the test database in the new cluster. + +``` +createdb +psql --no-psqlrc -f testsetup/postgresql_setup.sql +``` + +### PgBouncer + +There are tests specific for PgBouncer that will be executed if `PGX_TEST_PGBOUNCER_CONN_STRING` is set. + +### Optional Tests + +pgx supports multiple connection types and means of authentication. These tests are optional. They will only run if the +appropriate environment variables are set. In addition, there may be tests specific to particular PostgreSQL versions, +non-PostgreSQL servers (e.g. CockroachDB), or connection poolers (e.g. PgBouncer). `go test ./... -v | grep SKIP` to see +if any tests are being skipped. diff --git a/LICENSE b/LICENSE index 7dee3daf8..5c486c39a 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2013 Jack Christensen +Copyright (c) 2013-2021 Jack Christensen MIT License @@ -19,4 +19,4 @@ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md index 1acaabff8..3c6f8e6a7 100644 --- a/README.md +++ b/README.md @@ -1,151 +1,191 @@ -[![](https://godoc.org/github.com/jackc/pgx?status.svg)](https://godoc.org/github.com/jackc/pgx) -[![Build Status](https://travis-ci.org/jackc/pgx.svg)](https://travis-ci.org/jackc/pgx) +[![Go Reference](https://pkg.go.dev/badge/github.com/jackc/pgx/v5.svg)](https://pkg.go.dev/github.com/jackc/pgx/v5) +[![Build Status](https://github.com/jackc/pgx/actions/workflows/ci.yml/badge.svg)](https://github.com/jackc/pgx/actions/workflows/ci.yml) # pgx - PostgreSQL Driver and Toolkit -pgx is a pure Go driver and toolkit for PostgreSQL. pgx is different from other drivers such as [pq](http://godoc.org/github.com/lib/pq) because, while it can operate as a database/sql compatible driver, pgx is also usable directly. It offers a native interface similar to database/sql that offers better performance and more features. +pgx is a pure Go driver and toolkit for PostgreSQL. +The pgx driver is a low-level, high performance interface that exposes PostgreSQL-specific features such as `LISTEN` / +`NOTIFY` and `COPY`. It also includes an adapter for the standard `database/sql` interface. + +The toolkit component is a related set of packages that implement PostgreSQL functionality such as parsing the wire protocol +and type mapping between PostgreSQL and Go. These underlying packages can be used to implement alternative drivers, +proxies, load balancers, logical replication clients, etc. + +## Example Usage ```go -var name string -var weight int64 -err := conn.QueryRow("select name, weight from widgets where id=$1", 42).Scan(&name, &weight) -if err != nil { - return err +package main + +import ( + "context" + "fmt" + "os" + + "github.com/jackc/pgx/v5" +) + +func main() { + // urlExample := "postgres://username:password@localhost:5432/database_name" + conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) + if err != nil { + fmt.Fprintf(os.Stderr, "Unable to connect to database: %v\n", err) + os.Exit(1) + } + defer conn.Close(context.Background()) + + var name string + var weight int64 + err = conn.QueryRow(context.Background(), "select name, weight from widgets where id=$1", 42).Scan(&name, &weight) + if err != nil { + fmt.Fprintf(os.Stderr, "QueryRow failed: %v\n", err) + os.Exit(1) + } + + fmt.Println(name, weight) } ``` -## Features +See the [getting started guide](https://github.com/jackc/pgx/wiki/Getting-started-with-pgx) for more information. -pgx supports many additional features beyond what is available through database/sql. +## Features -* Support for approximately 60 different PostgreSQL types +* Support for approximately 70 different PostgreSQL types +* Automatic statement preparation and caching * Batch queries * Single-round trip query mode * Full TLS connection control -* Binary format support for custom types (can be much faster) -* Copy protocol support for faster bulk data loads -* Extendable logging support including built-in support for log15 and logrus -* Connection pool with after connect hook to do arbitrary connection setup -* Listen / notify -* PostgreSQL array to Go slice mapping for integers, floats, and strings -* Hstore support -* JSON and JSONB support -* Maps inet and cidr PostgreSQL types to net.IPNet and net.IP +* Binary format support for custom types (allows for much quicker encoding/decoding) +* `COPY` protocol support for faster bulk data loads +* Tracing and logging support +* Connection pool with after-connect hook for arbitrary connection setup +* `LISTEN` / `NOTIFY` +* Conversion of PostgreSQL arrays to Go slice mappings for integers, floats, and strings +* `hstore` support +* `json` and `jsonb` support +* Maps `inet` and `cidr` PostgreSQL types to `netip.Addr` and `netip.Prefix` * Large object support -* NULL mapping to Null* struct or pointer to pointer. -* Supports database/sql.Scanner and database/sql/driver.Valuer interfaces for custom types -* Logical replication connections, including receiving WAL and sending standby status updates -* Notice response handling (this is different than listen / notify) +* NULL mapping to pointer to pointer +* Supports `database/sql.Scanner` and `database/sql/driver.Valuer` interfaces for custom types +* Notice response handling +* Simulated nested transactions with savepoints -## Performance +## Choosing Between the pgx and database/sql Interfaces -pgx performs roughly equivalent to [go-pg](https://github.com/go-pg/pg) and is almost always faster than [pq](http://godoc.org/github.com/lib/pq). When parsing large result sets the percentage difference can be significant (16483 queries/sec for pgx vs. 10106 queries/sec for pq -- 63% faster). +The pgx interface is faster. Many PostgreSQL specific features such as `LISTEN` / `NOTIFY` and `COPY` are not available +through the `database/sql` interface. -In many use cases a significant cause of latency is network round trips between the application and the server. pgx supports query batching to bundle multiple queries into a single round trip. Even in the case of a connection with the lowest possible latency, a local Unix domain socket, batching as few as three queries together can yield an improvement of 57%. With a typical network connection the results can be even more substantial. +The pgx interface is recommended when: -See this [gist](https://gist.github.com/jackc/4996e8648a0c59839bff644f49d6e434) for the underlying benchmark results or checkout [go_db_bench](https://github.com/jackc/go_db_bench) to run tests for yourself. +1. The application only targets PostgreSQL. +2. No other libraries that require `database/sql` are in use. -In addition to the native driver, pgx also includes a number of packages that provide additional functionality. +It is also possible to use the `database/sql` interface and convert a connection to the lower-level pgx interface as needed. -## github.com/jackc/pgx/stdlib +## Testing -database/sql compatibility layer for pgx. pgx can be used as a normal database/sql driver, but at any time the native interface may be acquired for more performance or PostgreSQL specific functionality. +See [CONTRIBUTING.md](./CONTRIBUTING.md) for setup instructions. -## github.com/jackc/pgx/pgtype +## Architecture -Approximately 60 PostgreSQL types are supported including uuid, hstore, json, bytea, numeric, interval, inet, and arrays. These types support database/sql interfaces and are usable even outside of pgx. They are fully tested in pgx and pq. They also support a higher performance interface when used with the pgx driver. +See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube.com/watch?v=sXMSWhcHCf8) for a description of pgx architecture. + +## Supported Go and PostgreSQL Versions + +pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.24 and higher and PostgreSQL 13 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). + +## Version Policy -## github.com/jackc/pgx/pgproto3 +pgx follows semantic versioning for the documented public API on stable releases. `v5` is the latest stable major version. -pgproto3 provides standalone encoding and decoding of the PostgreSQL v3 wire protocol. This is useful for implementing very low level PostgreSQL tooling. +## PGX Family Libraries -## github.com/jackc/pgx/pgmock +### [github.com/jackc/pglogrepl](https://github.com/jackc/pglogrepl) + +pglogrepl provides functionality to act as a client for PostgreSQL logical replication. + +### [github.com/jackc/pgmock](https://github.com/jackc/pgmock) pgmock offers the ability to create a server that mocks the PostgreSQL wire protocol. This is used internally to test pgx by purposely inducing unusual errors. pgproto3 and pgmock together provide most of the foundational tooling required to implement a PostgreSQL proxy or MitM (such as for a custom connection pooler). -## Documentation +### [github.com/jackc/tern](https://github.com/jackc/tern) -pgx includes extensive documentation in the godoc format. It is viewable online at [godoc.org](https://godoc.org/github.com/jackc/pgx). +tern is a stand-alone SQL migration system. -## Testing +### [github.com/jackc/pgerrcode](https://github.com/jackc/pgerrcode) -pgx supports multiple connection and authentication types. Setting up a test -environment that can test all of them can be cumbersome. In particular, -Windows cannot test Unix domain socket connections. Because of this pgx will -skip tests for connection types that are not configured. +pgerrcode contains constants for the PostgreSQL error codes. -### Normal Test Environment +## Adapters for 3rd Party Types -To setup the normal test environment, first install these dependencies: +* [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid) +* [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal) +* [github.com/twpayne/pgx-geos](https://github.com/twpayne/pgx-geos) ([PostGIS](https://postgis.net/) and [GEOS](https://libgeos.org/) via [go-geos](https://github.com/twpayne/go-geos)) +* [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid) - go get github.com/cockroachdb/apd - go get github.com/hashicorp/go-version - go get github.com/jackc/fake - go get github.com/lib/pq - go get github.com/pkg/errors - go get github.com/satori/go.uuid - go get github.com/shopspring/decimal - go get github.com/sirupsen/logrus - go get go.uber.org/zap - go get gopkg.in/inconshreveable/log15.v2 -Then run the following SQL: +## Adapters for 3rd Party Tracers - create user pgx_md5 password 'secret'; - create user " tricky, ' } "" \ test user " password 'secret'; - create database pgx_test; - create user pgx_replication with replication password 'secret'; +* [github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer) +* [github.com/exaring/otelpgx](https://github.com/exaring/otelpgx) -Connect to database pgx_test and run: +## Adapters for 3rd Party Loggers - create extension hstore; +These adapters can be used with the tracelog package. -Next open conn_config_test.go.example and make a copy without the -.example. If your PostgreSQL server is accepting connections on 127.0.0.1, -then you are done. +* [github.com/jackc/pgx-go-kit-log](https://github.com/jackc/pgx-go-kit-log) +* [github.com/jackc/pgx-log15](https://github.com/jackc/pgx-log15) +* [github.com/jackc/pgx-logrus](https://github.com/jackc/pgx-logrus) +* [github.com/jackc/pgx-zap](https://github.com/jackc/pgx-zap) +* [github.com/jackc/pgx-zerolog](https://github.com/jackc/pgx-zerolog) +* [github.com/mcosta74/pgx-slog](https://github.com/mcosta74/pgx-slog) +* [github.com/kataras/pgx-golog](https://github.com/kataras/pgx-golog) -### Connection and Authentication Test Environment +## 3rd Party Libraries with PGX Support -Complete the normal test environment setup and also do the following. +### [github.com/pashagolub/pgxmock](https://github.com/pashagolub/pgxmock) -Run the following SQL: +pgxmock is a mock library implementing pgx interfaces. +pgxmock has one and only purpose - to simulate pgx behavior in tests, without needing a real database connection. - create user pgx_none; - create user pgx_pw password 'secret'; +### [github.com/georgysavva/scany](https://github.com/georgysavva/scany) -Add the following to your pg_hba.conf: +Library for scanning data from a database into Go structs and more. -If you are developing on Unix with domain socket connections: +### [github.com/vingarcia/ksql](https://github.com/vingarcia/ksql) - local pgx_test pgx_none trust - local pgx_test pgx_pw password - local pgx_test pgx_md5 md5 +A carefully designed SQL client for making using SQL easier, +more productive, and less error-prone on Golang. -If you are developing on Windows with TCP connections: +### [github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5) - host pgx_test pgx_none 127.0.0.1/32 trust - host pgx_test pgx_pw 127.0.0.1/32 password - host pgx_test pgx_md5 127.0.0.1/32 md5 +Adds GSSAPI / Kerberos authentication support. -### Replication Test Environment +### [github.com/wcamarao/pmx](https://github.com/wcamarao/pmx) -Add a replication user: +Explicit data mapping and scanning library for Go structs and slices. - create user pgx_replication with replication password 'secret'; +### [github.com/stephenafamo/scan](https://github.com/stephenafamo/scan) -Add a replication line to your pg_hba.conf: +Type safe and flexible package for scanning database data into Go types. +Supports, structs, maps, slices and custom mapping functions. - host replication pgx_replication 127.0.0.1/32 md5 +### [github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx) -Change the following settings in your postgresql.conf: +Code first migration library for native pgx (no database/sql abstraction). - wal_level=logical - max_wal_senders=5 - max_replication_slots=5 +### [github.com/amirsalarsafaei/sqlc-pgx-monitoring](https://github.com/amirsalarsafaei/sqlc-pgx-monitoring) -Set `replicationConnConfig` appropriately in `conn_config_test.go`. +A database monitoring/metrics library for pgx and sqlc. Trace, log and monitor your sqlc query performance using OpenTelemetry. -## Version Policy +### [https://github.com/nikolayk812/pgx-outbox](https://github.com/nikolayk812/pgx-outbox) + +Simple Golang implementation for transactional outbox pattern for PostgreSQL using jackc/pgx driver. + +### [https://github.com/Arlandaren/pgxWrappy](https://github.com/Arlandaren/pgxWrappy) + +Simplifies working with the pgx library, providing convenient scanning of nested structures. + +### [https://github.com/KoNekoD/pgx-colon-query-rewriter](https://github.com/KoNekoD/pgx-colon-query-rewriter) -pgx follows semantic versioning for the documented public API on stable releases. Branch `v3` is the latest stable release. `master` can contain new features or behavior that will change or be removed before being merged to the stable `v3` branch (in practice, this occurs very rarely). `v2` is the previous stable release. +Implementation of the pgx query rewriter to use ':' instead of '@' in named query parameters. diff --git a/Rakefile b/Rakefile new file mode 100644 index 000000000..3e3aa5030 --- /dev/null +++ b/Rakefile @@ -0,0 +1,18 @@ +require "erb" + +rule '.go' => '.go.erb' do |task| + erb = ERB.new(File.read(task.source)) + File.write(task.name, "// Code generated from #{task.source}. DO NOT EDIT.\n\n" + erb.result(binding)) + sh "goimports", "-w", task.name +end + +generated_code_files = [ + "pgtype/int.go", + "pgtype/int_test.go", + "pgtype/integration_benchmark_test.go", + "pgtype/zeronull/int.go", + "pgtype/zeronull/int_test.go" +] + +desc "Generate code" +task generate: generated_code_files diff --git a/batch.go b/batch.go index 0d7f14cc8..d5e7dc8ec 100644 --- a/batch.go +++ b/batch.go @@ -2,312 +2,506 @@ package pgx import ( "context" + "errors" + "fmt" - "github.com/jackc/pgx/pgproto3" - "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/v5/pgconn" ) -type batchItem struct { - query string - arguments []interface{} - parameterOIDs []pgtype.OID - resultFormatCodes []int16 +// QueuedQuery is a query that has been queued for execution via a Batch. +type QueuedQuery struct { + SQL string + Arguments []any + Fn batchItemFunc + sd *pgconn.StatementDescription +} + +type batchItemFunc func(br BatchResults) error + +// Query sets fn to be called when the response to qq is received. +func (qq *QueuedQuery) Query(fn func(rows Rows) error) { + qq.Fn = func(br BatchResults) error { + rows, _ := br.Query() + defer rows.Close() + + err := fn(rows) + if err != nil { + return err + } + rows.Close() + + return rows.Err() + } +} + +// Query sets fn to be called when the response to qq is received. +func (qq *QueuedQuery) QueryRow(fn func(row Row) error) { + qq.Fn = func(br BatchResults) error { + row := br.QueryRow() + return fn(row) + } +} + +// Exec sets fn to be called when the response to qq is received. +// +// Note: for simple batch insert uses where it is not required to handle +// each potential error individually, it's sufficient to not set any callbacks, +// and just handle the return value of BatchResults.Close. +func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) { + qq.Fn = func(br BatchResults) error { + ct, err := br.Exec() + if err != nil { + return err + } + + return fn(ct) + } } // Batch queries are a way of bundling multiple queries together to avoid -// unnecessary network round trips. +// unnecessary network round trips. A Batch must only be sent once. type Batch struct { - conn *Conn - connPool *ConnPool - items []*batchItem - resultsRead int - pendingCommandComplete bool - ctx context.Context - err error - inTx bool + QueuedQueries []*QueuedQuery } -// BeginBatch returns a *Batch query for c. -func (c *Conn) BeginBatch() *Batch { - return &Batch{conn: c} +// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. The only pgx option +// argument that is supported is QueryRewriter. Queries are executed using the connection's DefaultQueryExecMode. +// +// While query can contain multiple statements if the connection's DefaultQueryExecMode is QueryModeSimple, this should +// be avoided. QueuedQuery.Fn must not be set as it will only be called for the first query. That is, QueuedQuery.Query, +// QueuedQuery.QueryRow, and QueuedQuery.Exec must not be called. In addition, any error messages or tracing that +// include the current query may reference the wrong query. +func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery { + qq := &QueuedQuery{ + SQL: query, + Arguments: arguments, + } + b.QueuedQueries = append(b.QueuedQueries, qq) + return qq } -// BeginBatch returns a *Batch query for tx. Since this *Batch is already part -// of a transaction it will not automatically be wrapped in a transaction. -func (tx *Tx) BeginBatch() *Batch { - return &Batch{conn: tx.conn, inTx: true} +// Len returns number of queries that have been queued so far. +func (b *Batch) Len() int { + return len(b.QueuedQueries) } -// Conn returns the underlying connection that b will or was performed on. -func (b *Batch) Conn() *Conn { - return b.conn +type BatchResults interface { + // Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. Prefer + // calling Exec on the QueuedQuery, or just calling Close. + Exec() (pgconn.CommandTag, error) + + // Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. Prefer + // calling Query on the QueuedQuery. + Query() (Rows, error) + + // QueryRow reads the results from the next query in the batch as if the query has been sent with Conn.QueryRow. + // Prefer calling QueryRow on the QueuedQuery. + QueryRow() Row + + // Close closes the batch operation. All unread results are read and any callback functions registered with + // QueuedQuery.Query, QueuedQuery.QueryRow, or QueuedQuery.Exec will be called. If a callback function returns an + // error or the batch encounters an error subsequent callback functions will not be called. + // + // For simple batch inserts inside a transaction or similar queries, it's sufficient to not set any callbacks, + // and just handle the return value of Close. + // + // Close must be called before the underlying connection can be used again. Any error that occurred during a batch + // operation may have made it impossible to resyncronize the connection with the server. In this case the underlying + // connection will have been closed. + // + // Close is safe to call multiple times. If it returns an error subsequent calls will return the same error. Callback + // functions will not be rerun. + Close() error } -// Queue queues a query to batch b. parameterOIDs are required if there are -// parameters and query is not the name of a prepared statement. -// resultFormatCodes are required if there is a result. -func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgtype.OID, resultFormatCodes []int16) { - b.items = append(b.items, &batchItem{ - query: query, - arguments: arguments, - parameterOIDs: parameterOIDs, - resultFormatCodes: resultFormatCodes, - }) +type batchResults struct { + ctx context.Context + conn *Conn + mrr *pgconn.MultiResultReader + err error + b *Batch + qqIdx int + closed bool + endTraced bool } -// Send sends all queued queries to the server at once. -// If the batch is created from a conn Object then All queries are wrapped -// in a transaction. The transaction can optionally be configured with -// txOptions. The context is in effect until the Batch is closed. -// -// Warning: Send writes all queued queries before reading any results. This can -// cause a deadlock if an excessive number of queries are queued. It is highly -// advisable to use a timeout context to protect against this possibility. -// Unfortunately, this excessive number can vary based on operating system, -// connection type (TCP or Unix domain socket), and type of query. Unix domain -// sockets seem to be much more susceptible to this issue than TCP connections. -// However, it usually is at least several thousand. -// -// The deadlock occurs when the batched queries to be sent are so large that the -// PostgreSQL server cannot receive it all at once. PostgreSQL received some of -// the queued queries and starts executing them. As PostgreSQL executes the -// queries it sends responses back. pgx will not read any of these responses -// until it has finished sending. Therefore, if all network buffers are full pgx -// will not be able to finish sending the queries and PostgreSQL will not be -// able to finish sending the responses. -// -// See https://github.com/jackc/pgx/issues/374. -func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { - if b.err != nil { - return b.err +// Exec reads the results from the next query in the batch as if the query has been sent with Exec. +func (br *batchResults) Exec() (pgconn.CommandTag, error) { + if br.err != nil { + return pgconn.CommandTag{}, br.err + } + if br.closed { + return pgconn.CommandTag{}, fmt.Errorf("batch already closed") } - b.ctx = ctx + query, arguments, _ := br.nextQueryAndArgs() + + if !br.mrr.NextResult() { + err := br.mrr.Close() + if err == nil { + err = errors.New("no more results in batch") + } + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ + SQL: query, + Args: arguments, + Err: err, + }) + } + return pgconn.CommandTag{}, err + } - err := b.conn.waitForPreviousCancelQuery(ctx) + commandTag, err := br.mrr.ResultReader().Close() if err != nil { - return err + br.err = err + br.mrr.Close() + } + + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ + SQL: query, + Args: arguments, + CommandTag: commandTag, + Err: br.err, + }) } - if err := b.conn.ensureConnectionReadyForQuery(); err != nil { - return err + return commandTag, br.err +} + +// Query reads the results from the next query in the batch as if the query has been sent with Query. +func (br *batchResults) Query() (Rows, error) { + query, arguments, ok := br.nextQueryAndArgs() + if !ok { + query = "batch query" } - buf := b.conn.wbuf - if !b.inTx { - buf = appendQuery(buf, txOptions.beginSQL()) + if br.err != nil { + return &baseRows{err: br.err, closed: true}, br.err } - err = b.conn.initContext(ctx) - if err != nil { - return err + if br.closed { + alreadyClosedErr := fmt.Errorf("batch already closed") + return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr } - for _, bi := range b.items { - var psName string - var psParameterOIDs []pgtype.OID + rows := br.conn.getRows(br.ctx, query, arguments) + rows.batchTracer = br.conn.batchTracer - if ps, ok := b.conn.preparedStatements[bi.query]; ok { - psName = ps.Name - psParameterOIDs = ps.ParameterOIDs - } else { - psParameterOIDs = bi.parameterOIDs - buf = appendParse(buf, "", bi.query, psParameterOIDs) + if !br.mrr.NextResult() { + rows.err = br.mrr.Close() + if rows.err == nil { + rows.err = errors.New("no more results in batch") } - - var err error - buf, err = appendBind(buf, "", psName, b.conn.ConnInfo, psParameterOIDs, bi.arguments, bi.resultFormatCodes) - if err != nil { - return err + rows.closed = true + + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ + SQL: query, + Args: arguments, + Err: rows.err, + }) } - buf = appendDescribe(buf, 'P', "") - buf = appendExecute(buf, "", 0) + return rows, rows.err } - buf = appendSync(buf) - b.conn.pendingReadyForQueryCount++ + rows.resultReader = br.mrr.ResultReader() + return rows, nil +} - if !b.inTx { - buf = appendQuery(buf, "commit") - b.conn.pendingReadyForQueryCount++ - } +// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow. +func (br *batchResults) QueryRow() Row { + rows, _ := br.Query() + return (*connRow)(rows.(*baseRows)) +} - n, err := b.conn.conn.Write(buf) - if err != nil { - if fatalWriteErr(n, err) { - b.conn.die(err) +// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to +// resyncronize the connection with the server. In this case the underlying connection will have been closed. +func (br *batchResults) Close() error { + defer func() { + if !br.endTraced { + if br.conn != nil && br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err}) + } + br.endTraced = true } - return err + + invalidateCachesOnBatchResultsError(br.conn, br.b, br.err) + }() + + if br.err != nil { + return br.err } - for !b.inTx { - msg, err := b.conn.rxMsg() - if err != nil { - return err - } + if br.closed { + return nil + } - switch msg := msg.(type) { - case *pgproto3.ReadyForQuery: - return nil - default: - if err := b.conn.processContextFreeMsg(msg); err != nil { - return err + // Read and run fn for all remaining items + for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) { + if br.b.QueuedQueries[br.qqIdx].Fn != nil { + err := br.b.QueuedQueries[br.qqIdx].Fn(br) + if err != nil { + br.err = err } + } else { + br.Exec() } } - return nil -} + br.closed = true -// ExecResults reads the results from the next query in the batch as if the -// query has been sent with Exec. -func (b *Batch) ExecResults() (CommandTag, error) { - if b.err != nil { - return "", b.err + err := br.mrr.Close() + if br.err == nil { + br.err = err } - select { - case <-b.ctx.Done(): - b.die(b.ctx.Err()) - return "", b.ctx.Err() - default: - } + return br.err +} + +func (br *batchResults) earlyError() error { + return br.err +} - if err := b.ensureCommandComplete(); err != nil { - b.die(err) - return "", err +func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) { + if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) { + bi := br.b.QueuedQueries[br.qqIdx] + query = bi.SQL + args = bi.Arguments + ok = true + br.qqIdx++ } + return +} - b.resultsRead++ +type pipelineBatchResults struct { + ctx context.Context + conn *Conn + pipeline *pgconn.Pipeline + lastRows *baseRows + err error + b *Batch + qqIdx int + closed bool + endTraced bool +} - b.pendingCommandComplete = true +// Exec reads the results from the next query in the batch as if the query has been sent with Exec. +func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) { + if br.err != nil { + return pgconn.CommandTag{}, br.err + } + if br.closed { + return pgconn.CommandTag{}, fmt.Errorf("batch already closed") + } + if br.lastRows != nil && br.lastRows.err != nil { + return pgconn.CommandTag{}, br.err + } - for { - msg, err := b.conn.rxMsg() - if err != nil { - return "", err - } + query, arguments, err := br.nextQueryAndArgs() + if err != nil { + return pgconn.CommandTag{}, err + } - switch msg := msg.(type) { - case *pgproto3.CommandComplete: - b.pendingCommandComplete = false - return CommandTag(msg.CommandTag), nil - default: - if err := b.conn.processContextFreeMsg(msg); err != nil { - return "", err - } - } + results, err := br.pipeline.GetResults() + if err != nil { + br.err = err + return pgconn.CommandTag{}, br.err + } + var commandTag pgconn.CommandTag + switch results := results.(type) { + case *pgconn.ResultReader: + commandTag, br.err = results.Close() + default: + return pgconn.CommandTag{}, fmt.Errorf("unexpected pipeline result: %T", results) + } + + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ + SQL: query, + Args: arguments, + CommandTag: commandTag, + Err: br.err, + }) } -} -// QueryResults reads the results from the next query in the batch as if the -// query has been sent with Query. -func (b *Batch) QueryResults() (*Rows, error) { - rows := b.conn.getRows("batch query", nil) + return commandTag, br.err +} - if b.err != nil { - rows.fatal(b.err) - return rows, b.err +// Query reads the results from the next query in the batch as if the query has been sent with Query. +func (br *pipelineBatchResults) Query() (Rows, error) { + if br.err != nil { + return &baseRows{err: br.err, closed: true}, br.err } - select { - case <-b.ctx.Done(): - b.die(b.ctx.Err()) - rows.fatal(b.err) - return rows, b.ctx.Err() - default: + if br.closed { + alreadyClosedErr := fmt.Errorf("batch already closed") + return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr } - if err := b.ensureCommandComplete(); err != nil { - b.die(err) - rows.fatal(err) - return rows, err + if br.lastRows != nil && br.lastRows.err != nil { + br.err = br.lastRows.err + return &baseRows{err: br.err, closed: true}, br.err } - b.resultsRead++ + query, arguments, err := br.nextQueryAndArgs() + if err != nil { + return &baseRows{err: err, closed: true}, err + } - b.pendingCommandComplete = true + rows := br.conn.getRows(br.ctx, query, arguments) + rows.batchTracer = br.conn.batchTracer + br.lastRows = rows - fieldDescriptions, err := b.conn.readUntilRowDescription() + results, err := br.pipeline.GetResults() if err != nil { - b.die(err) - rows.fatal(b.err) - return rows, err + br.err = err + rows.err = err + rows.closed = true + + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ + SQL: query, + Args: arguments, + Err: err, + }) + } + } else { + switch results := results.(type) { + case *pgconn.ResultReader: + rows.resultReader = results + default: + err = fmt.Errorf("unexpected pipeline result: %T", results) + br.err = err + rows.err = err + rows.closed = true + } } - rows.batch = b - rows.fields = fieldDescriptions - return rows, nil + return rows, rows.err } -// QueryRowResults reads the results from the next query in the batch as if the -// query has been sent with QueryRow. -func (b *Batch) QueryRowResults() *Row { - rows, _ := b.QueryResults() - return (*Row)(rows) - +// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow. +func (br *pipelineBatchResults) QueryRow() Row { + rows, _ := br.Query() + return (*connRow)(rows.(*baseRows)) } -// Close closes the batch operation. Any error that occured during a batch -// operation may have made it impossible to resyncronize the connection with the -// server. In this case the underlying connection will have been closed. -func (b *Batch) Close() (err error) { - if b.err != nil { - return b.err - } - +// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to +// resyncronize the connection with the server. In this case the underlying connection will have been closed. +func (br *pipelineBatchResults) Close() error { defer func() { - err = b.conn.termContext(err) - if b.conn != nil && b.connPool != nil { - b.connPool.Release(b.conn) + if !br.endTraced { + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err}) + } + br.endTraced = true } + + invalidateCachesOnBatchResultsError(br.conn, br.b, br.err) }() - for i := b.resultsRead; i < len(b.items); i++ { - if _, err = b.ExecResults(); err != nil { - return err + if br.err == nil && br.lastRows != nil && br.lastRows.err != nil { + br.err = br.lastRows.err + } + + if br.closed { + return br.err + } + + // Read and run fn for all remaining items + for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) { + if br.b.QueuedQueries[br.qqIdx].Fn != nil { + err := br.b.QueuedQueries[br.qqIdx].Fn(br) + if err != nil { + br.err = err + } + } else { + br.Exec() } } - if err = b.conn.ensureConnectionReadyForQuery(); err != nil { - return err + br.closed = true + + err := br.pipeline.Close() + if br.err == nil { + br.err = err } - return nil + return br.err } -func (b *Batch) die(err error) { - if b.err != nil { - return +func (br *pipelineBatchResults) earlyError() error { + return br.err +} + +func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, err error) { + if br.b == nil { + return "", nil, errors.New("no reference to batch") + } + + if br.qqIdx >= len(br.b.QueuedQueries) { + return "", nil, errors.New("no more results in batch") } - b.err = err - b.conn.die(err) + bi := br.b.QueuedQueries[br.qqIdx] + br.qqIdx++ + return bi.SQL, bi.Arguments, nil +} + +type emptyBatchResults struct { + conn *Conn + closed bool +} - if b.conn != nil && b.connPool != nil { - b.connPool.Release(b.conn) +// Exec reads the results from the next query in the batch as if the query has been sent with Exec. +func (br *emptyBatchResults) Exec() (pgconn.CommandTag, error) { + if br.closed { + return pgconn.CommandTag{}, fmt.Errorf("batch already closed") } + return pgconn.CommandTag{}, errors.New("no more results in batch") } -func (b *Batch) ensureCommandComplete() error { - for b.pendingCommandComplete { - msg, err := b.conn.rxMsg() - if err != nil { - return err +// Query reads the results from the next query in the batch as if the query has been sent with Query. +func (br *emptyBatchResults) Query() (Rows, error) { + if br.closed { + alreadyClosedErr := fmt.Errorf("batch already closed") + return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr + } + + rows := br.conn.getRows(context.Background(), "", nil) + rows.err = errors.New("no more results in batch") + rows.closed = true + return rows, rows.err +} + +// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow. +func (br *emptyBatchResults) QueryRow() Row { + rows, _ := br.Query() + return (*connRow)(rows.(*baseRows)) +} + +// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to +// resyncronize the connection with the server. In this case the underlying connection will have been closed. +func (br *emptyBatchResults) Close() error { + br.closed = true + return nil +} + +// invalidates statement and description caches on batch results error +func invalidateCachesOnBatchResultsError(conn *Conn, b *Batch, err error) { + if err != nil && conn != nil && b != nil { + if sc := conn.statementCache; sc != nil { + for _, bi := range b.QueuedQueries { + sc.Invalidate(bi.SQL) + } } - switch msg := msg.(type) { - case *pgproto3.CommandComplete: - b.pendingCommandComplete = false - return nil - default: - err = b.conn.processContextFreeMsg(msg) - if err != nil { - return err + if sc := conn.descriptionCache; sc != nil { + for _, bi := range b.QueuedQueries { + sc.Invalidate(bi.SQL) } } } - - return nil } diff --git a/batch_test.go b/batch_test.go index 61bbe357a..b4c421e57 100644 --- a/batch_test.go +++ b/batch_test.go @@ -2,182 +2,458 @@ package pgx_test import ( "context" + "errors" + "fmt" + "os" "testing" + "time" - "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func TestConnBeginBatch(t *testing.T) { +func TestConnSendBatch(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - sql := `create temporary table ledger( - id serial primary key, - description varchar not null, - amount int not null -);` - mustExec(t, conn, sql) - - batch := conn.BeginBatch() - batch.Queue("insert into ledger(description, amount) values($1, $2)", - []interface{}{"q1", 1}, - []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, - nil, - ) - batch.Queue("insert into ledger(description, amount) values($1, $2)", - []interface{}{"q2", 2}, - []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, - nil, - ) - batch.Queue("insert into ledger(description, amount) values($1, $2)", - []interface{}{"q3", 3}, - []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, - nil, - ) - batch.Queue("select id, description, amount from ledger order by id", - nil, - nil, - []int16{pgx.BinaryFormatCode, pgx.TextFormatCode, pgx.BinaryFormatCode}, - ) - batch.Queue("select sum(amount) from ledger", - nil, - nil, - []int16{pgx.BinaryFormatCode}, - ) - - err := batch.Send(context.Background(), nil) - if err != nil { - t.Fatal(err) - } + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test") - ct, err := batch.ExecResults() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 1 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) - } + sql := `create temporary table ledger( + id serial primary key, + description varchar not null, + amount int not null + );` + mustExec(t, conn, sql) - ct, err = batch.ExecResults() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 1 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) - } + batch := &pgx.Batch{} + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1) + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2) + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3) + batch.Queue("select id, description, amount from ledger order by id") + batch.Queue("select id, description, amount from ledger order by id") + batch.Queue("select * from ledger where false") + batch.Queue("select sum(amount) from ledger") - rows, err := batch.QueryResults() - if err != nil { - t.Error(err) - } + br := conn.SendBatch(ctx, batch) - var id int32 - var description string - var amount int32 - if !rows.Next() { - t.Fatal("expected a row to be available") - } - if err := rows.Scan(&id, &description, &amount); err != nil { - t.Fatal(err) - } - if id != 1 { - t.Errorf("id => %v, want %v", id, 1) - } - if description != "q1" { - t.Errorf("description => %v, want %v", description, "q1") - } - if amount != 1 { - t.Errorf("amount => %v, want %v", amount, 1) - } + ct, err := br.Exec() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } - if !rows.Next() { - t.Fatal("expected a row to be available") - } - if err := rows.Scan(&id, &description, &amount); err != nil { - t.Fatal(err) - } - if id != 2 { - t.Errorf("id => %v, want %v", id, 2) - } - if description != "q2" { - t.Errorf("description => %v, want %v", description, "q2") - } - if amount != 2 { - t.Errorf("amount => %v, want %v", amount, 2) - } + ct, err = br.Exec() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } - if !rows.Next() { - t.Fatal("expected a row to be available") - } - if err := rows.Scan(&id, &description, &amount); err != nil { - t.Fatal(err) - } - if id != 3 { - t.Errorf("id => %v, want %v", id, 3) - } - if description != "q3" { - t.Errorf("description => %v, want %v", description, "q3") - } - if amount != 3 { - t.Errorf("amount => %v, want %v", amount, 3) - } + ct, err = br.Exec() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } - if rows.Next() { - t.Fatal("did not expect a row to be available") - } + selectFromLedgerExpectedRows := []struct { + id int32 + description string + amount int32 + }{ + {1, "q1", 1}, + {2, "q2", 2}, + {3, "q3", 3}, + } - if rows.Err() != nil { - t.Fatal(rows.Err()) - } + rows, err := br.Query() + if err != nil { + t.Error(err) + } - err = batch.QueryRowResults().Scan(&amount) - if err != nil { - t.Error(err) - } - if amount != 6 { - t.Errorf("amount => %v, want %v", amount, 6) - } + var id int32 + var description string + var amount int32 + rowCount := 0 - err = batch.Close() - if err != nil { - t.Fatal(err) + for rows.Next() { + if rowCount >= len(selectFromLedgerExpectedRows) { + t.Fatalf("got too many rows: %d", rowCount) + } + + if err := rows.Scan(&id, &description, &amount); err != nil { + t.Fatalf("row %d: %v", rowCount, err) + } + + if id != selectFromLedgerExpectedRows[rowCount].id { + t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id) + } + if description != selectFromLedgerExpectedRows[rowCount].description { + t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description) + } + if amount != selectFromLedgerExpectedRows[rowCount].amount { + t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount) + } + + rowCount++ + } + + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + + rowCount = 0 + rows, _ = br.Query() + _, err = pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error { + if id != selectFromLedgerExpectedRows[rowCount].id { + t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id) + } + if description != selectFromLedgerExpectedRows[rowCount].description { + t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description) + } + if amount != selectFromLedgerExpectedRows[rowCount].amount { + t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount) + } + + rowCount++ + + return nil + }) + if err != nil { + t.Error(err) + } + + err = br.QueryRow().Scan(&id, &description, &amount) + if !errors.Is(err, pgx.ErrNoRows) { + t.Errorf("expected pgx.ErrNoRows but got: %v", err) + } + + err = br.QueryRow().Scan(&amount) + if err != nil { + t.Error(err) + } + if amount != 6 { + t.Errorf("amount => %v, want %v", amount, 6) + } + + err = br.Close() + if err != nil { + t.Fatal(err) + } + }) +} + +func TestConnSendBatchQueuedQuery(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test") + + sql := `create temporary table ledger( + id serial primary key, + description varchar not null, + amount int not null + );` + mustExec(t, conn, sql) + + batch := &pgx.Batch{} + + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1).Exec(func(ct pgconn.CommandTag) error { + assert.EqualValues(t, 1, ct.RowsAffected()) + return nil + }) + + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2).Exec(func(ct pgconn.CommandTag) error { + assert.EqualValues(t, 1, ct.RowsAffected()) + return nil + }) + + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3).Exec(func(ct pgconn.CommandTag) error { + assert.EqualValues(t, 1, ct.RowsAffected()) + return nil + }) + + selectFromLedgerExpectedRows := []struct { + id int32 + description string + amount int32 + }{ + {1, "q1", 1}, + {2, "q2", 2}, + {3, "q3", 3}, + } + + batch.Queue("select id, description, amount from ledger order by id").Query(func(rows pgx.Rows) error { + rowCount := 0 + var id int32 + var description string + var amount int32 + _, err := pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error { + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].id, id) + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].description, description) + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].amount, amount) + rowCount++ + + return nil + }) + assert.NoError(t, err) + return nil + }) + + batch.Queue("select id, description, amount from ledger order by id").Query(func(rows pgx.Rows) error { + rowCount := 0 + var id int32 + var description string + var amount int32 + _, err := pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error { + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].id, id) + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].description, description) + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].amount, amount) + rowCount++ + + return nil + }) + assert.NoError(t, err) + return nil + }) + + batch.Queue("select * from ledger where false").QueryRow(func(row pgx.Row) error { + err := row.Scan(nil, nil, nil) + assert.ErrorIs(t, err, pgx.ErrNoRows) + return nil + }) + + batch.Queue("select sum(amount) from ledger").QueryRow(func(row pgx.Row) error { + var sumAmount int32 + err := row.Scan(&sumAmount) + assert.NoError(t, err) + assert.EqualValues(t, 6, sumAmount) + return nil + }) + + err := conn.SendBatch(ctx, batch).Close() + assert.NoError(t, err) + }) +} + +func TestConnSendBatchMany(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + sql := `create temporary table ledger( + id serial primary key, + description varchar not null, + amount int not null + );` + mustExec(t, conn, sql) + + batch := &pgx.Batch{} + + numInserts := 1000 + + for i := 0; i < numInserts; i++ { + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1) + } + batch.Queue("select count(*) from ledger") + + br := conn.SendBatch(ctx, batch) + + for i := 0; i < numInserts; i++ { + ct, err := br.Exec() + assert.NoError(t, err) + assert.EqualValues(t, 1, ct.RowsAffected()) + } + + var actualInserts int + err := br.QueryRow().Scan(&actualInserts) + assert.NoError(t, err) + assert.EqualValues(t, numInserts, actualInserts) + + err = br.Close() + require.NoError(t, err) + }) +} + +// https://github.com/jackc/pgx/issues/1801#issuecomment-2203784178 +func TestConnSendBatchReadResultsWhenNothingQueued(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + batch := &pgx.Batch{} + br := conn.SendBatch(ctx, batch) + commandTag, err := br.Exec() + require.Equal(t, "", commandTag.String()) + require.EqualError(t, err, "no more results in batch") + err = br.Close() + require.NoError(t, err) + }) +} + +func TestConnSendBatchReadMoreResultsThanQueriesSent(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + batch := &pgx.Batch{} + batch.Queue("select 1") + br := conn.SendBatch(ctx, batch) + commandTag, err := br.Exec() + require.Equal(t, "SELECT 1", commandTag.String()) + require.NoError(t, err) + commandTag, err = br.Exec() + require.Equal(t, "", commandTag.String()) + require.EqualError(t, err, "no more results in batch") + err = br.Close() + require.NoError(t, err) + }) +} + +func TestConnSendBatchWithPreparedStatement(t *testing.T) { + t.Parallel() + + modes := []pgx.QueryExecMode{ + pgx.QueryExecModeCacheStatement, + pgx.QueryExecModeCacheDescribe, + pgx.QueryExecModeDescribeExec, + pgx.QueryExecModeExec, + // Don't test simple mode with prepared statements. } + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - ensureConnValid(t, conn) + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, modes, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") + _, err := conn.Prepare(ctx, "ps1", "select n from generate_series(0,$1::int) n") + if err != nil { + t.Fatal(err) + } + + batch := &pgx.Batch{} + + queryCount := 3 + for i := 0; i < queryCount; i++ { + batch.Queue("ps1", 5) + } + + br := conn.SendBatch(ctx, batch) + + for i := 0; i < queryCount; i++ { + rows, err := br.Query() + if err != nil { + t.Fatal(err) + } + + for k := 0; rows.Next(); k++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Fatal(err) + } + if n != k { + t.Fatalf("n => %v, want %v", n, k) + } + } + + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + } + + err = br.Close() + if err != nil { + t.Fatal(err) + } + }) } -func TestConnBeginBatchWithPreparedStatement(t *testing.T) { +func TestConnSendBatchWithQueryRewriter(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + batch := &pgx.Batch{} + batch.Queue("something to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{1}}) + batch.Queue("something else to be replaced", &testQueryRewriter{sql: "select $1::text", args: []any{"hello"}}) + batch.Queue("more to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{3}}) + + br := conn.SendBatch(ctx, batch) + + var n int32 + err := br.QueryRow().Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + + var s string + err = br.QueryRow().Scan(&s) + require.NoError(t, err) + require.Equal(t, "hello", s) + + err = br.QueryRow().Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 3, n) + + err = br.Close() + require.NoError(t, err) + }) +} + +// https://github.com/jackc/pgx/issues/856 +func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 0 + + conn := mustConnect(t, config) defer closeConn(t, conn) - _, err := conn.Prepare("ps1", "select n from generate_series(0,$1::int) n") + pgxtest.SkipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") + + _, err = conn.Prepare(ctx, "ps1", "select n from generate_series(0,$1::int) n") if err != nil { t.Fatal(err) } - batch := conn.BeginBatch() + batch := &pgx.Batch{} queryCount := 3 for i := 0; i < queryCount; i++ { - batch.Queue("ps1", - []interface{}{5}, - nil, - []int16{pgx.BinaryFormatCode}, - ) + batch.Queue("ps1", 5) } - err = batch.Send(context.Background(), nil) - if err != nil { - t.Fatal(err) - } + br := conn.SendBatch(ctx, batch) for i := 0; i < queryCount; i++ { - rows, err := batch.QueryResults() + rows, err := br.Query() if err != nil { t.Fatal(err) } @@ -197,7 +473,7 @@ func TestConnBeginBatchWithPreparedStatement(t *testing.T) { } } - err = batch.Close() + err = br.Close() if err != nil { t.Fatal(err) } @@ -205,499 +481,640 @@ func TestConnBeginBatchWithPreparedStatement(t *testing.T) { ensureConnValid(t, conn) } -func TestConnBeginBatchContextCancelBeforeExecResults(t *testing.T) { +func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - - sql := `create temporary table ledger( - id serial primary key, - description varchar not null, - amount int not null -);` - mustExec(t, conn, sql) - - batch := conn.BeginBatch() - batch.Queue("insert into ledger(description, amount) values($1, $2)", - []interface{}{"q1", 1}, - []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, - nil, - ) - batch.Queue("select pg_sleep(2)", - nil, - nil, - nil, - ) - - ctx, cancelFn := context.WithCancel(context.Background()) - - err := batch.Send(ctx, nil) - if err != nil { - t.Fatal(err) - } + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - cancelFn() + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + batch := &pgx.Batch{} + batch.Queue("select n from generate_series(0,5) n") + batch.Queue("select n from generate_series(0,5) n") - _, err = batch.ExecResults() - if err != context.Canceled { - t.Errorf("err => %v, want %v", err, context.Canceled) - } + br := conn.SendBatch(ctx, batch) - if conn.IsAlive() { - t.Error("conn should be dead, but was alive") - } + rows, err := br.Query() + if err != nil { + t.Error(err) + } + + for i := 0; i < 3; i++ { + if !rows.Next() { + t.Error("expected a row to be available") + } + + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } + } + + rows.Close() + + rows, err = br.Query() + if err != nil { + t.Error(err) + } + + for i := 0; rows.Next(); i++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } + } + + if rows.Err() != nil { + t.Error(rows.Err()) + } + + err = br.Close() + if err != nil { + t.Fatal(err) + } + }) } -func TestConnBeginBatchContextCancelBeforeQueryResults(t *testing.T) { +func TestConnSendBatchQueryError(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - batch := conn.BeginBatch() - batch.Queue("select pg_sleep(2)", - nil, - nil, - nil, - ) - batch.Queue("select pg_sleep(2)", - nil, - nil, - nil, - ) + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + batch := &pgx.Batch{} + batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0") + batch.Queue("select n from generate_series(0,5) n") - ctx, cancelFn := context.WithCancel(context.Background()) + br := conn.SendBatch(ctx, batch) - err := batch.Send(ctx, nil) - if err != nil { - t.Fatal(err) - } + rows, err := br.Query() + if err != nil { + t.Error(err) + } - cancelFn() + for i := 0; rows.Next(); i++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } + } - _, err = batch.QueryResults() - if err != context.Canceled { - t.Errorf("err => %v, want %v", err, context.Canceled) - } + if pgErr, ok := rows.Err().(*pgconn.PgError); !(ok && pgErr.Code == "22012") { + t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012) + } - if conn.IsAlive() { - t.Error("conn should be dead, but was alive") - } + err = br.Close() + if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") { + t.Errorf("br.Close() => %v, want error code %v", err, 22012) + } + }) } -func TestConnBeginBatchContextCancelBeforeFinish(t *testing.T) { +func TestConnSendBatchQuerySyntaxError(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - batch := conn.BeginBatch() - batch.Queue("select pg_sleep(2)", - nil, - nil, - nil, - ) - batch.Queue("select pg_sleep(2)", - nil, - nil, - nil, - ) + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + batch := &pgx.Batch{} + batch.Queue("select 1 1") - ctx, cancelFn := context.WithCancel(context.Background()) + br := conn.SendBatch(ctx, batch) - err := batch.Send(ctx, nil) - if err != nil { - t.Fatal(err) - } + var n int32 + err := br.QueryRow().Scan(&n) + if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") { + t.Errorf("rows.Err() => %v, want error code %v", err, 42601) + } - cancelFn() + err = br.Close() + if err == nil { + t.Error("Expected error") + } + }) +} - err = batch.Close() - if err != context.Canceled { - t.Errorf("err => %v, want %v", err, context.Canceled) - } +func TestConnSendBatchQueryRowInsert(t *testing.T) { + t.Parallel() - if conn.IsAlive() { - t.Error("conn should be dead, but was alive") - } + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + sql := `create temporary table ledger( + id serial primary key, + description varchar not null, + amount int not null + );` + mustExec(t, conn, sql) + + batch := &pgx.Batch{} + batch.Queue("select 1") + batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1) + + br := conn.SendBatch(ctx, batch) + + var value int + err := br.QueryRow().Scan(&value) + if err != nil { + t.Error(err) + } + + ct, err := br.Exec() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 2 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) + } + + br.Close() + }) } -func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) { +func TestConnSendBatchQueryPartialReadInsert(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - batch := conn.BeginBatch() - batch.Queue("select n from generate_series(0,5) n", - nil, - nil, - []int16{pgx.BinaryFormatCode}, - ) - batch.Queue("select n from generate_series(0,5) n", - nil, - nil, - []int16{pgx.BinaryFormatCode}, - ) - - err := batch.Send(context.Background(), nil) - if err != nil { - t.Fatal(err) - } + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + sql := `create temporary table ledger( + id serial primary key, + description varchar not null, + amount int not null + );` + mustExec(t, conn, sql) - rows, err := batch.QueryResults() - if err != nil { - t.Error(err) - } + batch := &pgx.Batch{} + batch.Queue("select 1 union all select 2 union all select 3") + batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1) + + br := conn.SendBatch(ctx, batch) - for i := 0; i < 3; i++ { - if !rows.Next() { - t.Error("expected a row to be available") + rows, err := br.Query() + if err != nil { + t.Error(err) } + rows.Close() - var n int - if err := rows.Scan(&n); err != nil { + ct, err := br.Exec() + if err != nil { t.Error(err) } - if n != i { - t.Errorf("n => %v, want %v", n, i) + if ct.RowsAffected() != 2 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) } - } - rows.Close() + br.Close() + }) +} - rows, err = batch.QueryResults() - if err != nil { - t.Error(err) - } +func TestTxSendBatch(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + sql := `create temporary table ledger1( + id serial primary key, + description varchar not null + );` + mustExec(t, conn, sql) - for i := 0; rows.Next(); i++ { - var n int - if err := rows.Scan(&n); err != nil { + sql = `create temporary table ledger2( + id int primary key, + amount int not null + );` + mustExec(t, conn, sql) + + tx, _ := conn.Begin(ctx) + batch := &pgx.Batch{} + batch.Queue("insert into ledger1(description) values($1) returning id", "q1") + + br := tx.SendBatch(context.Background(), batch) + + var id int + err := br.QueryRow().Scan(&id) + if err != nil { t.Error(err) } - if n != i { - t.Errorf("n => %v, want %v", n, i) + br.Close() + + batch = &pgx.Batch{} + batch.Queue("insert into ledger2(id,amount) values($1, $2)", id, 2) + batch.Queue("select amount from ledger2 where id = $1", id) + + br = tx.SendBatch(ctx, batch) + + ct, err := br.Exec() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) } - } - if rows.Err() != nil { - t.Error(rows.Err()) - } + var amount int + err = br.QueryRow().Scan(&amount) + if err != nil { + t.Error(err) + } - err = batch.Close() - if err != nil { - t.Fatal(err) - } + br.Close() + tx.Commit(ctx) - ensureConnValid(t, conn) + var count int + conn.QueryRow(ctx, "select count(1) from ledger1 where id = $1", id).Scan(&count) + if count != 1 { + t.Errorf("count => %v, want %v", count, 1) + } + + err = br.Close() + if err != nil { + t.Fatal(err) + } + }) } -func TestConnBeginBatchQueryError(t *testing.T) { +func TestTxSendBatchRollback(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - batch := conn.BeginBatch() - batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0", - nil, - nil, - []int16{pgx.BinaryFormatCode}, - ) - batch.Queue("select n from generate_series(0,5) n", - nil, - nil, - []int16{pgx.BinaryFormatCode}, - ) - - err := batch.Send(context.Background(), nil) - if err != nil { - t.Fatal(err) - } + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + sql := `create temporary table ledger1( + id serial primary key, + description varchar not null + );` + mustExec(t, conn, sql) - rows, err := batch.QueryResults() - if err != nil { - t.Error(err) - } + tx, _ := conn.Begin(ctx) + batch := &pgx.Batch{} + batch.Queue("insert into ledger1(description) values($1) returning id", "q1") + + br := tx.SendBatch(ctx, batch) - for i := 0; rows.Next(); i++ { - var n int - if err := rows.Scan(&n); err != nil { + var id int + err := br.QueryRow().Scan(&id) + if err != nil { t.Error(err) } - if n != i { - t.Errorf("n => %v, want %v", n, i) + br.Close() + tx.Rollback(ctx) + + row := conn.QueryRow(ctx, "select count(1) from ledger1 where id = $1", id) + var count int + row.Scan(&count) + if count != 0 { + t.Errorf("count => %v, want %v", count, 0) } - } + }) +} - if pgErr, ok := rows.Err().(pgx.PgError); !(ok && pgErr.Code == "22012") { - t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012) - } +// https://github.com/jackc/pgx/issues/1578 +func TestSendBatchErrorWhileReadingResultsWithoutCallback(t *testing.T) { + t.Parallel() - err = batch.Close() - if pgErr, ok := err.(pgx.PgError); !(ok && pgErr.Code == "22012") { - t.Errorf("rows.Err() => %v, want error code %v", err, 22012) - } + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - if conn.IsAlive() { - t.Error("conn should be dead, but was alive") - } + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + batch := &pgx.Batch{} + batch.Queue("select 4 / $1::int", 0) + + batchResult := conn.SendBatch(ctx, batch) + + _, execErr := batchResult.Exec() + require.Error(t, execErr) + + closeErr := batchResult.Close() + require.Equal(t, execErr, closeErr) + + // Try to use the connection. + _, err := conn.Exec(ctx, "select 1") + require.NoError(t, err) + }) } -func TestConnBeginBatchQuerySyntaxError(t *testing.T) { +func TestSendBatchErrorWhileReadingResultsWithExecWhereSomeRowsAreReturned(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - batch := conn.BeginBatch() - batch.Queue("select 1 1", - nil, - nil, - []int16{pgx.BinaryFormatCode}, - ) + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + batch := &pgx.Batch{} + batch.Queue("select 4 / n from generate_series(-2, 2) n") - err := batch.Send(context.Background(), nil) - if err != nil { - t.Fatal(err) - } + batchResult := conn.SendBatch(ctx, batch) - var n int32 - err = batch.QueryRowResults().Scan(&n) - if pgErr, ok := err.(pgx.PgError); !(ok && pgErr.Code == "42601") { - t.Errorf("rows.Err() => %v, want error code %v", err, 42601) - } + _, execErr := batchResult.Exec() + require.Error(t, execErr) - err = batch.Close() - if err == nil { - t.Error("Expected error") - } + closeErr := batchResult.Close() + require.Equal(t, execErr, closeErr) - if conn.IsAlive() { - t.Error("conn should be dead, but was alive") - } + // Try to use the connection. + _, err := conn.Exec(ctx, "select 1") + require.NoError(t, err) + }) } -func TestConnBeginBatchQueryRowInsert(t *testing.T) { +func TestConnBeginBatchDeferredError(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - sql := `create temporary table ledger( - id serial primary key, - description varchar not null, - amount int not null -);` - mustExec(t, conn, sql) - - batch := conn.BeginBatch() - batch.Queue("select 1", - nil, - nil, - []int16{pgx.BinaryFormatCode}, - ) - batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", - []interface{}{"q1", 1}, - []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, - nil, - ) - - err := batch.Send(context.Background(), nil) - if err != nil { - t.Fatal(err) - } + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") - var value int - err = batch.QueryRowResults().Scan(&value) - if err != nil { - t.Error(err) - } + mustExec(t, conn, `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - ct, err := batch.ExecResults() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 2 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) - } + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) - batch.Close() + batch := &pgx.Batch{} - ensureConnValid(t, conn) + batch.Queue(`update t set n=n+1 where id='b' returning *`) + + br := conn.SendBatch(ctx, batch) + + rows, err := br.Query() + if err != nil { + t.Error(err) + } + + for rows.Next() { + var id string + var n int32 + err = rows.Scan(&id, &n) + if err != nil { + t.Fatal(err) + } + } + + err = br.Close() + if err == nil { + t.Fatal("expected error 23505 but got none") + } + + if err, ok := err.(*pgconn.PgError); !ok || err.Code != "23505" { + t.Fatalf("expected error 23505, got %v", err) + } + }) } -func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) { - t.Parallel() +func TestConnSendBatchNoStatementCache(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 0 - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnect(t, config) defer closeConn(t, conn) - sql := `create temporary table ledger( - id serial primary key, - description varchar not null, - amount int not null -);` - mustExec(t, conn, sql) - - batch := conn.BeginBatch() - batch.Queue("select 1 union all select 2 union all select 3", - nil, - nil, - []int16{pgx.BinaryFormatCode}, - ) - batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", - []interface{}{"q1", 1}, - []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, - nil, - ) - - err := batch.Send(context.Background(), nil) - if err != nil { - t.Fatal(err) - } + testConnSendBatch(t, ctx, conn, 3) +} - rows, err := batch.QueryResults() - if err != nil { - t.Error(err) - } - rows.Close() +func TestConnSendBatchPrepareStatementCache(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - ct, err := batch.ExecResults() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 2 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) - } + config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement + config.StatementCacheCapacity = 32 - batch.Close() + conn := mustConnect(t, config) + defer closeConn(t, conn) - ensureConnValid(t, conn) + testConnSendBatch(t, ctx, conn, 3) } -func TestTxBeginBatch(t *testing.T) { - t.Parallel() +func TestConnSendBatchDescribeStatementCache(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe + config.DescriptionCacheCapacity = 32 - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnect(t, config) defer closeConn(t, conn) - sql := `create temporary table ledger1( - id serial primary key, - description varchar not null -);` - mustExec(t, conn, sql) - - sql = `create temporary table ledger2( - id int primary key, - amount int not null -);` - mustExec(t, conn, sql) - - tx, _ := conn.Begin() - batch := tx.BeginBatch() - batch.Queue("insert into ledger1(description) values($1) returning id", - []interface{}{"q1"}, - []pgtype.OID{pgtype.VarcharOID}, - []int16{pgx.BinaryFormatCode}, - ) - - err := batch.Send(context.Background(), nil) - if err != nil { - t.Fatal(err) - } - var id int - err = batch.QueryRowResults().Scan(&id) - if err != nil { - t.Error(err) - } - batch.Close() - - batch = tx.BeginBatch() - batch.Queue("insert into ledger2(id,amount) values($1, $2)", - []interface{}{id, 2}, - []pgtype.OID{pgtype.Int4OID, pgtype.Int4OID}, - nil, - ) - - batch.Queue("select amount from ledger2 where id = $1", - []interface{}{id}, - []pgtype.OID{pgtype.Int4OID}, - nil, - ) - - err = batch.Send(context.Background(), nil) - if err != nil { - t.Fatal(err) - } - ct, err := batch.ExecResults() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 1 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) - } + testConnSendBatch(t, ctx, conn, 3) +} - var amout int - err = batch.QueryRowResults().Scan(&amout) - if err != nil { - t.Error(err) +func testConnSendBatch(t *testing.T, ctx context.Context, conn *pgx.Conn, queryCount int) { + batch := &pgx.Batch{} + for j := 0; j < queryCount; j++ { + batch.Queue("select n from generate_series(0,5) n") } - batch.Close() - tx.Commit() + br := conn.SendBatch(ctx, batch) - var count int - conn.QueryRow("select count(1) from ledger1 where id = $1", id).Scan(&count) - if count != 1 { - t.Errorf("count => %v, want %v", count, 1) - } + for j := 0; j < queryCount; j++ { + rows, err := br.Query() + require.NoError(t, err) - err = batch.Close() - if err != nil { - t.Fatal(err) + for k := 0; rows.Next(); k++ { + var n int + err := rows.Scan(&n) + require.NoError(t, err) + require.Equal(t, k, n) + } + + require.NoError(t, rows.Err()) } - ensureConnValid(t, conn) + err := br.Close() + require.NoError(t, err) } -func TestTxBeginBatchRollback(t *testing.T) { +func TestSendBatchSimpleProtocol(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol + + conn := mustConnect(t, config) defer closeConn(t, conn) - sql := `create temporary table ledger1( - id serial primary key, - description varchar not null -);` - mustExec(t, conn, sql) - - tx, _ := conn.Begin() - batch := tx.BeginBatch() - batch.Queue("insert into ledger1(description) values($1) returning id", - []interface{}{"q1"}, - []pgtype.OID{pgtype.VarcharOID}, - []int16{pgx.BinaryFormatCode}, - ) - - err := batch.Send(context.Background(), nil) + var batch pgx.Batch + batch.Queue("SELECT 1::int") + batch.Queue("SELECT 2::int; SELECT $1::int", 3) + results := conn.SendBatch(ctx, &batch) + rows, err := results.Query() + assert.NoError(t, err) + assert.True(t, rows.Next()) + values, err := rows.Values() + assert.NoError(t, err) + assert.EqualValues(t, 1, values[0]) + assert.False(t, rows.Next()) + + rows, err = results.Query() + assert.NoError(t, err) + assert.True(t, rows.Next()) + values, err = rows.Values() + assert.NoError(t, err) + assert.EqualValues(t, 2, values[0]) + assert.False(t, rows.Next()) + + rows, err = results.Query() + assert.NoError(t, err) + assert.True(t, rows.Next()) + values, err = rows.Values() + assert.NoError(t, err) + assert.EqualValues(t, 3, values[0]) + assert.False(t, rows.Next()) +} + +// https://github.com/jackc/pgx/issues/1847#issuecomment-2347858887 +func TestConnSendBatchErrorDoesNotLeaveOrphanedPreparedStatement(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test") + + mustExec(t, conn, `create temporary table foo(col1 text primary key);`) + + batch := &pgx.Batch{} + batch.Queue("select col1 from foo") + batch.Queue("select col1 from baz") + err := conn.SendBatch(ctx, batch).Close() + require.EqualError(t, err, `ERROR: relation "baz" does not exist (SQLSTATE 42P01)`) + + mustExec(t, conn, `create temporary table baz(col1 text primary key);`) + + // Since table baz now exists, the batch should succeed. + + batch = &pgx.Batch{} + batch.Queue("select col1 from foo") + batch.Queue("select col1 from baz") + err = conn.SendBatch(ctx, batch).Close() + require.NoError(t, err) + }) +} + +func TestSendBatchStatementTimeout(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + batch := &pgx.Batch{} + batch.Queue("SET statement_timeout='1ms'") + batch.Queue("SELECT pg_sleep(10)") + + br := conn.SendBatch(context.Background(), batch) + // set statement_timeout + _, err := br.Exec() + assert.NoError(t, err) + + // get pg_sleep results + rows, err := br.Query() + assert.NoError(t, err) + + // Consume rows and check error + for rows.Next() { + } + err = rows.Err() + assert.ErrorContains(t, err, "(SQLSTATE 57014)") + rows.Close() + + // The last error should be repeated when closing the batch + err = br.Close() + assert.ErrorContains(t, err, "(SQLSTATE 57014)") + + // Connection should be usable after the statement timeout in pipeline + _, err = conn.Exec(context.Background(), "Select 1") + assert.NoError(t, err) + }) + +} + +func ExampleConn_SendBatch() { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) if err != nil { - t.Fatal(err) + fmt.Printf("Unable to establish connection: %v", err) + return } - var id int - err = batch.QueryRowResults().Scan(&id) + + batch := &pgx.Batch{} + batch.Queue("select 1 + 1").QueryRow(func(row pgx.Row) error { + var n int32 + err := row.Scan(&n) + if err != nil { + return err + } + + fmt.Println(n) + + return err + }) + + batch.Queue("select 1 + 2").QueryRow(func(row pgx.Row) error { + var n int32 + err := row.Scan(&n) + if err != nil { + return err + } + + fmt.Println(n) + + return err + }) + + batch.Queue("select 2 + 3").QueryRow(func(row pgx.Row) error { + var n int32 + err := row.Scan(&n) + if err != nil { + return err + } + + fmt.Println(n) + + return err + }) + + err = conn.SendBatch(ctx, batch).Close() if err != nil { - t.Error(err) - } - batch.Close() - tx.Rollback() - - row := conn.QueryRow("select count(1) from ledger1 where id = $1", id) - var count int - row.Scan(&count) - if count != 0 { - t.Errorf("count => %v, want %v", count, 0) + fmt.Printf("SendBatch error: %v", err) + return } - ensureConnValid(t, conn) + // Output: + // 2 + // 3 + // 5 } diff --git a/bench-tmp_test.go b/bench-tmp_test.go deleted file mode 100644 index a8e3f7dbc..000000000 --- a/bench-tmp_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package pgx_test - -import ( - "testing" -) - -func BenchmarkPgtypeInt4ParseBinary(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - - _, err := conn.Prepare("selectBinary", "select n::int4 from generate_series(1, 100) n") - if err != nil { - b.Fatal(err) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - var n int32 - - rows, err := conn.Query("selectBinary") - if err != nil { - b.Fatal(err) - } - - for rows.Next() { - err := rows.Scan(&n) - if err != nil { - b.Fatal(err) - } - } - - if rows.Err() != nil { - b.Fatal(rows.Err()) - } - } -} - -func BenchmarkPgtypeInt4EncodeBinary(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - - _, err := conn.Prepare("encodeBinary", "select $1::int4, $2::int4, $3::int4, $4::int4, $5::int4, $6::int4, $7::int4") - if err != nil { - b.Fatal(err) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - rows, err := conn.Query("encodeBinary", int32(i), int32(i), int32(i), int32(i), int32(i), int32(i), int32(i)) - if err != nil { - b.Fatal(err) - } - rows.Close() - } -} diff --git a/bench_test.go b/bench_test.go index 7f82891e3..a4440bc1c 100644 --- a/bench_test.go +++ b/bench_test.go @@ -4,118 +4,170 @@ import ( "bytes" "context" "fmt" + "io" + "net" + "os" + "strconv" "strings" "testing" "time" - "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/require" ) -func BenchmarkConnPool(b *testing.B) { - config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 5} - pool, err := pgx.NewConnPool(config) - if err != nil { - b.Fatalf("Unable to create connection pool: %v", err) +func BenchmarkConnectClose(b *testing.B) { + for i := 0; i < b.N; i++ { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + b.Fatal(err) + } + + err = conn.Close(context.Background()) + if err != nil { + b.Fatal(err) + } } - defer pool.Close() +} + +func BenchmarkMinimalUnpreparedSelectWithoutStatementCache(b *testing.B) { + config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) + config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 0 + + conn := mustConnect(b, config) + defer closeConn(b, conn) + + var n int64 b.ResetTimer() for i := 0; i < b.N; i++ { - var conn *pgx.Conn - if conn, err = pool.Acquire(); err != nil { - b.Fatalf("Unable to acquire connection: %v", err) + err := conn.QueryRow(context.Background(), "select $1::int8", i).Scan(&n) + if err != nil { + b.Fatal(err) + } + + if n != int64(i) { + b.Fatalf("expected %d, got %d", i, n) } - pool.Release(conn) } } -func BenchmarkConnPoolQueryRow(b *testing.B) { - config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 5} - pool, err := pgx.NewConnPool(config) - if err != nil { - b.Fatalf("Unable to create connection pool: %v", err) +func BenchmarkMinimalUnpreparedSelectWithStatementCacheModeDescribe(b *testing.B) { + config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) + config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 32 + + conn := mustConnect(b, config) + defer closeConn(b, conn) + + var n int64 + + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := conn.QueryRow(context.Background(), "select $1::int8", i).Scan(&n) + if err != nil { + b.Fatal(err) + } + + if n != int64(i) { + b.Fatalf("expected %d, got %d", i, n) + } } - defer pool.Close() +} + +func BenchmarkMinimalUnpreparedSelectWithStatementCacheModePrepare(b *testing.B) { + config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) + config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement + config.StatementCacheCapacity = 32 + config.DescriptionCacheCapacity = 0 + + conn := mustConnect(b, config) + defer closeConn(b, conn) + + var n int64 b.ResetTimer() for i := 0; i < b.N; i++ { - num := float64(-1) - if err := pool.QueryRow("select random()").Scan(&num); err != nil { + err := conn.QueryRow(context.Background(), "select $1::int8", i).Scan(&n) + if err != nil { b.Fatal(err) } - if num < 0 { - b.Fatalf("expected `select random()` to return between 0 and 1 but it was: %v", num) + if n != int64(i) { + b.Fatalf("expected %d, got %d", i, n) } } } -func BenchmarkPointerPointerWithNullValues(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) +func BenchmarkMinimalPreparedSelect(b *testing.B) { + conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) defer closeConn(b, conn) - _, err := conn.Prepare("selectNulls", "select 1::int4, 'johnsmith', null::text, null::text, null::text, null::date, null::timestamptz") + _, err := conn.Prepare(context.Background(), "ps1", "select $1::int8") if err != nil { b.Fatal(err) } + var n int64 + b.ResetTimer() for i := 0; i < b.N; i++ { - var record struct { - id int32 - userName string - email *string - name *string - sex *string - birthDate *time.Time - lastLoginTime *time.Time - } - - err = conn.QueryRow("selectNulls").Scan( - &record.id, - &record.userName, - &record.email, - &record.name, - &record.sex, - &record.birthDate, - &record.lastLoginTime, - ) + err = conn.QueryRow(context.Background(), "ps1", i).Scan(&n) if err != nil { b.Fatal(err) } - // These checks both ensure that the correct data was returned - // and provide a benchmark of accessing the returned values. - if record.id != 1 { - b.Fatalf("bad value for id: %v", record.id) - } - if record.userName != "johnsmith" { - b.Fatalf("bad value for userName: %v", record.userName) - } - if record.email != nil { - b.Fatalf("bad value for email: %v", record.email) - } - if record.name != nil { - b.Fatalf("bad value for name: %v", record.name) + if n != int64(i) { + b.Fatalf("expected %d, got %d", i, n) } - if record.sex != nil { - b.Fatalf("bad value for sex: %v", record.sex) + } +} + +func BenchmarkMinimalPgConnPreparedSelect(b *testing.B) { + conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) + defer closeConn(b, conn) + + pgConn := conn.PgConn() + + _, err := pgConn.Prepare(context.Background(), "ps1", "select $1::int8", nil) + if err != nil { + b.Fatal(err) + } + + encodedBytes := make([]byte, 8) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + + rr := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{encodedBytes}, []int16{1}, []int16{1}) + if err != nil { + b.Fatal(err) } - if record.birthDate != nil { - b.Fatalf("bad value for birthDate: %v", record.birthDate) + + for rr.NextRow() { + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[0], encodedBytes) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], encodedBytes) + } + } } - if record.lastLoginTime != nil { - b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime) + _, err = rr.Close() + if err != nil { + b.Fatal(err) } } } -func BenchmarkPointerPointerWithPresentValues(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) +func BenchmarkPointerPointerWithNullValues(b *testing.B) { + conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) defer closeConn(b, conn) - _, err := conn.Prepare("selectNulls", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz") + _, err := conn.Prepare(context.Background(), "selectNulls", "select 1::int4, 'johnsmith', null::text, null::text, null::text, null::date, null::timestamptz") if err != nil { b.Fatal(err) } @@ -132,7 +184,7 @@ func BenchmarkPointerPointerWithPresentValues(b *testing.B) { lastLoginTime *time.Time } - err = conn.QueryRow("selectNulls").Scan( + err = conn.QueryRow(context.Background(), "selectNulls").Scan( &record.id, &record.userName, &record.email, @@ -153,81 +205,29 @@ func BenchmarkPointerPointerWithPresentValues(b *testing.B) { if record.userName != "johnsmith" { b.Fatalf("bad value for userName: %v", record.userName) } - if record.email == nil || *record.email != "johnsmith@example.com" { + if record.email != nil { b.Fatalf("bad value for email: %v", record.email) } - if record.name == nil || *record.name != "John Smith" { + if record.name != nil { b.Fatalf("bad value for name: %v", record.name) } - if record.sex == nil || *record.sex != "male" { + if record.sex != nil { b.Fatalf("bad value for sex: %v", record.sex) } - if record.birthDate == nil || *record.birthDate != time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local) { + if record.birthDate != nil { b.Fatalf("bad value for birthDate: %v", record.birthDate) } - if record.lastLoginTime == nil || *record.lastLoginTime != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) { + if record.lastLoginTime != nil { b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime) } } } -func BenchmarkSelectWithoutLogging(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - - benchmarkSelectWithLog(b, conn) -} - -type discardLogger struct{} - -func (dl discardLogger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {} - -func BenchmarkSelectWithLoggingTraceDiscard(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - - var logger discardLogger - conn.SetLogger(logger) - conn.SetLogLevel(pgx.LogLevelTrace) - - benchmarkSelectWithLog(b, conn) -} - -func BenchmarkSelectWithLoggingDebugWithDiscard(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - - var logger discardLogger - conn.SetLogger(logger) - conn.SetLogLevel(pgx.LogLevelDebug) - - benchmarkSelectWithLog(b, conn) -} - -func BenchmarkSelectWithLoggingInfoWithDiscard(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - - var logger discardLogger - conn.SetLogger(logger) - conn.SetLogLevel(pgx.LogLevelInfo) - - benchmarkSelectWithLog(b, conn) -} - -func BenchmarkSelectWithLoggingErrorWithDiscard(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) +func BenchmarkPointerPointerWithPresentValues(b *testing.B) { + conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) defer closeConn(b, conn) - var logger discardLogger - conn.SetLogger(logger) - conn.SetLogLevel(pgx.LogLevelError) - - benchmarkSelectWithLog(b, conn) -} - -func benchmarkSelectWithLog(b *testing.B, conn *pgx.Conn) { - _, err := conn.Prepare("test", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz") + _, err := conn.Prepare(context.Background(), "selectNulls", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz") if err != nil { b.Fatal(err) } @@ -237,14 +237,14 @@ func benchmarkSelectWithLog(b *testing.B, conn *pgx.Conn) { var record struct { id int32 userName string - email string - name string - sex string - birthDate time.Time - lastLoginTime time.Time + email *string + name *string + sex *string + birthDate *time.Time + lastLoginTime *time.Time } - err = conn.QueryRow("test").Scan( + err = conn.QueryRow(context.Background(), "selectNulls").Scan( &record.id, &record.userName, &record.email, @@ -265,19 +265,19 @@ func benchmarkSelectWithLog(b *testing.B, conn *pgx.Conn) { if record.userName != "johnsmith" { b.Fatalf("bad value for userName: %v", record.userName) } - if record.email != "johnsmith@example.com" { + if record.email == nil || *record.email != "johnsmith@example.com" { b.Fatalf("bad value for email: %v", record.email) } - if record.name != "John Smith" { + if record.name == nil || *record.name != "John Smith" { b.Fatalf("bad value for name: %v", record.name) } - if record.sex != "male" { + if record.sex == nil || *record.sex != "male" { b.Fatalf("bad value for sex: %v", record.sex) } - if record.birthDate != time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local) { + if record.birthDate == nil || *record.birthDate != time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) { b.Fatalf("bad value for birthDate: %v", record.birthDate) } - if record.lastLoginTime != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) { + if record.lastLoginTime == nil || *record.lastLoginTime != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) { b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime) } } @@ -335,15 +335,16 @@ const benchmarkWriteTableInsertSQL = `insert into t( type benchmarkWriteTableCopyFromSrc struct { count int idx int - row []interface{} + row []any } func (s *benchmarkWriteTableCopyFromSrc) Next() bool { + next := s.idx < s.count s.idx++ - return s.idx < s.count + return next } -func (s *benchmarkWriteTableCopyFromSrc) Values() ([]interface{}, error) { +func (s *benchmarkWriteTableCopyFromSrc) Values() ([]any, error) { return s.row, nil } @@ -354,15 +355,15 @@ func (s *benchmarkWriteTableCopyFromSrc) Err() error { func newBenchmarkWriteTableCopyFromSrc(count int) pgx.CopyFromSource { return &benchmarkWriteTableCopyFromSrc{ count: count, - row: []interface{}{ + row: []any{ "varchar_1", "varchar_2", - pgtype.Text{}, + &pgtype.Text{}, time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), - pgtype.Date{}, + &pgtype.Date{}, 1, 2, - pgtype.Int4{}, + &pgtype.Int4{}, time.Date(2001, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2002, 1, 1, 0, 0, 0, 0, time.Local), true, @@ -373,11 +374,11 @@ func newBenchmarkWriteTableCopyFromSrc(count int) pgx.CopyFromSource { } func benchmarkWriteNRowsViaInsert(b *testing.B, n int) { - conn := mustConnect(b, *defaultConnConfig) + conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) defer closeConn(b, conn) mustExec(b, conn, benchmarkWriteTableCreateSQL) - _, err := conn.Prepare("insert_t", benchmarkWriteTableInsertSQL) + _, err := conn.Prepare(context.Background(), "insert_t", benchmarkWriteTableInsertSQL) if err != nil { b.Fatal(err) } @@ -387,25 +388,60 @@ func benchmarkWriteNRowsViaInsert(b *testing.B, n int) { for i := 0; i < b.N; i++ { src := newBenchmarkWriteTableCopyFromSrc(n) - tx, err := conn.Begin() + tx, err := conn.Begin(context.Background()) if err != nil { b.Fatal(err) } for src.Next() { values, _ := src.Values() - if _, err = tx.Exec("insert_t", values...); err != nil { + if _, err = tx.Exec(context.Background(), "insert_t", values...); err != nil { b.Fatalf("Exec unexpectedly failed with: %v", err) } } - err = tx.Commit() + err = tx.Commit(context.Background()) + if err != nil { + b.Fatal(err) + } + } +} + +func benchmarkWriteNRowsViaBatchInsert(b *testing.B, n int) { + conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) + defer closeConn(b, conn) + + mustExec(b, conn, benchmarkWriteTableCreateSQL) + _, err := conn.Prepare(context.Background(), "insert_t", benchmarkWriteTableInsertSQL) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + src := newBenchmarkWriteTableCopyFromSrc(n) + + batch := &pgx.Batch{} + for src.Next() { + values, _ := src.Values() + batch.Queue("insert_t", values...) + } + + err = conn.SendBatch(context.Background(), batch).Close() if err != nil { b.Fatal(err) } } } +type queryArgs []any + +func (qa *queryArgs) Append(v any) string { + *qa = append(*qa, v) + return "$" + strconv.Itoa(len(*qa)) +} + // note this function is only used for benchmarks -- it doesn't escape tableName // or columnNames func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc pgx.CopyFromSource) (int, error) { @@ -414,7 +450,7 @@ func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc rowCount := 0 sqlBuf := &bytes.Buffer{} - args := make(pgx.QueryArgs, 0) + args := make(queryArgs, 0) resetQuery := func() { sqlBuf.Reset() @@ -426,11 +462,11 @@ func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc } resetQuery() - tx, err := conn.Begin() + tx, err := conn.Begin(context.Background()) if err != nil { return 0, err } - defer tx.Rollback() + defer tx.Rollback(context.Background()) for rowSrc.Next() { if rowsThisInsert > 0 { @@ -456,7 +492,7 @@ func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc rowsThisInsert++ if rowsThisInsert == maxRowsPerInsert { - _, err := tx.Exec(sqlBuf.String(), args...) + _, err := tx.Exec(context.Background(), sqlBuf.String(), args...) if err != nil { return 0, err } @@ -467,7 +503,7 @@ func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc } if rowsThisInsert > 0 { - _, err := tx.Exec(sqlBuf.String(), args...) + _, err := tx.Exec(context.Background(), sqlBuf.String(), args...) if err != nil { return 0, err } @@ -475,20 +511,19 @@ func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc rowCount += rowsThisInsert } - if err := tx.Commit(); err != nil { - return 0, nil + if err := tx.Commit(context.Background()); err != nil { + return 0, err } return rowCount, nil - } func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) { - conn := mustConnect(b, *defaultConnConfig) + conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) defer closeConn(b, conn) mustExec(b, conn, benchmarkWriteTableCreateSQL) - _, err := conn.Prepare("insert_t", benchmarkWriteTableInsertSQL) + _, err := conn.Prepare(context.Background(), "insert_t", benchmarkWriteTableInsertSQL) if err != nil { b.Fatal(err) } @@ -499,7 +534,8 @@ func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) { src := newBenchmarkWriteTableCopyFromSrc(n) _, err := multiInsert(conn, "t", - []string{"varchar_1", + []string{ + "varchar_1", "varchar_2", "varchar_null_1", "date_1", @@ -511,7 +547,8 @@ func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) { "tstz_2", "bool_1", "bool_2", - "bool_3"}, + "bool_3", + }, src) if err != nil { b.Fatal(err) @@ -520,7 +557,7 @@ func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) { } func benchmarkWriteNRowsViaCopy(b *testing.B, n int) { - conn := mustConnect(b, *defaultConnConfig) + conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) defer closeConn(b, conn) mustExec(b, conn, benchmarkWriteTableCreateSQL) @@ -530,8 +567,10 @@ func benchmarkWriteNRowsViaCopy(b *testing.B, n int) { for i := 0; i < b.N; i++ { src := newBenchmarkWriteTableCopyFromSrc(n) - _, err := conn.CopyFrom(pgx.Identifier{"t"}, - []string{"varchar_1", + _, err := conn.CopyFrom(context.Background(), + pgx.Identifier{"t"}, + []string{ + "varchar_1", "varchar_2", "varchar_null_1", "date_1", @@ -543,7 +582,8 @@ func benchmarkWriteNRowsViaCopy(b *testing.B, n int) { "tstz_2", "bool_1", "bool_2", - "bool_3"}, + "bool_3", + }, src) if err != nil { b.Fatal(err) @@ -551,6 +591,22 @@ func benchmarkWriteNRowsViaCopy(b *testing.B, n int) { } } +func BenchmarkWrite2RowsViaInsert(b *testing.B) { + benchmarkWriteNRowsViaInsert(b, 2) +} + +func BenchmarkWrite2RowsViaMultiInsert(b *testing.B) { + benchmarkWriteNRowsViaMultiInsert(b, 2) +} + +func BenchmarkWrite2RowsViaBatchInsert(b *testing.B) { + benchmarkWriteNRowsViaBatchInsert(b, 2) +} + +func BenchmarkWrite2RowsViaCopy(b *testing.B) { + benchmarkWriteNRowsViaCopy(b, 2) +} + func BenchmarkWrite5RowsViaInsert(b *testing.B) { benchmarkWriteNRowsViaInsert(b, 5) } @@ -559,6 +615,10 @@ func BenchmarkWrite5RowsViaMultiInsert(b *testing.B) { benchmarkWriteNRowsViaMultiInsert(b, 5) } +func BenchmarkWrite5RowsViaBatchInsert(b *testing.B) { + benchmarkWriteNRowsViaBatchInsert(b, 5) +} + func BenchmarkWrite5RowsViaCopy(b *testing.B) { benchmarkWriteNRowsViaCopy(b, 5) } @@ -571,6 +631,10 @@ func BenchmarkWrite10RowsViaMultiInsert(b *testing.B) { benchmarkWriteNRowsViaMultiInsert(b, 10) } +func BenchmarkWrite10RowsViaBatchInsert(b *testing.B) { + benchmarkWriteNRowsViaBatchInsert(b, 10) +} + func BenchmarkWrite10RowsViaCopy(b *testing.B) { benchmarkWriteNRowsViaCopy(b, 10) } @@ -583,6 +647,10 @@ func BenchmarkWrite100RowsViaMultiInsert(b *testing.B) { benchmarkWriteNRowsViaMultiInsert(b, 100) } +func BenchmarkWrite100RowsViaBatchInsert(b *testing.B) { + benchmarkWriteNRowsViaBatchInsert(b, 100) +} + func BenchmarkWrite100RowsViaCopy(b *testing.B) { benchmarkWriteNRowsViaCopy(b, 100) } @@ -595,6 +663,10 @@ func BenchmarkWrite1000RowsViaMultiInsert(b *testing.B) { benchmarkWriteNRowsViaMultiInsert(b, 1000) } +func BenchmarkWrite1000RowsViaBatchInsert(b *testing.B) { + benchmarkWriteNRowsViaBatchInsert(b, 1000) +} + func BenchmarkWrite1000RowsViaCopy(b *testing.B) { benchmarkWriteNRowsViaCopy(b, 1000) } @@ -607,24 +679,55 @@ func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) { benchmarkWriteNRowsViaMultiInsert(b, 10000) } +func BenchmarkWrite10000RowsViaBatchInsert(b *testing.B) { + benchmarkWriteNRowsViaBatchInsert(b, 10000) +} + func BenchmarkWrite10000RowsViaCopy(b *testing.B) { benchmarkWriteNRowsViaCopy(b, 10000) } -func BenchmarkMultipleQueriesNonBatch(b *testing.B) { - config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 5} - pool, err := pgx.NewConnPool(config) - if err != nil { - b.Fatalf("Unable to create connection pool: %v", err) - } - defer pool.Close() +func BenchmarkMultipleQueriesNonBatchNoStatementCache(b *testing.B) { + config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) + config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 0 + + conn := mustConnect(b, config) + defer closeConn(b, conn) - queryCount := 3 + benchmarkMultipleQueriesNonBatch(b, conn, 3) +} + +func BenchmarkMultipleQueriesNonBatchPrepareStatementCache(b *testing.B) { + config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) + config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement + config.StatementCacheCapacity = 32 + config.DescriptionCacheCapacity = 0 + conn := mustConnect(b, config) + defer closeConn(b, conn) + + benchmarkMultipleQueriesNonBatch(b, conn, 3) +} + +func BenchmarkMultipleQueriesNonBatchDescribeStatementCache(b *testing.B) { + config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) + config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 32 + + conn := mustConnect(b, config) + defer closeConn(b, conn) + + benchmarkMultipleQueriesNonBatch(b, conn, 3) +} + +func benchmarkMultipleQueriesNonBatch(b *testing.B, conn *pgx.Conn, queryCount int) { b.ResetTimer() for i := 0; i < b.N; i++ { for j := 0; j < queryCount; j++ { - rows, err := pool.Query("select n from generate_series(0, 5) n") + rows, err := conn.Query(context.Background(), "select n from generate_series(0, 5) n") if err != nil { b.Fatal(err) } @@ -646,34 +749,54 @@ func BenchmarkMultipleQueriesNonBatch(b *testing.B) { } } -func BenchmarkMultipleQueriesBatch(b *testing.B) { - config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 5} - pool, err := pgx.NewConnPool(config) - if err != nil { - b.Fatalf("Unable to create connection pool: %v", err) - } - defer pool.Close() +func BenchmarkMultipleQueriesBatchNoStatementCache(b *testing.B) { + config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) + config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 0 - queryCount := 3 + conn := mustConnect(b, config) + defer closeConn(b, conn) + benchmarkMultipleQueriesBatch(b, conn, 3) +} + +func BenchmarkMultipleQueriesBatchPrepareStatementCache(b *testing.B) { + config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) + config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement + config.StatementCacheCapacity = 32 + config.DescriptionCacheCapacity = 0 + + conn := mustConnect(b, config) + defer closeConn(b, conn) + + benchmarkMultipleQueriesBatch(b, conn, 3) +} + +func BenchmarkMultipleQueriesBatchDescribeStatementCache(b *testing.B) { + config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) + config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 32 + + conn := mustConnect(b, config) + defer closeConn(b, conn) + + benchmarkMultipleQueriesBatch(b, conn, 3) +} + +func benchmarkMultipleQueriesBatch(b *testing.B, conn *pgx.Conn, queryCount int) { b.ResetTimer() for i := 0; i < b.N; i++ { - batch := pool.BeginBatch() + batch := &pgx.Batch{} for j := 0; j < queryCount; j++ { - batch.Queue("select n from generate_series(0,5) n", - nil, - nil, - []int16{pgx.BinaryFormatCode}, - ) + batch.Queue("select n from generate_series(0,5) n") } - err := batch.Send(context.Background(), nil) - if err != nil { - b.Fatal(err) - } + br := conn.SendBatch(context.Background(), batch) for j := 0; j < queryCount; j++ { - rows, err := batch.QueryResults() + rows, err := br.Query() if err != nil { b.Fatal(err) } @@ -693,9 +816,579 @@ func BenchmarkMultipleQueriesBatch(b *testing.B) { } } - err = batch.Close() + err := br.Close() + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkSelectManyUnknownEnum(b *testing.B) { + conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(b, conn) + + ctx := context.Background() + tx, err := conn.Begin(ctx) + require.NoError(b, err) + defer tx.Rollback(ctx) + + _, err = tx.Exec(context.Background(), "drop type if exists color;") + require.NoError(b, err) + + _, err = tx.Exec(ctx, `create type color as enum ('blue', 'green', 'orange')`) + require.NoError(b, err) + + b.ResetTimer() + var x, y, z string + for i := 0; i < b.N; i++ { + rows, err := conn.Query(ctx, "select 'blue'::color, 'green'::color, 'orange'::color from generate_series(1,10)") + if err != nil { + b.Fatal(err) + } + + for rows.Next() { + err = rows.Scan(&x, &y, &z) + if err != nil { + b.Fatal(err) + } + + if x != "blue" { + b.Fatal("unexpected result") + } + if y != "green" { + b.Fatal("unexpected result") + } + if z != "orange" { + b.Fatal("unexpected result") + } + } + + if rows.Err() != nil { + b.Fatal(rows.Err()) + } + } +} + +func BenchmarkSelectManyRegisteredEnum(b *testing.B) { + conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(b, conn) + + ctx := context.Background() + tx, err := conn.Begin(ctx) + require.NoError(b, err) + defer tx.Rollback(ctx) + + _, err = tx.Exec(context.Background(), "drop type if exists color;") + require.NoError(b, err) + + _, err = tx.Exec(ctx, `create type color as enum ('blue', 'green', 'orange')`) + require.NoError(b, err) + + var oid uint32 + err = conn.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "color").Scan(&oid) + require.NoError(b, err) + + conn.TypeMap().RegisterType(&pgtype.Type{Name: "color", OID: oid, Codec: &pgtype.EnumCodec{}}) + + b.ResetTimer() + var x, y, z string + for i := 0; i < b.N; i++ { + rows, err := conn.Query(ctx, "select 'blue'::color, 'green'::color, 'orange'::color from generate_series(1,10)") if err != nil { b.Fatal(err) } + + for rows.Next() { + err = rows.Scan(&x, &y, &z) + if err != nil { + b.Fatal(err) + } + + if x != "blue" { + b.Fatal("unexpected result") + } + if y != "green" { + b.Fatal("unexpected result") + } + if z != "orange" { + b.Fatal("unexpected result") + } + } + + if rows.Err() != nil { + b.Fatal(rows.Err()) + } + } +} + +func getSelectRowsCounts(b *testing.B) []int64 { + var rowCounts []int64 + { + s := os.Getenv("PGX_BENCH_SELECT_ROWS_COUNTS") + if s != "" { + for _, p := range strings.Split(s, " ") { + n, err := strconv.ParseInt(p, 10, 64) + if err != nil { + b.Fatalf("Bad PGX_BENCH_SELECT_ROWS_COUNTS value: %v", err) + } + rowCounts = append(rowCounts, n) + } + } + } + + if len(rowCounts) == 0 { + rowCounts = []int64{1, 10, 100, 1000} + } + + return rowCounts +} + +type BenchRowSimple struct { + ID int32 + FirstName string + LastName string + Sex string + BirthDate time.Time + Weight int32 + Height int32 + Tags []string + UpdateTime time.Time +} + +func BenchmarkSelectRowsScanSimple(b *testing.B) { + conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(b, conn) + + rowCounts := getSelectRowsCounts(b) + + for _, rowCount := range rowCounts { + b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { + br := &BenchRowSimple{} + for i := 0; i < b.N; i++ { + rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount) + if err != nil { + b.Fatal(err) + } + + for rows.Next() { + rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.Tags, &br.UpdateTime) + } + + if rows.Err() != nil { + b.Fatal(rows.Err()) + } + } + }) + } +} + +type BenchRowStringBytes struct { + ID int32 + FirstName []byte + LastName []byte + Sex []byte + BirthDate time.Time + Weight int32 + Height int32 + Tags []string + UpdateTime time.Time +} + +func BenchmarkSelectRowsScanStringBytes(b *testing.B) { + conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(b, conn) + + rowCounts := getSelectRowsCounts(b) + + for _, rowCount := range rowCounts { + b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { + br := &BenchRowStringBytes{} + for i := 0; i < b.N; i++ { + rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount) + if err != nil { + b.Fatal(err) + } + + for rows.Next() { + rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.Tags, &br.UpdateTime) + } + + if rows.Err() != nil { + b.Fatal(rows.Err()) + } + } + }) + } +} + +type BenchRowDecoder struct { + ID pgtype.Int4 + FirstName pgtype.Text + LastName pgtype.Text + Sex pgtype.Text + BirthDate pgtype.Date + Weight pgtype.Int4 + Height pgtype.Int4 + Tags pgtype.FlatArray[string] + UpdateTime pgtype.Timestamptz +} + +func BenchmarkSelectRowsScanDecoder(b *testing.B) { + conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(b, conn) + + rowCounts := getSelectRowsCounts(b) + + for _, rowCount := range rowCounts { + b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { + formats := []struct { + name string + code int16 + }{ + {"text", pgx.TextFormatCode}, + {"binary", pgx.BinaryFormatCode}, + } + for _, format := range formats { + b.Run(format.name, func(b *testing.B) { + br := &BenchRowDecoder{} + for i := 0; i < b.N; i++ { + rows, err := conn.Query( + context.Background(), + "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", + pgx.QueryResultFormats{format.code}, + rowCount, + ) + if err != nil { + b.Fatal(err) + } + + for rows.Next() { + rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.Tags, &br.UpdateTime) + } + + if rows.Err() != nil { + b.Fatal(rows.Err()) + } + } + }) + } + }) + } +} + +func BenchmarkSelectRowsPgConnExecText(b *testing.B) { + conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(b, conn) + + rowCounts := getSelectRowsCounts(b) + + for _, rowCount := range rowCounts { + b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { + for i := 0; i < b.N; i++ { + mrr := conn.PgConn().Exec(context.Background(), fmt.Sprintf("select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + %d) n", rowCount)) + for mrr.NextResult() { + rr := mrr.ResultReader() + for rr.NextRow() { + rr.Values() + } + } + + err := mrr.Close() + if err != nil { + b.Fatal(err) + } + } + }) + } +} + +func BenchmarkSelectRowsPgConnExecParams(b *testing.B) { + conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(b, conn) + + rowCounts := getSelectRowsCounts(b) + + for _, rowCount := range rowCounts { + b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { + formats := []struct { + name string + code int16 + }{ + {"text", pgx.TextFormatCode}, + {"binary - mostly", pgx.BinaryFormatCode}, + } + for _, format := range formats { + b.Run(format.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + rr := conn.PgConn().ExecParams( + context.Background(), + "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", + [][]byte{[]byte(strconv.FormatInt(rowCount, 10))}, + nil, + nil, + []int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code, format.code}, + ) + for rr.NextRow() { + rr.Values() + } + + _, err := rr.Close() + if err != nil { + b.Fatal(err) + } + } + }) + } + }) + } +} + +func BenchmarkSelectRowsSimpleCollectRowsRowToStructByPos(b *testing.B) { + conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(b, conn) + + rowCounts := getSelectRowsCounts(b) + + for _, rowCount := range rowCounts { + b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { + for i := 0; i < b.N; i++ { + rows, _ := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount) + benchRows, err := pgx.CollectRows(rows, pgx.RowToStructByPos[BenchRowSimple]) + if err != nil { + b.Fatal(err) + } + if len(benchRows) != int(rowCount) { + b.Fatalf("Expected %d rows, got %d", rowCount, len(benchRows)) + } + } + }) + } +} + +func BenchmarkSelectRowsSimpleAppendRowsRowToStructByPos(b *testing.B) { + conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(b, conn) + + rowCounts := getSelectRowsCounts(b) + + for _, rowCount := range rowCounts { + b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { + benchRows := make([]BenchRowSimple, 0, rowCount) + for i := 0; i < b.N; i++ { + benchRows = benchRows[:0] + rows, _ := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount) + var err error + benchRows, err = pgx.AppendRows(benchRows, rows, pgx.RowToStructByPos[BenchRowSimple]) + if err != nil { + b.Fatal(err) + } + if len(benchRows) != int(rowCount) { + b.Fatalf("Expected %d rows, got %d", rowCount, len(benchRows)) + } + } + }) + } +} + +func BenchmarkSelectRowsSimpleCollectRowsRowToStructByName(b *testing.B) { + conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(b, conn) + + rowCounts := getSelectRowsCounts(b) + + for _, rowCount := range rowCounts { + b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { + for i := 0; i < b.N; i++ { + rows, _ := conn.Query(context.Background(), "select n as id, 'Adam' as first_name, 'Smith ' || n as last_name, 'male' as sex, '1952-06-16'::date as birth_date, 258 as weight, 72 as height, '{foo,bar,baz}'::text[] as tags, '2001-01-28 01:02:03-05'::timestamptz as update_time from generate_series(100001, 100000 + $1) n", rowCount) + benchRows, err := pgx.CollectRows(rows, pgx.RowToStructByName[BenchRowSimple]) + if err != nil { + b.Fatal(err) + } + if len(benchRows) != int(rowCount) { + b.Fatalf("Expected %d rows, got %d", rowCount, len(benchRows)) + } + } + }) + } +} + +func BenchmarkSelectRowsSimpleAppendRowsRowToStructByName(b *testing.B) { + conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(b, conn) + + rowCounts := getSelectRowsCounts(b) + + for _, rowCount := range rowCounts { + b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { + benchRows := make([]BenchRowSimple, 0, rowCount) + for i := 0; i < b.N; i++ { + benchRows = benchRows[:0] + rows, _ := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount) + var err error + benchRows, err = pgx.AppendRows(benchRows, rows, pgx.RowToStructByPos[BenchRowSimple]) + if err != nil { + b.Fatal(err) + } + if len(benchRows) != int(rowCount) { + b.Fatalf("Expected %d rows, got %d", rowCount, len(benchRows)) + } + } + }) + } +} + +func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) { + conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(b, conn) + + rowCounts := getSelectRowsCounts(b) + + _, err := conn.PgConn().Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil) + if err != nil { + b.Fatal(err) + } + + for _, rowCount := range rowCounts { + b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { + formats := []struct { + name string + code int16 + }{ + {"text", pgx.TextFormatCode}, + {"binary - mostly", pgx.BinaryFormatCode}, + } + for _, format := range formats { + b.Run(format.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + rr := conn.PgConn().ExecPrepared( + context.Background(), + "ps1", + [][]byte{[]byte(strconv.FormatInt(rowCount, 10))}, + nil, + []int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code, format.code}, + ) + for rr.NextRow() { + rr.Values() + } + + _, err := rr.Close() + if err != nil { + b.Fatal(err) + } + } + }) + } + }) + } +} + +type queryRecorder struct { + conn net.Conn + writeBuf []byte + readCount int +} + +func (qr *queryRecorder) Read(b []byte) (n int, err error) { + n, err = qr.conn.Read(b) + qr.readCount += n + return n, err +} + +func (qr *queryRecorder) Write(b []byte) (n int, err error) { + qr.writeBuf = append(qr.writeBuf, b...) + return qr.conn.Write(b) +} + +func (qr *queryRecorder) Close() error { + return qr.conn.Close() +} + +func (qr *queryRecorder) LocalAddr() net.Addr { + return qr.conn.LocalAddr() +} + +func (qr *queryRecorder) RemoteAddr() net.Addr { + return qr.conn.RemoteAddr() +} + +func (qr *queryRecorder) SetDeadline(t time.Time) error { + return qr.conn.SetDeadline(t) +} + +func (qr *queryRecorder) SetReadDeadline(t time.Time) error { + return qr.conn.SetReadDeadline(t) +} + +func (qr *queryRecorder) SetWriteDeadline(t time.Time) error { + return qr.conn.SetWriteDeadline(t) +} + +// BenchmarkSelectRowsRawPrepared hijacks a pgconn connection and inserts a queryRecorder. It then executes the query +// once. The benchmark is simply sending the exact query bytes over the wire to the server and reading the expected +// number of bytes back. It does nothing else. This should be the theoretical maximum performance a Go application +// could achieve. +func BenchmarkSelectRowsRawPrepared(b *testing.B) { + rowCounts := getSelectRowsCounts(b) + + for _, rowCount := range rowCounts { + b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { + formats := []struct { + name string + code int16 + }{ + {"text", pgx.TextFormatCode}, + {"binary - mostly", pgx.BinaryFormatCode}, + } + for _, format := range formats { + b.Run(format.name, func(b *testing.B) { + conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")).PgConn() + defer conn.Close(context.Background()) + + _, err := conn.Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil) + if err != nil { + b.Fatal(err) + } + + hijackedConn, err := conn.Hijack() + require.NoError(b, err) + + qr := &queryRecorder{ + conn: hijackedConn.Conn, + } + + hijackedConn.Conn = qr + hijackedConn.Frontend = hijackedConn.Config.BuildFrontend(qr, qr) + conn, err = pgconn.Construct(hijackedConn) + require.NoError(b, err) + + { + rr := conn.ExecPrepared( + context.Background(), + "ps1", + [][]byte{[]byte(strconv.FormatInt(rowCount, 10))}, + nil, + []int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code, format.code}, + ) + _, err := rr.Close() + require.NoError(b, err) + } + + buf := make([]byte, qr.readCount) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := qr.conn.Write(qr.writeBuf) + if err != nil { + b.Fatal(err) + } + + _, err = io.ReadFull(qr.conn, buf) + if err != nil { + b.Fatal(err) + } + } + }) + } + }) } } diff --git a/chunkreader/chunkreader.go b/chunkreader/chunkreader.go deleted file mode 100644 index f8d437b2e..000000000 --- a/chunkreader/chunkreader.go +++ /dev/null @@ -1,89 +0,0 @@ -package chunkreader - -import ( - "io" -) - -type ChunkReader struct { - r io.Reader - - buf []byte - rp, wp int // buf read position and write position - - options Options -} - -type Options struct { - MinBufLen int // Minimum buffer length -} - -func NewChunkReader(r io.Reader) *ChunkReader { - cr, err := NewChunkReaderEx(r, Options{}) - if err != nil { - panic("default options can't be bad") - } - - return cr -} - -func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) { - if options.MinBufLen == 0 { - options.MinBufLen = 4096 - } - - return &ChunkReader{ - r: r, - buf: make([]byte, options.MinBufLen), - options: options, - }, nil -} - -// Next returns buf filled with the next n bytes. If an error occurs, buf will -// be nil. -func (r *ChunkReader) Next(n int) (buf []byte, err error) { - // n bytes already in buf - if (r.wp - r.rp) >= n { - buf = r.buf[r.rp : r.rp+n] - r.rp += n - return buf, err - } - - // available space in buf is less than n - if len(r.buf) < n { - r.copyBufContents(r.newBuf(n)) - } - - // buf is large enough, but need to shift filled area to start to make enough contiguous space - minReadCount := n - (r.wp - r.rp) - if (len(r.buf) - r.wp) < minReadCount { - newBuf := r.newBuf(n) - r.copyBufContents(newBuf) - } - - if err := r.appendAtLeast(minReadCount); err != nil { - return nil, err - } - - buf = r.buf[r.rp : r.rp+n] - r.rp += n - return buf, nil -} - -func (r *ChunkReader) appendAtLeast(fillLen int) error { - n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen) - r.wp += n - return err -} - -func (r *ChunkReader) newBuf(size int) []byte { - if size < r.options.MinBufLen { - size = r.options.MinBufLen - } - return make([]byte, size) -} - -func (r *ChunkReader) copyBufContents(dest []byte) { - r.wp = copy(dest, r.buf[r.rp:r.wp]) - r.rp = 0 - r.buf = dest -} diff --git a/chunkreader/chunkreader_test.go b/chunkreader/chunkreader_test.go deleted file mode 100644 index 3be07e3cf..000000000 --- a/chunkreader/chunkreader_test.go +++ /dev/null @@ -1,96 +0,0 @@ -package chunkreader - -import ( - "bytes" - "testing" -) - -func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { - server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Options{MinBufLen: 4}) - if err != nil { - t.Fatal(err) - } - - src := []byte{1, 2, 3, 4} - server.Write(src) - - n1, err := r.Next(2) - if err != nil { - t.Fatal(err) - } - if bytes.Compare(n1, src[0:2]) != 0 { - t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:2], n1) - } - - n2, err := r.Next(2) - if err != nil { - t.Fatal(err) - } - if bytes.Compare(n2, src[2:4]) != 0 { - t.Fatalf("Expected read bytes to be %v, but they were %v", src[2:4], n2) - } - - if bytes.Compare(r.buf, src) != 0 { - t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf) - } - if r.rp != 4 { - t.Fatalf("Expected r.rp to be %v, but it was %v", 4, r.rp) - } - if r.wp != 4 { - t.Fatalf("Expected r.wp to be %v, but it was %v", 4, r.wp) - } -} - -func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) { - server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Options{MinBufLen: 4}) - if err != nil { - t.Fatal(err) - } - - src := []byte{1, 2, 3, 4, 5, 6, 7, 8} - server.Write(src) - - n1, err := r.Next(5) - if err != nil { - t.Fatal(err) - } - if bytes.Compare(n1, src[0:5]) != 0 { - t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:5], n1) - } - if len(r.buf) != 5 { - t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 5, len(r.buf)) - } -} - -func TestChunkReaderDoesNotReuseBuf(t *testing.T) { - server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Options{MinBufLen: 4}) - if err != nil { - t.Fatal(err) - } - - src := []byte{1, 2, 3, 4, 5, 6, 7, 8} - server.Write(src) - - n1, err := r.Next(4) - if err != nil { - t.Fatal(err) - } - if bytes.Compare(n1, src[0:4]) != 0 { - t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:4], n1) - } - - n2, err := r.Next(4) - if err != nil { - t.Fatal(err) - } - if bytes.Compare(n2, src[4:8]) != 0 { - t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2) - } - - if bytes.Compare(n1, src[0:4]) != 0 { - t.Fatalf("Expected KeepLast to prevent Next from overwriting buf, expected %v but it was %v", src[0:4], n1) - } -} diff --git a/ci/setup_test.bash b/ci/setup_test.bash new file mode 100755 index 000000000..d591c512c --- /dev/null +++ b/ci/setup_test.bash @@ -0,0 +1,61 @@ +#!/usr/bin/env bash +set -eux + +if [[ "${PGVERSION-}" =~ ^[0-9.]+$ ]] +then + sudo apt-get remove -y --purge postgresql libpq-dev libpq5 postgresql-client-common postgresql-common + sudo rm -rf /var/lib/postgresql + wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - + sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list" + sudo apt-get update -qq + sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION + + sudo cp testsetup/pg_hba.conf /etc/postgresql/$PGVERSION/main/pg_hba.conf + sudo sh -c "echo \"listen_addresses = '127.0.0.1'\" >> /etc/postgresql/$PGVERSION/main/postgresql.conf" + sudo sh -c "cat testsetup/postgresql_ssl.conf >> /etc/postgresql/$PGVERSION/main/postgresql.conf" + + cd testsetup + + # Generate CA, server, and encrypted client certificates. + go run generate_certs.go + + # Copy certificates to server directory and set permissions. + sudo cp ca.pem /var/lib/postgresql/$PGVERSION/main/root.crt + sudo chown postgres:postgres /var/lib/postgresql/$PGVERSION/main/root.crt + sudo cp localhost.key /var/lib/postgresql/$PGVERSION/main/server.key + sudo chown postgres:postgres /var/lib/postgresql/$PGVERSION/main/server.key + sudo chmod 600 /var/lib/postgresql/$PGVERSION/main/server.key + sudo cp localhost.crt /var/lib/postgresql/$PGVERSION/main/server.crt + sudo chown postgres:postgres /var/lib/postgresql/$PGVERSION/main/server.crt + + cp ca.pem /tmp + cp pgx_sslcert.key /tmp + cp pgx_sslcert.crt /tmp + + cd .. + + sudo /etc/init.d/postgresql restart + + createdb -U postgres pgx_test + psql -U postgres -f testsetup/postgresql_setup.sql pgx_test +fi + +if [[ "${PGVERSION-}" =~ ^cockroach ]] +then + wget -qO- https://binaries.cockroachdb.com/cockroach-v24.3.3.linux-amd64.tgz | tar xvz + sudo mv cockroach-v24.3.3.linux-amd64/cockroach /usr/local/bin/ + cockroach start-single-node --insecure --background --listen-addr=localhost + cockroach sql --insecure -e 'create database pgx_test' +fi + +if [ "${CRATEVERSION-}" != "" ] +then + docker run \ + -p "6543:5432" \ + -d \ + crate:"$CRATEVERSION" \ + crate \ + -Cnetwork.host=0.0.0.0 \ + -Ctransport.host=localhost \ + -Clicense.enterprise=false +fi diff --git a/conn.go b/conn.go index 9cb325fe6..340bca5ae 100644 --- a/conn.go +++ b/conn.go @@ -2,183 +2,89 @@ package pgx import ( "context" - "crypto/md5" - "crypto/tls" - "crypto/x509" - "encoding/binary" + "crypto/sha256" + "database/sql" "encoding/hex" + "errors" "fmt" - "io" - "io/ioutil" - "net" - "net/url" - "os" - "os/user" - "path/filepath" - "regexp" "strconv" "strings" - "sync" - "sync/atomic" "time" - "github.com/pkg/errors" - - "github.com/jackc/pgx/pgio" - "github.com/jackc/pgx/pgproto3" - "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/v5/internal/sanitize" + "github.com/jackc/pgx/v5/internal/stmtcache" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" ) -const ( - connStatusUninitialized = iota - connStatusClosed - connStatusIdle - connStatusBusy -) +// ConnConfig contains all the options used to establish a connection. It must be created by ParseConfig and +// then it can be modified. A manually initialized ConnConfig will cause ConnectConfig to panic. +type ConnConfig struct { + pgconn.Config -// minimalConnInfo has just enough static type information to establish the -// connection and retrieve the type data. -var minimalConnInfo *pgtype.ConnInfo - -func init() { - minimalConnInfo = pgtype.NewConnInfo() - minimalConnInfo.InitializeDataTypes(map[string]pgtype.OID{ - "int4": pgtype.Int4OID, - "name": pgtype.NameOID, - "oid": pgtype.OIDOID, - "text": pgtype.TextOID, - "varchar": pgtype.VarcharOID, - }) -} + Tracer QueryTracer -// NoticeHandler is a function that can handle notices received from the -// PostgreSQL server. Notices can be received at any time, usually during -// handling of a query response. The *Conn is provided so the handler is aware -// of the origin of the notice, but it must not invoke any query method. Be -// aware that this is distinct from LISTEN/NOTIFY notification. -type NoticeHandler func(*Conn, *Notice) + // Original connection string that was parsed into config. + connString string -// DialFunc is a function that can be used to connect to a PostgreSQL server -type DialFunc func(network, addr string) (net.Conn, error) + // StatementCacheCapacity is maximum size of the statement cache used when executing a query with "cache_statement" + // query exec mode. + StatementCacheCapacity int -// ConnConfig contains all the options used to establish a connection. -type ConnConfig struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) - Port uint16 // default: 5432 - Database string - User string // default: OS user name - Password string - TLSConfig *tls.Config // config for TLS connection -- nil disables TLS - UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa - FallbackTLSConfig *tls.Config // config for fallback TLS connection (only used if UseFallBackTLS is true)-- nil disables TLS - Logger Logger - LogLevel int - Dial DialFunc - RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) - OnNotice NoticeHandler // Callback function called when a notice response is received. - CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc. - - // PreferSimpleProtocol disables implicit prepared statement usage. By default - // pgx automatically uses the unnamed prepared statement for Query and - // QueryRow. It also uses a prepared statement when Exec has arguments. This - // can improve performance due to being able to use the binary format. It also - // does not rely on client side parameter sanitization. However, it does incur - // two round-trips per query and may be incompatible proxies such as - // PGBouncer. Setting PreferSimpleProtocol causes the simple protocol to be - // used by default. The same functionality can be controlled on a per query - // basis by setting QueryExOptions.SimpleProtocol. - PreferSimpleProtocol bool -} + // DescriptionCacheCapacity is the maximum size of the description cache used when executing a query with + // "cache_describe" query exec mode. + DescriptionCacheCapacity int -func (cc *ConnConfig) networkAddress() (network, address string) { - network = "tcp" - address = fmt.Sprintf("%s:%d", cc.Host, cc.Port) - // See if host is a valid path, if yes connect with a socket - if _, err := os.Stat(cc.Host); err == nil { - // For backward compatibility accept socket file paths -- but directories are now preferred - network = "unix" - address = cc.Host - if !strings.Contains(address, "/.s.PGSQL.") { - address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10) - } - } + // DefaultQueryExecMode controls the default mode for executing queries. By default pgx uses the extended protocol + // and automatically prepares and caches prepared statements. However, this may be incompatible with proxies such as + // PGBouncer. In this case it may be preferable to use QueryExecModeExec or QueryExecModeSimpleProtocol. The same + // functionality can be controlled on a per query basis by passing a QueryExecMode as the first query argument. + DefaultQueryExecMode QueryExecMode - return network, address + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } -// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. -// Use ConnPool to manage access to multiple database connections from multiple -// goroutines. -type Conn struct { - conn net.Conn // the underlying TCP or unix domain socket connection - lastActivityTime time.Time // the last time the connection was used - wbuf []byte - pid uint32 // backend pid - secretKey uint32 // key to use to send a cancel query message to the server - RuntimeParams map[string]string // parameters that have been reported by the server - config ConnConfig // config used when establishing this connection - txStatus byte - preparedStatements map[string]*PreparedStatement - channels map[string]struct{} - notifications []*Notification - logger Logger - logLevel int - fp *fastpath - poolResetCount int - preallocatedRows []Rows - onNotice NoticeHandler - - mux sync.Mutex - status byte // One of connStatus* constants - causeOfDeath error - - pendingReadyForQueryCount int // numer of ReadyForQuery messages expected - cancelQueryInProgress int32 - cancelQueryCompleted chan struct{} - - // context support - ctxInProgress bool - doneChan chan struct{} - closedChan chan error - - ConnInfo *pgtype.ConnInfo - - frontend *pgproto3.Frontend +// ParseConfigOptions contains options that control how a config is built such as getsslpassword. +type ParseConfigOptions struct { + pgconn.ParseConfigOptions } -// PreparedStatement is a description of a prepared statement -type PreparedStatement struct { - Name string - SQL string - FieldDescriptions []FieldDescription - ParameterOIDs []pgtype.OID +// Copy returns a deep copy of the config that is safe to use and modify. +// The only exception is the tls.Config: +// according to the tls.Config docs it must not be modified after creation. +func (cc *ConnConfig) Copy() *ConnConfig { + newConfig := new(ConnConfig) + *newConfig = *cc + newConfig.Config = *newConfig.Config.Copy() + return newConfig } -// PrepareExOptions is an option struct that can be passed to PrepareEx -type PrepareExOptions struct { - ParameterOIDs []pgtype.OID -} +// ConnString returns the connection string as parsed by pgx.ParseConfig into pgx.ConnConfig. +func (cc *ConnConfig) ConnString() string { return cc.connString } -// Notification is a message received from the PostgreSQL LISTEN/NOTIFY system -type Notification struct { - PID uint32 // backend pid that sent the notification - Channel string // channel from which notification was received - Payload string -} +// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. Use a connection pool to manage access +// to multiple database connections from multiple goroutines. +type Conn struct { + pgConn *pgconn.PgConn + config *ConnConfig // config used when establishing this connection + preparedStatements map[string]*pgconn.StatementDescription + statementCache stmtcache.Cache + descriptionCache stmtcache.Cache -// CommandTag is the result of an Exec function -type CommandTag string + queryTracer QueryTracer + batchTracer BatchTracer + copyFromTracer CopyFromTracer + prepareTracer PrepareTracer -// RowsAffected returns the number of rows affected. If the CommandTag was not -// for a row affecting command (such as "CREATE TABLE") then it returns 0 -func (ct CommandTag) RowsAffected() int64 { - s := string(ct) - index := strings.LastIndex(s, " ") - if index == -1 { - return 0 - } - n, _ := strconv.ParseInt(s[index+1:], 10, 64) - return n + notifications []*pgconn.Notification + + doneChan chan struct{} + closedChan chan error + + typeMap *pgtype.Map + + wbuf []byte + eqb ExtendedQueryBuilder } // Identifier a PostgreSQL identifier or name. Identifiers can be composed of @@ -189,1719 +95,1347 @@ type Identifier []string func (ident Identifier) Sanitize() string { parts := make([]string, len(ident)) for i := range ident { - parts[i] = `"` + strings.Replace(ident[i], `"`, `""`, -1) + `"` + s := strings.ReplaceAll(ident[i], string([]byte{0}), "") + parts[i] = `"` + strings.ReplaceAll(s, `"`, `""`) + `"` } return strings.Join(parts, ".") } -// ErrNoRows occurs when rows are expected but none are returned. -var ErrNoRows = errors.New("no rows in result set") +var ( + // ErrNoRows occurs when rows are expected but none are returned. + ErrNoRows = newProxyErr(sql.ErrNoRows, "no rows in result set") + // ErrTooManyRows occurs when more rows than expected are returned. + ErrTooManyRows = errors.New("too many rows in result set") +) -// ErrDeadConn occurs on an attempt to use a dead connection -var ErrDeadConn = errors.New("conn is dead") +func newProxyErr(background error, msg string) error { + return &proxyError{ + msg: msg, + background: background, + } +} -// ErrTLSRefused occurs when the connection attempt requires TLS and the -// PostgreSQL server refuses to use TLS -var ErrTLSRefused = errors.New("server refused TLS connection") +type proxyError struct { + msg string + background error +} -// ErrConnBusy occurs when the connection is busy (for example, in the middle of -// reading query results) and another action is attempted. -var ErrConnBusy = errors.New("conn is busy") +func (err *proxyError) Error() string { return err.msg } -// ErrInvalidLogLevel occurs on attempt to set an invalid log level. -var ErrInvalidLogLevel = errors.New("invalid log level") +func (err *proxyError) Unwrap() error { return err.background } -// ProtocolError occurs when unexpected data is received from PostgreSQL -type ProtocolError string +var ( + errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") + errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") +) -func (e ProtocolError) Error() string { - return string(e) +// Connect establishes a connection with a PostgreSQL server with a connection string. See +// pgconn.Connect for details. +func Connect(ctx context.Context, connString string) (*Conn, error) { + connConfig, err := ParseConfig(connString) + if err != nil { + return nil, err + } + return connect(ctx, connConfig) } -// Connect establishes a connection with a PostgreSQL server using config. -// config.Host must be specified. config.User will default to the OS user name. -// Other config fields are optional. -func Connect(config ConnConfig) (c *Conn, err error) { - return connect(config, minimalConnInfo) +// ConnectWithOptions behaves exactly like Connect with the addition of options. At the present options is only used to +// provide a GetSSLPassword function. +func ConnectWithOptions(ctx context.Context, connString string, options ParseConfigOptions) (*Conn, error) { + connConfig, err := ParseConfigWithOptions(connString, options) + if err != nil { + return nil, err + } + return connect(ctx, connConfig) } -func defaultDialer() *net.Dialer { - return &net.Dialer{KeepAlive: 5 * time.Minute} -} +// ConnectConfig establishes a connection with a PostgreSQL server with a configuration struct. +// connConfig must have been created by ParseConfig. +func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) { + // In general this improves safety. In particular avoid the config.Config.OnNotification mutation from affecting other + // connections with the same config. See https://github.com/jackc/pgx/issues/618. + connConfig = connConfig.Copy() -func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) { - c = new(Conn) + return connect(ctx, connConfig) +} - c.config = config - c.ConnInfo = connInfo +// ParseConfigWithOptions behaves exactly as ParseConfig does with the addition of options. At the present options is +// only used to provide a GetSSLPassword function. +func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*ConnConfig, error) { + config, err := pgconn.ParseConfigWithOptions(connString, options.ParseConfigOptions) + if err != nil { + return nil, err + } - if c.config.LogLevel != 0 { - c.logLevel = c.config.LogLevel - } else { - // Preserve pre-LogLevel behavior by defaulting to LogLevelDebug - c.logLevel = LogLevelDebug + statementCacheCapacity := 512 + if s, ok := config.RuntimeParams["statement_cache_capacity"]; ok { + delete(config.RuntimeParams, "statement_cache_capacity") + n, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return nil, pgconn.NewParseConfigError(connString, "cannot parse statement_cache_capacity", err) + } + statementCacheCapacity = int(n) } - c.logger = c.config.Logger - if c.config.User == "" { - user, err := user.Current() + descriptionCacheCapacity := 512 + if s, ok := config.RuntimeParams["description_cache_capacity"]; ok { + delete(config.RuntimeParams, "description_cache_capacity") + n, err := strconv.ParseInt(s, 10, 32) if err != nil { - return nil, err + return nil, pgconn.NewParseConfigError(connString, "cannot parse description_cache_capacity", err) } - c.config.User = user.Username - if c.shouldLog(LogLevelDebug) { - c.log(LogLevelDebug, "Using default connection config", map[string]interface{}{"User": c.config.User}) + descriptionCacheCapacity = int(n) + } + + defaultQueryExecMode := QueryExecModeCacheStatement + if s, ok := config.RuntimeParams["default_query_exec_mode"]; ok { + delete(config.RuntimeParams, "default_query_exec_mode") + switch s { + case "cache_statement": + defaultQueryExecMode = QueryExecModeCacheStatement + case "cache_describe": + defaultQueryExecMode = QueryExecModeCacheDescribe + case "describe_exec": + defaultQueryExecMode = QueryExecModeDescribeExec + case "exec": + defaultQueryExecMode = QueryExecModeExec + case "simple_protocol": + defaultQueryExecMode = QueryExecModeSimpleProtocol + default: + return nil, pgconn.NewParseConfigError(connString, "invalid default_query_exec_mode", err) } } - if c.config.Port == 0 { - c.config.Port = 5432 - if c.shouldLog(LogLevelDebug) { - c.log(LogLevelDebug, "Using default connection config", map[string]interface{}{"Port": c.config.Port}) - } + connConfig := &ConnConfig{ + Config: *config, + createdByParseConfig: true, + StatementCacheCapacity: statementCacheCapacity, + DescriptionCacheCapacity: descriptionCacheCapacity, + DefaultQueryExecMode: defaultQueryExecMode, + connString: connString, } - c.onNotice = config.OnNotice + return connConfig, nil +} + +// ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that [pgconn.ParseConfig] +// does. In addition, it accepts the following options: +// +// - default_query_exec_mode. +// Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See +// QueryExecMode constant documentation for the meaning of these values. Default: "cache_statement". +// +// - statement_cache_capacity. +// The maximum size of the statement cache used when executing a query with "cache_statement" query exec mode. +// Default: 512. +// +// - description_cache_capacity. +// The maximum size of the description cache used when executing a query with "cache_describe" query exec mode. +// Default: 512. +func ParseConfig(connString string) (*ConnConfig, error) { + return ParseConfigWithOptions(connString, ParseConfigOptions{}) +} - network, address := c.config.networkAddress() - if c.config.Dial == nil { - d := defaultDialer() - c.config.Dial = d.Dial +// connect connects to a database. connect takes ownership of config. The caller must not use or access it again. +func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { + if connectTracer, ok := config.Tracer.(ConnectTracer); ok { + ctx = connectTracer.TraceConnectStart(ctx, TraceConnectStartData{ConnConfig: config}) + defer func() { + connectTracer.TraceConnectEnd(ctx, TraceConnectEndData{Conn: c, Err: err}) + }() } - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address}) + // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from + // zero values. + if !config.createdByParseConfig { + panic("config must be created by ParseConfig") } - err = c.connect(config, network, address, config.TLSConfig) - if err != nil && config.UseFallbackTLS { - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err}) - } - err = c.connect(config, network, address, config.FallbackTLSConfig) + + c = &Conn{ + config: config, + typeMap: pgtype.NewMap(), + queryTracer: config.Tracer, } - if err != nil { - if c.shouldLog(LogLevelError) { - c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err}) - } - return nil, err + if t, ok := c.queryTracer.(BatchTracer); ok { + c.batchTracer = t + } + if t, ok := c.queryTracer.(CopyFromTracer); ok { + c.copyFromTracer = t + } + if t, ok := c.queryTracer.(PrepareTracer); ok { + c.prepareTracer = t } - return c, nil -} + // Only install pgx notification system if no other callback handler is present. + if config.Config.OnNotification == nil { + config.Config.OnNotification = c.bufferNotifications + } -func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) { - c.conn, err = c.config.Dial(network, address) + c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config) if err != nil { - return err + return nil, err } - defer func() { - if c != nil && err != nil { - c.conn.Close() - c.mux.Lock() - c.status = connStatusClosed - c.mux.Unlock() - } - }() - c.RuntimeParams = make(map[string]string) - c.preparedStatements = make(map[string]*PreparedStatement) - c.channels = make(map[string]struct{}) - c.lastActivityTime = time.Now() - c.cancelQueryCompleted = make(chan struct{}, 1) + c.preparedStatements = make(map[string]*pgconn.StatementDescription) c.doneChan = make(chan struct{}) c.closedChan = make(chan error) c.wbuf = make([]byte, 0, 1024) - c.mux.Lock() - c.status = connStatusIdle - c.mux.Unlock() - - if tlsConfig != nil { - if c.shouldLog(LogLevelDebug) { - c.log(LogLevelDebug, "starting TLS handshake", nil) - } - if err := c.startTLS(tlsConfig); err != nil { - return err - } - } - - c.frontend, err = pgproto3.NewFrontend(c.conn, c.conn) - if err != nil { - return err + if c.config.StatementCacheCapacity > 0 { + c.statementCache = stmtcache.NewLRUCache(c.config.StatementCacheCapacity) } - startupMsg := pgproto3.StartupMessage{ - ProtocolVersion: pgproto3.ProtocolVersionNumber, - Parameters: make(map[string]string), + if c.config.DescriptionCacheCapacity > 0 { + c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity) } - // Default to disabling TLS renegotiation. - // - // Go does not support (https://github.com/golang/go/issues/5742) - // PostgreSQL recommends disabling (http://www.postgresql.org/docs/9.4/static/runtime-config-connection.html#GUC-SSL-RENEGOTIATION-LIMIT) - if tlsConfig != nil { - startupMsg.Parameters["ssl_renegotiation_limit"] = "0" - } + return c, nil +} - // Copy default run-time params - for k, v := range config.RuntimeParams { - startupMsg.Parameters[k] = v +// Close closes a connection. It is safe to call Close on an already closed +// connection. +func (c *Conn) Close(ctx context.Context) error { + if c.IsClosed() { + return nil } - startupMsg.Parameters["user"] = c.config.User - if c.config.Database != "" { - startupMsg.Parameters["database"] = c.config.Database - } + err := c.pgConn.Close(ctx) + return err +} - if _, err := c.conn.Write(startupMsg.Encode(nil)); err != nil { - return err +// Prepare creates a prepared statement with name and sql. sql can contain placeholders for bound parameters. These +// placeholders are referenced positionally as $1, $2, etc. name can be used instead of sql with Query, QueryRow, and +// Exec to execute the statement. It can also be used with Batch.Queue. +// +// The underlying PostgreSQL identifier for the prepared statement will be name if name != sql or a digest of sql if +// name == sql. +// +// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same name and sql arguments. This +// allows a code path to Prepare and Query/Exec without concern for if the statement has already been prepared. +func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) { + if c.prepareTracer != nil { + ctx = c.prepareTracer.TracePrepareStart(ctx, c, TracePrepareStartData{Name: name, SQL: sql}) } - c.pendingReadyForQueryCount = 1 - - for { - msg, err := c.rxMsg() - if err != nil { - return err - } - - switch msg := msg.(type) { - case *pgproto3.BackendKeyData: - c.rxBackendKeyData(msg) - case *pgproto3.Authentication: - if err = c.rxAuthenticationX(msg); err != nil { - return err - } - case *pgproto3.ReadyForQuery: - c.rxReadyForQuery(msg) - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "connection established", nil) - } - - // Replication connections can't execute the queries to - // populate the c.PgTypes and c.pgsqlAfInet - if _, ok := config.RuntimeParams["replication"]; ok { - return nil - } - - if c.ConnInfo == minimalConnInfo { - err = c.initConnInfo() - if err != nil { - return err - } - } - - return nil - default: - if err = c.processContextFreeMsg(msg); err != nil { - return err + if name != "" { + var ok bool + if sd, ok = c.preparedStatements[name]; ok && sd.SQL == sql { + if c.prepareTracer != nil { + c.prepareTracer.TracePrepareEnd(ctx, c, TracePrepareEndData{AlreadyPrepared: true}) } + return sd, nil } } -} -func initPostgresql(c *Conn) (*pgtype.ConnInfo, error) { - const ( - namedOIDQuery = `select t.oid, - case when nsp.nspname in ('pg_catalog', 'public') then t.typname - else nsp.nspname||'.'||t.typname - end -from pg_type t -left join pg_type base_type on t.typelem=base_type.oid -left join pg_namespace nsp on t.typnamespace=nsp.oid -where ( - t.typtype in('b', 'p', 'r', 'e') - and (base_type.oid is null or base_type.typtype in('b', 'p', 'r')) - )` - ) + if c.prepareTracer != nil { + defer func() { + c.prepareTracer.TracePrepareEnd(ctx, c, TracePrepareEndData{Err: err}) + }() + } + + var psName, psKey string + if name == sql { + digest := sha256.Sum256([]byte(sql)) + psName = "stmt_" + hex.EncodeToString(digest[0:24]) + psKey = sql + } else { + psName = name + psKey = name + } - nameOIDs, err := connInfoFromRows(c.Query(namedOIDQuery)) + sd, err = c.pgConn.Prepare(ctx, psName, sql, nil) if err != nil { return nil, err } - cinfo := pgtype.NewConnInfo() - cinfo.InitializeDataTypes(nameOIDs) - - if err = c.initConnInfoEnumArray(cinfo); err != nil { - return nil, err + if psKey != "" { + c.preparedStatements[psKey] = sd } - return cinfo, nil + return sd, nil } -func (c *Conn) initConnInfo() (err error) { - var ( - connInfo *pgtype.ConnInfo - ) - - if c.config.CustomConnInfo != nil { - if c.ConnInfo, err = c.config.CustomConnInfo(c); err != nil { - return err - } - - return nil +// Deallocate releases a prepared statement. Calling Deallocate on a non-existent prepared statement will succeed. +func (c *Conn) Deallocate(ctx context.Context, name string) error { + var psName string + sd := c.preparedStatements[name] + if sd != nil { + psName = sd.Name + } else { + psName = name } - if connInfo, err = initPostgresql(c); err == nil { - c.ConnInfo = connInfo + err := c.pgConn.Deallocate(ctx, psName) + if err != nil { return err } - // Check if CrateDB specific approach might still allow us to connect. - if connInfo, err = c.crateDBTypesQuery(err); err == nil { - c.ConnInfo = connInfo + if sd != nil { + delete(c.preparedStatements, name) } - return err + return nil } -// initConnInfoEnumArray introspects for arrays of enums and registers a data type for them. -func (c *Conn) initConnInfoEnumArray(cinfo *pgtype.ConnInfo) error { - nameOIDs := make(map[string]pgtype.OID, 16) - rows, err := c.Query(`select t.oid, t.typname -from pg_type t - join pg_type base_type on t.typelem=base_type.oid -where t.typtype = 'b' - and base_type.typtype = 'e'`) - if err != nil { - return err +// DeallocateAll releases all previously prepared statements from the server and client, where it also resets the statement and description cache. +func (c *Conn) DeallocateAll(ctx context.Context) error { + c.preparedStatements = map[string]*pgconn.StatementDescription{} + if c.config.StatementCacheCapacity > 0 { + c.statementCache = stmtcache.NewLRUCache(c.config.StatementCacheCapacity) } + if c.config.DescriptionCacheCapacity > 0 { + c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity) + } + _, err := c.pgConn.Exec(ctx, "deallocate all").ReadAll() + return err +} - for rows.Next() { - var oid pgtype.OID - var name pgtype.Text - if err := rows.Scan(&oid, &name); err != nil { - return err - } +func (c *Conn) bufferNotifications(_ *pgconn.PgConn, n *pgconn.Notification) { + c.notifications = append(c.notifications, n) +} - nameOIDs[name.String] = oid - } +// WaitForNotification waits for a PostgreSQL notification. It wraps the underlying pgconn notification system in a +// slightly more convenient form. +func (c *Conn) WaitForNotification(ctx context.Context) (*pgconn.Notification, error) { + var n *pgconn.Notification - if rows.Err() != nil { - return rows.Err() + // Return already received notification immediately + if len(c.notifications) > 0 { + n = c.notifications[0] + c.notifications = c.notifications[1:] + return n, nil } - for name, oid := range nameOIDs { - cinfo.RegisterDataType(pgtype.DataType{ - Value: &pgtype.EnumArray{}, - Name: name, - OID: oid, - }) + err := c.pgConn.WaitForNotification(ctx) + if len(c.notifications) > 0 { + n = c.notifications[0] + c.notifications = c.notifications[1:] } - - return nil + return n, err } -// crateDBTypesQuery checks if the given err is likely to be the result of -// CrateDB not implementing the pg_types table correctly. If yes, a CrateDB -// specific query against pg_types is executed and its results are returned. If -// not, the original error is returned. -func (c *Conn) crateDBTypesQuery(err error) (*pgtype.ConnInfo, error) { - // CrateDB 2.1.6 is a database that implements the PostgreSQL wire protocol, - // but not perfectly. In particular, the pg_catalog schema containing the - // pg_type table is not visible by default and the pg_type.typtype column is - // not implemented. Therefor the query above currently returns the following - // error: - // - // pgx.PgError{Severity:"ERROR", Code:"XX000", - // Message:"TableUnknownException: Table 'test.pg_type' unknown", - // Detail:"", Hint:"", Position:0, InternalPosition:0, InternalQuery:"", - // Where:"", SchemaName:"", TableName:"", ColumnName:"", DataTypeName:"", - // ConstraintName:"", File:"Schemas.java", Line:99, Routine:"getTableInfo"} - // - // If CrateDB was to fix the pg_type table visbility in the future, we'd - // still get this error until typtype column is implemented: - // - // pgx.PgError{Severity:"ERROR", Code:"XX000", - // Message:"ColumnUnknownException: Column typtype unknown", Detail:"", - // Hint:"", Position:0, InternalPosition:0, InternalQuery:"", Where:"", - // SchemaName:"", TableName:"", ColumnName:"", DataTypeName:"", - // ConstraintName:"", File:"FullQualifiedNameFieldProvider.java", Line:132, - // - // Additionally CrateDB doesn't implement Postgres error codes [2], and - // instead always returns "XX000" (internal_error). The code below uses all - // of this knowledge as a heuristic to detect CrateDB. If CrateDB is - // detected, a CrateDB specific pg_type query is executed instead. - // - // The heuristic is designed to still work even if CrateDB fixes [2] or - // renames its internal exception names. If both are changed but pg_types - // isn't fixed, this code will need to be changed. - // - // There is also a small chance the heuristic will yield a false positive for - // non-CrateDB databases (e.g. if a real Postgres instance returns a XX000 - // error), but hopefully there will be no harm in attempting the alternative - // query in this case. - // - // CrateDB also uses the type varchar for the typname column which required - // adding varchar to the minimalConnInfo init code. - // - // Also see the discussion here [3]. - // - // [1] https://crate.io/ - // [2] https://github.com/crate/crate/issues/5027 - // [3] https://github.com/jackc/pgx/issues/320 - - if pgErr, ok := err.(PgError); ok && - (pgErr.Code == "XX000" || - strings.Contains(pgErr.Message, "TableUnknownException") || - strings.Contains(pgErr.Message, "ColumnUnknownException")) { - var ( - nameOIDs map[string]pgtype.OID - ) - - if nameOIDs, err = connInfoFromRows(c.Query(`select oid, typname from pg_catalog.pg_type`)); err != nil { - return nil, err - } - - cinfo := pgtype.NewConnInfo() - cinfo.InitializeDataTypes(nameOIDs) +// IsClosed reports if the connection has been closed. +func (c *Conn) IsClosed() bool { + return c.pgConn.IsClosed() +} - return cinfo, err +func (c *Conn) die() { + if c.IsClosed() { + return } - return nil, err + ctx, cancel := context.WithCancel(context.Background()) + cancel() // force immediate hard cancel + c.pgConn.Close(ctx) } -// PID returns the backend PID for this connection. -func (c *Conn) PID() uint32 { - return c.pid +func quoteIdentifier(s string) string { + return `"` + strings.ReplaceAll(s, `"`, `""`) + `"` } -// Close closes a connection. It is safe to call Close on a already closed -// connection. -func (c *Conn) Close() (err error) { - c.mux.Lock() - defer c.mux.Unlock() +// Ping delegates to the underlying *pgconn.PgConn.Ping. +func (c *Conn) Ping(ctx context.Context) error { + return c.pgConn.Ping(ctx) +} - if c.status < connStatusIdle { - return nil - } - c.status = connStatusClosed +// PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the +// PostgreSQL connection than pgx exposes. +// +// It is strongly recommended that the connection be idle (no in-progress queries) before the underlying *pgconn.PgConn +// is used and the connection must be returned to the same state before any *pgx.Conn methods are again used. +func (c *Conn) PgConn() *pgconn.PgConn { return c.pgConn } - defer func() { - c.conn.Close() - c.causeOfDeath = errors.New("Closed") - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "closed connection", nil) - } - }() +// TypeMap returns the connection info used for this connection. +func (c *Conn) TypeMap() *pgtype.Map { return c.typeMap } - err = c.conn.SetDeadline(time.Time{}) - if err != nil && c.shouldLog(LogLevelWarn) { - c.log(LogLevelWarn, "failed to clear deadlines to send close message", map[string]interface{}{"err": err}) - return err - } +// Config returns a copy of config that was used to establish this connection. +func (c *Conn) Config() *ConnConfig { return c.config.Copy() } - _, err = c.conn.Write([]byte{'X', 0, 0, 0, 4}) - if err != nil && c.shouldLog(LogLevelWarn) { - c.log(LogLevelWarn, "failed to send terminate message", map[string]interface{}{"err": err}) - return err +// Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced +// positionally from the sql string as $1, $2, etc. +func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { + if c.queryTracer != nil { + ctx = c.queryTracer.TraceQueryStart(ctx, c, TraceQueryStartData{SQL: sql, Args: arguments}) } - err = c.conn.SetReadDeadline(time.Now().Add(5 * time.Second)) - if err != nil && c.shouldLog(LogLevelWarn) { - c.log(LogLevelWarn, "failed to set read deadline to finish closing", map[string]interface{}{"err": err}) - return err + if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil { + return pgconn.CommandTag{}, err } - _, err = c.conn.Read(make([]byte, 1)) - if err != io.EOF { - return err + commandTag, err := c.exec(ctx, sql, arguments...) + + if c.queryTracer != nil { + c.queryTracer.TraceQueryEnd(ctx, c, TraceQueryEndData{CommandTag: commandTag, Err: err}) } - return nil + return commandTag, err } -// Merge returns a new ConnConfig with the attributes of old and other -// combined. When an attribute is set on both, other takes precedence. -// -// As a security precaution, if the other TLSConfig is nil, all old TLS -// attributes will be preserved. -func (old ConnConfig) Merge(other ConnConfig) ConnConfig { - cc := old - - if other.Host != "" { - cc.Host = other.Host - } - if other.Port != 0 { - cc.Port = other.Port - } - if other.Database != "" { - cc.Database = other.Database - } - if other.User != "" { - cc.User = other.User - } - if other.Password != "" { - cc.Password = other.Password - } +func (c *Conn) exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) { + mode := c.config.DefaultQueryExecMode + var queryRewriter QueryRewriter - if other.TLSConfig != nil { - cc.TLSConfig = other.TLSConfig - cc.UseFallbackTLS = other.UseFallbackTLS - cc.FallbackTLSConfig = other.FallbackTLSConfig +optionLoop: + for len(arguments) > 0 { + switch arg := arguments[0].(type) { + case QueryExecMode: + mode = arg + arguments = arguments[1:] + case QueryRewriter: + queryRewriter = arg + arguments = arguments[1:] + default: + break optionLoop + } } - if other.Logger != nil { - cc.Logger = other.Logger - } - if other.LogLevel != 0 { - cc.LogLevel = other.LogLevel + if queryRewriter != nil { + sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments) + if err != nil { + return pgconn.CommandTag{}, fmt.Errorf("rewrite query failed: %w", err) + } } - if other.Dial != nil { - cc.Dial = other.Dial + // Always use simple protocol when there are no arguments. + if len(arguments) == 0 { + mode = QueryExecModeSimpleProtocol } - cc.RuntimeParams = make(map[string]string) - for k, v := range old.RuntimeParams { - cc.RuntimeParams[k] = v + if sd, ok := c.preparedStatements[sql]; ok { + return c.execPrepared(ctx, sd, arguments) } - for k, v := range other.RuntimeParams { - cc.RuntimeParams[k] = v - } - - return cc -} - -// ParseURI parses a database URI into ConnConfig -// -// Query parameters not used by the connection process are parsed into ConnConfig.RuntimeParams. -func ParseURI(uri string) (ConnConfig, error) { - var cp ConnConfig - url, err := url.Parse(uri) - if err != nil { - return cp, err - } + switch mode { + case QueryExecModeCacheStatement: + if c.statementCache == nil { + return pgconn.CommandTag{}, errDisabledStatementCache + } + sd := c.statementCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, stmtcache.StatementName(sql), sql) + if err != nil { + return pgconn.CommandTag{}, err + } + c.statementCache.Put(sd) + } - if url.User != nil { - cp.User = url.User.Username() - cp.Password, _ = url.User.Password() - } + return c.execPrepared(ctx, sd, arguments) + case QueryExecModeCacheDescribe: + if c.descriptionCache == nil { + return pgconn.CommandTag{}, errDisabledDescriptionCache + } + sd := c.descriptionCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, "", sql) + if err != nil { + return pgconn.CommandTag{}, err + } + c.descriptionCache.Put(sd) + } - parts := strings.SplitN(url.Host, ":", 2) - cp.Host = parts[0] - if len(parts) == 2 { - p, err := strconv.ParseUint(parts[1], 10, 16) + return c.execParams(ctx, sd, arguments) + case QueryExecModeDescribeExec: + sd, err := c.Prepare(ctx, "", sql) if err != nil { - return cp, err + return pgconn.CommandTag{}, err } - cp.Port = uint16(p) + return c.execPrepared(ctx, sd, arguments) + case QueryExecModeExec: + return c.execSQLParams(ctx, sql, arguments) + case QueryExecModeSimpleProtocol: + return c.execSimpleProtocol(ctx, sql, arguments) + default: + return pgconn.CommandTag{}, fmt.Errorf("unknown QueryExecMode: %v", mode) } - cp.Database = strings.TrimLeft(url.Path, "/") +} - if pgtimeout := url.Query().Get("connect_timeout"); pgtimeout != "" { - timeout, err := strconv.ParseInt(pgtimeout, 10, 64) +func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []any) (commandTag pgconn.CommandTag, err error) { + if len(arguments) > 0 { + sql, err = c.sanitizeForSimpleQuery(sql, arguments...) if err != nil { - return cp, err + return pgconn.CommandTag{}, err } - d := defaultDialer() - d.Timeout = time.Duration(timeout) * time.Second - cp.Dial = d.Dial } - tlsArgs := configTLSArgs{ - sslCert: url.Query().Get("sslcert"), - sslKey: url.Query().Get("sslkey"), - sslMode: url.Query().Get("sslmode"), - sslRootCert: url.Query().Get("sslrootcert"), + mrr := c.pgConn.Exec(ctx, sql) + for mrr.NextResult() { + commandTag, _ = mrr.ResultReader().Close() } - err = configTLS(tlsArgs, &cp) + err = mrr.Close() + return commandTag, err +} + +func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, arguments []any) (pgconn.CommandTag, error) { + err := c.eqb.Build(c.typeMap, sd, arguments) if err != nil { - return cp, err + return pgconn.CommandTag{}, err } - ignoreKeys := map[string]struct{}{ - "connect_timeout": {}, - "sslcert": {}, - "sslkey": {}, - "sslmode": {}, - "sslrootcert": {}, - } + result := c.pgConn.ExecParams(ctx, sd.SQL, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats).Read() + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + return result.CommandTag, result.Err +} - cp.RuntimeParams = make(map[string]string) +func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription, arguments []any) (pgconn.CommandTag, error) { + err := c.eqb.Build(c.typeMap, sd, arguments) + if err != nil { + return pgconn.CommandTag{}, err + } - for k, v := range url.Query() { - if _, ok := ignoreKeys[k]; ok { - continue - } + result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats).Read() + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + return result.CommandTag, result.Err +} - cp.RuntimeParams[k] = v[0] - } - if cp.Password == "" { - pgpass(&cp) +func (c *Conn) execSQLParams(ctx context.Context, sql string, args []any) (pgconn.CommandTag, error) { + err := c.eqb.Build(c.typeMap, nil, args) + if err != nil { + return pgconn.CommandTag{}, err } - return cp, nil + + result := c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats).Read() + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + return result.CommandTag, result.Err } -var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) +func (c *Conn) getRows(ctx context.Context, sql string, args []any) *baseRows { + r := &baseRows{} -// ParseDSN parses a database DSN (data source name) into a ConnConfig -// -// e.g. ParseDSN("user=username password=password host=1.2.3.4 port=5432 dbname=mydb sslmode=disable") -// -// Any options not used by the connection process are parsed into ConnConfig.RuntimeParams. -// -// e.g. ParseDSN("application_name=pgxtest search_path=admin user=username password=password host=1.2.3.4 dbname=mydb") -// -// ParseDSN tries to match libpq behavior with regard to sslmode. See comments -// for ParseEnvLibpq for more information on the security implications of -// sslmode options. -func ParseDSN(s string) (ConnConfig, error) { - var cp ConnConfig - - m := dsnRegexp.FindAllStringSubmatch(s, -1) - - tlsArgs := configTLSArgs{} - - cp.RuntimeParams = make(map[string]string) - - for _, b := range m { - switch b[1] { - case "user": - cp.User = b[2] - case "password": - cp.Password = b[2] - case "host": - cp.Host = b[2] - case "port": - p, err := strconv.ParseUint(b[2], 10, 16) - if err != nil { - return cp, err - } - cp.Port = uint16(p) - case "dbname": - cp.Database = b[2] - case "sslmode": - tlsArgs.sslMode = b[2] - case "sslrootcert": - tlsArgs.sslRootCert = b[2] - case "sslcert": - tlsArgs.sslCert = b[2] - case "sslkey": - tlsArgs.sslKey = b[2] - case "connect_timeout": - timeout, err := strconv.ParseInt(b[2], 10, 64) - if err != nil { - return cp, err - } - d := defaultDialer() - d.Timeout = time.Duration(timeout) * time.Second - cp.Dial = d.Dial - default: - cp.RuntimeParams[b[1]] = b[2] - } - } + r.ctx = ctx + r.queryTracer = c.queryTracer + r.typeMap = c.typeMap + r.startTime = time.Now() + r.sql = sql + r.args = args + r.conn = c - err := configTLS(tlsArgs, &cp) - if err != nil { - return cp, err - } - if cp.Password == "" { - pgpass(&cp) - } - return cp, nil + return r } -// ParseConnectionString parses either a URI or a DSN connection string. -// see ParseURI and ParseDSN for details. -func ParseConnectionString(s string) (ConnConfig, error) { - if u, err := url.Parse(s); err == nil && u.Scheme != "" { - return ParseURI(s) +type QueryExecMode int32 + +const ( + _ QueryExecMode = iota + + // Automatically prepare and cache statements. This uses the extended protocol. Queries are executed in a single round + // trip after the statement is cached. This is the default. If the database schema is modified or the search_path is + // changed after a statement is cached then the first execution of a previously cached query may fail. e.g. If the + // number of columns returned by a "SELECT *" changes or the type of a column is changed. + QueryExecModeCacheStatement + + // Cache statement descriptions (i.e. argument and result types) and assume they do not change. This uses the extended + // protocol. Queries are executed in a single round trip after the description is cached. If the database schema is + // modified or the search_path is changed after a statement is cached then the first execution of a previously cached + // query may fail. e.g. If the number of columns returned by a "SELECT *" changes or the type of a column is changed. + QueryExecModeCacheDescribe + + // Get the statement description on every execution. This uses the extended protocol. Queries require two round trips + // to execute. It does not use named prepared statements. But it does use the unnamed prepared statement to get the + // statement description on the first round trip and then uses it to execute the query on the second round trip. This + // may cause problems with connection poolers that switch the underlying connection between round trips. It is safe + // even when the database schema is modified concurrently. + QueryExecModeDescribeExec + + // Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol + // with text formatted parameters and results. Queries are executed in a single round trip. Type mappings can be + // registered with pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are + // unregistered or ambiguous. e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know + // the PostgreSQL type can use a map[string]string directly as an argument. This mode cannot. + // + // On rare occasions user defined types may behave differently when encoded in the text format instead of the binary + // format. For example, this could happen if a "type RomanNumeral int32" implements fmt.Stringer to format integers as + // Roman numerals (e.g. 7 is VII). The binary format would properly encode the integer 7 as the binary value for 7. + // But the text format would encode the integer 7 as the string "VII". As QueryExecModeExec uses the text format, it + // is possible that changing query mode from another mode to QueryExecModeExec could change the behavior of the query. + // This should not occur with types pgx supports directly and can be avoided by registering the types with + // pgtype.Map.RegisterDefaultPgType and implementing the appropriate type interfaces. In the cas of RomanNumeral, it + // should implement pgtype.Int64Valuer. + QueryExecModeExec + + // Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. This is + // especially significant for []byte values. []byte values are encoded as PostgreSQL bytea. string must be used + // instead for text type values including json and jsonb. Type mappings can be registered with + // pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambiguous. + // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use a + // map[string]string directly as an argument. This mode cannot. Queries are executed in a single round trip. + // + // QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec. This includes + // the warning regarding differences in text format and binary format encoding with user defined types. There may be + // other minor exceptions such as behavior when multiple result returning queries are erroneously sent in a single + // string. + // + // QueryExecModeSimpleProtocol uses client side parameter interpolation. All values are quoted and escaped. Prefer + // QueryExecModeExec over QueryExecModeSimpleProtocol whenever possible. In general QueryExecModeSimpleProtocol should + // only be used if connecting to a proxy server, connection pool server, or non-PostgreSQL server that does not + // support the extended protocol. + QueryExecModeSimpleProtocol +) + +func (m QueryExecMode) String() string { + switch m { + case QueryExecModeCacheStatement: + return "cache statement" + case QueryExecModeCacheDescribe: + return "cache describe" + case QueryExecModeDescribeExec: + return "describe exec" + case QueryExecModeExec: + return "exec" + case QueryExecModeSimpleProtocol: + return "simple protocol" + default: + return "invalid" } - return ParseDSN(s) } -// ParseEnvLibpq parses the environment like libpq does into a ConnConfig -// -// See http://www.postgresql.org/docs/9.4/static/libpq-envars.html for details -// on the meaning of environment variables. +// QueryResultFormats controls the result format (text=0, binary=1) of a query by result column position. +type QueryResultFormats []int16 + +// QueryResultFormatsByOID controls the result format (text=0, binary=1) of a query by the result column OID. +type QueryResultFormatsByOID map[uint32]int16 + +// QueryRewriter rewrites a query when used as the first arguments to a query method. +type QueryRewriter interface { + RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) +} + +// Query sends a query to the server and returns a Rows to read the results. Only errors encountered sending the query +// and initializing Rows will be returned. Err() on the returned Rows must be checked after the Rows is closed to +// determine if the query executed successfully. // -// ParseEnvLibpq currently recognizes the following environment variables: -// PGHOST -// PGPORT -// PGDATABASE -// PGUSER -// PGPASSWORD -// PGSSLMODE -// PGSSLCERT -// PGSSLKEY -// PGSSLROOTCERT -// PGAPPNAME -// PGCONNECT_TIMEOUT +// The returned Rows must be closed before the connection can be used again. It is safe to attempt to read from the +// returned Rows even if an error is returned. The error will be the available in rows.Err() after rows are closed. It +// is allowed to ignore the error returned from Query and handle it in Rows. // -// Important TLS Security Notes: -// ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This -// includes defaulting to "prefer" behavior if no environment variable is set. +// It is possible for a call of FieldDescriptions on the returned Rows to return nil even if the Query call did not +// return an error. // -// See http://www.postgresql.org/docs/9.4/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION -// for details on what level of security each sslmode provides. +// It is possible for a query to return one or more rows before encountering an error. In most cases the rows should be +// collected before processing rather than processed while receiving each row. This avoids the possibility of the +// application processing rows from a query that the server rejected. The CollectRows function is useful here. // -// "verify-ca" mode currently is treated as "verify-full". e.g. It has stronger -// security guarantees than it would with libpq. Do not rely on this behavior as it -// may be possible to match libpq in the future. If you need full security use -// "verify-full". +// An implementor of QueryRewriter may be passed as the first element of args. It can rewrite the sql and change or +// replace args. For example, NamedArgs is QueryRewriter that implements named arguments. // -// Several of the PGSSLMODE options (including the default behavior of "prefer") -// will set UseFallbackTLS to true and FallbackTLSConfig to a disabled or -// weakened TLS mode. This means that if ParseEnvLibpq is used, but TLSConfig is -// later set from a different source that UseFallbackTLS MUST be set false to -// avoid the possibility of falling back to weaker or disabled security. -func ParseEnvLibpq() (ConnConfig, error) { - var cc ConnConfig - - cc.Host = os.Getenv("PGHOST") - - if pgport := os.Getenv("PGPORT"); pgport != "" { - if port, err := strconv.ParseUint(pgport, 10, 16); err == nil { - cc.Port = uint16(port) - } else { - return cc, err - } +// For extra control over how the query is executed, the types QueryExecMode, QueryResultFormats, and +// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely +// needed. See the documentation for those types for details. +func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) { + if c.queryTracer != nil { + ctx = c.queryTracer.TraceQueryStart(ctx, c, TraceQueryStartData{SQL: sql, Args: args}) } - cc.Database = os.Getenv("PGDATABASE") - cc.User = os.Getenv("PGUSER") - cc.Password = os.Getenv("PGPASSWORD") - - if pgtimeout := os.Getenv("PGCONNECT_TIMEOUT"); pgtimeout != "" { - if timeout, err := strconv.ParseInt(pgtimeout, 10, 64); err == nil { - d := defaultDialer() - d.Timeout = time.Duration(timeout) * time.Second - cc.Dial = d.Dial - } else { - return cc, err + if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil { + if c.queryTracer != nil { + c.queryTracer.TraceQueryEnd(ctx, c, TraceQueryEndData{Err: err}) + } + return &baseRows{err: err, closed: true}, err + } + + var resultFormats QueryResultFormats + var resultFormatsByOID QueryResultFormatsByOID + mode := c.config.DefaultQueryExecMode + var queryRewriter QueryRewriter + +optionLoop: + for len(args) > 0 { + switch arg := args[0].(type) { + case QueryResultFormats: + resultFormats = arg + args = args[1:] + case QueryResultFormatsByOID: + resultFormatsByOID = arg + args = args[1:] + case QueryExecMode: + mode = arg + args = args[1:] + case QueryRewriter: + queryRewriter = arg + args = args[1:] + default: + break optionLoop } } - tlsArgs := configTLSArgs{ - sslMode: os.Getenv("PGSSLMODE"), - sslKey: os.Getenv("PGSSLKEY"), - sslCert: os.Getenv("PGSSLCERT"), - sslRootCert: os.Getenv("PGSSLROOTCERT"), + if queryRewriter != nil { + var err error + originalSQL := sql + originalArgs := args + sql, args, err = queryRewriter.RewriteQuery(ctx, c, sql, args) + if err != nil { + rows := c.getRows(ctx, originalSQL, originalArgs) + err = fmt.Errorf("rewrite query failed: %w", err) + rows.fatal(err) + return rows, err + } } - err := configTLS(tlsArgs, &cc) - if err != nil { - return cc, err + // Bypass any statement caching. + if sql == "" { + mode = QueryExecModeSimpleProtocol } - cc.RuntimeParams = make(map[string]string) - if appname := os.Getenv("PGAPPNAME"); appname != "" { - cc.RuntimeParams["application_name"] = appname - } - if cc.Password == "" { - pgpass(&cc) - } - return cc, nil -} + c.eqb.reset() + rows := c.getRows(ctx, sql, args) -type configTLSArgs struct { - sslMode string - sslRootCert string - sslCert string - sslKey string -} + var err error + sd, explicitPreparedStatement := c.preparedStatements[sql] + if sd != nil || mode == QueryExecModeCacheStatement || mode == QueryExecModeCacheDescribe || mode == QueryExecModeDescribeExec { + if sd == nil { + sd, err = c.getStatementDescription(ctx, mode, sql) + if err != nil { + rows.fatal(err) + return rows, err + } + } -// configTLS uses lib/pq's TLS parameters to reconstruct a coherent tls.Config. -// Inputs are parsed out and provided by ParseDSN() or ParseURI(). -func configTLS(args configTLSArgs, cc *ConnConfig) error { - // Match libpq default behavior - if args.sslMode == "" { - args.sslMode = "prefer" - } + if len(sd.ParamOIDs) != len(args) { + rows.fatal(fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args))) + return rows, rows.err + } - switch args.sslMode { - case "disable": - cc.UseFallbackTLS = false - cc.TLSConfig = nil - cc.FallbackTLSConfig = nil - return nil - case "allow": - cc.UseFallbackTLS = true - cc.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true} - case "prefer": - cc.TLSConfig = &tls.Config{InsecureSkipVerify: true} - cc.UseFallbackTLS = true - cc.FallbackTLSConfig = nil - case "require": - cc.TLSConfig = &tls.Config{InsecureSkipVerify: true} - case "verify-ca", "verify-full": - cc.TLSConfig = &tls.Config{ - ServerName: cc.Host, - } - default: - return errors.New("sslmode is invalid") - } - - if args.sslRootCert != "" { - caCertPool := x509.NewCertPool() + rows.sql = sd.SQL - caPath := args.sslRootCert - caCert, err := ioutil.ReadFile(caPath) + err = c.eqb.Build(c.typeMap, sd, args) if err != nil { - return errors.Wrapf(err, "unable to read CA file %q", caPath) + rows.fatal(err) + return rows, rows.err } - if !caCertPool.AppendCertsFromPEM(caCert) { - return errors.Wrap(err, "unable to add CA to cert pool") + if resultFormatsByOID != nil { + resultFormats = make([]int16, len(sd.Fields)) + for i := range resultFormats { + resultFormats[i] = resultFormatsByOID[uint32(sd.Fields[i].DataTypeOID)] + } } - cc.TLSConfig.RootCAs = caCertPool - cc.TLSConfig.ClientCAs = caCertPool - } - - sslcert := args.sslCert - sslkey := args.sslKey - - if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { - return fmt.Errorf(`both "sslcert" and "sslkey" are required`) - } + if resultFormats == nil { + resultFormats = c.eqb.ResultFormats + } - if sslcert != "" && sslkey != "" { - cert, err := tls.LoadX509KeyPair(sslcert, sslkey) + if !explicitPreparedStatement && mode == QueryExecModeCacheDescribe { + rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, resultFormats) + } else { + rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, resultFormats) + } + } else if mode == QueryExecModeExec { + err := c.eqb.Build(c.typeMap, nil, args) if err != nil { - return errors.Wrap(err, "unable to read cert") + rows.fatal(err) + return rows, rows.err } - cc.TLSConfig.Certificates = []tls.Certificate{cert} - } - - return nil -} + rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats) + } else if mode == QueryExecModeSimpleProtocol { + sql, err = c.sanitizeForSimpleQuery(sql, args...) + if err != nil { + rows.fatal(err) + return rows, err + } -// Prepare creates a prepared statement with name and sql. sql can contain placeholders -// for bound parameters. These placeholders are referenced positional as $1, $2, etc. -// -// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same -// name and sql arguments. This allows a code path to Prepare and Query/Exec without -// concern for if the statement has already been prepared. -func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { - return c.PrepareEx(context.Background(), name, sql, nil) -} + mrr := c.pgConn.Exec(ctx, sql) + if mrr.NextResult() { + rows.resultReader = mrr.ResultReader() + rows.multiResultReader = mrr + } else { + err = mrr.Close() + rows.fatal(err) + return rows, err + } -// PrepareEx creates a prepared statement with name and sql. sql can contain placeholders -// for bound parameters. These placeholders are referenced positional as $1, $2, etc. -// It defers from Prepare as it allows additional options (such as parameter OIDs) to be passed via struct -// -// PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same -// name and sql arguments. This allows a code path to PrepareEx and Query/Exec without -// concern for if the statement has already been prepared. -func (c *Conn) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { - err = c.waitForPreviousCancelQuery(ctx) - if err != nil { - return nil, err + return rows, nil + } else { + err = fmt.Errorf("unknown QueryExecMode: %v", mode) + rows.fatal(err) + return rows, rows.err } - err = c.initContext(ctx) - if err != nil { - return nil, err - } + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. - ps, err = c.prepareEx(name, sql, opts) - err = c.termContext(err) - return ps, err + return rows, rows.err } -func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { - if name != "" { - if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql { - return ps, nil +// getStatementDescription returns the statement description of the sql query +// according to the given mode. +// +// If the mode is one that doesn't require to know the param and result OIDs +// then nil is returned without error. +func (c *Conn) getStatementDescription( + ctx context.Context, + mode QueryExecMode, + sql string, +) (sd *pgconn.StatementDescription, err error) { + switch mode { + case QueryExecModeCacheStatement: + if c.statementCache == nil { + return nil, errDisabledStatementCache } - } - - if err := c.ensureConnectionReadyForQuery(); err != nil { - return nil, err - } - - if c.shouldLog(LogLevelError) { - defer func() { + sd = c.statementCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, stmtcache.StatementName(sql), sql) if err != nil { - c.log(LogLevelError, "prepareEx failed", map[string]interface{}{"err": err, "name": name, "sql": sql}) + return nil, err } - }() - } - - if opts == nil { - opts = &PrepareExOptions{} - } - - if len(opts.ParameterOIDs) > 65535 { - return nil, errors.Errorf("Number of PrepareExOptions ParameterOIDs must be between 0 and 65535, received %d", len(opts.ParameterOIDs)) - } - - buf := appendParse(c.wbuf, name, sql, opts.ParameterOIDs) - buf = appendDescribe(buf, 'S', name) - buf = appendSync(buf) - - n, err := c.conn.Write(buf) - if err != nil { - if fatalWriteErr(n, err) { - c.die(err) + c.statementCache.Put(sd) } - return nil, err - } - c.pendingReadyForQueryCount++ - - ps = &PreparedStatement{Name: name, SQL: sql} - - var softErr error - - for { - msg, err := c.rxMsg() - if err != nil { - return nil, err + case QueryExecModeCacheDescribe: + if c.descriptionCache == nil { + return nil, errDisabledDescriptionCache } - - switch msg := msg.(type) { - case *pgproto3.ParameterDescription: - ps.ParameterOIDs = c.rxParameterDescription(msg) - - if len(ps.ParameterOIDs) > 65535 && softErr == nil { - softErr = errors.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOIDs)) - } - case *pgproto3.RowDescription: - ps.FieldDescriptions = c.rxRowDescription(msg) - for i := range ps.FieldDescriptions { - if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok { - ps.FieldDescriptions[i].DataTypeName = dt.Name - if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { - ps.FieldDescriptions[i].FormatCode = BinaryFormatCode - } else { - ps.FieldDescriptions[i].FormatCode = TextFormatCode - } - } else { - return nil, errors.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType) - } - } - case *pgproto3.ReadyForQuery: - c.rxReadyForQuery(msg) - - if softErr == nil { - c.preparedStatements[name] = ps - } - - return ps, softErr - default: - if e := c.processContextFreeMsg(msg); e != nil && softErr == nil { - softErr = e + sd = c.descriptionCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, "", sql) + if err != nil { + return nil, err } + c.descriptionCache.Put(sd) } + case QueryExecModeDescribeExec: + return c.Prepare(ctx, "", sql) } + return sd, err } -// Deallocate released a prepared statement -func (c *Conn) Deallocate(name string) error { - return c.deallocateContext(context.Background(), name) +// QueryRow is a convenience wrapper over Query. Any error that occurs while +// querying is deferred until calling Scan on the returned Row. That Row will +// error with ErrNoRows if no rows are returned. +func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row { + rows, _ := c.Query(ctx, sql, args...) + return (*connRow)(rows.(*baseRows)) } -// TODO - consider making this public -func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) { - err = c.waitForPreviousCancelQuery(ctx) - if err != nil { - return err +// SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless +// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection +// is used again. +// +// Depending on the QueryExecMode, all queries may be prepared before any are executed. This means that creating a table +// and using it in a subsequent query in the same batch can fail. +func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { + if len(b.QueuedQueries) == 0 { + return &emptyBatchResults{conn: c} } - err = c.initContext(ctx) - if err != nil { - return err + if c.batchTracer != nil { + ctx = c.batchTracer.TraceBatchStart(ctx, c, TraceBatchStartData{Batch: b}) + defer func() { + err := br.(interface{ earlyError() error }).earlyError() + if err != nil { + c.batchTracer.TraceBatchEnd(ctx, c, TraceBatchEndData{Err: err}) + } + }() } - defer func() { - err = c.termContext(err) - }() - if err := c.ensureConnectionReadyForQuery(); err != nil { - return err + if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} } - delete(c.preparedStatements, name) - - // close - buf := c.wbuf - buf = append(buf, 'C') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, 'S') - buf = append(buf, name...) - buf = append(buf, 0) - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - // flush - buf = append(buf, 'H') - buf = pgio.AppendInt32(buf, 4) - - _, err = c.conn.Write(buf) - if err != nil { - c.die(err) - return err - } + for _, bi := range b.QueuedQueries { + var queryRewriter QueryRewriter + sql := bi.SQL + arguments := bi.Arguments - for { - msg, err := c.rxMsg() - if err != nil { - return err + optionLoop: + for len(arguments) > 0 { + // Update Batch.Queue function comment when additional options are implemented + switch arg := arguments[0].(type) { + case QueryRewriter: + queryRewriter = arg + arguments = arguments[1:] + default: + break optionLoop + } } - switch msg.(type) { - case *pgproto3.CloseComplete: - return nil - default: - err = c.processContextFreeMsg(msg) + if queryRewriter != nil { + var err error + sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments) if err != nil { - return err + return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("rewrite query failed: %w", err)} } } - } -} -// Listen establishes a PostgreSQL listen/notify to channel -func (c *Conn) Listen(channel string) error { - _, err := c.Exec("listen " + quoteIdentifier(channel)) - if err != nil { - return err + bi.SQL = sql + bi.Arguments = arguments } - c.channels[channel] = struct{}{} - - return nil -} - -// Unlisten unsubscribes from a listen channel -func (c *Conn) Unlisten(channel string) error { - _, err := c.Exec("unlisten " + quoteIdentifier(channel)) - if err != nil { - return err - } - - delete(c.channels, channel) - return nil -} - -// WaitForNotification waits for a PostgreSQL notification. -func (c *Conn) WaitForNotification(ctx context.Context) (notification *Notification, err error) { - // Return already received notification immediately - if len(c.notifications) > 0 { - notification := c.notifications[0] - c.notifications = c.notifications[1:] - return notification, nil - } - - err = c.waitForPreviousCancelQuery(ctx) - if err != nil { - return nil, err - } - - err = c.initContext(ctx) - if err != nil { - return nil, err + // TODO: changing mode per batch? Update Batch.Queue function comment when implemented + mode := c.config.DefaultQueryExecMode + if mode == QueryExecModeSimpleProtocol { + return c.sendBatchQueryExecModeSimpleProtocol(ctx, b) } - defer func() { - err = c.termContext(err) - }() - if err = c.lock(); err != nil { - return nil, err - } - defer func() { - if unlockErr := c.unlock(); unlockErr != nil && err == nil { - err = unlockErr + // All other modes use extended protocol and thus can use prepared statements. + for _, bi := range b.QueuedQueries { + if sd, ok := c.preparedStatements[bi.SQL]; ok { + bi.sd = sd } - }() + } - if err := c.ensureConnectionReadyForQuery(); err != nil { - return nil, err + switch mode { + case QueryExecModeExec: + return c.sendBatchQueryExecModeExec(ctx, b) + case QueryExecModeCacheStatement: + return c.sendBatchQueryExecModeCacheStatement(ctx, b) + case QueryExecModeCacheDescribe: + return c.sendBatchQueryExecModeCacheDescribe(ctx, b) + case QueryExecModeDescribeExec: + return c.sendBatchQueryExecModeDescribeExec(ctx, b) + default: + panic("unknown QueryExecMode") } +} - for { - msg, err := c.rxMsg() - if err != nil { - return nil, err +func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults { + var sb strings.Builder + for i, bi := range b.QueuedQueries { + if i > 0 { + sb.WriteByte(';') } - - err = c.processContextFreeMsg(msg) + sql, err := c.sanitizeForSimpleQuery(bi.SQL, bi.Arguments...) if err != nil { - return nil, err - } - - if len(c.notifications) > 0 { - notification := c.notifications[0] - c.notifications = c.notifications[1:] - return notification, nil + return &batchResults{ctx: ctx, conn: c, err: err} } + sb.WriteString(sql) } -} - -func (c *Conn) IsAlive() bool { - c.mux.Lock() - defer c.mux.Unlock() - return c.status >= connStatusIdle -} - -func (c *Conn) CauseOfDeath() error { - c.mux.Lock() - defer c.mux.Unlock() - return c.causeOfDeath -} - -func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) { - if ps, present := c.preparedStatements[sql]; present { - return c.sendPreparedQuery(ps, arguments...) + mrr := c.pgConn.Exec(ctx, sb.String()) + return &batchResults{ + ctx: ctx, + conn: c, + mrr: mrr, + b: b, + qqIdx: 0, } - return c.sendSimpleQuery(sql, arguments...) } -func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { - if err := c.ensureConnectionReadyForQuery(); err != nil { - return err - } +func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults { + batch := &pgconn.Batch{} - if len(args) == 0 { - buf := appendQuery(c.wbuf, sql) + for _, bi := range b.QueuedQueries { + sd := bi.sd + if sd != nil { + err := c.eqb.Build(c.typeMap, sd, bi.Arguments) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } - _, err := c.conn.Write(buf) - if err != nil { - c.die(err) - return err + batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) + } else { + err := c.eqb.Build(c.typeMap, nil, bi.Arguments) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + batch.ExecParams(bi.SQL, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats) } - c.pendingReadyForQueryCount++ - - return nil - } - - ps, err := c.Prepare("", sql) - if err != nil { - return err } - return c.sendPreparedQuery(ps, args...) -} + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. -func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}) (err error) { - if len(ps.ParameterOIDs) != len(arguments) { - return errors.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOIDs), len(arguments)) - } + mrr := c.pgConn.ExecBatch(ctx, batch) - if err := c.ensureConnectionReadyForQuery(); err != nil { - return err + return &batchResults{ + ctx: ctx, + conn: c, + mrr: mrr, + b: b, + qqIdx: 0, } +} - resultFormatCodes := make([]int16, len(ps.FieldDescriptions)) - for i, fd := range ps.FieldDescriptions { - resultFormatCodes[i] = fd.FormatCode - } - buf, err := appendBind(c.wbuf, "", ps.Name, c.ConnInfo, ps.ParameterOIDs, arguments, resultFormatCodes) - if err != nil { - return err +func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { + if c.statementCache == nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache, closed: true} } - buf = appendExecute(buf, "", 0) - buf = appendSync(buf) + distinctNewQueries := []*pgconn.StatementDescription{} + distinctNewQueriesIdxMap := make(map[string]int) - n, err := c.conn.Write(buf) - if err != nil { - if fatalWriteErr(n, err) { - c.die(err) + for _, bi := range b.QueuedQueries { + if bi.sd == nil { + sd := c.statementCache.Get(bi.SQL) + if sd != nil { + bi.sd = sd + } else { + if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present { + bi.sd = distinctNewQueries[idx] + } else { + sd = &pgconn.StatementDescription{ + Name: stmtcache.StatementName(bi.SQL), + SQL: bi.SQL, + } + distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) + distinctNewQueries = append(distinctNewQueries, sd) + bi.sd = sd + } + } } - return err } - c.pendingReadyForQueryCount++ - return nil + return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.statementCache) } -// fatalWriteError takes the response of a net.Conn.Write and determines if it is fatal -func fatalWriteErr(bytesWritten int, err error) bool { - // Partial writes break the connection - if bytesWritten > 0 { - return true +func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { + if c.descriptionCache == nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache, closed: true} } - netErr, is := err.(net.Error) - return !(is && netErr.Timeout()) -} - -// Exec executes sql. sql can be either a prepared statement name or an SQL string. -// arguments should be referenced positionally from the sql string as $1, $2, etc. -func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { - return c.ExecEx(context.Background(), sql, nil, arguments...) -} + distinctNewQueries := []*pgconn.StatementDescription{} + distinctNewQueriesIdxMap := make(map[string]int) -// Processes messages that are not exclusive to one context such as -// authentication or query response. The response to these messages is the same -// regardless of when they occur. It also ignores messages that are only -// meaningful in a given context. These messages can occur due to a context -// deadline interrupting message processing. For example, an interrupted query -// may have left DataRow messages on the wire. -func (c *Conn) processContextFreeMsg(msg pgproto3.BackendMessage) (err error) { - switch msg := msg.(type) { - case *pgproto3.ErrorResponse: - return c.rxErrorResponse(msg) - case *pgproto3.NoticeResponse: - c.rxNoticeResponse(msg) - case *pgproto3.NotificationResponse: - c.rxNotificationResponse(msg) - case *pgproto3.ReadyForQuery: - c.rxReadyForQuery(msg) - case *pgproto3.ParameterStatus: - c.rxParameterStatus(msg) + for _, bi := range b.QueuedQueries { + if bi.sd == nil { + sd := c.descriptionCache.Get(bi.SQL) + if sd != nil { + bi.sd = sd + } else { + if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present { + bi.sd = distinctNewQueries[idx] + } else { + sd = &pgconn.StatementDescription{ + SQL: bi.SQL, + } + distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) + distinctNewQueries = append(distinctNewQueries, sd) + bi.sd = sd + } + } + } } - return nil + return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.descriptionCache) } -func (c *Conn) rxMsg() (pgproto3.BackendMessage, error) { - if !c.IsAlive() { - return nil, ErrDeadConn - } +func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { + distinctNewQueries := []*pgconn.StatementDescription{} + distinctNewQueriesIdxMap := make(map[string]int) - msg, err := c.frontend.Receive() - if err != nil { - if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { - c.die(err) + for _, bi := range b.QueuedQueries { + if bi.sd == nil { + if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present { + bi.sd = distinctNewQueries[idx] + } else { + sd := &pgconn.StatementDescription{ + SQL: bi.SQL, + } + distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) + distinctNewQueries = append(distinctNewQueries, sd) + bi.sd = sd + } } - return nil, err } - c.lastActivityTime = time.Now() - - // fmt.Printf("rxMsg: %#v\n", msg) - - return msg, nil + return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, nil) } -func (c *Conn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { - switch msg.Type { - case pgproto3.AuthTypeOk: - case pgproto3.AuthTypeCleartextPassword: - err = c.txPasswordMessage(c.config.Password) - case pgproto3.AuthTypeMD5Password: - digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+string(msg.Salt[:])) - err = c.txPasswordMessage(digestedPassword) - default: - err = errors.New("Received unknown authentication message") - } - - return -} +func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) { + pipeline := c.pgConn.StartPipeline(ctx) + defer func() { + if pbr != nil && pbr.err != nil { + pipeline.Close() + } + }() -func hexMD5(s string) string { - hash := md5.New() - io.WriteString(hash, s) - return hex.EncodeToString(hash.Sum(nil)) -} + // Prepare any needed queries + if len(distinctNewQueries) > 0 { + err := func() (err error) { + for _, sd := range distinctNewQueries { + pipeline.SendPrepare(sd.Name, sd.SQL, nil) + } -func (c *Conn) rxParameterStatus(msg *pgproto3.ParameterStatus) { - c.RuntimeParams[msg.Name] = msg.Value -} + // Store all statements we are preparing into the cache. It's fine if it overflows because HandleInvalidated will + // clean them up later. + if sdCache != nil { + for _, sd := range distinctNewQueries { + sdCache.Put(sd) + } + } -func (c *Conn) rxErrorResponse(msg *pgproto3.ErrorResponse) PgError { - err := PgError{ - Severity: msg.Severity, - Code: msg.Code, - Message: msg.Message, - Detail: msg.Detail, - Hint: msg.Hint, - Position: msg.Position, - InternalPosition: msg.InternalPosition, - InternalQuery: msg.InternalQuery, - Where: msg.Where, - SchemaName: msg.SchemaName, - TableName: msg.TableName, - ColumnName: msg.ColumnName, - DataTypeName: msg.DataTypeName, - ConstraintName: msg.ConstraintName, - File: msg.File, - Line: msg.Line, - Routine: msg.Routine, - } - - if err.Severity == "FATAL" { - c.die(err) - } + // If something goes wrong preparing the statements, we need to invalidate the cache entries we just added. + defer func() { + if err != nil && sdCache != nil { + for _, sd := range distinctNewQueries { + sdCache.Invalidate(sd.SQL) + } + } + }() - return err -} + err = pipeline.Sync() + if err != nil { + return err + } -func (c *Conn) rxNoticeResponse(msg *pgproto3.NoticeResponse) { - if c.onNotice == nil { - return - } + for _, sd := range distinctNewQueries { + results, err := pipeline.GetResults() + if err != nil { + return err + } - notice := &Notice{ - Severity: msg.Severity, - Code: msg.Code, - Message: msg.Message, - Detail: msg.Detail, - Hint: msg.Hint, - Position: msg.Position, - InternalPosition: msg.InternalPosition, - InternalQuery: msg.InternalQuery, - Where: msg.Where, - SchemaName: msg.SchemaName, - TableName: msg.TableName, - ColumnName: msg.ColumnName, - DataTypeName: msg.DataTypeName, - ConstraintName: msg.ConstraintName, - File: msg.File, - Line: msg.Line, - Routine: msg.Routine, - } - - c.onNotice(c, notice) -} + resultSD, ok := results.(*pgconn.StatementDescription) + if !ok { + return fmt.Errorf("expected statement description, got %T", results) + } -func (c *Conn) rxBackendKeyData(msg *pgproto3.BackendKeyData) { - c.pid = msg.ProcessID - c.secretKey = msg.SecretKey -} + // Fill in the previously empty / pending statement descriptions. + sd.ParamOIDs = resultSD.ParamOIDs + sd.Fields = resultSD.Fields + } -func (c *Conn) rxReadyForQuery(msg *pgproto3.ReadyForQuery) { - c.pendingReadyForQueryCount-- - c.txStatus = msg.TxStatus -} + results, err := pipeline.GetResults() + if err != nil { + return err + } -func (c *Conn) rxRowDescription(msg *pgproto3.RowDescription) []FieldDescription { - fields := make([]FieldDescription, len(msg.Fields)) - for i := 0; i < len(fields); i++ { - fields[i].Name = msg.Fields[i].Name - fields[i].Table = pgtype.OID(msg.Fields[i].TableOID) - fields[i].AttributeNumber = msg.Fields[i].TableAttributeNumber - fields[i].DataType = pgtype.OID(msg.Fields[i].DataTypeOID) - fields[i].DataTypeSize = msg.Fields[i].DataTypeSize - fields[i].Modifier = msg.Fields[i].TypeModifier - fields[i].FormatCode = msg.Fields[i].Format - } - return fields -} + _, ok := results.(*pgconn.PipelineSync) + if !ok { + return fmt.Errorf("expected sync, got %T", results) + } -func (c *Conn) rxParameterDescription(msg *pgproto3.ParameterDescription) []pgtype.OID { - parameters := make([]pgtype.OID, len(msg.ParameterOIDs)) - for i := 0; i < len(parameters); i++ { - parameters[i] = pgtype.OID(msg.ParameterOIDs[i]) + return nil + }() + if err != nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} + } } - return parameters -} -func (c *Conn) rxNotificationResponse(msg *pgproto3.NotificationResponse) { - n := new(Notification) - n.PID = msg.PID - n.Channel = msg.Channel - n.Payload = msg.Payload - c.notifications = append(c.notifications, n) -} + // Queue the queries. + for _, bi := range b.QueuedQueries { + err := c.eqb.Build(c.typeMap, bi.sd, bi.Arguments) + if err != nil { + // we wrap the error so we the user can understand which query failed inside the batch + err = fmt.Errorf("error building query %s: %w", bi.SQL, err) + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} + } -func (c *Conn) startTLS(tlsConfig *tls.Config) (err error) { - err = binary.Write(c.conn, binary.BigEndian, []int32{8, 80877103}) - if err != nil { - return + if bi.sd.Name == "" { + pipeline.SendQueryParams(bi.sd.SQL, c.eqb.ParamValues, bi.sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats) + } else { + pipeline.SendQueryPrepared(bi.sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) + } } - response := make([]byte, 1) - if _, err = io.ReadFull(c.conn, response); err != nil { - return + err := pipeline.Sync() + if err != nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } - if response[0] != 'S' { - return ErrTLSRefused + return &pipelineBatchResults{ + ctx: ctx, + conn: c, + pipeline: pipeline, + b: b, } - - c.conn = tls.Client(c.conn, tlsConfig) - - return nil -} - -func (c *Conn) txPasswordMessage(password string) (err error) { - buf := c.wbuf - buf = append(buf, 'p') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, password...) - buf = append(buf, 0) - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - _, err = c.conn.Write(buf) - - return err } -func (c *Conn) die(err error) { - c.mux.Lock() - defer c.mux.Unlock() - - if c.status == connStatusClosed { - return +func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) { + if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" { + return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on") } - c.status = connStatusClosed - c.causeOfDeath = err - c.conn.Close() -} - -func (c *Conn) lock() error { - c.mux.Lock() - defer c.mux.Unlock() - - if c.status != connStatusIdle { - return ErrConnBusy + if c.pgConn.ParameterStatus("client_encoding") != "UTF8" { + return "", errors.New("simple protocol queries must be run with client_encoding=UTF8") } - c.status = connStatusBusy - return nil -} - -func (c *Conn) unlock() error { - c.mux.Lock() - defer c.mux.Unlock() - - if c.status != connStatusBusy { - return errors.New("unlock conn that is not busy") + var err error + valueArgs := make([]any, len(args)) + for i, a := range args { + valueArgs[i], err = convertSimpleArgument(c.typeMap, a) + if err != nil { + return "", err + } } - c.status = connStatusIdle - return nil + return sanitize.SanitizeSQL(sql, valueArgs...) } -func (c *Conn) shouldLog(lvl int) bool { - return c.logger != nil && c.logLevel >= lvl -} +// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration. typeName must be +// the name of a type where the underlying type(s) is already understood by pgx. It is for derived types. In particular, +// typeName must be one of the following: +// - An array type name of a type that is already registered. e.g. "_foo" when "foo" is registered. +// - A composite type name where all field types are already registered. +// - A domain type name where the base type is already registered. +// - An enum type name. +// - A range type name where the element type is already registered. +// - A multirange type name where the element type is already registered. +func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, error) { + var oid uint32 -func (c *Conn) log(lvl LogLevel, msg string, data map[string]interface{}) { - if data == nil { - data = map[string]interface{}{} - } - if c.pid != 0 { - data["pid"] = c.pid + err := c.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid) + if err != nil { + return nil, err } - c.logger.Log(lvl, msg, data) -} - -// SetLogger replaces the current logger and returns the previous logger. -func (c *Conn) SetLogger(logger Logger) Logger { - oldLogger := c.logger - c.logger = logger - return oldLogger -} - -// SetLogLevel replaces the current log level and returns the previous log -// level. -func (c *Conn) SetLogLevel(lvl int) (int, error) { - oldLvl := c.logLevel + var typtype string + var typbasetype uint32 - if lvl < LogLevelNone || lvl > LogLevelTrace { - return oldLvl, ErrInvalidLogLevel + err = c.QueryRow(ctx, "select typtype::text, typbasetype from pg_type where oid=$1", oid).Scan(&typtype, &typbasetype) + if err != nil { + return nil, err } - c.logLevel = lvl - return lvl, nil -} - -func quoteIdentifier(s string) string { - return `"` + strings.Replace(s, `"`, `""`, -1) + `"` -} - -// cancelQuery sends a cancel request to the PostgreSQL server. It returns an -// error if unable to deliver the cancel request, but lack of an error does not -// ensure that the query was canceled. As specified in the documentation, there -// is no way to be sure a query was canceled. See -// https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861 -func (c *Conn) cancelQuery() { - if !atomic.CompareAndSwapInt32(&c.cancelQueryInProgress, 0, 1) { - panic("cancelQuery when cancelQueryInProgress") - } + switch typtype { + case "b": // array + elementOID, err := c.getArrayElementOID(ctx, oid) + if err != nil { + return nil, err + } - if err := c.conn.SetDeadline(time.Now()); err != nil { - c.Close() // Close connection if unable to set deadline - return - } + dt, ok := c.TypeMap().TypeForOID(elementOID) + if !ok { + return nil, errors.New("array element OID not registered") + } - doCancel := func() error { - network, address := c.config.networkAddress() - cancelConn, err := c.config.Dial(network, address) + return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementType: dt}}, nil + case "c": // composite + fields, err := c.getCompositeFields(ctx, oid) if err != nil { - return err + return nil, err } - defer cancelConn.Close() - // If server doesn't process cancellation request in bounded time then abort. - err = cancelConn.SetDeadline(time.Now().Add(15 * time.Second)) - if err != nil { - return err + return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}, nil + case "d": // domain + dt, ok := c.TypeMap().TypeForOID(typbasetype) + if !ok { + return nil, errors.New("domain base type OID not registered") } - buf := make([]byte, 16) - binary.BigEndian.PutUint32(buf[0:4], 16) - binary.BigEndian.PutUint32(buf[4:8], 80877102) - binary.BigEndian.PutUint32(buf[8:12], uint32(c.pid)) - binary.BigEndian.PutUint32(buf[12:16], uint32(c.secretKey)) - _, err = cancelConn.Write(buf) + return &pgtype.Type{Name: typeName, OID: oid, Codec: dt.Codec}, nil + case "e": // enum + return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil + case "r": // range + elementOID, err := c.getRangeElementOID(ctx, oid) if err != nil { - return err + return nil, err } - _, err = cancelConn.Read(buf) - if err != io.EOF { - return errors.Errorf("Server failed to close connection after cancel query request: %v %v", err, buf) + dt, ok := c.TypeMap().TypeForOID(elementOID) + if !ok { + return nil, errors.New("range element OID not registered") } - return nil - } - - go func() { - err := doCancel() + return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.RangeCodec{ElementType: dt}}, nil + case "m": // multirange + elementOID, err := c.getMultiRangeElementOID(ctx, oid) if err != nil { - c.Close() // Something is very wrong. Terminate the connection. + return nil, err } - c.cancelQueryCompleted <- struct{}{} - }() -} -func (c *Conn) Ping(ctx context.Context) error { - _, err := c.ExecEx(ctx, ";", nil) - return err -} - -func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (CommandTag, error) { - err := c.waitForPreviousCancelQuery(ctx) - if err != nil { - return "", err - } + dt, ok := c.TypeMap().TypeForOID(elementOID) + if !ok { + return nil, errors.New("multirange element OID not registered") + } - if err := c.lock(); err != nil { - return "", err + return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}}, nil + default: + return &pgtype.Type{}, errors.New("unknown typtype") } - defer c.unlock() +} - startTime := time.Now() - c.lastActivityTime = startTime +func (c *Conn) getArrayElementOID(ctx context.Context, oid uint32) (uint32, error) { + var typelem uint32 - commandTag, err := c.execEx(ctx, sql, options, arguments...) + err := c.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem) if err != nil { - if c.shouldLog(LogLevelError) { - c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err}) - } - return commandTag, err - } - - if c.shouldLog(LogLevelInfo) { - endTime := time.Now() - c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag}) + return 0, err } - return commandTag, err + return typelem, nil } -func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) { - err = c.initContext(ctx) +func (c *Conn) getRangeElementOID(ctx context.Context, oid uint32) (uint32, error) { + var typelem uint32 + + err := c.QueryRow(ctx, "select rngsubtype from pg_range where rngtypid=$1", oid).Scan(&typelem) if err != nil { - return "", err + return 0, err } - defer func() { - err = c.termContext(err) - }() - - if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) { - err = c.sanitizeAndSendSimpleQuery(sql, arguments...) - if err != nil { - return "", err - } - } else if options != nil && len(options.ParameterOIDs) > 0 { - if err := c.ensureConnectionReadyForQuery(); err != nil { - return "", err - } - - buf, err := c.buildOneRoundTripExec(c.wbuf, sql, options, arguments) - if err != nil { - return "", err - } - buf = appendSync(buf) + return typelem, nil +} - n, err := c.conn.Write(buf) - if err != nil && fatalWriteErr(n, err) { - c.die(err) - return "", err - } - c.pendingReadyForQueryCount++ - } else { - if len(arguments) > 0 { - ps, ok := c.preparedStatements[sql] - if !ok { - var err error - ps, err = c.prepareEx("", sql, nil) - if err != nil { - return "", err - } - } +func (c *Conn) getMultiRangeElementOID(ctx context.Context, oid uint32) (uint32, error) { + var typelem uint32 - err = c.sendPreparedQuery(ps, arguments...) - if err != nil { - return "", err - } - } else { - if err = c.sendQuery(sql, arguments...); err != nil { - return - } - } + err := c.QueryRow(ctx, "select rngtypid from pg_range where rngmultitypid=$1", oid).Scan(&typelem) + if err != nil { + return 0, err } - var softErr error - - for { - msg, err := c.rxMsg() - if err != nil { - return commandTag, err - } - - switch msg := msg.(type) { - case *pgproto3.ReadyForQuery: - c.rxReadyForQuery(msg) - return commandTag, softErr - case *pgproto3.CommandComplete: - commandTag = CommandTag(msg.CommandTag) - default: - if e := c.processContextFreeMsg(msg); e != nil && softErr == nil { - softErr = e - } - } - } + return typelem, nil } -func (c *Conn) buildOneRoundTripExec(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) { - if len(arguments) != len(options.ParameterOIDs) { - return nil, errors.Errorf("mismatched number of arguments (%d) and options.ParameterOIDs (%d)", len(arguments), len(options.ParameterOIDs)) - } +func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.CompositeCodecField, error) { + var typrelid uint32 - if len(options.ParameterOIDs) > 65535 { - return nil, errors.Errorf("Number of QueryExOptions ParameterOIDs must be between 0 and 65535, received %d", len(options.ParameterOIDs)) - } - - buf = appendParse(buf, "", sql, options.ParameterOIDs) - buf, err := appendBind(buf, "", "", c.ConnInfo, options.ParameterOIDs, arguments, nil) + err := c.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid) if err != nil { return nil, err } - buf = appendExecute(buf, "", 0) - - return buf, nil -} - -func (c *Conn) initContext(ctx context.Context) error { - if c.ctxInProgress { - return errors.New("ctx already in progress") - } - if ctx.Done() == nil { + var fields []pgtype.CompositeCodecField + var fieldName string + var fieldOID uint32 + rows, _ := c.Query(ctx, `select attname, atttypid +from pg_attribute +where attrelid=$1 + and not attisdropped + and attnum > 0 +order by attnum`, + typrelid, + ) + _, err = ForEachRow(rows, []any{&fieldName, &fieldOID}, func() error { + dt, ok := c.TypeMap().TypeForOID(fieldOID) + if !ok { + return fmt.Errorf("unknown composite type field OID: %v", fieldOID) + } + fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt}) return nil + }) + if err != nil { + return nil, err } - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - c.ctxInProgress = true - - go c.contextHandler(ctx) - - return nil + return fields, nil } -func (c *Conn) termContext(opErr error) error { - if !c.ctxInProgress { - return opErr +func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error { + if txStatus := c.pgConn.TxStatus(); txStatus != 'I' && txStatus != 'T' { + return nil } - var err error - - select { - case err = <-c.closedChan: - if opErr == nil { - err = nil - } - case c.doneChan <- struct{}{}: - err = opErr + if c.descriptionCache != nil { + c.descriptionCache.RemoveInvalidated() } - c.ctxInProgress = false - return err -} - -func (c *Conn) contextHandler(ctx context.Context) { - select { - case <-ctx.Done(): - c.cancelQuery() - c.closedChan <- ctx.Err() - case <-c.doneChan: + var invalidatedStatements []*pgconn.StatementDescription + if c.statementCache != nil { + invalidatedStatements = c.statementCache.GetInvalidated() } -} -func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error { - if atomic.LoadInt32(&c.cancelQueryInProgress) == 0 { + if len(invalidatedStatements) == 0 { return nil } - select { - case <-c.cancelQueryCompleted: - atomic.StoreInt32(&c.cancelQueryInProgress, 0) - if err := c.conn.SetDeadline(time.Time{}); err != nil { - c.Close() // Close connection if unable to disable deadline - return err - } - return nil - case <-ctx.Done(): - return ctx.Err() - } -} - -func (c *Conn) ensureConnectionReadyForQuery() error { - for c.pendingReadyForQueryCount > 0 { - msg, err := c.rxMsg() - if err != nil { - return err - } + pipeline := c.pgConn.StartPipeline(ctx) + defer pipeline.Close() - switch msg := msg.(type) { - case *pgproto3.ErrorResponse: - pgErr := c.rxErrorResponse(msg) - if pgErr.Severity == "FATAL" { - return pgErr - } - default: - err = c.processContextFreeMsg(msg) - if err != nil { - return err - } - } + for _, sd := range invalidatedStatements { + pipeline.SendDeallocate(sd.Name) } - return nil -} - -func connInfoFromRows(rows *Rows, err error) (map[string]pgtype.OID, error) { + err := pipeline.Sync() if err != nil { - return nil, err + return fmt.Errorf("failed to deallocate cached statement(s): %w", err) } - defer rows.Close() - nameOIDs := make(map[string]pgtype.OID, 256) - for rows.Next() { - var oid pgtype.OID - var name pgtype.Text - if err = rows.Scan(&oid, &name); err != nil { - return nil, err - } - - nameOIDs[name.String] = oid + err = pipeline.Close() + if err != nil { + return fmt.Errorf("failed to deallocate cached statement(s): %w", err) } - if err = rows.Err(); err != nil { - return nil, err + c.statementCache.RemoveInvalidated() + for _, sd := range invalidatedStatements { + delete(c.preparedStatements, sd.Name) } - return nameOIDs, err + return nil } diff --git a/conn_config_test.go.example b/conn_config_test.go.example deleted file mode 100644 index 096e13548..000000000 --- a/conn_config_test.go.example +++ /dev/null @@ -1,79 +0,0 @@ -package pgx_test - -import ( - // "crypto/tls" - // "crypto/x509" - // "fmt" - // "go/build" - // "io/ioutil" - // "path" - - "github.com/jackc/pgx" -) - -var defaultConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} - -// To skip tests for specific connection / authentication types set that connection param to nil -var tcpConnConfig *pgx.ConnConfig = nil -var unixSocketConnConfig *pgx.ConnConfig = nil -var md5ConnConfig *pgx.ConnConfig = nil -var plainPasswordConnConfig *pgx.ConnConfig = nil -var invalidUserConnConfig *pgx.ConnConfig = nil -var tlsConnConfig *pgx.ConnConfig = nil -var customDialerConnConfig *pgx.ConnConfig = nil -var replicationConnConfig *pgx.ConnConfig = nil -var cratedbConnConfig *pgx.ConnConfig = nil - -// var tcpConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} -// var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"} -// var md5ConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} -// var plainPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"} -// var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"} -// var customDialerConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} -// var replicationConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_replication", Password: "secret", Database: "pgx_test"} - -// var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}} -// -//// or to test client certs: -// -// var tlsConnConfig *pgx.ConnConfig -// -// func init() { -// homeDir := build.Default.GOPATH -// tlsConnConfig = &pgx.ConnConfig{ -// Host: "127.0.0.1", -// User: "pgx_md5", -// Password: "secret", -// Database: "pgx_test", -// TLSConfig: &tls.Config{ -// InsecureSkipVerify: true, -// }, -// } -// caCertPool := x509.NewCertPool() -// -// caPath := path.Join(homeDir, "/src/github.com/jackc/pgx/rootCA.pem") -// caCert, err := ioutil.ReadFile(caPath) -// if err != nil { -// panic(fmt.Sprintf("unable to read CA file: %v", err)) -// } -// -// if !caCertPool.AppendCertsFromPEM(caCert) { -// panic("unable to add CA to cert pool") -// } -// -// tlsConnConfig.TLSConfig.RootCAs = caCertPool -// tlsConnConfig.TLSConfig.ClientCAs = caCertPool -// -// sslCert := path.Join(homeDir, "/src/github.com/jackc/pgx/pg_md5.crt") -// sslKey := path.Join(homeDir, "/src/github.com/jackc/pgx/pg_md5.key") -// if (sslCert != "" && sslKey == "") || (sslCert == "" && sslKey != "") { -// panic(`both "sslcert" and "sslkey" are required`) -// } -// -// cert, err := tls.LoadX509KeyPair(sslCert, sslKey) -// if err != nil { -// panic(fmt.Sprintf("unable to read cert: %v", err)) -// } -// -// tlsConnConfig.TLSConfig.Certificates = []tls.Certificate{cert} -// } diff --git a/conn_config_test.go.travis b/conn_config_test.go.travis deleted file mode 100644 index cf29a7437..000000000 --- a/conn_config_test.go.travis +++ /dev/null @@ -1,36 +0,0 @@ -package pgx_test - -import ( - "crypto/tls" - "github.com/jackc/pgx" - "os" - "strconv" -) - -var defaultConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} -var tcpConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} -var unixSocketConnConfig = &pgx.ConnConfig{Host: "/var/run/postgresql", User: "postgres", Database: "pgx_test"} -var md5ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} -var plainPasswordConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"} -var invalidUserConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"} -var tlsConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_ssl", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}} -var customDialerConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} -var replicationConnConfig *pgx.ConnConfig = nil -var cratedbConnConfig *pgx.ConnConfig = nil - -func init() { - pgVersion := os.Getenv("PGVERSION") - - if len(pgVersion) > 0 { - v, err := strconv.ParseFloat(pgVersion, 64) - if err == nil && v >= 9.6 { - replicationConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_replication", Password: "secret", Database: "pgx_test"} - } - } - - crateVersion := os.Getenv("CRATEVERSION") - if crateVersion != "" { - cratedbConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", Port: 6543, User: "pgx", Password: "", Database: "pgx_test"} - } -} - diff --git a/conn_internal_test.go b/conn_internal_test.go new file mode 100644 index 000000000..d3127ef7d --- /dev/null +++ b/conn_internal_test.go @@ -0,0 +1,55 @@ +package pgx + +import ( + "context" + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func mustParseConfig(t testing.TB, connString string) *ConnConfig { + config, err := ParseConfig(connString) + require.Nil(t, err) + return config +} + +func mustConnect(t testing.TB, config *ConnConfig) *Conn { + conn, err := ConnectConfig(context.Background(), config) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + return conn +} + +// Ensures the connection limits the size of its cached objects. +// This test examines the internals of *Conn so must be in the same package. +func TestStmtCacheSizeLimit(t *testing.T) { + const cacheLimit = 16 + + connConfig := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + connConfig.StatementCacheCapacity = cacheLimit + conn := mustConnect(t, connConfig) + defer func() { + err := conn.Close(context.Background()) + if err != nil { + t.Fatal(err) + } + }() + + // run a set of unique queries that should overflow the cache + ctx := context.Background() + for i := 0; i < cacheLimit*2; i++ { + uniqueString := fmt.Sprintf("unique %d", i) + uniqueSQL := fmt.Sprintf("select '%s'", uniqueString) + var output string + err := conn.QueryRow(ctx, uniqueSQL).Scan(&output) + require.NoError(t, err) + require.Equal(t, uniqueString, output) + } + // preparedStatements contains cacheLimit+1 because deallocation happens before the query + assert.Len(t, conn.preparedStatements, cacheLimit+1) + assert.Equal(t, cacheLimit, conn.statementCache.Len()) +} diff --git a/conn_pool.go b/conn_pool.go deleted file mode 100644 index 6ca0ee018..000000000 --- a/conn_pool.go +++ /dev/null @@ -1,549 +0,0 @@ -package pgx - -import ( - "context" - "sync" - "time" - - "github.com/pkg/errors" - - "github.com/jackc/pgx/pgtype" -) - -type ConnPoolConfig struct { - ConnConfig - MaxConnections int // max simultaneous connections to use, default 5, must be at least 2 - AfterConnect func(*Conn) error // function to call on every new connection - AcquireTimeout time.Duration // max wait time when all connections are busy (0 means no timeout) -} - -type ConnPool struct { - allConnections []*Conn - availableConnections []*Conn - cond *sync.Cond - config ConnConfig // config used when establishing connection - inProgressConnects int - maxConnections int - resetCount int - afterConnect func(*Conn) error - logger Logger - logLevel int - closed bool - preparedStatements map[string]*PreparedStatement - acquireTimeout time.Duration - connInfo *pgtype.ConnInfo -} - -type ConnPoolStat struct { - MaxConnections int // max simultaneous connections to use - CurrentConnections int // current live connections - AvailableConnections int // unused live connections -} - -// ErrAcquireTimeout occurs when an attempt to acquire a connection times out. -var ErrAcquireTimeout = errors.New("timeout acquiring connection from pool") - -// ErrClosedPool occurs on an attempt to acquire a connection from a closed pool. -var ErrClosedPool = errors.New("cannot acquire from closed pool") - -// NewConnPool creates a new ConnPool. config.ConnConfig is passed through to -// Connect directly. -func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) { - p = new(ConnPool) - p.config = config.ConnConfig - p.connInfo = minimalConnInfo - p.maxConnections = config.MaxConnections - if p.maxConnections == 0 { - p.maxConnections = 5 - } - if p.maxConnections < 1 { - return nil, errors.New("MaxConnections must be at least 1") - } - p.acquireTimeout = config.AcquireTimeout - if p.acquireTimeout < 0 { - return nil, errors.New("AcquireTimeout must be equal to or greater than 0") - } - - p.afterConnect = config.AfterConnect - - if config.LogLevel != 0 { - p.logLevel = config.LogLevel - } else { - // Preserve pre-LogLevel behavior by defaulting to LogLevelDebug - p.logLevel = LogLevelDebug - } - p.logger = config.Logger - if p.logger == nil { - p.logLevel = LogLevelNone - } - - p.allConnections = make([]*Conn, 0, p.maxConnections) - p.availableConnections = make([]*Conn, 0, p.maxConnections) - p.preparedStatements = make(map[string]*PreparedStatement) - p.cond = sync.NewCond(new(sync.Mutex)) - - // Initially establish one connection - var c *Conn - c, err = p.createConnection() - if err != nil { - return - } - p.allConnections = append(p.allConnections, c) - p.availableConnections = append(p.availableConnections, c) - p.connInfo = c.ConnInfo.DeepCopy() - - return -} - -// Acquire takes exclusive use of a connection until it is released. -func (p *ConnPool) Acquire() (*Conn, error) { - p.cond.L.Lock() - c, err := p.acquire(nil) - p.cond.L.Unlock() - return c, err -} - -// deadlinePassed returns true if the given deadline has passed. -func (p *ConnPool) deadlinePassed(deadline *time.Time) bool { - return deadline != nil && time.Now().After(*deadline) -} - -// acquire performs acquision assuming pool is already locked -func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { - if p.closed { - return nil, ErrClosedPool - } - - // A connection is available - if len(p.availableConnections) > 0 { - c := p.availableConnections[len(p.availableConnections)-1] - c.poolResetCount = p.resetCount - p.availableConnections = p.availableConnections[:len(p.availableConnections)-1] - return c, nil - } - - // Set initial timeout/deadline value. If the method (acquire) happens to - // recursively call itself the deadline should retain its value. - if deadline == nil && p.acquireTimeout > 0 { - tmp := time.Now().Add(p.acquireTimeout) - deadline = &tmp - } - - // Make sure the deadline (if it is) has not passed yet - if p.deadlinePassed(deadline) { - return nil, ErrAcquireTimeout - } - - // If there is a deadline then start a timeout timer - var timer *time.Timer - if deadline != nil { - timer = time.AfterFunc(deadline.Sub(time.Now()), func() { - p.cond.Broadcast() - }) - defer timer.Stop() - } - - // No connections are available, but we can create more - if len(p.allConnections)+p.inProgressConnects < p.maxConnections { - // Create a new connection. - // Careful here: createConnectionUnlocked() removes the current lock, - // creates a connection and then locks it back. - c, err := p.createConnectionUnlocked() - if err != nil { - return nil, err - } - c.poolResetCount = p.resetCount - p.allConnections = append(p.allConnections, c) - return c, nil - } - // All connections are in use and we cannot create more - if p.logLevel >= LogLevelWarn { - p.logger.Log(LogLevelWarn, "waiting for available connection", nil) - } - - // Wait until there is an available connection OR room to create a new connection - for len(p.availableConnections) == 0 && len(p.allConnections)+p.inProgressConnects == p.maxConnections { - if p.deadlinePassed(deadline) { - return nil, ErrAcquireTimeout - } - p.cond.Wait() - } - - // Stop the timer so that we do not spawn it on every acquire call. - if timer != nil { - timer.Stop() - } - return p.acquire(deadline) -} - -// Release gives up use of a connection. -func (p *ConnPool) Release(conn *Conn) { - if conn.ctxInProgress { - panic("should never release when context is in progress") - } - - if conn.txStatus != 'I' { - conn.Exec("rollback") - } - - if len(conn.channels) > 0 { - if err := conn.Unlisten("*"); err != nil { - conn.die(err) - } - conn.channels = make(map[string]struct{}) - } - conn.notifications = nil - - p.cond.L.Lock() - - if conn.poolResetCount != p.resetCount { - conn.Close() - p.cond.L.Unlock() - p.cond.Signal() - return - } - - if conn.IsAlive() { - p.availableConnections = append(p.availableConnections, conn) - } else { - p.removeFromAllConnections(conn) - } - p.cond.L.Unlock() - p.cond.Signal() -} - -// removeFromAllConnections Removes the given connection from the list. -// It returns true if the connection was found and removed or false otherwise. -func (p *ConnPool) removeFromAllConnections(conn *Conn) bool { - for i, c := range p.allConnections { - if conn == c { - p.allConnections = append(p.allConnections[:i], p.allConnections[i+1:]...) - return true - } - } - return false -} - -// Close ends the use of a connection pool. It prevents any new connections from -// being acquired and closes available underlying connections. Any acquired -// connections will be closed when they are released. -func (p *ConnPool) Close() { - p.cond.L.Lock() - defer p.cond.L.Unlock() - - p.closed = true - - for _, c := range p.availableConnections { - _ = c.Close() - } - - // This will cause any checked out connections to be closed on release - p.resetCount++ -} - -// Reset closes all open connections, but leaves the pool open. It is intended -// for use when an error is detected that would disrupt all connections (such as -// a network interruption or a server state change). -// -// It is safe to reset a pool while connections are checked out. Those -// connections will be closed when they are returned to the pool. -func (p *ConnPool) Reset() { - p.cond.L.Lock() - defer p.cond.L.Unlock() - - p.resetCount++ - p.allConnections = p.allConnections[0:0] - - for _, conn := range p.availableConnections { - conn.Close() - } - - p.availableConnections = p.availableConnections[0:0] -} - -// invalidateAcquired causes all acquired connections to be closed when released. -// The pool must already be locked. -func (p *ConnPool) invalidateAcquired() { - p.resetCount++ - - for _, c := range p.availableConnections { - c.poolResetCount = p.resetCount - } - - p.allConnections = p.allConnections[:len(p.availableConnections)] - copy(p.allConnections, p.availableConnections) -} - -// Stat returns connection pool statistics -func (p *ConnPool) Stat() (s ConnPoolStat) { - p.cond.L.Lock() - defer p.cond.L.Unlock() - - s.MaxConnections = p.maxConnections - s.CurrentConnections = len(p.allConnections) - s.AvailableConnections = len(p.availableConnections) - return -} - -func (p *ConnPool) createConnection() (*Conn, error) { - c, err := connect(p.config, p.connInfo) - if err != nil { - return nil, err - } - return p.afterConnectionCreated(c) -} - -// createConnectionUnlocked Removes the current lock, creates a new connection, and -// then locks it back. -// Here is the point: lets say our pool dialer's OpenTimeout is set to 3 seconds. -// And we have a pool with 20 connections in it, and we try to acquire them all at -// startup. -// If it happens that the remote server is not accessible, then the first connection -// in the pool blocks all the others for 3 secs, before it gets the timeout. Then -// connection #2 holds the lock and locks everything for the next 3 secs until it -// gets OpenTimeout err, etc. And the very last 20th connection will fail only after -// 3 * 20 = 60 secs. -// To avoid this we put Connect(p.config) outside of the lock (it is thread safe) -// what would allow us to make all the 20 connection in parallel (more or less). -func (p *ConnPool) createConnectionUnlocked() (*Conn, error) { - p.inProgressConnects++ - p.cond.L.Unlock() - c, err := Connect(p.config) - p.cond.L.Lock() - p.inProgressConnects-- - - if err != nil { - return nil, err - } - return p.afterConnectionCreated(c) -} - -// afterConnectionCreated executes (if it is) afterConnect() callback and prepares -// all the known statements for the new connection. -func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) { - if p.afterConnect != nil { - err := p.afterConnect(c) - if err != nil { - c.die(err) - return nil, err - } - } - - for _, ps := range p.preparedStatements { - if _, err := c.Prepare(ps.Name, ps.SQL); err != nil { - c.die(err) - return nil, err - } - } - - return c, nil -} - -// Exec acquires a connection, delegates the call to that connection, and releases the connection -func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { - var c *Conn - if c, err = p.Acquire(); err != nil { - return - } - defer p.Release(c) - - return c.Exec(sql, arguments...) -} - -func (p *ConnPool) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) { - var c *Conn - if c, err = p.Acquire(); err != nil { - return - } - defer p.Release(c) - - return c.ExecEx(ctx, sql, options, arguments...) -} - -// Query acquires a connection and delegates the call to that connection. When -// *Rows are closed, the connection is released automatically. -func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) { - c, err := p.Acquire() - if err != nil { - // Because checking for errors can be deferred to the *Rows, build one with the error - return &Rows{closed: true, err: err}, err - } - - rows, err := c.Query(sql, args...) - if err != nil { - p.Release(c) - return rows, err - } - - rows.connPool = p - - return rows, nil -} - -func (p *ConnPool) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (*Rows, error) { - c, err := p.Acquire() - if err != nil { - // Because checking for errors can be deferred to the *Rows, build one with the error - return &Rows{closed: true, err: err}, err - } - - rows, err := c.QueryEx(ctx, sql, options, args...) - if err != nil { - p.Release(c) - return rows, err - } - - rows.connPool = p - - return rows, nil -} - -// QueryRow acquires a connection and delegates the call to that connection. The -// connection is released automatically after Scan is called on the returned -// *Row. -func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row { - rows, _ := p.Query(sql, args...) - return (*Row)(rows) -} - -func (p *ConnPool) QueryRowEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) *Row { - rows, _ := p.QueryEx(ctx, sql, options, args...) - return (*Row)(rows) -} - -// Begin acquires a connection and begins a transaction on it. When the -// transaction is closed the connection will be automatically released. -func (p *ConnPool) Begin() (*Tx, error) { - return p.BeginEx(context.Background(), nil) -} - -// Prepare creates a prepared statement on a connection in the pool to test the -// statement is valid. If it succeeds all connections accessed through the pool -// will have the statement available. -// -// Prepare creates a prepared statement with name and sql. sql can contain -// placeholders for bound parameters. These placeholders are referenced -// positional as $1, $2, etc. -// -// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with -// the same name and sql arguments. This allows a code path to Prepare and -// Query/Exec/PrepareEx without concern for if the statement has already been prepared. -func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) { - return p.PrepareEx(context.Background(), name, sql, nil) -} - -// PrepareEx creates a prepared statement on a connection in the pool to test the -// statement is valid. If it succeeds all connections accessed through the pool -// will have the statement available. -// -// PrepareEx creates a prepared statement with name and sql. sql can contain placeholders -// for bound parameters. These placeholders are referenced positional as $1, $2, etc. -// It defers from Prepare as it allows additional options (such as parameter OIDs) to be passed via struct -// -// PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same -// name and sql arguments. This allows a code path to PrepareEx and Query/Exec/Prepare without -// concern for if the statement has already been prepared. -func (p *ConnPool) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) { - p.cond.L.Lock() - defer p.cond.L.Unlock() - - if ps, ok := p.preparedStatements[name]; ok && ps.SQL == sql { - return ps, nil - } - - c, err := p.acquire(nil) - if err != nil { - return nil, err - } - - p.availableConnections = append(p.availableConnections, c) - - // Double check that the statement was not prepared by someone else - // while we were acquiring the connection (since acquire is not fully - // blocking now, see createConnectionUnlocked()) - if ps, ok := p.preparedStatements[name]; ok && ps.SQL == sql { - return ps, nil - } - - ps, err := c.PrepareEx(ctx, name, sql, opts) - if err != nil { - return nil, err - } - - for _, c := range p.availableConnections { - _, err := c.PrepareEx(ctx, name, sql, opts) - if err != nil { - return nil, err - } - } - - p.invalidateAcquired() - p.preparedStatements[name] = ps - - return ps, err -} - -// Deallocate releases a prepared statement from all connections in the pool. -func (p *ConnPool) Deallocate(name string) (err error) { - p.cond.L.Lock() - defer p.cond.L.Unlock() - - for _, c := range p.availableConnections { - if err := c.Deallocate(name); err != nil { - return err - } - } - - p.invalidateAcquired() - delete(p.preparedStatements, name) - - return nil -} - -// BeginEx acquires a connection and starts a transaction with txOptions -// determining the transaction mode. When the transaction is closed the -// connection will be automatically released. -func (p *ConnPool) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) { - for { - c, err := p.Acquire() - if err != nil { - return nil, err - } - - tx, err := c.BeginEx(ctx, txOptions) - if err != nil { - alive := c.IsAlive() - p.Release(c) - - // If connection is still alive then the error is not something trying - // again on a new connection would fix, so just return the error. But - // if the connection is dead try to acquire a new connection and try - // again. - if alive || ctx.Err() != nil { - return nil, err - } - continue - } - - tx.connPool = p - return tx, nil - } -} - -// CopyFrom acquires a connection, delegates the call to that connection, and releases the connection -func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) { - c, err := p.Acquire() - if err != nil { - return 0, err - } - defer p.Release(c) - - return c.CopyFrom(tableName, columnNames, rowSrc) -} - -// BeginBatch acquires a connection and begins a batch on that connection. When -// *Batch is finished, the connection is released automatically. -func (p *ConnPool) BeginBatch() *Batch { - c, err := p.Acquire() - return &Batch{conn: c, connPool: p, err: err} -} diff --git a/conn_pool_private_test.go b/conn_pool_private_test.go deleted file mode 100644 index ef0ec1dea..000000000 --- a/conn_pool_private_test.go +++ /dev/null @@ -1,44 +0,0 @@ -package pgx - -import ( - "testing" -) - -func compareConnSlices(slice1, slice2 []*Conn) bool { - if len(slice1) != len(slice2) { - return false - } - for i, c := range slice1 { - if c != slice2[i] { - return false - } - } - return true -} - -func TestConnPoolRemoveFromAllConnections(t *testing.T) { - t.Parallel() - pool := ConnPool{} - conn1 := &Conn{} - conn2 := &Conn{} - conn3 := &Conn{} - - // First element - pool.allConnections = []*Conn{conn1, conn2, conn3} - pool.removeFromAllConnections(conn1) - if !compareConnSlices(pool.allConnections, []*Conn{conn2, conn3}) { - t.Fatal("First element test failed") - } - // Element somewhere in the middle - pool.allConnections = []*Conn{conn1, conn2, conn3} - pool.removeFromAllConnections(conn2) - if !compareConnSlices(pool.allConnections, []*Conn{conn1, conn3}) { - t.Fatal("Middle element test failed") - } - // Last element - pool.allConnections = []*Conn{conn1, conn2, conn3} - pool.removeFromAllConnections(conn3) - if !compareConnSlices(pool.allConnections, []*Conn{conn1, conn2}) { - t.Fatal("Last element test failed") - } -} diff --git a/conn_pool_test.go b/conn_pool_test.go deleted file mode 100644 index 84a74aed3..000000000 --- a/conn_pool_test.go +++ /dev/null @@ -1,1083 +0,0 @@ -package pgx_test - -import ( - "context" - "fmt" - "net" - "sync" - "testing" - "time" - - "github.com/pkg/errors" - - "github.com/jackc/pgx" -) - -func createConnPool(t *testing.T, maxConnections int) *pgx.ConnPool { - config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: maxConnections} - pool, err := pgx.NewConnPool(config) - if err != nil { - t.Fatalf("Unable to create connection pool: %v", err) - } - return pool -} - -func acquireAllConnections(t *testing.T, pool *pgx.ConnPool, maxConnections int) []*pgx.Conn { - connections := make([]*pgx.Conn, maxConnections) - for i := 0; i < maxConnections; i++ { - var err error - if connections[i], err = pool.Acquire(); err != nil { - t.Fatalf("Unable to acquire connection: %v", err) - } - } - return connections -} - -func releaseAllConnections(pool *pgx.ConnPool, connections []*pgx.Conn) { - for _, c := range connections { - pool.Release(c) - } -} - -func acquireWithTimeTaken(pool *pgx.ConnPool) (*pgx.Conn, time.Duration, error) { - startTime := time.Now() - c, err := pool.Acquire() - return c, time.Since(startTime), err -} - -func TestNewConnPool(t *testing.T) { - t.Parallel() - - var numCallbacks int - afterConnect := func(c *pgx.Conn) error { - numCallbacks++ - return nil - } - - config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 2, AfterConnect: afterConnect} - pool, err := pgx.NewConnPool(config) - if err != nil { - t.Fatal("Unable to establish connection pool") - } - defer pool.Close() - - // It initially connects once - stat := pool.Stat() - if stat.CurrentConnections != 1 { - t.Errorf("Expected 1 connection to be established immediately, but %v were", numCallbacks) - } - - // Pool creation returns an error if any AfterConnect callback does - errAfterConnect := errors.New("Some error") - afterConnect = func(c *pgx.Conn) error { - return errAfterConnect - } - - config = pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 2, AfterConnect: afterConnect} - pool, err = pgx.NewConnPool(config) - if err != errAfterConnect { - t.Errorf("Expected errAfterConnect but received unexpected: %v", err) - } -} - -func TestNewConnPoolDefaultsTo5MaxConnections(t *testing.T) { - t.Parallel() - - config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig} - pool, err := pgx.NewConnPool(config) - if err != nil { - t.Fatal("Unable to establish connection pool") - } - defer pool.Close() - - if n := pool.Stat().MaxConnections; n != 5 { - t.Fatalf("Expected pool to default to 5 max connections, but it was %d", n) - } -} - -func TestPoolAcquireAndReleaseCycle(t *testing.T) { - t.Parallel() - - maxConnections := 2 - incrementCount := int32(100) - completeSync := make(chan int) - pool := createConnPool(t, maxConnections) - defer pool.Close() - - allConnections := acquireAllConnections(t, pool, maxConnections) - - for _, c := range allConnections { - mustExec(t, c, "create temporary table t(counter integer not null)") - mustExec(t, c, "insert into t(counter) values(0);") - } - - releaseAllConnections(pool, allConnections) - - f := func() { - conn, err := pool.Acquire() - if err != nil { - t.Fatal("Unable to acquire connection") - } - defer pool.Release(conn) - - // Increment counter... - mustExec(t, conn, "update t set counter = counter + 1") - completeSync <- 0 - } - - for i := int32(0); i < incrementCount; i++ { - go f() - } - - // Wait for all f() to complete - for i := int32(0); i < incrementCount; i++ { - <-completeSync - } - - // Check that temp table in each connection has been incremented some number of times - actualCount := int32(0) - allConnections = acquireAllConnections(t, pool, maxConnections) - - for _, c := range allConnections { - var n int32 - c.QueryRow("select counter from t").Scan(&n) - if n == 0 { - t.Error("A connection was never used") - } - - actualCount += n - } - - if actualCount != incrementCount { - fmt.Println(actualCount) - t.Error("Wrong number of increments") - } - - releaseAllConnections(pool, allConnections) -} - -func TestPoolNonBlockingConnections(t *testing.T) { - t.Parallel() - - var dialCountLock sync.Mutex - dialCount := 0 - openTimeout := 1 * time.Second - testDialer := func(network, address string) (net.Conn, error) { - var firstDial bool - dialCountLock.Lock() - dialCount++ - firstDial = dialCount == 1 - dialCountLock.Unlock() - - if firstDial { - return net.Dial(network, address) - } else { - time.Sleep(openTimeout) - return nil, &net.OpError{Op: "dial", Net: "tcp"} - } - } - - maxConnections := 3 - config := pgx.ConnPoolConfig{ - ConnConfig: *defaultConnConfig, - MaxConnections: maxConnections, - } - config.ConnConfig.Dial = testDialer - - pool, err := pgx.NewConnPool(config) - if err != nil { - t.Fatalf("Expected NewConnPool not to fail, instead it failed with: %v", err) - } - defer pool.Close() - - // NewConnPool establishes an initial connection - // so we need to close that for the rest of the test - if conn, err := pool.Acquire(); err == nil { - conn.Close() - pool.Release(conn) - } else { - t.Fatalf("pool.Acquire unexpectedly failed: %v", err) - } - - var wg sync.WaitGroup - wg.Add(maxConnections) - - startedAt := time.Now() - for i := 0; i < maxConnections; i++ { - go func() { - _, err := pool.Acquire() - wg.Done() - if err == nil { - t.Fatal("Acquire() expected to fail but it did not") - } - }() - } - wg.Wait() - - // Prior to createConnectionUnlocked() use the test took - // maxConnections * openTimeout seconds to complete. - // With createConnectionUnlocked() it takes ~ 1 * openTimeout seconds. - timeTaken := time.Since(startedAt) - if timeTaken > openTimeout+1*time.Second { - t.Fatalf("Expected all Acquire() to run in parallel and take about %v, instead it took '%v'", openTimeout, timeTaken) - } - -} - -func TestAcquireTimeoutSanity(t *testing.T) { - t.Parallel() - - config := pgx.ConnPoolConfig{ - ConnConfig: *defaultConnConfig, - MaxConnections: 1, - } - - // case 1: default 0 value - pool, err := pgx.NewConnPool(config) - if err != nil { - t.Fatalf("Expected NewConnPool with default config.AcquireTimeout not to fail, instead it failed with '%v'", err) - } - pool.Close() - - // case 2: negative value - config.AcquireTimeout = -1 * time.Second - _, err = pgx.NewConnPool(config) - if err == nil { - t.Fatal("Expected NewConnPool with negative config.AcquireTimeout to fail, instead it did not") - } - - // case 3: positive value - config.AcquireTimeout = 1 * time.Second - pool, err = pgx.NewConnPool(config) - if err != nil { - t.Fatalf("Expected NewConnPool with positive config.AcquireTimeout not to fail, instead it failed with '%v'", err) - } - defer pool.Close() -} - -func TestPoolWithAcquireTimeoutSet(t *testing.T) { - t.Parallel() - - connAllocTimeout := 2 * time.Second - config := pgx.ConnPoolConfig{ - ConnConfig: *defaultConnConfig, - MaxConnections: 1, - AcquireTimeout: connAllocTimeout, - } - - pool, err := pgx.NewConnPool(config) - if err != nil { - t.Fatalf("Unable to create connection pool: %v", err) - } - defer pool.Close() - - // Consume all connections ... - allConnections := acquireAllConnections(t, pool, config.MaxConnections) - defer releaseAllConnections(pool, allConnections) - - // ... then try to consume 1 more. It should fail after a short timeout. - _, timeTaken, err := acquireWithTimeTaken(pool) - - if err == nil || err != pgx.ErrAcquireTimeout { - t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) - } - if timeTaken < connAllocTimeout { - t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", connAllocTimeout, timeTaken) - } -} - -func TestPoolWithoutAcquireTimeoutSet(t *testing.T) { - t.Parallel() - - maxConnections := 1 - pool := createConnPool(t, maxConnections) - defer pool.Close() - - // Consume all connections ... - allConnections := acquireAllConnections(t, pool, maxConnections) - - // ... then try to consume 1 more. It should hang forever. - // To unblock it we release the previously taken connection in a goroutine. - stopDeadWaitTimeout := 5 * time.Second - timer := time.AfterFunc(stopDeadWaitTimeout+100*time.Millisecond, func() { - releaseAllConnections(pool, allConnections) - }) - defer timer.Stop() - - conn, timeTaken, err := acquireWithTimeTaken(pool) - if err == nil { - pool.Release(conn) - } else { - t.Fatalf("Expected error to be nil, instead it was '%v'", err) - } - if timeTaken < stopDeadWaitTimeout { - t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", stopDeadWaitTimeout, timeTaken) - } -} - -func TestPoolErrClosedPool(t *testing.T) { - t.Parallel() - - pool := createConnPool(t, 1) - // Intentionaly close the pool now so we can test ErrClosedPool - pool.Close() - - c, err := pool.Acquire() - if c != nil { - t.Fatalf("Expected acquired connection to be nil, instead it was '%v'", c) - } - - if err == nil || err != pgx.ErrClosedPool { - t.Fatalf("Expected error to be pgx.ErrClosedPool, instead it was '%v'", err) - } -} - -func TestPoolReleaseWithTransactions(t *testing.T) { - t.Parallel() - - pool := createConnPool(t, 2) - defer pool.Close() - - conn, err := pool.Acquire() - if err != nil { - t.Fatalf("Unable to acquire connection: %v", err) - } - mustExec(t, conn, "begin") - if _, err = conn.Exec("selct"); err == nil { - t.Fatal("Did not receive expected error") - } - - if conn.TxStatus() != 'E' { - t.Fatalf("Expected TxStatus to be 'E', instead it was '%c'", conn.TxStatus()) - } - - pool.Release(conn) - - if conn.TxStatus() != 'I' { - t.Fatalf("Expected release to rollback errored transaction, but it did not: '%c'", conn.TxStatus()) - } - - conn, err = pool.Acquire() - if err != nil { - t.Fatalf("Unable to acquire connection: %v", err) - } - mustExec(t, conn, "begin") - if conn.TxStatus() != 'T' { - t.Fatalf("Expected txStatus to be 'T', instead it was '%c'", conn.TxStatus()) - } - - pool.Release(conn) - - if conn.TxStatus() != 'I' { - t.Fatalf("Expected release to rollback uncommitted transaction, but it did not: '%c'", conn.TxStatus()) - } -} - -func TestPoolAcquireAndReleaseCycleAutoConnect(t *testing.T) { - t.Parallel() - - maxConnections := 3 - pool := createConnPool(t, maxConnections) - defer pool.Close() - - doSomething := func() { - c, err := pool.Acquire() - if err != nil { - t.Fatalf("Unable to Acquire: %v", err) - } - rows, _ := c.Query("select 1, pg_sleep(0.02)") - rows.Close() - pool.Release(c) - } - - for i := 0; i < 10; i++ { - doSomething() - } - - stat := pool.Stat() - if stat.CurrentConnections != 1 { - t.Fatalf("Pool shouldn't have established more connections when no contention: %v", stat.CurrentConnections) - } - - var wg sync.WaitGroup - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - defer wg.Done() - doSomething() - }() - } - wg.Wait() - - stat = pool.Stat() - if stat.CurrentConnections != stat.MaxConnections { - t.Fatalf("Pool should have used all possible connections: %v", stat.CurrentConnections) - } -} - -func TestPoolReleaseDiscardsDeadConnections(t *testing.T) { - t.Parallel() - - // Run timing sensitive test many times - for i := 0; i < 50; i++ { - func() { - maxConnections := 3 - pool := createConnPool(t, maxConnections) - defer pool.Close() - - var c1, c2 *pgx.Conn - var err error - var stat pgx.ConnPoolStat - - if c1, err = pool.Acquire(); err != nil { - t.Fatalf("Unexpected error acquiring connection: %v", err) - } - defer func() { - if c1 != nil { - pool.Release(c1) - } - }() - - if c2, err = pool.Acquire(); err != nil { - t.Fatalf("Unexpected error acquiring connection: %v", err) - } - defer func() { - if c2 != nil { - pool.Release(c2) - } - }() - - if _, err = c2.Exec("select pg_terminate_backend($1)", c1.PID()); err != nil { - t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) - } - - // do something with the connection so it knows it's dead - rows, _ := c1.Query("select 1") - rows.Close() - if rows.Err() == nil { - t.Fatal("Expected error but none occurred") - } - - if c1.IsAlive() { - t.Fatal("Expected connection to be dead but it wasn't") - } - - stat = pool.Stat() - if stat.CurrentConnections != 2 { - t.Fatalf("Unexpected CurrentConnections: %v", stat.CurrentConnections) - } - if stat.AvailableConnections != 0 { - t.Fatalf("Unexpected AvailableConnections: %v", stat.CurrentConnections) - } - - pool.Release(c1) - c1 = nil // so it doesn't get released again by the defer - - stat = pool.Stat() - if stat.CurrentConnections != 1 { - t.Fatalf("Unexpected CurrentConnections: %v", stat.CurrentConnections) - } - if stat.AvailableConnections != 0 { - t.Fatalf("Unexpected AvailableConnections: %v", stat.CurrentConnections) - } - }() - } -} - -func TestConnPoolResetClosesCheckedOutConnectionsOnRelease(t *testing.T) { - t.Parallel() - - pool := createConnPool(t, 5) - defer pool.Close() - - inProgressRows := []*pgx.Rows{} - var inProgressPIDs []int32 - - // Start some queries and reset pool while they are in progress - for i := 0; i < 10; i++ { - rows, err := pool.Query("select pg_backend_pid() union all select 1 union all select 2") - if err != nil { - t.Fatal(err) - } - - rows.Next() - var pid int32 - rows.Scan(&pid) - inProgressPIDs = append(inProgressPIDs, pid) - - inProgressRows = append(inProgressRows, rows) - pool.Reset() - } - - // Check that the queries are completed - for _, rows := range inProgressRows { - var expectedN int32 - - for rows.Next() { - expectedN++ - var n int32 - err := rows.Scan(&n) - if err != nil { - t.Fatal(err) - } - if expectedN != n { - t.Fatalf("Expected n to be %d, but it was %d", expectedN, n) - } - } - - if err := rows.Err(); err != nil { - t.Fatal(err) - } - } - - // pool should be in fresh state due to previous reset - stats := pool.Stat() - if stats.CurrentConnections != 0 || stats.AvailableConnections != 0 { - t.Fatalf("Unexpected connection pool stats: %v", stats) - } - - var connCount int - err := pool.QueryRow("select count(*) from pg_stat_activity where pid = any($1::int4[])", inProgressPIDs).Scan(&connCount) - if err != nil { - t.Fatal(err) - } - if connCount != 0 { - t.Fatalf("%d connections not closed", connCount) - } -} - -func TestConnPoolResetClosesCheckedInConnections(t *testing.T) { - t.Parallel() - - pool := createConnPool(t, 5) - defer pool.Close() - - inProgressRows := []*pgx.Rows{} - var inProgressPIDs []int32 - - // Start some queries and reset pool while they are in progress - for i := 0; i < 5; i++ { - rows, err := pool.Query("select pg_backend_pid()") - if err != nil { - t.Fatal(err) - } - - inProgressRows = append(inProgressRows, rows) - } - - // Check that the queries are completed - for _, rows := range inProgressRows { - for rows.Next() { - var pid int32 - err := rows.Scan(&pid) - if err != nil { - t.Fatal(err) - } - inProgressPIDs = append(inProgressPIDs, pid) - - } - - if err := rows.Err(); err != nil { - t.Fatal(err) - } - } - - // Ensure pool is fully connected and available - stats := pool.Stat() - if stats.CurrentConnections != 5 || stats.AvailableConnections != 5 { - t.Fatalf("Unexpected connection pool stats: %v", stats) - } - - pool.Reset() - - // Pool should be empty after reset - stats = pool.Stat() - if stats.CurrentConnections != 0 || stats.AvailableConnections != 0 { - t.Fatalf("Unexpected connection pool stats: %v", stats) - } - - var connCount int - err := pool.QueryRow("select count(*) from pg_stat_activity where pid = any($1::int4[])", inProgressPIDs).Scan(&connCount) - if err != nil { - t.Fatal(err) - } - if connCount != 0 { - t.Fatalf("%d connections not closed", connCount) - } -} - -func TestConnPoolTransaction(t *testing.T) { - t.Parallel() - - pool := createConnPool(t, 2) - defer pool.Close() - - stats := pool.Stat() - if stats.CurrentConnections != 1 || stats.AvailableConnections != 1 { - t.Fatalf("Unexpected connection pool stats: %v", stats) - } - - tx, err := pool.Begin() - if err != nil { - t.Fatalf("pool.Begin failed: %v", err) - } - defer tx.Rollback() - - var n int32 - err = tx.QueryRow("select 40+$1", 2).Scan(&n) - if err != nil { - t.Fatalf("tx.QueryRow Scan failed: %v", err) - } - if n != 42 { - t.Errorf("Expected 42, got %d", n) - } - - stats = pool.Stat() - if stats.CurrentConnections != 1 || stats.AvailableConnections != 0 { - t.Fatalf("Unexpected connection pool stats: %v", stats) - } - - err = tx.Rollback() - if err != nil { - t.Fatalf("tx.Rollback failed: %v", err) - } - - stats = pool.Stat() - if stats.CurrentConnections != 1 || stats.AvailableConnections != 1 { - t.Fatalf("Unexpected connection pool stats: %v", stats) - } -} - -func TestConnPoolTransactionIso(t *testing.T) { - t.Parallel() - - pool := createConnPool(t, 2) - defer pool.Close() - - tx, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable}) - if err != nil { - t.Fatalf("pool.BeginEx failed: %v", err) - } - defer tx.Rollback() - - var level string - err = tx.QueryRow("select current_setting('transaction_isolation')").Scan(&level) - if err != nil { - t.Fatalf("tx.QueryRow failed: %v", level) - } - - if level != "serializable" { - t.Errorf("Expected to be in isolation level %v but was %v", "serializable", level) - } -} - -func TestConnPoolBeginRetry(t *testing.T) { - t.Parallel() - - // Run timing sensitive test many times - for i := 0; i < 50; i++ { - func() { - pool := createConnPool(t, 2) - defer pool.Close() - - killerConn, err := pool.Acquire() - if err != nil { - t.Fatal(err) - } - defer pool.Release(killerConn) - - victimConn, err := pool.Acquire() - if err != nil { - t.Fatal(err) - } - pool.Release(victimConn) - - // Terminate connection that was released to pool - if _, err = killerConn.Exec("select pg_terminate_backend($1)", victimConn.PID()); err != nil { - t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) - } - - // Since victimConn is the only available connection in the pool, pool.Begin should - // try to use it, fail, and allocate another connection - tx, err := pool.Begin() - if err != nil { - t.Fatalf("pool.Begin failed: %v", err) - } - defer tx.Rollback() - - var txPID uint32 - err = tx.QueryRow("select pg_backend_pid()").Scan(&txPID) - if err != nil { - t.Fatalf("tx.QueryRow Scan failed: %v", err) - } - if txPID == victimConn.PID() { - t.Error("Expected txPID to defer from killed conn pid, but it didn't") - } - }() - } -} - -func TestConnPoolQuery(t *testing.T) { - t.Parallel() - - pool := createConnPool(t, 2) - defer pool.Close() - - var sum, rowCount int32 - - rows, err := pool.Query("select generate_series(1,$1)", 10) - if err != nil { - t.Fatalf("pool.Query failed: %v", err) - } - - stats := pool.Stat() - if stats.CurrentConnections != 1 || stats.AvailableConnections != 0 { - t.Fatalf("Unexpected connection pool stats: %v", stats) - } - - for rows.Next() { - var n int32 - rows.Scan(&n) - sum += n - rowCount++ - } - - if rows.Err() != nil { - t.Fatalf("conn.Query failed: %v", err) - } - - if rowCount != 10 { - t.Error("Select called onDataRow wrong number of times") - } - if sum != 55 { - t.Error("Wrong values returned") - } - - stats = pool.Stat() - if stats.CurrentConnections != 1 || stats.AvailableConnections != 1 { - t.Fatalf("Unexpected connection pool stats: %v", stats) - } -} - -func TestConnPoolQueryConcurrentLoad(t *testing.T) { - t.Parallel() - - pool := createConnPool(t, 10) - defer pool.Close() - - n := 100 - done := make(chan bool) - - for i := 0; i < n; i++ { - go func() { - defer func() { done <- true }() - var rowCount int32 - - rows, err := pool.Query("select generate_series(1,$1)", 1000) - if err != nil { - t.Fatalf("pool.Query failed: %v", err) - } - defer rows.Close() - - for rows.Next() { - var n int32 - err = rows.Scan(&n) - if err != nil { - t.Fatalf("rows.Scan failed: %v", err) - } - if n != rowCount+1 { - t.Fatalf("Expected n to be %d, but it was %d", rowCount+1, n) - } - rowCount++ - } - - if rows.Err() != nil { - t.Fatalf("conn.Query failed: %v", rows.Err()) - } - - if rowCount != 1000 { - t.Error("Select called onDataRow wrong number of times") - } - - _, err = pool.Exec("--;") - if err != nil { - t.Fatalf("pool.Exec failed: %v", err) - } - }() - } - - for i := 0; i < n; i++ { - <-done - } -} - -func TestConnPoolQueryRow(t *testing.T) { - t.Parallel() - - pool := createConnPool(t, 2) - defer pool.Close() - - var n int32 - err := pool.QueryRow("select 40+$1", 2).Scan(&n) - if err != nil { - t.Fatalf("pool.QueryRow Scan failed: %v", err) - } - - if n != 42 { - t.Errorf("Expected 42, got %d", n) - } - - stats := pool.Stat() - if stats.CurrentConnections != 1 || stats.AvailableConnections != 1 { - t.Fatalf("Unexpected connection pool stats: %v", stats) - } -} - -func TestConnPoolExec(t *testing.T) { - t.Parallel() - - pool := createConnPool(t, 2) - defer pool.Close() - - results, err := pool.Exec("create temporary table foo(id integer primary key);") - if err != nil { - t.Fatalf("Unexpected error from pool.Exec: %v", err) - } - if results != "CREATE TABLE" { - t.Errorf("Unexpected results from Exec: %v", results) - } - - results, err = pool.Exec("insert into foo(id) values($1)", 1) - if err != nil { - t.Fatalf("Unexpected error from pool.Exec: %v", err) - } - if results != "INSERT 0 1" { - t.Errorf("Unexpected results from Exec: %v", results) - } - - results, err = pool.Exec("drop table foo;") - if err != nil { - t.Fatalf("Unexpected error from pool.Exec: %v", err) - } - if results != "DROP TABLE" { - t.Errorf("Unexpected results from Exec: %v", results) - } -} - -func TestConnPoolPrepare(t *testing.T) { - t.Parallel() - - pool := createConnPool(t, 2) - defer pool.Close() - - _, err := pool.Prepare("test", "select $1::varchar") - if err != nil { - t.Fatalf("Unable to prepare statement: %v", err) - } - - var s string - err = pool.QueryRow("test", "hello").Scan(&s) - if err != nil { - t.Errorf("Executing prepared statement failed: %v", err) - } - - if s != "hello" { - t.Errorf("Prepared statement did not return expected value: %v", s) - } - - err = pool.Deallocate("test") - if err != nil { - t.Errorf("Deallocate failed: %v", err) - } - - err = pool.QueryRow("test", "hello").Scan(&s) - if err, ok := err.(pgx.PgError); !(ok && err.Code == "42601") { - t.Errorf("Expected error calling deallocated prepared statement, but got: %v", err) - } -} - -func TestConnPoolPrepareDeallocatePrepare(t *testing.T) { - t.Parallel() - - pool := createConnPool(t, 2) - defer pool.Close() - - _, err := pool.Prepare("test", "select $1::varchar") - if err != nil { - t.Fatalf("Unable to prepare statement: %v", err) - } - err = pool.Deallocate("test") - if err != nil { - t.Fatalf("Unable to deallocate statement: %v", err) - } - _, err = pool.Prepare("test", "select $1::varchar") - if err != nil { - t.Fatalf("Unable to prepare statement: %v", err) - } - - var s string - err = pool.QueryRow("test", "hello").Scan(&s) - if err != nil { - t.Fatalf("Executing prepared statement failed: %v", err) - } - - if s != "hello" { - t.Errorf("Prepared statement did not return expected value: %v", s) - } -} - -func TestConnPoolPrepareWhenConnIsAlreadyAcquired(t *testing.T) { - t.Parallel() - - pool := createConnPool(t, 2) - defer pool.Close() - - testPreparedStatement := func(db queryRower, desc string) { - var s string - err := db.QueryRow("test", "hello").Scan(&s) - if err != nil { - t.Fatalf("%s. Executing prepared statement failed: %v", desc, err) - } - - if s != "hello" { - t.Fatalf("%s. Prepared statement did not return expected value: %v", desc, s) - } - } - - newReleaseOnce := func(c *pgx.Conn) func() { - var once sync.Once - return func() { - once.Do(func() { pool.Release(c) }) - } - } - - c1, err := pool.Acquire() - if err != nil { - t.Fatalf("Unable to acquire connection: %v", err) - } - c1Release := newReleaseOnce(c1) - defer c1Release() - - _, err = pool.Prepare("test", "select $1::varchar") - if err != nil { - t.Fatalf("Unable to prepare statement: %v", err) - } - - testPreparedStatement(pool, "pool") - - c1Release() - - c2, err := pool.Acquire() - if err != nil { - t.Fatalf("Unable to acquire connection: %v", err) - } - c2Release := newReleaseOnce(c2) - defer c2Release() - - // This conn will not be available and will be connection at this point - c3, err := pool.Acquire() - if err != nil { - t.Fatalf("Unable to acquire connection: %v", err) - } - c3Release := newReleaseOnce(c3) - defer c3Release() - - testPreparedStatement(c2, "c2") - testPreparedStatement(c3, "c3") - - c2Release() - c3Release() - - err = pool.Deallocate("test") - if err != nil { - t.Errorf("Deallocate failed: %v", err) - } - - var s string - err = pool.QueryRow("test", "hello").Scan(&s) - if err, ok := err.(pgx.PgError); !(ok && err.Code == "42601") { - t.Errorf("Expected error calling deallocated prepared statement, but got: %v", err) - } -} - -func TestConnPoolBeginBatch(t *testing.T) { - t.Parallel() - - pool := createConnPool(t, 2) - defer pool.Close() - - batch := pool.BeginBatch() - batch.Queue("select n from generate_series(0,5) n", - nil, - nil, - []int16{pgx.BinaryFormatCode}, - ) - batch.Queue("select n from generate_series(0,5) n", - nil, - nil, - []int16{pgx.BinaryFormatCode}, - ) - - err := batch.Send(context.Background(), nil) - if err != nil { - t.Fatal(err) - } - - rows, err := batch.QueryResults() - if err != nil { - t.Error(err) - } - - for i := 0; rows.Next(); i++ { - var n int - if err := rows.Scan(&n); err != nil { - t.Error(err) - } - if n != i { - t.Errorf("n => %v, want %v", n, i) - } - } - - if rows.Err() != nil { - t.Error(rows.Err()) - } - - rows, err = batch.QueryResults() - if err != nil { - t.Error(err) - } - - for i := 0; rows.Next(); i++ { - var n int - if err := rows.Scan(&n); err != nil { - t.Error(err) - } - if n != i { - t.Errorf("n => %v, want %v", n, i) - } - } - - if rows.Err() != nil { - t.Error(rows.Err()) - } - - err = batch.Close() - if err != nil { - t.Fatal(err) - } -} - -func TestConnPoolBeginEx(t *testing.T) { - t.Parallel() - - pool := createConnPool(t, 2) - defer pool.Close() - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - tx, err := pool.BeginEx(ctx, nil) - if err == nil || tx != nil { - t.Fatal("Should not be able to create a tx") - } -} diff --git a/conn_test.go b/conn_test.go index 6f1d41eab..10959f062 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,2075 +1,1590 @@ package pgx_test import ( + "bytes" "context" - "crypto/tls" - "fmt" - "net" + "database/sql" "os" - "reflect" - "strconv" "strings" "sync" "testing" "time" - "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCrateDBConnect(t *testing.T) { t.Parallel() - if cratedbConnConfig == nil { - t.Skip("Skipping due to undefined cratedbConnConfig") + connString := os.Getenv("PGX_TEST_CRATEDB_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_CRATEDB_CONN_STRING") } - conn, err := pgx.Connect(*cratedbConnConfig) - if err != nil { - t.Fatalf("Unable to establish connection: %v", err) - } + conn, err := pgx.Connect(context.Background(), connString) + require.Nil(t, err) + defer closeConn(t, conn) + + assert.Equal(t, connString, conn.Config().ConnString()) var result int - err = conn.QueryRow("select 1 +1").Scan(&result) + err = conn.QueryRow(context.Background(), "select 1 +1").Scan(&result) if err != nil { t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) } if result != 2 { t.Errorf("bad result: %d", result) } - - err = conn.Close() - if err != nil { - t.Fatal("Unable to close connection") - } } func TestConnect(t *testing.T) { t.Parallel() - conn, err := pgx.Connect(*defaultConnConfig) + connString := os.Getenv("PGX_TEST_DATABASE") + config := mustParseConfig(t, connString) + + conn, err := pgx.ConnectConfig(context.Background(), config) if err != nil { t.Fatalf("Unable to establish connection: %v", err) } - if _, present := conn.RuntimeParams["server_version"]; !present { - t.Error("Runtime parameters not stored") - } - - if conn.PID() == 0 { - t.Error("Backend PID not stored") - } + assertConfigsEqual(t, config, conn.Config(), "Conn.Config() returns original config") var currentDB string - err = conn.QueryRow("select current_database()").Scan(¤tDB) + err = conn.QueryRow(context.Background(), "select current_database()").Scan(¤tDB) if err != nil { t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) } - if currentDB != defaultConnConfig.Database { - t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database) + if currentDB != config.Config.Database { + t.Errorf("Did not connect to specified database (%v)", config.Config.Database) } var user string - err = conn.QueryRow("select current_user").Scan(&user) + err = conn.QueryRow(context.Background(), "select current_user").Scan(&user) if err != nil { t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) } - if user != defaultConnConfig.User { - t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User) + if user != config.Config.User { + t.Errorf("Did not connect as specified user (%v)", config.Config.User) } - err = conn.Close() + err = conn.Close(context.Background()) if err != nil { t.Fatal("Unable to close connection") } } -func TestConnectWithUnixSocketDirectory(t *testing.T) { +func TestConnectWithPreferSimpleProtocol(t *testing.T) { t.Parallel() - // /.s.PGSQL.5432 - if unixSocketConnConfig == nil { - t.Skip("Skipping due to undefined unixSocketConnConfig") - } + connConfig := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + connConfig.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol - conn, err := pgx.Connect(*unixSocketConnConfig) - if err != nil { - t.Fatalf("Unable to establish connection: %v", err) - } + conn := mustConnect(t, connConfig) + defer closeConn(t, conn) - err = conn.Close() - if err != nil { - t.Fatal("Unable to close connection") - } + // If simple protocol is used we should be able to correctly scan the result + // into a pgtype.Text as the integer will have been encoded in text. + + var s pgtype.Text + err := conn.QueryRow(context.Background(), "select $1::int4", 42).Scan(&s) + require.NoError(t, err) + require.Equal(t, pgtype.Text{String: "42", Valid: true}, s) + + ensureConnValid(t, conn) +} + +func TestConnectConfigRequiresConnConfigFromParseConfig(t *testing.T) { + config := &pgx.ConnConfig{} + require.PanicsWithValue(t, "config must be created by ParseConfig", func() { + pgx.ConnectConfig(context.Background(), config) + }) +} + +func TestConfigContainsConnStr(t *testing.T) { + connStr := os.Getenv("PGX_TEST_DATABASE") + config, err := pgx.ParseConfig(connStr) + require.NoError(t, err) + assert.Equal(t, connStr, config.ConnString()) +} + +func TestConfigCopyReturnsEqualConfig(t *testing.T) { + connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" + original, err := pgx.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + assertConfigsEqual(t, original, copied, t.Name()) } -func TestConnectWithUnixSocketFile(t *testing.T) { +func TestConfigCopyCanBeUsedToConnect(t *testing.T) { + connString := os.Getenv("PGX_TEST_DATABASE") + original, err := pgx.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + assert.NotPanics(t, func() { + _, err = pgx.ConnectConfig(context.Background(), copied) + }) + assert.NoError(t, err) +} + +func TestParseConfigExtractsStatementCacheOptions(t *testing.T) { t.Parallel() - if unixSocketConnConfig == nil { - t.Skip("Skipping due to undefined unixSocketConnConfig") - } + config, err := pgx.ParseConfig("statement_cache_capacity=0") + require.NoError(t, err) + require.EqualValues(t, 0, config.StatementCacheCapacity) - connParams := *unixSocketConnConfig - connParams.Host = connParams.Host + "/.s.PGSQL.5432" - conn, err := pgx.Connect(connParams) - if err != nil { - t.Fatalf("Unable to establish connection: %v", err) - } + config, err = pgx.ParseConfig("statement_cache_capacity=42") + require.NoError(t, err) + require.EqualValues(t, 42, config.StatementCacheCapacity) - err = conn.Close() - if err != nil { - t.Fatal("Unable to close connection") - } + config, err = pgx.ParseConfig("description_cache_capacity=0") + require.NoError(t, err) + require.EqualValues(t, 0, config.DescriptionCacheCapacity) + + config, err = pgx.ParseConfig("description_cache_capacity=42") + require.NoError(t, err) + require.EqualValues(t, 42, config.DescriptionCacheCapacity) + + // default_query_exec_mode + // Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See + + config, err = pgx.ParseConfig("default_query_exec_mode=cache_statement") + require.NoError(t, err) + require.Equal(t, pgx.QueryExecModeCacheStatement, config.DefaultQueryExecMode) + + config, err = pgx.ParseConfig("default_query_exec_mode=cache_describe") + require.NoError(t, err) + require.Equal(t, pgx.QueryExecModeCacheDescribe, config.DefaultQueryExecMode) + + config, err = pgx.ParseConfig("default_query_exec_mode=describe_exec") + require.NoError(t, err) + require.Equal(t, pgx.QueryExecModeDescribeExec, config.DefaultQueryExecMode) + + config, err = pgx.ParseConfig("default_query_exec_mode=exec") + require.NoError(t, err) + require.Equal(t, pgx.QueryExecModeExec, config.DefaultQueryExecMode) + + config, err = pgx.ParseConfig("default_query_exec_mode=simple_protocol") + require.NoError(t, err) + require.Equal(t, pgx.QueryExecModeSimpleProtocol, config.DefaultQueryExecMode) } -func TestConnectWithTcp(t *testing.T) { +func TestParseConfigExtractsDefaultQueryExecMode(t *testing.T) { t.Parallel() - if tcpConnConfig == nil { - t.Skip("Skipping due to undefined tcpConnConfig") + for _, tt := range []struct { + connString string + defaultQueryExecMode pgx.QueryExecMode + }{ + {"", pgx.QueryExecModeCacheStatement}, + {"default_query_exec_mode=cache_statement", pgx.QueryExecModeCacheStatement}, + {"default_query_exec_mode=cache_describe", pgx.QueryExecModeCacheDescribe}, + {"default_query_exec_mode=describe_exec", pgx.QueryExecModeDescribeExec}, + {"default_query_exec_mode=exec", pgx.QueryExecModeExec}, + {"default_query_exec_mode=simple_protocol", pgx.QueryExecModeSimpleProtocol}, + } { + config, err := pgx.ParseConfig(tt.connString) + require.NoError(t, err) + require.Equalf(t, tt.defaultQueryExecMode, config.DefaultQueryExecMode, "connString: `%s`", tt.connString) + require.Empty(t, config.RuntimeParams["default_query_exec_mode"]) } +} - conn, err := pgx.Connect(*tcpConnConfig) - if err != nil { - t.Fatal("Unable to establish connection: " + err.Error()) - } +func TestParseConfigErrors(t *testing.T) { + t.Parallel() - err = conn.Close() - if err != nil { - t.Fatal("Unable to close connection") + for _, tt := range []struct { + connString string + expectedErrSubstring string + }{ + {"default_query_exec_mode=does_not_exist", "does_not_exist"}, + } { + config, err := pgx.ParseConfig(tt.connString) + require.Nil(t, config) + require.ErrorContains(t, err, tt.expectedErrSubstring) } } -func TestConnectWithTLS(t *testing.T) { +func TestExec(t *testing.T) { t.Parallel() - if tlsConnConfig == nil { - t.Skip("Skipping due to undefined tlsConnConfig") - } + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - conn, err := pgx.Connect(*tlsConnConfig) - if err != nil { - t.Fatal("Unable to establish connection: " + err.Error()) - } + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); results.String() != "CREATE TABLE" { + t.Error("Unexpected results from Exec") + } - err = conn.Close() - if err != nil { - t.Fatal("Unable to close connection") - } + // Accept parameters + if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); results.String() != "INSERT 0 1" { + t.Errorf("Unexpected results from Exec: %v", results) + } + + if results := mustExec(t, conn, "drop table foo;"); results.String() != "DROP TABLE" { + t.Error("Unexpected results from Exec") + } + + // Multiple statements can be executed -- last command tag is returned + if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); results.String() != "DROP TABLE" { + t.Error("Unexpected results from Exec") + } + + // Can execute longer SQL strings than sharedBufferSize + if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); results.String() != "SELECT 1" { + t.Errorf("Unexpected results from Exec: %v", results) + } + + // Exec no-op which does not return a command tag + if results := mustExec(t, conn, "--;"); results.String() != "" { + t.Errorf("Unexpected results from Exec: %v", results) + } + }) +} + +type testQueryRewriter struct { + sql string + args []any } -func TestConnectWithInvalidUser(t *testing.T) { +func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args []any) (newSQL string, newArgs []any, err error) { + return qr.sql, qr.args, nil +} + +func TestExecWithQueryRewriter(t *testing.T) { t.Parallel() - if invalidUserConnConfig == nil { - t.Skip("Skipping due to undefined invalidUserConnConfig") - } + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - _, err := pgx.Connect(*invalidUserConnConfig) - pgErr, ok := err.(pgx.PgError) - if !ok { - t.Fatalf("Expected to receive a PgError with code 28000, instead received: %v", err) - } - if pgErr.Code != "28000" && pgErr.Code != "28P01" { - t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) - } + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + qr := testQueryRewriter{sql: "select $1::int", args: []any{42}} + _, err := conn.Exec(ctx, "should be replaced", &qr) + require.NoError(t, err) + }) } -func TestConnectWithPlainTextPassword(t *testing.T) { +func TestExecFailure(t *testing.T) { t.Parallel() - if plainPasswordConnConfig == nil { - t.Skip("Skipping due to undefined plainPasswordConnConfig") - } + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - conn, err := pgx.Connect(*plainPasswordConnConfig) - if err != nil { - t.Fatal("Unable to establish connection: " + err.Error()) - } + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + if _, err := conn.Exec(context.Background(), "selct;"); err == nil { + t.Fatal("Expected SQL syntax error") + } - err = conn.Close() - if err != nil { - t.Fatal("Unable to close connection") - } + rows, _ := conn.Query(context.Background(), "select 1") + rows.Close() + if rows.Err() != nil { + t.Fatalf("Exec failure appears to have broken connection: %v", rows.Err()) + } + }) } -func TestConnectWithMD5Password(t *testing.T) { +func TestExecFailureWithArguments(t *testing.T) { t.Parallel() - if md5ConnConfig == nil { - t.Skip("Skipping due to undefined md5ConnConfig") - } + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - conn, err := pgx.Connect(*md5ConnConfig) - if err != nil { - t.Fatal("Unable to establish connection: " + err.Error()) - } + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(context.Background(), "selct $1;", 1) + if err == nil { + t.Fatal("Expected SQL syntax error") + } + assert.False(t, pgconn.SafeToRetry(err)) - err = conn.Close() - if err != nil { - t.Fatal("Unable to close connection") - } + _, err = conn.Exec(context.Background(), "select $1::varchar(1);", "1", "2") + require.Error(t, err) + }) } -func TestConnectWithTLSFallback(t *testing.T) { +func TestExecContextWithoutCancelation(t *testing.T) { t.Parallel() - if tlsConnConfig == nil { - t.Skip("Skipping due to undefined tlsConnConfig") - } + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - connConfig := *tlsConnConfig - connConfig.TLSConfig = &tls.Config{ServerName: "bogus.local"} // bogus ServerName should ensure certificate validation failure + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + ctx, cancelFunc := context.WithCancel(ctx) + defer cancelFunc() - conn, err := pgx.Connect(connConfig) - if err == nil { - t.Fatal("Expected failed connection, but succeeded") - } + commandTag, err := conn.Exec(ctx, "create temporary table foo(id integer primary key);") + if err != nil { + t.Fatal(err) + } + if commandTag.String() != "CREATE TABLE" { + t.Fatalf("Unexpected results from Exec: %v", commandTag) + } + assert.False(t, pgconn.SafeToRetry(err)) + }) +} - connConfig.UseFallbackTLS = true - connConfig.FallbackTLSConfig = tlsConnConfig.TLSConfig - connConfig.FallbackTLSConfig.InsecureSkipVerify = true +func TestExecContextFailureWithoutCancelation(t *testing.T) { + t.Parallel() - conn, err = pgx.Connect(connConfig) - if err != nil { - t.Fatal("Unable to establish connection: " + err.Error()) - } + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - err = conn.Close() - if err != nil { - t.Fatal("Unable to close connection") - } + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + ctx, cancelFunc := context.WithCancel(ctx) + defer cancelFunc() + + _, err := conn.Exec(ctx, "selct;") + if err == nil { + t.Fatal("Expected SQL syntax error") + } + assert.False(t, pgconn.SafeToRetry(err)) + + rows, _ := conn.Query(context.Background(), "select 1") + rows.Close() + if rows.Err() != nil { + t.Fatalf("ExecEx failure appears to have broken connection: %v", rows.Err()) + } + assert.False(t, pgconn.SafeToRetry(err)) + }) } -func TestConnectWithConnectionRefused(t *testing.T) { +func TestExecContextFailureWithoutCancelationWithArguments(t *testing.T) { t.Parallel() - // Presumably nothing is listening on 127.0.0.1:1 - bad := *defaultConnConfig - bad.Host = "127.0.0.1" - bad.Port = 1 + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - _, err := pgx.Connect(bad) - if err == nil { - t.Fatal("Expected error establishing connection to bad port") - } + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + ctx, cancelFunc := context.WithCancel(ctx) + defer cancelFunc() + + _, err := conn.Exec(ctx, "selct $1;", 1) + if err == nil { + t.Fatal("Expected SQL syntax error") + } + assert.False(t, pgconn.SafeToRetry(err)) + }) } -func TestConnectWithPreferSimpleProtocol(t *testing.T) { +func TestExecFailureCloseBefore(t *testing.T) { t.Parallel() - connConfig := *defaultConnConfig - connConfig.PreferSimpleProtocol = true + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + closeConn(t, conn) - conn := mustConnect(t, connConfig) + _, err := conn.Exec(context.Background(), "select 1") + require.Error(t, err) + assert.True(t, pgconn.SafeToRetry(err)) +} + +func TestExecPerQuerySimpleProtocol(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - // If simple protocol is used we should be able to correctly scan the result - // into a pgtype.Text as the integer will have been encoded in text. + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() - var s pgtype.Text - err := conn.QueryRow("select $1::int4", 42).Scan(&s) + commandTag, err := conn.Exec(ctx, "create temporary table foo(name varchar primary key);") if err != nil { t.Fatal(err) } - - if s.Get() != "42" { - t.Fatalf(`expected "42", got %v`, s) + if commandTag.String() != "CREATE TABLE" { + t.Fatalf("Unexpected results from Exec: %v", commandTag) } - ensureConnValid(t, conn) + commandTag, err = conn.Exec(ctx, + "insert into foo(name) values($1);", + pgx.QueryExecModeSimpleProtocol, + "bar'; drop table foo;--", + ) + if err != nil { + t.Fatal(err) + } + if commandTag.String() != "INSERT 0 1" { + t.Fatalf("Unexpected results from Exec: %v", commandTag) + } } -func TestConnectCustomDialer(t *testing.T) { +func TestPrepare(t *testing.T) { t.Parallel() - if customDialerConnConfig == nil { - t.Skip("Skipping due to undefined customDialerConnConfig") - } + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) - dialled := false - conf := *customDialerConnConfig - conf.Dial = func(network, address string) (net.Conn, error) { - dialled = true - return net.Dial(network, address) + _, err := conn.Prepare(context.Background(), "test", "select $1::varchar") + if err != nil { + t.Errorf("Unable to prepare statement: %v", err) + return } - conn, err := pgx.Connect(conf) + var s string + err = conn.QueryRow(context.Background(), "test", "hello").Scan(&s) if err != nil { - t.Fatalf("Unable to establish connection: %s", err) + t.Errorf("Executing prepared statement failed: %v", err) } - if !dialled { - t.Fatal("Connect did not use custom dialer") + + if s != "hello" { + t.Errorf("Prepared statement did not return expected value: %v", s) } - err = conn.Close() + err = conn.Deallocate(context.Background(), "test") if err != nil { - t.Fatal("Unable to close connection") + t.Errorf("conn.Deallocate failed: %v", err) } -} -func TestConnectWithRuntimeParams(t *testing.T) { - t.Parallel() - - connConfig := *defaultConnConfig - connConfig.RuntimeParams = map[string]string{ - "application_name": "pgxtest", - "search_path": "myschema", - } + // Create another prepared statement to ensure Deallocate left the connection + // in a working state and that we can reuse the prepared statement name. - conn, err := pgx.Connect(connConfig) + _, err = conn.Prepare(context.Background(), "test", "select $1::integer") if err != nil { - t.Fatalf("Unable to establish connection: %v", err) + t.Errorf("Unable to prepare statement: %v", err) + return } - defer conn.Close() - var s string - err = conn.QueryRow("show application_name").Scan(&s) + var n int32 + err = conn.QueryRow(context.Background(), "test", int32(1)).Scan(&n) if err != nil { - t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + t.Errorf("Executing prepared statement failed: %v", err) } - if s != "pgxtest" { - t.Errorf("Expected application_name to be %s, but it was %s", "pgxtest", s) + + if n != 1 { + t.Errorf("Prepared statement did not return expected value: %v", s) } - err = conn.QueryRow("show search_path").Scan(&s) + err = conn.DeallocateAll(context.Background()) if err != nil { - t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) - } - if s != "myschema" { - t.Errorf("Expected search_path to be %s, but it was %s", "myschema", s) + t.Errorf("conn.Deallocate failed: %v", err) } } -func TestParseURI(t *testing.T) { +func TestPrepareBadSQLFailure(t *testing.T) { t.Parallel() - tests := []struct { - url string - connParams pgx.ConnConfig - }{ - { - url: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - UseFallbackTLS: false, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgres://jack:secret@localhost:5432/mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgresql://jack:secret@localhost:5432/mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgres://jack@localhost:5432/mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgres://jack@localhost/mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgres://jack@localhost/mydb?application_name=pgxtest&search_path=myschema", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{ - "application_name": "pgxtest", - "search_path": "myschema", - }, - }, - }, + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + if _, err := conn.Prepare(context.Background(), "badSQL", "select foo"); err == nil { + t.Fatal("Prepare should have failed with syntax error") } - for i, tt := range tests { - connParams, err := pgx.ParseURI(tt.url) - if err != nil { - t.Errorf("%d. Unexpected error from pgx.ParseURL(%q) => %v", i, tt.url, err) - continue + ensureConnValid(t, conn) +} + +func TestPrepareIdempotency(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + for i := 0; i < 2; i++ { + _, err := conn.Prepare(context.Background(), "test", "select 42::integer") + if err != nil { + t.Fatalf("%d. Unable to prepare statement: %v", i, err) + } + + var n int32 + err = conn.QueryRow(context.Background(), "test").Scan(&n) + if err != nil { + t.Errorf("%d. Executing prepared statement failed: %v", i, err) + } + + if n != int32(42) { + t.Errorf("%d. Prepared statement did not return expected value: %v", i, n) + } } - if !reflect.DeepEqual(connParams, tt.connParams) { - t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams) + _, err := conn.Prepare(context.Background(), "test", "select 'fail'::varchar") + if err == nil { + t.Fatalf("Prepare statement with same name but different SQL should have failed but it didn't") + return } - } + }) } -func TestParseDSN(t *testing.T) { +func TestPrepareStatementCacheModes(t *testing.T) { t.Parallel() - tests := []struct { - url string - connParams pgx.ConnConfig - }{ - { - url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - RuntimeParams: map[string]string{}, - }, - }, - { - url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=prefer", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "user=jack password=secret host=localhost port=5432 dbname=mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "user=jack host=localhost port=5432 dbname=mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "user=jack host=localhost dbname=mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "user=jack host=localhost dbname=mydb application_name=pgxtest search_path=myschema", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{ - "application_name": "pgxtest", - "search_path": "myschema", - }, - }, - }, - { - url: "user=jack host=localhost dbname=mydb connect_timeout=10", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - } + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - for i, tt := range tests { - actual, err := pgx.ParseDSN(tt.url) - if err != nil { - t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err) - continue - } + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Prepare(context.Background(), "test", "select $1::text") + require.NoError(t, err) - testConnConfigEquals(t, tt.connParams, actual, strconv.Itoa(i)) - } + var s string + err = conn.QueryRow(context.Background(), "test", "hello").Scan(&s) + require.NoError(t, err) + require.Equal(t, "hello", s) + }) } -func TestParseConnectionString(t *testing.T) { +func TestPrepareWithDigestedName(t *testing.T) { t.Parallel() - tests := []struct { - url string - connParams pgx.ConnConfig - }{ - { - url: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - UseFallbackTLS: false, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgres://jack:secret@localhost:5432/mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgresql://jack:secret@localhost:5432/mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgres://jack@localhost:5432/mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgres://jack@localhost/mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgres://jack@localhost/mydb?application_name=pgxtest&search_path=myschema", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{ - "application_name": "pgxtest", - "search_path": "myschema", - }, - }, - }, - { - url: "postgres://jack@localhost/mydb?connect_timeout=10", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - RuntimeParams: map[string]string{}, - }, - }, - { - url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=prefer", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "user=jack password=secret host=localhost port=5432 dbname=mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "user=jack host=localhost port=5432 dbname=mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "user=jack host=localhost dbname=mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "user=jack host=localhost dbname=mydb application_name=pgxtest search_path=myschema", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{ - "application_name": "pgxtest", - "search_path": "myschema", - }, - }, - }, - } - - for i, tt := range tests { - actual, err := pgx.ParseConnectionString(tt.url) - if err != nil { - t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err) - continue - } - - testConnConfigEquals(t, tt.connParams, actual, strconv.Itoa(i)) - } -} - -func testConnConfigEquals(t *testing.T, expected pgx.ConnConfig, actual pgx.ConnConfig, testName string) { - if actual.Host != expected.Host { - t.Errorf("%s: expected Host to be %v got %v", testName, expected.Host, actual.Host) - } - if actual.Database != expected.Database { - t.Errorf("%s: expected Database to be %v got %v", testName, expected.Database, actual.Database) - } - if actual.Port != expected.Port { - t.Errorf("%s: expected Port to be %v got %v", testName, expected.Port, actual.Port) - } - if actual.Port != expected.Port { - t.Errorf("%s: expected Port to be %v got %v", testName, expected.Port, actual.Port) - } - if actual.User != expected.User { - t.Errorf("%s: expected User to be %v got %v", testName, expected.User, actual.User) - } - if actual.Password != expected.Password { - t.Errorf("%s: expected Password to be %v got %v", testName, expected.Password, actual.Password) - } - // Cannot test value of underlying Dialer stuct but can at least test if Dial func is set. - if (actual.Dial != nil) != (expected.Dial != nil) { - t.Errorf("%s: expected Dial mismatch", testName) - } - - if !reflect.DeepEqual(actual.RuntimeParams, expected.RuntimeParams) { - t.Errorf("%s: expected RuntimeParams to be %#v got %#v", testName, expected.RuntimeParams, actual.RuntimeParams) - } - - tlsTests := []struct { - name string - expected *tls.Config - actual *tls.Config - }{ - { - name: "TLSConfig", - expected: expected.TLSConfig, - actual: actual.TLSConfig, - }, - { - name: "FallbackTLSConfig", - expected: expected.FallbackTLSConfig, - actual: actual.FallbackTLSConfig, - }, - } - for _, tlsTest := range tlsTests { - name := tlsTest.name - expected := tlsTest.expected - actual := tlsTest.actual - - if expected == nil && actual != nil { - t.Errorf("%s / %s: expected nil, but it was set", testName, name) - } else if expected != nil && actual == nil { - t.Errorf("%s / %s: expected to be set, but got nil", testName, name) - } else if expected != nil && actual != nil { - if actual.InsecureSkipVerify != expected.InsecureSkipVerify { - t.Errorf("%s / %s: expected InsecureSkipVerify to be %v got %v", testName, name, expected.InsecureSkipVerify, actual.InsecureSkipVerify) - } - - if actual.ServerName != expected.ServerName { - t.Errorf("%s / %s: expected ServerName to be %v got %v", testName, name, expected.ServerName, actual.ServerName) - } - } - } - - if actual.UseFallbackTLS != expected.UseFallbackTLS { - t.Errorf("%s: expected UseFallbackTLS to be %v got %v", testName, expected.UseFallbackTLS, actual.UseFallbackTLS) - } -} - -func TestParseEnvLibpq(t *testing.T) { - pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"} - - savedEnv := make(map[string]string) - for _, n := range pgEnvvars { - savedEnv[n] = os.Getenv(n) - } - defer func() { - for k, v := range savedEnv { - err := os.Setenv(k, v) - if err != nil { - t.Fatalf("Unable to restore environment: %v", err) - } - } - }() - - tests := []struct { - name string - envvars map[string]string - config pgx.ConnConfig - }{ - { - name: "No environment", - envvars: map[string]string{}, - config: pgx.ConnConfig{ - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "Normal PG vars", - envvars: map[string]string{ - "PGHOST": "123.123.123.123", - "PGPORT": "7777", - "PGDATABASE": "foo", - "PGUSER": "bar", - "PGPASSWORD": "baz", - "PGCONNECT_TIMEOUT": "10", - }, - config: pgx.ConnConfig{ - Host: "123.123.123.123", - Port: 7777, - Database: "foo", - User: "bar", - Password: "baz", - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "application_name", - envvars: map[string]string{ - "PGAPPNAME": "pgxtest", - }, - config: pgx.ConnConfig{ - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{"application_name": "pgxtest"}, - }, - }, - { - name: "sslmode=disable", - envvars: map[string]string{ - "PGSSLMODE": "disable", - }, - config: pgx.ConnConfig{ - TLSConfig: nil, - UseFallbackTLS: false, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "sslmode=allow", - envvars: map[string]string{ - "PGSSLMODE": "allow", - }, - config: pgx.ConnConfig{ - TLSConfig: nil, - UseFallbackTLS: true, - FallbackTLSConfig: &tls.Config{InsecureSkipVerify: true}, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "sslmode=prefer", - envvars: map[string]string{ - "PGSSLMODE": "prefer", - }, - config: pgx.ConnConfig{ - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "sslmode=require", - envvars: map[string]string{ - "PGSSLMODE": "require", - }, - config: pgx.ConnConfig{ - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - UseFallbackTLS: false, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "sslmode=verify-ca", - envvars: map[string]string{ - "PGSSLMODE": "verify-ca", - }, - config: pgx.ConnConfig{ - TLSConfig: &tls.Config{}, - UseFallbackTLS: false, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "sslmode=verify-full", - envvars: map[string]string{ - "PGSSLMODE": "verify-full", - }, - config: pgx.ConnConfig{ - TLSConfig: &tls.Config{}, - UseFallbackTLS: false, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "sslmode=verify-full with host", - envvars: map[string]string{ - "PGHOST": "pgx.example", - "PGSSLMODE": "verify-full", - }, - config: pgx.ConnConfig{ - Host: "pgx.example", - TLSConfig: &tls.Config{ - ServerName: "pgx.example", - }, - UseFallbackTLS: false, - RuntimeParams: map[string]string{}, - }, - }, - } - - for _, tt := range tests { - for _, n := range pgEnvvars { - err := os.Unsetenv(n) - if err != nil { - t.Fatalf("%s: Unable to clear environment: %v", tt.name, err) - } - } - - for k, v := range tt.envvars { - err := os.Setenv(k, v) - if err != nil { - t.Fatalf("%s: Unable to set environment: %v", tt.name, err) - } - } - - actual, err := pgx.ParseEnvLibpq() - if err != nil { - t.Errorf("%s: Unexpected error from pgx.ParseLibpq() => %v", tt.name, err) - continue - } - - testConnConfigEquals(t, tt.config, actual, tt.name) - } -} - -func TestExec(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); results != "CREATE TABLE" { - t.Error("Unexpected results from Exec") - } - - // Accept parameters - if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); results != "INSERT 0 1" { - t.Errorf("Unexpected results from Exec: %v", results) - } - - if results := mustExec(t, conn, "drop table foo;"); results != "DROP TABLE" { - t.Error("Unexpected results from Exec") - } - - // Multiple statements can be executed -- last command tag is returned - if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); results != "DROP TABLE" { - t.Error("Unexpected results from Exec") - } - - // Can execute longer SQL strings than sharedBufferSize - if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); results != "SELECT 1" { - t.Errorf("Unexpected results from Exec: %v", results) - } - - // Exec no-op which does not return a command tag - if results := mustExec(t, conn, "--;"); results != "" { - t.Errorf("Unexpected results from Exec: %v", results) - } -} - -func TestExecFailure(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - if _, err := conn.Exec("selct;"); err == nil { - t.Fatal("Expected SQL syntax error") - } - - rows, _ := conn.Query("select 1") - rows.Close() - if rows.Err() != nil { - t.Fatalf("Exec failure appears to have broken connection: %v", rows.Err()) - } -} - -func TestExecExContextWithoutCancelation(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - - commandTag, err := conn.ExecEx(ctx, "create temporary table foo(id integer primary key);", nil) - if err != nil { - t.Fatal(err) - } - if commandTag != "CREATE TABLE" { - t.Fatalf("Unexpected results from ExecEx: %v", commandTag) - } -} - -func TestExecExContextFailureWithoutCancelation(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - - if _, err := conn.ExecEx(ctx, "selct;", nil); err == nil { - t.Fatal("Expected SQL syntax error") - } - - rows, _ := conn.Query("select 1") - rows.Close() - if rows.Err() != nil { - t.Fatalf("ExecEx failure appears to have broken connection: %v", rows.Err()) - } -} - -func TestExecExContextCancelationCancelsQuery(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - ctx, cancelFunc := context.WithCancel(context.Background()) - go func() { - time.Sleep(500 * time.Millisecond) - cancelFunc() - }() - - _, err := conn.ExecEx(ctx, "select pg_sleep(60)", nil) - if err != context.Canceled { - t.Fatalf("Expected context.Canceled err, got %v", err) - } - - ensureConnValid(t, conn) -} - -func TestExecExExtendedProtocol(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - - commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil) - if err != nil { - t.Fatal(err) - } - if commandTag != "CREATE TABLE" { - t.Fatalf("Unexpected results from ExecEx: %v", commandTag) - } - - commandTag, err = conn.ExecEx( - ctx, - "insert into foo(name) values($1);", - nil, - "bar", - ) - if err != nil { - t.Fatal(err) - } - if commandTag != "INSERT 0 1" { - t.Fatalf("Unexpected results from ExecEx: %v", commandTag) - } - - ensureConnValid(t, conn) -} - -func TestExecExSimpleProtocol(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - - commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil) - if err != nil { - t.Fatal(err) - } - if commandTag != "CREATE TABLE" { - t.Fatalf("Unexpected results from ExecEx: %v", commandTag) - } - - commandTag, err = conn.ExecEx( - ctx, - "insert into foo(name) values($1);", - &pgx.QueryExOptions{SimpleProtocol: true}, - "bar'; drop table foo;--", - ) - if err != nil { - t.Fatal(err) - } - if commandTag != "INSERT 0 1" { - t.Fatalf("Unexpected results from ExecEx: %v", commandTag) - } -} - -func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - mustExec(t, conn, "create temporary table foo(name varchar primary key);") - - commandTag, err := conn.ExecEx( - context.Background(), - "insert into foo(name) values($1);", - &pgx.QueryExOptions{ParameterOIDs: []pgtype.OID{pgtype.VarcharOID}}, - "bar'; drop table foo;--", - ) - if err != nil { - t.Fatal(err) - } - if commandTag != "INSERT 0 1" { - t.Fatalf("Unexpected results from ExecEx: %v", commandTag) - } -} - -func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - mustExec(t, conn, "create temporary table foo(name varchar primary key);") - - _, err := conn.ExecEx( - context.Background(), - "insert into foo(name) values($1);", - &pgx.QueryExOptions{ParameterOIDs: []pgtype.OID{pgtype.Int4OID}}, - "bar'; drop table foo;--", - ) - if err == nil { - t.Fatal("expected error but got none") - } -} - -func TestConnExecExIncorrectParameterOIDsAfterAnotherQuery(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - mustExec(t, conn, "create temporary table foo(name varchar primary key);") - - var s string - err := conn.QueryRow("insert into foo(name) values('baz') returning name;").Scan(&s) - if err != nil { - t.Errorf("Executing query failed: %v", err) - } - if s != "baz" { - t.Errorf("Query did not return expected value: %v", s) - } - - _, err = conn.ExecEx( - context.Background(), - "insert into foo(name) values($1);", - &pgx.QueryExOptions{ParameterOIDs: []pgtype.OID{pgtype.Int4OID}}, - "bar'; drop table foo;--", - ) - if err == nil { - t.Fatal("expected error but got none") - } -} - -func TestPrepare(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - _, err := conn.Prepare("test", "select $1::varchar") - if err != nil { - t.Errorf("Unable to prepare statement: %v", err) - return - } - - var s string - err = conn.QueryRow("test", "hello").Scan(&s) - if err != nil { - t.Errorf("Executing prepared statement failed: %v", err) - } - - if s != "hello" { - t.Errorf("Prepared statement did not return expected value: %v", s) - } - - err = conn.Deallocate("test") - if err != nil { - t.Errorf("conn.Deallocate failed: %v", err) - } - - // Create another prepared statement to ensure Deallocate left the connection - // in a working state and that we can reuse the prepared statement name. - - _, err = conn.Prepare("test", "select $1::integer") - if err != nil { - t.Errorf("Unable to prepare statement: %v", err) - return - } - - var n int32 - err = conn.QueryRow("test", int32(1)).Scan(&n) - if err != nil { - t.Errorf("Executing prepared statement failed: %v", err) - } - - if n != 1 { - t.Errorf("Prepared statement did not return expected value: %v", s) - } - - err = conn.Deallocate("test") - if err != nil { - t.Errorf("conn.Deallocate failed: %v", err) - } -} - -func TestPrepareBadSQLFailure(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - if _, err := conn.Prepare("badSQL", "select foo"); err == nil { - t.Fatal("Prepare should have failed with syntax error") - } - - ensureConnValid(t, conn) -} - -func TestPrepareQueryManyParameters(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - tests := []struct { - count int - succeed bool - }{ - { - count: 65534, - succeed: true, - }, - { - count: 65535, - succeed: true, - }, - { - count: 65536, - succeed: false, - }, - { - count: 65537, - succeed: false, - }, - } - - for i, tt := range tests { - params := make([]string, 0, tt.count) - args := make([]interface{}, 0, tt.count) - for j := 0; j < tt.count; j++ { - params = append(params, fmt.Sprintf("($%d::text)", j+1)) - args = append(args, strconv.Itoa(j)) - } + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - sql := "values" + strings.Join(params, ", ") + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + sql := "select $1::text" + sd, err := conn.Prepare(ctx, sql, sql) + require.NoError(t, err) + require.Equal(t, "stmt_2510cc7db17de3f42758a2a29c8b9ef8305d007b997ebdd6", sd.Name) - psName := fmt.Sprintf("manyParams%d", i) - _, err := conn.Prepare(psName, sql) - if err != nil { - if tt.succeed { - t.Errorf("%d. %v", i, err) - } - continue - } - if !tt.succeed { - t.Errorf("%d. Expected error but succeeded", i) - continue - } + var s string + err = conn.QueryRow(ctx, sql, "hello").Scan(&s) + require.NoError(t, err) + require.Equal(t, "hello", s) - rows, err := conn.Query(psName, args...) - if err != nil { - t.Errorf("conn.Query failed: %v", err) - continue - } + err = conn.Deallocate(ctx, sql) + require.NoError(t, err) + }) +} - for rows.Next() { - var s string - rows.Scan(&s) - } +// https://github.com/jackc/pgx/pull/1795 +func TestDeallocateInAbortedTransaction(t *testing.T) { + t.Parallel() - if rows.Err() != nil { - t.Errorf("Reading query result failed: %v", err) - } - } + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - ensureConnValid(t, conn) -} + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + tx, err := conn.Begin(ctx) + require.NoError(t, err) -func TestPrepareIdempotency(t *testing.T) { - t.Parallel() + sql := "select $1::text" + sd, err := tx.Prepare(ctx, sql, sql) + require.NoError(t, err) + require.Equal(t, "stmt_2510cc7db17de3f42758a2a29c8b9ef8305d007b997ebdd6", sd.Name) - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) + var s string + err = tx.QueryRow(ctx, sql, "hello").Scan(&s) + require.NoError(t, err) + require.Equal(t, "hello", s) - for i := 0; i < 2; i++ { - _, err := conn.Prepare("test", "select 42::integer") - if err != nil { - t.Fatalf("%d. Unable to prepare statement: %v", i, err) - } + _, err = tx.Exec(ctx, "select 1/0") // abort transaction with divide by zero error + require.Error(t, err) - var n int32 - err = conn.QueryRow("test").Scan(&n) - if err != nil { - t.Errorf("%d. Executing prepared statement failed: %v", i, err) - } + err = conn.Deallocate(ctx, sql) + require.NoError(t, err) - if n != int32(42) { - t.Errorf("%d. Prepared statement did not return expected value: %v", i, n) - } - } + err = tx.Rollback(ctx) + require.NoError(t, err) - _, err := conn.Prepare("test", "select 'fail'::varchar") - if err == nil { - t.Fatalf("Prepare statement with same name but different SQL should have failed but it didn't") - return - } + sd, err = conn.Prepare(ctx, sql, sql) + require.NoError(t, err) + require.Equal(t, "stmt_2510cc7db17de3f42758a2a29c8b9ef8305d007b997ebdd6", sd.Name) + }) } -func TestPrepareEx(t *testing.T) { +func TestDeallocateMissingPreparedStatementStillClearsFromPreparedStatementMap(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - _, err := conn.PrepareEx(context.Background(), "test", "select $1", &pgx.PrepareExOptions{ParameterOIDs: []pgtype.OID{pgtype.TextOID}}) - if err != nil { - t.Errorf("Unable to prepare statement: %v", err) - return - } + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Prepare(ctx, "ps", "select $1::text") + require.NoError(t, err) - var s string - err = conn.QueryRow("test", "hello").Scan(&s) - if err != nil { - t.Errorf("Executing prepared statement failed: %v", err) - } + _, err = conn.Exec(ctx, "deallocate ps") + require.NoError(t, err) - if s != "hello" { - t.Errorf("Prepared statement did not return expected value: %v", s) - } + err = conn.Deallocate(ctx, "ps") + require.NoError(t, err) - err = conn.Deallocate("test") - if err != nil { - t.Errorf("conn.Deallocate failed: %v", err) - } + _, err = conn.Prepare(ctx, "ps", "select $1::text, $2::text") + require.NoError(t, err) + + var s1, s2 string + err = conn.QueryRow(ctx, "ps", "hello", "world").Scan(&s1, &s2) + require.NoError(t, err) + require.Equal(t, "hello", s1) + require.Equal(t, "world", s2) + }) } func TestListenNotify(t *testing.T) { t.Parallel() - listener := mustConnect(t, *defaultConnConfig) + listener := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, listener) - if err := listener.Listen("chat"); err != nil { - t.Fatalf("Unable to start listening: %v", err) + if listener.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") } - notifier := mustConnect(t, *defaultConnConfig) + mustExec(t, listener, "listen chat") + + notifier := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, notifier) mustExec(t, notifier, "notify chat") // when notification is waiting on the socket to be read notification, err := listener.WaitForNotification(context.Background()) - if err != nil { - t.Fatalf("Unexpected error on WaitForNotification: %v", err) - } - if notification.Channel != "chat" { - t.Errorf("Did not receive notification on expected channel: %v", notification.Channel) - } + require.NoError(t, err) + assert.Equal(t, "chat", notification.Channel) // when notification has already been read during previous query mustExec(t, notifier, "notify chat") - rows, _ := listener.Query("select 1") + rows, _ := listener.Query(context.Background(), "select 1") rows.Close() - if rows.Err() != nil { - t.Fatalf("Unexpected error on Query: %v", rows.Err()) - } + require.NoError(t, rows.Err()) ctx, cancelFn := context.WithCancel(context.Background()) cancelFn() notification, err = listener.WaitForNotification(ctx) - if err != nil { - t.Fatalf("Unexpected error on WaitForNotification: %v", err) - } - if notification.Channel != "chat" { - t.Errorf("Did not receive notification on expected channel: %v", notification.Channel) - } + require.NoError(t, err) + assert.Equal(t, "chat", notification.Channel) // when timeout occurs - ctx, _ = context.WithTimeout(context.Background(), time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() notification, err = listener.WaitForNotification(ctx) - if err != context.DeadlineExceeded { - t.Errorf("WaitForNotification returned the wrong kind of error: %v", err) - } - if notification != nil { - t.Errorf("WaitForNotification returned an unexpected notification: %v", notification) - } + assert.True(t, pgconn.Timeout(err)) + assert.Nil(t, notification) // listener can listen again after a timeout mustExec(t, notifier, "notify chat") notification, err = listener.WaitForNotification(context.Background()) - if err != nil { - t.Fatalf("Unexpected error on WaitForNotification: %v", err) - } - if notification.Channel != "chat" { - t.Errorf("Did not receive notification on expected channel: %v", notification.Channel) - } + require.NoError(t, err) + assert.Equal(t, "chat", notification.Channel) } -func TestUnlistenSpecificChannel(t *testing.T) { +func TestListenNotifyWhileBusyIsSafe(t *testing.T) { t.Parallel() - listener := mustConnect(t, *defaultConnConfig) - defer closeConn(t, listener) + func() { + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + pgxtest.SkipCockroachDB(t, conn, "Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") + }() - if err := listener.Listen("unlisten_test"); err != nil { - t.Fatalf("Unable to start listening: %v", err) - } + listenerDone := make(chan bool) + notifierDone := make(chan bool) + listening := make(chan bool) + go func() { + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + defer func() { + listenerDone <- true + }() - notifier := mustConnect(t, *defaultConnConfig) - defer closeConn(t, notifier) + mustExec(t, conn, "listen busysafe") + listening <- true - mustExec(t, notifier, "notify unlisten_test") + for i := 0; i < 5000; i++ { + var sum int32 + var rowCount int32 - // when notification is waiting on the socket to be read - notification, err := listener.WaitForNotification(context.Background()) - if err != nil { - t.Fatalf("Unexpected error on WaitForNotification: %v", err) - } - if notification.Channel != "unlisten_test" { - t.Errorf("Did not receive notification on expected channel: %v", notification.Channel) - } + rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 100) + if err != nil { + t.Errorf("conn.Query failed: %v", err) + return + } - err = listener.Unlisten("unlisten_test") - if err != nil { - t.Fatalf("Unexpected error on Unlisten: %v", err) - } + for rows.Next() { + var n int32 + if err := rows.Scan(&n); err != nil { + t.Errorf("Row scan failed: %v", err) + return + } + sum += n + rowCount++ + } - // when notification has already been read during previous query - mustExec(t, notifier, "notify unlisten_test") - rows, _ := listener.Query("select 1") + if rows.Err() != nil { + t.Errorf("conn.Query failed: %v", rows.Err()) + return + } + + if sum != 5050 { + t.Errorf("Wrong rows sum: %v", sum) + return + } + + if rowCount != 100 { + t.Errorf("Wrong number of rows: %v", rowCount) + return + } + } + }() + + go func() { + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + defer func() { + notifierDone <- true + }() + + <-listening + + for i := 0; i < 100000; i++ { + mustExec(t, conn, "notify busysafe, 'hello'") + } + }() + + <-listenerDone + <-notifierDone +} + +func TestListenNotifySelfNotification(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + pgxtest.SkipCockroachDB(t, conn, "Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") + + mustExec(t, conn, "listen self") + + // Notify self and WaitForNotification immediately + mustExec(t, conn, "notify self") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + notification, err := conn.WaitForNotification(ctx) + require.NoError(t, err) + assert.Equal(t, "self", notification.Channel) + + // Notify self and do something else before WaitForNotification + mustExec(t, conn, "notify self") + + rows, _ := conn.Query(context.Background(), "select 1") rows.Close() if rows.Err() != nil { t.Fatalf("Unexpected error on Query: %v", rows.Err()) } - ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) - notification, err = listener.WaitForNotification(ctx) - if err != context.DeadlineExceeded { - t.Errorf("WaitForNotification returned the wrong kind of error: %v", err) - } + ctx, cncl := context.WithTimeout(context.Background(), time.Second) + defer cncl() + notification, err = conn.WaitForNotification(ctx) + require.NoError(t, err) + assert.Equal(t, "self", notification.Channel) } -func TestListenNotifyWhileBusyIsSafe(t *testing.T) { +func TestFatalRxError(t *testing.T) { t.Parallel() - listenerDone := make(chan bool) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + pgxtest.SkipCockroachDB(t, conn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") + + var wg sync.WaitGroup + wg.Add(1) go func() { - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - defer func() { - listenerDone <- true + defer wg.Done() + var n int32 + var s string + err := conn.QueryRow(context.Background(), "select 1::int4, pg_sleep(10)::varchar").Scan(&n, &s) + if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Severity == "FATAL" { + } else { + t.Errorf("Expected QueryRow Scan to return fatal PgError, but instead received %v", err) + return + } + }() + + otherConn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer otherConn.Close(context.Background()) + + if _, err := otherConn.Exec(context.Background(), "select pg_terminate_backend($1)", conn.PgConn().PID()); err != nil { + t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) + } + + wg.Wait() + + if !conn.IsClosed() { + t.Fatal("Connection should be closed") + } +} + +func TestFatalTxError(t *testing.T) { + t.Parallel() + + // Run timing sensitive test many times + for i := 0; i < 50; i++ { + func() { + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + pgxtest.SkipCockroachDB(t, conn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") + + otherConn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer otherConn.Close(context.Background()) + + _, err := otherConn.Exec(context.Background(), "select pg_terminate_backend($1)", conn.PgConn().PID()) + if err != nil { + t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) + } + + err = conn.QueryRow(context.Background(), "select 1").Scan(nil) + if err == nil { + t.Fatal("Expected error but none occurred") + } + + if !conn.IsClosed() { + t.Fatalf("Connection should be closed but isn't. Previous Query err: %v", err) + } }() + } +} + +func TestInsertBoolArray(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results.String() != "CREATE TABLE" { + t.Error("Unexpected results from Exec") + } + + // Accept parameters + if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); results.String() != "INSERT 0 1" { + t.Errorf("Unexpected results from Exec: %v", results) + } + }) +} - if err := conn.Listen("busysafe"); err != nil { - t.Fatalf("Unable to start listening: %v", err) +func TestInsertTimestampArray(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results.String() != "CREATE TABLE" { + t.Error("Unexpected results from Exec") + } + + // Accept parameters + if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); results.String() != "INSERT 0 1" { + t.Errorf("Unexpected results from Exec: %v", results) + } + }) +} + +func TestIdentifierSanitize(t *testing.T) { + t.Parallel() + + tests := []struct { + ident pgx.Identifier + expected string + }{ + { + ident: pgx.Identifier{`foo`}, + expected: `"foo"`, + }, + { + ident: pgx.Identifier{`select`}, + expected: `"select"`, + }, + { + ident: pgx.Identifier{`foo`, `bar`}, + expected: `"foo"."bar"`, + }, + { + ident: pgx.Identifier{`you should " not do this`}, + expected: `"you should "" not do this"`, + }, + { + ident: pgx.Identifier{`you should " not do this`, `please don't`}, + expected: `"you should "" not do this"."please don't"`, + }, + { + ident: pgx.Identifier{`you should ` + string([]byte{0}) + `not do this`}, + expected: `"you should not do this"`, + }, + } + + for i, tt := range tests { + qval := tt.ident.Sanitize() + if qval != tt.expected { + t.Errorf("%d. Expected Sanitize %v to return %v but it was %v", i, tt.ident, tt.expected, qval) + } + } +} + +func TestConnInitTypeMap(t *testing.T) { + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + // spot check that the standard postgres type names aren't qualified + nameOIDs := map[string]uint32{ + "_int8": pgtype.Int8ArrayOID, + "int8": pgtype.Int8OID, + "json": pgtype.JSONOID, + "text": pgtype.TextOID, + } + for name, oid := range nameOIDs { + dtByName, ok := conn.TypeMap().TypeForName(name) + if !ok { + t.Fatalf("Expected type named %v to be present", name) + } + dtByOID, ok := conn.TypeMap().TypeForOID(oid) + if !ok { + t.Fatalf("Expected type OID %v to be present", oid) + } + if dtByName != dtByOID { + t.Fatalf("Expected type named %v to be the same as type OID %v", name, oid) } + } - for i := 0; i < 5000; i++ { - var sum int32 - var rowCount int32 - - rows, err := conn.Query("select generate_series(1,$1)", 100) - if err != nil { - t.Fatalf("conn.Query failed: %v", err) - } - - for rows.Next() { - var n int32 - rows.Scan(&n) - sum += n - rowCount++ - } - - if rows.Err() != nil { - t.Fatalf("conn.Query failed: %v", err) - } + ensureConnValid(t, conn) +} - if sum != 5050 { - t.Fatalf("Wrong rows sum: %v", sum) - } +func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - if rowCount != 100 { - t.Fatalf("Wrong number of rows: %v", rowCount) - } + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)") - time.Sleep(1 * time.Microsecond) + var n uint64 + err := conn.QueryRow(context.Background(), "select $1::uint64", "42").Scan(&n) + if err != nil { + t.Fatal(err) } - }() - - go func() { - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - for i := 0; i < 100000; i++ { - mustExec(t, conn, "notify busysafe, 'hello'") - time.Sleep(1 * time.Microsecond) + if n != 42 { + t.Fatalf("Expected n to be 42, but was %v", n) } - }() - - <-listenerDone + }) } -func TestListenNotifySelfNotification(t *testing.T) { - t.Parallel() +func TestDomainType(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)") - if err := conn.Listen("self"); err != nil { - t.Fatalf("Unable to start listening: %v", err) - } + // Domain type uint64 is a PostgreSQL domain of underlying type numeric. - // Notify self and WaitForNotification immediately - mustExec(t, conn, "notify self") + // In the extended protocol preparing "select $1::uint64" appears to create a statement that expects a param OID of + // uint64 but a result OID of the underlying numeric. - ctx, _ := context.WithTimeout(context.Background(), time.Second) - notification, err := conn.WaitForNotification(ctx) - if err != nil { - t.Fatalf("Unexpected error on WaitForNotification: %v", err) - } - if notification.Channel != "self" { - t.Errorf("Did not receive notification on expected channel: %v", notification.Channel) - } + var s string + err := conn.QueryRow(ctx, "select $1::uint64", "24").Scan(&s) + require.NoError(t, err) + require.Equal(t, "24", s) - // Notify self and do something else before WaitForNotification - mustExec(t, conn, "notify self") + // Register type + uint64Type, err := conn.LoadType(ctx, "uint64") + require.NoError(t, err) + conn.TypeMap().RegisterType(uint64Type) - rows, _ := conn.Query("select 1") - rows.Close() - if rows.Err() != nil { - t.Fatalf("Unexpected error on Query: %v", rows.Err()) - } + var n uint64 + err = conn.QueryRow(ctx, "select $1::uint64", uint64(24)).Scan(&n) + require.NoError(t, err) - ctx, _ = context.WithTimeout(context.Background(), time.Second) - notification, err = conn.WaitForNotification(ctx) - if err != nil { - t.Fatalf("Unexpected error on WaitForNotification: %v", err) - } - if notification.Channel != "self" { - t.Errorf("Did not receive notification on expected channel: %v", notification.Channel) - } + // String is still an acceptable argument after registration + err = conn.QueryRow(ctx, "select $1::uint64", "7").Scan(&n) + if err != nil { + t.Fatal(err) + } + if n != 7 { + t.Fatalf("Expected n to be 7, but was %v", n) + } + }) } -func TestListenUnlistenSpecialCharacters(t *testing.T) { - t.Parallel() +func TestLoadTypeSameNameInDifferentSchemas(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does support composite types (https://github.com/cockroachdb/cockroach/issues/27792)") + + tx, err := conn.Begin(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + + _, err = tx.Exec(ctx, `create schema pgx_a; +create type pgx_a.point as (a text, b text); +create schema pgx_b; +create type pgx_b.point as (c text); +`) + require.NoError(t, err) + + // Register types + for _, typename := range []string{"pgx_a.point", "pgx_b.point"} { + // Obviously using conn while a tx is in use and registering a type after the connection has been established are + // really bad practices, but for the sake of convenience we do it in the test here. + dt, err := conn.LoadType(ctx, typename) + require.NoError(t, err) + conn.TypeMap().RegisterType(dt) + } - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) + type aPoint struct { + A string + B string + } - chanName := "special characters !@#{$%^&*()}" - if err := conn.Listen(chanName); err != nil { - t.Fatalf("Unable to start listening: %v", err) - } + type bPoint struct { + C string + } - if err := conn.Unlisten(chanName); err != nil { - t.Fatalf("Unable to stop listening: %v", err) - } + var a aPoint + var b bPoint + err = tx.QueryRow(ctx, `select '(foo,bar)'::pgx_a.point, '(baz)'::pgx_b.point`).Scan(&a, &b) + require.NoError(t, err) + require.Equal(t, aPoint{"foo", "bar"}, a) + require.Equal(t, bPoint{"baz"}, b) + }) } -func TestFatalRxError(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) +func TestLoadCompositeType(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - var n int32 - var s string - err := conn.QueryRow("select 1::int4, pg_sleep(10)::varchar").Scan(&n, &s) - if err == pgx.ErrDeadConn { - } else if pgErr, ok := err.(pgx.PgError); ok && pgErr.Severity == "FATAL" { - } else { - t.Fatalf("Expected QueryRow Scan to return fatal PgError or ErrDeadConn, but instead received %v", err) - } - }() + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does support composite types (https://github.com/cockroachdb/cockroach/issues/27792)") - otherConn, err := pgx.Connect(*defaultConnConfig) - if err != nil { - t.Fatalf("Unable to establish connection: %v", err) - } - defer otherConn.Close() + tx, err := conn.Begin(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) - if _, err := otherConn.Exec("select pg_terminate_backend($1)", conn.PID()); err != nil { - t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) - } + _, err = tx.Exec(ctx, "create type compositetype as (attr1 int, attr2 int)") + require.NoError(t, err) - wg.Wait() + _, err = tx.Exec(ctx, "alter type compositetype drop attribute attr1") + require.NoError(t, err) - if conn.IsAlive() { - t.Fatal("Connection should not be live but was") - } + _, err = conn.LoadType(ctx, "compositetype") + require.NoError(t, err) + }) } -func TestFatalTxError(t *testing.T) { - t.Parallel() +func TestLoadRangeType(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - // Run timing sensitive test many times - for i := 0; i < 50; i++ { - func() { - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does support range types") - otherConn, err := pgx.Connect(*defaultConnConfig) - if err != nil { - t.Fatalf("Unable to establish connection: %v", err) - } - defer otherConn.Close() + tx, err := conn.Begin(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) - _, err = otherConn.Exec("select pg_terminate_backend($1)", conn.PID()) - if err != nil { - t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) - } + _, err = tx.Exec(ctx, "create type examplefloatrange as range (subtype=float8, subtype_diff=float8mi)") + require.NoError(t, err) - _, err = conn.Query("select 1") - if err == nil { - t.Fatal("Expected error but none occurred") - } + // Register types + newRangeType, err := conn.LoadType(ctx, "examplefloatrange") + require.NoError(t, err) + conn.TypeMap().RegisterType(newRangeType) + conn.TypeMap().RegisterDefaultPgType(pgtype.Range[float64]{}, "examplefloatrange") - if conn.IsAlive() { - t.Fatalf("Connection should not be live but was. Previous Query err: %v", err) - } - }() - } + inputRangeType := pgtype.Range[float64]{ + Lower: 1.0, + Upper: 2.0, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Inclusive, + Valid: true, + } + var outputRangeType pgtype.Range[float64] + err = tx.QueryRow(ctx, "SELECT $1::examplefloatrange", inputRangeType).Scan(&outputRangeType) + require.NoError(t, err) + require.Equal(t, inputRangeType, outputRangeType) + }) } -func TestCommandTag(t *testing.T) { - t.Parallel() - - var tests = []struct { - commandTag pgx.CommandTag - rowsAffected int64 - }{ - {commandTag: "INSERT 0 5", rowsAffected: 5}, - {commandTag: "UPDATE 0", rowsAffected: 0}, - {commandTag: "UPDATE 1", rowsAffected: 1}, - {commandTag: "DELETE 0", rowsAffected: 0}, - {commandTag: "DELETE 1", rowsAffected: 1}, - {commandTag: "CREATE TABLE", rowsAffected: 0}, - {commandTag: "ALTER TABLE", rowsAffected: 0}, - {commandTag: "DROP TABLE", rowsAffected: 0}, - } - - for i, tt := range tests { - actual := tt.commandTag.RowsAffected() - if tt.rowsAffected != actual { - t.Errorf(`%d. "%s" should have affected %d rows but it was %d`, i, tt.commandTag, tt.rowsAffected, actual) +func TestLoadMultiRangeType(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does support range types") + pgxtest.SkipPostgreSQLVersionLessThan(t, conn, 14) // multirange data type was added in 14 postgresql + + tx, err := conn.Begin(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + + _, err = tx.Exec(ctx, "create type examplefloatrange as range (subtype=float8, subtype_diff=float8mi, multirange_type_name=examplefloatmultirange)") + require.NoError(t, err) + + // Register types + newRangeType, err := conn.LoadType(ctx, "examplefloatrange") + require.NoError(t, err) + conn.TypeMap().RegisterType(newRangeType) + conn.TypeMap().RegisterDefaultPgType(pgtype.Range[float64]{}, "examplefloatrange") + + newMultiRangeType, err := conn.LoadType(ctx, "examplefloatmultirange") + require.NoError(t, err) + conn.TypeMap().RegisterType(newMultiRangeType) + conn.TypeMap().RegisterDefaultPgType(pgtype.Multirange[pgtype.Range[float64]]{}, "examplefloatmultirange") + + inputMultiRangeType := pgtype.Multirange[pgtype.Range[float64]]{ + { + Lower: 1.0, + Upper: 2.0, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Inclusive, + Valid: true, + }, + { + Lower: 3.0, + Upper: 4.0, + LowerType: pgtype.Exclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, } - } + var outputMultiRangeType pgtype.Multirange[pgtype.Range[float64]] + err = tx.QueryRow(ctx, "SELECT $1::examplefloatmultirange", inputMultiRangeType).Scan(&outputMultiRangeType) + require.NoError(t, err) + require.Equal(t, inputMultiRangeType, outputMultiRangeType) + }) } -func TestInsertBoolArray(t *testing.T) { - t.Parallel() +func TestStmtCacheInvalidationConn(t *testing.T) { + ctx := context.Background() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results != "CREATE TABLE" { - t.Error("Unexpected results from Exec") + // create a table and fill it with some data + _, err := conn.Exec(ctx, ` + DROP TABLE IF EXISTS drop_cols; + CREATE TABLE drop_cols ( + id SERIAL PRIMARY KEY NOT NULL, + f1 int NOT NULL, + f2 int NOT NULL + ); + `) + require.NoError(t, err) + _, err = conn.Exec(ctx, "INSERT INTO drop_cols (f1, f2) VALUES (1, 2)") + require.NoError(t, err) + + getSQL := "SELECT * FROM drop_cols WHERE id = $1" + + // This query will populate the statement cache. We don't care about the result. + rows, err := conn.Query(ctx, getSQL, 1) + require.NoError(t, err) + rows.Close() + require.NoError(t, rows.Err()) + + // Now, change the schema of the table out from under the statement, making it invalid. + _, err = conn.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1") + require.NoError(t, err) + + // We must get an error the first time we try to re-execute a bad statement. + // It is up to the application to determine if it wants to try again. We punt to + // the application because there is no clear recovery path in the case of failed transactions + // or batch operations and because automatic retry is tricky and we don't want to get + // it wrong at such an importaint layer of the stack. + rows, err = conn.Query(ctx, getSQL, 1) + require.NoError(t, err) + rows.Next() + nextErr := rows.Err() + rows.Close() + for _, err := range []error{nextErr, rows.Err()} { + if err == nil { + t.Fatal(`expected "cached plan must not change result type": no error`) + } + if !strings.Contains(err.Error(), "cached plan must not change result type") { + t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error()) + } } - // Accept parameters - if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); results != "INSERT 0 1" { - t.Errorf("Unexpected results from Exec: %v", results) - } + // On retry, the statement should have been flushed from the cache. + rows, err = conn.Query(ctx, getSQL, 1) + require.NoError(t, err) + rows.Next() + err = rows.Err() + require.NoError(t, err) + rows.Close() + require.NoError(t, rows.Err()) + + ensureConnValid(t, conn) } -func TestInsertTimestampArray(t *testing.T) { - t.Parallel() +func TestStmtCacheInvalidationTx(t *testing.T) { + ctx := context.Background() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results != "CREATE TABLE" { - t.Error("Unexpected results from Exec") + if conn.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip("Server has non-standard prepare in errored transaction behavior (https://github.com/cockroachdb/cockroach/issues/84140)") } - // Accept parameters - if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); results != "INSERT 0 1" { - t.Errorf("Unexpected results from Exec: %v", results) - } -} + // create a table and fill it with some data + _, err := conn.Exec(ctx, ` + DROP TABLE IF EXISTS drop_cols; + CREATE TABLE drop_cols ( + id SERIAL PRIMARY KEY NOT NULL, + f1 int NOT NULL, + f2 int NOT NULL + ); + `) + require.NoError(t, err) + _, err = conn.Exec(ctx, "INSERT INTO drop_cols (f1, f2) VALUES (1, 2)") + require.NoError(t, err) -func TestCatchSimultaneousConnectionQueries(t *testing.T) { - t.Parallel() + tx, err := conn.Begin(ctx) + require.NoError(t, err) - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) + getSQL := "SELECT * FROM drop_cols WHERE id = $1" - rows1, err := conn.Query("select generate_series(1,$1)", 10) - if err != nil { - t.Fatalf("conn.Query failed: %v", err) + // This query will populate the statement cache. We don't care about the result. + rows, err := tx.Query(ctx, getSQL, 1) + require.NoError(t, err) + rows.Close() + require.NoError(t, rows.Err()) + + // Now, change the schema of the table out from under the statement, making it invalid. + _, err = tx.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1") + require.NoError(t, err) + + // We must get an error the first time we try to re-execute a bad statement. + // It is up to the application to determine if it wants to try again. We punt to + // the application because there is no clear recovery path in the case of failed transactions + // or batch operations and because automatic retry is tricky and we don't want to get + // it wrong at such an importaint layer of the stack. + rows, err = tx.Query(ctx, getSQL, 1) + require.NoError(t, err) + rows.Next() + nextErr := rows.Err() + rows.Close() + for _, err := range []error{nextErr, rows.Err()} { + if err == nil { + t.Fatal(`expected "cached plan must not change result type": no error`) + } + if !strings.Contains(err.Error(), "cached plan must not change result type") { + t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error()) + } } - defer rows1.Close() - _, err = conn.Query("select generate_series(1,$1)", 10) - if err != pgx.ErrConnBusy { - t.Fatalf("conn.Query should have failed with pgx.ErrConnBusy, but it was %v", err) - } + rows, _ = tx.Query(ctx, getSQL, 1) + rows.Close() + err = rows.Err() + // Retries within the same transaction are errors (really anything except a rollback + // will be an error in this transaction). + require.Error(t, err) + rows.Close() + + err = tx.Rollback(ctx) + require.NoError(t, err) + + // once we've rolled back, retries will work + rows, err = conn.Query(ctx, getSQL, 1) + require.NoError(t, err) + rows.Next() + err = rows.Err() + require.NoError(t, err) + rows.Close() + + ensureConnValid(t, conn) } -func TestCatchSimultaneousConnectionQueryAndExec(t *testing.T) { - t.Parallel() +func TestStmtCacheInvalidationConnWithBatch(t *testing.T) { + ctx := context.Background() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - rows, err := conn.Query("select generate_series(1,$1)", 10) - if err != nil { - t.Fatalf("conn.Query failed: %v", err) + if conn.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip("Test fails due to different CRDB behavior") } - defer rows.Close() - _, err = conn.Exec("create temporary table foo(spice timestamp[])") - if err != pgx.ErrConnBusy { - t.Fatalf("conn.Exec should have failed with pgx.ErrConnBusy, but it was %v", err) - } -} + // create a table and fill it with some data + _, err := conn.Exec(ctx, ` + DROP TABLE IF EXISTS drop_cols; + CREATE TABLE drop_cols ( + id SERIAL PRIMARY KEY NOT NULL, + f1 int NOT NULL, + f2 int NOT NULL + ); + `) + require.NoError(t, err) + _, err = conn.Exec(ctx, "INSERT INTO drop_cols (f1, f2) VALUES (1, 2)") + require.NoError(t, err) -type testLog struct { - lvl pgx.LogLevel - msg string - data map[string]interface{} -} + getSQL := "SELECT * FROM drop_cols WHERE id = $1" -type testLogger struct { - logs []testLog -} + // This query will populate the statement cache. We don't care about the result. + rows, err := conn.Query(ctx, getSQL, 1) + require.NoError(t, err) + rows.Close() + require.NoError(t, rows.Err()) + + // Now, change the schema of the table out from under the statement, making it invalid. + _, err = conn.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1") + require.NoError(t, err) + + // We must get an error the first time we try to re-execute a bad statement. + // It is up to the application to determine if it wants to try again. We punt to + // the application because there is no clear recovery path in the case of failed transactions + // or batch operations and because automatic retry is tricky and we don't want to get + // it wrong at such an importaint layer of the stack. + batch := &pgx.Batch{} + batch.Queue(getSQL, 1) + br := conn.SendBatch(ctx, batch) + rows, err = br.Query() + require.Error(t, err) + rows.Next() + nextErr := rows.Err() + rows.Close() + err = br.Close() + require.Error(t, err) + for _, err := range []error{nextErr, rows.Err()} { + if err == nil { + t.Fatal(`expected "cached plan must not change result type": no error`) + } + if !strings.Contains(err.Error(), "cached plan must not change result type") { + t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error()) + } + } + + // On retry, the statement should have been flushed from the cache. + batch = &pgx.Batch{} + batch.Queue(getSQL, 1) + br = conn.SendBatch(ctx, batch) + rows, err = br.Query() + require.NoError(t, err) + rows.Next() + err = rows.Err() + require.NoError(t, err) + rows.Close() + require.NoError(t, rows.Err()) + err = br.Close() + require.NoError(t, err) -func (l *testLogger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { - l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data}) + ensureConnValid(t, conn) } -func TestSetLogger(t *testing.T) { - t.Parallel() +func TestStmtCacheInvalidationTxWithBatch(t *testing.T) { + ctx := context.Background() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - l1 := &testLogger{} - oldLogger := conn.SetLogger(l1) - if oldLogger != nil { - t.Fatalf("Expected conn.SetLogger to return %v, but it was %v", nil, oldLogger) + if conn.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip("Server has non-standard prepare in errored transaction behavior (https://github.com/cockroachdb/cockroach/issues/84140)") } - if err := conn.Listen("foo"); err != nil { - t.Fatal(err) - } + // create a table and fill it with some data + _, err := conn.Exec(ctx, ` + DROP TABLE IF EXISTS drop_cols; + CREATE TABLE drop_cols ( + id SERIAL PRIMARY KEY NOT NULL, + f1 int NOT NULL, + f2 int NOT NULL + ); + `) + require.NoError(t, err) + _, err = conn.Exec(ctx, "INSERT INTO drop_cols (f1, f2) VALUES (1, 2)") + require.NoError(t, err) - if len(l1.logs) == 0 { - t.Fatal("Expected new logger l1 to be called, but it wasn't") - } + tx, err := conn.Begin(ctx) + require.NoError(t, err) - l2 := &testLogger{} - oldLogger = conn.SetLogger(l2) - if oldLogger != l1 { - t.Fatalf("Expected conn.SetLogger to return %v, but it was %v", l1, oldLogger) - } + getSQL := "SELECT * FROM drop_cols WHERE id = $1" - if err := conn.Listen("bar"); err != nil { - t.Fatal(err) + // This query will populate the statement cache. We don't care about the result. + rows, err := tx.Query(ctx, getSQL, 1) + require.NoError(t, err) + rows.Close() + require.NoError(t, rows.Err()) + + // Now, change the schema of the table out from under the statement, making it invalid. + _, err = tx.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1") + require.NoError(t, err) + + // We must get an error the first time we try to re-execute a bad statement. + // It is up to the application to determine if it wants to try again. We punt to + // the application because there is no clear recovery path in the case of failed transactions + // or batch operations and because automatic retry is tricky and we don't want to get + // it wrong at such an importaint layer of the stack. + batch := &pgx.Batch{} + batch.Queue(getSQL, 1) + br := tx.SendBatch(ctx, batch) + rows, err = br.Query() + require.Error(t, err) + rows.Next() + nextErr := rows.Err() + rows.Close() + err = br.Close() + require.Error(t, err) + for _, err := range []error{nextErr, rows.Err()} { + if err == nil { + t.Fatal(`expected "cached plan must not change result type": no error`) + } + if !strings.Contains(err.Error(), "cached plan must not change result type") { + t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error()) + } } - if len(l2.logs) == 0 { - t.Fatal("Expected new logger l2 to be called, but it wasn't") - } + batch = &pgx.Batch{} + batch.Queue(getSQL, 1) + br = tx.SendBatch(ctx, batch) + rows, err = br.Query() + require.Error(t, err) + rows.Close() + err = rows.Err() + // Retries within the same transaction are errors (really anything except a rollback + // will be an error in this transaction). + require.Error(t, err) + rows.Close() + err = br.Close() + require.Error(t, err) + + err = tx.Rollback(ctx) + require.NoError(t, err) + + // once we've rolled back, retries will work + batch = &pgx.Batch{} + batch.Queue(getSQL, 1) + br = conn.SendBatch(ctx, batch) + rows, err = br.Query() + require.NoError(t, err) + rows.Next() + err = rows.Err() + require.NoError(t, err) + rows.Close() + err = br.Close() + require.NoError(t, err) + + ensureConnValid(t, conn) } -func TestSetLogLevel(t *testing.T) { - t.Parallel() +func TestInsertDurationInterval(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(context.Background(), "create temporary table t(duration INTERVAL(0) NOT NULL)") + require.NoError(t, err) - logger := &testLogger{} - conn.SetLogger(logger) + result, err := conn.Exec(context.Background(), "insert into t(duration) values($1)", time.Minute) + require.NoError(t, err) - if _, err := conn.SetLogLevel(0); err != pgx.ErrInvalidLogLevel { - t.Fatal("SetLogLevel with invalid level did not return error") - } + n := result.RowsAffected() + require.EqualValues(t, 1, n) + }) +} - if _, err := conn.SetLogLevel(pgx.LogLevelNone); err != nil { - t.Fatal(err) - } +func TestRawValuesUnderlyingMemoryReused(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var buf []byte - if err := conn.Listen("foo"); err != nil { - t.Fatal(err) - } + rows, err := conn.Query(ctx, `select 1::int`) + require.NoError(t, err) - if len(logger.logs) != 0 { - t.Fatalf("Expected logger not to be called, but it was: %v", logger.logs) - } + for rows.Next() { + buf = rows.RawValues()[0] + } - if _, err := conn.SetLogLevel(pgx.LogLevelTrace); err != nil { - t.Fatal(err) - } + require.NoError(t, rows.Err()) - if err := conn.Listen("bar"); err != nil { - t.Fatal(err) - } + original := make([]byte, len(buf)) + copy(original, buf) - if len(logger.logs) == 0 { - t.Fatal("Expected logger to be called, but it wasn't") - } + for i := 0; i < 1_000_000; i++ { + rows, err := conn.Query(ctx, `select $1::int`, i) + require.NoError(t, err) + rows.Close() + require.NoError(t, rows.Err()) + + if !bytes.Equal(original, buf) { + return + } + } + + t.Fatal("expected buffer from RawValues to be overwritten by subsequent queries but it was not") + }) } -func TestIdentifierSanitize(t *testing.T) { - t.Parallel() +// https://github.com/jackc/pgx/issues/1847 +func TestConnDeallocateInvalidatedCachedStatementsWhenCanceled(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - tests := []struct { - ident pgx.Identifier - expected string - }{ - { - ident: pgx.Identifier{`foo`}, - expected: `"foo"`, - }, - { - ident: pgx.Identifier{`select`}, - expected: `"select"`, - }, - { - ident: pgx.Identifier{`foo`, `bar`}, - expected: `"foo"."bar"`, - }, - { - ident: pgx.Identifier{`you should " not do this`}, - expected: `"you should "" not do this"`, - }, - { - ident: pgx.Identifier{`you should " not do this`, `please don't`}, - expected: `"you should "" not do this"."please don't"`, - }, - } + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "CockroachDB returns decimal instead of integer for integer division") - for i, tt := range tests { - qval := tt.ident.Sanitize() - if qval != tt.expected { - t.Errorf("%d. Expected Sanitize %v to return %v but it was %v", i, tt.ident, tt.expected, qval) - } - } + var n int32 + err := conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + + // Divide by zero causes an error. baseRows.Close() calls Invalidate on the statement cache whenever an error was + // encountered by the query. Use this to purposely invalidate the query. If we had access to private fields of conn + // we could call conn.statementCache.InvalidateAll() instead. + err = conn.QueryRow(ctx, "select 1 / $1::int", 0).Scan(&n) + require.Error(t, err) + + ctx2, cancel2 := context.WithCancel(ctx) + cancel2() + err = conn.QueryRow(ctx2, "select 1 / $1::int", 1).Scan(&n) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + + err = conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + }) } -func TestConnOnNotice(t *testing.T) { +// https://github.com/jackc/pgx/issues/1847 +func TestConnDeallocateInvalidatedCachedStatementsInTransactionWithBatch(t *testing.T) { t.Parallel() - var msg string + ctx := context.Background() - connConfig := *defaultConnConfig - connConfig.OnNotice = func(c *pgx.Conn, notice *pgx.Notice) { - msg = notice.Message - } - conn := mustConnect(t, connConfig) - defer closeConn(t, conn) + connString := os.Getenv("PGX_TEST_DATABASE") + config := mustParseConfig(t, connString) + config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement + config.StatementCacheCapacity = 2 - _, err := conn.Exec(`do $$ -begin - raise notice 'hello, world'; -end$$;`) - if err != nil { - t.Fatal(err) - } + conn, err := pgx.ConnectConfig(ctx, config) + require.NoError(t, err) - if msg != "hello, world" { - t.Errorf("msg => %v, want %v", msg, "hello, world") - } + tx, err := conn.Begin(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + + _, err = tx.Exec(ctx, "select $1::int + 1", 1) + require.NoError(t, err) + + _, err = tx.Exec(ctx, "select $1::int + 2", 1) + require.NoError(t, err) + + // This should invalidate the first cached statement. + _, err = tx.Exec(ctx, "select $1::int + 3", 1) + require.NoError(t, err) + + batch := &pgx.Batch{} + batch.Queue("select $1::int + 1", 1) + err = tx.SendBatch(ctx, batch).Close() + require.NoError(t, err) + + err = tx.Rollback(ctx) + require.NoError(t, err) ensureConnValid(t, conn) } -func TestConnInitConnInfo(t *testing.T) { - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) +func TestErrNoRows(t *testing.T) { + t.Parallel() - // spot check that the standard postgres type names aren't qualified - nameOIDs := map[string]pgtype.OID{ - "_int8": pgtype.Int8ArrayOID, - "int8": pgtype.Int8OID, - "json": pgtype.JSONOID, - "text": pgtype.TextOID, - } - for name, oid := range nameOIDs { - dtByName, ok := conn.ConnInfo.DataTypeForName(name) - if !ok { - t.Fatalf("Expected type named %v to be present", name) - } - dtByOID, ok := conn.ConnInfo.DataTypeForOID(oid) - if !ok { - t.Fatalf("Expected type OID %v to be present", oid) - } - if dtByName != dtByOID { - t.Fatalf("Expected type named %v to be the same as type OID %v", name, oid) - } - } + // ensure we preserve old error message + require.Equal(t, "no rows in result set", pgx.ErrNoRows.Error()) - ensureConnValid(t, conn) + require.ErrorIs(t, pgx.ErrNoRows, sql.ErrNoRows, "pgx.ErrNowRows must match sql.ErrNoRows") } diff --git a/copy_from.go b/copy_from.go index 8b7c3d5bd..abcd22396 100644 --- a/copy_from.go +++ b/copy_from.go @@ -2,21 +2,22 @@ package pgx import ( "bytes" + "context" "fmt" + "io" - "github.com/jackc/pgx/pgio" - "github.com/jackc/pgx/pgproto3" - "github.com/pkg/errors" + "github.com/jackc/pgx/v5/internal/pgio" + "github.com/jackc/pgx/v5/pgconn" ) // CopyFromRows returns a CopyFromSource interface over the provided rows slice // making it usable by *Conn.CopyFrom. -func CopyFromRows(rows [][]interface{}) CopyFromSource { +func CopyFromRows(rows [][]any) CopyFromSource { return ©FromRows{rows: rows, idx: -1} } type copyFromRows struct { - rows [][]interface{} + rows [][]any idx int } @@ -25,7 +26,7 @@ func (ctr *copyFromRows) Next() bool { return ctr.idx < len(ctr.rows) } -func (ctr *copyFromRows) Values() ([]interface{}, error) { +func (ctr *copyFromRows) Values() ([]any, error) { return ctr.rows[ctr.idx], nil } @@ -33,6 +34,63 @@ func (ctr *copyFromRows) Err() error { return nil } +// CopyFromSlice returns a CopyFromSource interface over a dynamic func +// making it usable by *Conn.CopyFrom. +func CopyFromSlice(length int, next func(int) ([]any, error)) CopyFromSource { + return ©FromSlice{next: next, idx: -1, len: length} +} + +type copyFromSlice struct { + next func(int) ([]any, error) + idx int + len int + err error +} + +func (cts *copyFromSlice) Next() bool { + cts.idx++ + return cts.idx < cts.len +} + +func (cts *copyFromSlice) Values() ([]any, error) { + values, err := cts.next(cts.idx) + if err != nil { + cts.err = err + } + return values, err +} + +func (cts *copyFromSlice) Err() error { + return cts.err +} + +// CopyFromFunc returns a CopyFromSource interface that relies on nxtf for values. +// nxtf returns rows until it either signals an 'end of data' by returning row=nil and err=nil, +// or it returns an error. If nxtf returns an error, the copy is aborted. +func CopyFromFunc(nxtf func() (row []any, err error)) CopyFromSource { + return ©FromFunc{next: nxtf} +} + +type copyFromFunc struct { + next func() ([]any, error) + valueRow []any + err error +} + +func (g *copyFromFunc) Next() bool { + g.valueRow, g.err = g.next() + // only return true if valueRow exists and no error + return g.valueRow != nil && g.err == nil +} + +func (g *copyFromFunc) Values() ([]any, error) { + return g.valueRow, g.err +} + +func (g *copyFromFunc) Err() error { + return g.err +} + // CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data. type CopyFromSource interface { // Next returns true if there is another row and makes the next row data @@ -41,7 +99,7 @@ type CopyFromSource interface { Next() bool // Values returns the values for the current row. - Values() ([]interface{}, error) + Values() ([]any, error) // Err returns any error that has been encountered by the CopyFromSource. If // this is not nil *Conn.CopyFrom will abort the copy. @@ -54,42 +112,17 @@ type copyFrom struct { columnNames []string rowSrc CopyFromSource readerErrChan chan error + mode QueryExecMode } -func (ct *copyFrom) readUntilReadyForQuery() { - for { - msg, err := ct.conn.rxMsg() - if err != nil { - ct.readerErrChan <- err - close(ct.readerErrChan) - return - } - - switch msg := msg.(type) { - case *pgproto3.ReadyForQuery: - ct.conn.rxReadyForQuery(msg) - close(ct.readerErrChan) - return - case *pgproto3.CommandComplete: - case *pgproto3.ErrorResponse: - ct.readerErrChan <- ct.conn.rxErrorResponse(msg) - default: - err = ct.conn.processContextFreeMsg(msg) - if err != nil { - ct.readerErrChan <- ct.conn.processContextFreeMsg(msg) - } - } +func (ct *copyFrom) run(ctx context.Context) (int64, error) { + if ct.conn.copyFromTracer != nil { + ctx = ct.conn.copyFromTracer.TraceCopyFromStart(ctx, ct.conn, TraceCopyFromStartData{ + TableName: ct.tableName, + ColumnNames: ct.columnNames, + }) } -} -func (ct *copyFrom) waitForReaderDone() error { - var err error - for err = range ct.readerErrChan { - } - return err -} - -func (ct *copyFrom) run() (int, error) { quotedTableName := ct.tableName.Sanitize() cbuf := &bytes.Buffer{} for i, cn := range ct.columnNames { @@ -100,152 +133,144 @@ func (ct *copyFrom) run() (int, error) { } quotedColumnNames := cbuf.String() - ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName)) - if err != nil { - return 0, err - } - - err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) - if err != nil { - return 0, err - } - - err = ct.conn.readUntilCopyInResponse() - if err != nil { - return 0, err + var sd *pgconn.StatementDescription + switch ct.mode { + case QueryExecModeExec, QueryExecModeSimpleProtocol: + // These modes don't support the binary format. Before the inclusion of the + // QueryExecModes, Conn.Prepare was called on every COPY operation to get + // the OIDs. These prepared statements were not cached. + // + // Since that's the same behavior provided by QueryExecModeDescribeExec, + // we'll default to that mode. + ct.mode = QueryExecModeDescribeExec + fallthrough + case QueryExecModeCacheStatement, QueryExecModeCacheDescribe, QueryExecModeDescribeExec: + var err error + sd, err = ct.conn.getStatementDescription( + ctx, + ct.mode, + fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName), + ) + if err != nil { + return 0, fmt.Errorf("statement description failed: %w", err) + } + default: + return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode) } - go ct.readUntilReadyForQuery() - defer ct.waitForReaderDone() - - buf := ct.conn.wbuf - buf = append(buf, copyData) - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) + r, w := io.Pipe() + doneChan := make(chan struct{}) - buf = append(buf, "PGCOPY\n\377\r\n\000"...) - buf = pgio.AppendInt32(buf, 0) - buf = pgio.AppendInt32(buf, 0) + go func() { + defer close(doneChan) - var sentCount int + // Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283. + buf := ct.conn.wbuf - for ct.rowSrc.Next() { - select { - case err = <-ct.readerErrChan: - return 0, err - default: - } + buf = append(buf, "PGCOPY\n\377\r\n\000"...) + buf = pgio.AppendInt32(buf, 0) + buf = pgio.AppendInt32(buf, 0) - if len(buf) > 65536 { - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - _, err = ct.conn.conn.Write(buf) + moreRows := true + for moreRows { + var err error + moreRows, buf, err = ct.buildCopyBuf(buf, sd) if err != nil { - ct.conn.die(err) - return 0, err + w.CloseWithError(err) + return } - // Directly manipulate wbuf to reset to reuse the same buffer - buf = buf[0:5] - } - - sentCount++ - - values, err := ct.rowSrc.Values() - if err != nil { - ct.cancelCopyIn() - return 0, err - } - if len(values) != len(ct.columnNames) { - ct.cancelCopyIn() - return 0, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) - } + if ct.rowSrc.Err() != nil { + w.CloseWithError(ct.rowSrc.Err()) + return + } - buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) - for i, val := range values { - buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val) - if err != nil { - ct.cancelCopyIn() - return 0, err + if len(buf) > 0 { + _, err = w.Write(buf) + if err != nil { + w.Close() + return + } } + buf = buf[:0] } - } - if ct.rowSrc.Err() != nil { - ct.cancelCopyIn() - return 0, ct.rowSrc.Err() - } + w.Close() + }() - buf = pgio.AppendInt16(buf, -1) // terminate the copy stream - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) - buf = append(buf, copyDone) - buf = pgio.AppendInt32(buf, 4) + r.Close() + <-doneChan - _, err = ct.conn.conn.Write(buf) - if err != nil { - ct.conn.die(err) - return 0, err + if ct.conn.copyFromTracer != nil { + ct.conn.copyFromTracer.TraceCopyFromEnd(ctx, ct.conn, TraceCopyFromEndData{ + CommandTag: commandTag, + Err: err, + }) } - err = ct.waitForReaderDone() - if err != nil { - return 0, err - } - return sentCount, nil + return commandTag.RowsAffected(), err } -func (c *Conn) readUntilCopyInResponse() error { - for { - msg, err := c.rxMsg() +func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) { + const sendBufSize = 65536 - 5 // The packet has a 5-byte header + lastBufLen := 0 + largestRowLen := 0 + + for ct.rowSrc.Next() { + lastBufLen = len(buf) + + values, err := ct.rowSrc.Values() if err != nil { - return err + return false, nil, err + } + if len(values) != len(ct.columnNames) { + return false, nil, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) } - switch msg := msg.(type) { - case *pgproto3.CopyInResponse: - return nil - default: - err = c.processContextFreeMsg(msg) + buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) + for i, val := range values { + buf, err = encodeCopyValue(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val) if err != nil { - return err + return false, nil, err } } - } -} -func (ct *copyFrom) cancelCopyIn() error { - buf := ct.conn.wbuf - buf = append(buf, copyFail) - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, "client error: abort"...) - buf = append(buf, 0) - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + rowLen := len(buf) - lastBufLen + if rowLen > largestRowLen { + largestRowLen = rowLen + } - _, err := ct.conn.conn.Write(buf) - if err != nil { - ct.conn.die(err) - return err + // Try not to overflow size of the buffer PgConn.CopyFrom will be reading into. If that happens then the nature of + // io.Pipe means that the next Read will be short. This can lead to pathological send sizes such as 65531, 13, 65531 + // 13, 65531, 13, 65531, 13. + if len(buf) > sendBufSize-largestRowLen { + return true, buf, nil + } } - return nil + return false, buf, nil } -// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. -// It returns the number of rows copied and an error. +// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. It returns the number of rows copied and +// an error. +// +// CopyFrom requires all values use the binary format. A pgtype.Type that supports the binary format must be registered +// for the type of each column. Almost all types implemented by pgx support the binary format. // -// CopyFrom requires all values use the binary format. Almost all types -// implemented by pgx use the binary format by default. Types implementing -// Encoder can only be used if they encode to the binary format. -func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) { +// Even though enum types appear to be strings they still must be registered to use with CopyFrom. This can be done with +// Conn.LoadType and pgtype.Map.RegisterType. +func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { ct := ©From{ conn: c, tableName: tableName, columnNames: columnNames, rowSrc: rowSrc, readerErrChan: make(chan error), + mode: c.config.DefaultQueryExecMode, } - return ct.run() + return ct.run(ctx) } diff --git a/copy_from_test.go b/copy_from_test.go index ec6748559..2c8986a5c 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -1,18 +1,150 @@ package pgx_test import ( + "context" + "fmt" + "os" "reflect" "testing" "time" - "github.com/jackc/pgx" - "github.com/pkg/errors" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" ) +func TestConnCopyWithAllQueryExecModes(t *testing.T) { + for _, mode := range pgxtest.AllQueryExecModes { + t.Run(mode.String(), func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + cfg := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + cfg.DefaultQueryExecMode = mode + conn := mustConnect(t, cfg) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int2, + b int4, + c int8, + d text, + e timestamptz + )`) + + tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) + + inputRows := [][]any{ + {int16(0), int32(1), int64(2), "abc", tzedTime}, + {nil, nil, nil, nil, nil}, + } + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e"}, pgx.CopyFromRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyFrom: %v", err) + } + if int(copyCount) != len(inputRows) { + t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query(ctx, "select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]any + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + + ensureConnValid(t, conn) + }) + } +} + +func TestConnCopyWithKnownOIDQueryExecModes(t *testing.T) { + for _, mode := range pgxtest.KnownOIDQueryExecModes { + t.Run(mode.String(), func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + cfg := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + cfg.DefaultQueryExecMode = mode + conn := mustConnect(t, cfg) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g timestamptz + )`) + + tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) + + inputRows := [][]any{ + {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, + {nil, nil, nil, nil, nil, nil, nil}, + } + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyFrom: %v", err) + } + if int(copyCount) != len(inputRows) { + t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query(ctx, "select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]any + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + + ensureConnValid(t, conn) + }) + } +} + func TestConnCopyFromSmall(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) mustExec(t, conn, `create temporary table foo( @@ -25,25 +157,89 @@ func TestConnCopyFromSmall(t *testing.T) { g timestamptz )`) - inputRows := [][]interface{}{ - {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)}, + tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) + + inputRows := [][]any{ + {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, + {nil, nil, nil, nil, nil, nil, nil}, + } + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyFrom: %v", err) + } + if int(copyCount) != len(inputRows) { + t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query(ctx, "select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]any + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + + ensureConnValid(t, conn) +} + +func TestConnCopyFromSliceSmall(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g timestamptz + )`) + + tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) + + inputRows := [][]any{ + {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, {nil, nil, nil, nil, nil, nil, nil}, } - copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, + pgx.CopyFromSlice(len(inputRows), func(i int) ([]any, error) { + return inputRows[i], nil + })) if err != nil { t.Errorf("Unexpected error for CopyFrom: %v", err) } - if copyCount != len(inputRows) { + if int(copyCount) != len(inputRows) { t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) } - rows, err := conn.Query("select * from foo") + rows, err := conn.Query(ctx, "select * from foo") if err != nil { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -66,7 +262,10 @@ func TestConnCopyFromSmall(t *testing.T) { func TestConnCopyFromLarge(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) mustExec(t, conn, `create temporary table foo( @@ -80,26 +279,28 @@ func TestConnCopyFromLarge(t *testing.T) { h bytea )`) - inputRows := [][]interface{}{} + tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) + + inputRows := [][]any{} for i := 0; i < 10000; i++ { - inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}}) + inputRows = append(inputRows, []any{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime, []byte{111, 111, 111, 111}}) } - copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows)) + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows)) if err != nil { t.Errorf("Unexpected error for CopyFrom: %v", err) } - if copyCount != len(inputRows) { + if int(copyCount) != len(inputRows) { t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) } - rows, err := conn.Query("select * from foo") + rows, err := conn.Query(ctx, "select * from foo") if err != nil { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -119,14 +320,91 @@ func TestConnCopyFromLarge(t *testing.T) { ensureConnValid(t, conn) } +func TestConnCopyFromEnum(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + tx, err := conn.Begin(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + + _, err = tx.Exec(ctx, `drop type if exists color`) + require.NoError(t, err) + + _, err = tx.Exec(ctx, `drop type if exists fruit`) + require.NoError(t, err) + + _, err = tx.Exec(ctx, `create type color as enum ('blue', 'green', 'orange')`) + require.NoError(t, err) + + _, err = tx.Exec(ctx, `create type fruit as enum ('apple', 'orange', 'grape')`) + require.NoError(t, err) + + // Obviously using conn while a tx is in use and registering a type after the connection has been established are + // really bad practices, but for the sake of convenience we do it in the test here. + for _, name := range []string{"fruit", "color"} { + typ, err := conn.LoadType(ctx, name) + require.NoError(t, err) + conn.TypeMap().RegisterType(typ) + } + + _, err = tx.Exec(ctx, `create temporary table foo( + a text, + b color, + c fruit, + d color, + e fruit, + f text + )`) + require.NoError(t, err) + + inputRows := [][]any{ + {"abc", "blue", "grape", "orange", "orange", "def"}, + {nil, nil, nil, nil, nil, nil}, + } + + copyCount, err := tx.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + + rows, err := tx.Query(ctx, "select * from foo") + require.NoError(t, err) + + var outputRows [][]any + for rows.Next() { + row, err := rows.Values() + require.NoError(t, err) + outputRows = append(outputRows, row) + } + + require.NoError(t, rows.Err()) + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + + err = tx.Rollback(ctx) + require.NoError(t, err) + + ensureConnValid(t, conn) +} + func TestConnCopyFromJSON(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) for _, typeName := range []string{"json", "jsonb"} { - if _, ok := conn.ConnInfo.DataTypeForName(typeName); !ok { + if _, ok := conn.TypeMap().TypeForName(typeName); !ok { return // No JSON/JSONB type -- must be running against old PostgreSQL } } @@ -136,25 +414,25 @@ func TestConnCopyFromJSON(t *testing.T) { b jsonb )`) - inputRows := [][]interface{}{ - {map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}}, + inputRows := [][]any{ + {map[string]any{"foo": "bar"}, map[string]any{"bar": "quz"}}, {nil, nil}, } - copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) if err != nil { t.Errorf("Unexpected error for CopyFrom: %v", err) } - if copyCount != len(inputRows) { + if int(copyCount) != len(inputRows) { t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) } - rows, err := conn.Query("select * from foo") + rows, err := conn.Query(ctx, "select * from foo") if err != nil { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -184,12 +462,12 @@ func (cfs *clientFailSource) Next() bool { return cfs.count < 100 } -func (cfs *clientFailSource) Values() ([]interface{}, error) { +func (cfs *clientFailSource) Values() ([]any, error) { if cfs.count == 3 { - cfs.err = errors.Errorf("client error") + cfs.err = fmt.Errorf("client error") return nil, cfs.err } - return []interface{}{make([]byte, 100000)}, nil + return []any{make([]byte, 100000)}, nil } func (cfs *clientFailSource) Err() error { @@ -199,7 +477,10 @@ func (cfs *clientFailSource) Err() error { func TestConnCopyFromFailServerSideMidway(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) mustExec(t, conn, `create temporary table foo( @@ -207,29 +488,29 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) { b varchar not null )`) - inputRows := [][]interface{}{ + inputRows := [][]any{ {int32(1), "abc"}, {int32(2), nil}, // this row should trigger a failure {int32(3), "def"}, } - copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) if err == nil { t.Errorf("Expected CopyFrom return error, but it did not") } - if _, ok := err.(pgx.PgError); !ok { + if _, ok := err.(*pgconn.PgError); !ok { t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err) } if copyCount != 0 { t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) } - rows, err := conn.Query("select * from foo") + rows, err := conn.Query(ctx, "select * from foo") if err != nil { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -246,6 +527,8 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) { t.Errorf("Expected 0 rows, but got %v", outputRows) } + mustExec(t, conn, "truncate foo") + ensureConnValid(t, conn) } @@ -259,11 +542,11 @@ func (fs *failSource) Next() bool { return fs.count < 100 } -func (fs *failSource) Values() ([]interface{}, error) { +func (fs *failSource) Values() ([]any, error) { if fs.count == 3 { - return []interface{}{nil}, nil + return []any{nil}, nil } - return []interface{}{make([]byte, 100000)}, nil + return []any{make([]byte, 100000)}, nil } func (fs *failSource) Err() error { @@ -273,20 +556,25 @@ func (fs *failSource) Err() error { func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) + pgxtest.SkipCockroachDB(t, conn, "Server copy error does not fail fast") + mustExec(t, conn, `create temporary table foo( a bytea not null )`) startTime := time.Now() - copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &failSource{}) + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, &failSource{}) if err == nil { t.Errorf("Expected CopyFrom return error, but it did not") } - if _, ok := err.(pgx.PgError); !ok { + if _, ok := err.(*pgconn.PgError); !ok { t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err) } if copyCount != 0 { @@ -299,12 +587,12 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { t.Errorf("Failing CopyFrom shouldn't have taken so long: %v", copyTime) } - rows, err := conn.Query("select * from foo") + rows, err := conn.Query(ctx, "select * from foo") if err != nil { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -324,17 +612,69 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { ensureConnValid(t, conn) } +type slowFailRaceSource struct { + count int +} + +func (fs *slowFailRaceSource) Next() bool { + time.Sleep(time.Millisecond) + fs.count++ + return fs.count < 1000 +} + +func (fs *slowFailRaceSource) Values() ([]any, error) { + if fs.count == 500 { + return []any{nil, nil}, nil + } + return []any{1, make([]byte, 1000)}, nil +} + +func (fs *slowFailRaceSource) Err() error { + return nil +} + +func TestConnCopyFromSlowFailRace(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int not null, + b bytea not null + )`) + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b"}, &slowFailRaceSource{}) + if err == nil { + t.Errorf("Expected CopyFrom return error, but it did not") + } + if _, ok := err.(*pgconn.PgError); !ok { + t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err) + } + if copyCount != 0 { + t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) + } + + ensureConnValid(t, conn) +} + func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) mustExec(t, conn, `create temporary table foo( a bytea not null )`) - copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &clientFailSource{}) + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, &clientFailSource{}) if err == nil { t.Errorf("Expected CopyFrom return error, but it did not") } @@ -342,12 +682,12 @@ func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) { t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) } - rows, err := conn.Query("select * from foo") + rows, err := conn.Query(ctx, "select * from foo") if err != nil { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -361,7 +701,7 @@ func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) { } if len(outputRows) != 0 { - t.Errorf("Expected 0 rows, but got %v", outputRows) + t.Errorf("Expected 0 rows, but got %v", len(outputRows)) } ensureConnValid(t, conn) @@ -376,25 +716,28 @@ func (cfs *clientFinalErrSource) Next() bool { return cfs.count < 5 } -func (cfs *clientFinalErrSource) Values() ([]interface{}, error) { - return []interface{}{make([]byte, 100000)}, nil +func (cfs *clientFinalErrSource) Values() ([]any, error) { + return []any{make([]byte, 100000)}, nil } func (cfs *clientFinalErrSource) Err() error { - return errors.Errorf("final error") + return fmt.Errorf("final error") } func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) mustExec(t, conn, `create temporary table foo( a bytea not null )`) - copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &clientFinalErrSource{}) + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, &clientFinalErrSource{}) if err == nil { t.Errorf("Expected CopyFrom return error, but it did not") } @@ -402,12 +745,12 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) { t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) } - rows, err := conn.Query("select * from foo") + rows, err := conn.Query(ctx, "select * from foo") if err != nil { t.Errorf("Unexpected error for Query: %v", err) } - var outputRows [][]interface{} + var outputRows [][]any for rows.Next() { row, err := rows.Values() if err != nil { @@ -426,3 +769,125 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) { ensureConnValid(t, conn) } + +func TestConnCopyFromAutomaticStringConversion(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int8 + )`) + + inputRows := [][]interface{}{ + {"42"}, + {"7"}, + {8}, + } + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + + rows, _ := conn.Query(ctx, "select * from foo") + nums, err := pgx.CollectRows(rows, pgx.RowTo[int64]) + require.NoError(t, err) + + require.Equal(t, []int64{42, 7, 8}, nums) + + ensureConnValid(t, conn) +} + +// https://github.com/jackc/pgx/discussions/1891 +func TestConnCopyFromAutomaticStringConversionArray(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a numeric[] + )`) + + inputRows := [][]interface{}{ + {[]string{"42"}}, + {[]string{"7"}}, + {[]string{"8", "9"}}, + {[][]string{{"10", "11"}, {"12", "13"}}}, + } + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + + // Test reads as int64 and flattened array for simplicity. + rows, _ := conn.Query(ctx, "select * from foo") + nums, err := pgx.CollectRows(rows, pgx.RowTo[[]int64]) + require.NoError(t, err) + require.Equal(t, [][]int64{{42}, {7}, {8, 9}, {10, 11, 12, 13}}, nums) + + ensureConnValid(t, conn) +} + +func TestCopyFromFunc(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int + )`) + + dataCh := make(chan int, 1) + + const channelItems = 10 + go func() { + for i := 0; i < channelItems; i++ { + dataCh <- i + } + close(dataCh) + }() + + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, + pgx.CopyFromFunc(func() ([]any, error) { + v, ok := <-dataCh + if !ok { + return nil, nil + } + return []any{v}, nil + })) + + require.ErrorIs(t, err, nil) + require.EqualValues(t, channelItems, copyCount) + + rows, err := conn.Query(context.Background(), "select * from foo order by a") + require.NoError(t, err) + nums, err := pgx.CollectRows(rows, pgx.RowTo[int64]) + require.NoError(t, err) + require.Equal(t, []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, nums) + + // simulate a failure + copyCount, err = conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, + pgx.CopyFromFunc(func() func() ([]any, error) { + x := 9 + return func() ([]any, error) { + x++ + if x > 100 { + return nil, fmt.Errorf("simulated error") + } + return []any{x}, nil + } + }())) + require.NotErrorIs(t, err, nil) + require.EqualValues(t, 0, copyCount) // no change, due to error + + ensureConnValid(t, conn) +} diff --git a/derived_types.go b/derived_types.go new file mode 100644 index 000000000..72c0a2423 --- /dev/null +++ b/derived_types.go @@ -0,0 +1,256 @@ +package pgx + +import ( + "context" + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/pgtype" +) + +/* +buildLoadDerivedTypesSQL generates the correct query for retrieving type information. + + pgVersion: the major version of the PostgreSQL server + typeNames: the names of the types to load. If nil, load all types. +*/ +func buildLoadDerivedTypesSQL(pgVersion int64, typeNames []string) string { + supportsMultirange := (pgVersion >= 14) + var typeNamesClause string + + if typeNames == nil { + // This should not occur; this will not return any types + typeNamesClause = "= ''" + } else { + typeNamesClause = "= ANY($1)" + } + parts := make([]string, 0, 10) + + // Each of the type names provided might be found in pg_class or pg_type. + // Additionally, it may or may not include a schema portion. + parts = append(parts, ` +WITH RECURSIVE +-- find the OIDs in pg_class which match one of the provided type names +selected_classes(oid,reltype) AS ( + -- this query uses the namespace search path, so will match type names without a schema prefix + SELECT pg_class.oid, pg_class.reltype + FROM pg_catalog.pg_class + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = pg_class.relnamespace + WHERE pg_catalog.pg_table_is_visible(pg_class.oid) + AND relname `, typeNamesClause, ` +UNION ALL + -- this query will only match type names which include the schema prefix + SELECT pg_class.oid, pg_class.reltype + FROM pg_class + INNER JOIN pg_namespace ON (pg_class.relnamespace = pg_namespace.oid) + WHERE nspname || '.' || relname `, typeNamesClause, ` +), +selected_types(oid) AS ( + -- collect the OIDs from pg_types which correspond to the selected classes + SELECT reltype AS oid + FROM selected_classes +UNION ALL + -- as well as any other type names which match our criteria + SELECT pg_type.oid + FROM pg_type + LEFT OUTER JOIN pg_namespace ON (pg_type.typnamespace = pg_namespace.oid) + WHERE typname `, typeNamesClause, ` + OR nspname || '.' || typname `, typeNamesClause, ` +), +-- this builds a parent/child mapping of objects, allowing us to know +-- all the child (ie: dependent) types that a parent (type) requires +-- As can be seen, there are 3 ways this can occur (the last of which +-- is due to being a composite class, where the composite fields are children) +pc(parent, child) AS ( + SELECT parent.oid, parent.typelem + FROM pg_type parent + WHERE parent.typtype = 'b' AND parent.typelem != 0 +UNION ALL + SELECT parent.oid, parent.typbasetype + FROM pg_type parent + WHERE parent.typtypmod = -1 AND parent.typbasetype != 0 +UNION ALL + SELECT pg_type.oid, atttypid + FROM pg_attribute + INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid) + INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype) + WHERE NOT attisdropped + AND attnum > 0 +), +-- Now construct a recursive query which includes a 'depth' element. +-- This is used to ensure that the "youngest" children are registered before +-- their parents. +relationships(parent, child, depth) AS ( + SELECT DISTINCT 0::OID, selected_types.oid, 0 + FROM selected_types +UNION ALL + SELECT pg_type.oid AS parent, pg_attribute.atttypid AS child, 1 + FROM selected_classes c + inner join pg_type ON (c.reltype = pg_type.oid) + inner join pg_attribute on (c.oid = pg_attribute.attrelid) +UNION ALL + SELECT pc.parent, pc.child, relationships.depth + 1 + FROM pc + INNER JOIN relationships ON (pc.parent = relationships.child) +), +-- composite fields need to be encapsulated as a couple of arrays to provide the required information for registration +composite AS ( + SELECT pg_type.oid, ARRAY_AGG(attname ORDER BY attnum) AS attnames, ARRAY_AGG(atttypid ORDER BY ATTNUM) AS atttypids + FROM pg_attribute + INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid) + INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype) + WHERE NOT attisdropped + AND attnum > 0 + GROUP BY pg_type.oid +) +-- Bring together this information, showing all the information which might possibly be required +-- to complete the registration, applying filters to only show the items which relate to the selected +-- types/classes. +SELECT typname, + pg_namespace.nspname, + typtype, + typbasetype, + typelem, + pg_type.oid,`) + if supportsMultirange { + parts = append(parts, ` + COALESCE(multirange.rngtypid, 0) AS rngtypid,`) + } else { + parts = append(parts, ` + 0 AS rngtypid,`) + } + parts = append(parts, ` + COALESCE(pg_range.rngsubtype, 0) AS rngsubtype, + attnames, atttypids + FROM relationships + INNER JOIN pg_type ON (pg_type.oid = relationships.child) + LEFT OUTER JOIN pg_range ON (pg_type.oid = pg_range.rngtypid)`) + if supportsMultirange { + parts = append(parts, ` + LEFT OUTER JOIN pg_range multirange ON (pg_type.oid = multirange.rngmultitypid)`) + } + + parts = append(parts, ` + LEFT OUTER JOIN composite USING (oid) + LEFT OUTER JOIN pg_namespace ON (pg_type.typnamespace = pg_namespace.oid) + WHERE NOT (typtype = 'b' AND typelem = 0)`) + parts = append(parts, ` + GROUP BY typname, pg_namespace.nspname, typtype, typbasetype, typelem, pg_type.oid, pg_range.rngsubtype,`) + if supportsMultirange { + parts = append(parts, ` + multirange.rngtypid,`) + } + parts = append(parts, ` + attnames, atttypids + ORDER BY MAX(depth) desc, typname;`) + return strings.Join(parts, "") +} + +type derivedTypeInfo struct { + Oid, Typbasetype, Typelem, Rngsubtype, Rngtypid uint32 + TypeName, Typtype, NspName string + Attnames []string + Atttypids []uint32 +} + +// LoadTypes performs a single (complex) query, returning all the required +// information to register the named types, as well as any other types directly +// or indirectly required to complete the registration. +// The result of this call can be passed into RegisterTypes to complete the process. +func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Type, error) { + m := c.TypeMap() + if len(typeNames) == 0 { + return nil, fmt.Errorf("No type names were supplied.") + } + + // Disregard server version errors. This will result in + // the SQL not support recent structures such as multirange + serverVersion, _ := serverVersion(c) + sql := buildLoadDerivedTypesSQL(serverVersion, typeNames) + rows, err := c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames) + if err != nil { + return nil, fmt.Errorf("While generating load types query: %w", err) + } + defer rows.Close() + result := make([]*pgtype.Type, 0, 100) + for rows.Next() { + ti := derivedTypeInfo{} + err = rows.Scan(&ti.TypeName, &ti.NspName, &ti.Typtype, &ti.Typbasetype, &ti.Typelem, &ti.Oid, &ti.Rngtypid, &ti.Rngsubtype, &ti.Attnames, &ti.Atttypids) + if err != nil { + return nil, fmt.Errorf("While scanning type information: %w", err) + } + var type_ *pgtype.Type + switch ti.Typtype { + case "b": // array + dt, ok := m.TypeForOID(ti.Typelem) + if !ok { + return nil, fmt.Errorf("Array element OID %v not registered while loading pgtype %q", ti.Typelem, ti.TypeName) + } + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.ArrayCodec{ElementType: dt}} + case "c": // composite + var fields []pgtype.CompositeCodecField + for i, fieldName := range ti.Attnames { + dt, ok := m.TypeForOID(ti.Atttypids[i]) + if !ok { + return nil, fmt.Errorf("Unknown field for composite type %q: field %q (OID %v) is not already registered.", ti.TypeName, fieldName, ti.Atttypids[i]) + } + fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt}) + } + + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.CompositeCodec{Fields: fields}} + case "d": // domain + dt, ok := m.TypeForOID(ti.Typbasetype) + if !ok { + return nil, fmt.Errorf("Domain base type OID %v was not already registered, needed for %q", ti.Typbasetype, ti.TypeName) + } + + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: dt.Codec} + case "e": // enum + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.EnumCodec{}} + case "r": // range + dt, ok := m.TypeForOID(ti.Rngsubtype) + if !ok { + return nil, fmt.Errorf("Range element OID %v was not already registered, needed for %q", ti.Rngsubtype, ti.TypeName) + } + + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.RangeCodec{ElementType: dt}} + case "m": // multirange + dt, ok := m.TypeForOID(ti.Rngtypid) + if !ok { + return nil, fmt.Errorf("Multirange element OID %v was not already registered, needed for %q", ti.Rngtypid, ti.TypeName) + } + + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}} + default: + return nil, fmt.Errorf("Unknown typtype %q was found while registering %q", ti.Typtype, ti.TypeName) + } + + // the type_ is imposible to be null + m.RegisterType(type_) + if ti.NspName != "" { + nspType := &pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec} + m.RegisterType(nspType) + result = append(result, nspType) + } + result = append(result, type_) + } + return result, nil +} + +// serverVersion returns the postgresql server version. +func serverVersion(c *Conn) (int64, error) { + serverVersionStr := c.PgConn().ParameterStatus("server_version") + serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr) + // if not PostgreSQL do nothing + if serverVersionStr == "" { + return 0, fmt.Errorf("Cannot identify server version in %q", serverVersionStr) + } + + version, err := strconv.ParseInt(serverVersionStr, 10, 64) + if err != nil { + return 0, fmt.Errorf("postgres version parsing failed: %w", err) + } + return version, nil +} diff --git a/derived_types_test.go b/derived_types_test.go new file mode 100644 index 000000000..6fb6e1d36 --- /dev/null +++ b/derived_types_test.go @@ -0,0 +1,40 @@ +package pgx_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" +) + +func TestCompositeCodecTranscodeWithLoadTypes(t *testing.T) { + skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(ctx, ` +drop type if exists dtype_test; +drop domain if exists anotheruint64; + +create domain anotheruint64 as numeric(20,0); +create type dtype_test as ( + a text, + b int4, + c anotheruint64, + d anotheruint64[] +);`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop type dtype_test") + defer conn.Exec(ctx, "drop domain anotheruint64") + + types, err := conn.LoadTypes(ctx, []string{"dtype_test"}) + require.NoError(t, err) + require.Len(t, types, 6) + require.Equal(t, types[0].Name, "public.anotheruint64") + require.Equal(t, types[1].Name, "anotheruint64") + require.Equal(t, types[2].Name, "public._anotheruint64") + require.Equal(t, types[3].Name, "_anotheruint64") + require.Equal(t, types[4].Name, "public.dtype_test") + require.Equal(t, types[5].Name, "dtype_test") + }) +} diff --git a/doc.go b/doc.go index 51f1edc5c..5d2ae3889 100644 --- a/doc.go +++ b/doc.go @@ -1,58 +1,66 @@ // Package pgx is a PostgreSQL database driver. /* -pgx provides lower level access to PostgreSQL than the standard database/sql. -It remains as similar to the database/sql interface as possible while -providing better speed and access to PostgreSQL specific features. Import -github.com/jackc/pgx/stdlib to use pgx as a database/sql compatible driver. +pgx provides a native PostgreSQL driver and can act as a database/sql driver. The native PostgreSQL interface is similar +to the database/sql interface while providing better speed and access to PostgreSQL specific features. Use +github.com/jackc/pgx/v5/stdlib to use pgx as a database/sql compatible driver. See that package's documentation for +details. + +Establishing a Connection + +The primary way of establishing a connection is with [pgx.Connect]: + + conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) + +The database connection string can be in URL or key/value format. Both PostgreSQL settings and pgx settings can be +specified here. In addition, a config struct can be created by [ParseConfig] and modified before establishing the +connection with [ConnectConfig] to configure settings such as tracing that cannot be configured with a connection +string. + +Connection Pool + +[*pgx.Conn] represents a single connection to the database and is not concurrency safe. Use package +github.com/jackc/pgx/v5/pgxpool for a concurrency safe connection pool. Query Interface -pgx implements Query and Scan in the familiar database/sql style. +pgx implements Query in the familiar database/sql style. However, pgx provides generic functions such as CollectRows and +ForEachRow that are a simpler and safer way of processing rows than manually calling defer rows.Close(), rows.Next(), +rows.Scan, and rows.Err(). - var sum int32 +CollectRows can be used collect all returned rows into a slice. - // Send the query to the server. The returned rows MUST be closed - // before conn can be used again. - rows, err := conn.Query("select generate_series(1,$1)", 10) + rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 5) + numbers, err := pgx.CollectRows(rows, pgx.RowTo[int32]) if err != nil { - return err + return err } + // numbers => [1 2 3 4 5] - // rows.Close is called by rows.Next when all rows are read - // or an error occurs in Next or Scan. So it may optionally be - // omitted if nothing in the rows.Next loop can panic. It is - // safe to close rows multiple times. - defer rows.Close() - - // Iterate through the result set - for rows.Next() { - var n int32 - err = rows.Scan(&n) - if err != nil { - return err - } - sum += n - } +ForEachRow can be used to execute a callback function for every row. This is often easier than iterating over rows +directly. - // Any errors encountered by rows.Next or rows.Scan will be returned here - if rows.Err() != nil { - return err + var sum, n int32 + rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 10) + _, err := pgx.ForEachRow(rows, []any{&n}, func() error { + sum += n + return nil + }) + if err != nil { + return err } - // No errors found - do something with sum - pgx also implements QueryRow in the same style as database/sql. var name string var weight int64 - err := conn.QueryRow("select name, weight from widgets where id=$1", 42).Scan(&name, &weight) + err := conn.QueryRow(context.Background(), "select name, weight from widgets where id=$1", 42).Scan(&name, &weight) if err != nil { return err } Use Exec to execute a query that does not return a result set. - commandTag, err := conn.Exec("delete from widgets where id=$1", 42) + commandTag, err := conn.Exec(context.Background(), "delete from widgets where id=$1", 42) if err != nil { return err } @@ -60,180 +68,127 @@ Use Exec to execute a query that does not return a result set. return errors.New("No row found to delete") } -Connection Pool +PostgreSQL Data Types -Connection pool usage is explicit and configurable. In pgx, a connection can be -created and managed directly, or a connection pool with a configurable maximum -connections can be used. The connection pool offers an after connect hook that -allows every connection to be automatically setup before being made available in -the connection pool. +pgx uses the pgtype package to converting Go values to and from PostgreSQL values. It supports many PostgreSQL types +directly and is customizable and extendable. User defined data types such as enums, domains, and composite types may +require type registration. See that package's documentation for details. -It delegates methods such as QueryRow to an automatically checked out and -released connection so you can avoid manually acquiring and releasing -connections when you do not need that level of control. +Transactions - var name string - var weight int64 - err := pool.QueryRow("select name, weight from widgets where id=$1", 42).Scan(&name, &weight) +Transactions are started by calling Begin. + + tx, err := conn.Begin(context.Background()) if err != nil { return err } + // Rollback is safe to call even if the tx is already closed, so if + // the tx commits successfully, this is a no-op + defer tx.Rollback(context.Background()) -Base Type Mapping - -pgx maps between all common base types directly between Go and PostgreSQL. In -particular: - - Go PostgreSQL - ----------------------- - string varchar - text - - // Integers are automatically be converted to any other integer type if - // it can be done without overflow or underflow. - int8 - int16 smallint - int32 int - int64 bigint - int - uint8 - uint16 - uint32 - uint64 - uint - - // Floats are strict and do not automatically convert like integers. - float32 float4 - float64 float8 - - time.Time date - timestamp - timestamptz - - []byte bytea - - -Null Mapping - -pgx can map nulls in two ways. The first is package pgtype provides types that -have a data field and a status field. They work in a similar fashion to -database/sql. The second is to use a pointer to a pointer. - - var foo pgtype.Varchar - var bar *string - err := conn.QueryRow("select foo, bar from widgets where id=$1", 42).Scan(&a, &b) + _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)") if err != nil { return err } -Array Mapping - -pgx maps between int16, int32, int64, float32, float64, and string Go slices -and the equivalent PostgreSQL array type. Go slices of native types do not -support nulls, so if a PostgreSQL array that contains a null is read into a -native Go slice an error will occur. The pgtype package includes many more -array types for PostgreSQL types that do not directly map to native Go types. - -JSON and JSONB Mapping - -pgx includes built-in support to marshal and unmarshal between Go types and -the PostgreSQL JSON and JSONB. - -Inet and CIDR Mapping - -pgx encodes from net.IPNet to and from inet and cidr PostgreSQL types. In -addition, as a convenience pgx will encode from a net.IP; it will assume a /32 -netmask for IPv4 and a /128 for IPv6. - -Custom Type Support - -pgx includes support for the common data types like integers, floats, strings, -dates, and times that have direct mappings between Go and SQL. In addition, -pgx uses the github.com/jackc/pgx/pgtype library to support more types. See -documention for that library for instructions on how to implement custom -types. - -See example_custom_type_test.go for an example of a custom type for the -PostgreSQL point type. - -pgx also includes support for custom types implementing the database/sql.Scanner -and database/sql/driver.Valuer interfaces. - -Raw Bytes Mapping + err = tx.Commit(context.Background()) + if err != nil { + return err + } -[]byte passed as arguments to Query, QueryRow, and Exec are passed unmodified -to PostgreSQL. +The Tx returned from Begin also implements the Begin method. This can be used to implement pseudo nested transactions. +These are internally implemented with savepoints. -Transactions +Use BeginTx to control the transaction mode. BeginTx also can be used to ensure a new transaction is created instead of +a pseudo nested transaction. -Transactions are started by calling Begin or BeginEx. The BeginEx variant -can create a transaction with a specified isolation level. +BeginFunc and BeginTxFunc are functions that begin a transaction, execute a function, and commit or rollback the +transaction depending on the return value of the function. These can be simpler and less error prone to use. - tx, err := conn.Begin() - if err != nil { + err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error { + _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") return err - } - // Rollback is safe to call even if the tx is already closed, so if - // the tx commits successfully, this is a no-op - defer tx.Rollback() - - _, err = tx.Exec("insert into foo(id) values (1)") + }) if err != nil { return err } - err = tx.Commit() - if err != nil { - return err - } +Prepared Statements + +Prepared statements can be manually created with the Prepare method. However, this is rarely necessary because pgx +includes an automatic statement cache by default. Queries run through the normal Query, QueryRow, and Exec functions are +automatically prepared on first execution and the prepared statement is reused on subsequent executions. See ParseConfig +for information on how to customize or disable the statement cache. Copy Protocol -Use CopyFrom to efficiently insert multiple rows at a time using the PostgreSQL -copy protocol. CopyFrom accepts a CopyFromSource interface. If the data is already -in a [][]interface{} use CopyFromRows to wrap it in a CopyFromSource interface. Or -implement CopyFromSource to avoid buffering the entire data set in memory. +Use CopyFrom to efficiently insert multiple rows at a time using the PostgreSQL copy protocol. CopyFrom accepts a +CopyFromSource interface. If the data is already in a [][]any use CopyFromRows to wrap it in a CopyFromSource interface. +Or implement CopyFromSource to avoid buffering the entire data set in memory. - rows := [][]interface{}{ + rows := [][]any{ {"John", "Smith", int32(36)}, {"Jane", "Doe", int32(29)}, } copyCount, err := conn.CopyFrom( + context.Background(), pgx.Identifier{"people"}, []string{"first_name", "last_name", "age"}, pgx.CopyFromRows(rows), ) +When you already have a typed array using CopyFromSlice can be more convenient. + + rows := []User{ + {"John", "Smith", 36}, + {"Jane", "Doe", 29}, + } + + copyCount, err := conn.CopyFrom( + context.Background(), + pgx.Identifier{"people"}, + []string{"first_name", "last_name", "age"}, + pgx.CopyFromSlice(len(rows), func(i int) ([]any, error) { + return []any{rows[i].FirstName, rows[i].LastName, rows[i].Age}, nil + }), + ) + CopyFrom can be faster than an insert with as few as 5 rows. Listen and Notify -pgx can listen to the PostgreSQL notification system with the -WaitForNotification function. It takes a maximum time to wait for a -notification. +pgx can listen to the PostgreSQL notification system with the `Conn.WaitForNotification` method. It blocks until a +notification is received or the context is canceled. - err := conn.Listen("channelname") + _, err := conn.Exec(context.Background(), "listen channelname") if err != nil { - return nil + return err } - if notification, err := conn.WaitForNotification(time.Second); err != nil { - // do something with notification + notification, err := conn.WaitForNotification(context.Background()) + if err != nil { + return err } + // do something with notification + + +Tracing and Logging + +pgx supports tracing by setting ConnConfig.Tracer. To combine several tracers you can use the multitracer.Tracer. + +In addition, the tracelog package provides the TraceLog type which lets a traditional logger act as a Tracer. + +For debug tracing of the actual PostgreSQL wire protocol messages see github.com/jackc/pgx/v5/pgproto3. -TLS +Lower Level PostgreSQL Functionality -The pgx ConnConfig struct has a TLSConfig field. If this field is -nil, then TLS will be disabled. If it is present, then it will be used to -configure the TLS connection. This allows total configuration of the TLS -connection. +github.com/jackc/pgx/v5/pgconn contains a lower level PostgreSQL driver roughly at the level of libpq. pgx.Conn is +implemented on top of pgconn. The Conn.PgConn() method can be used to access this lower layer. -Logging +PgBouncer -pgx defines a simple logger interface. Connections optionally accept a logger -that satisfies this interface. Set LogLevel to control logging verbosity. -Adapters for github.com/inconshreveable/log15, github.com/sirupsen/logrus, and -the testing log are provided in the log directory. +By default pgx automatically uses prepared statements. Prepared statements are incompatible with PgBouncer. This can be +disabled by setting a different QueryExecMode in ConnConfig.DefaultQueryExecMode. */ package pgx diff --git a/example_custom_type_test.go b/example_custom_type_test.go deleted file mode 100644 index d3cc90853..000000000 --- a/example_custom_type_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package pgx_test - -import ( - "fmt" - "regexp" - "strconv" - - "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" - "github.com/pkg/errors" -) - -var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`) - -// Point represents a point that may be null. -type Point struct { - X, Y float64 // Coordinates of point - Status pgtype.Status -} - -func (dst *Point) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Point", src) -} - -func (dst *Point) Get() interface{} { - switch dst.Status { - case pgtype.Present: - return dst - case pgtype.Null: - return nil - default: - return dst.Status - } -} - -func (src *Point) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error { - if src == nil { - *dst = Point{Status: pgtype.Null} - return nil - } - - s := string(src) - match := pointRegexp.FindStringSubmatch(s) - if match == nil { - return errors.Errorf("Received invalid point: %v", s) - } - - x, err := strconv.ParseFloat(match[1], 64) - if err != nil { - return errors.Errorf("Received invalid point: %v", s) - } - y, err := strconv.ParseFloat(match[2], 64) - if err != nil { - return errors.Errorf("Received invalid point: %v", s) - } - - *dst = Point{X: x, Y: y, Status: pgtype.Present} - - return nil -} - -func (src *Point) String() string { - if src.Status == pgtype.Null { - return "null point" - } - - return fmt.Sprintf("%.1f, %.1f", src.X, src.Y) -} - -func Example_CustomType() { - conn, err := pgx.Connect(*defaultConnConfig) - if err != nil { - fmt.Printf("Unable to establish connection: %v", err) - return - } - - // Override registered handler for point - conn.ConnInfo.RegisterDataType(pgtype.DataType{ - Value: &Point{}, - Name: "point", - OID: 600, - }) - - p := &Point{} - err = conn.QueryRow("select null::point").Scan(p) - if err != nil { - fmt.Println(err) - return - } - fmt.Println(p) - - err = conn.QueryRow("select point(1.5,2.5)").Scan(p) - if err != nil { - fmt.Println(err) - return - } - fmt.Println(p) - // Output: - // null point - // 1.5, 2.5 -} diff --git a/examples/README.md b/examples/README.md index 6a97bc09e..410ebc32f 100644 --- a/examples/README.md +++ b/examples/README.md @@ -3,5 +3,5 @@ * chat is a command line chat program using listen/notify. * todo is a command line todo list that demonstrates basic CRUD actions. * url_shortener contains a simple example of using pgx in a web context. -* [Tern](https://github.com/jackc/tern) is a migration tool that uses pgx (uses v1 of pgx). +* [Tern](https://github.com/jackc/tern) is a migration tool that uses pgx. * [The Pithy Reader](https://github.com/jackc/tpr) is a RSS aggregator that uses pgx. diff --git a/examples/chat/README.md b/examples/chat/README.md index 4b73eb510..4e68df489 100644 --- a/examples/chat/README.md +++ b/examples/chat/README.md @@ -8,12 +8,7 @@ between them. ## Connection configuration -The database connection is configured via the standard PostgreSQL environment variables. - -* PGHOST - defaults to localhost -* PGUSER - defaults to current OS user -* PGPASSWORD - defaults to empty string -* PGDATABASE - defaults to user name +The database connection is configured via DATABASE_URL and standard PostgreSQL environment variables (PGHOST, PGUSER, etc.) You can either export them then run chat: diff --git a/examples/chat/main.go b/examples/chat/main.go index 83b16c029..5adbb3b62 100644 --- a/examples/chat/main.go +++ b/examples/chat/main.go @@ -6,19 +6,14 @@ import ( "fmt" "os" - "github.com/jackc/pgx" + "github.com/jackc/pgx/v5/pgxpool" ) -var pool *pgx.ConnPool +var pool *pgxpool.Pool func main() { - config, err := pgx.ParseEnvLibpq() - if err != nil { - fmt.Fprintln(os.Stderr, "Unable to parse environment:", err) - os.Exit(1) - } - - pool, err = pgx.NewConnPool(pgx.ConnPoolConfig{ConnConfig: config}) + var err error + pool, err = pgxpool.New(context.Background(), os.Getenv("DATABASE_URL")) if err != nil { fmt.Fprintln(os.Stderr, "Unable to connect to database:", err) os.Exit(1) @@ -40,7 +35,7 @@ Type "exit" to quit.`) os.Exit(0) } - _, err = pool.Exec("select pg_notify('chat', $1)", msg) + _, err = pool.Exec(context.Background(), "select pg_notify('chat', $1)", msg) if err != nil { fmt.Fprintln(os.Stderr, "Error sending notification:", err) os.Exit(1) @@ -53,17 +48,21 @@ Type "exit" to quit.`) } func listen() { - conn, err := pool.Acquire() + conn, err := pool.Acquire(context.Background()) if err != nil { fmt.Fprintln(os.Stderr, "Error acquiring connection:", err) os.Exit(1) } - defer pool.Release(conn) + defer conn.Release() - conn.Listen("chat") + _, err = conn.Exec(context.Background(), "listen chat") + if err != nil { + fmt.Fprintln(os.Stderr, "Error listening to chat channel:", err) + os.Exit(1) + } for { - notification, err := conn.WaitForNotification(context.Background()) + notification, err := conn.Conn().WaitForNotification(context.Background()) if err != nil { fmt.Fprintln(os.Stderr, "Error waiting for notification:", err) os.Exit(1) diff --git a/examples/todo/README.md b/examples/todo/README.md index 32c32aa28..ecd2a3ca9 100644 --- a/examples/todo/README.md +++ b/examples/todo/README.md @@ -19,12 +19,7 @@ Build todo: ## Connection configuration -The database connection is configured via enviroment variables. - -* PGHOST - defaults to localhost -* PGUSER - defaults to current OS user -* PGPASSWORD - defaults to empty string -* PGDATABASE - defaults to user name +The database connection is configured via DATABASE_URL and standard PostgreSQL environment variables (PGHOST, PGUSER, etc.) You can either export them then run todo: @@ -45,7 +40,7 @@ Or you can prefix the todo execution with the environment variables: ## Update a task - ./todo add 1 'Learn more go' + ./todo update 1 'Learn more go' ## Delete a task diff --git a/examples/todo/main.go b/examples/todo/main.go index 70cbe14c2..6c644edec 100644 --- a/examples/todo/main.go +++ b/examples/todo/main.go @@ -1,22 +1,19 @@ package main import ( + "context" "fmt" - "github.com/jackc/pgx" "os" "strconv" + + "github.com/jackc/pgx/v5" ) var conn *pgx.Conn func main() { - config, err := pgx.ParseEnvLibpq() - if err != nil { - fmt.Fprintln(os.Stderr, "Unable to parse environment:", err) - os.Exit(1) - } - - conn, err = pgx.Connect(config) + var err error + conn, err = pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) if err != nil { fmt.Fprintf(os.Stderr, "Unable to connection to database: %v\n", err) os.Exit(1) @@ -74,7 +71,7 @@ func main() { } func listTasks() error { - rows, _ := conn.Query("select * from tasks") + rows, _ := conn.Query(context.Background(), "select * from tasks") for rows.Next() { var id int32 @@ -90,17 +87,17 @@ func listTasks() error { } func addTask(description string) error { - _, err := conn.Exec("insert into tasks(description) values($1)", description) + _, err := conn.Exec(context.Background(), "insert into tasks(description) values($1)", description) return err } func updateTask(itemNum int32, description string) error { - _, err := conn.Exec("update tasks set description=$1 where id=$2", description, itemNum) + _, err := conn.Exec(context.Background(), "update tasks set description=$1 where id=$2", description, itemNum) return err } func removeTask(itemNum int32) error { - _, err := conn.Exec("delete from tasks where id=$1", itemNum) + _, err := conn.Exec(context.Background(), "delete from tasks where id=$1", itemNum) return err } diff --git a/examples/url_shortener/README.md b/examples/url_shortener/README.md index cc04d6007..beb1802be 100644 --- a/examples/url_shortener/README.md +++ b/examples/url_shortener/README.md @@ -6,20 +6,28 @@ This is a sample REST URL shortener service implemented using pgx as the connect Create a PostgreSQL database and run structure.sql into it to create the necessary data schema. -Edit connectionOptions in main.go with the location and credentials for your database. +Configure the database connection with `DATABASE_URL` or standard PostgreSQL (`PG*`) environment variables or Run main.go: - go run main.go +``` +go run main.go +``` ## Create or Update a Shortened URL - curl -X PUT -d 'http://www.google.com/' http://localhost:8080/google +``` +curl -X PUT -d 'http://www.google.com/' http://localhost:8080/google +``` ## Get a Shortened URL - curl http://localhost:8080/google +``` +curl http://localhost:8080/google +``` ## Delete a Shortened URL - curl -X DELETE http://localhost:8080/google +``` +curl -X DELETE http://localhost:8080/google +``` diff --git a/examples/url_shortener/main.go b/examples/url_shortener/main.go index c6576a3a7..0887a0198 100644 --- a/examples/url_shortener/main.go +++ b/examples/url_shortener/main.go @@ -1,43 +1,21 @@ package main import ( - "io/ioutil" + "context" + "io" + "log" "net/http" "os" - "github.com/jackc/pgx" - "github.com/jackc/pgx/log/log15adapter" - log "gopkg.in/inconshreveable/log15.v2" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" ) -var pool *pgx.ConnPool - -// afterConnect creates the prepared statements that this application uses -func afterConnect(conn *pgx.Conn) (err error) { - _, err = conn.Prepare("getUrl", ` - select url from shortened_urls where id=$1 - `) - if err != nil { - return - } - - _, err = conn.Prepare("deleteUrl", ` - delete from shortened_urls where id=$1 - `) - if err != nil { - return - } - - _, err = conn.Prepare("putUrl", ` - insert into shortened_urls(id, url) values ($1, $2) - on conflict (id) do update set url=excluded.url - `) - return -} +var db *pgxpool.Pool func getUrlHandler(w http.ResponseWriter, req *http.Request) { var url string - err := pool.QueryRow("getUrl", req.URL.Path).Scan(&url) + err := db.QueryRow(context.Background(), "select url from shortened_urls where id=$1", req.URL.Path).Scan(&url) switch err { case nil: http.Redirect(w, req, url, http.StatusSeeOther) @@ -51,14 +29,15 @@ func getUrlHandler(w http.ResponseWriter, req *http.Request) { func putUrlHandler(w http.ResponseWriter, req *http.Request) { id := req.URL.Path var url string - if body, err := ioutil.ReadAll(req.Body); err == nil { + if body, err := io.ReadAll(req.Body); err == nil { url = string(body) } else { http.Error(w, "Internal server error", http.StatusInternalServerError) return } - if _, err := pool.Exec("putUrl", id, url); err == nil { + if _, err := db.Exec(context.Background(), `insert into shortened_urls(id, url) values ($1, $2) + on conflict (id) do update set url=excluded.url`, id, url); err == nil { w.WriteHeader(http.StatusOK) } else { http.Error(w, "Internal server error", http.StatusInternalServerError) @@ -66,7 +45,7 @@ func putUrlHandler(w http.ResponseWriter, req *http.Request) { } func deleteUrlHandler(w http.ResponseWriter, req *http.Request) { - if _, err := pool.Exec("deleteUrl", req.URL.Path); err == nil { + if _, err := db.Exec(context.Background(), "delete from shortened_urls where id=$1", req.URL.Path); err == nil { w.WriteHeader(http.StatusOK) } else { http.Error(w, "Internal server error", http.StatusInternalServerError) @@ -91,32 +70,21 @@ func urlHandler(w http.ResponseWriter, req *http.Request) { } func main() { - logger := log15adapter.NewLogger(log.New("module", "pgx")) - - var err error - connPoolConfig := pgx.ConnPoolConfig{ - ConnConfig: pgx.ConnConfig{ - Host: "127.0.0.1", - User: "jack", - Password: "jack", - Database: "url_shortener", - Logger: logger, - }, - MaxConnections: 5, - AfterConnect: afterConnect, + poolConfig, err := pgxpool.ParseConfig(os.Getenv("DATABASE_URL")) + if err != nil { + log.Fatalln("Unable to parse DATABASE_URL:", err) } - pool, err = pgx.NewConnPool(connPoolConfig) + + db, err = pgxpool.NewWithConfig(context.Background(), poolConfig) if err != nil { - log.Crit("Unable to create connection pool", "error", err) - os.Exit(1) + log.Fatalln("Unable to create connection pool:", err) } http.HandleFunc("/", urlHandler) - log.Info("Starting URL shortener on localhost:8080") + log.Println("Starting URL shortener on localhost:8080") err = http.ListenAndServe("localhost:8080", nil) if err != nil { - log.Crit("Unable to start web server", "error", err) - os.Exit(1) + log.Fatalln("Unable to start web server:", err) } } diff --git a/extended_query_builder.go b/extended_query_builder.go new file mode 100644 index 000000000..526b0e953 --- /dev/null +++ b/extended_query_builder.go @@ -0,0 +1,146 @@ +package pgx + +import ( + "fmt" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" +) + +// ExtendedQueryBuilder is used to choose the parameter formats, to format the parameters and to choose the result +// formats for an extended query. +type ExtendedQueryBuilder struct { + ParamValues [][]byte + paramValueBytes []byte + ParamFormats []int16 + ResultFormats []int16 +} + +// Build sets ParamValues, ParamFormats, and ResultFormats for use with *PgConn.ExecParams or *PgConn.ExecPrepared. If +// sd is nil then QueryExecModeExec behavior will be used. +func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error { + eqb.reset() + + if sd == nil { + for i := range args { + err := eqb.appendParam(m, 0, pgtype.TextFormatCode, args[i]) + if err != nil { + err = fmt.Errorf("failed to encode args[%d]: %w", i, err) + return err + } + } + return nil + } + + if len(sd.ParamOIDs) != len(args) { + return fmt.Errorf("mismatched param and argument count") + } + + for i := range args { + err := eqb.appendParam(m, sd.ParamOIDs[i], -1, args[i]) + if err != nil { + err = fmt.Errorf("failed to encode args[%d]: %w", i, err) + return err + } + } + + for i := range sd.Fields { + eqb.appendResultFormat(m.FormatCodeForOID(sd.Fields[i].DataTypeOID)) + } + + return nil +} + +// appendParam appends a parameter to the query. format may be -1 to automatically choose the format. If arg is nil it +// must be an untyped nil. +func (eqb *ExtendedQueryBuilder) appendParam(m *pgtype.Map, oid uint32, format int16, arg any) error { + if format == -1 { + preferredFormat := eqb.chooseParameterFormatCode(m, oid, arg) + preferredErr := eqb.appendParam(m, oid, preferredFormat, arg) + if preferredErr == nil { + return nil + } + + var otherFormat int16 + if preferredFormat == TextFormatCode { + otherFormat = BinaryFormatCode + } else { + otherFormat = TextFormatCode + } + + otherErr := eqb.appendParam(m, oid, otherFormat, arg) + if otherErr == nil { + return nil + } + + return preferredErr // return the error from the preferred format + } + + v, err := eqb.encodeExtendedParamValue(m, oid, format, arg) + if err != nil { + return err + } + + eqb.ParamFormats = append(eqb.ParamFormats, format) + eqb.ParamValues = append(eqb.ParamValues, v) + + return nil +} + +// appendResultFormat appends a result format to the query. +func (eqb *ExtendedQueryBuilder) appendResultFormat(format int16) { + eqb.ResultFormats = append(eqb.ResultFormats, format) +} + +// reset readies eqb to build another query. +func (eqb *ExtendedQueryBuilder) reset() { + eqb.ParamValues = eqb.ParamValues[0:0] + eqb.paramValueBytes = eqb.paramValueBytes[0:0] + eqb.ParamFormats = eqb.ParamFormats[0:0] + eqb.ResultFormats = eqb.ResultFormats[0:0] + + if cap(eqb.ParamValues) > 64 { + eqb.ParamValues = make([][]byte, 0, 64) + } + + if cap(eqb.paramValueBytes) > 256 { + eqb.paramValueBytes = make([]byte, 0, 256) + } + + if cap(eqb.ParamFormats) > 64 { + eqb.ParamFormats = make([]int16, 0, 64) + } + if cap(eqb.ResultFormats) > 64 { + eqb.ResultFormats = make([]int16, 0, 64) + } +} + +func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) { + if eqb.paramValueBytes == nil { + eqb.paramValueBytes = make([]byte, 0, 128) + } + + pos := len(eqb.paramValueBytes) + + buf, err := m.Encode(oid, formatCode, arg, eqb.paramValueBytes) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + eqb.paramValueBytes = buf + return eqb.paramValueBytes[pos:], nil +} + +// chooseParameterFormatCode determines the correct format code for an +// argument to a prepared statement. It defaults to TextFormatCode if no +// determination can be made. +func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid uint32, arg any) int16 { + switch arg.(type) { + case string, *string: + return TextFormatCode + } + + return m.FormatCodeForOID(oid) +} diff --git a/fastpath.go b/fastpath.go deleted file mode 100644 index 06e1354ab..000000000 --- a/fastpath.go +++ /dev/null @@ -1,117 +0,0 @@ -package pgx - -import ( - "encoding/binary" - - "github.com/jackc/pgx/pgio" - "github.com/jackc/pgx/pgproto3" - "github.com/jackc/pgx/pgtype" -) - -func newFastpath(cn *Conn) *fastpath { - return &fastpath{cn: cn, fns: make(map[string]pgtype.OID)} -} - -type fastpath struct { - cn *Conn - fns map[string]pgtype.OID -} - -func (f *fastpath) functionOID(name string) pgtype.OID { - return f.fns[name] -} - -func (f *fastpath) addFunction(name string, oid pgtype.OID) { - f.fns[name] = oid -} - -func (f *fastpath) addFunctions(rows *Rows) error { - for rows.Next() { - var name string - var oid pgtype.OID - if err := rows.Scan(&name, &oid); err != nil { - return err - } - f.addFunction(name, oid) - } - return rows.Err() -} - -type fpArg []byte - -func fpIntArg(n int32) fpArg { - res := make([]byte, 4) - binary.BigEndian.PutUint32(res, uint32(n)) - return res -} - -func fpInt64Arg(n int64) fpArg { - res := make([]byte, 8) - binary.BigEndian.PutUint64(res, uint64(n)) - return res -} - -func (f *fastpath) Call(oid pgtype.OID, args []fpArg) (res []byte, err error) { - if err := f.cn.ensureConnectionReadyForQuery(); err != nil { - return nil, err - } - - buf := f.cn.wbuf - buf = append(buf, 'F') // function call - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf = pgio.AppendInt32(buf, int32(oid)) // function object id - buf = pgio.AppendInt16(buf, 1) // # of argument format codes - buf = pgio.AppendInt16(buf, 1) // format code: binary - buf = pgio.AppendInt16(buf, int16(len(args))) // # of arguments - for _, arg := range args { - buf = pgio.AppendInt32(buf, int32(len(arg))) // length of argument - buf = append(buf, arg...) // argument value - } - buf = pgio.AppendInt16(buf, 1) // response format code (binary) - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - if _, err := f.cn.conn.Write(buf); err != nil { - return nil, err - } - - for { - msg, err := f.cn.rxMsg() - if err != nil { - return nil, err - } - switch msg := msg.(type) { - case *pgproto3.FunctionCallResponse: - res = make([]byte, len(msg.Result)) - copy(res, msg.Result) - case *pgproto3.ReadyForQuery: - f.cn.rxReadyForQuery(msg) - // done - return res, err - default: - if err := f.cn.processContextFreeMsg(msg); err != nil { - return nil, err - } - } - } -} - -func (f *fastpath) CallFn(fn string, args []fpArg) ([]byte, error) { - return f.Call(f.functionOID(fn), args) -} - -func fpInt32(data []byte, err error) (int32, error) { - if err != nil { - return 0, err - } - n := int32(binary.BigEndian.Uint32(data)) - return n, nil -} - -func fpInt64(data []byte, err error) (int64, error) { - if err != nil { - return 0, err - } - return int64(binary.BigEndian.Uint64(data)), nil -} diff --git a/go.mod b/go.mod new file mode 100644 index 000000000..9ef749935 --- /dev/null +++ b/go.mod @@ -0,0 +1,20 @@ +module github.com/jackc/pgx/v5 + +go 1.24.0 + +require ( + github.com/jackc/pgpassfile v1.0.0 + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 + github.com/jackc/puddle/v2 v2.2.2 + github.com/stretchr/testify v1.11.1 + golang.org/x/sync v0.17.0 + golang.org/x/text v0.29.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/kr/pretty v0.3.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 000000000..1d17aab79 --- /dev/null +++ b/go.sum @@ -0,0 +1,39 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= +golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= +golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/go_stdlib.go b/go_stdlib.go deleted file mode 100644 index 9372f9efa..000000000 --- a/go_stdlib.go +++ /dev/null @@ -1,61 +0,0 @@ -package pgx - -import ( - "database/sql/driver" - "reflect" -) - -// This file contains code copied from the Go standard library due to the -// required function not being public. - -// Copyright (c) 2009 The Go Authors. All rights reserved. - -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are -// met: - -// * Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above -// copyright notice, this list of conditions and the following disclaimer -// in the documentation and/or other materials provided with the -// distribution. -// * Neither the name of Google Inc. nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. - -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -// From database/sql/convert.go - -var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() - -// callValuerValue returns vr.Value(), with one exception: -// If vr.Value is an auto-generated method on a pointer type and the -// pointer is nil, it would panic at runtime in the panicwrap -// method. Treat it like nil instead. -// Issue 8415. -// -// This is so people can implement driver.Value on value types and -// still use nil pointers to those types to mean nil/NULL, just like -// string/*string. -// -// This function is mirrored in the database/sql/driver package. -func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { - if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && - rv.IsNil() && - rv.Type().Elem().Implements(valuerReflectType) { - return nil, nil - } - return vr.Value() -} diff --git a/helper_test.go b/helper_test.go index 78063107e..c5b650766 100644 --- a/helper_test.go +++ b/helper_test.go @@ -1,54 +1,71 @@ package pgx_test import ( + "context" + "os" "testing" - "github.com/jackc/pgx" + "github.com/stretchr/testify/assert" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" ) -func mustConnect(t testing.TB, config pgx.ConnConfig) *pgx.Conn { - conn, err := pgx.Connect(config) - if err != nil { - t.Fatalf("Unable to establish connection: %v", err) +var defaultConnTestRunner pgxtest.ConnTestRunner + +func init() { + defaultConnTestRunner = pgxtest.DefaultConnTestRunner() + defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + return config } - return conn } -func mustReplicationConnect(t testing.TB, config pgx.ConnConfig) *pgx.ReplicationConn { - conn, err := pgx.ReplicationConnect(config) +func mustConnectString(t testing.TB, connString string) *pgx.Conn { + conn, err := pgx.Connect(context.Background(), connString) if err != nil { t.Fatalf("Unable to establish connection: %v", err) } return conn } -func closeConn(t testing.TB, conn *pgx.Conn) { - err := conn.Close() +func mustParseConfig(t testing.TB, connString string) *pgx.ConnConfig { + config, err := pgx.ParseConfig(connString) + require.Nil(t, err) + return config +} + +func mustConnect(t testing.TB, config *pgx.ConnConfig) *pgx.Conn { + conn, err := pgx.ConnectConfig(context.Background(), config) if err != nil { - t.Fatalf("conn.Close unexpectedly failed: %v", err) + t.Fatalf("Unable to establish connection: %v", err) } + return conn } -func closeReplicationConn(t testing.TB, conn *pgx.ReplicationConn) { - err := conn.Close() +func closeConn(t testing.TB, conn *pgx.Conn) { + err := conn.Close(context.Background()) if err != nil { t.Fatalf("conn.Close unexpectedly failed: %v", err) } } -func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgx.CommandTag) { +func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...any) (commandTag pgconn.CommandTag) { var err error - if commandTag, err = conn.Exec(sql, arguments...); err != nil { + if commandTag, err = conn.Exec(context.Background(), sql, arguments...); err != nil { t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err) } return } // Do a simple query to ensure the connection is still usable -func ensureConnValid(t *testing.T, conn *pgx.Conn) { +func ensureConnValid(t testing.TB, conn *pgx.Conn) { var sum, rowCount int32 - rows, err := conn.Query("select generate_series(1,$1)", 10) + rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10) if err != nil { t.Fatalf("conn.Query failed: %v", err) } @@ -62,7 +79,7 @@ func ensureConnValid(t *testing.T, conn *pgx.Conn) { } if rows.Err() != nil { - t.Fatalf("conn.Query failed: %v", err) + t.Fatalf("conn.Query failed: %v", rows.Err()) } if rowCount != 10 { @@ -72,3 +89,50 @@ func ensureConnValid(t *testing.T, conn *pgx.Conn) { t.Error("Wrong values returned") } } + +func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName string) { + if !assert.NotNil(t, expected) { + return + } + if !assert.NotNil(t, actual) { + return + } + + assert.Equalf(t, expected.Tracer, actual.Tracer, "%s - Tracer", testName) + assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) + assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName) + assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName) + assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName) + assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) + assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) + assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) + assert.Equalf(t, expected.User, actual.User, "%s - User", testName) + assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) + assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName) + assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) + + // Can't test function equality, so just test that they are set or not. + assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName) + assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) + + if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { + if expected.TLSConfig != nil { + assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName) + assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName) + } + } + + if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) { + for i := range expected.Fallbacks { + assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i) + assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i) + + if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) { + if expected.Fallbacks[i].TLSConfig != nil { + assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName) + assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName) + } + } + } + } +} diff --git a/internal/iobufpool/iobufpool.go b/internal/iobufpool/iobufpool.go new file mode 100644 index 000000000..89e0c2273 --- /dev/null +++ b/internal/iobufpool/iobufpool.go @@ -0,0 +1,70 @@ +// Package iobufpool implements a global segregated-fit pool of buffers for IO. +// +// It uses *[]byte instead of []byte to avoid the sync.Pool allocation with Put. Unfortunately, using a pointer to avoid +// an allocation is purposely not documented. https://github.com/golang/go/issues/16323 +package iobufpool + +import "sync" + +const minPoolExpOf2 = 8 + +var pools [18]*sync.Pool + +func init() { + for i := range pools { + bufLen := 1 << (minPoolExpOf2 + i) + pools[i] = &sync.Pool{ + New: func() any { + buf := make([]byte, bufLen) + return &buf + }, + } + } +} + +// Get gets a []byte of len size with cap <= size*2. +func Get(size int) *[]byte { + i := getPoolIdx(size) + if i >= len(pools) { + buf := make([]byte, size) + return &buf + } + + ptrBuf := (pools[i].Get().(*[]byte)) + *ptrBuf = (*ptrBuf)[:size] + + return ptrBuf +} + +func getPoolIdx(size int) int { + size-- + size >>= minPoolExpOf2 + i := 0 + for size > 0 { + size >>= 1 + i++ + } + + return i +} + +// Put returns buf to the pool. +func Put(buf *[]byte) { + i := putPoolIdx(cap(*buf)) + if i < 0 { + return + } + + pools[i].Put(buf) +} + +func putPoolIdx(size int) int { + minPoolSize := 1 << minPoolExpOf2 + for i := range pools { + if size == minPoolSize< %#v, e.want => %#v", msg, e.want) + } + + return nil +} + +type expectStartupMessageStep struct { + want *pgproto3.StartupMessage + any bool +} + +func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error { + msg, err := backend.ReceiveStartupMessage() + if err != nil { + return err + } + + if e.any { + return nil + } + + if !reflect.DeepEqual(msg, e.want) { + return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want) + } + + return nil +} + +func ExpectMessage(want pgproto3.FrontendMessage) Step { + return expectMessage(want, false) +} + +func ExpectAnyMessage(want pgproto3.FrontendMessage) Step { + return expectMessage(want, true) +} + +func expectMessage(want pgproto3.FrontendMessage, any bool) Step { + if want, ok := want.(*pgproto3.StartupMessage); ok { + return &expectStartupMessageStep{want: want, any: any} + } + + return &expectMessageStep{want: want, any: any} +} + +type sendMessageStep struct { + msg pgproto3.BackendMessage +} + +func (e *sendMessageStep) Step(backend *pgproto3.Backend) error { + backend.Send(e.msg) + return backend.Flush() +} + +func SendMessage(msg pgproto3.BackendMessage) Step { + return &sendMessageStep{msg: msg} +} + +type waitForCloseMessageStep struct{} + +func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error { + for { + msg, err := backend.Receive() + if err == io.EOF { + return nil + } else if err != nil { + return err + } + + if _, ok := msg.(*pgproto3.Terminate); ok { + return nil + } + } +} + +func WaitForClose() Step { + return &waitForCloseMessageStep{} +} + +func AcceptUnauthenticatedConnRequestSteps() []Step { + return []Step{ + ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), + SendMessage(&pgproto3.AuthenticationOk{}), + SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), + SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + } +} diff --git a/internal/pgmock/pgmock_test.go b/internal/pgmock/pgmock_test.go new file mode 100644 index 000000000..7bc2fdeff --- /dev/null +++ b/internal/pgmock/pgmock_test.go @@ -0,0 +1,91 @@ +package pgmock_test + +import ( + "context" + "fmt" + "net" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5/internal/pgmock" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgproto3" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestScript(t *testing.T) { + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Query{String: "select 42"})) + script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + { + Name: []byte("?column?"), + TableOID: 0, + TableAttributeNumber: 0, + DataTypeOID: 23, + DataTypeSize: 4, + TypeModifier: -1, + Format: 0, + }, + }, + })) + script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.DataRow{ + Values: [][]byte{[]byte("42")}, + })) + script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")})) + script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'})) + script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Terminate{})) + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(time.Second)) + if err != nil { + serverErrChan <- err + return + } + + err = script.Run(pgproto3.NewBackend(conn, conn)) + if err != nil { + serverErrChan <- err + return + } + }() + + host, port, _ := strings.Cut(ln.Addr().String(), ":") + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + pgConn, err := pgconn.Connect(ctx, connStr) + require.NoError(t, err) + results, err := pgConn.Exec(ctx, "select 42").ReadAll() + assert.NoError(t, err) + + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "42", string(results[0].Rows[0][0])) + + pgConn.Close(ctx) + + assert.NoError(t, <-serverErrChan) +} diff --git a/internal/sanitize/benchmmark.sh b/internal/sanitize/benchmmark.sh new file mode 100644 index 000000000..ec0f7b03a --- /dev/null +++ b/internal/sanitize/benchmmark.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + +current_branch=$(git rev-parse --abbrev-ref HEAD) +if [ "$current_branch" == "HEAD" ]; then + current_branch=$(git rev-parse HEAD) +fi + +restore_branch() { + echo "Restoring original branch/commit: $current_branch" + git checkout "$current_branch" +} +trap restore_branch EXIT + +# Check if there are uncommitted changes +if ! git diff --quiet || ! git diff --cached --quiet; then + echo "There are uncommitted changes. Please commit or stash them before running this script." + exit 1 +fi + +# Ensure that at least one commit argument is passed +if [ "$#" -lt 1 ]; then + echo "Usage: $0 ... " + exit 1 +fi + +commits=("$@") +benchmarks_dir=benchmarks + +if ! mkdir -p "${benchmarks_dir}"; then + echo "Unable to create dir for benchmarks data" + exit 1 +fi + +# Benchmark results +bench_files=() + +# Run benchmark for each listed commit +for i in "${!commits[@]}"; do + commit="${commits[i]}" + git checkout "$commit" || { + echo "Failed to checkout $commit" + exit 1 + } + + # Sanitized commmit message + commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_') + + # Benchmark data will go there + bench_file="${benchmarks_dir}/${i}_${commit_message}.bench" + + if ! go test -bench=. -count=10 >"$bench_file"; then + echo "Benchmarking failed for commit $commit" + exit 1 + fi + + bench_files+=("$bench_file") +done + +# go install golang.org/x/perf/cmd/benchstat[@latest] +benchstat "${bench_files[@]}" diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index 53543b891..b516817cb 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -3,97 +3,209 @@ package sanitize import ( "bytes" "encoding/hex" + "fmt" + "slices" "strconv" "strings" + "sync" "time" "unicode/utf8" - - "github.com/pkg/errors" ) // Part is either a string or an int. A string is raw SQL. An int is a // argument placeholder. -type Part interface{} +type Part any type Query struct { Parts []Part } -func (q *Query) Sanitize(args ...interface{}) (string, error) { +// utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement +// character. utf8.RuneError is not an error if it is also width 3. +// +// https://github.com/jackc/pgx/issues/1380 +const replacementcharacterwidth = 3 + +const maxBufSize = 16384 // 16 Ki + +var bufPool = &pool[*bytes.Buffer]{ + new: func() *bytes.Buffer { + return &bytes.Buffer{} + }, + reset: func(b *bytes.Buffer) bool { + n := b.Len() + b.Reset() + return n < maxBufSize + }, +} + +var null = []byte("null") + +func (q *Query) Sanitize(args ...any) (string, error) { argUse := make([]bool, len(args)) - buf := &bytes.Buffer{} + buf := bufPool.get() + defer bufPool.put(buf) for _, part := range q.Parts { - var str string switch part := part.(type) { case string: - str = part + buf.WriteString(part) case int: argIdx := part - 1 + var p []byte + if argIdx < 0 { + return "", fmt.Errorf("first sql argument must be > 0") + } + if argIdx >= len(args) { - return "", errors.Errorf("insufficient arguments") + return "", fmt.Errorf("insufficient arguments") } + + // Prevent SQL injection via Line Comment Creation + // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p + buf.WriteByte(' ') + arg := args[argIdx] switch arg := arg.(type) { case nil: - str = "null" + p = null case int64: - str = strconv.FormatInt(arg, 10) + p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10) case float64: - str = strconv.FormatFloat(arg, 'f', -1, 64) + p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64) case bool: - str = strconv.FormatBool(arg) + p = strconv.AppendBool(buf.AvailableBuffer(), arg) case []byte: - str = QuoteBytes(arg) + p = QuoteBytes(buf.AvailableBuffer(), arg) case string: - str = QuoteString(arg) + p = QuoteString(buf.AvailableBuffer(), arg) case time.Time: - str = arg.Format("'2006-01-02 15:04:05.999999999Z07:00:00'") + p = arg.Truncate(time.Microsecond). + AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'") default: - return "", errors.Errorf("invalid arg type: %T", arg) + return "", fmt.Errorf("invalid arg type: %T", arg) } argUse[argIdx] = true + + buf.Write(p) + + // Prevent SQL injection via Line Comment Creation + // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p + buf.WriteByte(' ') default: - return "", errors.Errorf("invalid Part type: %T", part) + return "", fmt.Errorf("invalid Part type: %T", part) } - buf.WriteString(str) } for i, used := range argUse { if !used { - return "", errors.Errorf("unused argument: %d", i) + return "", fmt.Errorf("unused argument: %d", i) } } return buf.String(), nil } func NewQuery(sql string) (*Query, error) { - l := &sqlLexer{ - src: sql, - stateFn: rawState, + query := &Query{} + query.init(sql) + + return query, nil +} + +var sqlLexerPool = &pool[*sqlLexer]{ + new: func() *sqlLexer { + return &sqlLexer{} + }, + reset: func(sl *sqlLexer) bool { + *sl = sqlLexer{} + return true + }, +} + +func (q *Query) init(sql string) { + parts := q.Parts[:0] + if parts == nil { + // dirty, but fast heuristic to preallocate for ~90% usecases + n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1 + parts = make([]Part, 0, n) } + l := sqlLexerPool.get() + defer sqlLexerPool.put(l) + + l.src = sql + l.stateFn = rawState + l.parts = parts + for l.stateFn != nil { l.stateFn = l.stateFn(l) } - query := &Query{Parts: l.parts} - - return query, nil + q.Parts = l.parts } -func QuoteString(str string) string { - return "'" + strings.Replace(str, "'", "''", -1) + "'" +func QuoteString(dst []byte, str string) []byte { + const quote = '\'' + + // Preallocate space for the worst case scenario + dst = slices.Grow(dst, len(str)*2+2) + + // Add opening quote + dst = append(dst, quote) + + // Iterate through the string without allocating + for i := 0; i < len(str); i++ { + if str[i] == quote { + dst = append(dst, quote, quote) + } else { + dst = append(dst, str[i]) + } + } + + // Add closing quote + dst = append(dst, quote) + + return dst } -func QuoteBytes(buf []byte) string { - return `'\x` + hex.EncodeToString(buf) + "'" +func QuoteBytes(dst, buf []byte) []byte { + if len(buf) == 0 { + return append(dst, `'\x'`...) + } + + // Calculate required length + requiredLen := 3 + hex.EncodedLen(len(buf)) + 1 + + // Ensure dst has enough capacity + if cap(dst)-len(dst) < requiredLen { + newDst := make([]byte, len(dst), len(dst)+requiredLen) + copy(newDst, dst) + dst = newDst + } + + // Record original length and extend slice + origLen := len(dst) + dst = dst[:origLen+requiredLen] + + // Add prefix + dst[origLen] = '\'' + dst[origLen+1] = '\\' + dst[origLen+2] = 'x' + + // Encode bytes directly into dst + hex.Encode(dst[origLen+3:len(dst)-1], buf) + + // Add suffix + dst[len(dst)-1] = '\'' + + return dst } type sqlLexer struct { src string start int pos int + nested int // multiline comment nesting level. stateFn stateFn parts []Part } @@ -125,12 +237,26 @@ func rawState(l *sqlLexer) stateFn { l.start = l.pos return placeholderState } + case '-': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '-' { + l.pos += width + return oneLineCommentState + } + case '/': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '*' { + l.pos += width + return multilineCommentState + } case utf8.RuneError: - if l.pos-l.start > 0 { - l.parts = append(l.parts, l.src[l.start:l.pos]) - l.start = l.pos + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil } - return nil } } } @@ -148,11 +274,13 @@ func singleQuoteState(l *sqlLexer) stateFn { } l.pos += width case utf8.RuneError: - if l.pos-l.start > 0 { - l.parts = append(l.parts, l.src[l.start:l.pos]) - l.start = l.pos + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil } - return nil } } } @@ -170,11 +298,13 @@ func doubleQuoteState(l *sqlLexer) stateFn { } l.pos += width case utf8.RuneError: - if l.pos-l.start > 0 { - l.parts = append(l.parts, l.src[l.start:l.pos]) - l.start = l.pos + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil } - return nil } } } @@ -216,22 +346,115 @@ func escapeStringState(l *sqlLexer) stateFn { } l.pos += width case utf8.RuneError: - if l.pos-l.start > 0 { - l.parts = append(l.parts, l.src[l.start:l.pos]) - l.start = l.pos + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } + } +} + +func oneLineCommentState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\\': + _, width = utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + case '\n', '\r': + return rawState + case utf8.RuneError: + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil } - return nil } } } +func multilineCommentState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '/': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '*' { + l.pos += width + l.nested++ + } + case '*': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '/' { + continue + } + + l.pos += width + if l.nested == 0 { + return rawState + } + l.nested-- + + case utf8.RuneError: + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } + } +} + +var queryPool = &pool[*Query]{ + new: func() *Query { + return &Query{} + }, + reset: func(q *Query) bool { + n := len(q.Parts) + q.Parts = q.Parts[:0] + return n < 64 // drop too large queries + }, +} + // SanitizeSQL replaces placeholder values with args. It quotes and escapes args // as necessary. This function is only safe when standard_conforming_strings is // on. -func SanitizeSQL(sql string, args ...interface{}) (string, error) { - query, err := NewQuery(sql) - if err != nil { - return "", err - } +func SanitizeSQL(sql string, args ...any) (string, error) { + query := queryPool.get() + query.init(sql) + defer queryPool.put(query) + return query.Sanitize(args...) } + +type pool[E any] struct { + p sync.Pool + new func() E + reset func(E) bool +} + +func (pool *pool[E]) get() E { + v, ok := pool.p.Get().(E) + if !ok { + v = pool.new() + } + + return v +} + +func (p *pool[E]) put(v E) { + if p.reset(v) { + p.p.Put(v) + } +} diff --git a/internal/sanitize/sanitize_bench_test.go b/internal/sanitize/sanitize_bench_test.go new file mode 100644 index 000000000..baa742b11 --- /dev/null +++ b/internal/sanitize/sanitize_bench_test.go @@ -0,0 +1,62 @@ +// sanitize_benchmark_test.go +package sanitize_test + +import ( + "testing" + "time" + + "github.com/jackc/pgx/v5/internal/sanitize" +) + +var benchmarkSanitizeResult string + +const benchmarkQuery = "" + + `SELECT * + FROM "water_containers" + WHERE NOT "id" = $1 -- int64 + AND "tags" NOT IN $2 -- nil + AND "volume" > $3 -- float64 + AND "transportable" = $4 -- bool + AND position($5 IN "sign") -- bytes + AND "label" LIKE $6 -- string + AND "created_at" > $7; -- time.Time` + +var benchmarkArgs = []any{ + int64(12345), + nil, + float64(500), + true, + []byte("8BADF00D"), + "kombucha's han'dy awokowa", + time.Date(2015, 10, 1, 0, 0, 0, 0, time.UTC), +} + +func BenchmarkSanitize(b *testing.B) { + query, err := sanitize.NewQuery(benchmarkQuery) + if err != nil { + b.Fatalf("failed to create query: %v", err) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + benchmarkSanitizeResult, err = query.Sanitize(benchmarkArgs...) + if err != nil { + b.Fatalf("failed to sanitize query: %v", err) + } + } +} + +var benchmarkNewSQLResult string + +func BenchmarkSanitizeSQL(b *testing.B) { + b.ReportAllocs() + var err error + for i := 0; i < b.N; i++ { + benchmarkNewSQLResult, err = sanitize.SanitizeSQL(benchmarkQuery, benchmarkArgs...) + if err != nil { + b.Fatalf("failed to sanitize SQL: %v", err) + } + } +} diff --git a/internal/sanitize/sanitize_fuzz_test.go b/internal/sanitize/sanitize_fuzz_test.go new file mode 100644 index 000000000..2f0c41223 --- /dev/null +++ b/internal/sanitize/sanitize_fuzz_test.go @@ -0,0 +1,55 @@ +package sanitize_test + +import ( + "strings" + "testing" + + "github.com/jackc/pgx/v5/internal/sanitize" +) + +func FuzzQuoteString(f *testing.F) { + const prefix = "prefix" + f.Add("new\nline") + f.Add("sample text") + f.Add("sample q'u'o't'e's") + f.Add("select 'quoted $42', $1") + + f.Fuzz(func(t *testing.T, input string) { + got := string(sanitize.QuoteString([]byte(prefix), input)) + want := oldQuoteString(input) + + quoted, ok := strings.CutPrefix(got, prefix) + if !ok { + t.Fatalf("result has no prefix") + } + + if want != quoted { + t.Errorf("got %q", got) + t.Fatalf("want %q", want) + } + }) +} + +func FuzzQuoteBytes(f *testing.F) { + const prefix = "prefix" + f.Add([]byte(nil)) + f.Add([]byte("\n")) + f.Add([]byte("sample text")) + f.Add([]byte("sample q'u'o't'e's")) + f.Add([]byte("select 'quoted $42', $1")) + + f.Fuzz(func(t *testing.T, input []byte) { + got := string(sanitize.QuoteBytes([]byte(prefix), input)) + want := oldQuoteBytes(input) + + quoted, ok := strings.CutPrefix(got, prefix) + if !ok { + t.Fatalf("result has no prefix") + } + + if want != quoted { + t.Errorf("got %q", got) + t.Fatalf("want %q", want) + } + }) +} diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go index 9597840ee..926751534 100644 --- a/internal/sanitize/sanitize_test.go +++ b/internal/sanitize/sanitize_test.go @@ -1,9 +1,12 @@ package sanitize_test import ( + "encoding/hex" + "strings" "testing" + "time" - "github.com/jackc/pgx/internal/sanitize" + "github.com/jackc/pgx/v5/internal/sanitize" ) func TestNewQuery(t *testing.T) { @@ -59,6 +62,44 @@ func TestNewQuery(t *testing.T) { sql: `select e'escape string\' $42', $1`, expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}}, }, + { + sql: `select /* a baby's toy */ 'barbie', $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select /* a baby's toy */ 'barbie', `, 1}}, + }, + { + sql: `select /* *_* */ $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select /* *_* */ `, 1}}, + }, + { + sql: `select 42 /* /* /* 42 */ */ */, $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select 42 /* /* /* 42 */ */ */, `, 1}}, + }, + { + sql: "select -- a baby's toy\n'barbie', $1", + expected: sanitize.Query{Parts: []sanitize.Part{"select -- a baby's toy\n'barbie', ", 1}}, + }, + { + sql: "select 42 -- is a Deep Thought's favorite number", + expected: sanitize.Query{Parts: []sanitize.Part{"select 42 -- is a Deep Thought's favorite number"}}, + }, + { + sql: "select 42, -- \\nis a Deep Thought's favorite number\n$1", + expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Deep Thought's favorite number\n", 1}}, + }, + { + sql: "select 42, -- \\nis a Deep Thought's favorite number\r$1", + expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Deep Thought's favorite number\r", 1}}, + }, + { + // https://github.com/jackc/pgx/issues/1380 + sql: "select 'hello w�rld'", + expected: sanitize.Query{Parts: []sanitize.Part{"select 'hello w�rld'"}}, + }, + { + // Unterminated quoted string + sql: "select 'hello world", + expected: sanitize.Query{Parts: []sanitize.Part{"select 'hello world"}}, + }, } for i, tt := range successTests { @@ -82,53 +123,68 @@ func TestNewQuery(t *testing.T) { func TestQuerySanitize(t *testing.T) { successfulTests := []struct { query sanitize.Query - args []interface{} + args []any expected string }{ { query: sanitize.Query{Parts: []sanitize.Part{"select 42"}}, - args: []interface{}{}, + args: []any{}, expected: `select 42`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{int64(42)}, - expected: `select 42`, + args: []any{int64(42)}, + expected: `select 42 `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{float64(1.23)}, - expected: `select 1.23`, + args: []any{float64(1.23)}, + expected: `select 1.23 `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{true}, - expected: `select true`, + args: []any{true}, + expected: `select true `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{[]byte{0, 1, 2, 3, 255}}, - expected: `select '\x00010203ff'`, + args: []any{[]byte{0, 1, 2, 3, 255}}, + expected: `select '\x00010203ff' `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{nil}, - expected: `select null`, + args: []any{nil}, + expected: `select null `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{"foobar"}, - expected: `select 'foobar'`, + args: []any{"foobar"}, + expected: `select 'foobar' `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{"foo'bar"}, - expected: `select 'foo''bar'`, + args: []any{"foo'bar"}, + expected: `select 'foo''bar' `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{`foo\'bar`}, - expected: `select 'foo\''bar'`, + args: []any{`foo\'bar`}, + expected: `select 'foo\''bar' `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"insert ", 1}}, + args: []any{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)}, + expected: `insert '2020-03-01 23:59:59.999999Z' `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}}, + args: []any{int64(-1)}, + expected: `select 1- -1 `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}}, + args: []any{float64(-1)}, + expected: `select 1- -1 `, }, } @@ -146,22 +202,22 @@ func TestQuerySanitize(t *testing.T) { errorTests := []struct { query sanitize.Query - args []interface{} + args []any expected string }{ { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}}, - args: []interface{}{int64(42)}, + args: []any{int64(42)}, expected: `insufficient arguments`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}}, - args: []interface{}{int64(42)}, + args: []any{int64(42)}, expected: `unused argument: 0`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{42}, + args: []any{42}, expected: `invalid arg type: int`, }, } @@ -173,3 +229,55 @@ func TestQuerySanitize(t *testing.T) { } } } + +func TestQuoteString(t *testing.T) { + tc := func(name, input string) { + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := string(sanitize.QuoteString(nil, input)) + want := oldQuoteString(input) + + if got != want { + t.Errorf("got: %s", got) + t.Fatalf("want: %s", want) + } + }) + } + + tc("empty", "") + tc("text", "abcd") + tc("with quotes", `one's hat is always a cat`) +} + +// This function was used before optimizations. +// You should keep for testing purposes - we want to ensure there are no breaking changes. +func oldQuoteString(str string) string { + return "'" + strings.ReplaceAll(str, "'", "''") + "'" +} + +func TestQuoteBytes(t *testing.T) { + tc := func(name string, input []byte) { + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := string(sanitize.QuoteBytes(nil, input)) + want := oldQuoteBytes(input) + + if got != want { + t.Errorf("got: %s", got) + t.Fatalf("want: %s", want) + } + }) + } + + tc("nil", nil) + tc("empty", []byte{}) + tc("text", []byte("abcd")) +} + +// This function was used before optimizations. +// You should keep for testing purposes - we want to ensure there are no breaking changes. +func oldQuoteBytes(buf []byte) string { + return `'\x` + hex.EncodeToString(buf) + "'" +} diff --git a/internal/stmtcache/lru_cache.go b/internal/stmtcache/lru_cache.go new file mode 100644 index 000000000..17fec937b --- /dev/null +++ b/internal/stmtcache/lru_cache.go @@ -0,0 +1,111 @@ +package stmtcache + +import ( + "container/list" + + "github.com/jackc/pgx/v5/pgconn" +) + +// LRUCache implements Cache with a Least Recently Used (LRU) cache. +type LRUCache struct { + cap int + m map[string]*list.Element + l *list.List + invalidStmts []*pgconn.StatementDescription +} + +// NewLRUCache creates a new LRUCache. cap is the maximum size of the cache. +func NewLRUCache(cap int) *LRUCache { + return &LRUCache{ + cap: cap, + m: make(map[string]*list.Element), + l: list.New(), + } +} + +// Get returns the statement description for sql. Returns nil if not found. +func (c *LRUCache) Get(key string) *pgconn.StatementDescription { + if el, ok := c.m[key]; ok { + c.l.MoveToFront(el) + return el.Value.(*pgconn.StatementDescription) + } + + return nil +} + +// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache or +// sd.SQL has been invalidated and HandleInvalidated has not been called yet. +func (c *LRUCache) Put(sd *pgconn.StatementDescription) { + if sd.SQL == "" { + panic("cannot store statement description with empty SQL") + } + + if _, present := c.m[sd.SQL]; present { + return + } + + // The statement may have been invalidated but not yet handled. Do not readd it to the cache. + for _, invalidSD := range c.invalidStmts { + if invalidSD.SQL == sd.SQL { + return + } + } + + if c.l.Len() == c.cap { + c.invalidateOldest() + } + + el := c.l.PushFront(sd) + c.m[sd.SQL] = el +} + +// Invalidate invalidates statement description identified by sql. Does nothing if not found. +func (c *LRUCache) Invalidate(sql string) { + if el, ok := c.m[sql]; ok { + delete(c.m, sql) + c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription)) + c.l.Remove(el) + } +} + +// InvalidateAll invalidates all statement descriptions. +func (c *LRUCache) InvalidateAll() { + el := c.l.Front() + for el != nil { + c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription)) + el = el.Next() + } + + c.m = make(map[string]*list.Element) + c.l = list.New() +} + +// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. +func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription { + return c.invalidStmts +} + +// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a +// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were +// never seen by the call to GetInvalidated. +func (c *LRUCache) RemoveInvalidated() { + c.invalidStmts = nil +} + +// Len returns the number of cached prepared statement descriptions. +func (c *LRUCache) Len() int { + return c.l.Len() +} + +// Cap returns the maximum number of cached prepared statement descriptions. +func (c *LRUCache) Cap() int { + return c.cap +} + +func (c *LRUCache) invalidateOldest() { + oldest := c.l.Back() + sd := oldest.Value.(*pgconn.StatementDescription) + c.invalidStmts = append(c.invalidStmts, sd) + delete(c.m, sd.SQL) + c.l.Remove(oldest) +} diff --git a/internal/stmtcache/stmtcache.go b/internal/stmtcache/stmtcache.go new file mode 100644 index 000000000..d57bdd29e --- /dev/null +++ b/internal/stmtcache/stmtcache.go @@ -0,0 +1,45 @@ +// Package stmtcache is a cache for statement descriptions. +package stmtcache + +import ( + "crypto/sha256" + "encoding/hex" + + "github.com/jackc/pgx/v5/pgconn" +) + +// StatementName returns a statement name that will be stable for sql across multiple connections and program +// executions. +func StatementName(sql string) string { + digest := sha256.Sum256([]byte(sql)) + return "stmtcache_" + hex.EncodeToString(digest[0:24]) +} + +// Cache caches statement descriptions. +type Cache interface { + // Get returns the statement description for sql. Returns nil if not found. + Get(sql string) *pgconn.StatementDescription + + // Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache. + Put(sd *pgconn.StatementDescription) + + // Invalidate invalidates statement description identified by sql. Does nothing if not found. + Invalidate(sql string) + + // InvalidateAll invalidates all statement descriptions. + InvalidateAll() + + // GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. + GetInvalidated() []*pgconn.StatementDescription + + // RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a + // call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were + // never seen by the call to GetInvalidated. + RemoveInvalidated() + + // Len returns the number of cached prepared statement descriptions. + Len() int + + // Cap returns the maximum number of cached prepared statement descriptions. + Cap() int +} diff --git a/internal/stmtcache/unlimited_cache.go b/internal/stmtcache/unlimited_cache.go new file mode 100644 index 000000000..696413291 --- /dev/null +++ b/internal/stmtcache/unlimited_cache.go @@ -0,0 +1,77 @@ +package stmtcache + +import ( + "math" + + "github.com/jackc/pgx/v5/pgconn" +) + +// UnlimitedCache implements Cache with no capacity limit. +type UnlimitedCache struct { + m map[string]*pgconn.StatementDescription + invalidStmts []*pgconn.StatementDescription +} + +// NewUnlimitedCache creates a new UnlimitedCache. +func NewUnlimitedCache() *UnlimitedCache { + return &UnlimitedCache{ + m: make(map[string]*pgconn.StatementDescription), + } +} + +// Get returns the statement description for sql. Returns nil if not found. +func (c *UnlimitedCache) Get(sql string) *pgconn.StatementDescription { + return c.m[sql] +} + +// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache. +func (c *UnlimitedCache) Put(sd *pgconn.StatementDescription) { + if sd.SQL == "" { + panic("cannot store statement description with empty SQL") + } + + if _, present := c.m[sd.SQL]; present { + return + } + + c.m[sd.SQL] = sd +} + +// Invalidate invalidates statement description identified by sql. Does nothing if not found. +func (c *UnlimitedCache) Invalidate(sql string) { + if sd, ok := c.m[sql]; ok { + delete(c.m, sql) + c.invalidStmts = append(c.invalidStmts, sd) + } +} + +// InvalidateAll invalidates all statement descriptions. +func (c *UnlimitedCache) InvalidateAll() { + for _, sd := range c.m { + c.invalidStmts = append(c.invalidStmts, sd) + } + + c.m = make(map[string]*pgconn.StatementDescription) +} + +// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. +func (c *UnlimitedCache) GetInvalidated() []*pgconn.StatementDescription { + return c.invalidStmts +} + +// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a +// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were +// never seen by the call to GetInvalidated. +func (c *UnlimitedCache) RemoveInvalidated() { + c.invalidStmts = nil +} + +// Len returns the number of cached prepared statement descriptions. +func (c *UnlimitedCache) Len() int { + return len(c.m) +} + +// Cap returns the maximum number of cached prepared statement descriptions. +func (c *UnlimitedCache) Cap() int { + return math.MaxInt +} diff --git a/large_objects.go b/large_objects.go index e109bce25..9d21afdce 100644 --- a/large_objects.go +++ b/large_objects.go @@ -1,56 +1,24 @@ package pgx import ( + "context" + "errors" "io" - "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/v5/pgtype" ) -// LargeObjects is a structure used to access the large objects API. It is only -// valid within the transaction where it was created. +// The PostgreSQL wire protocol has a limit of 1 GB - 1 per message. See definition of +// PQ_LARGE_MESSAGE_LIMIT in the PostgreSQL source code. To allow for the other data +// in the message,maxLargeObjectMessageLength should be no larger than 1 GB - 1 KB. +var maxLargeObjectMessageLength = 1024*1024*1024 - 1024 + +// LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it +// was created. // // For more details see: http://www.postgresql.org/docs/current/static/largeobjects.html type LargeObjects struct { - // Has64 is true if the server is capable of working with 64-bit numbers - Has64 bool - fp *fastpath -} - -const largeObjectFns = `select proname, oid from pg_catalog.pg_proc -where proname in ( -'lo_open', -'lo_close', -'lo_create', -'lo_unlink', -'lo_lseek', -'lo_lseek64', -'lo_tell', -'lo_tell64', -'lo_truncate', -'lo_truncate64', -'loread', -'lowrite') -and pronamespace = (select oid from pg_catalog.pg_namespace where nspname = 'pg_catalog')` - -// LargeObjects returns a LargeObjects instance for the transaction. -func (tx *Tx) LargeObjects() (*LargeObjects, error) { - if tx.conn.fp == nil { - tx.conn.fp = newFastpath(tx.conn) - } - if _, exists := tx.conn.fp.fns["lo_open"]; !exists { - res, err := tx.Query(largeObjectFns) - if err != nil { - return nil, err - } - if err := tx.conn.fp.addFunctions(res); err != nil { - return nil, err - } - } - - lo := &LargeObjects{fp: tx.conn.fp} - _, lo.Has64 = lo.fp.fns["lo_lseek64"] - - return lo, nil + tx Tx } type LargeObjectMode int32 @@ -60,90 +28,134 @@ const ( LargeObjectModeRead LargeObjectMode = 0x40000 ) -// Create creates a new large object. If id is zero, the server assigns an -// unused OID. -func (o *LargeObjects) Create(id pgtype.OID) (pgtype.OID, error) { - newOID, err := fpInt32(o.fp.CallFn("lo_create", []fpArg{fpIntArg(int32(id))})) - return pgtype.OID(newOID), err +// Create creates a new large object. If oid is zero, the server assigns an unused OID. +func (o *LargeObjects) Create(ctx context.Context, oid uint32) (uint32, error) { + err := o.tx.QueryRow(ctx, "select lo_create($1)", oid).Scan(&oid) + return oid, err } -// Open opens an existing large object with the given mode. -func (o *LargeObjects) Open(oid pgtype.OID, mode LargeObjectMode) (*LargeObject, error) { - fd, err := fpInt32(o.fp.CallFn("lo_open", []fpArg{fpIntArg(int32(oid)), fpIntArg(int32(mode))})) - return &LargeObject{fd: fd, lo: o}, err +// Open opens an existing large object with the given mode. ctx will also be used for all operations on the opened large +// object. +func (o *LargeObjects) Open(ctx context.Context, oid uint32, mode LargeObjectMode) (*LargeObject, error) { + var fd int32 + err := o.tx.QueryRow(ctx, "select lo_open($1, $2)", oid, mode).Scan(&fd) + if err != nil { + return nil, err + } + return &LargeObject{fd: fd, tx: o.tx, ctx: ctx}, nil } // Unlink removes a large object from the database. -func (o *LargeObjects) Unlink(oid pgtype.OID) error { - _, err := o.fp.CallFn("lo_unlink", []fpArg{fpIntArg(int32(oid))}) - return err +func (o *LargeObjects) Unlink(ctx context.Context, oid uint32) error { + var result int32 + err := o.tx.QueryRow(ctx, "select lo_unlink($1)", oid).Scan(&result) + if err != nil { + return err + } + + if result != 1 { + return errors.New("failed to remove large object") + } + + return nil } -// A LargeObject is a large object stored on the server. It is only valid within -// the transaction that it was initialized in. It implements these interfaces: +// A LargeObject is a large object stored on the server. It is only valid within the transaction that it was initialized +// in. It uses the context it was initialized with for all operations. It implements these interfaces: // -// io.Writer -// io.Reader -// io.Seeker -// io.Closer +// io.Writer +// io.Reader +// io.Seeker +// io.Closer type LargeObject struct { - fd int32 - lo *LargeObjects + ctx context.Context + tx Tx + fd int32 } -// Write writes p to the large object and returns the number of bytes written -// and an error if not all of p was written. +// Write writes p to the large object and returns the number of bytes written and an error if not all of p was written. func (o *LargeObject) Write(p []byte) (int, error) { - n, err := fpInt32(o.lo.fp.CallFn("lowrite", []fpArg{fpIntArg(o.fd), p})) - return int(n), err + nTotal := 0 + for { + expected := len(p) - nTotal + if expected == 0 { + break + } else if expected > maxLargeObjectMessageLength { + expected = maxLargeObjectMessageLength + } + + var n int + err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p[nTotal:nTotal+expected]).Scan(&n) + if err != nil { + return nTotal, err + } + + if n < 0 { + return nTotal, errors.New("failed to write to large object") + } + + nTotal += n + + if n < expected { + return nTotal, errors.New("short write to large object") + } else if n > expected { + return nTotal, errors.New("invalid write to large object") + } + } + + return nTotal, nil } // Read reads up to len(p) bytes into p returning the number of bytes read. func (o *LargeObject) Read(p []byte) (int, error) { - res, err := o.lo.fp.CallFn("loread", []fpArg{fpIntArg(o.fd), fpIntArg(int32(len(p)))}) - if len(res) < len(p) { - err = io.EOF + nTotal := 0 + for { + expected := len(p) - nTotal + if expected == 0 { + break + } else if expected > maxLargeObjectMessageLength { + expected = maxLargeObjectMessageLength + } + + res := pgtype.PreallocBytes(p[nTotal:]) + err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, expected).Scan(&res) + // We compute expected so that it always fits into p, so it should never happen + // that PreallocBytes's ScanBytes had to allocate a new slice. + nTotal += len(res) + if err != nil { + return nTotal, err + } + + if len(res) < expected { + return nTotal, io.EOF + } else if len(res) > expected { + return nTotal, errors.New("invalid read of large object") + } } - return copy(p, res), err + + return nTotal, nil } // Seek moves the current location pointer to the new location specified by offset. func (o *LargeObject) Seek(offset int64, whence int) (n int64, err error) { - if o.lo.Has64 { - n, err = fpInt64(o.lo.fp.CallFn("lo_lseek64", []fpArg{fpIntArg(o.fd), fpInt64Arg(offset), fpIntArg(int32(whence))})) - } else { - var n32 int32 - n32, err = fpInt32(o.lo.fp.CallFn("lo_lseek", []fpArg{fpIntArg(o.fd), fpIntArg(int32(offset)), fpIntArg(int32(whence))})) - n = int64(n32) - } - return + err = o.tx.QueryRow(o.ctx, "select lo_lseek64($1, $2, $3)", o.fd, offset, whence).Scan(&n) + return n, err } -// Tell returns the current read or write location of the large object -// descriptor. +// Tell returns the current read or write location of the large object descriptor. func (o *LargeObject) Tell() (n int64, err error) { - if o.lo.Has64 { - n, err = fpInt64(o.lo.fp.CallFn("lo_tell64", []fpArg{fpIntArg(o.fd)})) - } else { - var n32 int32 - n32, err = fpInt32(o.lo.fp.CallFn("lo_tell", []fpArg{fpIntArg(o.fd)})) - n = int64(n32) - } - return + err = o.tx.QueryRow(o.ctx, "select lo_tell64($1)", o.fd).Scan(&n) + return n, err } -// Trunctes the large object to size. +// Truncate the large object to size. func (o *LargeObject) Truncate(size int64) (err error) { - if o.lo.Has64 { - _, err = o.lo.fp.CallFn("lo_truncate64", []fpArg{fpIntArg(o.fd), fpInt64Arg(size)}) - } else { - _, err = o.lo.fp.CallFn("lo_truncate", []fpArg{fpIntArg(o.fd), fpIntArg(int32(size))}) - } - return + _, err = o.tx.Exec(o.ctx, "select lo_truncate64($1, $2)", o.fd, size) + return err } -// Close closees the large object descriptor. +// Close the large object descriptor. func (o *LargeObject) Close() error { - _, err := o.lo.fp.CallFn("lo_close", []fpArg{fpIntArg(o.fd)}) + _, err := o.tx.Exec(o.ctx, "select lo_close($1)", o.fd) return err } diff --git a/large_objects_private_test.go b/large_objects_private_test.go new file mode 100644 index 000000000..36eca8f06 --- /dev/null +++ b/large_objects_private_test.go @@ -0,0 +1,20 @@ +package pgx + +import ( + "testing" +) + +// SetMaxLargeObjectMessageLength sets internal maxLargeObjectMessageLength variable +// to the given length for the duration of the test. +// +// Tests using this helper should not use t.Parallel(). +func SetMaxLargeObjectMessageLength(t *testing.T, length int) { + t.Helper() + + original := maxLargeObjectMessageLength + t.Cleanup(func() { + maxLargeObjectMessageLength = original + }) + + maxLargeObjectMessageLength = length +} diff --git a/large_objects_test.go b/large_objects_test.go index a19c851d2..de2eed0d8 100644 --- a/large_objects_test.go +++ b/large_objects_test.go @@ -1,36 +1,77 @@ package pgx_test import ( + "context" "io" + "os" "testing" + "time" - "github.com/jackc/pgx" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxtest" ) func TestLargeObjects(t *testing.T) { - t.Parallel() + // We use a very short limit to test chunking logic. + pgx.SetMaxLargeObjectMessageLength(t, 2) - conn, err := pgx.Connect(*defaultConnConfig) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) if err != nil { t.Fatal(err) } - tx, err := conn.Begin() + pgxtest.SkipCockroachDB(t, conn, "Server does support large objects") + + tx, err := conn.Begin(ctx) if err != nil { t.Fatal(err) } - lo, err := tx.LargeObjects() + testLargeObjects(t, ctx, tx) +} + +func TestLargeObjectsSimpleProtocol(t *testing.T) { + // We use a very short limit to test chunking logic. + pgx.SetMaxLargeObjectMessageLength(t, 2) + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) if err != nil { t.Fatal(err) } - id, err := lo.Create(0) + config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol + + conn, err := pgx.ConnectConfig(ctx, config) if err != nil { t.Fatal(err) } - obj, err := lo.Open(id, pgx.LargeObjectModeRead|pgx.LargeObjectModeWrite) + pgxtest.SkipCockroachDB(t, conn, "Server does support large objects") + + tx, err := conn.Begin(ctx) + if err != nil { + t.Fatal(err) + } + + testLargeObjects(t, ctx, tx) +} + +func testLargeObjects(t *testing.T, ctx context.Context, tx pgx.Tx) { + lo := tx.LargeObjects() + + id, err := lo.Create(ctx, 0) + if err != nil { + t.Fatal(err) + } + + obj, err := lo.Open(ctx, id, pgx.LargeObjectModeRead|pgx.LargeObjectModeWrite) if err != nil { t.Fatal(err) } @@ -109,13 +150,157 @@ func TestLargeObjects(t *testing.T) { t.Fatal(err) } - err = lo.Unlink(id) + err = lo.Unlink(ctx, id) + if err != nil { + t.Fatal(err) + } + + _, err = lo.Open(ctx, id, pgx.LargeObjectModeRead) + if e, ok := err.(*pgconn.PgError); !ok || e.Code != "42704" { + t.Errorf("Expected undefined_object error (42704), got %#v", err) + } +} + +func TestLargeObjectsMultipleTransactions(t *testing.T) { + // We use a very short limit to test chunking logic. + pgx.SetMaxLargeObjectMessageLength(t, 2) + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + t.Fatal(err) + } + + pgxtest.SkipCockroachDB(t, conn, "Server does support large objects") + + tx, err := conn.Begin(ctx) + if err != nil { + t.Fatal(err) + } + + lo := tx.LargeObjects() + + id, err := lo.Create(ctx, 0) + if err != nil { + t.Fatal(err) + } + + obj, err := lo.Open(ctx, id, pgx.LargeObjectModeWrite) + if err != nil { + t.Fatal(err) + } + + n, err := obj.Write([]byte("testing")) + if err != nil { + t.Fatal(err) + } + if n != 7 { + t.Errorf("Expected n to be 7, got %d", n) + } + + // Commit the first transaction + err = tx.Commit(ctx) + if err != nil { + t.Fatal(err) + } + + // IMPORTANT: Use the same connection for another query + query := `select n from generate_series(1,10) n` + rows, err := conn.Query(ctx, query) + if err != nil { + t.Fatal(err) + } + rows.Close() + + // Start a new transaction + tx2, err := conn.Begin(ctx) + if err != nil { + t.Fatal(err) + } + + lo2 := tx2.LargeObjects() + + // Reopen the large object in the new transaction + obj2, err := lo2.Open(ctx, id, pgx.LargeObjectModeRead|pgx.LargeObjectModeWrite) + if err != nil { + t.Fatal(err) + } + + pos, err := obj2.Seek(1, 0) + if err != nil { + t.Fatal(err) + } + if pos != 1 { + t.Errorf("Expected pos to be 1, got %d", pos) + } + + res := make([]byte, 6) + n, err = obj2.Read(res) + if err != nil { + t.Fatal(err) + } + if string(res) != "esting" { + t.Errorf(`Expected res to be "esting", got %q`, res) + } + if n != 6 { + t.Errorf("Expected n to be 6, got %d", n) + } + + n, err = obj2.Read(res) + if err != io.EOF { + t.Error("Expected io.EOF, go nil") + } + if n != 0 { + t.Errorf("Expected n to be 0, got %d", n) + } + + pos, err = obj2.Tell() + if err != nil { + t.Fatal(err) + } + if pos != 7 { + t.Errorf("Expected pos to be 7, got %d", pos) + } + + err = obj2.Truncate(1) + if err != nil { + t.Fatal(err) + } + + pos, err = obj2.Seek(-1, 2) + if err != nil { + t.Fatal(err) + } + if pos != 0 { + t.Errorf("Expected pos to be 0, got %d", pos) + } + + res = make([]byte, 2) + n, err = obj2.Read(res) + if err != io.EOF { + t.Errorf("Expected err to be io.EOF, got %v", err) + } + if n != 1 { + t.Errorf("Expected n to be 1, got %d", n) + } + if res[0] != 't' { + t.Errorf("Expected res[0] to be 't', got %v", res[0]) + } + + err = obj2.Close() + if err != nil { + t.Fatal(err) + } + + err = lo2.Unlink(ctx, id) if err != nil { t.Fatal(err) } - _, err = lo.Open(id, pgx.LargeObjectModeRead) - if e, ok := err.(pgx.PgError); !ok || e.Code != "42704" { + _, err = lo2.Open(ctx, id, pgx.LargeObjectModeRead) + if e, ok := err.(*pgconn.PgError); !ok || e.Code != "42704" { t.Errorf("Expected undefined_object error (42704), got %#v", err) } } diff --git a/log/log15adapter/adapter.go b/log/log15adapter/adapter.go deleted file mode 100644 index 8623a3806..000000000 --- a/log/log15adapter/adapter.go +++ /dev/null @@ -1,47 +0,0 @@ -// Package log15adapter provides a logger that writes to a github.com/inconshreveable/log15.Logger -// log. -package log15adapter - -import ( - "github.com/jackc/pgx" -) - -// Log15Logger interface defines the subset of -// github.com/inconshreveable/log15.Logger that this adapter uses. -type Log15Logger interface { - Debug(msg string, ctx ...interface{}) - Info(msg string, ctx ...interface{}) - Warn(msg string, ctx ...interface{}) - Error(msg string, ctx ...interface{}) - Crit(msg string, ctx ...interface{}) -} - -type Logger struct { - l Log15Logger -} - -func NewLogger(l Log15Logger) *Logger { - return &Logger{l: l} -} - -func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { - logArgs := make([]interface{}, 0, len(data)) - for k, v := range data { - logArgs = append(logArgs, k, v) - } - - switch level { - case pgx.LogLevelTrace: - l.l.Debug(msg, append(logArgs, "PGX_LOG_LEVEL", level)...) - case pgx.LogLevelDebug: - l.l.Debug(msg, logArgs...) - case pgx.LogLevelInfo: - l.l.Info(msg, logArgs...) - case pgx.LogLevelWarn: - l.l.Warn(msg, logArgs...) - case pgx.LogLevelError: - l.l.Error(msg, logArgs...) - default: - l.l.Error(msg, append(logArgs, "INVALID_PGX_LOG_LEVEL", level)...) - } -} diff --git a/log/logrusadapter/adapter.go b/log/logrusadapter/adapter.go deleted file mode 100644 index 0ee0da0bb..000000000 --- a/log/logrusadapter/adapter.go +++ /dev/null @@ -1,40 +0,0 @@ -// Package logrusadapter provides a logger that writes to a github.com/sirupsen/logrus.Logger -// log. -package logrusadapter - -import ( - "github.com/jackc/pgx" - "github.com/sirupsen/logrus" -) - -type Logger struct { - l logrus.FieldLogger -} - -func NewLogger(l logrus.FieldLogger) *Logger { - return &Logger{l: l} -} - -func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { - var logger logrus.FieldLogger - if data != nil { - logger = l.l.WithFields(data) - } else { - logger = l.l - } - - switch level { - case pgx.LogLevelTrace: - logger.WithField("PGX_LOG_LEVEL", level).Debug(msg) - case pgx.LogLevelDebug: - logger.Debug(msg) - case pgx.LogLevelInfo: - logger.Info(msg) - case pgx.LogLevelWarn: - logger.Warn(msg) - case pgx.LogLevelError: - logger.Error(msg) - default: - logger.WithField("INVALID_PGX_LOG_LEVEL", level).Error(msg) - } -} diff --git a/log/testingadapter/adapter.go b/log/testingadapter/adapter.go index 6c9cde838..c901a6a65 100644 --- a/log/testingadapter/adapter.go +++ b/log/testingadapter/adapter.go @@ -3,15 +3,16 @@ package testingadapter import ( + "context" "fmt" - "github.com/jackc/pgx" + "github.com/jackc/pgx/v5/tracelog" ) // TestingLogger interface defines the subset of testing.TB methods used by this // adapter. type TestingLogger interface { - Log(args ...interface{}) + Log(args ...any) } type Logger struct { @@ -22,8 +23,8 @@ func NewLogger(l TestingLogger) *Logger { return &Logger{l: l} } -func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { - logArgs := make([]interface{}, 0, 2+len(data)) +func (l *Logger) Log(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]any) { + logArgs := make([]any, 0, 2+len(data)) logArgs = append(logArgs, level, msg) for k, v := range data { logArgs = append(logArgs, fmt.Sprintf("%s=%v", k, v)) diff --git a/log/zapadapter/adapter.go b/log/zapadapter/adapter.go deleted file mode 100644 index 82263b6e7..000000000 --- a/log/zapadapter/adapter.go +++ /dev/null @@ -1,40 +0,0 @@ -// Package zapadapter provides a logger that writes to a go.uber.org/zap.Logger. -package zapadapter - -import ( - "github.com/jackc/pgx" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" -) - -type Logger struct { - logger *zap.Logger -} - -func NewLogger(logger *zap.Logger) *Logger { - return &Logger{logger: logger.WithOptions(zap.AddCallerSkip(1))} -} - -func (pl *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { - fields := make([]zapcore.Field, len(data)) - i := 0 - for k, v := range data { - fields[i] = zap.Reflect(k, v) - i++ - } - - switch level { - case pgx.LogLevelTrace: - pl.logger.Debug(msg, append(fields, zap.Stringer("PGX_LOG_LEVEL", level))...) - case pgx.LogLevelDebug: - pl.logger.Debug(msg, fields...) - case pgx.LogLevelInfo: - pl.logger.Info(msg, fields...) - case pgx.LogLevelWarn: - pl.logger.Warn(msg, fields...) - case pgx.LogLevelError: - pl.logger.Error(msg, fields...) - default: - pl.logger.Error(msg, append(fields, zap.Stringer("PGX_LOG_LEVEL", level))...) - } -} diff --git a/logger.go b/logger.go deleted file mode 100644 index 528698b19..000000000 --- a/logger.go +++ /dev/null @@ -1,98 +0,0 @@ -package pgx - -import ( - "encoding/hex" - "fmt" - - "github.com/pkg/errors" -) - -// The values for log levels are chosen such that the zero value means that no -// log level was specified. -const ( - LogLevelTrace = 6 - LogLevelDebug = 5 - LogLevelInfo = 4 - LogLevelWarn = 3 - LogLevelError = 2 - LogLevelNone = 1 -) - -// LogLevel represents the pgx logging level. See LogLevel* constants for -// possible values. -type LogLevel int - -func (ll LogLevel) String() string { - switch ll { - case LogLevelTrace: - return "trace" - case LogLevelDebug: - return "debug" - case LogLevelInfo: - return "info" - case LogLevelWarn: - return "warn" - case LogLevelError: - return "error" - case LogLevelNone: - return "none" - default: - return fmt.Sprintf("invalid level %d", ll) - } -} - -// Logger is the interface used to get logging from pgx internals. -type Logger interface { - // Log a message at the given level with data key/value pairs. data may be nil. - Log(level LogLevel, msg string, data map[string]interface{}) -} - -// LogLevelFromString converts log level string to constant -// -// Valid levels: -// trace -// debug -// info -// warn -// error -// none -func LogLevelFromString(s string) (LogLevel, error) { - switch s { - case "trace": - return LogLevelTrace, nil - case "debug": - return LogLevelDebug, nil - case "info": - return LogLevelInfo, nil - case "warn": - return LogLevelWarn, nil - case "error": - return LogLevelError, nil - case "none": - return LogLevelNone, nil - default: - return 0, errors.New("invalid log level") - } -} - -func logQueryArgs(args []interface{}) []interface{} { - logArgs := make([]interface{}, 0, len(args)) - - for _, a := range args { - switch v := a.(type) { - case []byte: - if len(v) < 64 { - a = hex.EncodeToString(v) - } else { - a = fmt.Sprintf("%x (truncated %d bytes)", v[:64], len(v)-64) - } - case string: - if len(v) > 64 { - a = fmt.Sprintf("%s (truncated %d bytes)", v[:64], len(v)-64) - } - } - logArgs = append(logArgs, a) - } - - return logArgs -} diff --git a/messages.go b/messages.go deleted file mode 100644 index 97e89295e..000000000 --- a/messages.go +++ /dev/null @@ -1,213 +0,0 @@ -package pgx - -import ( - "math" - "reflect" - "time" - - "github.com/jackc/pgx/pgio" - "github.com/jackc/pgx/pgtype" -) - -const ( - copyData = 'd' - copyFail = 'f' - copyDone = 'c' - varHeaderSize = 4 -) - -type FieldDescription struct { - Name string - Table pgtype.OID - AttributeNumber uint16 - DataType pgtype.OID - DataTypeSize int16 - DataTypeName string - Modifier uint32 - FormatCode int16 -} - -func (fd FieldDescription) Length() (int64, bool) { - switch fd.DataType { - case pgtype.TextOID, pgtype.ByteaOID: - return math.MaxInt64, true - case pgtype.VarcharOID, pgtype.BPCharArrayOID: - return int64(fd.Modifier - varHeaderSize), true - default: - return 0, false - } -} - -func (fd FieldDescription) PrecisionScale() (precision, scale int64, ok bool) { - switch fd.DataType { - case pgtype.NumericOID: - mod := fd.Modifier - varHeaderSize - precision = int64((mod >> 16) & 0xffff) - scale = int64(mod & 0xffff) - return precision, scale, true - default: - return 0, 0, false - } -} - -func (fd FieldDescription) Type() reflect.Type { - switch fd.DataType { - case pgtype.Int8OID: - return reflect.TypeOf(int64(0)) - case pgtype.Int4OID: - return reflect.TypeOf(int32(0)) - case pgtype.Int2OID: - return reflect.TypeOf(int16(0)) - case pgtype.VarcharOID, pgtype.BPCharArrayOID, pgtype.TextOID: - return reflect.TypeOf("") - case pgtype.BoolOID: - return reflect.TypeOf(false) - case pgtype.NumericOID: - return reflect.TypeOf(float64(0)) - case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID: - return reflect.TypeOf(time.Time{}) - case pgtype.ByteaOID: - return reflect.TypeOf([]byte(nil)) - default: - return reflect.TypeOf(new(interface{})).Elem() - } -} - -// PgError represents an error reported by the PostgreSQL server. See -// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for -// detailed field description. -type PgError struct { - Severity string - Code string - Message string - Detail string - Hint string - Position int32 - InternalPosition int32 - InternalQuery string - Where string - SchemaName string - TableName string - ColumnName string - DataTypeName string - ConstraintName string - File string - Line int32 - Routine string -} - -func (pe PgError) Error() string { - return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" -} - -// Notice represents a notice response message reported by the PostgreSQL -// server. Be aware that this is distinct from LISTEN/NOTIFY notification. -type Notice PgError - -// appendParse appends a PostgreSQL wire protocol parse message to buf and returns it. -func appendParse(buf []byte, name string, query string, parameterOIDs []pgtype.OID) []byte { - buf = append(buf, 'P') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, name...) - buf = append(buf, 0) - buf = append(buf, query...) - buf = append(buf, 0) - - buf = pgio.AppendInt16(buf, int16(len(parameterOIDs))) - for _, oid := range parameterOIDs { - buf = pgio.AppendUint32(buf, uint32(oid)) - } - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - return buf -} - -// appendDescribe appends a PostgreSQL wire protocol describe message to buf and returns it. -func appendDescribe(buf []byte, objectType byte, name string) []byte { - buf = append(buf, 'D') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, objectType) - buf = append(buf, name...) - buf = append(buf, 0) - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - return buf -} - -// appendSync appends a PostgreSQL wire protocol sync message to buf and returns it. -func appendSync(buf []byte) []byte { - buf = append(buf, 'S') - buf = pgio.AppendInt32(buf, 4) - - return buf -} - -// appendBind appends a PostgreSQL wire protocol bind message to buf and returns it. -func appendBind( - buf []byte, - destinationPortal, - preparedStatement string, - connInfo *pgtype.ConnInfo, - parameterOIDs []pgtype.OID, - arguments []interface{}, - resultFormatCodes []int16, -) ([]byte, error) { - buf = append(buf, 'B') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, destinationPortal...) - buf = append(buf, 0) - buf = append(buf, preparedStatement...) - buf = append(buf, 0) - - buf = pgio.AppendInt16(buf, int16(len(parameterOIDs))) - for i, oid := range parameterOIDs { - buf = pgio.AppendInt16(buf, chooseParameterFormatCode(connInfo, oid, arguments[i])) - } - - buf = pgio.AppendInt16(buf, int16(len(arguments))) - for i, oid := range parameterOIDs { - var err error - buf, err = encodePreparedStatementArgument(connInfo, buf, oid, arguments[i]) - if err != nil { - return nil, err - } - } - - buf = pgio.AppendInt16(buf, int16(len(resultFormatCodes))) - for _, fc := range resultFormatCodes { - buf = pgio.AppendInt16(buf, fc) - } - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - return buf, nil -} - -// appendExecute appends a PostgreSQL wire protocol execute message to buf and returns it. -func appendExecute(buf []byte, portal string, maxRows uint32) []byte { - buf = append(buf, 'E') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf = append(buf, portal...) - buf = append(buf, 0) - buf = pgio.AppendUint32(buf, maxRows) - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - return buf -} - -// appendQuery appends a PostgreSQL wire protocol query message to buf and returns it. -func appendQuery(buf []byte, query string) []byte { - buf = append(buf, 'Q') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, query...) - buf = append(buf, 0) - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - return buf -} diff --git a/multitracer/tracer.go b/multitracer/tracer.go new file mode 100644 index 000000000..acff17398 --- /dev/null +++ b/multitracer/tracer.go @@ -0,0 +1,152 @@ +// Package multitracer provides a Tracer that can combine several tracers into one. +package multitracer + +import ( + "context" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +// Tracer can combine several tracers into one. +// You can use New to automatically split tracers by interface. +type Tracer struct { + QueryTracers []pgx.QueryTracer + BatchTracers []pgx.BatchTracer + CopyFromTracers []pgx.CopyFromTracer + PrepareTracers []pgx.PrepareTracer + ConnectTracers []pgx.ConnectTracer + PoolAcquireTracers []pgxpool.AcquireTracer + PoolReleaseTracers []pgxpool.ReleaseTracer +} + +// New returns new Tracer from tracers with automatically split tracers by interface. +func New(tracers ...pgx.QueryTracer) *Tracer { + var t Tracer + + for _, tracer := range tracers { + t.QueryTracers = append(t.QueryTracers, tracer) + + if batchTracer, ok := tracer.(pgx.BatchTracer); ok { + t.BatchTracers = append(t.BatchTracers, batchTracer) + } + + if copyFromTracer, ok := tracer.(pgx.CopyFromTracer); ok { + t.CopyFromTracers = append(t.CopyFromTracers, copyFromTracer) + } + + if prepareTracer, ok := tracer.(pgx.PrepareTracer); ok { + t.PrepareTracers = append(t.PrepareTracers, prepareTracer) + } + + if connectTracer, ok := tracer.(pgx.ConnectTracer); ok { + t.ConnectTracers = append(t.ConnectTracers, connectTracer) + } + + if poolAcquireTracer, ok := tracer.(pgxpool.AcquireTracer); ok { + t.PoolAcquireTracers = append(t.PoolAcquireTracers, poolAcquireTracer) + } + + if poolReleaseTracer, ok := tracer.(pgxpool.ReleaseTracer); ok { + t.PoolReleaseTracers = append(t.PoolReleaseTracers, poolReleaseTracer) + } + } + + return &t +} + +func (t *Tracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + for _, tracer := range t.QueryTracers { + ctx = tracer.TraceQueryStart(ctx, conn, data) + } + + return ctx +} + +func (t *Tracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + for _, tracer := range t.QueryTracers { + tracer.TraceQueryEnd(ctx, conn, data) + } +} + +func (t *Tracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + for _, tracer := range t.BatchTracers { + ctx = tracer.TraceBatchStart(ctx, conn, data) + } + + return ctx +} + +func (t *Tracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + for _, tracer := range t.BatchTracers { + tracer.TraceBatchQuery(ctx, conn, data) + } +} + +func (t *Tracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + for _, tracer := range t.BatchTracers { + tracer.TraceBatchEnd(ctx, conn, data) + } +} + +func (t *Tracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context { + for _, tracer := range t.CopyFromTracers { + ctx = tracer.TraceCopyFromStart(ctx, conn, data) + } + + return ctx +} + +func (t *Tracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) { + for _, tracer := range t.CopyFromTracers { + tracer.TraceCopyFromEnd(ctx, conn, data) + } +} + +func (t *Tracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context { + for _, tracer := range t.PrepareTracers { + ctx = tracer.TracePrepareStart(ctx, conn, data) + } + + return ctx +} + +func (t *Tracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { + for _, tracer := range t.PrepareTracers { + tracer.TracePrepareEnd(ctx, conn, data) + } +} + +func (t *Tracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context { + for _, tracer := range t.ConnectTracers { + ctx = tracer.TraceConnectStart(ctx, data) + } + + return ctx +} + +func (t *Tracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) { + for _, tracer := range t.ConnectTracers { + tracer.TraceConnectEnd(ctx, data) + } +} + +func (t *Tracer) TraceAcquireStart(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context { + for _, tracer := range t.PoolAcquireTracers { + ctx = tracer.TraceAcquireStart(ctx, pool, data) + } + + return ctx +} + +func (t *Tracer) TraceAcquireEnd(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) { + for _, tracer := range t.PoolAcquireTracers { + tracer.TraceAcquireEnd(ctx, pool, data) + } +} + +func (t *Tracer) TraceRelease(pool *pgxpool.Pool, data pgxpool.TraceReleaseData) { + for _, tracer := range t.PoolReleaseTracers { + tracer.TraceRelease(pool, data) + } +} diff --git a/multitracer/tracer_test.go b/multitracer/tracer_test.go new file mode 100644 index 000000000..aa5ccd080 --- /dev/null +++ b/multitracer/tracer_test.go @@ -0,0 +1,115 @@ +package multitracer_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/multitracer" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/require" +) + +type testFullTracer struct{} + +func (tt *testFullTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + return ctx +} + +func (tt *testFullTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { +} + +func (tt *testFullTracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + return ctx +} + +func (tt *testFullTracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { +} + +func (tt *testFullTracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { +} + +func (tt *testFullTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context { + return ctx +} + +func (tt *testFullTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) { +} + +func (tt *testFullTracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context { + return ctx +} + +func (tt *testFullTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { +} + +func (tt *testFullTracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context { + return ctx +} + +func (tt *testFullTracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) { +} + +func (tt *testFullTracer) TraceAcquireStart(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context { + return ctx +} + +func (tt *testFullTracer) TraceAcquireEnd(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) { +} + +func (tt *testFullTracer) TraceRelease(pool *pgxpool.Pool, data pgxpool.TraceReleaseData) { +} + +type testCopyTracer struct{} + +func (tt *testCopyTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + return ctx +} + +func (tt *testCopyTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { +} + +func (tt *testCopyTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context { + return ctx +} + +func (tt *testCopyTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) { +} + +func TestNew(t *testing.T) { + t.Parallel() + + fullTracer := &testFullTracer{} + copyTracer := &testCopyTracer{} + + mt := multitracer.New(fullTracer, copyTracer) + require.Equal( + t, + &multitracer.Tracer{ + QueryTracers: []pgx.QueryTracer{ + fullTracer, + copyTracer, + }, + BatchTracers: []pgx.BatchTracer{ + fullTracer, + }, + CopyFromTracers: []pgx.CopyFromTracer{ + fullTracer, + copyTracer, + }, + PrepareTracers: []pgx.PrepareTracer{ + fullTracer, + }, + ConnectTracers: []pgx.ConnectTracer{ + fullTracer, + }, + PoolAcquireTracers: []pgxpool.AcquireTracer{ + fullTracer, + }, + PoolReleaseTracers: []pgxpool.ReleaseTracer{ + fullTracer, + }, + }, + mt, + ) +} diff --git a/named_args.go b/named_args.go new file mode 100644 index 000000000..c88991ee4 --- /dev/null +++ b/named_args.go @@ -0,0 +1,295 @@ +package pgx + +import ( + "context" + "fmt" + "strconv" + "strings" + "unicode/utf8" +) + +// NamedArgs can be used as the first argument to a query method. It will replace every '@' named placeholder with a '$' +// ordinal placeholder and construct the appropriate arguments. +// +// For example, the following two queries are equivalent: +// +// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2}) +// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2) +// +// Named placeholders are case sensitive and must start with a letter or underscore. Subsequent characters can be +// letters, numbers, or underscores. +type NamedArgs map[string]any + +// RewriteQuery implements the QueryRewriter interface. +func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) { + return rewriteQuery(na, sql, false) +} + +// StrictNamedArgs can be used in the same way as NamedArgs, but provided arguments are also checked to include all +// named arguments that the sql query uses, and no extra arguments. +type StrictNamedArgs map[string]any + +// RewriteQuery implements the QueryRewriter interface. +func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) { + return rewriteQuery(sna, sql, true) +} + +type namedArg string + +type sqlLexer struct { + src string + start int + pos int + nested int // multiline comment nesting level. + stateFn stateFn + parts []any + + nameToOrdinal map[namedArg]int +} + +type stateFn func(*sqlLexer) stateFn + +func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string, newArgs []any, err error) { + l := &sqlLexer{ + src: sql, + stateFn: rawState, + nameToOrdinal: make(map[namedArg]int, len(na)), + } + + for l.stateFn != nil { + l.stateFn = l.stateFn(l) + } + + sb := strings.Builder{} + for _, p := range l.parts { + switch p := p.(type) { + case string: + sb.WriteString(p) + case namedArg: + sb.WriteRune('$') + sb.WriteString(strconv.Itoa(l.nameToOrdinal[p])) + } + } + + newArgs = make([]any, len(l.nameToOrdinal)) + for name, ordinal := range l.nameToOrdinal { + var found bool + newArgs[ordinal-1], found = na[string(name)] + if isStrict && !found { + return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", name) + } + } + + if isStrict { + for name := range na { + if _, found := l.nameToOrdinal[namedArg(name)]; !found { + return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", name) + } + } + } + + return sb.String(), newArgs, nil +} + +func rawState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case 'e', 'E': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '\'' { + l.pos += width + return escapeStringState + } + case '\'': + return singleQuoteState + case '"': + return doubleQuoteState + case '@': + nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) + if isLetter(nextRune) || nextRune == '_' { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos-width]) + } + l.start = l.pos + return namedArgState + } + case '-': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '-' { + l.pos += width + return oneLineCommentState + } + case '/': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '*' { + l.pos += width + return multilineCommentState + } + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func isLetter(r rune) bool { + return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') +} + +func namedArgState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + if r == utf8.RuneError { + if l.pos-l.start > 0 { + na := namedArg(l.src[l.start:l.pos]) + if _, found := l.nameToOrdinal[na]; !found { + l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1 + } + l.parts = append(l.parts, na) + l.start = l.pos + } + return nil + } else if !(isLetter(r) || (r >= '0' && r <= '9') || r == '_') { + l.pos -= width + na := namedArg(l.src[l.start:l.pos]) + if _, found := l.nameToOrdinal[na]; !found { + l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1 + } + l.parts = append(l.parts, namedArg(na)) + l.start = l.pos + return rawState + } + } +} + +func singleQuoteState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\'': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '\'' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func doubleQuoteState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '"': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '"' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func escapeStringState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\\': + _, width = utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + case '\'': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '\'' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func oneLineCommentState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\\': + _, width = utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + case '\n', '\r': + return rawState + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func multilineCommentState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '/': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '*' { + l.pos += width + l.nested++ + } + case '*': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '/' { + continue + } + + l.pos += width + if l.nested == 0 { + return rawState + } + l.nested-- + + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} diff --git a/named_args_test.go b/named_args_test.go new file mode 100644 index 000000000..8cab2f4d2 --- /dev/null +++ b/named_args_test.go @@ -0,0 +1,162 @@ +package pgx_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNamedArgsRewriteQuery(t *testing.T) { + t.Parallel() + + for i, tt := range []struct { + sql string + args []any + namedArgs pgx.NamedArgs + expectedSQL string + expectedArgs []any + }{ + { + sql: "select * from users where id = @id", + namedArgs: pgx.NamedArgs{"id": int32(42)}, + expectedSQL: "select * from users where id = $1", + expectedArgs: []any{int32(42)}, + }, + { + sql: "select * from t where foo < @abc and baz = @def and bar < @abc", + namedArgs: pgx.NamedArgs{"abc": int32(42), "def": int32(1)}, + expectedSQL: "select * from t where foo < $1 and baz = $2 and bar < $1", + expectedArgs: []any{int32(42), int32(1)}, + }, + { + sql: "select @a::int, @b::text", + namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, + expectedSQL: "select $1::int, $2::text", + expectedArgs: []any{int32(42), "foo"}, + }, + { + sql: "select @Abc::int, @b_4::text, @_c::int", + namedArgs: pgx.NamedArgs{"Abc": int32(42), "b_4": "foo", "_c": int32(1)}, + expectedSQL: "select $1::int, $2::text, $3::int", + expectedArgs: []any{int32(42), "foo", int32(1)}, + }, + { + sql: "at end @", + namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, + expectedSQL: "at end @", + expectedArgs: []any{}, + }, + { + sql: "ignores without valid character after @ foo bar", + namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, + expectedSQL: "ignores without valid character after @ foo bar", + expectedArgs: []any{}, + }, + { + sql: "name cannot start with number @1 foo bar", + namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, + expectedSQL: "name cannot start with number @1 foo bar", + expectedArgs: []any{}, + }, + { + sql: `select *, '@foo' as "@bar" from users where id = @id`, + namedArgs: pgx.NamedArgs{"id": int32(42)}, + expectedSQL: `select *, '@foo' as "@bar" from users where id = $1`, + expectedArgs: []any{int32(42)}, + }, + { + sql: `select * -- @foo + from users -- @single line comments + where id = @id;`, + namedArgs: pgx.NamedArgs{"id": int32(42)}, + expectedSQL: `select * -- @foo + from users -- @single line comments + where id = $1;`, + expectedArgs: []any{int32(42)}, + }, + { + sql: `select * /* @multi line + @comment + */ + /* /* with @nesting */ */ + from users + where id = @id;`, + namedArgs: pgx.NamedArgs{"id": int32(42)}, + expectedSQL: `select * /* @multi line + @comment + */ + /* /* with @nesting */ */ + from users + where id = $1;`, + expectedArgs: []any{int32(42)}, + }, + { + sql: "extra provided argument", + namedArgs: pgx.NamedArgs{"extra": int32(1)}, + expectedSQL: "extra provided argument", + expectedArgs: []any{}, + }, + { + sql: "@missing argument", + namedArgs: pgx.NamedArgs{}, + expectedSQL: "$1 argument", + expectedArgs: []any{nil}, + }, + + // test comments and quotes + } { + sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args) + require.NoError(t, err) + assert.Equalf(t, tt.expectedSQL, sql, "%d", i) + assert.Equalf(t, tt.expectedArgs, args, "%d", i) + } +} + +func TestStrictNamedArgsRewriteQuery(t *testing.T) { + t.Parallel() + + for i, tt := range []struct { + sql string + namedArgs pgx.StrictNamedArgs + expectedSQL string + expectedArgs []any + isExpectedError bool + }{ + { + sql: "no arguments", + namedArgs: pgx.StrictNamedArgs{}, + expectedSQL: "no arguments", + expectedArgs: []any{}, + isExpectedError: false, + }, + { + sql: "@all @matches", + namedArgs: pgx.StrictNamedArgs{"all": int32(1), "matches": int32(2)}, + expectedSQL: "$1 $2", + expectedArgs: []any{int32(1), int32(2)}, + isExpectedError: false, + }, + { + sql: "extra provided argument", + namedArgs: pgx.StrictNamedArgs{"extra": int32(1)}, + isExpectedError: true, + }, + { + sql: "@missing argument", + namedArgs: pgx.StrictNamedArgs{}, + isExpectedError: true, + }, + } { + sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, nil) + if tt.isExpectedError { + assert.Errorf(t, err, "%d", i) + } else { + require.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expectedSQL, sql, "%d", i) + assert.Equalf(t, tt.expectedArgs, args, "%d", i) + } + } +} diff --git a/pgbouncer_test.go b/pgbouncer_test.go new file mode 100644 index 000000000..a9af09666 --- /dev/null +++ b/pgbouncer_test.go @@ -0,0 +1,75 @@ +package pgx_test + +import ( + "context" + "os" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPgbouncerStatementCacheDescribe(t *testing.T) { + connString := os.Getenv("PGX_TEST_PGBOUNCER_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_PGBOUNCER_CONN_STRING") + } + + config := mustParseConfig(t, connString) + config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe + config.DescriptionCacheCapacity = 1024 + + testPgbouncer(t, config, 10, 100) +} + +func TestPgbouncerSimpleProtocol(t *testing.T) { + connString := os.Getenv("PGX_TEST_PGBOUNCER_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_PGBOUNCER_CONN_STRING") + } + + config := mustParseConfig(t, connString) + config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol + + testPgbouncer(t, config, 10, 100) +} + +func testPgbouncer(t *testing.T, config *pgx.ConnConfig, workers, iterations int) { + doneChan := make(chan struct{}) + + for i := 0; i < workers; i++ { + go func() { + defer func() { doneChan <- struct{}{} }() + conn, err := pgx.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, conn) + + for i := 0; i < iterations; i++ { + var i32 int32 + var i64 int64 + var f32 float32 + var s string + var s2 string + err = conn.QueryRow(context.Background(), "select 1::int4, 2::int8, 3::float4, 'hi'::text").Scan(&i32, &i64, &f32, &s) + require.NoError(t, err) + assert.Equal(t, int32(1), i32) + assert.Equal(t, int64(2), i64) + assert.Equal(t, float32(3), f32) + assert.Equal(t, "hi", s) + + err = conn.QueryRow(context.Background(), "select 1::int8, 2::float4, 'bye'::text, 4::int4, 'whatever'::text").Scan(&i64, &f32, &s, &i32, &s2) + require.NoError(t, err) + assert.Equal(t, int64(1), i64) + assert.Equal(t, float32(2), f32) + assert.Equal(t, "bye", s) + assert.Equal(t, int32(4), i32) + assert.Equal(t, "whatever", s2) + } + }() + } + + for i := 0; i < workers; i++ { + <-doneChan + } +} diff --git a/pgconn/README.md b/pgconn/README.md new file mode 100644 index 000000000..1fe15c268 --- /dev/null +++ b/pgconn/README.md @@ -0,0 +1,29 @@ +# pgconn + +Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level as the C library libpq. +It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx. +Applications should handle normal queries with a higher level library and only use pgconn directly when required for +low-level access to PostgreSQL functionality. + +## Example Usage + +```go +pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL")) +if err != nil { + log.Fatalln("pgconn failed to connect:", err) +} +defer pgConn.Close(context.Background()) + +result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) +for result.NextRow() { + fmt.Println("User 123 has email:", string(result.Values()[0])) +} +_, err = result.Close() +if err != nil { + log.Fatalln("failed reading result:", err) +} +``` + +## Testing + +See CONTRIBUTING.md for setup instructions. diff --git a/pgconn/auth_scram.go b/pgconn/auth_scram.go new file mode 100644 index 000000000..2adf1fdd4 --- /dev/null +++ b/pgconn/auth_scram.go @@ -0,0 +1,276 @@ +// SCRAM-SHA-256 authentication +// +// Resources: +// https://tools.ietf.org/html/rfc5802 +// https://tools.ietf.org/html/rfc8265 +// https://www.postgresql.org/docs/current/sasl-authentication.html +// +// Inspiration drawn from other implementations: +// https://github.com/lib/pq/pull/608 +// https://github.com/lib/pq/pull/788 +// https://github.com/lib/pq/pull/833 + +package pgconn + +import ( + "bytes" + "crypto/hmac" + "crypto/pbkdf2" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "strconv" + + "github.com/jackc/pgx/v5/pgproto3" + "golang.org/x/text/secure/precis" +) + +const clientNonceLen = 18 + +// Perform SCRAM authentication. +func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { + sc, err := newScramClient(serverAuthMechanisms, c.config.Password) + if err != nil { + return err + } + + // Send client-first-message in a SASLInitialResponse + saslInitialResponse := &pgproto3.SASLInitialResponse{ + AuthMechanism: "SCRAM-SHA-256", + Data: sc.clientFirstMessage(), + } + c.frontend.Send(saslInitialResponse) + err = c.flushWithPotentialWriteReadDeadlock() + if err != nil { + return err + } + + // Receive server-first-message payload in an AuthenticationSASLContinue. + saslContinue, err := c.rxSASLContinue() + if err != nil { + return err + } + err = sc.recvServerFirstMessage(saslContinue.Data) + if err != nil { + return err + } + + // Send client-final-message in a SASLResponse + saslResponse := &pgproto3.SASLResponse{ + Data: []byte(sc.clientFinalMessage()), + } + c.frontend.Send(saslResponse) + err = c.flushWithPotentialWriteReadDeadlock() + if err != nil { + return err + } + + // Receive server-final-message payload in an AuthenticationSASLFinal. + saslFinal, err := c.rxSASLFinal() + if err != nil { + return err + } + return sc.recvServerFinalMessage(saslFinal.Data) +} + +func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) { + msg, err := c.receiveMessage() + if err != nil { + return nil, err + } + switch m := msg.(type) { + case *pgproto3.AuthenticationSASLContinue: + return m, nil + case *pgproto3.ErrorResponse: + return nil, ErrorResponseToPgError(m) + } + + return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg) +} + +func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) { + msg, err := c.receiveMessage() + if err != nil { + return nil, err + } + switch m := msg.(type) { + case *pgproto3.AuthenticationSASLFinal: + return m, nil + case *pgproto3.ErrorResponse: + return nil, ErrorResponseToPgError(m) + } + + return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg) +} + +type scramClient struct { + serverAuthMechanisms []string + password string + clientNonce []byte + + clientFirstMessageBare []byte + + serverFirstMessage []byte + clientAndServerNonce []byte + salt []byte + iterations int + + saltedPassword []byte + authMessage []byte +} + +func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) { + sc := &scramClient{ + serverAuthMechanisms: serverAuthMechanisms, + } + + // Ensure server supports SCRAM-SHA-256 + hasScramSHA256 := false + for _, mech := range sc.serverAuthMechanisms { + if mech == "SCRAM-SHA-256" { + hasScramSHA256 = true + break + } + } + if !hasScramSHA256 { + return nil, errors.New("server does not support SCRAM-SHA-256") + } + + // precis.OpaqueString is equivalent to SASLprep for password. + var err error + sc.password, err = precis.OpaqueString.String(password) + if err != nil { + // PostgreSQL allows passwords invalid according to SCRAM / SASLprep. + sc.password = password + } + + buf := make([]byte, clientNonceLen) + _, err = rand.Read(buf) + if err != nil { + return nil, err + } + sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf))) + base64.RawStdEncoding.Encode(sc.clientNonce, buf) + + return sc, nil +} + +func (sc *scramClient) clientFirstMessage() []byte { + sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce)) + return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare)) +} + +func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { + sc.serverFirstMessage = serverFirstMessage + buf := serverFirstMessage + if !bytes.HasPrefix(buf, []byte("r=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include r=") + } + buf = buf[2:] + + idx := bytes.IndexByte(buf, ',') + if idx == -1 { + return errors.New("invalid SCRAM server-first-message received from server: did not include s=") + } + sc.clientAndServerNonce = buf[:idx] + buf = buf[idx+1:] + + if !bytes.HasPrefix(buf, []byte("s=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include s=") + } + buf = buf[2:] + + idx = bytes.IndexByte(buf, ',') + if idx == -1 { + return errors.New("invalid SCRAM server-first-message received from server: did not include i=") + } + saltStr := buf[:idx] + buf = buf[idx+1:] + + if !bytes.HasPrefix(buf, []byte("i=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include i=") + } + buf = buf[2:] + iterationsStr := buf + + var err error + sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr)) + if err != nil { + return fmt.Errorf("invalid SCRAM salt received from server: %w", err) + } + + sc.iterations, err = strconv.Atoi(string(iterationsStr)) + if err != nil || sc.iterations <= 0 { + return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err) + } + + if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) { + return errors.New("invalid SCRAM nonce: did not start with client nonce") + } + + if len(sc.clientAndServerNonce) <= len(sc.clientNonce) { + return errors.New("invalid SCRAM nonce: did not include server nonce") + } + + return nil +} + +func (sc *scramClient) clientFinalMessage() string { + clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce)) + + var err error + sc.saltedPassword, err = pbkdf2.Key(sha256.New, sc.password, sc.salt, sc.iterations, 32) + if err != nil { + panic(err) // This should never happen. + } + sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(",")) + + clientProof := computeClientProof(sc.saltedPassword, sc.authMessage) + + return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof) +} + +func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error { + if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) { + return errors.New("invalid SCRAM server-final-message received from server") + } + + serverSignature := serverFinalMessage[2:] + + if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) { + return errors.New("invalid SCRAM ServerSignature received from server") + } + + return nil +} + +func computeHMAC(key, msg []byte) []byte { + mac := hmac.New(sha256.New, key) + mac.Write(msg) + return mac.Sum(nil) +} + +func computeClientProof(saltedPassword, authMessage []byte) []byte { + clientKey := computeHMAC(saltedPassword, []byte("Client Key")) + storedKey := sha256.Sum256(clientKey) + clientSignature := computeHMAC(storedKey[:], authMessage) + + clientProof := make([]byte, len(clientSignature)) + for i := 0; i < len(clientSignature); i++ { + clientProof[i] = clientKey[i] ^ clientSignature[i] + } + + buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof))) + base64.StdEncoding.Encode(buf, clientProof) + return buf +} + +func computeServerSignature(saltedPassword, authMessage []byte) []byte { + serverKey := computeHMAC(saltedPassword, []byte("Server Key")) + serverSignature := computeHMAC(serverKey, authMessage) + buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature))) + base64.StdEncoding.Encode(buf, serverSignature) + return buf +} diff --git a/pgconn/benchmark_private_test.go b/pgconn/benchmark_private_test.go new file mode 100644 index 000000000..9ea036ec7 --- /dev/null +++ b/pgconn/benchmark_private_test.go @@ -0,0 +1,73 @@ +package pgconn + +import ( + "strings" + "testing" +) + +func BenchmarkCommandTagRowsAffected(b *testing.B) { + benchmarks := []struct { + commandTag string + rowsAffected int64 + }{ + {"UPDATE 1", 1}, + {"UPDATE 123456789", 123456789}, + {"INSERT 0 1", 1}, + {"INSERT 0 123456789", 123456789}, + } + + for _, bm := range benchmarks { + ct := CommandTag{s: bm.commandTag} + b.Run(bm.commandTag, func(b *testing.B) { + var n int64 + for i := 0; i < b.N; i++ { + n = ct.RowsAffected() + } + if n != bm.rowsAffected { + b.Errorf("expected %d got %d", bm.rowsAffected, n) + } + }) + } +} + +func BenchmarkCommandTagTypeFromString(b *testing.B) { + ct := CommandTag{s: "UPDATE 1"} + + var update bool + for i := 0; i < b.N; i++ { + update = strings.HasPrefix(ct.String(), "UPDATE") + } + if !update { + b.Error("expected update") + } +} + +func BenchmarkCommandTagInsert(b *testing.B) { + benchmarks := []struct { + commandTag string + is bool + }{ + {"INSERT 1", true}, + {"INSERT 1234567890", true}, + {"UPDATE 1", false}, + {"UPDATE 1234567890", false}, + {"DELETE 1", false}, + {"DELETE 1234567890", false}, + {"SELECT 1", false}, + {"SELECT 1234567890", false}, + {"UNKNOWN 1234567890", false}, + } + + for _, bm := range benchmarks { + ct := CommandTag{s: bm.commandTag} + b.Run(bm.commandTag, func(b *testing.B) { + var is bool + for i := 0; i < b.N; i++ { + is = ct.Insert() + } + if is != bm.is { + b.Errorf("expected %v got %v", bm.is, is) + } + }) + } +} diff --git a/pgconn/benchmark_test.go b/pgconn/benchmark_test.go new file mode 100644 index 000000000..16e71cef7 --- /dev/null +++ b/pgconn/benchmark_test.go @@ -0,0 +1,250 @@ +package pgconn_test + +import ( + "bytes" + "context" + "os" + "testing" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/require" +) + +func BenchmarkConnect(b *testing.B) { + benchmarks := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + } + + for _, bm := range benchmarks { + bm := bm + b.Run(bm.name, func(b *testing.B) { + connString := os.Getenv(bm.env) + if connString == "" { + b.Skipf("Skipping due to missing environment variable %v", bm.env) + } + + for i := 0; i < b.N; i++ { + conn, err := pgconn.Connect(context.Background(), connString) + require.Nil(b, err) + + err = conn.Close(context.Background()) + require.Nil(b, err) + } + }) + } +} + +func BenchmarkExec(b *testing.B) { + expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + benchmarks := []struct { + name string + ctx context.Context + }{ + // Using an empty context other than context.Background() to compare + // performance + {"background context", context.Background()}, + {"empty context", context.TODO()}, + } + + for _, bm := range benchmarks { + bm := bm + b.Run(bm.name, func(b *testing.B) { + conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_DATABASE")) + require.Nil(b, err) + defer closeConn(b, conn) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + mrr := conn.Exec(bm.ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + + for mrr.NextResult() { + rr := mrr.ResultReader() + + rowCount := 0 + for rr.NextRow() { + rowCount++ + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) + } + } + + err := mrr.Close() + if err != nil { + b.Fatal(err) + } + } + }) + } +} + +func BenchmarkExecPossibleToCancel(b *testing.B) { + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(b, err) + defer closeConn(b, conn) + + expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + + b.ResetTimer() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for i := 0; i < b.N; i++ { + mrr := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + + for mrr.NextResult() { + rr := mrr.ResultReader() + + rowCount := 0 + for rr.NextRow() { + rowCount++ + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) + } + } + + err := mrr.Close() + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkExecPrepared(b *testing.B) { + expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + + benchmarks := []struct { + name string + ctx context.Context + }{ + // Using an empty context other than context.Background() to compare + // performance + {"background context", context.Background()}, + {"empty context", context.TODO()}, + } + + for _, bm := range benchmarks { + bm := bm + b.Run(bm.name, func(b *testing.B) { + conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_DATABASE")) + require.Nil(b, err) + defer closeConn(b, conn) + + _, err = conn.Prepare(bm.ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + require.Nil(b, err) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rr := conn.ExecPrepared(bm.ctx, "ps1", nil, nil, nil) + + rowCount := 0 + for rr.NextRow() { + rowCount++ + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) + } + } + }) + } +} + +func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(b, err) + defer closeConn(b, conn) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err = conn.Prepare(ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + require.Nil(b, err) + + expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rr := conn.ExecPrepared(ctx, "ps1", nil, nil, nil) + + rowCount := 0 + for rr.NextRow() { + rowCount += 1 + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) + } + } +} + +// func BenchmarkChanToSetDeadlinePossibleToCancel(b *testing.B) { +// conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) +// require.Nil(b, err) +// defer closeConn(b, conn) + +// ctx, cancel := context.WithCancel(context.Background()) +// defer cancel() + +// b.ResetTimer() + +// for i := 0; i < b.N; i++ { +// conn.ChanToSetDeadline().Watch(ctx) +// conn.ChanToSetDeadline().Ignore() +// } +// } diff --git a/pgconn/config.go b/pgconn/config.go new file mode 100644 index 000000000..3937dc407 --- /dev/null +++ b/pgconn/config.go @@ -0,0 +1,953 @@ +package pgconn + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "io" + "math" + "net" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/jackc/pgpassfile" + "github.com/jackc/pgservicefile" + "github.com/jackc/pgx/v5/pgconn/ctxwatch" + "github.com/jackc/pgx/v5/pgproto3" +) + +type ( + AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error + ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error + GetSSLPasswordFunc func(ctx context.Context) string +) + +// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by [ParseConfig]. A +// manually initialized Config will cause ConnectConfig to panic. +type Config struct { + Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + Database string + User string + Password string + TLSConfig *tls.Config // nil disables TLS + ConnectTimeout time.Duration + DialFunc DialFunc // e.g. net.Dialer.DialContext + LookupFunc LookupFunc // e.g. net.Resolver.LookupHost + BuildFrontend BuildFrontendFunc + + // BuildContextWatcherHandler is called to create a ContextWatcherHandler for a connection. The handler is called + // when a context passed to a PgConn method is canceled. + BuildContextWatcherHandler func(*PgConn) ctxwatch.Handler + + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + + KerberosSrvName string + KerberosSpn string + Fallbacks []*FallbackConfig + + SSLNegotiation string // sslnegotiation=postgres or sslnegotiation=direct + + // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. + // It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next + // fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs. + ValidateConnect ValidateConnectFunc + + // AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables + // or prepare statements). If this returns an error the connection attempt fails. + AfterConnect AfterConnectFunc + + // OnNotice is a callback function called when a notice response is received. + OnNotice NoticeHandler + + // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. + OnNotification NotificationHandler + + // OnPgError is a callback function called when a Postgres error is received by the server. The default handler will close + // the connection on any FATAL errors. If you override this handler you should call the previously set handler or ensure + // that you close on FATAL errors by returning false. + OnPgError PgErrorHandler + + createdByParseConfig bool // Used to enforce created by ParseConfig rule. +} + +// ParseConfigOptions contains options that control how a config is built such as GetSSLPassword. +type ParseConfigOptions struct { + // GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the libpq function + // PQsetSSLKeyPassHook_OpenSSL. + GetSSLPassword GetSSLPasswordFunc +} + +// Copy returns a deep copy of the config that is safe to use and modify. +// The only exception is the TLSConfig field: +// according to the tls.Config docs it must not be modified after creation. +func (c *Config) Copy() *Config { + newConf := new(Config) + *newConf = *c + if newConf.TLSConfig != nil { + newConf.TLSConfig = c.TLSConfig.Clone() + } + if newConf.RuntimeParams != nil { + newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams)) + for k, v := range c.RuntimeParams { + newConf.RuntimeParams[k] = v + } + } + if newConf.Fallbacks != nil { + newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks)) + for i, fallback := range c.Fallbacks { + newFallback := new(FallbackConfig) + *newFallback = *fallback + if newFallback.TLSConfig != nil { + newFallback.TLSConfig = fallback.TLSConfig.Clone() + } + newConf.Fallbacks[i] = newFallback + } + } + return newConf +} + +// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a +// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections. +type FallbackConfig struct { + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + TLSConfig *tls.Config // nil disables TLS +} + +// connectOneConfig is the configuration for a single attempt to connect to a single host. +type connectOneConfig struct { + network string + address string + originalHostname string // original hostname before resolving + tlsConfig *tls.Config // nil disables TLS +} + +// isAbsolutePath checks if the provided value is an absolute path either +// beginning with a forward slash (as on Linux-based systems) or with a capital +// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows). +func isAbsolutePath(path string) bool { + isWindowsPath := func(p string) bool { + if len(p) < 3 { + return false + } + drive := p[0] + colon := p[1] + backslash := p[2] + if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' { + return true + } + return false + } + return strings.HasPrefix(path, "/") || isWindowsPath(path) +} + +// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with +// net.Dial. +func NetworkAddress(host string, port uint16) (network, address string) { + if isAbsolutePath(host) { + network = "unix" + address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10) + } else { + network = "tcp" + address = net.JoinHostPort(host, strconv.Itoa(int(port))) + } + return network, address +} + +// ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It +// uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely +// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format. See +// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be empty +// to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file. +// +// # Example Keyword/Value +// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca +// +// # Example URL +// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca +// +// The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done +// through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be +// interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should +// not be modified individually. They should all be modified or all left unchanged. +// +// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated +// values that will be tried in order. This can be used as part of a high availability system. See +// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. +// +// # Example URL +// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb +// +// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed +// via database URL or keyword/value: +// +// PGHOST +// PGPORT +// PGDATABASE +// PGUSER +// PGPASSWORD +// PGPASSFILE +// PGSERVICE +// PGSERVICEFILE +// PGSSLMODE +// PGSSLCERT +// PGSSLKEY +// PGSSLROOTCERT +// PGSSLPASSWORD +// PGOPTIONS +// PGAPPNAME +// PGCONNECT_TIMEOUT +// PGTARGETSESSIONATTRS +// PGTZ +// +// See http://www.postgresql.org/docs/current/static/libpq-envars.html for details on the meaning of environment variables. +// +// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are +// usually but not always the environment variable name downcased and without the "PG" prefix. +// +// Important Security Notes: +// +// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if +// not set. +// +// See http://www.postgresql.org/docs/current/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of +// security each sslmode provides. +// +// The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of +// the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of +// sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback +// which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually +// changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting +// TLSConfig. +// +// Other known differences with libpq: +// +// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn +// does not. +// +// In addition, ParseConfig accepts the following options: +// +// - servicefile. +// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a +// part of the connection string. +func ParseConfig(connString string) (*Config, error) { + var parseConfigOptions ParseConfigOptions + return ParseConfigWithOptions(connString, parseConfigOptions) +} + +// ParseConfigWithOptions builds a *Config from connString and options with similar behavior to the PostgreSQL standard +// C library libpq. options contains settings that cannot be specified in a connString such as providing a function to +// get the SSL password. +func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Config, error) { + defaultSettings := defaultSettings() + envSettings := parseEnvSettings() + + connStringSettings := make(map[string]string) + if connString != "" { + var err error + // connString may be a database URL or in PostgreSQL keyword/value format + if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { + connStringSettings, err = parseURLSettings(connString) + if err != nil { + return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err} + } + } else { + connStringSettings, err = parseKeywordValueSettings(connString) + if err != nil { + return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as keyword/value", err: err} + } + } + } + + settings := mergeSettings(defaultSettings, envSettings, connStringSettings) + if service, present := settings["service"]; present { + serviceSettings, err := parseServiceSettings(settings["servicefile"], service) + if err != nil { + return nil, &ParseConfigError{ConnString: connString, msg: "failed to read service", err: err} + } + + settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings) + } + + config := &Config{ + createdByParseConfig: true, + Database: settings["database"], + User: settings["user"], + Password: settings["password"], + RuntimeParams: make(map[string]string), + BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend { + return pgproto3.NewFrontend(r, w) + }, + BuildContextWatcherHandler: func(pgConn *PgConn) ctxwatch.Handler { + return &DeadlineContextWatcherHandler{Conn: pgConn.conn} + }, + OnPgError: func(_ *PgConn, pgErr *PgError) bool { + // we want to automatically close any fatal errors + if strings.EqualFold(pgErr.Severity, "FATAL") { + return false + } + return true + }, + } + + if connectTimeoutSetting, present := settings["connect_timeout"]; present { + connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting) + if err != nil { + return nil, &ParseConfigError{ConnString: connString, msg: "invalid connect_timeout", err: err} + } + config.ConnectTimeout = connectTimeout + config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout) + } else { + defaultDialer := makeDefaultDialer() + config.DialFunc = defaultDialer.DialContext + } + + config.LookupFunc = makeDefaultResolver().LookupHost + + notRuntimeParams := map[string]struct{}{ + "host": {}, + "port": {}, + "database": {}, + "user": {}, + "password": {}, + "passfile": {}, + "connect_timeout": {}, + "sslmode": {}, + "sslkey": {}, + "sslcert": {}, + "sslrootcert": {}, + "sslnegotiation": {}, + "sslpassword": {}, + "sslsni": {}, + "krbspn": {}, + "krbsrvname": {}, + "target_session_attrs": {}, + "service": {}, + "servicefile": {}, + } + + // Adding kerberos configuration + if _, present := settings["krbsrvname"]; present { + config.KerberosSrvName = settings["krbsrvname"] + } + if _, present := settings["krbspn"]; present { + config.KerberosSpn = settings["krbspn"] + } + + for k, v := range settings { + if _, present := notRuntimeParams[k]; present { + continue + } + config.RuntimeParams[k] = v + } + + fallbacks := []*FallbackConfig{} + + hosts := strings.Split(settings["host"], ",") + ports := strings.Split(settings["port"], ",") + + for i, host := range hosts { + var portStr string + if i < len(ports) { + portStr = ports[i] + } else { + portStr = ports[0] + } + + port, err := parsePort(portStr) + if err != nil { + return nil, &ParseConfigError{ConnString: connString, msg: "invalid port", err: err} + } + + var tlsConfigs []*tls.Config + + // Ignore TLS settings if Unix domain socket like libpq + if network, _ := NetworkAddress(host, port); network == "unix" { + tlsConfigs = append(tlsConfigs, nil) + } else { + var err error + tlsConfigs, err = configTLS(settings, host, options) + if err != nil { + return nil, &ParseConfigError{ConnString: connString, msg: "failed to configure TLS", err: err} + } + } + + for _, tlsConfig := range tlsConfigs { + fallbacks = append(fallbacks, &FallbackConfig{ + Host: host, + Port: port, + TLSConfig: tlsConfig, + }) + } + } + + config.Host = fallbacks[0].Host + config.Port = fallbacks[0].Port + config.TLSConfig = fallbacks[0].TLSConfig + config.Fallbacks = fallbacks[1:] + config.SSLNegotiation = settings["sslnegotiation"] + + passfile, err := pgpassfile.ReadPassfile(settings["passfile"]) + if err == nil { + if config.Password == "" { + host := config.Host + if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" { + host = "localhost" + } + + config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User) + } + } + + switch tsa := settings["target_session_attrs"]; tsa { + case "read-write": + config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite + case "read-only": + config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly + case "primary": + config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary + case "standby": + config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby + case "prefer-standby": + config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby + case "any": + // do nothing + default: + return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)} + } + + return config, nil +} + +func mergeSettings(settingSets ...map[string]string) map[string]string { + settings := make(map[string]string) + + for _, s2 := range settingSets { + for k, v := range s2 { + settings[k] = v + } + } + + return settings +} + +func parseEnvSettings() map[string]string { + settings := make(map[string]string) + + nameMap := map[string]string{ + "PGHOST": "host", + "PGPORT": "port", + "PGDATABASE": "database", + "PGUSER": "user", + "PGPASSWORD": "password", + "PGPASSFILE": "passfile", + "PGAPPNAME": "application_name", + "PGCONNECT_TIMEOUT": "connect_timeout", + "PGSSLMODE": "sslmode", + "PGSSLKEY": "sslkey", + "PGSSLCERT": "sslcert", + "PGSSLSNI": "sslsni", + "PGSSLROOTCERT": "sslrootcert", + "PGSSLPASSWORD": "sslpassword", + "PGSSLNEGOTIATION": "sslnegotiation", + "PGTARGETSESSIONATTRS": "target_session_attrs", + "PGSERVICE": "service", + "PGSERVICEFILE": "servicefile", + "PGTZ": "timezone", + "PGOPTIONS": "options", + } + + for envname, realname := range nameMap { + value := os.Getenv(envname) + if value != "" { + settings[realname] = value + } + } + + return settings +} + +func parseURLSettings(connString string) (map[string]string, error) { + settings := make(map[string]string) + + parsedURL, err := url.Parse(connString) + if err != nil { + if urlErr := new(url.Error); errors.As(err, &urlErr) { + return nil, urlErr.Err + } + return nil, err + } + + if parsedURL.User != nil { + settings["user"] = parsedURL.User.Username() + if password, present := parsedURL.User.Password(); present { + settings["password"] = password + } + } + + // Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port. + var hosts []string + var ports []string + for _, host := range strings.Split(parsedURL.Host, ",") { + if host == "" { + continue + } + if isIPOnly(host) { + hosts = append(hosts, strings.Trim(host, "[]")) + continue + } + h, p, err := net.SplitHostPort(host) + if err != nil { + return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err) + } + if h != "" { + hosts = append(hosts, h) + } + if p != "" { + ports = append(ports, p) + } + } + if len(hosts) > 0 { + settings["host"] = strings.Join(hosts, ",") + } + if len(ports) > 0 { + settings["port"] = strings.Join(ports, ",") + } + + database := strings.TrimLeft(parsedURL.Path, "/") + if database != "" { + settings["database"] = database + } + + nameMap := map[string]string{ + "dbname": "database", + } + + for k, v := range parsedURL.Query() { + if k2, present := nameMap[k]; present { + k = k2 + } + + settings[k] = v[0] + } + + return settings, nil +} + +func isIPOnly(host string) bool { + return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":") +} + +var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} + +func parseKeywordValueSettings(s string) (map[string]string, error) { + settings := make(map[string]string) + + nameMap := map[string]string{ + "dbname": "database", + } + + for len(s) > 0 { + var key, val string + eqIdx := strings.IndexRune(s, '=') + if eqIdx < 0 { + return nil, errors.New("invalid keyword/value") + } + + key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") + s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f") + if len(s) == 0 { + } else if s[0] != '\'' { + end := 0 + for ; end < len(s); end++ { + if asciiSpace[s[end]] == 1 { + break + } + if s[end] == '\\' { + end++ + if end == len(s) { + return nil, errors.New("invalid backslash") + } + } + } + val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) + if end == len(s) { + s = "" + } else { + s = s[end+1:] + } + } else { // quoted string + s = s[1:] + end := 0 + for ; end < len(s); end++ { + if s[end] == '\'' { + break + } + if s[end] == '\\' { + end++ + } + } + if end == len(s) { + return nil, errors.New("unterminated quoted string in connection info string") + } + val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) + if end == len(s) { + s = "" + } else { + s = s[end+1:] + } + } + + if k, ok := nameMap[key]; ok { + key = k + } + + if key == "" { + return nil, errors.New("invalid keyword/value") + } + + settings[key] = val + } + + return settings, nil +} + +func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) { + servicefile, err := pgservicefile.ReadServicefile(servicefilePath) + if err != nil { + return nil, fmt.Errorf("failed to read service file: %v", servicefilePath) + } + + service, err := servicefile.GetService(serviceName) + if err != nil { + return nil, fmt.Errorf("unable to find service: %v", serviceName) + } + + nameMap := map[string]string{ + "dbname": "database", + } + + settings := make(map[string]string, len(service.Settings)) + for k, v := range service.Settings { + if k2, present := nameMap[k]; present { + k = k2 + } + settings[k] = v + } + + return settings, nil +} + +// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is +// necessary to allow returning multiple TLS configs as sslmode "allow" and +// "prefer" allow fallback. +func configTLS(settings map[string]string, thisHost string, parseConfigOptions ParseConfigOptions) ([]*tls.Config, error) { + host := thisHost + sslmode := settings["sslmode"] + sslrootcert := settings["sslrootcert"] + sslcert := settings["sslcert"] + sslkey := settings["sslkey"] + sslpassword := settings["sslpassword"] + sslsni := settings["sslsni"] + sslnegotiation := settings["sslnegotiation"] + + // Match libpq default behavior + if sslmode == "" { + sslmode = "prefer" + } + if sslsni == "" { + sslsni = "1" + } + + tlsConfig := &tls.Config{} + + if sslnegotiation == "direct" { + tlsConfig.NextProtos = []string{"postgresql"} + if sslmode == "prefer" { + sslmode = "require" + } + } + + if sslrootcert != "" { + var caCertPool *x509.CertPool + + if sslrootcert == "system" { + var err error + + caCertPool, err = x509.SystemCertPool() + if err != nil { + return nil, fmt.Errorf("unable to load system certificate pool: %w", err) + } + + sslmode = "verify-full" + } else { + caCertPool = x509.NewCertPool() + + caPath := sslrootcert + caCert, err := os.ReadFile(caPath) + if err != nil { + return nil, fmt.Errorf("unable to read CA file: %w", err) + } + + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, errors.New("unable to add CA to cert pool") + } + } + + tlsConfig.RootCAs = caCertPool + tlsConfig.ClientCAs = caCertPool + } + + switch sslmode { + case "disable": + return []*tls.Config{nil}, nil + case "allow", "prefer": + tlsConfig.InsecureSkipVerify = true + case "require": + // According to PostgreSQL documentation, if a root CA file exists, + // the behavior of sslmode=require should be the same as that of verify-ca + // + // See https://www.postgresql.org/docs/current/libpq-ssl.html + if sslrootcert != "" { + goto nextCase + } + tlsConfig.InsecureSkipVerify = true + break + nextCase: + fallthrough + case "verify-ca": + // Don't perform the default certificate verification because it + // will verify the hostname. Instead, verify the server's + // certificate chain ourselves in VerifyPeerCertificate and + // ignore the server name. This emulates libpq's verify-ca + // behavior. + // + // See https://github.com/golang/go/issues/21971#issuecomment-332693931 + // and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate + // for more info. + tlsConfig.InsecureSkipVerify = true + tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error { + certs := make([]*x509.Certificate, len(certificates)) + for i, asn1Data := range certificates { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + return errors.New("failed to parse certificate from server: " + err.Error()) + } + certs[i] = cert + } + + // Leave DNSName empty to skip hostname verification. + opts := x509.VerifyOptions{ + Roots: tlsConfig.RootCAs, + Intermediates: x509.NewCertPool(), + } + // Skip the first cert because it's the leaf. All others + // are intermediates. + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + _, err := certs[0].Verify(opts) + return err + } + case "verify-full": + tlsConfig.ServerName = host + default: + return nil, errors.New("sslmode is invalid") + } + + if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { + return nil, errors.New(`both "sslcert" and "sslkey" are required`) + } + + if sslcert != "" && sslkey != "" { + buf, err := os.ReadFile(sslkey) + if err != nil { + return nil, fmt.Errorf("unable to read sslkey: %w", err) + } + block, _ := pem.Decode(buf) + if block == nil { + return nil, errors.New("failed to decode sslkey") + } + var pemKey []byte + var decryptedKey []byte + var decryptedError error + // If PEM is encrypted, attempt to decrypt using pass phrase + if x509.IsEncryptedPEMBlock(block) { + // Attempt decryption with pass phrase + // NOTE: only supports RSA (PKCS#1) + if sslpassword != "" { + decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) + } + // if sslpassword not provided or has decryption error when use it + // try to find sslpassword with callback function + if sslpassword == "" || decryptedError != nil { + if parseConfigOptions.GetSSLPassword != nil { + sslpassword = parseConfigOptions.GetSSLPassword(context.Background()) + } + if sslpassword == "" { + return nil, fmt.Errorf("unable to find sslpassword") + } + } + decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) + // Should we also provide warning for PKCS#1 needed? + if decryptedError != nil { + return nil, fmt.Errorf("unable to decrypt key: %w", err) + } + + pemBytes := pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: decryptedKey, + } + pemKey = pem.EncodeToMemory(&pemBytes) + } else { + pemKey = pem.EncodeToMemory(block) + } + certfile, err := os.ReadFile(sslcert) + if err != nil { + return nil, fmt.Errorf("unable to read cert: %w", err) + } + cert, err := tls.X509KeyPair(certfile, pemKey) + if err != nil { + return nil, fmt.Errorf("unable to load cert: %w", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + } + + // Set Server Name Indication (SNI), if enabled by connection parameters. + // Per RFC 6066, do not set it if the host is a literal IP address (IPv4 + // or IPv6). + if sslsni == "1" && net.ParseIP(host) == nil { + tlsConfig.ServerName = host + } + + switch sslmode { + case "allow": + return []*tls.Config{nil, tlsConfig}, nil + case "prefer": + return []*tls.Config{tlsConfig, nil}, nil + case "require", "verify-ca", "verify-full": + return []*tls.Config{tlsConfig}, nil + default: + panic("BUG: bad sslmode should already have been caught") + } +} + +func parsePort(s string) (uint16, error) { + port, err := strconv.ParseUint(s, 10, 16) + if err != nil { + return 0, err + } + if port < 1 || port > math.MaxUint16 { + return 0, errors.New("outside range") + } + return uint16(port), nil +} + +func makeDefaultDialer() *net.Dialer { + // rely on GOLANG KeepAlive settings + return &net.Dialer{} +} + +func makeDefaultResolver() *net.Resolver { + return net.DefaultResolver +} + +func parseConnectTimeoutSetting(s string) (time.Duration, error) { + timeout, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return 0, err + } + if timeout < 0 { + return 0, errors.New("negative timeout") + } + return time.Duration(timeout) * time.Second, nil +} + +func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc { + d := makeDefaultDialer() + d.Timeout = timeout + return d.DialContext +} + +// ValidateConnectTargetSessionAttrsReadWrite is a ValidateConnectFunc that implements libpq compatible +// target_session_attrs=read-write. +func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { + result, err := pgConn.Exec(ctx, "show transaction_read_only").ReadAll() + if err != nil { + return err + } + + if string(result[0].Rows[0][0]) == "on" { + return errors.New("read only connection") + } + + return nil +} + +// ValidateConnectTargetSessionAttrsReadOnly is a ValidateConnectFunc that implements libpq compatible +// target_session_attrs=read-only. +func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error { + result, err := pgConn.Exec(ctx, "show transaction_read_only").ReadAll() + if err != nil { + return err + } + + if string(result[0].Rows[0][0]) != "on" { + return errors.New("connection is not read only") + } + + return nil +} + +// ValidateConnectTargetSessionAttrsStandby is a ValidateConnectFunc that implements libpq compatible +// target_session_attrs=standby. +func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error { + result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll() + if err != nil { + return err + } + + if string(result[0].Rows[0][0]) != "t" { + return errors.New("server is not in hot standby mode") + } + + return nil +} + +// ValidateConnectTargetSessionAttrsPrimary is a ValidateConnectFunc that implements libpq compatible +// target_session_attrs=primary. +func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error { + result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll() + if err != nil { + return err + } + + if string(result[0].Rows[0][0]) == "t" { + return errors.New("server is in standby mode") + } + + return nil +} + +// ValidateConnectTargetSessionAttrsPreferStandby is a ValidateConnectFunc that implements libpq compatible +// target_session_attrs=prefer-standby. +func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error { + result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll() + if err != nil { + return err + } + + if string(result[0].Rows[0][0]) != "t" { + return &NotPreferredError{err: errors.New("server is not in hot standby mode")} + } + + return nil +} diff --git a/pgconn/config_test.go b/pgconn/config_test.go new file mode 100644 index 000000000..ed719ece9 --- /dev/null +++ b/pgconn/config_test.go @@ -0,0 +1,1141 @@ +package pgconn_test + +import ( + "context" + "crypto/tls" + "fmt" + "os" + "os/user" + "path/filepath" + "runtime" + "strconv" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func skipOnWindows(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("FIXME: skipping on Windows, investigate why this test fails in CI environment") + } +} + +func getDefaultPort(t *testing.T) uint16 { + if envPGPORT := os.Getenv("PGPORT"); envPGPORT != "" { + p, err := strconv.ParseUint(envPGPORT, 10, 16) + require.NoError(t, err) + return uint16(p) + } + return 5432 +} + +func getDefaultUser(t *testing.T) string { + if pguser := os.Getenv("PGUSER"); pguser != "" { + return pguser + } + + var osUserName string + osUser, err := user.Current() + if err == nil { + // Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`, + // but the libpq default is just the `user` portion, so we strip off the first part. + if runtime.GOOS == "windows" && strings.Contains(osUser.Username, "\\") { + osUserName = osUser.Username[strings.LastIndex(osUser.Username, "\\")+1:] + } else { + osUserName = osUser.Username + } + } + + return osUserName +} + +func TestParseConfig(t *testing.T) { + skipOnWindows(t) + t.Parallel() + + config, err := pgconn.ParseConfig("") + require.NoError(t, err) + defaultHost := config.Host + + defaultUser := getDefaultUser(t) + defaultPort := getDefaultPort(t) + + tests := []struct { + name string + connString string + config *pgconn.Config + }{ + // Test all sslmodes + { + name: "sslmode not set (prefer)", + connString: "postgres://jack:secret@localhost:5432/mydb", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "localhost", + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + { + Host: "localhost", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "sslmode disable", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode allow", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=allow", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + { + Host: "localhost", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "localhost", + }, + }, + }, + }, + }, + { + name: "sslmode prefer", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "localhost", + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + { + Host: "localhost", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "sslmode require", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=require", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "localhost", + }, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode verify-ca", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "localhost", + }, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode verify-full", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-full", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ServerName: "localhost"}, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url everything", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema&connect_timeout=5", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + ConnectTimeout: 5 * time.Second, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + }, + }, + }, + { + name: "database url missing password", + connString: "postgres://jack@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url missing user and password", + connString: "postgres://localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: defaultUser, + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url missing port", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url unix domain socket host", + connString: "postgres:///foo?host=/tmp", + config: &pgconn.Config{ + User: defaultUser, + Host: "/tmp", + Port: defaultPort, + Database: "foo", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url unix domain socket host on windows", + connString: "postgres:///foo?host=C:\\tmp", + config: &pgconn.Config{ + User: defaultUser, + Host: "C:\\tmp", + Port: defaultPort, + Database: "foo", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url dbname", + connString: "postgres://localhost/?dbname=foo&sslmode=disable", + config: &pgconn.Config{ + User: defaultUser, + Host: "localhost", + Port: defaultPort, + Database: "foo", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url postgresql protocol", + connString: "postgresql://jack@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url IPv4 with port", + connString: "postgresql://jack@127.0.0.1:5433/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "127.0.0.1", + Port: 5433, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url IPv6 with port", + connString: "postgresql://jack@[2001:db8::1]:5433/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "2001:db8::1", + Port: 5433, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url IPv6 no port", + connString: "postgresql://jack@[2001:db8::1]/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "2001:db8::1", + Port: defaultPort, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "Key/value everything", + connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema connect_timeout=5", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + ConnectTimeout: 5 * time.Second, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + }, + }, + }, + { + name: "Key/value with escaped single quote", + connString: "user=jack\\'s password=secret host=localhost port=5432 dbname=mydb sslmode=disable", + config: &pgconn.Config{ + User: "jack's", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "Key/value with escaped backslash", + connString: "user=jack password=sooper\\\\secret host=localhost port=5432 dbname=mydb sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "sooper\\secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "Key/value with single quoted values", + connString: "user='jack' host='localhost' dbname='mydb' sslmode='disable'", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: defaultPort, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "Key/value with single quoted value with escaped single quote", + connString: "user='jack\\'s' host='localhost' dbname='mydb' sslmode='disable'", + config: &pgconn.Config{ + User: "jack's", + Host: "localhost", + Port: defaultPort, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "Key/value with empty single quoted value", + connString: "user='jack' password='' host='localhost' dbname='mydb' sslmode='disable'", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: defaultPort, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "Key/value with space between key and value", + connString: "user = 'jack' password = '' host = 'localhost' dbname = 'mydb' sslmode='disable'", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: defaultPort, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "URL multiple hosts", + connString: "postgres://jack:secret@foo,bar,baz/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: defaultPort, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + { + Host: "bar", + Port: defaultPort, + TLSConfig: nil, + }, + { + Host: "baz", + Port: defaultPort, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "URL multiple hosts and ports", + connString: "postgres://jack:secret@foo:1,bar:2,baz:3/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 1, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + { + Host: "bar", + Port: 2, + TLSConfig: nil, + }, + { + Host: "baz", + Port: 3, + TLSConfig: nil, + }, + }, + }, + }, + // https://github.com/jackc/pgconn/issues/72 + { + name: "URL without host but with port still uses default host", + connString: "postgres://jack:secret@:1/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: defaultHost, + Port: 1, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "Key/value multiple hosts one port", + connString: "user=jack password=secret host=foo,bar,baz port=5432 dbname=mydb sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + { + Host: "bar", + Port: 5432, + TLSConfig: nil, + }, + { + Host: "baz", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "Key/value multiple hosts multiple ports", + connString: "user=jack password=secret host=foo,bar,baz port=1,2,3 dbname=mydb sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 1, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + { + Host: "bar", + Port: 2, + TLSConfig: nil, + }, + { + Host: "baz", + Port: 3, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "multiple hosts and fallback tls", + connString: "user=jack password=secret host=foo,bar,baz dbname=mydb sslmode=prefer", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: defaultPort, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "foo", + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + { + Host: "foo", + Port: defaultPort, + TLSConfig: nil, + }, + { + Host: "bar", + Port: defaultPort, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "bar", + }, + }, + { + Host: "bar", + Port: defaultPort, + TLSConfig: nil, + }, + { + Host: "baz", + Port: defaultPort, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "baz", + }, + }, + { + Host: "baz", + Port: defaultPort, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "target_session_attrs read-write", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsReadWrite, + }, + }, + { + name: "target_session_attrs read-only", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-only", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsReadOnly, + }, + }, + { + name: "target_session_attrs primary", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=primary", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsPrimary, + }, + }, + { + name: "target_session_attrs standby", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=standby", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsStandby, + }, + }, + { + name: "target_session_attrs prefer-standby", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=prefer-standby", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsPreferStandby, + }, + }, + { + name: "target_session_attrs any", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=any", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "target_session_attrs not set (any)", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "SNI is set by default", + connString: "postgres://jack:secret@sni.test:5432/mydb?sslmode=require", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "sni.test", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "sni.test", + }, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "SNI is not set for IPv4", + connString: "postgres://jack:secret@1.1.1.1:5432/mydb?sslmode=require", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "1.1.1.1", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "SNI is not set for IPv6", + connString: "postgres://jack:secret@[::1]:5432/mydb?sslmode=require", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "::1", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "SNI is not set when disabled (URL-style)", + connString: "postgres://jack:secret@sni.test:5432/mydb?sslmode=require&sslsni=0", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "sni.test", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "SNI is not set when disabled (key/value style)", + connString: "user=jack password=secret host=sni.test dbname=mydb sslmode=require sslsni=0", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "sni.test", + Port: defaultPort, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + }, + }, + } + + for i, tt := range tests { + config, err := pgconn.ParseConfig(tt.connString) + if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) { + continue + } + + assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) + } +} + +// https://github.com/jackc/pgconn/issues/47 +func TestParseConfigKVWithTrailingEmptyEqualDoesNotPanic(t *testing.T) { + _, err := pgconn.ParseConfig("host= user= password= port= database=") + require.NoError(t, err) +} + +func TestParseConfigKVLeadingEqual(t *testing.T) { + _, err := pgconn.ParseConfig("= user=jack") + require.Error(t, err) +} + +// https://github.com/jackc/pgconn/issues/49 +func TestParseConfigKVTrailingBackslash(t *testing.T) { + _, err := pgconn.ParseConfig(`x=x\`) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid backslash") +} + +func TestConfigCopyReturnsEqualConfig(t *testing.T) { + connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" + original, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config") +} + +func TestConfigCopyOriginalConfigDidNotChange(t *testing.T) { + connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5&sslmode=prefer" + original, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config") + + copied.Port = uint16(5433) + copied.RuntimeParams["foo"] = "bar" + copied.Fallbacks[0].Port = uint16(5433) + + assert.Equal(t, uint16(5432), original.Port) + assert.Equal(t, "", original.RuntimeParams["foo"]) + assert.Equal(t, uint16(5432), original.Fallbacks[0].Port) +} + +func TestConfigCopyCanBeUsedToConnect(t *testing.T) { + connString := os.Getenv("PGX_TEST_DATABASE") + original, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + assert.NotPanics(t, func() { + _, err = pgconn.ConnectConfig(context.Background(), copied) + }) + assert.NoError(t, err) +} + +func TestNetworkAddress(t *testing.T) { + tests := []struct { + name string + host string + wantNet string + }{ + { + name: "Default Unix socket address", + host: "/var/run/postgresql", + wantNet: "unix", + }, + { + name: "Windows Unix socket address (standard drive name)", + host: "C:\\tmp", + wantNet: "unix", + }, + { + name: "Windows Unix socket address (first drive name)", + host: "A:\\tmp", + wantNet: "unix", + }, + { + name: "Windows Unix socket address (last drive name)", + host: "Z:\\tmp", + wantNet: "unix", + }, + { + name: "Assume TCP for unknown formats", + host: "a/tmp", + wantNet: "tcp", + }, + { + name: "loopback interface", + host: "localhost", + wantNet: "tcp", + }, + { + name: "IP address", + host: "127.0.0.1", + wantNet: "tcp", + }, + } + for i, tt := range tests { + gotNet, _ := pgconn.NetworkAddress(tt.host, 5432) + + assert.Equalf(t, tt.wantNet, gotNet, "Test %d (%s)", i, tt.name) + } +} + +func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) { + if !assert.NotNil(t, expected) { + return + } + if !assert.NotNil(t, actual) { + return + } + + assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) + assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) + assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) + assert.Equalf(t, expected.User, actual.User, "%s - User", testName) + assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) + assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName) + assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) + + // Can't test function equality, so just test that they are set or not. + assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName) + assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) + + if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { + if expected.TLSConfig != nil { + assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName) + assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName) + } + } + + if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) { + for i := range expected.Fallbacks { + assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i) + assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i) + + if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) { + if expected.Fallbacks[i].TLSConfig != nil { + assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName) + assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName) + } + } + } + } +} + +func TestParseConfigEnvLibpq(t *testing.T) { + var osUserName string + osUser, err := user.Current() + if err == nil { + // Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`, + // but the libpq default is just the `user` portion, so we strip off the first part. + if runtime.GOOS == "windows" && strings.Contains(osUser.Username, "\\") { + osUserName = osUser.Username[strings.LastIndex(osUser.Username, "\\")+1:] + } else { + osUserName = osUser.Username + } + } + + pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT", "PGSSLSNI", "PGTZ", "PGOPTIONS"} + + tests := []struct { + name string + envvars map[string]string + config *pgconn.Config + }{ + { + // not testing no environment at all as that would use default host and that can vary. + name: "PGHOST only", + envvars: map[string]string{"PGHOST": "123.123.123.123"}, + config: &pgconn.Config{ + User: osUserName, + Host: "123.123.123.123", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + { + Host: "123.123.123.123", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "All non-TLS environment", + envvars: map[string]string{ + "PGHOST": "123.123.123.123", + "PGPORT": "7777", + "PGDATABASE": "foo", + "PGUSER": "bar", + "PGPASSWORD": "baz", + "PGCONNECT_TIMEOUT": "10", + "PGSSLMODE": "disable", + "PGAPPNAME": "pgxtest", + "PGTZ": "America/New_York", + "PGOPTIONS": "-c search_path=myschema", + }, + config: &pgconn.Config{ + Host: "123.123.123.123", + Port: 7777, + Database: "foo", + User: "bar", + Password: "baz", + ConnectTimeout: 10 * time.Second, + TLSConfig: nil, + RuntimeParams: map[string]string{"application_name": "pgxtest", "timezone": "America/New_York", "options": "-c search_path=myschema"}, + }, + }, + { + name: "SNI can be disabled via environment variable", + envvars: map[string]string{ + "PGHOST": "test.foo", + "PGSSLMODE": "require", + "PGSSLSNI": "0", + }, + config: &pgconn.Config{ + User: osUserName, + Host: "test.foo", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + }, + }, + } + + for i, tt := range tests { + for _, env := range pgEnvvars { + t.Setenv(env, tt.envvars[env]) + } + + config, err := pgconn.ParseConfig("") + if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) { + continue + } + + assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) + } +} + +func TestParseConfigReadsPgPassfile(t *testing.T) { + skipOnWindows(t) + t.Parallel() + + tfName := filepath.Join(t.TempDir(), "config") + err := os.WriteFile(tfName, []byte("test1:5432:curlydb:curly:nyuknyuknyuk"), 0o600) + require.NoError(t, err) + + connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tfName) + expected := &pgconn.Config{ + User: "curly", + Password: "nyuknyuknyuk", + Host: "test1", + Port: 5432, + Database: "curlydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + } + + actual, err := pgconn.ParseConfig(connString) + assert.NoError(t, err) + + assertConfigsEqual(t, expected, actual, "passfile") +} + +func TestParseConfigReadsPgServiceFile(t *testing.T) { + skipOnWindows(t) + t.Parallel() + + tfName := filepath.Join(t.TempDir(), "config") + + err := os.WriteFile(tfName, []byte(` +[abc] +host=abc.example.com +port=9999 +dbname=abcdb +user=abcuser + +[def] +host = def.example.com +dbname = defdb +user = defuser +application_name = spaced string +`), 0o600) + require.NoError(t, err) + + defaultPort := getDefaultPort(t) + + tests := []struct { + name string + connString string + config *pgconn.Config + }{ + { + name: "abc", + connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tfName, "abc"), + config: &pgconn.Config{ + Host: "abc.example.com", + Database: "abcdb", + User: "abcuser", + Port: 9999, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "abc.example.com", + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + { + Host: "abc.example.com", + Port: 9999, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "def", + connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tfName, "def"), + config: &pgconn.Config{ + Host: "def.example.com", + Port: defaultPort, + Database: "defdb", + User: "defuser", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: "def.example.com", + }, + RuntimeParams: map[string]string{"application_name": "spaced string"}, + Fallbacks: []*pgconn.FallbackConfig{ + { + Host: "def.example.com", + Port: defaultPort, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "conn string has precedence", + connString: fmt.Sprintf("postgres://other.example.com:7777/?servicefile=%s&service=%s&sslmode=disable", tfName, "abc"), + config: &pgconn.Config{ + Host: "other.example.com", + Database: "abcdb", + User: "abcuser", + Port: 7777, + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + } + + for i, tt := range tests { + config, err := pgconn.ParseConfig(tt.connString) + if !assert.NoErrorf(t, err, "Test %d (%s)", i, tt.name) { + continue + } + + assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) + } +} diff --git a/pgconn/ctxwatch/context_watcher.go b/pgconn/ctxwatch/context_watcher.go new file mode 100644 index 000000000..db8884eb8 --- /dev/null +++ b/pgconn/ctxwatch/context_watcher.go @@ -0,0 +1,80 @@ +package ctxwatch + +import ( + "context" + "sync" +) + +// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a +// time. +type ContextWatcher struct { + handler Handler + unwatchChan chan struct{} + + lock sync.Mutex + watchInProgress bool + onCancelWasCalled bool +} + +// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled. +// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and +// onCancel called. +func NewContextWatcher(handler Handler) *ContextWatcher { + cw := &ContextWatcher{ + handler: handler, + unwatchChan: make(chan struct{}), + } + + return cw +} + +// Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called. +func (cw *ContextWatcher) Watch(ctx context.Context) { + cw.lock.Lock() + defer cw.lock.Unlock() + + if cw.watchInProgress { + panic("Watch already in progress") + } + + cw.onCancelWasCalled = false + + if ctx.Done() != nil { + cw.watchInProgress = true + go func() { + select { + case <-ctx.Done(): + cw.handler.HandleCancel(ctx) + cw.onCancelWasCalled = true + <-cw.unwatchChan + case <-cw.unwatchChan: + } + }() + } else { + cw.watchInProgress = false + } +} + +// Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was +// called then onUnwatchAfterCancel will also be called. +func (cw *ContextWatcher) Unwatch() { + cw.lock.Lock() + defer cw.lock.Unlock() + + if cw.watchInProgress { + cw.unwatchChan <- struct{}{} + if cw.onCancelWasCalled { + cw.handler.HandleUnwatchAfterCancel() + } + cw.watchInProgress = false + } +} + +type Handler interface { + // HandleCancel is called when the context that a ContextWatcher is currently watching is canceled. canceledCtx is the + // context that was canceled. + HandleCancel(canceledCtx context.Context) + + // HandleUnwatchAfterCancel is called when a ContextWatcher that called HandleCancel on this Handler is unwatched. + HandleUnwatchAfterCancel() +} diff --git a/pgconn/ctxwatch/context_watcher_test.go b/pgconn/ctxwatch/context_watcher_test.go new file mode 100644 index 000000000..a18e7339e --- /dev/null +++ b/pgconn/ctxwatch/context_watcher_test.go @@ -0,0 +1,185 @@ +package ctxwatch_test + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgconn/ctxwatch" + "github.com/stretchr/testify/require" +) + +type testHandler struct { + handleCancel func(context.Context) + handleUnwatchAfterCancel func() +} + +func (h *testHandler) HandleCancel(ctx context.Context) { + h.handleCancel(ctx) +} + +func (h *testHandler) HandleUnwatchAfterCancel() { + h.handleUnwatchAfterCancel() +} + +func TestContextWatcherContextCancelled(t *testing.T) { + canceledChan := make(chan struct{}) + cleanupCalled := false + cw := ctxwatch.NewContextWatcher(&testHandler{ + handleCancel: func(context.Context) { + canceledChan <- struct{}{} + }, handleUnwatchAfterCancel: func() { + cleanupCalled = true + }, + }) + + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + cancel() + + select { + case <-canceledChan: + case <-time.NewTimer(time.Second).C: + t.Fatal("Timed out waiting for cancel func to be called") + } + + cw.Unwatch() + + require.True(t, cleanupCalled, "Cleanup func was not called") +} + +func TestContextWatcherUnwatchedBeforeContextCancelled(t *testing.T) { + cw := ctxwatch.NewContextWatcher(&testHandler{ + handleCancel: func(context.Context) { + t.Error("cancel func should not have been called") + }, handleUnwatchAfterCancel: func() { + t.Error("cleanup func should not have been called") + }, + }) + + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + cw.Unwatch() + cancel() +} + +func TestContextWatcherMultipleWatchPanics(t *testing.T) { + cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cw.Watch(ctx) + defer cw.Unwatch() + + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times") +} + +func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) { + cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}}) + cw.Unwatch() // unwatch when not / never watching + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cw.Watch(ctx) + cw.Unwatch() + cw.Unwatch() // double unwatch +} + +func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) { + cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}}) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + cw.Watch(ctx) + + go cw.Unwatch() + go cw.Unwatch() + + <-ctx.Done() +} + +func TestContextWatcherStress(t *testing.T) { + var cancelFuncCalls int64 + var cleanupFuncCalls int64 + + cw := ctxwatch.NewContextWatcher(&testHandler{ + handleCancel: func(context.Context) { + atomic.AddInt64(&cancelFuncCalls, 1) + }, handleUnwatchAfterCancel: func() { + atomic.AddInt64(&cleanupFuncCalls, 1) + }, + }) + + cycleCount := 100000 + + for i := 0; i < cycleCount; i++ { + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + if i%2 == 0 { + cancel() + } + + // Without time.Sleep, cw.Unwatch will almost always run before the cancel func which means cancel will never happen. This gives us a better mix. + if i%333 == 0 { + // on Windows Sleep takes more time than expected so we try to get here less frequently to avoid + // the CI takes a long time + time.Sleep(time.Nanosecond) + } + + cw.Unwatch() + if i%2 == 1 { + cancel() + } + } + + actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls) + actualCleanupFuncCalls := atomic.LoadInt64(&cleanupFuncCalls) + + if actualCancelFuncCalls == 0 { + t.Fatal("actualCancelFuncCalls == 0") + } + + maxCancelFuncCalls := int64(cycleCount) / 2 + if actualCancelFuncCalls > maxCancelFuncCalls { + t.Errorf("cancel func calls should be no more than %d but was %d", actualCancelFuncCalls, maxCancelFuncCalls) + } + + if actualCancelFuncCalls != actualCleanupFuncCalls { + t.Errorf("cancel func calls (%d) should be equal to cleanup func calls (%d) but was not", actualCancelFuncCalls, actualCleanupFuncCalls) + } +} + +func BenchmarkContextWatcherUncancellable(b *testing.B) { + cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}}) + + for i := 0; i < b.N; i++ { + cw.Watch(context.Background()) + cw.Unwatch() + } +} + +func BenchmarkContextWatcherCancelled(b *testing.B) { + cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}}) + + for i := 0; i < b.N; i++ { + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + cancel() + cw.Unwatch() + } +} + +func BenchmarkContextWatcherCancellable(b *testing.B) { + cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for i := 0; i < b.N; i++ { + cw.Watch(ctx) + cw.Unwatch() + } +} diff --git a/pgconn/defaults.go b/pgconn/defaults.go new file mode 100644 index 000000000..1dd514ff4 --- /dev/null +++ b/pgconn/defaults.go @@ -0,0 +1,63 @@ +//go:build !windows +// +build !windows + +package pgconn + +import ( + "os" + "os/user" + "path/filepath" +) + +func defaultSettings() map[string]string { + settings := make(map[string]string) + + settings["host"] = defaultHost() + settings["port"] = "5432" + + // Default to the OS user name. Purposely ignoring err getting user name from + // OS. The client application will simply have to specify the user in that + // case (which they typically will be doing anyway). + user, err := user.Current() + if err == nil { + settings["user"] = user.Username + settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") + settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") + sslcert := filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") + sslkey := filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") + if _, err := os.Stat(sslcert); err == nil { + if _, err := os.Stat(sslkey); err == nil { + // Both the cert and key must be present to use them, or do not use either + settings["sslcert"] = sslcert + settings["sslkey"] = sslkey + } + } + sslrootcert := filepath.Join(user.HomeDir, ".postgresql", "root.crt") + if _, err := os.Stat(sslrootcert); err == nil { + settings["sslrootcert"] = sslrootcert + } + } + + settings["target_session_attrs"] = "any" + + return settings +} + +// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost +// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it +// checks the existence of common locations. +func defaultHost() string { + candidatePaths := []string{ + "/var/run/postgresql", // Debian + "/private/tmp", // OSX - homebrew + "/tmp", // standard PostgreSQL + } + + for _, path := range candidatePaths { + if _, err := os.Stat(path); err == nil { + return path + } + } + + return "localhost" +} diff --git a/pgconn/defaults_windows.go b/pgconn/defaults_windows.go new file mode 100644 index 000000000..33b4a1ff8 --- /dev/null +++ b/pgconn/defaults_windows.go @@ -0,0 +1,57 @@ +package pgconn + +import ( + "os" + "os/user" + "path/filepath" + "strings" +) + +func defaultSettings() map[string]string { + settings := make(map[string]string) + + settings["host"] = defaultHost() + settings["port"] = "5432" + + // Default to the OS user name. Purposely ignoring err getting user name from + // OS. The client application will simply have to specify the user in that + // case (which they typically will be doing anyway). + user, err := user.Current() + appData := os.Getenv("APPDATA") + if err == nil { + // Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`, + // but the libpq default is just the `user` portion, so we strip off the first part. + username := user.Username + if strings.Contains(username, "\\") { + username = username[strings.LastIndex(username, "\\")+1:] + } + + settings["user"] = username + settings["passfile"] = filepath.Join(appData, "postgresql", "pgpass.conf") + settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") + sslcert := filepath.Join(appData, "postgresql", "postgresql.crt") + sslkey := filepath.Join(appData, "postgresql", "postgresql.key") + if _, err := os.Stat(sslcert); err == nil { + if _, err := os.Stat(sslkey); err == nil { + // Both the cert and key must be present to use them, or do not use either + settings["sslcert"] = sslcert + settings["sslkey"] = sslkey + } + } + sslrootcert := filepath.Join(appData, "postgresql", "root.crt") + if _, err := os.Stat(sslrootcert); err == nil { + settings["sslrootcert"] = sslrootcert + } + } + + settings["target_session_attrs"] = "any" + + return settings +} + +// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost +// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it +// checks the existence of common locations. +func defaultHost() string { + return "localhost" +} diff --git a/pgconn/doc.go b/pgconn/doc.go new file mode 100644 index 000000000..701375019 --- /dev/null +++ b/pgconn/doc.go @@ -0,0 +1,38 @@ +// Package pgconn is a low-level PostgreSQL database driver. +/* +pgconn provides lower level access to a PostgreSQL connection than a database/sql or pgx connection. It operates at +nearly the same level is the C library libpq. + +Establishing a Connection + +Use Connect to establish a connection. It accepts a connection string in URL or keyword/value format and will read the +environment for libpq style environment variables. + +Executing a Query + +ExecParams and ExecPrepared execute a single query. They return readers that iterate over each row. The Read method +reads all rows into memory. + +Executing Multiple Queries in a Single Round Trip + +Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query +result. The ReadAll method reads all query results into memory. + +Pipeline Mode + +Pipeline mode allows sending queries without having read the results of previously sent queries. It allows control of +exactly how many and when network round trips occur. + +Context Support + +All potentially blocking operations take a context.Context. The default behavior when a context is canceled is for the +method to immediately return. In most circumstances, this will also close the underlying connection. This behavior can +be customized by using BuildContextWatcherHandler on the Config to create a ctxwatch.Handler with different behavior. +This can be especially useful when queries that are frequently canceled and the overhead of creating new connections is +a problem. DeadlineContextWatcherHandler and CancelRequestContextWatcherHandler can be used to introduce a delay before +interrupting the query in such a way as to close the connection. + +The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the +client to abort. +*/ +package pgconn diff --git a/pgconn/errors.go b/pgconn/errors.go new file mode 100644 index 000000000..d968d3f03 --- /dev/null +++ b/pgconn/errors.go @@ -0,0 +1,256 @@ +package pgconn + +import ( + "context" + "errors" + "fmt" + "net" + "net/url" + "regexp" + "strings" +) + +// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server. +func SafeToRetry(err error) bool { + var retryableErr interface{ SafeToRetry() bool } + if errors.As(err, &retryableErr) { + return retryableErr.SafeToRetry() + } + return false +} + +// Timeout checks if err was caused by a timeout. To be specific, it is true if err was caused within pgconn by a +// context.DeadlineExceeded or an implementer of net.Error where Timeout() is true. +func Timeout(err error) bool { + var timeoutErr *errTimeout + return errors.As(err, &timeoutErr) +} + +// PgError represents an error reported by the PostgreSQL server. See +// http://www.postgresql.org/docs/current/static/protocol-error-fields.html for +// detailed field description. +type PgError struct { + Severity string + SeverityUnlocalized string + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string +} + +func (pe *PgError) Error() string { + return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" +} + +// SQLState returns the SQLState of the error. +func (pe *PgError) SQLState() string { + return pe.Code +} + +// ConnectError is the error returned when a connection attempt fails. +type ConnectError struct { + Config *Config // The configuration that was used in the connection attempt. + err error +} + +func (e *ConnectError) Error() string { + prefix := fmt.Sprintf("failed to connect to `user=%s database=%s`:", e.Config.User, e.Config.Database) + details := e.err.Error() + if strings.Contains(details, "\n") { + return prefix + "\n\t" + strings.ReplaceAll(details, "\n", "\n\t") + } else { + return prefix + " " + details + } +} + +func (e *ConnectError) Unwrap() error { + return e.err +} + +type perDialConnectError struct { + address string + originalHostname string + err error +} + +func (e *perDialConnectError) Error() string { + return fmt.Sprintf("%s (%s): %s", e.address, e.originalHostname, e.err.Error()) +} + +func (e *perDialConnectError) Unwrap() error { + return e.err +} + +type connLockError struct { + status string +} + +func (e *connLockError) SafeToRetry() bool { + return true // a lock failure by definition happens before the connection is used. +} + +func (e *connLockError) Error() string { + return e.status +} + +// ParseConfigError is the error returned when a connection string cannot be parsed. +type ParseConfigError struct { + ConnString string // The connection string that could not be parsed. + msg string + err error +} + +func NewParseConfigError(conn, msg string, err error) error { + return &ParseConfigError{ + ConnString: conn, + msg: msg, + err: err, + } +} + +func (e *ParseConfigError) Error() string { + // Now that ParseConfigError is public and ConnString is available to the developer, perhaps it would be better only + // return a static string. That would ensure that the error message cannot leak a password. The ConnString field would + // allow access to the original string if desired and Unwrap would allow access to the underlying error. + connString := redactPW(e.ConnString) + if e.err == nil { + return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg) + } + return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error()) +} + +func (e *ParseConfigError) Unwrap() error { + return e.err +} + +func normalizeTimeoutError(ctx context.Context, err error) error { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + if ctx.Err() == context.Canceled { + // Since the timeout was caused by a context cancellation, the actual error is context.Canceled not the timeout error. + return context.Canceled + } else if ctx.Err() == context.DeadlineExceeded { + return &errTimeout{err: ctx.Err()} + } else { + return &errTimeout{err: netErr} + } + } + return err +} + +type pgconnError struct { + msg string + err error + safeToRetry bool +} + +func (e *pgconnError) Error() string { + if e.msg == "" { + return e.err.Error() + } + if e.err == nil { + return e.msg + } + return fmt.Sprintf("%s: %s", e.msg, e.err.Error()) +} + +func (e *pgconnError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *pgconnError) Unwrap() error { + return e.err +} + +// errTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is +// context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true. +type errTimeout struct { + err error +} + +func (e *errTimeout) Error() string { + return fmt.Sprintf("timeout: %s", e.err.Error()) +} + +func (e *errTimeout) SafeToRetry() bool { + return SafeToRetry(e.err) +} + +func (e *errTimeout) Unwrap() error { + return e.err +} + +type contextAlreadyDoneError struct { + err error +} + +func (e *contextAlreadyDoneError) Error() string { + return fmt.Sprintf("context already done: %s", e.err.Error()) +} + +func (e *contextAlreadyDoneError) SafeToRetry() bool { + return true +} + +func (e *contextAlreadyDoneError) Unwrap() error { + return e.err +} + +// newContextAlreadyDoneError double-wraps a context error in `contextAlreadyDoneError` and `errTimeout`. +func newContextAlreadyDoneError(ctx context.Context) (err error) { + return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}} +} + +func redactPW(connString string) string { + if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { + if u, err := url.Parse(connString); err == nil { + return redactURL(u) + } + } + quotedKV := regexp.MustCompile(`password='[^']*'`) + connString = quotedKV.ReplaceAllLiteralString(connString, "password=xxxxx") + plainKV := regexp.MustCompile(`password=[^ ]*`) + connString = plainKV.ReplaceAllLiteralString(connString, "password=xxxxx") + brokenURL := regexp.MustCompile(`:[^:@]+?@`) + connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@") + return connString +} + +func redactURL(u *url.URL) string { + if u == nil { + return "" + } + if _, pwSet := u.User.Password(); pwSet { + u.User = url.UserPassword(u.User.Username(), "xxxxx") + } + return u.String() +} + +type NotPreferredError struct { + err error + safeToRetry bool +} + +func (e *NotPreferredError) Error() string { + return fmt.Sprintf("standby server not found: %s", e.err.Error()) +} + +func (e *NotPreferredError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *NotPreferredError) Unwrap() error { + return e.err +} diff --git a/pgconn/errors_test.go b/pgconn/errors_test.go new file mode 100644 index 000000000..bbbfeb3c3 --- /dev/null +++ b/pgconn/errors_test.go @@ -0,0 +1,54 @@ +package pgconn_test + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/assert" +) + +func TestConfigError(t *testing.T) { + tests := []struct { + name string + err error + expectedMsg string + }{ + { + name: "url with password", + err: pgconn.NewParseConfigError("postgresql://foo:password@host", "msg", nil), + expectedMsg: "cannot parse `postgresql://foo:xxxxx@host`: msg", + }, + { + name: "keyword/value with password unquoted", + err: pgconn.NewParseConfigError("host=host password=password user=user", "msg", nil), + expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg", + }, + { + name: "keyword/value with password quoted", + err: pgconn.NewParseConfigError("host=host password='pass word' user=user", "msg", nil), + expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg", + }, + { + name: "weird url", + err: pgconn.NewParseConfigError("postgresql://foo::password@host:1:", "msg", nil), + expectedMsg: "cannot parse `postgresql://foo:xxxxx@host:1:`: msg", + }, + { + name: "weird url with slash in password", + err: pgconn.NewParseConfigError("postgres://user:pass/word@host:5432/db_name", "msg", nil), + expectedMsg: "cannot parse `postgres://user:xxxxxx@host:5432/db_name`: msg", + }, + { + name: "url without password", + err: pgconn.NewParseConfigError("postgresql://other@host/db", "msg", nil), + expectedMsg: "cannot parse `postgresql://other@host/db`: msg", + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.EqualError(t, tt.err, tt.expectedMsg) + }) + } +} diff --git a/pgconn/export_test.go b/pgconn/export_test.go new file mode 100644 index 000000000..9c0e02e74 --- /dev/null +++ b/pgconn/export_test.go @@ -0,0 +1,3 @@ +// File export_test exports some methods for better testing. + +package pgconn diff --git a/pgconn/helper_test.go b/pgconn/helper_test.go new file mode 100644 index 000000000..1ab7579ad --- /dev/null +++ b/pgconn/helper_test.go @@ -0,0 +1,36 @@ +package pgconn_test + +import ( + "context" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgconn" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func closeConn(t testing.TB, conn *pgconn.PgConn) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + require.NoError(t, conn.Close(ctx)) + select { + case <-conn.CleanupDone(): + case <-time.After(30 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + +// Do a simple query to ensure the connection is still usable +func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).Read() + cancel() + + require.Nil(t, result.Err) + assert.Equal(t, 3, len(result.Rows)) + assert.Equal(t, "1", string(result.Rows[0][0])) + assert.Equal(t, "2", string(result.Rows[1][0])) + assert.Equal(t, "3", string(result.Rows[2][0])) +} diff --git a/pgconn/internal/bgreader/bgreader.go b/pgconn/internal/bgreader/bgreader.go new file mode 100644 index 000000000..e65c2c2bf --- /dev/null +++ b/pgconn/internal/bgreader/bgreader.go @@ -0,0 +1,139 @@ +// Package bgreader provides a io.Reader that can optionally buffer reads in the background. +package bgreader + +import ( + "io" + "sync" + + "github.com/jackc/pgx/v5/internal/iobufpool" +) + +const ( + StatusStopped = iota + StatusRunning + StatusStopping +) + +// BGReader is an io.Reader that can optionally buffer reads in the background. It is safe for concurrent use. +type BGReader struct { + r io.Reader + + cond *sync.Cond + status int32 + readResults []readResult +} + +type readResult struct { + buf *[]byte + err error +} + +// Start starts the backgrounder reader. If the background reader is already running this is a no-op. The background +// reader will stop automatically when the underlying reader returns an error. +func (r *BGReader) Start() { + r.cond.L.Lock() + defer r.cond.L.Unlock() + + switch r.status { + case StatusStopped: + r.status = StatusRunning + go r.bgRead() + case StatusRunning: + // no-op + case StatusStopping: + r.status = StatusRunning + } +} + +// Stop tells the background reader to stop after the in progress Read returns. It is safe to call Stop when the +// background reader is not running. +func (r *BGReader) Stop() { + r.cond.L.Lock() + defer r.cond.L.Unlock() + + switch r.status { + case StatusStopped: + // no-op + case StatusRunning: + r.status = StatusStopping + case StatusStopping: + // no-op + } +} + +// Status returns the current status of the background reader. +func (r *BGReader) Status() int32 { + r.cond.L.Lock() + defer r.cond.L.Unlock() + return r.status +} + +func (r *BGReader) bgRead() { + keepReading := true + for keepReading { + buf := iobufpool.Get(8192) + n, err := r.r.Read(*buf) + *buf = (*buf)[:n] + + r.cond.L.Lock() + r.readResults = append(r.readResults, readResult{buf: buf, err: err}) + if r.status == StatusStopping || err != nil { + r.status = StatusStopped + keepReading = false + } + r.cond.L.Unlock() + r.cond.Broadcast() + } +} + +// Read implements the io.Reader interface. +func (r *BGReader) Read(p []byte) (int, error) { + r.cond.L.Lock() + defer r.cond.L.Unlock() + + if len(r.readResults) > 0 { + return r.readFromReadResults(p) + } + + // There are no unread background read results and the background reader is stopped. + if r.status == StatusStopped { + return r.r.Read(p) + } + + // Wait for results from the background reader + for len(r.readResults) == 0 { + r.cond.Wait() + } + return r.readFromReadResults(p) +} + +// readBackgroundResults reads a result previously read by the background reader. r.cond.L must be held. +func (r *BGReader) readFromReadResults(p []byte) (int, error) { + buf := r.readResults[0].buf + var err error + + n := copy(p, *buf) + if n == len(*buf) { + err = r.readResults[0].err + iobufpool.Put(buf) + if len(r.readResults) == 1 { + r.readResults = nil + } else { + r.readResults = r.readResults[1:] + } + } else { + *buf = (*buf)[n:] + r.readResults[0].buf = buf + } + + return n, err +} + +func New(r io.Reader) *BGReader { + return &BGReader{ + r: r, + cond: &sync.Cond{ + L: &sync.Mutex{}, + }, + } +} diff --git a/pgconn/internal/bgreader/bgreader_test.go b/pgconn/internal/bgreader/bgreader_test.go new file mode 100644 index 000000000..f787e2f1d --- /dev/null +++ b/pgconn/internal/bgreader/bgreader_test.go @@ -0,0 +1,140 @@ +package bgreader_test + +import ( + "bytes" + "errors" + "io" + "math/rand" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgconn/internal/bgreader" + "github.com/stretchr/testify/require" +) + +func TestBGReaderReadWhenStopped(t *testing.T) { + r := bytes.NewReader([]byte("foo bar baz")) + bgr := bgreader.New(r) + buf, err := io.ReadAll(bgr) + require.NoError(t, err) + require.Equal(t, []byte("foo bar baz"), buf) +} + +func TestBGReaderReadWhenStarted(t *testing.T) { + r := bytes.NewReader([]byte("foo bar baz")) + bgr := bgreader.New(r) + bgr.Start() + buf, err := io.ReadAll(bgr) + require.NoError(t, err) + require.Equal(t, []byte("foo bar baz"), buf) +} + +type mockReadFunc func(p []byte) (int, error) + +type mockReader struct { + readFuncs []mockReadFunc +} + +func (r *mockReader) Read(p []byte) (int, error) { + if len(r.readFuncs) == 0 { + return 0, io.EOF + } + + fn := r.readFuncs[0] + r.readFuncs = r.readFuncs[1:] + + return fn(p) +} + +func TestBGReaderReadWaitsForBackgroundRead(t *testing.T) { + rr := &mockReader{ + readFuncs: []mockReadFunc{ + func(p []byte) (int, error) { time.Sleep(1 * time.Second); return copy(p, []byte("foo")), nil }, + func(p []byte) (int, error) { return copy(p, []byte("bar")), nil }, + func(p []byte) (int, error) { return copy(p, []byte("baz")), nil }, + }, + } + bgr := bgreader.New(rr) + bgr.Start() + buf := make([]byte, 3) + n, err := bgr.Read(buf) + require.NoError(t, err) + require.EqualValues(t, 3, n) + require.Equal(t, []byte("foo"), buf) +} + +func TestBGReaderErrorWhenStarted(t *testing.T) { + rr := &mockReader{ + readFuncs: []mockReadFunc{ + func(p []byte) (int, error) { return copy(p, []byte("foo")), nil }, + func(p []byte) (int, error) { return copy(p, []byte("bar")), nil }, + func(p []byte) (int, error) { return copy(p, []byte("baz")), errors.New("oops") }, + }, + } + + bgr := bgreader.New(rr) + bgr.Start() + buf, err := io.ReadAll(bgr) + require.Equal(t, []byte("foobarbaz"), buf) + require.EqualError(t, err, "oops") +} + +func TestBGReaderErrorWhenStopped(t *testing.T) { + rr := &mockReader{ + readFuncs: []mockReadFunc{ + func(p []byte) (int, error) { return copy(p, []byte("foo")), nil }, + func(p []byte) (int, error) { return copy(p, []byte("bar")), nil }, + func(p []byte) (int, error) { return copy(p, []byte("baz")), errors.New("oops") }, + }, + } + + bgr := bgreader.New(rr) + buf, err := io.ReadAll(bgr) + require.Equal(t, []byte("foobarbaz"), buf) + require.EqualError(t, err, "oops") +} + +type numberReader struct { + v uint8 + rng *rand.Rand +} + +func (nr *numberReader) Read(p []byte) (int, error) { + n := nr.rng.Intn(len(p)) + for i := 0; i < n; i++ { + p[i] = nr.v + nr.v++ + } + + return n, nil +} + +// TestBGReaderStress stress tests BGReader by reading a lot of bytes in random sizes while randomly starting and +// stopping the background worker from other goroutines. +func TestBGReaderStress(t *testing.T) { + nr := &numberReader{rng: rand.New(rand.NewSource(0))} + bgr := bgreader.New(nr) + + bytesRead := 0 + var expected uint8 + buf := make([]byte, 10_000) + rng := rand.New(rand.NewSource(0)) + + for bytesRead < 1_000_000 { + randomNumber := rng.Intn(100) + switch { + case randomNumber < 10: + go bgr.Start() + case randomNumber < 20: + go bgr.Stop() + default: + n, err := bgr.Read(buf) + require.NoError(t, err) + for i := 0; i < n; i++ { + require.Equal(t, expected, buf[i]) + expected++ + } + bytesRead += n + } + } +} diff --git a/pgconn/krb5.go b/pgconn/krb5.go new file mode 100644 index 000000000..efb0d61b8 --- /dev/null +++ b/pgconn/krb5.go @@ -0,0 +1,100 @@ +package pgconn + +import ( + "errors" + "fmt" + + "github.com/jackc/pgx/v5/pgproto3" +) + +// NewGSSFunc creates a GSS authentication provider, for use with +// RegisterGSSProvider. +type NewGSSFunc func() (GSS, error) + +var newGSS NewGSSFunc + +// RegisterGSSProvider registers a GSS authentication provider. For example, if +// you need to use Kerberos to authenticate with your server, add this to your +// main package: +// +// import "github.com/otan/gopgkrb5" +// +// func init() { +// pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() }) +// } +func RegisterGSSProvider(newGSSArg NewGSSFunc) { + newGSS = newGSSArg +} + +// GSS provides GSSAPI authentication (e.g., Kerberos). +type GSS interface { + GetInitToken(host, service string) ([]byte, error) + GetInitTokenFromSPN(spn string) ([]byte, error) + Continue(inToken []byte) (done bool, outToken []byte, err error) +} + +func (c *PgConn) gssAuth() error { + if newGSS == nil { + return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5") + } + cli, err := newGSS() + if err != nil { + return err + } + + var nextData []byte + if c.config.KerberosSpn != "" { + // Use the supplied SPN if provided. + nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn) + } else { + // Allow the kerberos service name to be overridden + service := "postgres" + if c.config.KerberosSrvName != "" { + service = c.config.KerberosSrvName + } + nextData, err = cli.GetInitToken(c.config.Host, service) + } + if err != nil { + return err + } + + for { + gssResponse := &pgproto3.GSSResponse{ + Data: nextData, + } + c.frontend.Send(gssResponse) + err = c.flushWithPotentialWriteReadDeadlock() + if err != nil { + return err + } + resp, err := c.rxGSSContinue() + if err != nil { + return err + } + var done bool + done, nextData, err = cli.Continue(resp.Data) + if err != nil { + return err + } + if done { + break + } + } + return nil +} + +func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) { + msg, err := c.receiveMessage() + if err != nil { + return nil, err + } + + switch m := msg.(type) { + case *pgproto3.AuthenticationGSSContinue: + return m, nil + case *pgproto3.ErrorResponse: + return nil, ErrorResponseToPgError(m) + } + + return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg) +} diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go new file mode 100644 index 000000000..321656f98 --- /dev/null +++ b/pgconn/pgconn.go @@ -0,0 +1,2509 @@ +package pgconn + +import ( + "container/list" + "context" + "crypto/md5" + "crypto/tls" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" + "math" + "net" + "strconv" + "strings" + "sync" + "time" + + "github.com/jackc/pgx/v5/internal/iobufpool" + "github.com/jackc/pgx/v5/internal/pgio" + "github.com/jackc/pgx/v5/pgconn/ctxwatch" + "github.com/jackc/pgx/v5/pgconn/internal/bgreader" + "github.com/jackc/pgx/v5/pgproto3" +) + +const ( + connStatusUninitialized = iota + connStatusConnecting + connStatusClosed + connStatusIdle + connStatusBusy +) + +// Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from +// LISTEN/NOTIFY notification. +type Notice PgError + +// Notification is a message received from the PostgreSQL LISTEN/NOTIFY system +type Notification struct { + PID uint32 // backend pid that sent the notification + Channel string // channel from which notification was received + Payload string +} + +// DialFunc is a function that can be used to connect to a PostgreSQL server. +type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) + +// LookupFunc is a function that can be used to lookup IPs addrs from host. Optionally an ip:port combination can be +// returned in order to override the connection string's port. +type LookupFunc func(ctx context.Context, host string) (addrs []string, err error) + +// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. +type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend + +// PgErrorHandler is a function that handles errors returned from Postgres. This function must return true to keep +// the connection open. Returning false will cause the connection to be closed immediately. You should return +// false on any FATAL-severity errors. This will not receive network errors. The *PgConn is provided so the handler is +// aware of the origin of the error, but it must not invoke any query method. +type PgErrorHandler func(*PgConn, *PgError) bool + +// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at +// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin +// of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY +// notification. +type NoticeHandler func(*PgConn, *Notice) + +// NotificationHandler is a function that can handle notifications received from the PostgreSQL server. Notifications +// can be received at any time, usually during handling of a query response. The *PgConn is provided so the handler is +// aware of the origin of the notice, but it must not invoke any query method. Be aware that this is distinct from a +// notice event. +type NotificationHandler func(*PgConn, *Notification) + +// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. +type PgConn struct { + conn net.Conn + pid uint32 // backend pid + secretKey uint32 // key to use to send a cancel query message to the server + parameterStatuses map[string]string // parameters that have been reported by the server + txStatus byte + frontend *pgproto3.Frontend + bgReader *bgreader.BGReader + slowWriteTimer *time.Timer + bgReaderStarted chan struct{} + + customData map[string]any + + config *Config + + status byte // One of connStatus* constants + + bufferingReceive bool + bufferingReceiveMux sync.Mutex + bufferingReceiveMsg pgproto3.BackendMessage + bufferingReceiveErr error + + peekedMsg pgproto3.BackendMessage + + // Reusable / preallocated resources + resultReader ResultReader + multiResultReader MultiResultReader + pipeline Pipeline + contextWatcher *ctxwatch.ContextWatcher + fieldDescriptions [16]FieldDescription + + cleanupDone chan struct{} +} + +// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or keyword/value +// format) to provide configuration. See documentation for [ParseConfig] for details. ctx can be used to cancel a +// connect attempt. +func Connect(ctx context.Context, connString string) (*PgConn, error) { + config, err := ParseConfig(connString) + if err != nil { + return nil, err + } + + return ConnectConfig(ctx, config) +} + +// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or keyword/value +// format) and ParseConfigOptions to provide additional configuration. See documentation for [ParseConfig] for details. +// ctx can be used to cancel a connect attempt. +func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) { + config, err := ParseConfigWithOptions(connString, parseConfigOptions) + if err != nil { + return nil, err + } + + return ConnectConfig(ctx, config) +} + +// Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with +// [ParseConfig]. ctx can be used to cancel a connect attempt. +// +// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An +// authentication error will terminate the chain of attempts (like libpq: +// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. +func ConnectConfig(ctx context.Context, config *Config) (*PgConn, error) { + // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from + // zero values. + if !config.createdByParseConfig { + panic("config must be created by ParseConfig") + } + + var allErrors []error + + connectConfigs, errs := buildConnectOneConfigs(ctx, config) + if len(errs) > 0 { + allErrors = append(allErrors, errs...) + } + + if len(connectConfigs) == 0 { + return nil, &ConnectError{Config: config, err: fmt.Errorf("hostname resolving error: %w", errors.Join(allErrors...))} + } + + pgConn, errs := connectPreferred(ctx, config, connectConfigs) + if len(errs) > 0 { + allErrors = append(allErrors, errs...) + return nil, &ConnectError{Config: config, err: errors.Join(allErrors...)} + } + + if config.AfterConnect != nil { + err := config.AfterConnect(ctx, pgConn) + if err != nil { + pgConn.conn.Close() + return nil, &ConnectError{Config: config, err: fmt.Errorf("AfterConnect error: %w", err)} + } + } + + return pgConn, nil +} + +// buildConnectOneConfigs resolves hostnames and builds a list of connectOneConfigs to try connecting to. It returns a +// slice of successfully resolved connectOneConfigs and a slice of errors. It is possible for both slices to contain +// values if some hosts were successfully resolved and others were not. +func buildConnectOneConfigs(ctx context.Context, config *Config) ([]*connectOneConfig, []error) { + // Simplify usage by treating primary config and fallbacks the same. + fallbackConfigs := []*FallbackConfig{ + { + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + } + fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) + + var configs []*connectOneConfig + + var allErrors []error + + for _, fb := range fallbackConfigs { + // skip resolve for unix sockets + if isAbsolutePath(fb.Host) { + network, address := NetworkAddress(fb.Host, fb.Port) + configs = append(configs, &connectOneConfig{ + network: network, + address: address, + originalHostname: fb.Host, + tlsConfig: fb.TLSConfig, + }) + + continue + } + + ips, err := config.LookupFunc(ctx, fb.Host) + if err != nil { + allErrors = append(allErrors, err) + continue + } + + for _, ip := range ips { + splitIP, splitPort, err := net.SplitHostPort(ip) + if err == nil { + port, err := strconv.ParseUint(splitPort, 10, 16) + if err != nil { + return nil, []error{fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err)} + } + network, address := NetworkAddress(splitIP, uint16(port)) + configs = append(configs, &connectOneConfig{ + network: network, + address: address, + originalHostname: fb.Host, + tlsConfig: fb.TLSConfig, + }) + } else { + network, address := NetworkAddress(ip, fb.Port) + configs = append(configs, &connectOneConfig{ + network: network, + address: address, + originalHostname: fb.Host, + tlsConfig: fb.TLSConfig, + }) + } + } + } + + return configs, allErrors +} + +// connectPreferred attempts to connect to the preferred host from connectOneConfigs. The connections are attempted in +// order. If a connection is successful it is returned. If no connection is successful then all errors are returned. If +// a connection attempt returns a [NotPreferredError], then that host will be used if no other hosts are successful. +func connectPreferred(ctx context.Context, config *Config, connectOneConfigs []*connectOneConfig) (*PgConn, []error) { + octx := ctx + var allErrors []error + + var fallbackConnectOneConfig *connectOneConfig + for i, c := range connectOneConfigs { + // ConnectTimeout restricts the whole connection process. + if config.ConnectTimeout != 0 { + // create new context first time or when previous host was different + if i == 0 || (connectOneConfigs[i].address != connectOneConfigs[i-1].address) { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout) + defer cancel() + } + } else { + ctx = octx + } + + pgConn, err := connectOne(ctx, config, c, false) + if pgConn != nil { + return pgConn, nil + } + + allErrors = append(allErrors, err) + + var pgErr *PgError + if errors.As(err, &pgErr) { + // pgx will try next host even if libpq does not in certain cases (see #2246) + // consider change for the next major version + + const ERRCODE_INVALID_PASSWORD = "28P01" + const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist + const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege + + // auth failed due to invalid password, db does not exist or user has no permission + if pgErr.Code == ERRCODE_INVALID_PASSWORD || + pgErr.Code == ERRCODE_INVALID_CATALOG_NAME || + pgErr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { + return nil, allErrors + } + } + + var npErr *NotPreferredError + if errors.As(err, &npErr) { + fallbackConnectOneConfig = c + } + } + + if fallbackConnectOneConfig != nil { + pgConn, err := connectOne(ctx, config, fallbackConnectOneConfig, true) + if err == nil { + return pgConn, nil + } + allErrors = append(allErrors, err) + } + + return nil, allErrors +} + +// connectOne makes one connection attempt to a single host. +func connectOne(ctx context.Context, config *Config, connectConfig *connectOneConfig, + ignoreNotPreferredErr bool, +) (*PgConn, error) { + pgConn := new(PgConn) + pgConn.config = config + pgConn.cleanupDone = make(chan struct{}) + pgConn.customData = make(map[string]any) + + var err error + + newPerDialConnectError := func(msg string, err error) *perDialConnectError { + err = normalizeTimeoutError(ctx, err) + e := &perDialConnectError{address: connectConfig.address, originalHostname: connectConfig.originalHostname, err: fmt.Errorf("%s: %w", msg, err)} + return e + } + + pgConn.conn, err = config.DialFunc(ctx, connectConfig.network, connectConfig.address) + if err != nil { + return nil, newPerDialConnectError("dial error", err) + } + + if connectConfig.tlsConfig != nil { + pgConn.contextWatcher = ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: pgConn.conn}) + pgConn.contextWatcher.Watch(ctx) + var ( + tlsConn net.Conn + err error + ) + if config.SSLNegotiation == "direct" { + tlsConn = tls.Client(pgConn.conn, connectConfig.tlsConfig) + } else { + tlsConn, err = startTLS(pgConn.conn, connectConfig.tlsConfig) + } + pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. + if err != nil { + pgConn.conn.Close() + return nil, newPerDialConnectError("tls error", err) + } + + pgConn.conn = tlsConn + } + + pgConn.contextWatcher = ctxwatch.NewContextWatcher(config.BuildContextWatcherHandler(pgConn)) + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + + pgConn.parameterStatuses = make(map[string]string) + pgConn.status = connStatusConnecting + pgConn.bgReader = bgreader.New(pgConn.conn) + pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), + func() { + pgConn.bgReader.Start() + pgConn.bgReaderStarted <- struct{}{} + }, + ) + pgConn.slowWriteTimer.Stop() + pgConn.bgReaderStarted = make(chan struct{}) + pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn) + + startupMsg := pgproto3.StartupMessage{ + ProtocolVersion: pgproto3.ProtocolVersionNumber, + Parameters: make(map[string]string), + } + + // Copy default run-time params + for k, v := range config.RuntimeParams { + startupMsg.Parameters[k] = v + } + + startupMsg.Parameters["user"] = config.User + if config.Database != "" { + startupMsg.Parameters["database"] = config.Database + } + + pgConn.frontend.Send(&startupMsg) + if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil { + pgConn.conn.Close() + return nil, newPerDialConnectError("failed to write startup message", err) + } + + for { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.conn.Close() + if err, ok := err.(*PgError); ok { + return nil, newPerDialConnectError("server error", err) + } + return nil, newPerDialConnectError("failed to receive message", err) + } + + switch msg := msg.(type) { + case *pgproto3.BackendKeyData: + pgConn.pid = msg.ProcessID + pgConn.secretKey = msg.SecretKey + + case *pgproto3.AuthenticationOk: + case *pgproto3.AuthenticationCleartextPassword: + err = pgConn.txPasswordMessage(pgConn.config.Password) + if err != nil { + pgConn.conn.Close() + return nil, newPerDialConnectError("failed to write password message", err) + } + case *pgproto3.AuthenticationMD5Password: + digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) + err = pgConn.txPasswordMessage(digestedPassword) + if err != nil { + pgConn.conn.Close() + return nil, newPerDialConnectError("failed to write password message", err) + } + case *pgproto3.AuthenticationSASL: + err = pgConn.scramAuth(msg.AuthMechanisms) + if err != nil { + pgConn.conn.Close() + return nil, newPerDialConnectError("failed SASL auth", err) + } + case *pgproto3.AuthenticationGSS: + err = pgConn.gssAuth() + if err != nil { + pgConn.conn.Close() + return nil, newPerDialConnectError("failed GSS auth", err) + } + case *pgproto3.ReadyForQuery: + pgConn.status = connStatusIdle + if config.ValidateConnect != nil { + // ValidateConnect may execute commands that cause the context to be watched again. Unwatch first to avoid + // the watch already in progress panic. This is that last thing done by this method so there is no need to + // restart the watch after ValidateConnect returns. + // + // See https://github.com/jackc/pgconn/issues/40. + pgConn.contextWatcher.Unwatch() + + err := config.ValidateConnect(ctx, pgConn) + if err != nil { + if _, ok := err.(*NotPreferredError); ignoreNotPreferredErr && ok { + return pgConn, nil + } + pgConn.conn.Close() + return nil, newPerDialConnectError("ValidateConnect failed", err) + } + } + return pgConn, nil + case *pgproto3.ParameterStatus, *pgproto3.NoticeResponse: + // handled by ReceiveMessage + case *pgproto3.ErrorResponse: + pgConn.conn.Close() + return nil, newPerDialConnectError("server error", ErrorResponseToPgError(msg)) + default: + pgConn.conn.Close() + return nil, newPerDialConnectError("received unexpected message", err) + } + } +} + +func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { + err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) + if err != nil { + return nil, err + } + + response := make([]byte, 1) + if _, err = io.ReadFull(conn, response); err != nil { + return nil, err + } + + if response[0] != 'S' { + return nil, errors.New("server refused TLS connection") + } + + return tls.Client(conn, tlsConfig), nil +} + +func (pgConn *PgConn) txPasswordMessage(password string) (err error) { + pgConn.frontend.Send(&pgproto3.PasswordMessage{Password: password}) + return pgConn.flushWithPotentialWriteReadDeadlock() +} + +func hexMD5(s string) string { + hash := md5.New() + io.WriteString(hash, s) + return hex.EncodeToString(hash.Sum(nil)) +} + +func (pgConn *PgConn) signalMessage() chan struct{} { + if pgConn.bufferingReceive { + panic("BUG: signalMessage when already in progress") + } + + pgConn.bufferingReceive = true + pgConn.bufferingReceiveMux.Lock() + + ch := make(chan struct{}) + go func() { + pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive() + pgConn.bufferingReceiveMux.Unlock() + close(ch) + }() + + return ch +} + +// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the +// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages +// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger +// the OnNotification callback. +// +// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. +// See https://www.postgresql.org/docs/current/protocol.html. +func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + defer pgConn.unlock() + + if ctx != context.Background() { + select { + case <-ctx.Done(): + return nil, newContextAlreadyDoneError(ctx) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + msg, err := pgConn.receiveMessage() + if err != nil { + err = &pgconnError{ + msg: "receive message failed", + err: normalizeTimeoutError(ctx, err), + safeToRetry: true, + } + } + return msg, err +} + +// peekMessage peeks at the next message without setting up context cancellation. +func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { + if pgConn.peekedMsg != nil { + return pgConn.peekedMsg, nil + } + + var msg pgproto3.BackendMessage + var err error + if pgConn.bufferingReceive { + pgConn.bufferingReceiveMux.Lock() + msg = pgConn.bufferingReceiveMsg + err = pgConn.bufferingReceiveErr + pgConn.bufferingReceiveMux.Unlock() + pgConn.bufferingReceive = false + + // If a timeout error happened in the background try the read again. + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + msg, err = pgConn.frontend.Receive() + } + } else { + msg, err = pgConn.frontend.Receive() + } + + if err != nil { + // Close on anything other than timeout error - everything else is fatal + var netErr net.Error + isNetErr := errors.As(err, &netErr) + if !(isNetErr && netErr.Timeout()) { + pgConn.asyncClose() + } + + return nil, err + } + + pgConn.peekedMsg = msg + return msg, nil +} + +// receiveMessage receives a message without setting up context cancellation +func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { + msg, err := pgConn.peekMessage() + if err != nil { + return nil, err + } + pgConn.peekedMsg = nil + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + pgConn.txStatus = msg.TxStatus + case *pgproto3.ParameterStatus: + pgConn.parameterStatuses[msg.Name] = msg.Value + case *pgproto3.ErrorResponse: + err := ErrorResponseToPgError(msg) + if pgConn.config.OnPgError != nil && !pgConn.config.OnPgError(pgConn, err) { + pgConn.status = connStatusClosed + pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return. + close(pgConn.cleanupDone) + return nil, err + } + case *pgproto3.NoticeResponse: + if pgConn.config.OnNotice != nil { + pgConn.config.OnNotice(pgConn, noticeResponseToNotice(msg)) + } + case *pgproto3.NotificationResponse: + if pgConn.config.OnNotification != nil { + pgConn.config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload}) + } + } + + return msg, nil +} + +// Conn returns the underlying net.Conn. This rarely necessary. If the connection will be directly used for reading or +// writing then SyncConn should usually be called before Conn. +func (pgConn *PgConn) Conn() net.Conn { + return pgConn.conn +} + +// PID returns the backend PID. +func (pgConn *PgConn) PID() uint32 { + return pgConn.pid +} + +// TxStatus returns the current TxStatus as reported by the server in the ReadyForQuery message. +// +// Possible return values: +// +// 'I' - idle / not in transaction +// 'T' - in a transaction +// 'E' - in a failed transaction +// +// See https://www.postgresql.org/docs/current/protocol-message-formats.html. +func (pgConn *PgConn) TxStatus() byte { + return pgConn.txStatus +} + +// SecretKey returns the backend secret key used to send a cancel query message to the server. +func (pgConn *PgConn) SecretKey() uint32 { + return pgConn.secretKey +} + +// Frontend returns the underlying *pgproto3.Frontend. This rarely necessary. +func (pgConn *PgConn) Frontend() *pgproto3.Frontend { + return pgConn.frontend +} + +// Close closes a connection. It is safe to call Close on an already closed connection. Close attempts a clean close by +// sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The +// underlying net.Conn.Close() will always be called regardless of any other errors. +func (pgConn *PgConn) Close(ctx context.Context) error { + if pgConn.status == connStatusClosed { + return nil + } + pgConn.status = connStatusClosed + + defer close(pgConn.cleanupDone) + defer pgConn.conn.Close() + + if ctx != context.Background() { + // Close may be called while a cancellable query is in progress. This will most often be triggered by panic when + // a defer closes the connection (possibly indirectly via a transaction or a connection pool). Unwatch to end any + // previous watch. It is safe to Unwatch regardless of whether a watch is already is progress. + // + // See https://github.com/jackc/pgconn/issues/29 + pgConn.contextWatcher.Unwatch() + + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + // Ignore any errors sending Terminate message and waiting for server to close connection. + // This mimics the behavior of libpq PQfinish. It calls closePGconn which calls sendTerminateConn which purposefully + // ignores errors. + // + // See https://github.com/jackc/pgx/issues/637 + pgConn.frontend.Send(&pgproto3.Terminate{}) + pgConn.flushWithPotentialWriteReadDeadlock() + + return pgConn.conn.Close() +} + +// asyncClose marks the connection as closed and asynchronously sends a cancel query message and closes the underlying +// connection. +func (pgConn *PgConn) asyncClose() { + if pgConn.status == connStatusClosed { + return + } + pgConn.status = connStatusClosed + + go func() { + defer close(pgConn.cleanupDone) + defer pgConn.conn.Close() + + deadline := time.Now().Add(time.Second * 15) + + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + + pgConn.CancelRequest(ctx) + + pgConn.conn.SetDeadline(deadline) + + pgConn.frontend.Send(&pgproto3.Terminate{}) + pgConn.flushWithPotentialWriteReadDeadlock() + }() +} + +// CleanupDone returns a channel that will be closed after all underlying resources have been cleaned up. A closed +// connection is no longer usable, but underlying resources, in particular the net.Conn, may not have finished closing +// yet. This is because certain errors such as a context cancellation require that the interrupted function call return +// immediately, but the error may also cause the connection to be closed. In these cases the underlying resources are +// closed asynchronously. +// +// This is only likely to be useful to connection pools. It gives them a way avoid establishing a new connection while +// an old connection is still being cleaned up and thereby exceeding the maximum pool size. +func (pgConn *PgConn) CleanupDone() chan (struct{}) { + return pgConn.cleanupDone +} + +// IsClosed reports if the connection has been closed. +// +// CleanupDone() can be used to determine if all cleanup has been completed. +func (pgConn *PgConn) IsClosed() bool { + return pgConn.status < connStatusIdle +} + +// IsBusy reports if the connection is busy. +func (pgConn *PgConn) IsBusy() bool { + return pgConn.status == connStatusBusy +} + +// lock locks the connection. +func (pgConn *PgConn) lock() error { + switch pgConn.status { + case connStatusBusy: + return &connLockError{status: "conn busy"} // This only should be possible in case of an application bug. + case connStatusClosed: + return &connLockError{status: "conn closed"} + case connStatusUninitialized: + return &connLockError{status: "conn uninitialized"} + } + pgConn.status = connStatusBusy + return nil +} + +func (pgConn *PgConn) unlock() { + switch pgConn.status { + case connStatusBusy: + pgConn.status = connStatusIdle + case connStatusClosed: + default: + panic("BUG: cannot unlock unlocked connection") // This should only be possible if there is a bug in this package. + } +} + +// ParameterStatus returns the value of a parameter reported by the server (e.g. +// server_version). Returns an empty string for unknown parameters. +func (pgConn *PgConn) ParameterStatus(key string) string { + return pgConn.parameterStatuses[key] +} + +// CommandTag is the status text returned by PostgreSQL for a query. +type CommandTag struct { + s string +} + +// NewCommandTag makes a CommandTag from s. +func NewCommandTag(s string) CommandTag { + return CommandTag{s: s} +} + +// RowsAffected returns the number of rows affected. If the CommandTag was not +// for a row affecting command (e.g. "CREATE TABLE") then it returns 0. +func (ct CommandTag) RowsAffected() int64 { + // Find last non-digit + idx := -1 + for i := len(ct.s) - 1; i >= 0; i-- { + if ct.s[i] >= '0' && ct.s[i] <= '9' { + idx = i + } else { + break + } + } + + if idx == -1 { + return 0 + } + + var n int64 + for _, b := range ct.s[idx:] { + n = n*10 + int64(b-'0') + } + + return n +} + +func (ct CommandTag) String() string { + return ct.s +} + +// Insert is true if the command tag starts with "INSERT". +func (ct CommandTag) Insert() bool { + return strings.HasPrefix(ct.s, "INSERT") +} + +// Update is true if the command tag starts with "UPDATE". +func (ct CommandTag) Update() bool { + return strings.HasPrefix(ct.s, "UPDATE") +} + +// Delete is true if the command tag starts with "DELETE". +func (ct CommandTag) Delete() bool { + return strings.HasPrefix(ct.s, "DELETE") +} + +// Select is true if the command tag starts with "SELECT". +func (ct CommandTag) Select() bool { + return strings.HasPrefix(ct.s, "SELECT") +} + +type FieldDescription struct { + Name string + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier int32 + Format int16 +} + +func (pgConn *PgConn) convertRowDescription(dst []FieldDescription, rd *pgproto3.RowDescription) []FieldDescription { + if cap(dst) >= len(rd.Fields) { + dst = dst[:len(rd.Fields):len(rd.Fields)] + } else { + dst = make([]FieldDescription, len(rd.Fields)) + } + + for i := range rd.Fields { + dst[i].Name = string(rd.Fields[i].Name) + dst[i].TableOID = rd.Fields[i].TableOID + dst[i].TableAttributeNumber = rd.Fields[i].TableAttributeNumber + dst[i].DataTypeOID = rd.Fields[i].DataTypeOID + dst[i].DataTypeSize = rd.Fields[i].DataTypeSize + dst[i].TypeModifier = rd.Fields[i].TypeModifier + dst[i].Format = rd.Fields[i].Format + } + + return dst +} + +type StatementDescription struct { + Name string + SQL string + ParamOIDs []uint32 + Fields []FieldDescription +} + +// Prepare creates a prepared statement. If the name is empty, the anonymous prepared statement will be used. This +// allows Prepare to also to describe statements without creating a server-side prepared statement. +// +// Prepare does not send a PREPARE statement to the server. It uses the PostgreSQL Parse and Describe protocol messages +// directly. +func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + defer pgConn.unlock() + + if ctx != context.Background() { + select { + case <-ctx.Done(): + return nil, newContextAlreadyDoneError(ctx) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + pgConn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) + pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name}) + pgConn.frontend.SendSync(&pgproto3.Sync{}) + err := pgConn.flushWithPotentialWriteReadDeadlock() + if err != nil { + pgConn.asyncClose() + return nil, err + } + + psd := &StatementDescription{Name: name, SQL: sql} + + var parseErr error + +readloop: + for { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.asyncClose() + return nil, normalizeTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ParameterDescription: + psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) + copy(psd.ParamOIDs, msg.ParameterOIDs) + case *pgproto3.RowDescription: + psd.Fields = pgConn.convertRowDescription(nil, msg) + case *pgproto3.ErrorResponse: + parseErr = ErrorResponseToPgError(msg) + case *pgproto3.ReadyForQuery: + break readloop + } + } + + if parseErr != nil { + return nil, parseErr + } + return psd, nil +} + +// Deallocate deallocates a prepared statement. +// +// Deallocate does not send a DEALLOCATE statement to the server. It uses the PostgreSQL Close protocol message +// directly. This has slightly different behavior than executing DEALLOCATE statement. +// - Deallocate can succeed in an aborted transaction. +// - Deallocating a non-existent prepared statement is not an error. +func (pgConn *PgConn) Deallocate(ctx context.Context, name string) error { + if err := pgConn.lock(); err != nil { + return err + } + defer pgConn.unlock() + + if ctx != context.Background() { + select { + case <-ctx.Done(): + return newContextAlreadyDoneError(ctx) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + pgConn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name}) + pgConn.frontend.SendSync(&pgproto3.Sync{}) + err := pgConn.flushWithPotentialWriteReadDeadlock() + if err != nil { + pgConn.asyncClose() + return err + } + + for { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.asyncClose() + return normalizeTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + return ErrorResponseToPgError(msg) + case *pgproto3.ReadyForQuery: + return nil + } + } +} + +// ErrorResponseToPgError converts a wire protocol error message to a *PgError. +func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { + return &PgError{ + Severity: msg.Severity, + SeverityUnlocalized: msg.SeverityUnlocalized, + Code: string(msg.Code), + Message: string(msg.Message), + Detail: string(msg.Detail), + Hint: msg.Hint, + Position: msg.Position, + InternalPosition: msg.InternalPosition, + InternalQuery: string(msg.InternalQuery), + Where: string(msg.Where), + SchemaName: string(msg.SchemaName), + TableName: string(msg.TableName), + ColumnName: string(msg.ColumnName), + DataTypeName: string(msg.DataTypeName), + ConstraintName: msg.ConstraintName, + File: string(msg.File), + Line: msg.Line, + Routine: string(msg.Routine), + } +} + +func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { + pgerr := ErrorResponseToPgError((*pgproto3.ErrorResponse)(msg)) + return (*Notice)(pgerr) +} + +// CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel +// request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there +// is no way to be sure a query was canceled. +// See https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-CANCELING-REQUESTS +func (pgConn *PgConn) CancelRequest(ctx context.Context) error { + // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing + // the connection config. This is important in high availability configurations where fallback connections may be + // specified or DNS may be used to load balance. + serverAddr := pgConn.conn.RemoteAddr() + var serverNetwork string + var serverAddress string + if serverAddr.Network() == "unix" { + // for unix sockets, RemoteAddr() calls getpeername() which returns the name the + // server passed to bind(). For Postgres, this is always a relative path "./.s.PGSQL.5432" + // so connecting to it will fail. Fall back to the config's value + serverNetwork, serverAddress = NetworkAddress(pgConn.config.Host, pgConn.config.Port) + } else { + serverNetwork, serverAddress = serverAddr.Network(), serverAddr.String() + } + cancelConn, err := pgConn.config.DialFunc(ctx, serverNetwork, serverAddress) + if err != nil { + // In case of unix sockets, RemoteAddr() returns only the file part of the path. If the + // first connect failed, try the config. + if serverAddr.Network() != "unix" { + return err + } + serverNetwork, serverAddr := NetworkAddress(pgConn.config.Host, pgConn.config.Port) + cancelConn, err = pgConn.config.DialFunc(ctx, serverNetwork, serverAddr) + if err != nil { + return err + } + } + defer cancelConn.Close() + + if ctx != context.Background() { + contextWatcher := ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: cancelConn}) + contextWatcher.Watch(ctx) + defer contextWatcher.Unwatch() + } + + buf := make([]byte, 16) + binary.BigEndian.PutUint32(buf[0:4], 16) + binary.BigEndian.PutUint32(buf[4:8], 80877102) + binary.BigEndian.PutUint32(buf[8:12], pgConn.pid) + binary.BigEndian.PutUint32(buf[12:16], pgConn.secretKey) + + if _, err := cancelConn.Write(buf); err != nil { + return fmt.Errorf("write to connection for cancellation: %w", err) + } + + // Wait for the cancel request to be acknowledged by the server. + // It copies the behavior of the libpq: https://github.com/postgres/postgres/blob/REL_16_0/src/interfaces/libpq/fe-connect.c#L4946-L4960 + _, _ = cancelConn.Read(buf) + + return nil +} + +// WaitForNotification waits for a LISTEN/NOTIFY message to be received. It returns an error if a notification was not +// received. +func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { + if err := pgConn.lock(); err != nil { + return err + } + defer pgConn.unlock() + + if ctx != context.Background() { + select { + case <-ctx.Done(): + return newContextAlreadyDoneError(ctx) + default: + } + + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + for { + msg, err := pgConn.receiveMessage() + if err != nil { + return normalizeTimeoutError(ctx, err) + } + + switch msg.(type) { + case *pgproto3.NotificationResponse: + return nil + } + } +} + +// Exec executes SQL via the PostgreSQL simple query protocol. SQL may contain multiple queries. Execution is +// implicitly wrapped in a transaction unless a transaction is already in progress or SQL contains transaction control +// statements. +// +// Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries. +func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { + if err := pgConn.lock(); err != nil { + return &MultiResultReader{ + closed: true, + err: err, + } + } + + pgConn.multiResultReader = MultiResultReader{ + pgConn: pgConn, + ctx: ctx, + } + multiResult := &pgConn.multiResultReader + if ctx != context.Background() { + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = newContextAlreadyDoneError(ctx) + pgConn.unlock() + return multiResult + default: + } + pgConn.contextWatcher.Watch(ctx) + } + + pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) + err := pgConn.flushWithPotentialWriteReadDeadlock() + if err != nil { + pgConn.asyncClose() + pgConn.contextWatcher.Unwatch() + multiResult.closed = true + multiResult.err = err + pgConn.unlock() + return multiResult + } + + return multiResult +} + +// ExecParams executes a command via the PostgreSQL extended query protocol. +// +// sql is a SQL command string. It may only contain one query. Parameter substitution is positional using $1, $2, $3, +// etc. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramOIDs is a slice of data type OIDs for paramValues. If paramOIDs is nil, the server will infer the data type for +// all parameters. Any paramOID element that is 0 that will cause the server to infer the data type for that parameter. +// ExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues). +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all params are text format. ExecParams will panic if +// len(paramFormats) is not 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or +// binary format. If resultFormats is nil all results will be in text format. +// +// ResultReader must be closed before PgConn can be used again. +func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats, resultFormats []int16) *ResultReader { + result := pgConn.execExtendedPrefix(ctx, paramValues) + if result.closed { + return result + } + + pgConn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}) + pgConn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + + pgConn.execExtendedSuffix(result) + + return result +} + +// ExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all params are text format. ExecPrepared will panic if +// len(paramFormats) is not 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or +// binary format. If resultFormats is nil all results will be in text format. +// +// ResultReader must be closed before PgConn can be used again. +func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats, resultFormats []int16) *ResultReader { + result := pgConn.execExtendedPrefix(ctx, paramValues) + if result.closed { + return result + } + + pgConn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + + pgConn.execExtendedSuffix(result) + + return result +} + +func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { + pgConn.resultReader = ResultReader{ + pgConn: pgConn, + ctx: ctx, + } + result := &pgConn.resultReader + + if err := pgConn.lock(); err != nil { + result.concludeCommand(CommandTag{}, err) + result.closed = true + return result + } + + if len(paramValues) > math.MaxUint16 { + result.concludeCommand(CommandTag{}, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) + result.closed = true + pgConn.unlock() + return result + } + + if ctx != context.Background() { + select { + case <-ctx.Done(): + result.concludeCommand(CommandTag{}, newContextAlreadyDoneError(ctx)) + result.closed = true + pgConn.unlock() + return result + default: + } + pgConn.contextWatcher.Watch(ctx) + } + + return result +} + +func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) { + pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) + pgConn.frontend.SendExecute(&pgproto3.Execute{}) + pgConn.frontend.SendSync(&pgproto3.Sync{}) + + err := pgConn.flushWithPotentialWriteReadDeadlock() + if err != nil { + pgConn.asyncClose() + result.concludeCommand(CommandTag{}, err) + pgConn.contextWatcher.Unwatch() + result.closed = true + pgConn.unlock() + return + } + + result.readUntilRowDescription() +} + +// CopyTo executes the copy command sql and copies the results to w. +func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { + if err := pgConn.lock(); err != nil { + return CommandTag{}, err + } + + if ctx != context.Background() { + select { + case <-ctx.Done(): + pgConn.unlock() + return CommandTag{}, newContextAlreadyDoneError(ctx) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + // Send copy to command + pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) + + err := pgConn.flushWithPotentialWriteReadDeadlock() + if err != nil { + pgConn.asyncClose() + pgConn.unlock() + return CommandTag{}, err + } + + // Read results + var commandTag CommandTag + var pgErr error + for { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.asyncClose() + return CommandTag{}, normalizeTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.CopyDone: + case *pgproto3.CopyData: + _, err := w.Write(msg.Data) + if err != nil { + pgConn.asyncClose() + return CommandTag{}, err + } + case *pgproto3.ReadyForQuery: + pgConn.unlock() + return commandTag, pgErr + case *pgproto3.CommandComplete: + commandTag = pgConn.makeCommandTag(msg.CommandTag) + case *pgproto3.ErrorResponse: + pgErr = ErrorResponseToPgError(msg) + } + } +} + +// CopyFrom executes the copy command sql and copies all of r to the PostgreSQL server. +// +// Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r +// could still block. +func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { + if err := pgConn.lock(); err != nil { + return CommandTag{}, err + } + defer pgConn.unlock() + + if ctx != context.Background() { + select { + case <-ctx.Done(): + return CommandTag{}, newContextAlreadyDoneError(ctx) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + // Send copy from query + pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) + err := pgConn.flushWithPotentialWriteReadDeadlock() + if err != nil { + pgConn.asyncClose() + return CommandTag{}, err + } + + // Send copy data + abortCopyChan := make(chan struct{}) + copyErrChan := make(chan error, 1) + signalMessageChan := pgConn.signalMessage() + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + buf := iobufpool.Get(65536) + defer iobufpool.Put(buf) + (*buf)[0] = 'd' + + for { + n, readErr := r.Read((*buf)[5:cap(*buf)]) + if n > 0 { + *buf = (*buf)[0 : n+5] + pgio.SetInt32((*buf)[1:], int32(n+4)) + + writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf) + if writeErr != nil { + // Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. Not + // setting pgConn.status or closing pgConn.cleanupDone for the same reason. + pgConn.conn.Close() + + copyErrChan <- writeErr + return + } + } + if readErr != nil { + copyErrChan <- readErr + return + } + + select { + case <-abortCopyChan: + return + default: + } + } + }() + + var pgErr error + var copyErr error + for copyErr == nil && pgErr == nil { + select { + case copyErr = <-copyErrChan: + case <-signalMessageChan: + // If pgConn.receiveMessage encounters an error it will call pgConn.asyncClose. But that is a race condition with + // the goroutine. So instead check pgConn.bufferingReceiveErr which will have been set by the signalMessage. If an + // error is found then forcibly close the connection without sending the Terminate message. + if err := pgConn.bufferingReceiveErr; err != nil { + pgConn.status = connStatusClosed + pgConn.conn.Close() + close(pgConn.cleanupDone) + return CommandTag{}, normalizeTimeoutError(ctx, err) + } + // peekMessage never returns err in the bufferingReceive mode - it only forwards the bufferingReceive variables. + // Therefore, the only case for receiveMessage to return err is during handling of the ErrorResponse message type + // and using pgOnError handler to determine the connection is no longer valid (and thus closing the conn). + msg, serverError := pgConn.receiveMessage() + if serverError != nil { + close(abortCopyChan) + return CommandTag{}, serverError + } + + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + pgErr = ErrorResponseToPgError(msg) + default: + signalMessageChan = pgConn.signalMessage() + } + } + } + close(abortCopyChan) + // Make sure io goroutine finishes before writing. + wg.Wait() + + if copyErr == io.EOF || pgErr != nil { + pgConn.frontend.Send(&pgproto3.CopyDone{}) + } else { + pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()}) + } + err = pgConn.flushWithPotentialWriteReadDeadlock() + if err != nil { + pgConn.asyncClose() + return CommandTag{}, err + } + + // Read results + var commandTag CommandTag + for { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.asyncClose() + return CommandTag{}, normalizeTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + return commandTag, pgErr + case *pgproto3.CommandComplete: + commandTag = pgConn.makeCommandTag(msg.CommandTag) + case *pgproto3.ErrorResponse: + pgErr = ErrorResponseToPgError(msg) + } + } +} + +// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. +type MultiResultReader struct { + pgConn *PgConn + ctx context.Context + + rr *ResultReader + + closed bool + err error +} + +// ReadAll reads all available results. Calling ReadAll is mutually exclusive with all other MultiResultReader methods. +func (mrr *MultiResultReader) ReadAll() ([]*Result, error) { + var results []*Result + + for mrr.NextResult() { + results = append(results, mrr.ResultReader().Read()) + } + err := mrr.Close() + + return results, err +} + +func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) { + msg, err := mrr.pgConn.receiveMessage() + if err != nil { + mrr.pgConn.contextWatcher.Unwatch() + mrr.err = normalizeTimeoutError(mrr.ctx, err) + mrr.closed = true + mrr.pgConn.asyncClose() + return nil, mrr.err + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + mrr.closed = true + mrr.pgConn.contextWatcher.Unwatch() + mrr.pgConn.unlock() + case *pgproto3.ErrorResponse: + mrr.err = ErrorResponseToPgError(msg) + } + + return msg, nil +} + +// NextResult returns advances the MultiResultReader to the next result and returns true if a result is available. +func (mrr *MultiResultReader) NextResult() bool { + for !mrr.closed && mrr.err == nil { + msg, err := mrr.receiveMessage() + if err != nil { + return false + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + mrr.pgConn.resultReader = ResultReader{ + pgConn: mrr.pgConn, + multiResultReader: mrr, + ctx: mrr.ctx, + fieldDescriptions: mrr.pgConn.convertRowDescription(mrr.pgConn.fieldDescriptions[:], msg), + } + + mrr.rr = &mrr.pgConn.resultReader + return true + case *pgproto3.CommandComplete: + mrr.pgConn.resultReader = ResultReader{ + commandTag: mrr.pgConn.makeCommandTag(msg.CommandTag), + commandConcluded: true, + closed: true, + } + mrr.rr = &mrr.pgConn.resultReader + return true + case *pgproto3.EmptyQueryResponse: + mrr.pgConn.resultReader = ResultReader{ + commandConcluded: true, + closed: true, + } + mrr.rr = &mrr.pgConn.resultReader + return true + } + } + + return false +} + +// ResultReader returns the current ResultReader. +func (mrr *MultiResultReader) ResultReader() *ResultReader { + return mrr.rr +} + +// Close closes the MultiResultReader and returns the first error that occurred during the MultiResultReader's use. +func (mrr *MultiResultReader) Close() error { + for !mrr.closed { + _, err := mrr.receiveMessage() + if err != nil { + return mrr.err + } + } + + return mrr.err +} + +// ResultReader is a reader for the result of a single query. +type ResultReader struct { + pgConn *PgConn + multiResultReader *MultiResultReader + pipeline *Pipeline + ctx context.Context + + fieldDescriptions []FieldDescription + rowValues [][]byte + commandTag CommandTag + commandConcluded bool + closed bool + err error +} + +// Result is the saved query response that is returned by calling Read on a ResultReader. +type Result struct { + FieldDescriptions []FieldDescription + Rows [][][]byte + CommandTag CommandTag + Err error +} + +// Read saves the query response to a Result. +func (rr *ResultReader) Read() *Result { + br := &Result{} + + for rr.NextRow() { + if br.FieldDescriptions == nil { + br.FieldDescriptions = make([]FieldDescription, len(rr.FieldDescriptions())) + copy(br.FieldDescriptions, rr.FieldDescriptions()) + } + + values := rr.Values() + row := make([][]byte, len(values)) + for i := range row { + if values[i] != nil { + row[i] = make([]byte, len(values[i])) + copy(row[i], values[i]) + } + } + br.Rows = append(br.Rows, row) + } + + br.CommandTag, br.Err = rr.Close() + + return br +} + +// NextRow advances the ResultReader to the next row and returns true if a row is available. +func (rr *ResultReader) NextRow() bool { + for !rr.commandConcluded { + msg, err := rr.receiveMessage() + if err != nil { + return false + } + + switch msg := msg.(type) { + case *pgproto3.DataRow: + rr.rowValues = msg.Values + return true + } + } + + return false +} + +// FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until +// the ResultReader is closed. It may return nil (for example, if the query did not return a result set or an error was +// encountered.) +func (rr *ResultReader) FieldDescriptions() []FieldDescription { + return rr.fieldDescriptions +} + +// Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only +// valid until the next NextRow call or the ResultReader is closed. +func (rr *ResultReader) Values() [][]byte { + return rr.rowValues +} + +// Close consumes any remaining result data and returns the command tag or +// error. +func (rr *ResultReader) Close() (CommandTag, error) { + if rr.closed { + return rr.commandTag, rr.err + } + rr.closed = true + + for !rr.commandConcluded { + _, err := rr.receiveMessage() + if err != nil { + return CommandTag{}, rr.err + } + } + + if rr.multiResultReader == nil && rr.pipeline == nil { + for { + msg, err := rr.receiveMessage() + if err != nil { + return CommandTag{}, rr.err + } + + switch msg := msg.(type) { + // Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete. + case *pgproto3.ErrorResponse: + rr.err = ErrorResponseToPgError(msg) + case *pgproto3.ReadyForQuery: + rr.pgConn.contextWatcher.Unwatch() + rr.pgConn.unlock() + return rr.commandTag, rr.err + } + } + } + + return rr.commandTag, rr.err +} + +// readUntilRowDescription ensures the ResultReader's fieldDescriptions are loaded. It does not return an error as any +// error will be stored in the ResultReader. +func (rr *ResultReader) readUntilRowDescription() { + for !rr.commandConcluded { + // Peek before receive to avoid consuming a DataRow if the result set does not include a RowDescription method. + // This should never happen under normal pgconn usage, but it is possible if SendBytes and ReceiveResults are + // manually used to construct a query that does not issue a describe statement. + msg, _ := rr.pgConn.peekMessage() + if _, ok := msg.(*pgproto3.DataRow); ok { + return + } + + // Consume the message + msg, _ = rr.receiveMessage() + if _, ok := msg.(*pgproto3.RowDescription); ok { + return + } + } +} + +func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) { + if rr.multiResultReader == nil { + msg, err = rr.pgConn.receiveMessage() + } else { + msg, err = rr.multiResultReader.receiveMessage() + } + + if err != nil { + err = normalizeTimeoutError(rr.ctx, err) + rr.concludeCommand(CommandTag{}, err) + rr.pgConn.contextWatcher.Unwatch() + rr.closed = true + if rr.multiResultReader == nil { + rr.pgConn.asyncClose() + } + + return nil, rr.err + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + rr.fieldDescriptions = rr.pgConn.convertRowDescription(rr.pgConn.fieldDescriptions[:], msg) + case *pgproto3.CommandComplete: + rr.concludeCommand(rr.pgConn.makeCommandTag(msg.CommandTag), nil) + case *pgproto3.EmptyQueryResponse: + rr.concludeCommand(CommandTag{}, nil) + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + if rr.pipeline != nil { + rr.pipeline.state.HandleError(pgErr) + } + rr.concludeCommand(CommandTag{}, pgErr) + } + + return msg, nil +} + +func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { + // Keep the first error that is recorded. Store the error before checking if the command is already concluded to + // allow for receiving an error after CommandComplete but before ReadyForQuery. + if err != nil && rr.err == nil { + rr.err = err + } + + if rr.commandConcluded { + return + } + + rr.commandTag = commandTag + rr.rowValues = nil + rr.commandConcluded = true +} + +// Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. +type Batch struct { + buf []byte + err error +} + +// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. +func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats, resultFormats []int16) { + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) + if batch.err != nil { + return + } + batch.ExecPrepared("", paramValues, paramFormats, resultFormats) +} + +// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions. +func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats, resultFormats []int16) { + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf) + if batch.err != nil { + return + } +} + +// ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a +// transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing +// multiple queries in a single round trip than using pipeline mode. +func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { + if batch.err != nil { + return &MultiResultReader{ + closed: true, + err: batch.err, + } + } + + if err := pgConn.lock(); err != nil { + return &MultiResultReader{ + closed: true, + err: err, + } + } + + pgConn.multiResultReader = MultiResultReader{ + pgConn: pgConn, + ctx: ctx, + } + multiResult := &pgConn.multiResultReader + + if ctx != context.Background() { + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = newContextAlreadyDoneError(ctx) + pgConn.unlock() + return multiResult + default: + } + pgConn.contextWatcher.Watch(ctx) + } + + batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf) + if batch.err != nil { + pgConn.contextWatcher.Unwatch() + multiResult.err = normalizeTimeoutError(multiResult.ctx, batch.err) + multiResult.closed = true + pgConn.asyncClose() + return multiResult + } + + pgConn.enterPotentialWriteReadDeadlock() + defer pgConn.exitPotentialWriteReadDeadlock() + _, err := pgConn.conn.Write(batch.buf) + if err != nil { + pgConn.contextWatcher.Unwatch() + multiResult.err = normalizeTimeoutError(multiResult.ctx, err) + multiResult.closed = true + pgConn.asyncClose() + return multiResult + } + + return multiResult +} + +// EscapeString escapes a string such that it can safely be interpolated into a SQL command string. It does not include +// the surrounding single quotes. +// +// The current implementation requires that standard_conforming_strings=on and client_encoding="UTF8". If these +// conditions are not met an error will be returned. It is possible these restrictions will be lifted in the future. +func (pgConn *PgConn) EscapeString(s string) (string, error) { + if pgConn.ParameterStatus("standard_conforming_strings") != "on" { + return "", errors.New("EscapeString must be run with standard_conforming_strings=on") + } + + if pgConn.ParameterStatus("client_encoding") != "UTF8" { + return "", errors.New("EscapeString must be run with client_encoding=UTF8") + } + + return strings.Replace(s, "'", "''", -1), nil +} + +// CheckConn checks the underlying connection without writing any bytes. This is currently implemented by doing a read +// with a very short deadline. This can be useful because a TCP connection can be broken such that a write will appear +// to succeed even though it will never actually reach the server. Reading immediately before a write will detect this +// condition. If this is done immediately before sending a query it reduces the chances a query will be sent that fails +// without the client knowing whether the server received it or not. +// +// Deprecated: CheckConn is deprecated in favor of Ping. CheckConn cannot detect all types of broken connections where +// the write would still appear to succeed. Prefer Ping unless on a high latency connection. +func (pgConn *PgConn) CheckConn() error { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + _, err := pgConn.ReceiveMessage(ctx) + if err != nil { + if !Timeout(err) { + return err + } + } + + return nil +} + +// Ping pings the server. This can be useful because a TCP connection can be broken such that a write will appear to +// succeed even though it will never actually reach the server. Pinging immediately before sending a query reduces the +// chances a query will be sent that fails without the client knowing whether the server received it or not. +func (pgConn *PgConn) Ping(ctx context.Context) error { + return pgConn.Exec(ctx, "-- ping").Close() +} + +// makeCommandTag makes a CommandTag. It does not retain a reference to buf or buf's underlying memory. +func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag { + return CommandTag{s: string(buf)} +} + +// enterPotentialWriteReadDeadlock must be called before a write that could deadlock if the server is simultaneously +// blocked writing to us. +func (pgConn *PgConn) enterPotentialWriteReadDeadlock() { + // The time to wait is somewhat arbitrary. A Write should only take as long as the syscall and memcpy to the OS + // outbound network buffer unless the buffer is full (which potentially is a block). It needs to be long enough for + // the normal case, but short enough not to kill performance if a block occurs. + // + // In addition, on Windows the default timer resolution is 15.6ms. So setting the timer to less than that is + // ineffective. + if pgConn.slowWriteTimer.Reset(15 * time.Millisecond) { + panic("BUG: slow write timer already active") + } +} + +// exitPotentialWriteReadDeadlock must be called after a call to enterPotentialWriteReadDeadlock. +func (pgConn *PgConn) exitPotentialWriteReadDeadlock() { + if !pgConn.slowWriteTimer.Stop() { + // The timer starts its function in a separate goroutine. It is necessary to ensure the background reader has + // started before calling Stop. Otherwise, the background reader may not be stopped. That on its own is not a + // serious problem. But what is a serious problem is that the background reader may start at an inopportune time in + // a subsequent query. For example, if a subsequent query was canceled then a deadline may be set on the net.Conn to + // interrupt an in-progress read. After the read is interrupted, but before the deadline is cleared, the background + // reader could start and read a deadline error. Then the next query would receive the an unexpected deadline error. + <-pgConn.bgReaderStarted + pgConn.bgReader.Stop() + } +} + +func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error { + pgConn.enterPotentialWriteReadDeadlock() + defer pgConn.exitPotentialWriteReadDeadlock() + err := pgConn.frontend.Flush() + return err +} + +// SyncConn prepares the underlying net.Conn for direct use. PgConn may internally buffer reads or use goroutines for +// background IO. This means that any direct use of the underlying net.Conn may be corrupted if a read is already +// buffered or a read is in progress. SyncConn drains read buffers and stops background IO. In some cases this may +// require sending a ping to the server. ctx can be used to cancel this operation. This should be called before any +// operation that will use the underlying net.Conn directly. e.g. Before Conn() or Hijack(). +// +// This should not be confused with the PostgreSQL protocol Sync message. +func (pgConn *PgConn) SyncConn(ctx context.Context) error { + for i := 0; i < 10; i++ { + if pgConn.bgReader.Status() == bgreader.StatusStopped && pgConn.frontend.ReadBufferLen() == 0 { + return nil + } + + err := pgConn.Ping(ctx) + if err != nil { + return fmt.Errorf("SyncConn: Ping failed while syncing conn: %w", err) + } + } + + // This should never happen. Only way I can imagine this occurring is if the server is constantly sending data such as + // LISTEN/NOTIFY or log notifications such that we never can get an empty buffer. + return errors.New("SyncConn: conn never synchronized") +} + +// CustomData returns a map that can be used to associate custom data with the connection. +func (pgConn *PgConn) CustomData() map[string]any { + return pgConn.customData +} + +// HijackedConn is the result of hijacking a connection. +// +// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning +// compatibility. +type HijackedConn struct { + Conn net.Conn + PID uint32 // backend pid + SecretKey uint32 // key to use to send a cancel query message to the server + ParameterStatuses map[string]string // parameters that have been reported by the server + TxStatus byte + Frontend *pgproto3.Frontend + Config *Config + CustomData map[string]any +} + +// Hijack extracts the internal connection data. pgConn must be in an idle state. SyncConn should be called immediately +// before Hijack. pgConn is unusable after hijacking. Hijacking is typically only useful when using pgconn to establish +// a connection, but taking complete control of the raw connection after that (e.g. a load balancer or proxy). +// +// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning +// compatibility. +func (pgConn *PgConn) Hijack() (*HijackedConn, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + pgConn.status = connStatusClosed + + return &HijackedConn{ + Conn: pgConn.conn, + PID: pgConn.pid, + SecretKey: pgConn.secretKey, + ParameterStatuses: pgConn.parameterStatuses, + TxStatus: pgConn.txStatus, + Frontend: pgConn.frontend, + Config: pgConn.config, + CustomData: pgConn.customData, + }, nil +} + +// Construct created a PgConn from an already established connection to a PostgreSQL server. This is the inverse of +// PgConn.Hijack. The connection must be in an idle state. +// +// hc.Frontend is replaced by a new pgproto3.Frontend built by hc.Config.BuildFrontend. +// +// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning +// compatibility. +func Construct(hc *HijackedConn) (*PgConn, error) { + pgConn := &PgConn{ + conn: hc.Conn, + pid: hc.PID, + secretKey: hc.SecretKey, + parameterStatuses: hc.ParameterStatuses, + txStatus: hc.TxStatus, + frontend: hc.Frontend, + config: hc.Config, + customData: hc.CustomData, + + status: connStatusIdle, + + cleanupDone: make(chan struct{}), + } + + pgConn.contextWatcher = ctxwatch.NewContextWatcher(hc.Config.BuildContextWatcherHandler(pgConn)) + pgConn.bgReader = bgreader.New(pgConn.conn) + pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), + func() { + pgConn.bgReader.Start() + pgConn.bgReaderStarted <- struct{}{} + }, + ) + pgConn.slowWriteTimer.Stop() + pgConn.bgReaderStarted = make(chan struct{}) + pgConn.frontend = hc.Config.BuildFrontend(pgConn.bgReader, pgConn.conn) + + return pgConn, nil +} + +// Pipeline represents a connection in pipeline mode. +// +// SendPrepare, SendQueryParams, and SendQueryPrepared queue requests to the server. These requests are not written until +// pipeline is flushed by Flush or Sync. Sync must be called after the last request is queued. Requests between +// synchronization points are implicitly transactional unless explicit transaction control statements have been issued. +// +// The context the pipeline was started with is in effect for the entire life of the Pipeline. +// +// For a deeper understanding of pipeline mode see the PostgreSQL documentation for the extended query protocol +// (https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY) and the libpq pipeline mode +// (https://www.postgresql.org/docs/current/libpq-pipeline-mode.html). +type Pipeline struct { + conn *PgConn + ctx context.Context + + state pipelineState + err error + closed bool +} + +// PipelineSync is returned by GetResults when a ReadyForQuery message is received. +type PipelineSync struct{} + +// CloseComplete is returned by GetResults when a CloseComplete message is received. +type CloseComplete struct{} + +type pipelineRequestType int + +const ( + pipelineNil pipelineRequestType = iota + pipelinePrepare + pipelineQueryParams + pipelineQueryPrepared + pipelineDeallocate + pipelineSyncRequest + pipelineFlushRequest +) + +type pipelineRequestEvent struct { + RequestType pipelineRequestType + WasSentToServer bool + BeforeFlushOrSync bool +} + +type pipelineState struct { + requestEventQueue list.List + lastRequestType pipelineRequestType + pgErr *PgError + expectedReadyForQueryCount int +} + +func (s *pipelineState) Init() { + s.requestEventQueue.Init() + s.lastRequestType = pipelineNil +} + +func (s *pipelineState) RegisterSendingToServer() { + for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() { + val := elem.Value.(pipelineRequestEvent) + if val.WasSentToServer { + return + } + val.WasSentToServer = true + elem.Value = val + } +} + +func (s *pipelineState) registerFlushingBufferOnServer() { + for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() { + val := elem.Value.(pipelineRequestEvent) + if val.BeforeFlushOrSync { + return + } + val.BeforeFlushOrSync = true + elem.Value = val + } +} + +func (s *pipelineState) PushBackRequestType(req pipelineRequestType) { + if req == pipelineNil { + return + } + + if req != pipelineFlushRequest { + s.requestEventQueue.PushBack(pipelineRequestEvent{RequestType: req}) + } + if req == pipelineFlushRequest || req == pipelineSyncRequest { + s.registerFlushingBufferOnServer() + } + s.lastRequestType = req + + if req == pipelineSyncRequest { + s.expectedReadyForQueryCount++ + } +} + +func (s *pipelineState) ExtractFrontRequestType() pipelineRequestType { + for { + elem := s.requestEventQueue.Front() + if elem == nil { + return pipelineNil + } + val := elem.Value.(pipelineRequestEvent) + if !(val.WasSentToServer && val.BeforeFlushOrSync) { + return pipelineNil + } + + s.requestEventQueue.Remove(elem) + if val.RequestType == pipelineSyncRequest { + s.pgErr = nil + } + if s.pgErr == nil { + return val.RequestType + } + } +} + +func (s *pipelineState) HandleError(err *PgError) { + s.pgErr = err +} + +func (s *pipelineState) HandleReadyForQuery() { + s.expectedReadyForQueryCount-- +} + +func (s *pipelineState) PendingSync() bool { + var notPendingSync bool + + if elem := s.requestEventQueue.Back(); elem != nil { + val := elem.Value.(pipelineRequestEvent) + notPendingSync = (val.RequestType == pipelineSyncRequest) && val.WasSentToServer + } else { + notPendingSync = (s.lastRequestType == pipelineSyncRequest) || (s.lastRequestType == pipelineNil) + } + + return !notPendingSync +} + +func (s *pipelineState) ExpectedReadyForQuery() int { + return s.expectedReadyForQueryCount +} + +// StartPipeline switches the connection to pipeline mode and returns a *Pipeline. In pipeline mode requests can be sent +// to the server without waiting for a response. Close must be called on the returned *Pipeline to return the connection +// to normal mode. While in pipeline mode, no methods that communicate with the server may be called except +// CancelRequest and Close. ctx is in effect for entire life of the *Pipeline. +// +// Prefer ExecBatch when only sending one group of queries at once. +func (pgConn *PgConn) StartPipeline(ctx context.Context) *Pipeline { + if err := pgConn.lock(); err != nil { + pipeline := &Pipeline{ + closed: true, + err: err, + } + pipeline.state.Init() + + return pipeline + } + + pgConn.pipeline = Pipeline{ + conn: pgConn, + ctx: ctx, + } + pgConn.pipeline.state.Init() + + pipeline := &pgConn.pipeline + + if ctx != context.Background() { + select { + case <-ctx.Done(): + pipeline.closed = true + pipeline.err = newContextAlreadyDoneError(ctx) + pgConn.unlock() + return pipeline + default: + } + pgConn.contextWatcher.Watch(ctx) + } + + return pipeline +} + +// SendPrepare is the pipeline version of *PgConn.Prepare. +func (p *Pipeline) SendPrepare(name, sql string, paramOIDs []uint32) { + if p.closed { + return + } + + p.conn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) + p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name}) + p.state.PushBackRequestType(pipelinePrepare) +} + +// SendDeallocate deallocates a prepared statement. +func (p *Pipeline) SendDeallocate(name string) { + if p.closed { + return + } + + p.conn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name}) + p.state.PushBackRequestType(pipelineDeallocate) +} + +// SendQueryParams is the pipeline version of *PgConn.QueryParams. +func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats, resultFormats []int16) { + if p.closed { + return + } + + p.conn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}) + p.conn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) + p.conn.frontend.SendExecute(&pgproto3.Execute{}) + p.state.PushBackRequestType(pipelineQueryParams) +} + +// SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared. +func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, paramFormats, resultFormats []int16) { + if p.closed { + return + } + + p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) + p.conn.frontend.SendExecute(&pgproto3.Execute{}) + p.state.PushBackRequestType(pipelineQueryPrepared) +} + +// SendFlushRequest sends a request for the server to flush its output buffer. +// +// The server flushes its output buffer automatically as a result of Sync being called, +// or on any request when not in pipeline mode; this function is useful to cause the server +// to flush its output buffer in pipeline mode without establishing a synchronization point. +// Note that the request is not itself flushed to the server automatically; use Flush if +// necessary. This copies the behavior of libpq PQsendFlushRequest. +func (p *Pipeline) SendFlushRequest() { + if p.closed { + return + } + + p.conn.frontend.Send(&pgproto3.Flush{}) + p.state.PushBackRequestType(pipelineFlushRequest) +} + +// SendPipelineSync marks a synchronization point in a pipeline by sending a sync message +// without flushing the send buffer. This serves as the delimiter of an implicit +// transaction and an error recovery point. +// +// Note that the request is not itself flushed to the server automatically; use Flush if +// necessary. This copies the behavior of libpq PQsendPipelineSync. +func (p *Pipeline) SendPipelineSync() { + if p.closed { + return + } + + p.conn.frontend.SendSync(&pgproto3.Sync{}) + p.state.PushBackRequestType(pipelineSyncRequest) +} + +// Flush flushes the queued requests without establishing a synchronization point. +func (p *Pipeline) Flush() error { + if p.closed { + if p.err != nil { + return p.err + } + return errors.New("pipeline closed") + } + + err := p.conn.flushWithPotentialWriteReadDeadlock() + if err != nil { + err = normalizeTimeoutError(p.ctx, err) + + p.conn.asyncClose() + + p.conn.contextWatcher.Unwatch() + p.conn.unlock() + p.closed = true + p.err = err + return err + } + + p.state.RegisterSendingToServer() + return nil +} + +// Sync establishes a synchronization point and flushes the queued requests. +func (p *Pipeline) Sync() error { + p.SendPipelineSync() + return p.Flush() +} + +// GetResults gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, or +// *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError. If no +// results are available, results and err will both be nil. +func (p *Pipeline) GetResults() (results any, err error) { + if p.closed { + if p.err != nil { + return nil, p.err + } + return nil, errors.New("pipeline closed") + } + + if p.state.ExtractFrontRequestType() == pipelineNil { + return nil, nil + } + + return p.getResults() +} + +func (p *Pipeline) getResults() (results any, err error) { + for { + msg, err := p.conn.receiveMessage() + if err != nil { + p.closed = true + p.err = err + p.conn.asyncClose() + return nil, normalizeTimeoutError(p.ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + p.conn.resultReader = ResultReader{ + pgConn: p.conn, + pipeline: p, + ctx: p.ctx, + fieldDescriptions: p.conn.convertRowDescription(p.conn.fieldDescriptions[:], msg), + } + return &p.conn.resultReader, nil + case *pgproto3.CommandComplete: + p.conn.resultReader = ResultReader{ + commandTag: p.conn.makeCommandTag(msg.CommandTag), + commandConcluded: true, + closed: true, + } + return &p.conn.resultReader, nil + case *pgproto3.ParseComplete: + peekedMsg, err := p.conn.peekMessage() + if err != nil { + p.conn.asyncClose() + return nil, normalizeTimeoutError(p.ctx, err) + } + if _, ok := peekedMsg.(*pgproto3.ParameterDescription); ok { + return p.getResultsPrepare() + } + case *pgproto3.CloseComplete: + return &CloseComplete{}, nil + case *pgproto3.ReadyForQuery: + p.state.HandleReadyForQuery() + return &PipelineSync{}, nil + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + return nil, pgErr + } + } +} + +func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { + psd := &StatementDescription{} + + for { + msg, err := p.conn.receiveMessage() + if err != nil { + p.conn.asyncClose() + return nil, normalizeTimeoutError(p.ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ParameterDescription: + psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) + copy(psd.ParamOIDs, msg.ParameterOIDs) + case *pgproto3.RowDescription: + psd.Fields = p.conn.convertRowDescription(nil, msg) + return psd, nil + + // NoData is returned instead of RowDescription when there is no expected result. e.g. An INSERT without a RETURNING + // clause. + case *pgproto3.NoData: + return psd, nil + + // These should never happen here. But don't take chances that could lead to a deadlock. + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) + return nil, pgErr + case *pgproto3.CommandComplete: + p.conn.asyncClose() + return nil, errors.New("BUG: received CommandComplete while handling Describe") + case *pgproto3.ReadyForQuery: + p.conn.asyncClose() + return nil, errors.New("BUG: received ReadyForQuery while handling Describe") + } + } +} + +// Close closes the pipeline and returns the connection to normal mode. +func (p *Pipeline) Close() error { + if p.closed { + return p.err + } + + p.closed = true + + if p.state.PendingSync() { + p.conn.asyncClose() + p.err = errors.New("pipeline has unsynced requests") + p.conn.contextWatcher.Unwatch() + p.conn.unlock() + + return p.err + } + + for p.state.ExpectedReadyForQuery() > 0 { + _, err := p.getResults() + if err != nil { + p.err = err + var pgErr *PgError + if !errors.As(err, &pgErr) { + p.conn.asyncClose() + break + } + } + } + + p.conn.contextWatcher.Unwatch() + p.conn.unlock() + + return p.err +} + +// DeadlineContextWatcherHandler handles canceled contexts by setting a deadline on a net.Conn. +type DeadlineContextWatcherHandler struct { + Conn net.Conn + + // DeadlineDelay is the delay to set on the deadline set on net.Conn when the context is canceled. + DeadlineDelay time.Duration +} + +func (h *DeadlineContextWatcherHandler) HandleCancel(ctx context.Context) { + h.Conn.SetDeadline(time.Now().Add(h.DeadlineDelay)) +} + +func (h *DeadlineContextWatcherHandler) HandleUnwatchAfterCancel() { + h.Conn.SetDeadline(time.Time{}) +} + +// CancelRequestContextWatcherHandler handles canceled contexts by sending a cancel request to the server. It also sets +// a deadline on a net.Conn as a fallback. +type CancelRequestContextWatcherHandler struct { + Conn *PgConn + + // CancelRequestDelay is the delay before sending the cancel request to the server. + CancelRequestDelay time.Duration + + // DeadlineDelay is the delay to set on the deadline set on net.Conn when the context is canceled. + DeadlineDelay time.Duration + + cancelFinishedChan chan struct{} + handleUnwatchAfterCancelCalled func() +} + +func (h *CancelRequestContextWatcherHandler) HandleCancel(context.Context) { + h.cancelFinishedChan = make(chan struct{}) + var handleUnwatchedAfterCancelCalledCtx context.Context + handleUnwatchedAfterCancelCalledCtx, h.handleUnwatchAfterCancelCalled = context.WithCancel(context.Background()) + + deadline := time.Now().Add(h.DeadlineDelay) + h.Conn.conn.SetDeadline(deadline) + + go func() { + defer close(h.cancelFinishedChan) + + select { + case <-handleUnwatchedAfterCancelCalledCtx.Done(): + return + case <-time.After(h.CancelRequestDelay): + } + + cancelRequestCtx, cancel := context.WithDeadline(handleUnwatchedAfterCancelCalledCtx, deadline) + defer cancel() + h.Conn.CancelRequest(cancelRequestCtx) + + // CancelRequest is inherently racy. Even though the cancel request has been received by the server at this point, + // it hasn't necessarily been delivered to the other connection. If we immediately return and the connection is + // immediately used then it is possible the CancelRequest will actually cancel our next query. The + // TestCancelRequestContextWatcherHandler Stress test can produce this error without the sleep below. The sleep time + // is arbitrary, but should be sufficient to prevent this error case. + time.Sleep(100 * time.Millisecond) + }() +} + +func (h *CancelRequestContextWatcherHandler) HandleUnwatchAfterCancel() { + h.handleUnwatchAfterCancelCalled() + <-h.cancelFinishedChan + + h.Conn.conn.SetDeadline(time.Time{}) +} diff --git a/pgconn/pgconn_private_test.go b/pgconn/pgconn_private_test.go new file mode 100644 index 000000000..a0c15c27a --- /dev/null +++ b/pgconn/pgconn_private_test.go @@ -0,0 +1,41 @@ +package pgconn + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCommandTag(t *testing.T) { + t.Parallel() + + tests := []struct { + commandTag CommandTag + rowsAffected int64 + isInsert bool + isUpdate bool + isDelete bool + isSelect bool + }{ + {commandTag: CommandTag{s: "INSERT 0 5"}, rowsAffected: 5, isInsert: true}, + {commandTag: CommandTag{s: "UPDATE 0"}, rowsAffected: 0, isUpdate: true}, + {commandTag: CommandTag{s: "UPDATE 1"}, rowsAffected: 1, isUpdate: true}, + {commandTag: CommandTag{s: "DELETE 0"}, rowsAffected: 0, isDelete: true}, + {commandTag: CommandTag{s: "DELETE 1"}, rowsAffected: 1, isDelete: true}, + {commandTag: CommandTag{s: "DELETE 1234567890"}, rowsAffected: 1234567890, isDelete: true}, + {commandTag: CommandTag{s: "SELECT 1"}, rowsAffected: 1, isSelect: true}, + {commandTag: CommandTag{s: "SELECT 99999999999"}, rowsAffected: 99999999999, isSelect: true}, + {commandTag: CommandTag{s: "CREATE TABLE"}, rowsAffected: 0}, + {commandTag: CommandTag{s: "ALTER TABLE"}, rowsAffected: 0}, + {commandTag: CommandTag{s: "DROP TABLE"}, rowsAffected: 0}, + } + + for i, tt := range tests { + ct := tt.commandTag + assert.Equalf(t, tt.rowsAffected, ct.RowsAffected(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isInsert, ct.Insert(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isUpdate, ct.Update(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isDelete, ct.Delete(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag) + } +} diff --git a/pgconn/pgconn_stress_test.go b/pgconn/pgconn_stress_test.go new file mode 100644 index 000000000..1bfd98585 --- /dev/null +++ b/pgconn/pgconn_stress_test.go @@ -0,0 +1,90 @@ +package pgconn_test + +import ( + "context" + "math/rand" + "os" + "runtime" + "strconv" + "testing" + + "github.com/jackc/pgx/v5/pgconn" + + "github.com/stretchr/testify/require" +) + +func TestConnStress(t *testing.T) { + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + actionCount := 10000 + if s := os.Getenv("PGX_TEST_STRESS_FACTOR"); s != "" { + stressFactor, err := strconv.ParseInt(s, 10, 64) + require.Nil(t, err, "Failed to parse PGX_TEST_STRESS_FACTOR") + actionCount *= int(stressFactor) + } + + setupStressDB(t, pgConn) + + actions := []struct { + name string + fn func(*pgconn.PgConn) error + }{ + {"Exec Select", stressExecSelect}, + {"ExecParams Select", stressExecParamsSelect}, + {"Batch", stressBatch}, + } + + for i := 0; i < actionCount; i++ { + action := actions[rand.Intn(len(actions))] + err := action.fn(pgConn) + require.Nilf(t, err, "%d: %s", i, action.name) + } + + // Each call with a context starts a goroutine. Ensure they are cleaned up when context is not canceled. + numGoroutine := runtime.NumGoroutine() + require.Truef(t, numGoroutine < 1000, "goroutines appear to be orphaned: %d in process", numGoroutine) +} + +func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) { + _, err := pgConn.Exec(context.Background(), ` + create temporary table widgets( + id serial primary key, + name varchar not null, + description text, + creation_time timestamptz default now() + ); + + insert into widgets(name, description) values + ('Foo', 'bar'), + ('baz', 'Something really long Something really long Something really long Something really long Something really long'), + ('a', 'b')`).ReadAll() + require.NoError(t, err) +} + +func stressExecSelect(pgConn *pgconn.PgConn) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err := pgConn.Exec(ctx, "select * from widgets").ReadAll() + return err +} + +func stressExecParamsSelect(pgConn *pgconn.PgConn) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + result := pgConn.ExecParams(ctx, "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read() + return result.Err +} + +func stressBatch(pgConn *pgconn.PgConn) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + batch := &pgconn.Batch{} + + batch.ExecParams("select * from widgets", nil, nil, nil, nil) + batch.ExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) + _, err := pgConn.ExecBatch(ctx, batch).ReadAll() + return err +} diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go new file mode 100644 index 000000000..4d6770f81 --- /dev/null +++ b/pgconn/pgconn_test.go @@ -0,0 +1,4289 @@ +package pgconn_test + +import ( + "bytes" + "compress/gzip" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "log" + "math" + "net" + "os" + "strconv" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/internal/pgio" + "github.com/jackc/pgx/v5/internal/pgmock" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgconn/ctxwatch" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/jackc/pgx/v5/pgtype" +) + +const pgbouncerConnStringEnvVar = "PGX_TEST_PGBOUNCER_CONN_STRING" + +func TestConnect(t *testing.T) { + tests := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, + {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, + {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + connString := os.Getenv(tt.env) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", tt.env) + } + + conn, err := pgconn.Connect(ctx, connString) + require.NoError(t, err) + + closeConn(t, conn) + }) + } +} + +func TestConnectWithOptions(t *testing.T) { + tests := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, + {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, + {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + connString := os.Getenv(tt.env) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", tt.env) + } + var sslOptions pgconn.ParseConfigOptions + sslOptions.GetSSLPassword = GetSSLPassword + conn, err := pgconn.ConnectWithOptions(ctx, connString, sslOptions) + require.NoError(t, err) + + closeConn(t, conn) + }) + } +} + +// TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure +// connection. +func TestConnectTLS(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") + } + + conn, err := pgconn.Connect(ctx, connString) + require.NoError(t, err) + + result := conn.ExecParams(ctx, `select ssl from pg_stat_ssl where pg_backend_pid() = pid;`, nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, 1) + require.Len(t, result.Rows[0], 1) + require.Equalf(t, "t", string(result.Rows[0][0]), "not a TLS connection") + + closeConn(t, conn) +} + +func TestConnectTLSPasswordProtectedClientCertWithSSLPassword(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + connString := os.Getenv("PGX_TEST_TLS_CLIENT_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CLIENT_CONN_STRING") + } + if os.Getenv("PGX_SSL_PASSWORD") == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_SSL_PASSWORD") + } + + connString += " sslpassword=" + os.Getenv("PGX_SSL_PASSWORD") + + conn, err := pgconn.Connect(ctx, connString) + require.NoError(t, err) + + result := conn.ExecParams(ctx, `select ssl from pg_stat_ssl where pg_backend_pid() = pid;`, nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, 1) + require.Len(t, result.Rows[0], 1) + require.Equalf(t, "t", string(result.Rows[0][0]), "not a TLS connection") + + closeConn(t, conn) +} + +func TestConnectTLSPasswordProtectedClientCertWithGetSSLPasswordConfigOption(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + connString := os.Getenv("PGX_TEST_TLS_CLIENT_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CLIENT_CONN_STRING") + } + if os.Getenv("PGX_SSL_PASSWORD") == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_SSL_PASSWORD") + } + + var sslOptions pgconn.ParseConfigOptions + sslOptions.GetSSLPassword = GetSSLPassword + config, err := pgconn.ParseConfigWithOptions(connString, sslOptions) + require.Nil(t, err) + + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + + result := conn.ExecParams(ctx, `select ssl from pg_stat_ssl where pg_backend_pid() = pid;`, nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, 1) + require.Len(t, result.Rows[0], 1) + require.Equalf(t, "t", string(result.Rows[0][0]), "not a TLS connection") + + closeConn(t, conn) +} + +type pgmockWaitStep time.Duration + +func (s pgmockWaitStep) Step(*pgproto3.Backend) error { + time.Sleep(time.Duration(s)) + return nil +} + +func TestConnectTimeout(t *testing.T) { + t.Parallel() + tests := []struct { + name string + connect func(connStr string) error + }{ + { + name: "via context that times out", + connect: func(connStr string) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) + defer cancel() + _, err := pgconn.Connect(ctx, connStr) + return err + }, + }, + { + name: "via config ConnectTimeout", + connect: func(connStr string) error { + conf, err := pgconn.ParseConfig(connStr) + require.NoError(t, err) + conf.ConnectTimeout = time.Microsecond * 50 + _, err = pgconn.ConnectConfig(context.Background(), conf) + return err + }, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + script := &pgmock.Script{ + Steps: []pgmock.Step{ + pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), + pgmock.SendMessage(&pgproto3.AuthenticationOk{}), + pgmockWaitStep(time.Millisecond * 500), + pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), + pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + }, + } + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450)) + if err != nil { + serverErrChan <- err + return + } + + err = script.Run(pgproto3.NewBackend(conn, conn)) + if err != nil { + serverErrChan <- err + return + } + }() + + host, port, _ := strings.Cut(ln.Addr().String(), ":") + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) + tooLate := time.Now().Add(time.Millisecond * 500) + + err = tt.connect(connStr) + require.True(t, pgconn.Timeout(err), err) + require.True(t, time.Now().Before(tooLate)) + }) + } +} + +func TestConnectTimeoutStuckOnTLSHandshake(t *testing.T) { + t.Parallel() + tests := []struct { + name string + connect func(connStr string) error + }{ + { + name: "via context that times out", + connect: func(connStr string) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) + defer cancel() + _, err := pgconn.Connect(ctx, connStr) + return err + }, + }, + { + name: "via config ConnectTimeout", + connect: func(connStr string) error { + conf, err := pgconn.ParseConfig(connStr) + require.NoError(t, err) + conf.ConnectTimeout = time.Millisecond * 10 + _, err = pgconn.ConnectConfig(context.Background(), conf) + return err + }, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + go func() { + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + var buf []byte + _, err = conn.Read(buf) + if err != nil { + serverErrChan <- err + return + } + + // Sleeping to hang the TLS handshake. + time.Sleep(time.Minute) + }() + + host, port, _ := strings.Cut(ln.Addr().String(), ":") + connStr := fmt.Sprintf("host=%s port=%s", host, port) + + errChan := make(chan error) + go func() { + err := tt.connect(connStr) + errChan <- err + }() + + select { + case err = <-errChan: + require.True(t, pgconn.Timeout(err), err) + case err = <-serverErrChan: + t.Fatalf("server failed with error: %s", err) + case <-time.After(time.Millisecond * 500): + t.Fatal("exceeded connection timeout without erroring out") + } + }) + } +} + +func TestConnectInvalidUser(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } + + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + config.User = "pgxinvalidusertest" + + _, err = pgconn.ConnectConfig(ctx, config) + require.Error(t, err) + var pgErr *pgconn.PgError + require.ErrorAs(t, err, &pgErr) + if pgErr.Code != "28000" && pgErr.Code != "28P01" { + t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) + } +} + +func TestConnectWithConnectionRefused(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + // Presumably nothing is listening on 127.0.0.1:1 + conn, err := pgconn.Connect(ctx, "host=127.0.0.1 port=1") + if err == nil { + conn.Close(ctx) + t.Fatal("Expected error establishing connection to bad port") + } +} + +func TestConnectCustomDialer(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + dialed := false + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialed = true + return net.Dial(network, address) + } + + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + require.True(t, dialed) + closeConn(t, conn) +} + +func TestConnectCustomLookup(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } + + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + looked := false + config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { + looked = true + return net.LookupHost(host) + } + + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + require.True(t, looked) + closeConn(t, conn) +} + +func TestConnectCustomLookupWithPort(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } + + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + origPort := config.Port + // Change the config an invalid port so it will fail if used + config.Port = 0 + + looked := false + config.LookupFunc = func(ctx context.Context, host string) ([]string, error) { + looked = true + addrs, err := net.LookupHost(host) + if err != nil { + return nil, err + } + for i := range addrs { + addrs[i] = net.JoinHostPort(addrs[i], strconv.FormatUint(uint64(origPort), 10)) + } + return addrs, nil + } + + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + require.True(t, looked) + closeConn(t, conn) +} + +func TestConnectWithRuntimeParams(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.RuntimeParams = map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + } + + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, conn) + + result := conn.ExecParams(ctx, "show application_name", nil, nil, nil, nil).Read() + require.Nil(t, result.Err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "pgxtest", string(result.Rows[0][0])) + + result = conn.ExecParams(ctx, "show search_path", nil, nil, nil, nil).Read() + require.Nil(t, result.Err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "myschema", string(result.Rows[0][0])) +} + +func TestConnectWithFallback(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + // Prepend current primary config to fallbacks + config.Fallbacks = append([]*pgconn.FallbackConfig{ + { + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) + + // Make primary config bad + config.Host = "localhost" + config.Port = 1 // presumably nothing listening here + + // Prepend bad first fallback + config.Fallbacks = append([]*pgconn.FallbackConfig{ + { + Host: "localhost", + Port: 1, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) + + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + closeConn(t, conn) +} + +func TestConnectFailsWithResolveFailureAndFailedConnectionAttempts(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + conn, err := pgconn.Connect(ctx, "host=localhost,127.0.0.1,foo.invalid port=1,2,3 sslmode=disable") + require.Error(t, err) + require.Nil(t, conn) + + require.ErrorContains(t, err, "lookup foo.invalid") + // Not testing the entire string as depending on IPv4 or IPv6 support localhost may resolve to 127.0.0.1 or ::1. + require.ErrorContains(t, err, ":1 (localhost): dial error:") + require.ErrorContains(t, err, ":2 (127.0.0.1): dial error:") +} + +func TestConnectWithValidateConnect(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + dialCount := 0 + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialCount++ + return net.Dial(network, address) + } + + acceptConnCount := 0 + config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + acceptConnCount++ + if acceptConnCount < 2 { + return errors.New("reject first conn") + } + return nil + } + + // Append current primary config to fallbacks + config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }) + + // Repeat fallbacks + config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) + + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + closeConn(t, conn) + + assert.True(t, dialCount > 1) + assert.True(t, acceptConnCount > 1) +} + +func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite + config.RuntimeParams["default_transaction_read_only"] = "on" + + conn, err := pgconn.ConnectConfig(ctx, config) + if !assert.NotNil(t, err) { + conn.Close(ctx) + } +} + +func TestConnectWithAfterConnect(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() + return err + } + + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + + results, err := conn.Exec(ctx, "show search_path;").ReadAll() + require.NoError(t, err) + defer closeConn(t, conn) + + assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) +} + +func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config := &pgconn.Config{} + + require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(ctx, config) }) +} + +func TestConnPrepareSyntaxError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + psd, err := pgConn.Prepare(ctx, "ps1", "SYNTAX ERROR", nil) + require.Nil(t, psd) + require.NotNil(t, err) + + ensureConnValid(t, pgConn) +} + +func TestConnPrepareContextPrecanceled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + cancel() + + psd, err := pgConn.Prepare(ctx, "ps1", "select 1", nil) + assert.Nil(t, psd) + assert.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, pgconn.SafeToRetry(err)) + + ensureConnValid(t, pgConn) +} + +func TestConnDeallocate(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(ctx, "ps1", "select 1", nil) + require.NoError(t, err) + + _, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close() + require.NoError(t, err) + + err = pgConn.Deallocate(ctx, "ps1") + require.NoError(t, err) + + _, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close() + require.Error(t, err) + var pgErr *pgconn.PgError + require.ErrorAs(t, err, &pgErr) + require.Equal(t, "26000", pgErr.Code) + + ensureConnValid(t, pgConn) +} + +func TestConnDeallocateSucceedsInAbortedTransaction(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + err = pgConn.Exec(ctx, "begin").Close() + require.NoError(t, err) + + _, err = pgConn.Prepare(ctx, "ps1", "select 1", nil) + require.NoError(t, err) + + _, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close() + require.NoError(t, err) + + err = pgConn.Exec(ctx, "select 1/0").Close() // break transaction with divide by 0 error + require.Error(t, err) + var pgErr *pgconn.PgError + require.ErrorAs(t, err, &pgErr) + require.Equal(t, "22012", pgErr.Code) + + err = pgConn.Deallocate(ctx, "ps1") + require.NoError(t, err) + + err = pgConn.Exec(ctx, "rollback").Close() + require.NoError(t, err) + + _, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close() + require.Error(t, err) + require.ErrorAs(t, err, &pgErr) + require.Equal(t, "26000", pgErr.Code) + + ensureConnValid(t, pgConn) +} + +func TestConnDeallocateNonExistantStatementSucceeds(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + err = pgConn.Deallocate(ctx, "ps1") + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestConnExec(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + results, err := pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() + assert.NoError(t, err) + + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + + ensureConnValid(t, pgConn) +} + +func TestConnExecEmpty(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + results, err := pgConn.Exec(ctx, ";").ReadAll() + require.NoError(t, err) + require.Len(t, results, 1) + require.Nil(t, results[0].Err) + require.Equal(t, "", results[0].CommandTag.String()) + require.Nil(t, results[0].FieldDescriptions) + + ensureConnValid(t, pgConn) +} + +func TestConnExecMultipleQueries(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + results, err := pgConn.Exec(ctx, "select 'Hello, world'; select 1").ReadAll() + assert.NoError(t, err) + + assert.Len(t, results, 2) + + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + + assert.Nil(t, results[1].Err) + assert.Equal(t, "SELECT 1", results[1].CommandTag.String()) + assert.Len(t, results[1].Rows, 1) + assert.Equal(t, "1", string(results[1].Rows[0][0])) + + ensureConnValid(t, pgConn) +} + +func TestConnExecMultipleQueriesEagerFieldDescriptions(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + mrr := pgConn.Exec(ctx, "select 'Hello, world' as msg; select 1 as num") + + require.True(t, mrr.NextResult()) + require.Len(t, mrr.ResultReader().FieldDescriptions(), 1) + assert.Equal(t, "msg", mrr.ResultReader().FieldDescriptions()[0].Name) + _, err = mrr.ResultReader().Close() + require.NoError(t, err) + + require.True(t, mrr.NextResult()) + require.Len(t, mrr.ResultReader().FieldDescriptions(), 1) + assert.Equal(t, "num", mrr.ResultReader().FieldDescriptions()[0].Name) + _, err = mrr.ResultReader().Close() + require.NoError(t, err) + + require.False(t, mrr.NextResult()) + + require.NoError(t, mrr.Close()) + + ensureConnValid(t, pgConn) +} + +func TestConnExecMultipleQueriesError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + results, err := pgConn.Exec(ctx, "select 1; select 1/0; select 1").ReadAll() + require.NotNil(t, err) + if pgErr, ok := err.(*pgconn.PgError); ok { + assert.Equal(t, "22012", pgErr.Code) + } else { + t.Errorf("unexpected error: %v", err) + } + + if pgConn.ParameterStatus("crdb_version") != "" { + // CockroachDB starts the second query result set and then sends the divide by zero error. + require.Len(t, results, 2) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "1", string(results[0].Rows[0][0])) + assert.Len(t, results[1].Rows, 0) + } else { + // PostgreSQL sends the divide by zero and never sends the second query result set. + require.Len(t, results, 1) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "1", string(results[0].Rows[0][0])) + } + + ensureConnValid(t, pgConn) +} + +func TestConnExecDeferredError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + } + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(ctx, setupSQL).ReadAll() + assert.NoError(t, err) + + _, err = pgConn.Exec(ctx, `update t set n=n+1 where id='b' returning *`).ReadAll() + require.NotNil(t, err) + + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + +func TestConnExecContextCanceled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + cancel() + + ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(1)") + + for multiResult.NextResult() { + } + err = multiResult.Close() + assert.True(t, pgconn.Timeout(err)) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + +func TestConnExecContextPrecanceled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + cancel() + _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() + assert.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, pgconn.SafeToRetry(err)) + + ensureConnValid(t, pgConn) +} + +func TestConnExecParams(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + result := pgConn.ExecParams(ctx, "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + require.Len(t, result.FieldDescriptions(), 1) + assert.Equal(t, "msg", result.FieldDescriptions()[0].Name) + + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", commandTag.String()) + assert.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestConnExecParamsDeferredError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + } + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(ctx, setupSQL).ReadAll() + assert.NoError(t, err) + + result := pgConn.ExecParams(ctx, `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() + require.NotNil(t, result.Err) + var pgErr *pgconn.PgError + require.True(t, errors.As(result.Err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + +func TestConnExecParamsMaxNumberOfParams(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) + + ensureConnValid(t, pgConn) +} + +func TestConnExecParamsTooManyParams(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + + ensureConnValid(t, pgConn) +} + +func TestConnExecParamsCanceled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond) + defer cancel() + result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + } + assert.Equal(t, 0, rowCount) + commandTag, err := result.Close() + assert.Equal(t, pgconn.CommandTag{}, commandTag) + assert.True(t, pgconn.Timeout(err)) + assert.ErrorIs(t, err, context.DeadlineExceeded) + + assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + +func TestConnExecParamsPrecanceled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + cancel() + result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read() + require.Error(t, result.Err) + assert.True(t, errors.Is(result.Err, context.Canceled)) + assert.True(t, pgconn.SafeToRetry(result.Err)) + + ensureConnValid(t, pgConn) +} + +func TestConnExecParamsEmptySQL(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + result := pgConn.ExecParams(ctx, "", nil, nil, nil, nil).Read() + assert.Equal(t, pgconn.CommandTag{}, result.CommandTag) + assert.Len(t, result.Rows, 0) + assert.NoError(t, result.Err) + + ensureConnValid(t, pgConn) +} + +// https://github.com/jackc/pgx/issues/859 +func TestResultReaderValuesHaveSameCapacityAsLength(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + result := pgConn.ExecParams(ctx, "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + require.Len(t, result.FieldDescriptions(), 1) + assert.Equal(t, "msg", result.FieldDescriptions()[0].Name) + + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + assert.Equal(t, len(result.Values()[0]), cap(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", commandTag.String()) + assert.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +// https://github.com/jackc/pgx/issues/1987 +func TestResultReaderReadNil(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + result := pgConn.ExecParams(ctx, "select null::text", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Nil(t, result.Rows[0][0]) + + ensureConnValid(t, pgConn) +} + +func TestConnExecPrepared(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + psd, err := pgConn.Prepare(ctx, "ps1", "select $1::text as msg", nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, 1) + assert.Len(t, psd.Fields, 1) + + result := pgConn.ExecPrepared(ctx, "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) + require.Len(t, result.FieldDescriptions(), 1) + assert.Equal(t, "msg", result.FieldDescriptions()[0].Name) + + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", commandTag.String()) + assert.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestConnExecPreparedMaxNumberOfParams(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + psd, err := pgConn.Prepare(ctx, "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) + + result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) + + ensureConnValid(t, pgConn) +} + +func TestConnExecPreparedTooManyParams(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + psd, err := pgConn.Prepare(ctx, "ps1", sql, nil) + if pgConn.ParameterStatus("crdb_version") != "" { + // CockroachDB rejects preparing a statement with more than 65535 parameters. + require.EqualError(t, err, "ERROR: more than 65535 arguments to prepared statement: 65536 (SQLSTATE 08P01)") + } else { + // PostgreSQL accepts preparing a statement with more than 65535 parameters and only fails when executing it through the extended protocol. + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) + + result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read() + require.EqualError(t, result.Err, "extended protocol limited to 65535 parameters") + } + + ensureConnValid(t, pgConn) +} + +func TestConnExecPreparedCanceled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(ctx, "ps1", "select current_database(), pg_sleep(1)", nil) + require.NoError(t, err) + + ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond) + defer cancel() + result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + } + assert.Equal(t, 0, rowCount) + commandTag, err := result.Close() + assert.Equal(t, pgconn.CommandTag{}, commandTag) + assert.True(t, pgconn.Timeout(err)) + assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + +func TestConnExecPreparedPrecanceled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(ctx, "ps1", "select current_database(), pg_sleep(1)", nil) + require.NoError(t, err) + + cancel() + result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() + require.Error(t, result.Err) + assert.True(t, errors.Is(result.Err, context.Canceled)) + assert.True(t, pgconn.SafeToRetry(result.Err)) + + ensureConnValid(t, pgConn) +} + +func TestConnExecPreparedEmptySQL(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(ctx, "ps1", "", nil) + require.NoError(t, err) + + result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() + assert.Equal(t, pgconn.CommandTag{}, result.CommandTag) + assert.Len(t, result.Rows, 0) + assert.NoError(t, result.Err) + + ensureConnValid(t, pgConn) +} + +func TestConnExecBatch(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(ctx, "ps1", "select $1::text", nil) + require.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) + results, err := pgConn.ExecBatch(ctx, batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, 3) + + require.Len(t, results[0].Rows, 1) + require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) + + require.Len(t, results[1].Rows, 1) + require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) + assert.Equal(t, "SELECT 1", results[1].CommandTag.String()) + + require.Len(t, results[2].Rows, 1) + require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) + assert.Equal(t, "SELECT 1", results[2].CommandTag.String()) +} + +type mockConnection struct { + net.Conn + writeLatency *time.Duration +} + +func (m mockConnection) Write(b []byte) (n int, err error) { + time.Sleep(*m.writeLatency) + return m.Conn.Write(b) +} + +func TestConnExecBatchWriteError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + var mockConn mockConnection + writeLatency := 0 * time.Second + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + conn, err := net.Dial(network, address) + mockConn = mockConnection{conn, &writeLatency} + return mockConn, err + } + + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + batch := &pgconn.Batch{} + pgConn.Conn() + + ctx2, cancel2 := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel2() + + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + writeLatency = 2 * time.Second + mrr := pgConn.ExecBatch(ctx2, batch) + err = mrr.Close() + require.Error(t, err) + assert.ErrorIs(t, err, context.DeadlineExceeded) + require.True(t, pgConn.IsClosed()) +} + +func TestConnExecBatchDeferredError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + } + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(ctx, setupSQL).ReadAll() + require.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) + _, err = pgConn.ExecBatch(ctx, batch).ReadAll() + require.NotNil(t, err) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + +func TestConnExecBatchPrecanceled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(ctx, "ps1", "select $1::text", nil) + require.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) + + cancel() + _, err = pgConn.ExecBatch(ctx, batch).ReadAll() + require.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, pgconn.SafeToRetry(err)) + + ensureConnValid(t, pgConn) +} + +// Without concurrent reading and writing large batches can deadlock. +// +// See https://github.com/jackc/pgx/issues/374. +func TestConnExecBatchHuge(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } + + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + batch := &pgconn.Batch{} + + queryCount := 100000 + args := make([]string, queryCount) + + for i := range args { + args[i] = strconv.Itoa(i) + batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) + } + + results, err := pgConn.ExecBatch(ctx, batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, queryCount) + + for i := range args { + require.Len(t, results[i].Rows, 1) + require.Equal(t, args[i], string(results[i].Rows[0][0])) + assert.Equal(t, "SELECT 1", results[i].CommandTag.String()) + } +} + +func TestConnExecBatchImplicitTransaction(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/44803)") + } + + _, err = pgConn.Exec(ctx, "create temporary table t(id int)").ReadAll() + require.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) + batch.ExecParams("select 1/0", nil, nil, nil, nil) + _, err = pgConn.ExecBatch(ctx, batch).ReadAll() + require.Error(t, err) + + result := pgConn.ExecParams(ctx, "select count(*) from t", nil, nil, nil, nil).Read() + require.Equal(t, "0", string(result.Rows[0][0])) +} + +func TestConnLocking(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + mrr := pgConn.Exec(ctx, "select 'Hello, world'") + _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() + assert.Error(t, err) + assert.Equal(t, "conn busy", err.Error()) + assert.True(t, pgconn.SafeToRetry(err)) + + results, err := mrr.ReadAll() + assert.NoError(t, err) + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + + ensureConnValid(t, pgConn) +} + +func TestConnOnNotice(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + var notice *pgconn.Notice + config.OnNotice = func(c *pgconn.PgConn, n *pgconn.Notice) { + notice = n + } + config.RuntimeParams["client_min_messages"] = "notice" // Ensure we only get the message we expect. + + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support PL/PGSQL (https://github.com/cockroachdb/cockroach/issues/17511)") + } + + multiResult := pgConn.Exec(ctx, `do $$ +begin + raise notice 'hello, world'; +end$$;`) + err = multiResult.Close() + require.NoError(t, err) + assert.Equal(t, "NOTICE", notice.SeverityUnlocalized) + assert.Equal(t, "hello, world", notice.Message) + + ensureConnValid(t, pgConn) +} + +func TestConnOnNotification(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } + + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") + } + + _, err = pgConn.Exec(ctx, "listen foo").ReadAll() + require.NoError(t, err) + + notifier, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll() + require.NoError(t, err) + + _, err = pgConn.Exec(ctx, "select 1").ReadAll() + require.NoError(t, err) + + assert.Equal(t, "bar", msg) + + ensureConnValid(t, pgConn) +} + +func TestConnWaitForNotification(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } + + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") + } + + _, err = pgConn.Exec(ctx, "listen foo").ReadAll() + require.NoError(t, err) + + notifier, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll() + require.NoError(t, err) + + err = pgConn.WaitForNotification(ctx) + require.NoError(t, err) + + assert.Equal(t, "bar", msg) + + ensureConnValid(t, pgConn) +} + +func TestConnWaitForNotificationPrecanceled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + cancel() + err = pgConn.WaitForNotification(ctx) + require.ErrorIs(t, err, context.Canceled) + + ensureConnValid(t, pgConn) +} + +func TestConnWaitForNotificationTimeout(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel = context.WithTimeout(ctx, 5*time.Millisecond) + err = pgConn.WaitForNotification(ctx) + cancel() + assert.True(t, pgconn.Timeout(err)) + assert.ErrorIs(t, err, context.DeadlineExceeded) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyToSmall(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does support COPY TO") + } + + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json + )`).ReadAll() + require.NoError(t, err) + + _, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() + require.NoError(t, err) + + _, err = pgConn.Exec(ctx, `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() + require.NoError(t, err) + + inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + + "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") + + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + + res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout") + require.NoError(t, err) + + assert.Equal(t, int64(2), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyToLarge(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does support COPY TO") + } + + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json, + h bytea + )`).ReadAll() + require.NoError(t, err) + + inputBytes := make([]byte, 0) + + for i := 0; i < 1000; i++ { + _, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() + require.NoError(t, err) + inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) + } + + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + + res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout") + require.NoError(t, err) + + assert.Equal(t, int64(1000), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyToQueryError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + outputWriter := bytes.NewBuffer(make([]byte, 0)) + + res, err := pgConn.CopyTo(ctx, outputWriter, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyToCanceled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)") + } + + outputWriter := &bytes.Buffer{} + + ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond) + defer cancel() + res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") + assert.Error(t, err) + assert.Equal(t, pgconn.CommandTag{}, res) + + assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + +func TestConnCopyToPrecanceled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + outputWriter := &bytes.Buffer{} + + cancel() + res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout") + require.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, pgconn.SafeToRetry(err)) + assert.Equal(t, pgconn.CommandTag{}, res) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFrom(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + srcBuf := &bytes.Buffer{} + + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } + + copySql := "COPY foo FROM STDIN WITH (FORMAT csv)" + if pgConn.ParameterStatus("crdb_version") != "" { + copySql = "COPY foo FROM STDIN WITH CSV" + } + ct, err := pgConn.CopyFrom(ctx, srcBuf, copySql) + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + + result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + assert.Equal(t, inputRows, result.Rows) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFromBinary(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + buf := []byte{} + buf = append(buf, "PGCOPY\n\377\r\n\000"...) + buf = pgio.AppendInt32(buf, 0) + buf = pgio.AppendInt32(buf, 0) + + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + // Number of elements in the tuple + buf = pgio.AppendInt16(buf, int16(2)) + a := i + + // Length of element for column `a int4` + buf = pgio.AppendInt32(buf, 4) + buf, err = pgtype.NewMap().Encode(pgtype.Int4OID, pgx.BinaryFormatCode, a, buf) + require.NoError(t, err) + + b := "foo " + strconv.Itoa(a) + " bar" + lenB := int32(len([]byte(b))) + // Length of element for column `b varchar` + buf = pgio.AppendInt32(buf, lenB) + buf, err = pgtype.NewMap().Encode(pgtype.VarcharOID, pgx.BinaryFormatCode, b, buf) + require.NoError(t, err) + + inputRows = append(inputRows, [][]byte{[]byte(strconv.Itoa(a)), []byte(b)}) + } + + srcBuf := &bytes.Buffer{} + srcBuf.Write(buf) + ct, err := pgConn.CopyFrom(ctx, srcBuf, "COPY foo (a, b) FROM STDIN BINARY;") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + + result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + assert.Equal(t, inputRows, result.Rows) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFromCanceled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + r, w := io.Pipe() + go func() { + for i := 0; i < 1000000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + _, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + if err != nil { + return + } + time.Sleep(time.Microsecond) + } + }() + + ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond) + copySql := "COPY foo FROM STDIN WITH (FORMAT csv)" + if pgConn.ParameterStatus("crdb_version") != "" { + copySql = "COPY foo FROM STDIN WITH CSV" + } + ct, err := pgConn.CopyFrom(ctx, r, copySql) + cancel() + assert.Equal(t, int64(0), ct.RowsAffected()) + assert.Error(t, err) + + assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + +func TestConnCopyFromPrecanceled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + r, w := io.Pipe() + go func() { + for i := 0; i < 1000000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + _, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + if err != nil { + return + } + time.Sleep(time.Microsecond) + } + }() + + ctx, cancel = context.WithCancel(ctx) + cancel() + ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, pgconn.SafeToRetry(err)) + assert.Equal(t, pgconn.CommandTag{}, ct) + + ensureConnValid(t, pgConn) +} + +// https://github.com/jackc/pgx/issues/2364 +func TestConnCopyFromConnectionTerminated(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support pg_terminate_backend") + } + + closerConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + time.AfterFunc(500*time.Millisecond, func() { + // defer inside of AfterFunc instead of outer test function because outer function can finish while Read is still in + // progress which could cause closerConn to be closed too soon. + defer closeConn(t, closerConn) + err := closerConn.ExecParams(ctx, "select pg_terminate_backend($1)", [][]byte{[]byte(fmt.Sprintf("%d", pgConn.PID()))}, nil, nil, nil).Read().Err + require.NoError(t, err) + }) + + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + r, w := io.Pipe() + go func() { + for i := 0; i < 5_000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + _, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + if err != nil { + return + } + time.Sleep(time.Millisecond) + } + }() + + copySql := "COPY foo FROM STDIN WITH (FORMAT csv)" + ct, err := pgConn.CopyFrom(ctx, r, copySql) + assert.Equal(t, int64(0), ct.RowsAffected()) + assert.Error(t, err) + + assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + +func TestConnCopyFromGzipReader(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not fully support COPY FROM (https://www.cockroachlabs.com/docs/v20.2/copy-from.html)") + } + + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + f, err := os.CreateTemp(t.TempDir(), "*") + require.NoError(t, err) + defer f.Close() + + gw := gzip.NewWriter(f) + + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } + + err = gw.Close() + require.NoError(t, err) + + _, err = f.Seek(0, 0) + require.NoError(t, err) + + gr, err := gzip.NewReader(f) + require.NoError(t, err) + + copySql := "COPY foo FROM STDIN WITH (FORMAT csv)" + if pgConn.ParameterStatus("crdb_version") != "" { + copySql = "COPY foo FROM STDIN WITH CSV" + } + ct, err := pgConn.CopyFrom(ctx, gr, copySql) + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + + err = gr.Close() + require.NoError(t, err) + + result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + assert.Equal(t, inputRows, result.Rows) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFromQuerySyntaxError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + srcBuf := &bytes.Buffer{} + + // Send data even though the COPY FROM command will be rejected with a syntax error. This ensures that this does not + // break the connection. See https://github.com/jackc/pgconn/pull/127 for context. + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } + + res, err := pgConn.CopyFrom(ctx, srcBuf, "cropy foo FROM STDIN WITH (FORMAT csv)") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFromQueryNoTableError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + srcBuf := &bytes.Buffer{} + + res, err := pgConn.CopyFrom(ctx, srcBuf, "copy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) + + ensureConnValid(t, pgConn) +} + +// https://github.com/jackc/pgconn/issues/21 +func TestConnCopyFromNoticeResponseReceivedMidStream(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support triggers (https://github.com/cockroachdb/cockroach/issues/28296)") + } + + _, err = pgConn.Exec(ctx, `create temporary table sentences( + t text, + ts tsvector + )`).ReadAll() + require.NoError(t, err) + + _, err = pgConn.Exec(ctx, `create function pg_temp.sentences_trigger() returns trigger as $$ + begin + new.ts := to_tsvector(new.t); + return new; + end + $$ language plpgsql;`).ReadAll() + require.NoError(t, err) + + _, err = pgConn.Exec(ctx, `create trigger sentences_update before insert on sentences for each row execute procedure pg_temp.sentences_trigger();`).ReadAll() + require.NoError(t, err) + + longString := make([]byte, 10001) + for i := range longString { + longString[i] = 'x' + } + + buf := &bytes.Buffer{} + for i := 0; i < 1000; i++ { + buf.Write([]byte(fmt.Sprintf("%s\n", string(longString)))) + } + + _, err = pgConn.CopyFrom(ctx, buf, "COPY sentences(t) FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) +} + +type delayedReader struct { + r io.Reader +} + +func (d delayedReader) Read(p []byte) (int, error) { + // W/o sleep test passes, with sleep it fails. + time.Sleep(time.Millisecond) + return d.r.Read(p) +} + +// https://github.com/jackc/pgconn/issues/128 +func TestConnCopyFromDataWriteAfterErrorAndReturn(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + connString := os.Getenv("PGX_TEST_DATABASE") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_DATABASE") + } + + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not fully support COPY FROM") + } + + setupSQL := `create temporary table t ( + id text primary key, + n int not null + );` + + _, err = pgConn.Exec(ctx, setupSQL).ReadAll() + assert.NoError(t, err) + + r1 := delayedReader{r: strings.NewReader(`id 0\n`)} + // Generate an error with a bogus COPY command + _, err = pgConn.CopyFrom(ctx, r1, "COPY nosuchtable FROM STDIN ") + assert.Error(t, err) + + r2 := delayedReader{r: strings.NewReader(`id 0\n`)} + _, err = pgConn.CopyFrom(ctx, r2, "COPY t FROM STDIN") + assert.NoError(t, err) +} + +func TestConnEscapeString(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + tests := []struct { + in string + out string + }{ + {in: "", out: ""}, + {in: "42", out: "42"}, + {in: "'", out: "''"}, + {in: "hi'there", out: "hi''there"}, + {in: "'hi there'", out: "''hi there''"}, + } + + for i, tt := range tests { + value, err := pgConn.EscapeString(tt.in) + if assert.NoErrorf(t, err, "%d.", i) { + assert.Equalf(t, tt.out, value, "%d.", i) + } + } + + ensureConnValid(t, pgConn) +} + +func TestConnCancelRequest(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)") + } + + multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(25)") + + errChan := make(chan error) + go func() { + // The query is actually sent when multiResult.NextResult() is called. So wait to ensure it is sent. + // Once Flush is available this could use that instead. + time.Sleep(1 * time.Second) + + err := pgConn.CancelRequest(ctx) + errChan <- err + }() + + for multiResult.NextResult() { + } + err = multiResult.Close() + + require.IsType(t, &pgconn.PgError{}, err) + require.Equal(t, "57014", err.(*pgconn.PgError).Code) + + err = <-errChan + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +// https://github.com/jackc/pgx/issues/659 +func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) { + t.Parallel() + + t.Run("postgres", func(t *testing.T) { + t.Parallel() + + testConnContextCanceledCancelsRunningQueryOnServer(t, os.Getenv("PGX_TEST_DATABASE"), "postgres") + }) + + t.Run("pgbouncer", func(t *testing.T) { + t.Parallel() + + connString := os.Getenv(pgbouncerConnStringEnvVar) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", pgbouncerConnStringEnvVar) + } + + testConnContextCanceledCancelsRunningQueryOnServer(t, connString, "pgbouncer") + }) +} + +func testConnContextCanceledCancelsRunningQueryOnServer(t *testing.T, connString, dbType string) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, connString) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond) + defer cancel() + + // Getting the actual PostgreSQL server process ID (PID) from a query executed through pgbouncer is not straightforward + // because pgbouncer abstracts the underlying database connections, and it doesn't expose the PID of the PostgreSQL + // server process to clients. However, we can check if the query is running by checking the generated query ID. + queryID := fmt.Sprintf("%s testConnContextCanceled %d", dbType, time.Now().UnixNano()) + + multiResult := pgConn.Exec(ctx, fmt.Sprintf(` + -- %v + select 'Hello, world', pg_sleep(30) + `, queryID)) + + for multiResult.NextResult() { + } + err = multiResult.Close() + assert.True(t, pgconn.Timeout(err)) + assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } + + ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + otherConn, err := pgconn.Connect(ctx, connString) + require.NoError(t, err) + defer closeConn(t, otherConn) + + ctx, cancel = context.WithTimeout(ctx, time.Second*5) + defer cancel() + + for { + result := otherConn.ExecParams(ctx, + `select 1 from pg_stat_activity where query like $1`, + [][]byte{[]byte("%" + queryID + "%")}, + nil, + nil, + nil, + ).Read() + require.NoError(t, result.Err) + + if len(result.Rows) == 0 { + break + } + } +} + +func TestHijackAndConstruct(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + origConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + err = origConn.SyncConn(ctx) + require.NoError(t, err) + + hc, err := origConn.Hijack() + require.NoError(t, err) + + _, err = origConn.Exec(ctx, "select 'Hello, world'").ReadAll() + require.Error(t, err) + + newConn, err := pgconn.Construct(hc) + require.NoError(t, err) + + defer closeConn(t, newConn) + + results, err := newConn.Exec(ctx, "select 'Hello, world'").ReadAll() + assert.NoError(t, err) + + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + + ensureConnValid(t, newConn) +} + +func TestConnCloseWhileCancellableQueryInProgress(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + pgConn.Exec(ctx, "select n from generate_series(1,10) n") + + closeCtx, _ := context.WithCancel(ctx) + pgConn.Close(closeCtx) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + +// https://github.com/jackc/pgx/issues/800 +func TestFatalErrorReceivedAfterCommandComplete(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + steps := pgmock.AcceptUnauthenticatedConnRequestSteps() + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Bind{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Execute{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Sync{})) + steps = append(steps, pgmock.SendMessage(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ + {Name: []byte("mock")}, + }})) + steps = append(steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 0")})) + steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"})) + + script := &pgmock.Script{Steps: steps} + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + serverErrChan <- err + return + } + + err = script.Run(pgproto3.NewBackend(conn, conn)) + if err != nil { + serverErrChan <- err + return + } + }() + + host, port, _ := strings.Cut(ln.Addr().String(), ":") + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) + + ctx, cancel = context.WithTimeout(ctx, 5*time.Second) + defer cancel() + conn, err := pgconn.Connect(ctx, connStr) + require.NoError(t, err) + + rr := conn.ExecParams(ctx, "mocked...", nil, nil, nil, nil) + + for rr.NextRow() { + } + + _, err = rr.Close() + require.Error(t, err) +} + +// https://github.com/jackc/pgconn/issues/27 +func TestConnLargeResponseWhileWritingDoesNotDeadlock(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(ctx, "set client_min_messages = debug5").ReadAll() + require.NoError(t, err) + + // The actual contents of this test aren't important. What's important is a large amount of data to be written and + // because of client_min_messages = debug5 the server will return a large amount of data. + + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) + + ensureConnValid(t, pgConn) +} + +func TestConnCheckConn(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + // Intentionally using TCP connection for more predictable close behavior. (Not sure if Unix domain sockets would behave subtly different.) + + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } + + c1, err := pgconn.Connect(ctx, connString) + require.NoError(t, err) + defer c1.Close(ctx) + + if c1.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") + } + + err = c1.CheckConn() + require.NoError(t, err) + + c2, err := pgconn.Connect(ctx, connString) + require.NoError(t, err) + defer c2.Close(ctx) + + _, err = c2.Exec(ctx, fmt.Sprintf("select pg_terminate_backend(%d)", c1.PID())).ReadAll() + require.NoError(t, err) + + // It may take a while for the server to kill the backend. Retry until the error is detected or the test context is + // canceled. + for err == nil && ctx.Err() == nil { + time.Sleep(50 * time.Millisecond) + err = c1.CheckConn() + } + require.Error(t, err) +} + +func TestConnPing(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + // Intentionally using TCP connection for more predictable close behavior. (Not sure if Unix domain sockets would behave subtly different.) + + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } + + c1, err := pgconn.Connect(ctx, connString) + require.NoError(t, err) + defer c1.Close(ctx) + + if c1.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") + } + + err = c1.Exec(ctx, "set log_statement = 'all'").Close() + require.NoError(t, err) + + err = c1.Ping(ctx) + require.NoError(t, err) + + c2, err := pgconn.Connect(ctx, connString) + require.NoError(t, err) + defer c2.Close(ctx) + + _, err = c2.Exec(ctx, fmt.Sprintf("select pg_terminate_backend(%d)", c1.PID())).ReadAll() + require.NoError(t, err) + + // Give a little time for the signal to actually kill the backend. + time.Sleep(500 * time.Millisecond) + + err = c1.Ping(ctx) + require.Error(t, err) +} + +func TestPipelinePrepare(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + result := pgConn.ExecParams(ctx, `create temporary table t (id text primary key)`, nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + pipeline := pgConn.StartPipeline(ctx) + pipeline.SendPrepare("selectInt", "select $1::bigint as a", nil) + pipeline.SendPrepare("selectText", "select $1::text as b", nil) + pipeline.SendPrepare("selectNoParams", "select 42 as c", nil) + pipeline.SendPrepare("insertNoResults", "insert into t (id) values ($1)", nil) + pipeline.SendPrepare("insertNoParamsOrResults", "insert into t (id) values ('foo')", nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, "a", string(sd.Fields[0].Name)) + require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + sd, ok = results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, "b", string(sd.Fields[0].Name)) + require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + sd, ok = results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, "c", string(sd.Fields[0].Name)) + require.Equal(t, []uint32{}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + sd, ok = results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 0) + require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + sd, ok = results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 0) + require.Len(t, sd.ParamOIDs, 0) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelinePrepareError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(ctx) + pipeline.SendPrepare("selectInt", "select $1::bigint as a", nil) + pipeline.SendPrepare("selectError", "bad", nil) + pipeline.SendPrepare("selectText", "select $1::text as b", nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, "a", string(sd.Fields[0].Name)) + require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + var pgErr *pgconn.PgError + require.ErrorAs(t, err, &pgErr) + require.Nil(t, results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelinePrepareAndDeallocate(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(ctx) + pipeline.SendPrepare("selectInt", "select $1::bigint as a", nil) + pipeline.SendDeallocate("selectInt") + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, "a", string(sd.Fields[0].Name)) + require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.CloseComplete) + require.Truef(t, ok, "expected CloseComplete, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineQuery(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(ctx) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "2", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "3", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "4", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "5", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelinePrepareQuery(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(ctx) + pipeline.SendPrepare("ps", "select $1::text as msg", nil) + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil) + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("goodbye")}, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, "msg", string(sd.Fields[0].Name)) + require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "hello", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "goodbye", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineQueryErrorBetweenSyncs(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(ctx) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 1/(3-n) from generate_series(1,10) n`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 6`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "2", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "3", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + var pgErr *pgconn.PgError + require.ErrorAs(t, readResult.Err, &pgErr) + require.Equal(t, "22012", pgErr.Code) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "5", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "6", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineFlushForSingleRequests(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(ctx) + + pipeline.SendPrepare("ps", "select $1::text as msg", nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, "msg", string(sd.Fields[0].Name)) + require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "hello", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendDeallocate("ps") + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.CloseComplete) + require.Truef(t, ok, "expected CloseComplete, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Sync() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineFlushForRequestSeries(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(ctx) + pipeline.SendPrepare("ps", "select $1::bigint as num", nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, "num", string(sd.Fields[0].Name)) + require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("1")}, nil, nil) + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("2")}, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "2", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("3")}, nil, nil) + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("4")}, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "3", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "4", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("5")}, nil, nil) + pipeline.SendFlushRequest() + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("6")}, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "5", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "6", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Sync() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineFlushWithError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(ctx) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 1/(3-n) from generate_series(1,10) n`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + var pgErr *pgconn.PgError + require.ErrorAs(t, readResult.Err, &pgErr) + require.Equal(t, "22012", pgErr.Code) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + pipeline.SendPipelineSync() + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "5", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Sync() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineCloseReadsUnreadResults(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(ctx) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineCloseDetectsUnsyncedRequests(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(ctx) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + err = pipeline.Close() + require.EqualError(t, err, "pipeline has unsynced requests") +} + +func TestConnOnPgError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.OnPgError = func(c *pgconn.PgConn, pgErr *pgconn.PgError) bool { + require.NotNil(t, c) + require.NotNil(t, pgErr) + // close connection on undefined tables only + if pgErr.Code == "42P01" { + return false + } + return true + } + + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() + assert.NoError(t, err) + assert.False(t, pgConn.IsClosed()) + + _, err = pgConn.Exec(ctx, "select 1/0").ReadAll() + assert.Error(t, err) + assert.False(t, pgConn.IsClosed()) + + _, err = pgConn.Exec(ctx, "select * from non_existant_table").ReadAll() + assert.Error(t, err) + assert.True(t, pgConn.IsClosed()) +} + +func TestConnCustomData(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pgConn.CustomData()["foo"] = "bar" + assert.Equal(t, "bar", pgConn.CustomData()["foo"]) + + ensureConnValid(t, pgConn) +} + +func Example() { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + log.Fatalln(err) + } + defer pgConn.Close(ctx) + + result := pgConn.ExecParams(ctx, "select generate_series(1,3)", nil, nil, nil, nil).Read() + if result.Err != nil { + log.Fatalln(result.Err) + } + + for _, row := range result.Rows { + fmt.Println(string(row[0])) + } + + fmt.Println(result.CommandTag) + // Output: + // 1 + // 2 + // 3 + // SELECT 3 +} + +func GetSSLPassword(ctx context.Context) string { + connString := os.Getenv("PGX_SSL_PASSWORD") + return connString +} + +var rsaCertPEM = `-----BEGIN CERTIFICATE----- +MIIDCTCCAfGgAwIBAgIUQDlN1g1bzxIJ8KWkayNcQY5gzMEwDQYJKoZIhvcNAQEL +BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIyMDgxNTIxNDgyNloXDTIzMDgx +NTIxNDgyNlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF +AAOCAQ8AMIIBCgKCAQEA0vOppiT8zE+076acRORzD5JVbRYKMK3XlWLVrHua4+ct +Rm54WyP+3XsYU4JGGGKgb8E+u2UosGJYcSM+b+U1/5XPTcpuumS+pCiD9WP++A39 +tsukYwR7m65cgpiI4dlLEZI3EWpAW+Bb3230KiYW4sAmQ0Ih4PrN+oPvzcs86F4d +9Y03CqVUxRKLBLaClZQAg8qz2Pawwj1FKKjDX7u2fRVR0wgOugpCMOBJMcCgz9pp +0HSa4x3KZDHEZY7Pah5XwWrCfAEfRWsSTGcNaoN8gSxGFM1JOEJa8SAuPGjFcYIv +MmVWdw0FXCgYlSDL02fzLE0uyvXBDibzSqOk770JhQIDAQABo1MwUTAdBgNVHQ4E +FgQUiJ8JLENJ+2k1Xl4o6y2Lc/qHTh0wHwYDVR0jBBgwFoAUiJ8JLENJ+2k1Xl4o +6y2Lc/qHTh0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAwjn2 +gnNAhFvh58VqLIjU6ftvn6rhz5B9dg2+XyY8sskLhhkO1nL9339BVZsRt+eI3a7I +81GNIm9qHVM3MUAcQv3SZy+0UPVUT8DNH2LwHT3CHnYTBP8U+8n8TDNGSTMUhIBB +Rx+6KwODpwLdI79VGT3IkbU9bZwuepB9I9nM5t/tt5kS4gHmJFlO0aLJFCTO4Scf +hp/WLPv4XQUH+I3cPfaJRxz2j0Kc8iOzMhFmvl1XOGByjX6X33LnOzY/LVeTSGyS +VgC32BGtnMwuy5XZYgFAeUx9HKy4tG4OH2Ux6uPF/WAhsug6PXSjV7BK6wYT5i27 +MlascjupnaptKX/wMA== +-----END CERTIFICATE----- +` + +var rsaKeyPEM = testingKey(`-----BEGIN TESTING KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDS86mmJPzMT7Tv +ppxE5HMPklVtFgowrdeVYtWse5rj5y1GbnhbI/7dexhTgkYYYqBvwT67ZSiwYlhx +Iz5v5TX/lc9Nym66ZL6kKIP1Y/74Df22y6RjBHubrlyCmIjh2UsRkjcRakBb4Fvf +bfQqJhbiwCZDQiHg+s36g+/NyzzoXh31jTcKpVTFEosEtoKVlACDyrPY9rDCPUUo +qMNfu7Z9FVHTCA66CkIw4EkxwKDP2mnQdJrjHcpkMcRljs9qHlfBasJ8AR9FaxJM +Zw1qg3yBLEYUzUk4QlrxIC48aMVxgi8yZVZ3DQVcKBiVIMvTZ/MsTS7K9cEOJvNK +o6TvvQmFAgMBAAECggEAKzTK54Ol33bn2TnnwdiElIjlRE2CUswYXrl6iDRc2hbs +WAOiVRB/T/+5UMla7/2rXJhY7+rdNZs/ABU24ZYxxCJ77jPrD/Q4c8j0lhsgCtBa +ycjV543wf0dsHTd+ubtWu8eVzdRUUD0YtB+CJevdPh4a+CWgaMMV0xyYzi61T+Yv +Z7Uc3awIAiT4Kw9JRmJiTnyMJg5vZqW3BBAX4ZIvS/54ipwEU+9sWLcuH7WmCR0B +QCTqS6hfJDLm//dGC89Iyno57zfYuiT3PYCWH5crr/DH3LqnwlNaOGSBkhkXuIL+ +QvOaUMe2i0pjqxDrkBx05V554vyy9jEvK7i330HL4QKBgQDUJmouEr0+o7EMBApC +CPPu58K04qY5t9aGciG/pOurN42PF99yNZ1CnynH6DbcnzSl8rjc6Y65tzTlWods +bjwVfcmcokG7sPcivJvVjrjKpSQhL8xdZwSAjcqjN4yoJ/+ghm9w+SRmZr6oCQZ3 +1jREfJKT+PGiWTEjYcExPWUD2QKBgQD+jdgq4c3tFavU8Hjnlf75xbStr5qu+fp2 +SGLRRbX+msQwVbl2ZM9AJLoX9MTCl7D9zaI3ONhheMmfJ77lDTa3VMFtr3NevGA6 +MxbiCEfRtQpNkJnsqCixLckx3bskj5+IF9BWzw7y7nOzdhoWVFv/+TltTm3RB51G +McdlmmVjjQKBgQDSFAw2/YV6vtu2O1XxGC591/Bd8MaMBziev+wde3GHhaZfGVPC +I8dLTpMwCwowpFKdNeLLl1gnHX161I+f1vUWjw4TVjVjaBUBx+VEr2Tb/nXtiwiD +QV0a883CnGJjreAblKRMKdpasMmBWhaWmn39h6Iad3zHuCzJjaaiXNpn2QKBgQCf +k1Q8LanmQnuh1c41f7aD5gjKCRezMUpt9BrejhD1NxheJJ9LNQ8nat6uPedLBcUS +lmJms+AR2qKqf0QQWyQ98YgAtshgTz8TvQtPT1mWgSOgVFHqJdC8obNK63FyDgc4 +TZVxlgQNDqbBjfv0m5XA9f+mIlB9hYR2iKYzb4K30QKBgQC+LEJYZh00zsXttGHr +5wU1RzbgDIEsNuu+nZ4MxsaCik8ILNRHNXdeQbnADKuo6ATfhdmDIQMVZLG8Mivi +UwnwLd1GhizvqvLHa3ULnFphRyMGFxaLGV48axTT2ADoMX67ILrIY/yjycLqRZ3T +z3w+CgS20UrbLIR1YXfqUXge1g== +-----END TESTING KEY----- +`) + +func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } + +func TestSNISupport(t *testing.T) { + t.Parallel() + tests := []struct { + name string + sni_param string + sni_set bool + }{ + { + name: "SNI is passed by default", + sni_param: "", + sni_set: true, + }, + { + name: "SNI is passed when asked for", + sni_param: "sslsni=1", + sni_set: true, + }, + { + name: "SNI is not passed when disabled", + sni_param: "sslsni=0", + sni_set: false, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + serverSNINameChan := make(chan string, 1) + defer close(serverErrChan) + defer close(serverSNINameChan) + + go func() { + var sniHost string + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + serverErrChan <- err + return + } + + backend := pgproto3.NewBackend(conn, conn) + startupMessage, err := backend.ReceiveStartupMessage() + if err != nil { + serverErrChan <- err + return + } + + switch startupMessage.(type) { + case *pgproto3.SSLRequest: + _, err = conn.Write([]byte("S")) + if err != nil { + serverErrChan <- err + return + } + default: + serverErrChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage) + return + } + + cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM)) + if err != nil { + serverErrChan <- err + return + } + + srv := tls.Server(conn, &tls.Config{ + Certificates: []tls.Certificate{cert}, + GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) { + sniHost = argHello.ServerName + return nil, nil + }, + }) + defer srv.Close() + + if err := srv.Handshake(); err != nil { + serverErrChan <- fmt.Errorf("handshake: %w", err) + return + } + + srv.Write(mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil))) + srv.Write(mustEncode((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil))) + srv.Write(mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil))) + + serverSNINameChan <- sniHost + }() + + _, port, _ := strings.Cut(ln.Addr().String(), ":") + connStr := fmt.Sprintf("sslmode=require host=localhost port=%s %s", port, tt.sni_param) + _, err = pgconn.Connect(ctx, connStr) + + select { + case sniHost := <-serverSNINameChan: + if tt.sni_set { + require.Equal(t, "localhost", sniHost) + } else { + require.Equal(t, "", sniHost) + } + case err = <-serverErrChan: + t.Fatalf("server failed with error: %+v", err) + case <-time.After(time.Millisecond * 100): + t.Fatal("exceeded connection timeout without erroring out") + } + }) + } +} + +func TestConnectWithDirectSSLNegotiation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + connString string + expectDirectNego bool + }{ + { + name: "Default negotiation (postgres)", + connString: "sslmode=require", + expectDirectNego: false, + }, + { + name: "Direct negotiation", + connString: "sslmode=require sslnegotiation=direct", + expectDirectNego: true, + }, + { + name: "Explicit postgres negotiation", + connString: "sslmode=require sslnegotiation=postgres", + expectDirectNego: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + _, port, err := net.SplitHostPort(ln.Addr().String()) + require.NoError(t, err) + + var directNegoObserved atomic.Bool + + serverErrCh := make(chan error, 1) + go func() { + defer close(serverErrCh) + + conn, err := ln.Accept() + if err != nil { + serverErrCh <- fmt.Errorf("accept error: %w", err) + return + } + defer conn.Close() + + conn.SetDeadline(time.Now().Add(5 * time.Second)) + + firstByte := make([]byte, 1) + _, err = conn.Read(firstByte) + if err != nil { + serverErrCh <- fmt.Errorf("read first byte error: %w", err) + return + } + + // Check if TLS Client Hello (direct) or PostgreSQL SSLRequest + isDirect := firstByte[0] >= 20 && firstByte[0] <= 23 + directNegoObserved.Store(isDirect) + + var tlsConn *tls.Conn + + if !isDirect { + // Handle standard PostgreSQL SSL negotiation + // Read the rest of the SSL request message + sslRequestRemainder := make([]byte, 7) + _, err = io.ReadFull(conn, sslRequestRemainder) + if err != nil { + serverErrCh <- fmt.Errorf("read ssl request remainder error: %w", err) + return + } + + // Send SSL acceptance response + _, err = conn.Write([]byte("S")) + if err != nil { + serverErrCh <- fmt.Errorf("write ssl acceptance error: %w", err) + return + } + + // Setup TLS server without needing to reuse the first byte + cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM)) + if err != nil { + serverErrCh <- fmt.Errorf("cert error: %w", err) + return + } + + tlsConn = tls.Server(conn, &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + } else { + // Handle direct TLS negotiation + // Setup TLS server with the first byte already read + cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM)) + if err != nil { + serverErrCh <- fmt.Errorf("cert error: %w", err) + return + } + + // Use a wrapper to inject the first byte back into the TLS handshake + bufConn := &prefixConn{ + Conn: conn, + prefixData: firstByte, + } + + tlsConn = tls.Server(bufConn, &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + } + + // Complete TLS handshake + if err := tlsConn.Handshake(); err != nil { + serverErrCh <- fmt.Errorf("TLS handshake error: %w", err) + return + } + defer tlsConn.Close() + + err = script.Run(pgproto3.NewBackend(tlsConn, tlsConn)) + if err != nil { + serverErrCh <- fmt.Errorf("pgmock run error: %w", err) + return + } + }() + + connStr := fmt.Sprintf("%s host=localhost port=%s sslmode=require sslinsecure=1", + tt.connString, port) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + conn, err := pgconn.Connect(ctx, connStr) + + require.NoError(t, err) + + defer conn.Close(ctx) + + err = <-serverErrCh + require.NoError(t, err) + + require.Equal(t, tt.expectDirectNego, directNegoObserved.Load()) + }) + } +} + +// prefixConn implements a net.Conn that prepends some data to the first Read +type prefixConn struct { + net.Conn + prefixData []byte + prefixConsumed bool +} + +func (c *prefixConn) Read(b []byte) (n int, err error) { + if !c.prefixConsumed && len(c.prefixData) > 0 { + n = copy(b, c.prefixData) + c.prefixData = c.prefixData[n:] + c.prefixConsumed = len(c.prefixData) == 0 + return n, nil + } + return c.Conn.Read(b) +} + +// https://github.com/jackc/pgx/issues/1920 +func TestFatalErrorReceivedInPipelineMode(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + steps := pgmock.AcceptUnauthenticatedConnRequestSteps() + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{})) + steps = append(steps, pgmock.SendMessage(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ + {Name: []byte("mock")}, + }})) + steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"})) + // We shouldn't get anything after the first fatal error. But the reported issue was with PgBouncer so maybe that + // causes the issue. Anyway, a FATAL error after the connection had already been killed could cause a panic. + steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"})) + + script := &pgmock.Script{Steps: steps} + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverKeepAlive := make(chan struct{}) + defer close(serverKeepAlive) + + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(59 * time.Second)) + if err != nil { + serverErrChan <- err + return + } + + err = script.Run(pgproto3.NewBackend(conn, conn)) + if err != nil { + serverErrChan <- err + return + } + + <-serverKeepAlive + }() + + parts := strings.Split(ln.Addr().String(), ":") + host := parts[0] + port := parts[1] + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) + + ctx, cancel = context.WithTimeout(ctx, 59*time.Second) + defer cancel() + conn, err := pgconn.Connect(ctx, connStr) + require.NoError(t, err) + + pipeline := conn.StartPipeline(ctx) + pipeline.SendPrepare("s1", "select 1", nil) + pipeline.SendPrepare("s2", "select 2", nil) + pipeline.SendPrepare("s3", "select 3", nil) + err = pipeline.Sync() + require.NoError(t, err) + + _, err = pipeline.GetResults() + require.NoError(t, err) + _, err = pipeline.GetResults() + require.Error(t, err) + + err = pipeline.Close() + require.Error(t, err) +} + +func mustEncode(buf []byte, err error) []byte { + if err != nil { + panic(err) + } + return buf +} + +func TestDeadlineContextWatcherHandler(t *testing.T) { + t.Parallel() + + t.Run("DeadlineExceeded with zero DeadlineDelay", func(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler { + return &pgconn.DeadlineContextWatcherHandler{Conn: conn.Conn()} + } + config.ConnectTimeout = 5 * time.Second + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err = pgConn.Exec(ctx, "select 1, pg_sleep(1)").ReadAll() + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) + require.True(t, pgConn.IsClosed()) + }) + + t.Run("DeadlineExceeded with DeadlineDelay", func(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler { + return &pgconn.DeadlineContextWatcherHandler{Conn: conn.Conn(), DeadlineDelay: 500 * time.Millisecond} + } + config.ConnectTimeout = 5 * time.Second + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, err = pgConn.Exec(ctx, "select 1, pg_sleep(0.250)").ReadAll() + require.NoError(t, err) + + ensureConnValid(t, pgConn) + }) +} + +func TestCancelRequestContextWatcherHandler(t *testing.T) { + t.Parallel() + + t.Run("DeadlineExceeded cancels request after CancelRequestDelay", func(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler { + return &pgconn.CancelRequestContextWatcherHandler{ + Conn: conn, + CancelRequestDelay: 250 * time.Millisecond, + DeadlineDelay: 5000 * time.Millisecond, + } + } + config.ConnectTimeout = 5 * time.Second + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err = pgConn.Exec(ctx, "select 1, pg_sleep(3)").ReadAll() + require.Error(t, err) + var pgErr *pgconn.PgError + require.ErrorAs(t, err, &pgErr) + + ensureConnValid(t, pgConn) + }) + + t.Run("DeadlineExceeded - do not send cancel request when query finishes in grace period", func(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler { + return &pgconn.CancelRequestContextWatcherHandler{ + Conn: conn, + CancelRequestDelay: 1000 * time.Millisecond, + DeadlineDelay: 5000 * time.Millisecond, + } + } + config.ConnectTimeout = 5 * time.Second + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, err = pgConn.Exec(ctx, "select 1, pg_sleep(0.250)").ReadAll() + require.NoError(t, err) + + ensureConnValid(t, pgConn) + }) + + t.Run("DeadlineExceeded sets conn deadline with DeadlineDelay", func(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler { + return &pgconn.CancelRequestContextWatcherHandler{ + Conn: conn, + CancelRequestDelay: 5000 * time.Millisecond, // purposely setting this higher than DeadlineDelay to ensure the cancel request never happens. + DeadlineDelay: 250 * time.Millisecond, + } + } + config.ConnectTimeout = 5 * time.Second + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err = pgConn.Exec(ctx, "select 1, pg_sleep(1)").ReadAll() + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) + require.True(t, pgConn.IsClosed()) + }) + + for i := 0; i < 10; i++ { + t.Run(fmt.Sprintf("Stress %d", i), func(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler { + return &pgconn.CancelRequestContextWatcherHandler{ + Conn: conn, + CancelRequestDelay: 5 * time.Millisecond, + DeadlineDelay: 1000 * time.Millisecond, + } + } + config.ConnectTimeout = 5 * time.Second + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + for i := 0; i < 20; i++ { + func() { + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Millisecond) + defer cancel() + pgConn.Exec(ctx, "select 1, pg_sleep(0.010)").ReadAll() + time.Sleep(100 * time.Millisecond) // ensure a cancel request that was a little late doesn't interrupt ensureConnValid. + ensureConnValid(t, pgConn) + }() + } + }) + } +} diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go deleted file mode 100644 index 71f18852e..000000000 --- a/pgmock/pgmock.go +++ /dev/null @@ -1,552 +0,0 @@ -package pgmock - -import ( - "io" - "net" - "reflect" - - "github.com/pkg/errors" - - "github.com/jackc/pgx/pgproto3" - "github.com/jackc/pgx/pgtype" -) - -type Server struct { - ln net.Listener - controller Controller -} - -func NewServer(controller Controller) (*Server, error) { - ln, err := net.Listen("tcp", "127.0.0.1:") - if err != nil { - return nil, err - } - - server := &Server{ - ln: ln, - controller: controller, - } - - return server, nil -} - -func (s *Server) Addr() net.Addr { - return s.ln.Addr() -} - -func (s *Server) ServeOne() error { - conn, err := s.ln.Accept() - if err != nil { - return err - } - defer conn.Close() - - s.Close() - - backend, err := pgproto3.NewBackend(conn, conn) - if err != nil { - conn.Close() - return err - } - - return s.controller.Serve(backend) -} - -func (s *Server) Close() error { - err := s.ln.Close() - if err != nil { - return err - } - - return nil -} - -type Controller interface { - Serve(backend *pgproto3.Backend) error -} - -type Step interface { - Step(*pgproto3.Backend) error -} - -type Script struct { - Steps []Step -} - -func (s *Script) Run(backend *pgproto3.Backend) error { - for _, step := range s.Steps { - err := step.Step(backend) - if err != nil { - return err - } - } - - return nil -} - -func (s *Script) Serve(backend *pgproto3.Backend) error { - for _, step := range s.Steps { - err := step.Step(backend) - if err != nil { - return err - } - } - - return nil -} - -func (s *Script) Step(backend *pgproto3.Backend) error { - return s.Serve(backend) -} - -type expectMessageStep struct { - want pgproto3.FrontendMessage - any bool -} - -func (e *expectMessageStep) Step(backend *pgproto3.Backend) error { - msg, err := backend.Receive() - if err != nil { - return err - } - - if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) { - return nil - } - - if !reflect.DeepEqual(msg, e.want) { - return errors.Errorf("msg => %#v, e.want => %#v", msg, e.want) - } - - return nil -} - -type expectStartupMessageStep struct { - want *pgproto3.StartupMessage - any bool -} - -func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error { - msg, err := backend.ReceiveStartupMessage() - if err != nil { - return err - } - - if e.any { - return nil - } - - if !reflect.DeepEqual(msg, e.want) { - return errors.Errorf("msg => %#v, e.want => %#v", msg, e.want) - } - - return nil -} - -func ExpectMessage(want pgproto3.FrontendMessage) Step { - return expectMessage(want, false) -} - -func ExpectAnyMessage(want pgproto3.FrontendMessage) Step { - return expectMessage(want, true) -} - -func expectMessage(want pgproto3.FrontendMessage, any bool) Step { - if want, ok := want.(*pgproto3.StartupMessage); ok { - return &expectStartupMessageStep{want: want, any: any} - } - - return &expectMessageStep{want: want, any: any} -} - -type sendMessageStep struct { - msg pgproto3.BackendMessage -} - -func (e *sendMessageStep) Step(backend *pgproto3.Backend) error { - return backend.Send(e.msg) -} - -func SendMessage(msg pgproto3.BackendMessage) Step { - return &sendMessageStep{msg: msg} -} - -type waitForCloseMessageStep struct{} - -func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error { - for { - msg, err := backend.Receive() - if err == io.EOF { - return nil - } else if err != nil { - return err - } - - if _, ok := msg.(*pgproto3.Terminate); ok { - return nil - } - } -} - -func WaitForClose() Step { - return &waitForCloseMessageStep{} -} - -func AcceptUnauthenticatedConnRequestSteps() []Step { - return []Step{ - ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), - SendMessage(&pgproto3.Authentication{Type: pgproto3.AuthTypeOk}), - SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), - SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), - } -} - -func PgxInitSteps() []Step { - steps := []Step{ - ExpectMessage(&pgproto3.Parse{ - Query: `select t.oid, - case when nsp.nspname in ('pg_catalog', 'public') then t.typname - else nsp.nspname||'.'||t.typname - end -from pg_type t -left join pg_type base_type on t.typelem=base_type.oid -left join pg_namespace nsp on t.typnamespace=nsp.oid -where ( - t.typtype in('b', 'p', 'r', 'e') - and (base_type.oid is null or base_type.typtype in('b', 'p', 'r')) - )`, - }), - ExpectMessage(&pgproto3.Describe{ - ObjectType: 'S', - }), - ExpectMessage(&pgproto3.Sync{}), - SendMessage(&pgproto3.ParseComplete{}), - SendMessage(&pgproto3.ParameterDescription{}), - SendMessage(&pgproto3.RowDescription{ - Fields: []pgproto3.FieldDescription{ - {Name: "oid", - TableOID: 1247, - TableAttributeNumber: 65534, - DataTypeOID: 26, - DataTypeSize: 4, - TypeModifier: 4294967295, - Format: 0, - }, - {Name: "typname", - TableOID: 1247, - TableAttributeNumber: 1, - DataTypeOID: 19, - DataTypeSize: 64, - TypeModifier: 4294967295, - Format: 0, - }, - }, - }), - SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), - ExpectMessage(&pgproto3.Bind{ - ResultFormatCodes: []int16{1, 1}, - }), - ExpectMessage(&pgproto3.Execute{}), - ExpectMessage(&pgproto3.Sync{}), - SendMessage(&pgproto3.BindComplete{}), - } - - rowVals := []struct { - oid pgtype.OID - name string - }{ - {16, "bool"}, - {17, "bytea"}, - {18, "char"}, - {19, "name"}, - {20, "int8"}, - {21, "int2"}, - {22, "int2vector"}, - {23, "int4"}, - {24, "regproc"}, - {25, "text"}, - {26, "oid"}, - {27, "tid"}, - {28, "xid"}, - {29, "cid"}, - {30, "oidvector"}, - {114, "json"}, - {142, "xml"}, - {143, "_xml"}, - {199, "_json"}, - {194, "pg_node_tree"}, - {32, "pg_ddl_command"}, - {210, "smgr"}, - {600, "point"}, - {601, "lseg"}, - {602, "path"}, - {603, "box"}, - {604, "polygon"}, - {628, "line"}, - {629, "_line"}, - {700, "float4"}, - {701, "float8"}, - {702, "abstime"}, - {703, "reltime"}, - {704, "tinterval"}, - {705, "unknown"}, - {718, "circle"}, - {719, "_circle"}, - {790, "money"}, - {791, "_money"}, - {829, "macaddr"}, - {869, "inet"}, - {650, "cidr"}, - {1000, "_bool"}, - {1001, "_bytea"}, - {1002, "_char"}, - {1003, "_name"}, - {1005, "_int2"}, - {1006, "_int2vector"}, - {1007, "_int4"}, - {1008, "_regproc"}, - {1009, "_text"}, - {1028, "_oid"}, - {1010, "_tid"}, - {1011, "_xid"}, - {1012, "_cid"}, - {1013, "_oidvector"}, - {1014, "_bpchar"}, - {1015, "_varchar"}, - {1016, "_int8"}, - {1017, "_point"}, - {1018, "_lseg"}, - {1019, "_path"}, - {1020, "_box"}, - {1021, "_float4"}, - {1022, "_float8"}, - {1023, "_abstime"}, - {1024, "_reltime"}, - {1025, "_tinterval"}, - {1027, "_polygon"}, - {1033, "aclitem"}, - {1034, "_aclitem"}, - {1040, "_macaddr"}, - {1041, "_inet"}, - {651, "_cidr"}, - {1263, "_cstring"}, - {1042, "bpchar"}, - {1043, "varchar"}, - {1082, "date"}, - {1083, "time"}, - {1114, "timestamp"}, - {1115, "_timestamp"}, - {1182, "_date"}, - {1183, "_time"}, - {1184, "timestamptz"}, - {1185, "_timestamptz"}, - {1186, "interval"}, - {1187, "_interval"}, - {1231, "_numeric"}, - {1266, "timetz"}, - {1270, "_timetz"}, - {1560, "bit"}, - {1561, "_bit"}, - {1562, "varbit"}, - {1563, "_varbit"}, - {1700, "numeric"}, - {1790, "refcursor"}, - {2201, "_refcursor"}, - {2202, "regprocedure"}, - {2203, "regoper"}, - {2204, "regoperator"}, - {2205, "regclass"}, - {2206, "regtype"}, - {4096, "regrole"}, - {4089, "regnamespace"}, - {2207, "_regprocedure"}, - {2208, "_regoper"}, - {2209, "_regoperator"}, - {2210, "_regclass"}, - {2211, "_regtype"}, - {4097, "_regrole"}, - {4090, "_regnamespace"}, - {2950, "uuid"}, - {2951, "_uuid"}, - {3220, "pg_lsn"}, - {3221, "_pg_lsn"}, - {3614, "tsvector"}, - {3642, "gtsvector"}, - {3615, "tsquery"}, - {3734, "regconfig"}, - {3769, "regdictionary"}, - {3643, "_tsvector"}, - {3644, "_gtsvector"}, - {3645, "_tsquery"}, - {3735, "_regconfig"}, - {3770, "_regdictionary"}, - {3802, "jsonb"}, - {3807, "_jsonb"}, - {2970, "txid_snapshot"}, - {2949, "_txid_snapshot"}, - {3904, "int4range"}, - {3905, "_int4range"}, - {3906, "numrange"}, - {3907, "_numrange"}, - {3908, "tsrange"}, - {3909, "_tsrange"}, - {3910, "tstzrange"}, - {3911, "_tstzrange"}, - {3912, "daterange"}, - {3913, "_daterange"}, - {3926, "int8range"}, - {3927, "_int8range"}, - {2249, "record"}, - {2287, "_record"}, - {2275, "cstring"}, - {2276, "any"}, - {2277, "anyarray"}, - {2278, "void"}, - {2279, "trigger"}, - {3838, "event_trigger"}, - {2280, "language_handler"}, - {2281, "internal"}, - {2282, "opaque"}, - {2283, "anyelement"}, - {2776, "anynonarray"}, - {3500, "anyenum"}, - {3115, "fdw_handler"}, - {325, "index_am_handler"}, - {3310, "tsm_handler"}, - {3831, "anyrange"}, - {51367, "gbtreekey4"}, - {51370, "_gbtreekey4"}, - {51371, "gbtreekey8"}, - {51374, "_gbtreekey8"}, - {51375, "gbtreekey16"}, - {51378, "_gbtreekey16"}, - {51379, "gbtreekey32"}, - {51382, "_gbtreekey32"}, - {51383, "gbtreekey_var"}, - {51386, "_gbtreekey_var"}, - {51921, "hstore"}, - {51926, "_hstore"}, - {52005, "ghstore"}, - {52008, "_ghstore"}, - } - - for _, rv := range rowVals { - step := SendMessage(mustBuildDataRow([]interface{}{rv.oid, rv.name}, []int16{pgproto3.BinaryFormat})) - steps = append(steps, step) - } - - steps = append(steps, SendMessage(&pgproto3.CommandComplete{CommandTag: "SELECT 163"})) - steps = append(steps, SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'})) - - steps = append(steps, []Step{ - ExpectMessage(&pgproto3.Parse{ - Query: "select t.oid, t.typname\nfrom pg_type t\n join pg_type base_type on t.typelem=base_type.oid\nwhere t.typtype = 'b'\n and base_type.typtype = 'e'", - }), - ExpectMessage(&pgproto3.Describe{ - ObjectType: 'S', - }), - ExpectMessage(&pgproto3.Sync{}), - SendMessage(&pgproto3.ParseComplete{}), - SendMessage(&pgproto3.ParameterDescription{}), - SendMessage(&pgproto3.RowDescription{ - Fields: []pgproto3.FieldDescription{ - {Name: "oid", - TableOID: 1247, - TableAttributeNumber: 65534, - DataTypeOID: 26, - DataTypeSize: 4, - TypeModifier: 4294967295, - Format: 0, - }, - {Name: "typname", - TableOID: 1247, - TableAttributeNumber: 1, - DataTypeOID: 19, - DataTypeSize: 64, - TypeModifier: 4294967295, - Format: 0, - }, - }, - }), - SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), - ExpectMessage(&pgproto3.Bind{ - ResultFormatCodes: []int16{1, 1}, - }), - ExpectMessage(&pgproto3.Execute{}), - ExpectMessage(&pgproto3.Sync{}), - SendMessage(&pgproto3.BindComplete{}), - SendMessage(&pgproto3.CommandComplete{CommandTag: "SELECT 0"}), - SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), - }...) - - return steps -} - -type dataRowValue struct { - Value interface{} - FormatCode int16 -} - -func mustBuildDataRow(values []interface{}, formatCodes []int16) *pgproto3.DataRow { - dr, err := buildDataRow(values, formatCodes) - if err != nil { - panic(err) - } - - return dr -} - -func buildDataRow(values []interface{}, formatCodes []int16) (*pgproto3.DataRow, error) { - dr := &pgproto3.DataRow{ - Values: make([][]byte, len(values)), - } - - if len(formatCodes) == 1 { - for i := 1; i < len(values); i++ { - formatCodes = append(formatCodes, formatCodes[0]) - } - } - - for i := range values { - switch v := values[i].(type) { - case string: - values[i] = &pgtype.Text{String: v, Status: pgtype.Present} - case int16: - values[i] = &pgtype.Int2{Int: v, Status: pgtype.Present} - case int32: - values[i] = &pgtype.Int4{Int: v, Status: pgtype.Present} - case int64: - values[i] = &pgtype.Int8{Int: v, Status: pgtype.Present} - } - } - - for i := range values { - switch formatCodes[i] { - case pgproto3.TextFormat: - if e, ok := values[i].(pgtype.TextEncoder); ok { - buf, err := e.EncodeText(nil, nil) - if err != nil { - return nil, errors.Errorf("failed to encode values[%d]", i) - } - dr.Values[i] = buf - } else { - return nil, errors.Errorf("values[%d] does not implement TextExcoder", i) - } - - case pgproto3.BinaryFormat: - if e, ok := values[i].(pgtype.BinaryEncoder); ok { - buf, err := e.EncodeBinary(nil, nil) - if err != nil { - return nil, errors.Errorf("failed to encode values[%d]", i) - } - dr.Values[i] = buf - } else { - return nil, errors.Errorf("values[%d] does not implement BinaryEncoder", i) - } - default: - return nil, errors.New("unknown FormatCode") - } - } - - return dr, nil -} diff --git a/pgpass.go b/pgpass.go deleted file mode 100644 index b6f028d27..000000000 --- a/pgpass.go +++ /dev/null @@ -1,85 +0,0 @@ -package pgx - -import ( - "bufio" - "fmt" - "os" - "os/user" - "path/filepath" - "strings" -) - -func parsepgpass(cfg *ConnConfig, line string) *string { - const ( - backslash = "\r" - colon = "\n" - ) - const ( - host int = iota - port - database - username - pw - ) - line = strings.Replace(line, `\:`, colon, -1) - line = strings.Replace(line, `\\`, backslash, -1) - parts := strings.Split(line, `:`) - if len(parts) != 5 { - return nil - } - for i := range parts { - if parts[i] == `*` { - continue - } - parts[i] = strings.Replace(strings.Replace(parts[i], backslash, `\`, -1), colon, `:`, -1) - switch i { - case host: - if parts[i] != cfg.Host { - return nil - } - case port: - portstr := fmt.Sprintf(`%v`, cfg.Port) - if portstr == "0" { - portstr = "5432" - } - if parts[i] != portstr { - return nil - } - case database: - if parts[i] != cfg.Database { - return nil - } - case username: - if parts[i] != cfg.User { - return nil - } - } - } - return &parts[4] -} - -func pgpass(cfg *ConnConfig) (found bool) { - passfile := os.Getenv("PGPASSFILE") - if passfile == "" { - u, err := user.Current() - if err != nil { - return - } - passfile = filepath.Join(u.HomeDir, ".pgpass") - } - f, err := os.Open(passfile) - if err != nil { - return - } - defer f.Close() - scanner := bufio.NewScanner(f) - var pw *string - for scanner.Scan() { - pw = parsepgpass(cfg, scanner.Text()) - if pw != nil { - cfg.Password = *pw - return true - } - } - return false -} diff --git a/pgpass_test.go b/pgpass_test.go deleted file mode 100644 index d36e811af..000000000 --- a/pgpass_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package pgx - -import ( - "fmt" - "io/ioutil" - "os" - "strings" - "testing" -) - -func unescape(s string) string { - s = strings.Replace(s, `\:`, `:`, -1) - s = strings.Replace(s, `\\`, `\`, -1) - return s -} - -var passfile = [][]string{ - {"test1", "5432", "larrydb", "larry", "whatstheidea"}, - {"test1", "5432", "moedb", "moe", "imbecile"}, - {"test1", "5432", "curlydb", "curly", "nyuknyuknyuk"}, - {"test2", "5432", "*", "shemp", "heymoe"}, - {"test2", "5432", "*", "*", `test\\ing\:`}, -} - -func TestPGPass(t *testing.T) { - tf, err := ioutil.TempFile("", "") - if err != nil { - t.Fatal(err) - } - defer tf.Close() - defer os.Remove(tf.Name()) - os.Setenv("PGPASSFILE", tf.Name()) - for _, l := range passfile { - _, err := fmt.Fprintln(tf, strings.Join(l, `:`)) - if err != nil { - t.Fatal(err) - } - } - if err = tf.Close(); err != nil { - t.Fatal(err) - } - for i, l := range passfile { - cfg := ConnConfig{Host: l[0], Database: l[2], User: l[3]} - found := pgpass(&cfg) - if !found { - t.Fatalf("Entry %v not found", i) - } - if cfg.Password != unescape(l[4]) { - t.Fatalf(`Password mismatch entry %v want %s got %s`, i, unescape(l[4]), cfg.Password) - } - } - cfg := ConnConfig{Host: "derp", Database: "herp", User: "joe"} - found := pgpass(&cfg) - if found { - t.Fatal("bad found") - } -} diff --git a/pgproto3/README.md b/pgproto3/README.md new file mode 100644 index 000000000..7a26f1cbd --- /dev/null +++ b/pgproto3/README.md @@ -0,0 +1,7 @@ +# pgproto3 + +Package pgproto3 is an encoder and decoder of the PostgreSQL wire protocol version 3. + +pgproto3 can be used as a foundation for PostgreSQL drivers, proxies, mock servers, load balancers and more. + +See example/pgfortune for a playful example of a fake PostgreSQL server. diff --git a/pgproto3/authentication.go b/pgproto3/authentication.go deleted file mode 100644 index 77750b862..000000000 --- a/pgproto3/authentication.go +++ /dev/null @@ -1,54 +0,0 @@ -package pgproto3 - -import ( - "encoding/binary" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -const ( - AuthTypeOk = 0 - AuthTypeCleartextPassword = 3 - AuthTypeMD5Password = 5 -) - -type Authentication struct { - Type uint32 - - // MD5Password fields - Salt [4]byte -} - -func (*Authentication) Backend() {} - -func (dst *Authentication) Decode(src []byte) error { - *dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])} - - switch dst.Type { - case AuthTypeOk: - case AuthTypeCleartextPassword: - case AuthTypeMD5Password: - copy(dst.Salt[:], src[4:8]) - default: - return errors.Errorf("unknown authentication type: %d", dst.Type) - } - - return nil -} - -func (src *Authentication) Encode(dst []byte) []byte { - dst = append(dst, 'R') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - dst = pgio.AppendUint32(dst, src.Type) - - switch src.Type { - case AuthTypeMD5Password: - dst = append(dst, src.Salt[:]...) - } - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst -} diff --git a/pgproto3/authentication_cleartext_password.go b/pgproto3/authentication_cleartext_password.go new file mode 100644 index 000000000..415e1a24a --- /dev/null +++ b/pgproto3/authentication_cleartext_password.go @@ -0,0 +1,50 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required. +type AuthenticationCleartextPassword struct{} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationCleartextPassword) Backend() {} + +// Backend identifies this message as an authentication response. +func (*AuthenticationCleartextPassword) AuthenticationResponse() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationCleartextPassword) Decode(src []byte) error { + if len(src) != 4 { + return errors.New("bad authentication message size") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeCleartextPassword { + return errors.New("bad auth type") + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationCleartextPassword) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') + dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword) + return finishMessage(dst, sp) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationCleartextPassword) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "AuthenticationCleartextPassword", + }) +} diff --git a/pgproto3/authentication_gss.go b/pgproto3/authentication_gss.go new file mode 100644 index 000000000..178ef31d8 --- /dev/null +++ b/pgproto3/authentication_gss.go @@ -0,0 +1,58 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type AuthenticationGSS struct{} + +func (a *AuthenticationGSS) Backend() {} + +func (a *AuthenticationGSS) AuthenticationResponse() {} + +func (a *AuthenticationGSS) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeGSS { + return errors.New("bad auth type") + } + return nil +} + +func (a *AuthenticationGSS) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') + dst = pgio.AppendUint32(dst, AuthTypeGSS) + return finishMessage(dst, sp) +} + +func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data []byte + }{ + Type: "AuthenticationGSS", + }) +} + +func (a *AuthenticationGSS) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + return nil +} diff --git a/pgproto3/authentication_gss_continue.go b/pgproto3/authentication_gss_continue.go new file mode 100644 index 000000000..2ba3f3b3e --- /dev/null +++ b/pgproto3/authentication_gss_continue.go @@ -0,0 +1,67 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type AuthenticationGSSContinue struct { + Data []byte +} + +func (a *AuthenticationGSSContinue) Backend() {} + +func (a *AuthenticationGSSContinue) AuthenticationResponse() {} + +func (a *AuthenticationGSSContinue) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeGSSCont { + return errors.New("bad auth type") + } + + a.Data = src[4:] + return nil +} + +func (a *AuthenticationGSSContinue) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') + dst = pgio.AppendUint32(dst, AuthTypeGSSCont) + dst = append(dst, a.Data...) + return finishMessage(dst, sp) +} + +func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data []byte + }{ + Type: "AuthenticationGSSContinue", + Data: a.Data, + }) +} + +func (a *AuthenticationGSSContinue) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + Data []byte + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + a.Data = msg.Data + return nil +} diff --git a/pgproto3/authentication_md5_password.go b/pgproto3/authentication_md5_password.go new file mode 100644 index 000000000..854c6404e --- /dev/null +++ b/pgproto3/authentication_md5_password.go @@ -0,0 +1,76 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// AuthenticationMD5Password is a message sent from the backend indicating that an MD5 hashed password is required. +type AuthenticationMD5Password struct { + Salt [4]byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationMD5Password) Backend() {} + +// Backend identifies this message as an authentication response. +func (*AuthenticationMD5Password) AuthenticationResponse() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationMD5Password) Decode(src []byte) error { + if len(src) != 8 { + return errors.New("bad authentication message size") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeMD5Password { + return errors.New("bad auth type") + } + + copy(dst.Salt[:], src[4:8]) + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationMD5Password) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') + dst = pgio.AppendUint32(dst, AuthTypeMD5Password) + dst = append(dst, src.Salt[:]...) + return finishMessage(dst, sp) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationMD5Password) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Salt [4]byte + }{ + Type: "AuthenticationMD5Password", + Salt: src.Salt, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + Salt [4]byte + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Salt = msg.Salt + return nil +} diff --git a/pgproto3/authentication_ok.go b/pgproto3/authentication_ok.go new file mode 100644 index 000000000..98c0b2d66 --- /dev/null +++ b/pgproto3/authentication_ok.go @@ -0,0 +1,50 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// AuthenticationOk is a message sent from the backend indicating that authentication was successful. +type AuthenticationOk struct{} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationOk) Backend() {} + +// Backend identifies this message as an authentication response. +func (*AuthenticationOk) AuthenticationResponse() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationOk) Decode(src []byte) error { + if len(src) != 4 { + return errors.New("bad authentication message size") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeOk { + return errors.New("bad auth type") + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationOk) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') + dst = pgio.AppendUint32(dst, AuthTypeOk) + return finishMessage(dst, sp) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationOk) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "AuthenticationOK", + }) +} diff --git a/pgproto3/authentication_sasl.go b/pgproto3/authentication_sasl.go new file mode 100644 index 000000000..e66580f44 --- /dev/null +++ b/pgproto3/authentication_sasl.go @@ -0,0 +1,72 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// AuthenticationSASL is a message sent from the backend indicating that SASL authentication is required. +type AuthenticationSASL struct { + AuthMechanisms []string +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationSASL) Backend() {} + +// Backend identifies this message as an authentication response. +func (*AuthenticationSASL) AuthenticationResponse() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationSASL) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeSASL { + return errors.New("bad auth type") + } + + authMechanisms := src[4:] + for len(authMechanisms) > 1 { + idx := bytes.IndexByte(authMechanisms, 0) + if idx == -1 { + return &invalidMessageFormatErr{messageType: "AuthenticationSASL", details: "unterminated string"} + } + dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx])) + authMechanisms = authMechanisms[idx+1:] + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') + dst = pgio.AppendUint32(dst, AuthTypeSASL) + + for _, s := range src.AuthMechanisms { + dst = append(dst, []byte(s)...) + dst = append(dst, 0) + } + dst = append(dst, 0) + + return finishMessage(dst, sp) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationSASL) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + AuthMechanisms []string + }{ + Type: "AuthenticationSASL", + AuthMechanisms: src.AuthMechanisms, + }) +} diff --git a/pgproto3/authentication_sasl_continue.go b/pgproto3/authentication_sasl_continue.go new file mode 100644 index 000000000..70fba4a67 --- /dev/null +++ b/pgproto3/authentication_sasl_continue.go @@ -0,0 +1,75 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// AuthenticationSASLContinue is a message sent from the backend containing a SASL challenge. +type AuthenticationSASLContinue struct { + Data []byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationSASLContinue) Backend() {} + +// Backend identifies this message as an authentication response. +func (*AuthenticationSASLContinue) AuthenticationResponse() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationSASLContinue) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeSASLContinue { + return errors.New("bad auth type") + } + + dst.Data = src[4:] + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationSASLContinue) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') + dst = pgio.AppendUint32(dst, AuthTypeSASLContinue) + dst = append(dst, src.Data...) + return finishMessage(dst, sp) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationSASLContinue) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "AuthenticationSASLContinue", + Data: string(src.Data), + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *AuthenticationSASLContinue) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Data = []byte(msg.Data) + return nil +} diff --git a/pgproto3/authentication_sasl_final.go b/pgproto3/authentication_sasl_final.go new file mode 100644 index 000000000..84976c2a3 --- /dev/null +++ b/pgproto3/authentication_sasl_final.go @@ -0,0 +1,75 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// AuthenticationSASLFinal is a message sent from the backend indicating a SASL authentication has completed. +type AuthenticationSASLFinal struct { + Data []byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationSASLFinal) Backend() {} + +// Backend identifies this message as an authentication response. +func (*AuthenticationSASLFinal) AuthenticationResponse() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationSASLFinal) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeSASLFinal { + return errors.New("bad auth type") + } + + dst.Data = src[4:] + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationSASLFinal) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') + dst = pgio.AppendUint32(dst, AuthTypeSASLFinal) + dst = append(dst, src.Data...) + return finishMessage(dst, sp) +} + +// MarshalJSON implements encoding/json.Unmarshaler. +func (src AuthenticationSASLFinal) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "AuthenticationSASLFinal", + Data: string(src.Data), + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *AuthenticationSASLFinal) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Data = []byte(msg.Data) + return nil +} diff --git a/pgproto3/backend.go b/pgproto3/backend.go index 8f3c3478a..28cff049a 100644 --- a/pgproto3/backend.go +++ b/pgproto3/backend.go @@ -1,74 +1,190 @@ package pgproto3 import ( + "bytes" "encoding/binary" + "fmt" "io" - - "github.com/jackc/pgx/chunkreader" - "github.com/pkg/errors" ) +// Backend acts as a server for the PostgreSQL wire protocol version 3. type Backend struct { - cr *chunkreader.ChunkReader + cr *chunkReader w io.Writer + // tracer is used to trace messages when Send or Receive is called. This means an outbound message is traced + // before it is actually transmitted (i.e. before Flush). + tracer *tracer + + wbuf []byte + encodeError error + // Frontend message flyweights - bind Bind - _close Close - describe Describe - execute Execute - flush Flush - parse Parse - passwordMessage PasswordMessage - query Query - startupMessage StartupMessage - sync Sync - terminate Terminate + bind Bind + cancelRequest CancelRequest + _close Close + copyFail CopyFail + copyData CopyData + copyDone CopyDone + describe Describe + execute Execute + flush Flush + functionCall FunctionCall + gssEncRequest GSSEncRequest + parse Parse + query Query + sslRequest SSLRequest + startupMessage StartupMessage + sync Sync + terminate Terminate bodyLen int + maxBodyLen int // maxBodyLen is the maximum length of a message body in octets. If a message body exceeds this length, Receive will return an error. msgType byte partialMsg bool + authType uint32 } -func NewBackend(r io.Reader, w io.Writer) (*Backend, error) { - cr := chunkreader.NewChunkReader(r) - return &Backend{cr: cr, w: w}, nil +const ( + minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code. + maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source. +) + +// NewBackend creates a new Backend. +func NewBackend(r io.Reader, w io.Writer) *Backend { + cr := newChunkReader(r, 0) + return &Backend{cr: cr, w: w} } -func (b *Backend) Send(msg BackendMessage) error { - _, err := b.w.Write(msg.Encode(nil)) - return err +// Send sends a message to the frontend (i.e. the client). The message is buffered until Flush is called. Any error +// encountered will be returned from Flush. +func (b *Backend) Send(msg BackendMessage) { + if b.encodeError != nil { + return + } + + prevLen := len(b.wbuf) + newBuf, err := msg.Encode(b.wbuf) + if err != nil { + b.encodeError = err + return + } + b.wbuf = newBuf + + if b.tracer != nil { + b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg) + } } -func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) { +// Flush writes any pending messages to the frontend (i.e. the client). +func (b *Backend) Flush() error { + if err := b.encodeError; err != nil { + b.encodeError = nil + b.wbuf = b.wbuf[:0] + return &writeError{err: err, safeToRetry: true} + } + + n, err := b.w.Write(b.wbuf) + + const maxLen = 1024 + if len(b.wbuf) > maxLen { + b.wbuf = make([]byte, 0, maxLen) + } else { + b.wbuf = b.wbuf[:0] + } + + if err != nil { + return &writeError{err: err, safeToRetry: n == 0} + } + + return nil +} + +// Trace starts tracing the message traffic to w. It writes in a similar format to that produced by the libpq function +// PQtrace. +func (b *Backend) Trace(w io.Writer, options TracerOptions) { + b.tracer = &tracer{ + w: w, + buf: &bytes.Buffer{}, + TracerOptions: options, + } +} + +// Untrace stops tracing. +func (b *Backend) Untrace() { + b.tracer = nil +} + +// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method +// because the initial connection message is "special" and does not include the message type as the first byte. This +// will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest. +func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { buf, err := b.cr.Next(4) if err != nil { return nil, err } msgSize := int(binary.BigEndian.Uint32(buf) - 4) - buf, err = b.cr.Next(msgSize) - if err != nil { - return nil, err + if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen { + return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize) } - err = b.startupMessage.Decode(buf) + buf, err = b.cr.Next(msgSize) if err != nil { - return nil, err + return nil, translateEOFtoErrUnexpectedEOF(err) } - return &b.startupMessage, nil + code := binary.BigEndian.Uint32(buf) + + switch code { + case ProtocolVersionNumber: + err = b.startupMessage.Decode(buf) + if err != nil { + return nil, err + } + return &b.startupMessage, nil + case sslRequestNumber: + err = b.sslRequest.Decode(buf) + if err != nil { + return nil, err + } + return &b.sslRequest, nil + case cancelRequestCode: + err = b.cancelRequest.Decode(buf) + if err != nil { + return nil, err + } + return &b.cancelRequest, nil + case gssEncReqNumber: + err = b.gssEncRequest.Decode(buf) + if err != nil { + return nil, err + } + return &b.gssEncRequest, nil + default: + return nil, fmt.Errorf("unknown startup message code: %d", code) + } } +// Receive receives a message from the frontend. The returned message is only valid until the next call to Receive. func (b *Backend) Receive() (FrontendMessage, error) { if !b.partialMsg { header, err := b.cr.Next(5) if err != nil { - return nil, err + return nil, translateEOFtoErrUnexpectedEOF(err) } b.msgType = header[0] - b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 + + msgLength := int(binary.BigEndian.Uint32(header[1:])) + if msgLength < 4 { + return nil, fmt.Errorf("invalid message length: %d", msgLength) + } + + b.bodyLen = msgLength - 4 + if b.maxBodyLen > 0 && b.bodyLen > b.maxBodyLen { + return nil, &ExceededMaxBodyLenErr{b.maxBodyLen, b.bodyLen} + } b.partialMsg = true } @@ -82,12 +198,34 @@ func (b *Backend) Receive() (FrontendMessage, error) { msg = &b.describe case 'E': msg = &b.execute + case 'F': + msg = &b.functionCall + case 'f': + msg = &b.copyFail + case 'd': + msg = &b.copyData + case 'c': + msg = &b.copyDone case 'H': msg = &b.flush case 'P': msg = &b.parse case 'p': - msg = &b.passwordMessage + switch b.authType { + case AuthTypeSASL: + msg = &SASLInitialResponse{} + case AuthTypeSASLContinue: + msg = &SASLResponse{} + case AuthTypeSASLFinal: + msg = &SASLResponse{} + case AuthTypeGSS, AuthTypeGSSCont: + msg = &GSSResponse{} + case AuthTypeCleartextPassword, AuthTypeMD5Password: + fallthrough + default: + // to maintain backwards compatibility + msg = &PasswordMessage{} + } case 'Q': msg = &b.query case 'S': @@ -95,16 +233,67 @@ func (b *Backend) Receive() (FrontendMessage, error) { case 'X': msg = &b.terminate default: - return nil, errors.Errorf("unknown message type: %c", b.msgType) + return nil, fmt.Errorf("unknown message type: %c", b.msgType) } msgBody, err := b.cr.Next(b.bodyLen) if err != nil { - return nil, err + return nil, translateEOFtoErrUnexpectedEOF(err) } b.partialMsg = false err = msg.Decode(msgBody) - return msg, err + if err != nil { + return nil, err + } + + if b.tracer != nil { + b.tracer.traceMessage('F', int32(5+len(msgBody)), msg) + } + + return msg, nil +} + +// SetAuthType sets the authentication type in the backend. +// Since multiple message types can start with 'p', SetAuthType allows +// contextual identification of FrontendMessages. For example, in the +// PG message flow documentation for PasswordMessage: +// +// Byte1('p') +// +// Identifies the message as a password response. Note that this is also used for +// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from +// the context. +// +// Since the Frontend does not know about the state of a backend, it is important +// to call SetAuthType() after an authentication request is received by the Frontend. +func (b *Backend) SetAuthType(authType uint32) error { + switch authType { + case AuthTypeOk, + AuthTypeCleartextPassword, + AuthTypeMD5Password, + AuthTypeSCMCreds, + AuthTypeGSS, + AuthTypeGSSCont, + AuthTypeSSPI, + AuthTypeSASL, + AuthTypeSASLContinue, + AuthTypeSASLFinal: + b.authType = authType + default: + return fmt.Errorf("authType not recognized: %d", authType) + } + + return nil +} + +// SetMaxBodyLen sets the maximum length of a message body in octets. +// If a message body exceeds this length, Receive will return an error. +// This is useful for protecting against malicious clients that send +// large messages with the intent of causing memory exhaustion. +// The default value is 0. +// If maxBodyLen is 0, then no maximum is enforced. +func (b *Backend) SetMaxBodyLen(maxBodyLen int) { + b.maxBodyLen = maxBodyLen } diff --git a/pgproto3/backend_key_data.go b/pgproto3/backend_key_data.go index 5a478f107..23f5da677 100644 --- a/pgproto3/backend_key_data.go +++ b/pgproto3/backend_key_data.go @@ -4,7 +4,7 @@ import ( "encoding/binary" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type BackendKeyData struct { @@ -12,8 +12,11 @@ type BackendKeyData struct { SecretKey uint32 } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*BackendKeyData) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *BackendKeyData) Decode(src []byte) error { if len(src) != 8 { return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)} @@ -25,15 +28,16 @@ func (dst *BackendKeyData) Decode(src []byte) error { return nil } -func (src *BackendKeyData) Encode(dst []byte) []byte { - dst = append(dst, 'K') - dst = pgio.AppendUint32(dst, 12) +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'K') dst = pgio.AppendUint32(dst, src.ProcessID) dst = pgio.AppendUint32(dst, src.SecretKey) - return dst + return finishMessage(dst, sp) } -func (src *BackendKeyData) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src BackendKeyData) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ProcessID uint32 diff --git a/pgproto3/backend_test.go b/pgproto3/backend_test.go index 02a5e9cae..5107ef76a 100644 --- a/pgproto3/backend_test.go +++ b/pgproto3/backend_test.go @@ -1,9 +1,13 @@ package pgproto3_test import ( + "io" "testing" - "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgx/v5/internal/pgio" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestBackendReceiveInterrupted(t *testing.T) { @@ -12,10 +16,7 @@ func TestBackendReceiveInterrupted(t *testing.T) { server := &interruptReader{} server.push([]byte{'Q', 0, 0, 0, 6}) - backend, err := pgproto3.NewBackend(server, nil) - if err != nil { - t.Fatal(err) - } + backend := pgproto3.NewBackend(server, nil) msg, err := backend.Receive() if err == nil { @@ -35,3 +36,105 @@ func TestBackendReceiveInterrupted(t *testing.T) { t.Fatalf("unexpected msg: %v", msg) } } + +func TestBackendReceiveUnexpectedEOF(t *testing.T) { + t.Parallel() + + server := &interruptReader{} + server.push([]byte{'Q', 0, 0, 0, 6}) + + backend := pgproto3.NewBackend(server, nil) + + // Receive regular msg + msg, err := backend.Receive() + assert.Nil(t, msg) + assert.Equal(t, io.ErrUnexpectedEOF, err) + + // Receive StartupMessage msg + dst := []byte{} + dst = pgio.AppendUint32(dst, 1000) // tell the backend we expect 1000 bytes to be read + dst = pgio.AppendUint32(dst, 1) // only send 1 byte + server.push(dst) + + msg, err = backend.ReceiveStartupMessage() + assert.Nil(t, msg) + assert.Equal(t, io.ErrUnexpectedEOF, err) +} + +func TestStartupMessage(t *testing.T) { + t.Parallel() + + t.Run("valid StartupMessage", func(t *testing.T) { + want := &pgproto3.StartupMessage{ + ProtocolVersion: pgproto3.ProtocolVersionNumber, + Parameters: map[string]string{ + "username": "tester", + }, + } + dst, err := want.Encode([]byte{}) + require.NoError(t, err) + + server := &interruptReader{} + server.push(dst) + + backend := pgproto3.NewBackend(server, nil) + + msg, err := backend.ReceiveStartupMessage() + require.NoError(t, err) + require.Equal(t, want, msg) + }) + + t.Run("invalid packet length", func(t *testing.T) { + wantErr := "invalid length of startup packet" + tests := []struct { + name string + packetLen uint32 + }{ + { + name: "large packet length", + // Since the StartupMessage contains the "Length of message contents + // in bytes, including self", the max startup packet length is actually + // 10000+4. Therefore, let's go past the limit with 10005 + packetLen: 10005, + }, + { + name: "short packet length", + packetLen: 3, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := &interruptReader{} + dst := []byte{} + dst = pgio.AppendUint32(dst, tt.packetLen) + dst = pgio.AppendUint32(dst, pgproto3.ProtocolVersionNumber) + server.push(dst) + + backend := pgproto3.NewBackend(server, nil) + + msg, err := backend.ReceiveStartupMessage() + require.Error(t, err) + require.Nil(t, msg) + require.Contains(t, err.Error(), wantErr) + }) + } + }) +} + +func TestBackendReceiveExceededMaxBodyLen(t *testing.T) { + t.Parallel() + + server := &interruptReader{} + server.push([]byte{'Q', 0, 0, 10, 10}) + + backend := pgproto3.NewBackend(server, nil) + + // Set max body len to 5 + backend.SetMaxBodyLen(5) + + // Receive regular msg + msg, err := backend.Receive() + assert.Nil(t, msg) + var invalidBodyLenErr *pgproto3.ExceededMaxBodyLenErr + assert.ErrorAs(t, err, &invalidBodyLenErr) +} diff --git a/pgproto3/bind.go b/pgproto3/bind.go index cceee6abd..ad6ac48bf 100644 --- a/pgproto3/bind.go +++ b/pgproto3/bind.go @@ -5,8 +5,11 @@ import ( "encoding/binary" "encoding/hex" "encoding/json" + "errors" + "fmt" + "math" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type Bind struct { @@ -17,8 +20,11 @@ type Bind struct { ResultFormatCodes []int16 } +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*Bind) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *Bind) Decode(src []byte) error { *dst = Bind{} @@ -103,21 +109,26 @@ func (dst *Bind) Decode(src []byte) error { return nil } -func (src *Bind) Encode(dst []byte) []byte { - dst = append(dst, 'B') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Bind) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'B') dst = append(dst, src.DestinationPortal...) dst = append(dst, 0) dst = append(dst, src.PreparedStatement...) dst = append(dst, 0) + if len(src.ParameterFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many parameter format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes))) for _, fc := range src.ParameterFormatCodes { dst = pgio.AppendInt16(dst, fc) } + if len(src.Parameters) > math.MaxUint16 { + return nil, errors.New("too many parameters") + } dst = pgio.AppendUint16(dst, uint16(len(src.Parameters))) for _, p := range src.Parameters { if p == nil { @@ -129,24 +140,33 @@ func (src *Bind) Encode(dst []byte) []byte { dst = append(dst, p...) } + if len(src.ResultFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many result format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes))) for _, fc := range src.ResultFormatCodes { dst = pgio.AppendInt16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } -func (src *Bind) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src Bind) MarshalJSON() ([]byte, error) { formattedParameters := make([]map[string]string, len(src.Parameters)) for i, p := range src.Parameters { if p == nil { continue } - if src.ParameterFormatCodes[i] == 0 { + textFormat := true + if len(src.ParameterFormatCodes) == 1 { + textFormat = src.ParameterFormatCodes[0] == 0 + } else if len(src.ParameterFormatCodes) > 1 { + textFormat = src.ParameterFormatCodes[i] == 0 + } + + if textFormat { formattedParameters[i] = map[string]string{"text": string(p)} } else { formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)} @@ -169,3 +189,35 @@ func (src *Bind) MarshalJSON() ([]byte, error) { ResultFormatCodes: src.ResultFormatCodes, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *Bind) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + DestinationPortal string + PreparedStatement string + ParameterFormatCodes []int16 + Parameters []map[string]string + ResultFormatCodes []int16 + } + err := json.Unmarshal(data, &msg) + if err != nil { + return err + } + dst.DestinationPortal = msg.DestinationPortal + dst.PreparedStatement = msg.PreparedStatement + dst.ParameterFormatCodes = msg.ParameterFormatCodes + dst.Parameters = make([][]byte, len(msg.Parameters)) + dst.ResultFormatCodes = msg.ResultFormatCodes + for n, parameter := range msg.Parameters { + dst.Parameters[n], err = getValueFromJSON(parameter) + if err != nil { + return fmt.Errorf("cannot get param %d: %w", n, err) + } + } + return nil +} diff --git a/pgproto3/bind_complete.go b/pgproto3/bind_complete.go index 603605195..bacf30d88 100644 --- a/pgproto3/bind_complete.go +++ b/pgproto3/bind_complete.go @@ -6,8 +6,11 @@ import ( type BindComplete struct{} +// Backend identifies this message as sendable by the PostgreSQL backend. func (*BindComplete) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *BindComplete) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)} @@ -16,11 +19,13 @@ func (dst *BindComplete) Decode(src []byte) error { return nil } -func (src *BindComplete) Encode(dst []byte) []byte { - return append(dst, '2', 0, 0, 0, 4) +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *BindComplete) Encode(dst []byte) ([]byte, error) { + return append(dst, '2', 0, 0, 0, 4), nil } -func (src *BindComplete) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src BindComplete) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string }{ diff --git a/pgproto3/bind_test.go b/pgproto3/bind_test.go new file mode 100644 index 000000000..6ec0e0245 --- /dev/null +++ b/pgproto3/bind_test.go @@ -0,0 +1,20 @@ +package pgproto3_test + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/require" +) + +func TestBindBiggerThanMaxMessageBodyLen(t *testing.T) { + t.Parallel() + + // Maximum allowed size. + _, err := (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-16)}}).Encode(nil) + require.NoError(t, err) + + // 1 byte too big + _, err = (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-15)}}).Encode(nil) + require.Error(t, err) +} diff --git a/pgproto3/cancel_request.go b/pgproto3/cancel_request.go new file mode 100644 index 000000000..6b52dd977 --- /dev/null +++ b/pgproto3/cancel_request.go @@ -0,0 +1,58 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +const cancelRequestCode = 80877102 + +type CancelRequest struct { + ProcessID uint32 + SecretKey uint32 +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*CancelRequest) Frontend() {} + +func (dst *CancelRequest) Decode(src []byte) error { + if len(src) != 12 { + return errors.New("bad cancel request size") + } + + requestCode := binary.BigEndian.Uint32(src) + + if requestCode != cancelRequestCode { + return errors.New("bad cancel request code") + } + + dst.ProcessID = binary.BigEndian.Uint32(src[4:]) + dst.SecretKey = binary.BigEndian.Uint32(src[8:]) + + return nil +} + +// Encode encodes src into dst. dst will include the 4 byte message length. +func (src *CancelRequest) Encode(dst []byte) ([]byte, error) { + dst = pgio.AppendInt32(dst, 16) + dst = pgio.AppendInt32(dst, cancelRequestCode) + dst = pgio.AppendUint32(dst, src.ProcessID) + dst = pgio.AppendUint32(dst, src.SecretKey) + return dst, nil +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CancelRequest) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProcessID uint32 + SecretKey uint32 + }{ + Type: "CancelRequest", + ProcessID: src.ProcessID, + SecretKey: src.SecretKey, + }) +} diff --git a/pgproto3/chunkreader.go b/pgproto3/chunkreader.go new file mode 100644 index 000000000..fc0fa61e9 --- /dev/null +++ b/pgproto3/chunkreader.go @@ -0,0 +1,90 @@ +package pgproto3 + +import ( + "io" + + "github.com/jackc/pgx/v5/internal/iobufpool" +) + +// chunkReader is a io.Reader wrapper that minimizes IO reads and memory allocations. It allocates memory in chunks and +// will read as much as will fit in the current buffer in a single call regardless of how large a read is actually +// requested. The memory returned via Next is only valid until the next call to Next. +// +// This is roughly equivalent to a bufio.Reader that only uses Peek and Discard to never copy bytes. +type chunkReader struct { + r io.Reader + + buf *[]byte + rp, wp int // buf read position and write position + + minBufSize int +} + +// newChunkReader creates and returns a new chunkReader for r with default configuration. If minBufSize is <= 0 it uses +// a default value. +func newChunkReader(r io.Reader, minBufSize int) *chunkReader { + if minBufSize <= 0 { + // By historical reasons Postgres currently has 8KB send buffer inside, + // so here we want to have at least the same size buffer. + // @see https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134 + // @see https://www.postgresql.org/message-id/0cdc5485-cb3c-5e16-4a46-e3b2f7a41322%40ya.ru + // + // In addition, testing has found no benefit of any larger buffer. + minBufSize = 8192 + } + + return &chunkReader{ + r: r, + minBufSize: minBufSize, + buf: iobufpool.Get(minBufSize), + } +} + +// Next returns buf filled with the next n bytes. buf is only valid until next call of Next. If an error occurs, buf +// will be nil. +func (r *chunkReader) Next(n int) (buf []byte, err error) { + // Reset the buffer if it is empty + if r.rp == r.wp { + if len(*r.buf) != r.minBufSize { + iobufpool.Put(r.buf) + r.buf = iobufpool.Get(r.minBufSize) + } + r.rp = 0 + r.wp = 0 + } + + // n bytes already in buf + if (r.wp - r.rp) >= n { + buf = (*r.buf)[r.rp : r.rp+n : r.rp+n] + r.rp += n + return buf, err + } + + // buf is smaller than requested number of bytes + if len(*r.buf) < n { + bigBuf := iobufpool.Get(n) + r.wp = copy((*bigBuf), (*r.buf)[r.rp:r.wp]) + r.rp = 0 + iobufpool.Put(r.buf) + r.buf = bigBuf + } + + // buf is large enough, but need to shift filled area to start to make enough contiguous space + minReadCount := n - (r.wp - r.rp) + if (len(*r.buf) - r.wp) < minReadCount { + r.wp = copy((*r.buf), (*r.buf)[r.rp:r.wp]) + r.rp = 0 + } + + // Read at least the required number of bytes from the underlying io.Reader + readBytesCount, err := io.ReadAtLeast(r.r, (*r.buf)[r.wp:], minReadCount) + r.wp += readBytesCount + // fmt.Println("read", n) + if err != nil { + return nil, err + } + + buf = (*r.buf)[r.rp : r.rp+n : r.rp+n] + r.rp += n + return buf, nil +} diff --git a/pgproto3/chunkreader_test.go b/pgproto3/chunkreader_test.go new file mode 100644 index 000000000..f509fb6eb --- /dev/null +++ b/pgproto3/chunkreader_test.go @@ -0,0 +1,75 @@ +package pgproto3 + +import ( + "bytes" + "math/rand" + "testing" +) + +func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { + server := &bytes.Buffer{} + r := newChunkReader(server, 4) + + src := []byte{1, 2, 3, 4} + server.Write(src) + + n1, err := r.Next(2) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(n1, src[0:2]) { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:2], n1) + } + + n2, err := r.Next(2) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(n2, src[2:4]) { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[2:4], n2) + } + + if !bytes.Equal((*r.buf)[:len(src)], src) { + t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf) + } + + _, err = r.Next(0) // Trigger the buffer reset. + if err != nil { + t.Fatal(err) + } + + if r.rp != 0 { + t.Fatalf("Expected r.rp to be %v, but it was %v", 0, r.rp) + } + if r.wp != 0 { + t.Fatalf("Expected r.wp to be %v, but it was %v", 0, r.wp) + } +} + +type randomReader struct { + rnd *rand.Rand +} + +// Read reads a random number of random bytes. +func (r *randomReader) Read(p []byte) (n int, err error) { + n = r.rnd.Intn(len(p) + 1) + return r.rnd.Read(p[:n]) +} + +func TestChunkReaderNextFuzz(t *testing.T) { + rr := &randomReader{rnd: rand.New(rand.NewSource(1))} + r := newChunkReader(rr, 8192) + + randomSizes := rand.New(rand.NewSource(0)) + + for i := 0; i < 100000; i++ { + size := randomSizes.Intn(16384) + 1 + buf, err := r.Next(size) + if err != nil { + t.Fatal(err) + } + if len(buf) != size { + t.Fatalf("Expected to get %v bytes but got %v bytes", size, len(buf)) + } + } +} diff --git a/pgproto3/close.go b/pgproto3/close.go index 5ff4c8861..0b50f27cb 100644 --- a/pgproto3/close.go +++ b/pgproto3/close.go @@ -3,8 +3,7 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgx/pgio" + "errors" ) type Close struct { @@ -12,8 +11,11 @@ type Close struct { Name string } +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*Close) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *Close) Decode(src []byte) error { if len(src) < 2 { return &invalidMessageFormatErr{messageType: "Close"} @@ -32,21 +34,17 @@ func (dst *Close) Decode(src []byte) error { return nil } -func (src *Close) Encode(dst []byte) []byte { - dst = append(dst, 'C') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Close) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'C') dst = append(dst, src.ObjectType) dst = append(dst, src.Name...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } -func (src *Close) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src Close) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ObjectType string @@ -57,3 +55,27 @@ func (src *Close) MarshalJSON() ([]byte, error) { Name: src.Name, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *Close) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + ObjectType string + Name string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.ObjectType) != 1 { + return errors.New("invalid length for Close.ObjectType") + } + + dst.ObjectType = byte(msg.ObjectType[0]) + dst.Name = msg.Name + return nil +} diff --git a/pgproto3/close_complete.go b/pgproto3/close_complete.go index db793c94c..833f7a12c 100644 --- a/pgproto3/close_complete.go +++ b/pgproto3/close_complete.go @@ -6,8 +6,11 @@ import ( type CloseComplete struct{} +// Backend identifies this message as sendable by the PostgreSQL backend. func (*CloseComplete) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *CloseComplete) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)} @@ -16,11 +19,13 @@ func (dst *CloseComplete) Decode(src []byte) error { return nil } -func (src *CloseComplete) Encode(dst []byte) []byte { - return append(dst, '3', 0, 0, 0, 4) +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CloseComplete) Encode(dst []byte) ([]byte, error) { + return append(dst, '3', 0, 0, 0, 4), nil } -func (src *CloseComplete) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src CloseComplete) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string }{ diff --git a/pgproto3/command_complete.go b/pgproto3/command_complete.go index 858485326..eba70947d 100644 --- a/pgproto3/command_complete.go +++ b/pgproto3/command_complete.go @@ -3,46 +3,64 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgx/pgio" ) type CommandComplete struct { - CommandTag string + CommandTag []byte } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*CommandComplete) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *CommandComplete) Decode(src []byte) error { idx := bytes.IndexByte(src, 0) + if idx == -1 { + return &invalidMessageFormatErr{messageType: "CommandComplete", details: "unterminated string"} + } if idx != len(src)-1 { - return &invalidMessageFormatErr{messageType: "CommandComplete"} + return &invalidMessageFormatErr{messageType: "CommandComplete", details: "string terminated too early"} } - dst.CommandTag = string(src[:idx]) + dst.CommandTag = src[:idx] return nil } -func (src *CommandComplete) Encode(dst []byte) []byte { - dst = append(dst, 'C') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CommandComplete) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'C') dst = append(dst, src.CommandTag...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } -func (src *CommandComplete) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src CommandComplete) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string CommandTag string }{ Type: "CommandComplete", - CommandTag: src.CommandTag, + CommandTag: string(src.CommandTag), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CommandComplete) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + CommandTag string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.CommandTag = []byte(msg.CommandTag) + return nil +} diff --git a/pgproto3/copy_both_response.go b/pgproto3/copy_both_response.go index 2862a34f6..99e1afea4 100644 --- a/pgproto3/copy_both_response.go +++ b/pgproto3/copy_both_response.go @@ -4,8 +4,10 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type CopyBothResponse struct { @@ -13,8 +15,11 @@ type CopyBothResponse struct { ColumnFormatCodes []uint16 } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*CopyBothResponse) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *CopyBothResponse) Decode(src []byte) error { buf := bytes.NewBuffer(src) @@ -39,22 +44,23 @@ func (dst *CopyBothResponse) Decode(src []byte) error { return nil } -func (src *CopyBothResponse) Encode(dst []byte) []byte { - dst = append(dst, 'W') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CopyBothResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'W') + dst = append(dst, src.OverallFormat) + if len(src.ColumnFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many column format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } -func (src *CopyBothResponse) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src CopyBothResponse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ColumnFormatCodes []uint16 @@ -63,3 +69,27 @@ func (src *CopyBothResponse) MarshalJSON() ([]byte, error) { ColumnFormatCodes: src.ColumnFormatCodes, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyBothResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + OverallFormat string + ColumnFormatCodes []uint16 + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.OverallFormat) != 1 { + return errors.New("invalid length for CopyBothResponse.OverallFormat") + } + + dst.OverallFormat = msg.OverallFormat[0] + dst.ColumnFormatCodes = msg.ColumnFormatCodes + return nil +} diff --git a/pgproto3/copy_both_response_test.go b/pgproto3/copy_both_response_test.go new file mode 100644 index 000000000..1c988f21d --- /dev/null +++ b/pgproto3/copy_both_response_test.go @@ -0,0 +1,20 @@ +package pgproto3_test + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEncodeDecode(t *testing.T) { + srcBytes := []byte{'W', 0x00, 0x00, 0x00, 0x0b, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01} + dstResp := pgproto3.CopyBothResponse{} + err := dstResp.Decode(srcBytes[5:]) + assert.NoError(t, err, "No errors on decode") + dstBytes := []byte{} + dstBytes, err = dstResp.Encode(dstBytes) + require.NoError(t, err) + assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match") +} diff --git a/pgproto3/copy_data.go b/pgproto3/copy_data.go index fab139e61..89ecdd4dd 100644 --- a/pgproto3/copy_data.go +++ b/pgproto3/copy_data.go @@ -3,30 +3,34 @@ package pgproto3 import ( "encoding/hex" "encoding/json" - - "github.com/jackc/pgx/pgio" ) type CopyData struct { Data []byte } -func (*CopyData) Backend() {} +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*CopyData) Backend() {} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*CopyData) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *CopyData) Decode(src []byte) error { dst.Data = src return nil } -func (src *CopyData) Encode(dst []byte) []byte { - dst = append(dst, 'd') - dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CopyData) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'd') dst = append(dst, src.Data...) - return dst + return finishMessage(dst, sp) } -func (src *CopyData) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src CopyData) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string Data string @@ -35,3 +39,21 @@ func (src *CopyData) MarshalJSON() ([]byte, error) { Data: hex.EncodeToString(src.Data), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyData) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Data = []byte(msg.Data) + return nil +} diff --git a/pgproto3/copy_done.go b/pgproto3/copy_done.go new file mode 100644 index 000000000..c3421a9b5 --- /dev/null +++ b/pgproto3/copy_done.go @@ -0,0 +1,37 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type CopyDone struct{} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*CopyDone) Backend() {} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*CopyDone) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *CopyDone) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "CopyDone", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CopyDone) Encode(dst []byte) ([]byte, error) { + return append(dst, 'c', 0, 0, 0, 4), nil +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CopyDone) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "CopyDone", + }) +} diff --git a/pgproto3/copy_fail.go b/pgproto3/copy_fail.go new file mode 100644 index 000000000..72a85fd09 --- /dev/null +++ b/pgproto3/copy_fail.go @@ -0,0 +1,45 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" +) + +type CopyFail struct { + Message string +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*CopyFail) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *CopyFail) Decode(src []byte) error { + idx := bytes.IndexByte(src, 0) + if idx != len(src)-1 { + return &invalidMessageFormatErr{messageType: "CopyFail"} + } + + dst.Message = string(src[:idx]) + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CopyFail) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'f') + dst = append(dst, src.Message...) + dst = append(dst, 0) + return finishMessage(dst, sp) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CopyFail) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Message string + }{ + Type: "CopyFail", + Message: src.Message, + }) +} diff --git a/pgproto3/copy_in_response.go b/pgproto3/copy_in_response.go index 54083cd63..06cf99ced 100644 --- a/pgproto3/copy_in_response.go +++ b/pgproto3/copy_in_response.go @@ -4,8 +4,10 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type CopyInResponse struct { @@ -13,8 +15,11 @@ type CopyInResponse struct { ColumnFormatCodes []uint16 } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*CopyInResponse) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *CopyInResponse) Decode(src []byte) error { buf := bytes.NewBuffer(src) @@ -39,22 +44,24 @@ func (dst *CopyInResponse) Decode(src []byte) error { return nil } -func (src *CopyInResponse) Encode(dst []byte) []byte { - dst = append(dst, 'G') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CopyInResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'G') + dst = append(dst, src.OverallFormat) + if len(src.ColumnFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many column format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } -func (src *CopyInResponse) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src CopyInResponse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ColumnFormatCodes []uint16 @@ -63,3 +70,27 @@ func (src *CopyInResponse) MarshalJSON() ([]byte, error) { ColumnFormatCodes: src.ColumnFormatCodes, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyInResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + OverallFormat string + ColumnFormatCodes []uint16 + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.OverallFormat) != 1 { + return errors.New("invalid length for CopyInResponse.OverallFormat") + } + + dst.OverallFormat = msg.OverallFormat[0] + dst.ColumnFormatCodes = msg.ColumnFormatCodes + return nil +} diff --git a/pgproto3/copy_out_response.go b/pgproto3/copy_out_response.go index eaa33b8bd..549e916c1 100644 --- a/pgproto3/copy_out_response.go +++ b/pgproto3/copy_out_response.go @@ -4,8 +4,10 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type CopyOutResponse struct { @@ -15,6 +17,8 @@ type CopyOutResponse struct { func (*CopyOutResponse) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *CopyOutResponse) Decode(src []byte) error { buf := bytes.NewBuffer(src) @@ -39,22 +43,25 @@ func (dst *CopyOutResponse) Decode(src []byte) error { return nil } -func (src *CopyOutResponse) Encode(dst []byte) []byte { - dst = append(dst, 'H') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CopyOutResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'H') + dst = append(dst, src.OverallFormat) + + if len(src.ColumnFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many column format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } -func (src *CopyOutResponse) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src CopyOutResponse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ColumnFormatCodes []uint16 @@ -63,3 +70,27 @@ func (src *CopyOutResponse) MarshalJSON() ([]byte, error) { ColumnFormatCodes: src.ColumnFormatCodes, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyOutResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + OverallFormat string + ColumnFormatCodes []uint16 + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.OverallFormat) != 1 { + return errors.New("invalid length for CopyOutResponse.OverallFormat") + } + + dst.OverallFormat = msg.OverallFormat[0] + dst.ColumnFormatCodes = msg.ColumnFormatCodes + return nil +} diff --git a/pgproto3/data_row.go b/pgproto3/data_row.go index e46d3cc0a..fdfb0f7f6 100644 --- a/pgproto3/data_row.go +++ b/pgproto3/data_row.go @@ -4,16 +4,21 @@ import ( "encoding/binary" "encoding/hex" "encoding/json" + "errors" + "math" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type DataRow struct { Values [][]byte } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*DataRow) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *DataRow) Decode(src []byte) error { if len(src) < 2 { return &invalidMessageFormatErr{messageType: "DataRow"} @@ -40,30 +45,32 @@ func (dst *DataRow) Decode(src []byte) error { return &invalidMessageFormatErr{messageType: "DataRow"} } - msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) + valueLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 // null - if msgSize == -1 { + if valueLen == -1 { dst.Values[i] = nil } else { - if len(src[rp:]) < msgSize { + if len(src[rp:]) < valueLen || valueLen < 0 { return &invalidMessageFormatErr{messageType: "DataRow"} } - dst.Values[i] = src[rp : rp+msgSize] - rp += msgSize + dst.Values[i] = src[rp : rp+valueLen : rp+valueLen] + rp += valueLen } } return nil } -func (src *DataRow) Encode(dst []byte) []byte { - dst = append(dst, 'D') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *DataRow) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'D') + if len(src.Values) > math.MaxUint16 { + return nil, errors.New("too many values") + } dst = pgio.AppendUint16(dst, uint16(len(src.Values))) for _, v := range src.Values { if v == nil { @@ -75,12 +82,11 @@ func (src *DataRow) Encode(dst []byte) []byte { dst = append(dst, v...) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } -func (src *DataRow) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src DataRow) MarshalJSON() ([]byte, error) { formattedValues := make([]map[string]string, len(src.Values)) for i, v := range src.Values { if v == nil { @@ -110,3 +116,28 @@ func (src *DataRow) MarshalJSON() ([]byte, error) { Values: formattedValues, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *DataRow) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Values []map[string]string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Values = make([][]byte, len(msg.Values)) + for n, parameter := range msg.Values { + var err error + dst.Values[n], err = getValueFromJSON(parameter) + if err != nil { + return err + } + } + return nil +} diff --git a/pgproto3/describe.go b/pgproto3/describe.go index bb7bc0563..89feff215 100644 --- a/pgproto3/describe.go +++ b/pgproto3/describe.go @@ -3,8 +3,7 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgx/pgio" + "errors" ) type Describe struct { @@ -12,8 +11,11 @@ type Describe struct { Name string } +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*Describe) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *Describe) Decode(src []byte) error { if len(src) < 2 { return &invalidMessageFormatErr{messageType: "Describe"} @@ -32,21 +34,17 @@ func (dst *Describe) Decode(src []byte) error { return nil } -func (src *Describe) Encode(dst []byte) []byte { - dst = append(dst, 'D') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Describe) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'D') dst = append(dst, src.ObjectType) dst = append(dst, src.Name...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } -func (src *Describe) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src Describe) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ObjectType string @@ -57,3 +55,26 @@ func (src *Describe) MarshalJSON() ([]byte, error) { Name: src.Name, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *Describe) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + ObjectType string + Name string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + if len(msg.ObjectType) != 1 { + return errors.New("invalid length for Describe.ObjectType") + } + + dst.ObjectType = byte(msg.ObjectType[0]) + dst.Name = msg.Name + return nil +} diff --git a/pgproto3/doc.go b/pgproto3/doc.go new file mode 100644 index 000000000..0afd18e29 --- /dev/null +++ b/pgproto3/doc.go @@ -0,0 +1,11 @@ +// Package pgproto3 is an encoder and decoder of the PostgreSQL wire protocol version 3. +// +// The primary interfaces are Frontend and Backend. They correspond to a client and server respectively. Messages are +// sent with Send (or a specialized Send variant). Messages are automatically buffered to minimize small writes. Call +// Flush to ensure a message has actually been sent. +// +// The Trace method of Frontend and Backend can be used to examine the wire-level message traffic. It outputs in a +// similar format to the PQtrace function in libpq. +// +// See https://www.postgresql.org/docs/current/protocol-message-formats.html for meanings of the different messages. +package pgproto3 diff --git a/pgproto3/empty_query_response.go b/pgproto3/empty_query_response.go index d283b06df..cb6cca073 100644 --- a/pgproto3/empty_query_response.go +++ b/pgproto3/empty_query_response.go @@ -6,8 +6,11 @@ import ( type EmptyQueryResponse struct{} +// Backend identifies this message as sendable by the PostgreSQL backend. func (*EmptyQueryResponse) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *EmptyQueryResponse) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)} @@ -16,11 +19,13 @@ func (dst *EmptyQueryResponse) Decode(src []byte) error { return nil } -func (src *EmptyQueryResponse) Encode(dst []byte) []byte { - return append(dst, 'I', 0, 0, 0, 4) +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *EmptyQueryResponse) Encode(dst []byte) ([]byte, error) { + return append(dst, 'I', 0, 0, 0, 4), nil } -func (src *EmptyQueryResponse) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src EmptyQueryResponse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string }{ diff --git a/pgproto3/error_response.go b/pgproto3/error_response.go index 160234f20..6ef9bd061 100644 --- a/pgproto3/error_response.go +++ b/pgproto3/error_response.go @@ -2,34 +2,38 @@ package pgproto3 import ( "bytes" - "encoding/binary" + "encoding/json" "strconv" ) type ErrorResponse struct { - Severity string - Code string - Message string - Detail string - Hint string - Position int32 - InternalPosition int32 - InternalQuery string - Where string - SchemaName string - TableName string - ColumnName string - DataTypeName string - ConstraintName string - File string - Line int32 - Routine string + Severity string + SeverityUnlocalized string // only in 9.6 and greater + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string UnknownFields map[byte]string } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*ErrorResponse) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *ErrorResponse) Decode(src []byte) error { *dst = ErrorResponse{} @@ -53,6 +57,8 @@ func (dst *ErrorResponse) Decode(src []byte) error { switch k { case 'S': dst.Severity = v + case 'V': + dst.SeverityUnlocalized = v case 'C': dst.Code = v case 'M': @@ -103,95 +109,218 @@ func (dst *ErrorResponse) Decode(src []byte) error { return nil } -func (src *ErrorResponse) Encode(dst []byte) []byte { - return append(dst, src.marshalBinary('E')...) +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *ErrorResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'E') + dst = src.appendFields(dst) + return finishMessage(dst, sp) } -func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte(typeByte) - buf.Write(bigEndian.Uint32(0)) - +func (src *ErrorResponse) appendFields(dst []byte) []byte { if src.Severity != "" { - buf.WriteString(src.Severity) - buf.WriteByte(0) + dst = append(dst, 'S') + dst = append(dst, src.Severity...) + dst = append(dst, 0) + } + if src.SeverityUnlocalized != "" { + dst = append(dst, 'V') + dst = append(dst, src.SeverityUnlocalized...) + dst = append(dst, 0) } if src.Code != "" { - buf.WriteString(src.Code) - buf.WriteByte(0) + dst = append(dst, 'C') + dst = append(dst, src.Code...) + dst = append(dst, 0) } if src.Message != "" { - buf.WriteString(src.Message) - buf.WriteByte(0) + dst = append(dst, 'M') + dst = append(dst, src.Message...) + dst = append(dst, 0) } if src.Detail != "" { - buf.WriteString(src.Detail) - buf.WriteByte(0) + dst = append(dst, 'D') + dst = append(dst, src.Detail...) + dst = append(dst, 0) } if src.Hint != "" { - buf.WriteString(src.Hint) - buf.WriteByte(0) + dst = append(dst, 'H') + dst = append(dst, src.Hint...) + dst = append(dst, 0) } if src.Position != 0 { - buf.WriteString(strconv.Itoa(int(src.Position))) - buf.WriteByte(0) + dst = append(dst, 'P') + dst = append(dst, strconv.Itoa(int(src.Position))...) + dst = append(dst, 0) } if src.InternalPosition != 0 { - buf.WriteString(strconv.Itoa(int(src.InternalPosition))) - buf.WriteByte(0) + dst = append(dst, 'p') + dst = append(dst, strconv.Itoa(int(src.InternalPosition))...) + dst = append(dst, 0) } if src.InternalQuery != "" { - buf.WriteString(src.InternalQuery) - buf.WriteByte(0) + dst = append(dst, 'q') + dst = append(dst, src.InternalQuery...) + dst = append(dst, 0) } if src.Where != "" { - buf.WriteString(src.Where) - buf.WriteByte(0) + dst = append(dst, 'W') + dst = append(dst, src.Where...) + dst = append(dst, 0) } if src.SchemaName != "" { - buf.WriteString(src.SchemaName) - buf.WriteByte(0) + dst = append(dst, 's') + dst = append(dst, src.SchemaName...) + dst = append(dst, 0) } if src.TableName != "" { - buf.WriteString(src.TableName) - buf.WriteByte(0) + dst = append(dst, 't') + dst = append(dst, src.TableName...) + dst = append(dst, 0) } if src.ColumnName != "" { - buf.WriteString(src.ColumnName) - buf.WriteByte(0) + dst = append(dst, 'c') + dst = append(dst, src.ColumnName...) + dst = append(dst, 0) } if src.DataTypeName != "" { - buf.WriteString(src.DataTypeName) - buf.WriteByte(0) + dst = append(dst, 'd') + dst = append(dst, src.DataTypeName...) + dst = append(dst, 0) } if src.ConstraintName != "" { - buf.WriteString(src.ConstraintName) - buf.WriteByte(0) + dst = append(dst, 'n') + dst = append(dst, src.ConstraintName...) + dst = append(dst, 0) } if src.File != "" { - buf.WriteString(src.File) - buf.WriteByte(0) + dst = append(dst, 'F') + dst = append(dst, src.File...) + dst = append(dst, 0) } if src.Line != 0 { - buf.WriteString(strconv.Itoa(int(src.Line))) - buf.WriteByte(0) + dst = append(dst, 'L') + dst = append(dst, strconv.Itoa(int(src.Line))...) + dst = append(dst, 0) } if src.Routine != "" { - buf.WriteString(src.Routine) - buf.WriteByte(0) + dst = append(dst, 'R') + dst = append(dst, src.Routine...) + dst = append(dst, 0) } for k, v := range src.UnknownFields { - buf.WriteByte(k) - buf.WriteByte(0) - buf.WriteString(v) - buf.WriteByte(0) + dst = append(dst, k) + dst = append(dst, v...) + dst = append(dst, 0) } - buf.WriteByte(0) - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + dst = append(dst, 0) + + return dst +} - return buf.Bytes() +// MarshalJSON implements encoding/json.Marshaler. +func (src ErrorResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Severity string + SeverityUnlocalized string // only in 9.6 and greater + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string + + UnknownFields map[byte]string + }{ + Type: "ErrorResponse", + Severity: src.Severity, + SeverityUnlocalized: src.SeverityUnlocalized, + Code: src.Code, + Message: src.Message, + Detail: src.Detail, + Hint: src.Hint, + Position: src.Position, + InternalPosition: src.InternalPosition, + InternalQuery: src.InternalQuery, + Where: src.Where, + SchemaName: src.SchemaName, + TableName: src.TableName, + ColumnName: src.ColumnName, + DataTypeName: src.DataTypeName, + ConstraintName: src.ConstraintName, + File: src.File, + Line: src.Line, + Routine: src.Routine, + UnknownFields: src.UnknownFields, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *ErrorResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + Severity string + SeverityUnlocalized string // only in 9.6 and greater + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string + + UnknownFields map[byte]string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Severity = msg.Severity + dst.SeverityUnlocalized = msg.SeverityUnlocalized + dst.Code = msg.Code + dst.Message = msg.Message + dst.Detail = msg.Detail + dst.Hint = msg.Hint + dst.Position = msg.Position + dst.InternalPosition = msg.InternalPosition + dst.InternalQuery = msg.InternalQuery + dst.Where = msg.Where + dst.SchemaName = msg.SchemaName + dst.TableName = msg.TableName + dst.ColumnName = msg.ColumnName + dst.DataTypeName = msg.DataTypeName + dst.ConstraintName = msg.ConstraintName + dst.File = msg.File + dst.Line = msg.Line + dst.Routine = msg.Routine + + dst.UnknownFields = msg.UnknownFields + + return nil } diff --git a/pgproto3/example/pgfortune/README.md b/pgproto3/example/pgfortune/README.md new file mode 100644 index 000000000..c181c38a0 --- /dev/null +++ b/pgproto3/example/pgfortune/README.md @@ -0,0 +1,53 @@ +# pgfortune + +pgfortune is a mock PostgreSQL server that responds to every query with a fortune. + +## Installation + +Install `fortune` and `cowsay`. They should be available in any Unix package manager (apt, yum, brew, etc.) + +``` +go get -u github.com/jackc/pgproto3/example/pgfortune +``` + +## Usage + +``` +$ pgfortune +``` + +By default pgfortune listens on 127.0.0.1:15432 and responds to queries with `fortune | cowsay -f elephant`. These are +configurable with the `listen` and `response-command` arguments respectively. + +While `pgfortune` is running connect to it with `psql`. + +``` +$ psql -h 127.0.0.1 -p 15432 +Timing is on. +Null display is "∅". +Line style is unicode. +psql (11.5, server 0.0.0) +Type "help" for help. + +jack@127.0.0.1:15432 jack=# select foo; + fortune +───────────────────────────────────────────── + _________________________________________ ↵ + / Ships are safe in harbor, but they were \↵ + \ never meant to stay there. /↵ + ----------------------------------------- ↵ + \ /\ ___ /\ ↵ + \ // \/ \/ \\ ↵ + (( O O )) ↵ + \\ / \ // ↵ + \/ | | \/ ↵ + | | | | ↵ + | | | | ↵ + | o | ↵ + | | | | ↵ + |m| |m| ↵ + +(1 row) + +Time: 28.161 ms +``` diff --git a/pgproto3/example/pgfortune/main.go b/pgproto3/example/pgfortune/main.go new file mode 100644 index 000000000..0c25510b9 --- /dev/null +++ b/pgproto3/example/pgfortune/main.go @@ -0,0 +1,51 @@ +package main + +import ( + "flag" + "fmt" + "log" + "net" + "os" + "os/exec" +) + +var options struct { + listenAddress string + responseCommand string +} + +func main() { + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "usage: %s [options]\n", os.Args[0]) + flag.PrintDefaults() + } + + flag.StringVar(&options.listenAddress, "listen", "127.0.0.1:15432", "Listen address") + flag.StringVar(&options.responseCommand, "response-command", "fortune | cowsay -f elephant", "Command to execute to generate query response") + flag.Parse() + + ln, err := net.Listen("tcp", options.listenAddress) + if err != nil { + log.Fatal(err) + } + log.Println("Listening on", ln.Addr()) + + for { + conn, err := ln.Accept() + if err != nil { + log.Fatal(err) + } + log.Println("Accepted connection from", conn.RemoteAddr()) + + b := NewPgFortuneBackend(conn, func() ([]byte, error) { + return exec.Command("sh", "-c", options.responseCommand).CombinedOutput() + }) + go func() { + err := b.Run() + if err != nil { + log.Println(err) + } + log.Println("Closed connection from", conn.RemoteAddr()) + }() + } +} diff --git a/pgproto3/example/pgfortune/server.go b/pgproto3/example/pgfortune/server.go new file mode 100644 index 000000000..06a45dda0 --- /dev/null +++ b/pgproto3/example/pgfortune/server.go @@ -0,0 +1,111 @@ +package main + +import ( + "fmt" + "net" + + "github.com/jackc/pgx/v5/pgproto3" +) + +type PgFortuneBackend struct { + backend *pgproto3.Backend + conn net.Conn + responder func() ([]byte, error) +} + +func NewPgFortuneBackend(conn net.Conn, responder func() ([]byte, error)) *PgFortuneBackend { + backend := pgproto3.NewBackend(conn, conn) + + connHandler := &PgFortuneBackend{ + backend: backend, + conn: conn, + responder: responder, + } + + return connHandler +} + +func (p *PgFortuneBackend) Run() error { + defer p.Close() + + err := p.handleStartup() + if err != nil { + return err + } + + for { + msg, err := p.backend.Receive() + if err != nil { + return fmt.Errorf("error receiving message: %w", err) + } + + switch msg.(type) { + case *pgproto3.Query: + response, err := p.responder() + if err != nil { + return fmt.Errorf("error generating query response: %w", err) + } + + buf := mustEncode((&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ + { + Name: []byte("fortune"), + TableOID: 0, + TableAttributeNumber: 0, + DataTypeOID: 25, + DataTypeSize: -1, + TypeModifier: -1, + Format: 0, + }, + }}).Encode(nil)) + buf = mustEncode((&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf)) + buf = mustEncode((&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf)) + buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)) + _, err = p.conn.Write(buf) + if err != nil { + return fmt.Errorf("error writing query response: %w", err) + } + case *pgproto3.Terminate: + return nil + default: + return fmt.Errorf("received message other than Query from client: %#v", msg) + } + } +} + +func (p *PgFortuneBackend) handleStartup() error { + startupMessage, err := p.backend.ReceiveStartupMessage() + if err != nil { + return fmt.Errorf("error receiving startup message: %w", err) + } + + switch startupMessage.(type) { + case *pgproto3.StartupMessage: + buf := mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil)) + buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)) + _, err = p.conn.Write(buf) + if err != nil { + return fmt.Errorf("error sending ready for query: %w", err) + } + case *pgproto3.SSLRequest: + _, err = p.conn.Write([]byte("N")) + if err != nil { + return fmt.Errorf("error sending deny SSL request: %w", err) + } + return p.handleStartup() + default: + return fmt.Errorf("unknown startup message: %#v", startupMessage) + } + + return nil +} + +func (p *PgFortuneBackend) Close() error { + return p.conn.Close() +} + +func mustEncode(buf []byte, err error) []byte { + if err != nil { + panic(err) + } + return buf +} diff --git a/pgproto3/execute.go b/pgproto3/execute.go index 76da9943d..31bc714d1 100644 --- a/pgproto3/execute.go +++ b/pgproto3/execute.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type Execute struct { @@ -13,8 +13,11 @@ type Execute struct { MaxRows uint32 } +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*Execute) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *Execute) Decode(src []byte) error { buf := bytes.NewBuffer(src) @@ -32,22 +35,17 @@ func (dst *Execute) Decode(src []byte) error { return nil } -func (src *Execute) Encode(dst []byte) []byte { - dst = append(dst, 'E') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Execute) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'E') dst = append(dst, src.Portal...) dst = append(dst, 0) - dst = pgio.AppendUint32(dst, src.MaxRows) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } -func (src *Execute) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src Execute) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string Portal string diff --git a/pgproto3/flush.go b/pgproto3/flush.go index 7fd5e987c..e5dc1fbbd 100644 --- a/pgproto3/flush.go +++ b/pgproto3/flush.go @@ -6,8 +6,11 @@ import ( type Flush struct{} +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*Flush) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *Flush) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "Flush", expectedLen: 0, actualLen: len(src)} @@ -16,11 +19,13 @@ func (dst *Flush) Decode(src []byte) error { return nil } -func (src *Flush) Encode(dst []byte) []byte { - return append(dst, 'H', 0, 0, 0, 4) +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Flush) Encode(dst []byte) ([]byte, error) { + return append(dst, 'H', 0, 0, 0, 4), nil } -func (src *Flush) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src Flush) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string }{ diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index d803d362d..056e547cd 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -1,122 +1,468 @@ package pgproto3 import ( + "bytes" "encoding/binary" + "errors" + "fmt" "io" - - "github.com/jackc/pgx/chunkreader" - "github.com/pkg/errors" ) +// Frontend acts as a client for the PostgreSQL wire protocol version 3. type Frontend struct { - cr *chunkreader.ChunkReader + cr *chunkReader w io.Writer + // tracer is used to trace messages when Send or Receive is called. This means an outbound message is traced + // before it is actually transmitted (i.e. before Flush). It is safe to change this variable when the Frontend is + // idle. Setting and unsetting tracer provides equivalent functionality to PQtrace and PQuntrace in libpq. + tracer *tracer + + wbuf []byte + encodeError error + // Backend message flyweights - authentication Authentication - backendKeyData BackendKeyData - bindComplete BindComplete - closeComplete CloseComplete - commandComplete CommandComplete - copyBothResponse CopyBothResponse - copyData CopyData - copyInResponse CopyInResponse - copyOutResponse CopyOutResponse - dataRow DataRow - emptyQueryResponse EmptyQueryResponse - errorResponse ErrorResponse - functionCallResponse FunctionCallResponse - noData NoData - noticeResponse NoticeResponse - notificationResponse NotificationResponse - parameterDescription ParameterDescription - parameterStatus ParameterStatus - parseComplete ParseComplete - readyForQuery ReadyForQuery - rowDescription RowDescription + authenticationOk AuthenticationOk + authenticationCleartextPassword AuthenticationCleartextPassword + authenticationMD5Password AuthenticationMD5Password + authenticationGSS AuthenticationGSS + authenticationGSSContinue AuthenticationGSSContinue + authenticationSASL AuthenticationSASL + authenticationSASLContinue AuthenticationSASLContinue + authenticationSASLFinal AuthenticationSASLFinal + backendKeyData BackendKeyData + bindComplete BindComplete + closeComplete CloseComplete + commandComplete CommandComplete + copyBothResponse CopyBothResponse + copyData CopyData + copyInResponse CopyInResponse + copyOutResponse CopyOutResponse + copyDone CopyDone + dataRow DataRow + emptyQueryResponse EmptyQueryResponse + errorResponse ErrorResponse + functionCallResponse FunctionCallResponse + noData NoData + noticeResponse NoticeResponse + notificationResponse NotificationResponse + parameterDescription ParameterDescription + parameterStatus ParameterStatus + parseComplete ParseComplete + readyForQuery ReadyForQuery + rowDescription RowDescription + portalSuspended PortalSuspended bodyLen int + maxBodyLen int // maxBodyLen is the maximum length of a message body in octets. If a message body exceeds this length, Receive will return an error. msgType byte partialMsg bool + authType uint32 +} + +// NewFrontend creates a new Frontend. +func NewFrontend(r io.Reader, w io.Writer) *Frontend { + cr := newChunkReader(r, 0) + return &Frontend{cr: cr, w: w} +} + +// Send sends a message to the backend (i.e. the server). The message is buffered until Flush is called. Any error +// encountered will be returned from Flush. +// +// Send can work with any FrontendMessage. Some commonly used message types such as Bind have specialized send methods +// such as SendBind. These methods should be preferred when the type of message is known up front (e.g. when building an +// extended query protocol query) as they may be faster due to knowing the type of msg rather than it being hidden +// behind an interface. +func (f *Frontend) Send(msg FrontendMessage) { + if f.encodeError != nil { + return + } + + prevLen := len(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + + if f.tracer != nil { + f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// Flush writes any pending messages to the backend (i.e. the server). +func (f *Frontend) Flush() error { + if err := f.encodeError; err != nil { + f.encodeError = nil + f.wbuf = f.wbuf[:0] + return &writeError{err: err, safeToRetry: true} + } + + if len(f.wbuf) == 0 { + return nil + } + + n, err := f.w.Write(f.wbuf) + + const maxLen = 1024 + if len(f.wbuf) > maxLen { + f.wbuf = make([]byte, 0, maxLen) + } else { + f.wbuf = f.wbuf[:0] + } + + if err != nil { + return &writeError{err: err, safeToRetry: n == 0} + } + + return nil +} + +// Trace starts tracing the message traffic to w. It writes in a similar format to that produced by the libpq function +// PQtrace. +func (f *Frontend) Trace(w io.Writer, options TracerOptions) { + f.tracer = &tracer{ + w: w, + buf: &bytes.Buffer{}, + TracerOptions: options, + } +} + +// Untrace stops tracing. +func (f *Frontend) Untrace() { + f.tracer = nil +} + +// SendBind sends a Bind message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. +func (f *Frontend) SendBind(msg *Bind) { + if f.encodeError != nil { + return + } + + prevLen := len(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + + if f.tracer != nil { + f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendParse sends a Parse message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. +func (f *Frontend) SendParse(msg *Parse) { + if f.encodeError != nil { + return + } + + prevLen := len(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + + if f.tracer != nil { + f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendClose sends a Close message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. +func (f *Frontend) SendClose(msg *Close) { + if f.encodeError != nil { + return + } + + prevLen := len(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + + if f.tracer != nil { + f.tracer.traceClose('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendDescribe sends a Describe message to the backend (i.e. the server). The message is buffered until Flush is +// called. Any error encountered will be returned from Flush. +func (f *Frontend) SendDescribe(msg *Describe) { + if f.encodeError != nil { + return + } + + prevLen := len(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + + if f.tracer != nil { + f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg) + } } -func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) { - cr := chunkreader.NewChunkReader(r) - return &Frontend{cr: cr, w: w}, nil +// SendExecute sends an Execute message to the backend (i.e. the server). The message is buffered until Flush is called. +// Any error encountered will be returned from Flush. +func (f *Frontend) SendExecute(msg *Execute) { + if f.encodeError != nil { + return + } + + prevLen := len(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + + if f.tracer != nil { + f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendSync sends a Sync message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. +func (f *Frontend) SendSync(msg *Sync) { + if f.encodeError != nil { + return + } + + prevLen := len(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + + if f.tracer != nil { + f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendQuery sends a Query message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. +func (f *Frontend) SendQuery(msg *Query) { + if f.encodeError != nil { + return + } + + prevLen := len(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + + if f.tracer != nil { + f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendUnbufferedEncodedCopyData immediately sends an encoded CopyData message to the backend (i.e. the server). This method +// is more efficient than sending a CopyData message with Send as the message data is not copied to the internal buffer +// before being written out. The internal buffer is flushed before the message is sent. +func (f *Frontend) SendUnbufferedEncodedCopyData(msg []byte) error { + err := f.Flush() + if err != nil { + return err + } + + n, err := f.w.Write(msg) + if err != nil { + return &writeError{err: err, safeToRetry: n == 0} + } + + if f.tracer != nil { + f.tracer.traceCopyData('F', int32(len(msg)-1), &CopyData{}) + } + + return nil } -func (b *Frontend) Send(msg FrontendMessage) error { - _, err := b.w.Write(msg.Encode(nil)) +func translateEOFtoErrUnexpectedEOF(err error) error { + if err == io.EOF { + return io.ErrUnexpectedEOF + } return err } -func (b *Frontend) Receive() (BackendMessage, error) { - if !b.partialMsg { - header, err := b.cr.Next(5) +// Receive receives a message from the backend. The returned message is only valid until the next call to Receive. +func (f *Frontend) Receive() (BackendMessage, error) { + if !f.partialMsg { + header, err := f.cr.Next(5) if err != nil { - return nil, err + return nil, translateEOFtoErrUnexpectedEOF(err) + } + + f.msgType = header[0] + + msgLength := int(binary.BigEndian.Uint32(header[1:])) + if msgLength < 4 { + return nil, fmt.Errorf("invalid message length: %d", msgLength) + } + + f.bodyLen = msgLength - 4 + if f.maxBodyLen > 0 && f.bodyLen > f.maxBodyLen { + return nil, &ExceededMaxBodyLenErr{f.maxBodyLen, f.bodyLen} } + f.partialMsg = true + } - b.msgType = header[0] - b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 - b.partialMsg = true + msgBody, err := f.cr.Next(f.bodyLen) + if err != nil { + return nil, translateEOFtoErrUnexpectedEOF(err) } + f.partialMsg = false + var msg BackendMessage - switch b.msgType { + switch f.msgType { case '1': - msg = &b.parseComplete + msg = &f.parseComplete case '2': - msg = &b.bindComplete + msg = &f.bindComplete case '3': - msg = &b.closeComplete + msg = &f.closeComplete case 'A': - msg = &b.notificationResponse + msg = &f.notificationResponse + case 'c': + msg = &f.copyDone case 'C': - msg = &b.commandComplete + msg = &f.commandComplete case 'd': - msg = &b.copyData + msg = &f.copyData case 'D': - msg = &b.dataRow + msg = &f.dataRow case 'E': - msg = &b.errorResponse + msg = &f.errorResponse case 'G': - msg = &b.copyInResponse + msg = &f.copyInResponse case 'H': - msg = &b.copyOutResponse + msg = &f.copyOutResponse case 'I': - msg = &b.emptyQueryResponse + msg = &f.emptyQueryResponse case 'K': - msg = &b.backendKeyData + msg = &f.backendKeyData case 'n': - msg = &b.noData + msg = &f.noData case 'N': - msg = &b.noticeResponse + msg = &f.noticeResponse case 'R': - msg = &b.authentication + var err error + msg, err = f.findAuthenticationMessageType(msgBody) + if err != nil { + return nil, err + } + case 's': + msg = &f.portalSuspended case 'S': - msg = &b.parameterStatus + msg = &f.parameterStatus case 't': - msg = &b.parameterDescription + msg = &f.parameterDescription case 'T': - msg = &b.rowDescription + msg = &f.rowDescription case 'V': - msg = &b.functionCallResponse + msg = &f.functionCallResponse case 'W': - msg = &b.copyBothResponse + msg = &f.copyBothResponse case 'Z': - msg = &b.readyForQuery + msg = &f.readyForQuery default: - return nil, errors.Errorf("unknown message type: %c", b.msgType) + return nil, fmt.Errorf("unknown message type: %c", f.msgType) } - msgBody, err := b.cr.Next(b.bodyLen) + err = msg.Decode(msgBody) if err != nil { return nil, err } - b.partialMsg = false + if f.tracer != nil { + f.tracer.traceMessage('B', int32(5+len(msgBody)), msg) + } - err = msg.Decode(msgBody) - return msg, err + return msg, nil +} + +// Authentication message type constants. +// See src/include/libpq/pqcomm.h for all +// constants. +const ( + AuthTypeOk = 0 + AuthTypeCleartextPassword = 3 + AuthTypeMD5Password = 5 + AuthTypeSCMCreds = 6 + AuthTypeGSS = 7 + AuthTypeGSSCont = 8 + AuthTypeSSPI = 9 + AuthTypeSASL = 10 + AuthTypeSASLContinue = 11 + AuthTypeSASLFinal = 12 +) + +func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) { + if len(src) < 4 { + return nil, errors.New("authentication message too short") + } + f.authType = binary.BigEndian.Uint32(src[:4]) + + switch f.authType { + case AuthTypeOk: + return &f.authenticationOk, nil + case AuthTypeCleartextPassword: + return &f.authenticationCleartextPassword, nil + case AuthTypeMD5Password: + return &f.authenticationMD5Password, nil + case AuthTypeSCMCreds: + return nil, errors.New("AuthTypeSCMCreds is unimplemented") + case AuthTypeGSS: + return &f.authenticationGSS, nil + case AuthTypeGSSCont: + return &f.authenticationGSSContinue, nil + case AuthTypeSSPI: + return nil, errors.New("AuthTypeSSPI is unimplemented") + case AuthTypeSASL: + return &f.authenticationSASL, nil + case AuthTypeSASLContinue: + return &f.authenticationSASLContinue, nil + case AuthTypeSASLFinal: + return &f.authenticationSASLFinal, nil + default: + return nil, fmt.Errorf("unknown authentication type: %d", f.authType) + } +} + +// GetAuthType returns the authType used in the current state of the frontend. +// See SetAuthType for more information. +func (f *Frontend) GetAuthType() uint32 { + return f.authType +} + +func (f *Frontend) ReadBufferLen() int { + return f.cr.wp - f.cr.rp +} + +// SetMaxBodyLen sets the maximum length of a message body in octets. +// If a message body exceeds this length, Receive will return an error. +// This is useful for protecting against a corrupted server that sends +// messages with incorrect length, which can cause memory exhaustion. +// The default value is 0. +// If maxBodyLen is 0, then no maximum is enforced. +func (f *Frontend) SetMaxBodyLen(maxBodyLen int) { + f.maxBodyLen = maxBodyLen } diff --git a/pgproto3/frontend_test.go b/pgproto3/frontend_test.go index 7d6652c10..c2872661d 100644 --- a/pgproto3/frontend_test.go +++ b/pgproto3/frontend_test.go @@ -1,11 +1,12 @@ package pgproto3_test import ( + "io" "testing" - "github.com/pkg/errors" - - "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type interruptReader struct { @@ -14,7 +15,7 @@ type interruptReader struct { func (ir *interruptReader) Read(p []byte) (n int, err error) { if len(ir.chunks) == 0 { - return 0, errors.New("no data") + return 0, io.EOF } n = copy(p, ir.chunks[0]) @@ -37,10 +38,7 @@ func TestFrontendReceiveInterrupted(t *testing.T) { server := &interruptReader{} server.push([]byte{'Z', 0, 0, 0, 5}) - frontend, err := pgproto3.NewFrontend(server, nil) - if err != nil { - t.Fatal(err) - } + frontend := pgproto3.NewFrontend(server, nil) msg, err := frontend.Receive() if err == nil { @@ -60,3 +58,78 @@ func TestFrontendReceiveInterrupted(t *testing.T) { t.Fatalf("unexpected msg: %v", msg) } } + +func TestFrontendReceiveUnexpectedEOF(t *testing.T) { + t.Parallel() + + server := &interruptReader{} + server.push([]byte{'Z', 0, 0, 0, 5}) + + frontend := pgproto3.NewFrontend(server, nil) + + msg, err := frontend.Receive() + if err == nil { + t.Fatal("expected err") + } + if msg != nil { + t.Fatalf("did not expect msg, but %v", msg) + } + + msg, err = frontend.Receive() + assert.Nil(t, msg) + assert.Equal(t, io.ErrUnexpectedEOF, err) +} + +func TestErrorResponse(t *testing.T) { + t.Parallel() + + want := &pgproto3.ErrorResponse{ + Severity: "ERROR", + SeverityUnlocalized: "ERROR", + Message: `column "foo" does not exist`, + File: "parse_relation.c", + Code: "42703", + Position: 8, + Line: 3513, + Routine: "errorMissingColumn", + } + + raw := []byte{ + 'E', 0, 0, 0, 'f', + 'S', 'E', 'R', 'R', 'O', 'R', 0, + 'V', 'E', 'R', 'R', 'O', 'R', 0, + 'C', '4', '2', '7', '0', '3', 0, + 'M', 'c', 'o', 'l', 'u', 'm', 'n', 32, '"', 'f', 'o', 'o', '"', 32, 'd', 'o', 'e', 's', 32, 'n', 'o', 't', 32, 'e', 'x', 'i', 's', 't', 0, + 'P', '8', 0, + 'F', 'p', 'a', 'r', 's', 'e', '_', 'r', 'e', 'l', 'a', 't', 'i', 'o', 'n', '.', 'c', 0, + 'L', '3', '5', '1', '3', 0, + 'R', 'e', 'r', 'r', 'o', 'r', 'M', 'i', 's', 's', 'i', 'n', 'g', 'C', 'o', 'l', 'u', 'm', 'n', 0, 0, + } + + server := &interruptReader{} + server.push(raw) + + frontend := pgproto3.NewFrontend(server, nil) + + got, err := frontend.Receive() + require.NoError(t, err) + assert.Equal(t, want, got) +} + +func TestFrontendReceiveExceededMaxBodyLen(t *testing.T) { + t.Parallel() + + client := &interruptReader{} + client.push([]byte{'D', 0, 0, 10, 10}) + + frontend := pgproto3.NewFrontend(client, nil) + + // Set max body len to 5 + frontend.SetMaxBodyLen(5) + + // Receive regular msg + msg, err := frontend.Receive() + assert.Nil(t, msg) + var invalidBodyLenErr *pgproto3.ExceededMaxBodyLenErr + assert.ErrorAs(t, err, &invalidBodyLenErr) +} diff --git a/pgproto3/function_call.go b/pgproto3/function_call.go new file mode 100644 index 000000000..7d83579ff --- /dev/null +++ b/pgproto3/function_call.go @@ -0,0 +1,102 @@ +package pgproto3 + +import ( + "encoding/binary" + "errors" + "math" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type FunctionCall struct { + Function uint32 + ArgFormatCodes []uint16 + Arguments [][]byte + ResultFormatCode uint16 +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*FunctionCall) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *FunctionCall) Decode(src []byte) error { + *dst = FunctionCall{} + rp := 0 + // Specifies the object ID of the function to call. + dst.Function = binary.BigEndian.Uint32(src[rp:]) + rp += 4 + // The number of argument format codes that follow (denoted C below). + // This can be zero to indicate that there are no arguments or that the arguments all use the default format (text); + // or one, in which case the specified format code is applied to all arguments; + // or it can equal the actual number of arguments. + nArgumentCodes := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + argumentCodes := make([]uint16, nArgumentCodes) + for i := 0; i < nArgumentCodes; i++ { + // The argument format codes. Each must presently be zero (text) or one (binary). + ac := binary.BigEndian.Uint16(src[rp:]) + if ac != 0 && ac != 1 { + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } + argumentCodes[i] = ac + rp += 2 + } + dst.ArgFormatCodes = argumentCodes + + // Specifies the number of arguments being supplied to the function. + nArguments := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + arguments := make([][]byte, nArguments) + for i := 0; i < nArguments; i++ { + // The length of the argument value, in bytes (this count does not include itself). Can be zero. + // As a special case, -1 indicates a NULL argument value. No value bytes follow in the NULL case. + argumentLength := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + if argumentLength == -1 { + arguments[i] = nil + } else { + // The value of the argument, in the format indicated by the associated format code. n is the above length. + argumentValue := src[rp : rp+argumentLength] + rp += argumentLength + arguments[i] = argumentValue + } + } + dst.Arguments = arguments + // The format code for the function result. Must presently be zero (text) or one (binary). + resultFormatCode := binary.BigEndian.Uint16(src[rp:]) + if resultFormatCode != 0 && resultFormatCode != 1 { + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } + dst.ResultFormatCode = resultFormatCode + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *FunctionCall) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'F') + dst = pgio.AppendUint32(dst, src.Function) + + if len(src.ArgFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many arg format codes") + } + dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes))) + for _, argFormatCode := range src.ArgFormatCodes { + dst = pgio.AppendUint16(dst, argFormatCode) + } + + if len(src.Arguments) > math.MaxUint16 { + return nil, errors.New("too many arguments") + } + dst = pgio.AppendUint16(dst, uint16(len(src.Arguments))) + for _, argument := range src.Arguments { + if argument == nil { + dst = pgio.AppendInt32(dst, -1) + } else { + dst = pgio.AppendInt32(dst, int32(len(argument))) + dst = append(dst, argument...) + } + } + dst = pgio.AppendUint16(dst, src.ResultFormatCode) + return finishMessage(dst, sp) +} diff --git a/pgproto3/function_call_response.go b/pgproto3/function_call_response.go index bb325b698..1f2734952 100644 --- a/pgproto3/function_call_response.go +++ b/pgproto3/function_call_response.go @@ -5,15 +5,18 @@ import ( "encoding/hex" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type FunctionCallResponse struct { Result []byte } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*FunctionCallResponse) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *FunctionCallResponse) Decode(src []byte) error { if len(src) < 4 { return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} @@ -35,10 +38,9 @@ func (dst *FunctionCallResponse) Decode(src []byte) error { return nil } -func (src *FunctionCallResponse) Encode(dst []byte) []byte { - dst = append(dst, 'V') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *FunctionCallResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'V') if src.Result == nil { dst = pgio.AppendInt32(dst, -1) @@ -47,12 +49,11 @@ func (src *FunctionCallResponse) Encode(dst []byte) []byte { dst = append(dst, src.Result...) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } -func (src *FunctionCallResponse) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src FunctionCallResponse) MarshalJSON() ([]byte, error) { var formattedValue map[string]string var hasNonPrintable bool for _, b := range src.Result { @@ -76,3 +77,21 @@ func (src *FunctionCallResponse) MarshalJSON() ([]byte, error) { Result: formattedValue, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *FunctionCallResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Result map[string]string + } + err := json.Unmarshal(data, &msg) + if err != nil { + return err + } + dst.Result, err = getValueFromJSON(msg.Result) + return err +} diff --git a/pgproto3/function_call_test.go b/pgproto3/function_call_test.go new file mode 100644 index 000000000..2a70fd308 --- /dev/null +++ b/pgproto3/function_call_test.go @@ -0,0 +1,65 @@ +package pgproto3 + +import ( + "encoding/binary" + "reflect" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFunctionCall_EncodeDecode(t *testing.T) { + type fields struct { + Function uint32 + ArgFormatCodes []uint16 + Arguments [][]byte + ResultFormatCode uint16 + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + {"valid", fields{uint32(123), []uint16{0, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(1)}, false}, + {"invalid format code", fields{uint32(123), []uint16{2, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(0)}, true}, + {"invalid result format code", fields{uint32(123), []uint16{1, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(2)}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + src := &FunctionCall{ + Function: tt.fields.Function, + ArgFormatCodes: tt.fields.ArgFormatCodes, + Arguments: tt.fields.Arguments, + ResultFormatCode: tt.fields.ResultFormatCode, + } + encoded, err := src.Encode([]byte{}) + require.NoError(t, err) + dst := &FunctionCall{} + // Check the header + msgTypeCode := encoded[0] + if msgTypeCode != 'F' { + t.Errorf("msgTypeCode %v should be 'F'", msgTypeCode) + return + } + // Check length, does not include type code character + l := binary.BigEndian.Uint32(encoded[1:5]) + if int(l) != (len(encoded) - 1) { + t.Errorf("Incorrect message length, got = %v, wanted = %v", l, len(encoded)) + } + // Check decoding works as expected + err = dst.Decode(encoded[5:]) + if err != nil { + if !tt.wantErr { + t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr) + } + return + } + + if !reflect.DeepEqual(src, dst) { + t.Error("difference after encode / decode cycle") + t.Errorf("src = %v", src) + t.Errorf("dst = %v", dst) + } + }) + } +} diff --git a/pgproto3/fuzz_test.go b/pgproto3/fuzz_test.go new file mode 100644 index 000000000..332596aba --- /dev/null +++ b/pgproto3/fuzz_test.go @@ -0,0 +1,57 @@ +package pgproto3_test + +import ( + "bytes" + "testing" + + "github.com/jackc/pgx/v5/internal/pgio" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/require" +) + +func FuzzFrontend(f *testing.F) { + testcases := []struct { + msgType byte + msgLen uint32 + msgBody []byte + }{ + { + msgType: 'Z', + msgLen: 2, + msgBody: []byte{'I'}, + }, + { + msgType: 'Z', + msgLen: 5, + msgBody: []byte{'I'}, + }, + } + for _, tc := range testcases { + f.Add(tc.msgType, tc.msgLen, tc.msgBody) + } + f.Fuzz(func(t *testing.T, msgType byte, msgLen uint32, msgBody []byte) { + // Prune any msgLen > len(msgBody) because they would hang the test waiting for more input. + if int(msgLen) > len(msgBody)+4 { + return + } + + // Prune any messages that are too long. + if msgLen > 128 || len(msgBody) > 128 { + return + } + + r := &bytes.Buffer{} + w := &bytes.Buffer{} + fe := pgproto3.NewFrontend(r, w) + + var encodedMsg []byte + encodedMsg = append(encodedMsg, msgType) + encodedMsg = pgio.AppendUint32(encodedMsg, msgLen) + encodedMsg = append(encodedMsg, msgBody...) + _, err := r.Write(encodedMsg) + require.NoError(t, err) + + // Not checking anything other than no panic. + fe.Receive() + }) +} diff --git a/pgproto3/gss_enc_request.go b/pgproto3/gss_enc_request.go new file mode 100644 index 000000000..122d1341c --- /dev/null +++ b/pgproto3/gss_enc_request.go @@ -0,0 +1,48 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +const gssEncReqNumber = 80877104 + +type GSSEncRequest struct{} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*GSSEncRequest) Frontend() {} + +func (dst *GSSEncRequest) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("gss encoding request too short") + } + + requestCode := binary.BigEndian.Uint32(src) + + if requestCode != gssEncReqNumber { + return errors.New("bad gss encoding request code") + } + + return nil +} + +// Encode encodes src into dst. dst will include the 4 byte message length. +func (src *GSSEncRequest) Encode(dst []byte) ([]byte, error) { + dst = pgio.AppendInt32(dst, 8) + dst = pgio.AppendInt32(dst, gssEncReqNumber) + return dst, nil +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src GSSEncRequest) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProtocolVersion uint32 + Parameters map[string]string + }{ + Type: "GSSEncRequest", + }) +} diff --git a/pgproto3/gss_response.go b/pgproto3/gss_response.go new file mode 100644 index 000000000..10d937759 --- /dev/null +++ b/pgproto3/gss_response.go @@ -0,0 +1,46 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type GSSResponse struct { + Data []byte +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (g *GSSResponse) Frontend() {} + +func (g *GSSResponse) Decode(data []byte) error { + g.Data = data + return nil +} + +func (g *GSSResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') + dst = append(dst, g.Data...) + return finishMessage(dst, sp) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (g *GSSResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data []byte + }{ + Type: "GSSResponse", + Data: g.Data, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (g *GSSResponse) UnmarshalJSON(data []byte) error { + var msg struct { + Data []byte + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + g.Data = msg.Data + return nil +} diff --git a/pgproto3/json_test.go b/pgproto3/json_test.go new file mode 100644 index 000000000..677221249 --- /dev/null +++ b/pgproto3/json_test.go @@ -0,0 +1,611 @@ +package pgproto3 + +import ( + "encoding/hex" + "encoding/json" + "reflect" + "testing" +) + +func TestJSONUnmarshalAuthenticationMD5Password(t *testing.T) { + data := []byte(`{"Type":"AuthenticationMD5Password", "Salt":[97,98,99,100]}`) + want := AuthenticationMD5Password{ + Salt: [4]byte{'a', 'b', 'c', 'd'}, + } + + var got AuthenticationMD5Password + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationMD5Password struct doesn't match expected value") + } +} + +func TestJSONUnmarshalAuthenticationSASL(t *testing.T) { + data := []byte(`{"Type":"AuthenticationSASL","AuthMechanisms":["SCRAM-SHA-256"]}`) + want := AuthenticationSASL{ + []string{"SCRAM-SHA-256"}, + } + + var got AuthenticationSASL + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationSASL struct doesn't match expected value") + } +} + +func TestJSONUnmarshalAuthenticationGSS(t *testing.T) { + data := []byte(`{"Type":"AuthenticationGSS"}`) + want := AuthenticationGSS{} + + var got AuthenticationGSS + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationGSS struct doesn't match expected value") + } +} + +func TestJSONUnmarshalAuthenticationGSSContinue(t *testing.T) { + data := []byte(`{"Type":"AuthenticationGSSContinue","Data":[1,2,3,4]}`) + want := AuthenticationGSSContinue{Data: []byte{1, 2, 3, 4}} + + var got AuthenticationGSSContinue + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationGSSContinue struct doesn't match expected value") + } +} + +func TestJSONUnmarshalAuthenticationSASLContinue(t *testing.T) { + data := []byte(`{"Type":"AuthenticationSASLContinue", "Data":"1"}`) + want := AuthenticationSASLContinue{ + Data: []byte{'1'}, + } + + var got AuthenticationSASLContinue + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationSASLContinue struct doesn't match expected value") + } +} + +func TestJSONUnmarshalAuthenticationSASLFinal(t *testing.T) { + data := []byte(`{"Type":"AuthenticationSASLFinal", "Data":"1"}`) + want := AuthenticationSASLFinal{ + Data: []byte{'1'}, + } + + var got AuthenticationSASLFinal + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationSASLFinal struct doesn't match expected value") + } +} + +func TestJSONUnmarshalBackendKeyData(t *testing.T) { + data := []byte(`{"Type":"BackendKeyData","ProcessID":8864,"SecretKey":3641487067}`) + want := BackendKeyData{ + ProcessID: 8864, + SecretKey: 3641487067, + } + + var got BackendKeyData + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled BackendKeyData struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCommandComplete(t *testing.T) { + data := []byte(`{"Type":"CommandComplete","CommandTag":"SELECT 1"}`) + want := CommandComplete{ + CommandTag: []byte("SELECT 1"), + } + + var got CommandComplete + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CommandComplete struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCopyBothResponse(t *testing.T) { + data := []byte(`{"Type":"CopyBothResponse", "OverallFormat": "W"}`) + want := CopyBothResponse{ + OverallFormat: 'W', + } + + var got CopyBothResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CopyBothResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCopyData(t *testing.T) { + data := []byte(`{"Type":"CopyData"}`) + want := CopyData{ + Data: []byte{}, + } + + var got CopyData + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CopyData struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCopyInResponse(t *testing.T) { + data := []byte(`{"Type":"CopyBothResponse", "OverallFormat": "W"}`) + want := CopyBothResponse{ + OverallFormat: 'W', + } + + var got CopyBothResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CopyBothResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCopyOutResponse(t *testing.T) { + data := []byte(`{"Type":"CopyOutResponse", "OverallFormat": "W"}`) + want := CopyOutResponse{ + OverallFormat: 'W', + } + + var got CopyOutResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CopyOutResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalDataRow(t *testing.T) { + data := []byte(`{"Type":"DataRow","Values":[{"text":"abc"},{"text":"this is a test"},{"binary":"000263d3114d2e34"}]}`) + want := DataRow{ + Values: [][]byte{ + []byte("abc"), + []byte("this is a test"), + {0, 2, 99, 211, 17, 77, 46, 52}, + }, + } + + var got DataRow + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled DataRow struct doesn't match expected value") + } +} + +func TestJSONUnmarshalErrorResponse(t *testing.T) { + data := []byte(`{"Type":"ErrorResponse", "UnknownFields": {"97": "foo"}}`) + want := ErrorResponse{ + UnknownFields: map[byte]string{ + 'a': "foo", + }, + } + + var got ErrorResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled ErrorResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalFunctionCallResponse(t *testing.T) { + data := []byte(`{"Type":"FunctionCallResponse"}`) + want := FunctionCallResponse{} + + var got FunctionCallResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled FunctionCallResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalNoticeResponse(t *testing.T) { + data := []byte(`{"Type":"NoticeResponse", "UnknownFields": {"97": "foo"}}`) + want := NoticeResponse{ + UnknownFields: map[byte]string{ + 'a': "foo", + }, + } + + var got NoticeResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled NoticeResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalNotificationResponse(t *testing.T) { + data := []byte(`{"Type":"NotificationResponse"}`) + want := NotificationResponse{} + + var got NotificationResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled NotificationResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalParameterDescription(t *testing.T) { + data := []byte(`{"Type":"ParameterDescription", "ParameterOIDs": [25]}`) + want := ParameterDescription{ + ParameterOIDs: []uint32{25}, + } + + var got ParameterDescription + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled ParameterDescription struct doesn't match expected value") + } +} + +func TestJSONUnmarshalParameterStatus(t *testing.T) { + data := []byte(`{"Type":"ParameterStatus","Name":"TimeZone","Value":"Europe/Amsterdam"}`) + want := ParameterStatus{ + Name: "TimeZone", + Value: "Europe/Amsterdam", + } + + var got ParameterStatus + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled ParameterDescription struct doesn't match expected value") + } +} + +func TestJSONUnmarshalReadyForQuery(t *testing.T) { + data := []byte(`{"Type":"ReadyForQuery","TxStatus":"I"}`) + want := ReadyForQuery{ + TxStatus: 'I', + } + + var got ReadyForQuery + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled ParameterDescription struct doesn't match expected value") + } +} + +func TestJSONUnmarshalRowDescription(t *testing.T) { + data := []byte(`{"Type":"RowDescription","Fields":[{"Name":"generate_series","TableOID":0,"TableAttributeNumber":0,"DataTypeOID":23,"DataTypeSize":4,"TypeModifier":-1,"Format":0}]}`) + want := RowDescription{ + Fields: []FieldDescription{ + { + Name: []byte("generate_series"), + DataTypeOID: 23, + DataTypeSize: 4, + TypeModifier: -1, + }, + }, + } + + var got RowDescription + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled RowDescription struct doesn't match expected value") + } +} + +func TestJSONUnmarshalBind(t *testing.T) { + testCases := []struct { + desc string + data []byte + }{ + { + "textual", + []byte(`{"Type":"Bind","DestinationPortal":"","PreparedStatement":"lrupsc_1_0","ParameterFormatCodes":[0],"Parameters":[{"text":"ABC-123"}],"ResultFormatCodes":[0,0,0,0,0,1,1]}`), + }, + { + "binary", + []byte(`{"Type":"Bind","DestinationPortal":"","PreparedStatement":"lrupsc_1_0","ParameterFormatCodes":[0],"Parameters":[{"binary":"` + hex.EncodeToString([]byte("ABC-123")) + `"}],"ResultFormatCodes":[0,0,0,0,0,1,1]}`), + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + want := Bind{ + PreparedStatement: "lrupsc_1_0", + ParameterFormatCodes: []int16{0}, + Parameters: [][]byte{[]byte("ABC-123")}, + ResultFormatCodes: []int16{0, 0, 0, 0, 0, 1, 1}, + } + + var got Bind + if err := json.Unmarshal(tc.data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Bind struct doesn't match expected value") + } + }) + } +} + +func TestJSONUnmarshalCancelRequest(t *testing.T) { + data := []byte(`{"Type":"CancelRequest","ProcessID":8864,"SecretKey":3641487067}`) + want := CancelRequest{ + ProcessID: 8864, + SecretKey: 3641487067, + } + + var got CancelRequest + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CancelRequest struct doesn't match expected value") + } +} + +func TestJSONUnmarshalClose(t *testing.T) { + data := []byte(`{"Type":"Close","ObjectType":"S","Name":"abc"}`) + want := Close{ + ObjectType: 'S', + Name: "abc", + } + + var got Close + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Close struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCopyFail(t *testing.T) { + data := []byte(`{"Type":"CopyFail","Message":"abc"}`) + want := CopyFail{ + Message: "abc", + } + + var got CopyFail + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CopyFail struct doesn't match expected value") + } +} + +func TestJSONUnmarshalDescribe(t *testing.T) { + data := []byte(`{"Type":"Describe","ObjectType":"S","Name":"abc"}`) + want := Describe{ + ObjectType: 'S', + Name: "abc", + } + + var got Describe + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Describe struct doesn't match expected value") + } +} + +func TestJSONUnmarshalExecute(t *testing.T) { + data := []byte(`{"Type":"Execute","Portal":"","MaxRows":0}`) + want := Execute{} + + var got Execute + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Execute struct doesn't match expected value") + } +} + +func TestJSONUnmarshalParse(t *testing.T) { + data := []byte(`{"Type":"Parse","Name":"lrupsc_1_0","Query":"SELECT id, name FROM t WHERE id = $1","ParameterOIDs":null}`) + want := Parse{ + Name: "lrupsc_1_0", + Query: "SELECT id, name FROM t WHERE id = $1", + } + + var got Parse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Parse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalPasswordMessage(t *testing.T) { + data := []byte(`{"Type":"PasswordMessage","Password":"abcdef"}`) + want := PasswordMessage{ + Password: "abcdef", + } + + var got PasswordMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled PasswordMessage struct doesn't match expected value") + } +} + +func TestJSONUnmarshalQuery(t *testing.T) { + data := []byte(`{"Type":"Query","String":"SELECT 1"}`) + want := Query{ + String: "SELECT 1", + } + + var got Query + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Query struct doesn't match expected value") + } +} + +func TestJSONUnmarshalSASLInitialResponse(t *testing.T) { + data := []byte(`{"Type":"SASLInitialResponse", "AuthMechanism":"SCRAM-SHA-256", "Data": "6D"}`) + want := SASLInitialResponse{ + AuthMechanism: "SCRAM-SHA-256", + Data: []byte{109}, + } + + var got SASLInitialResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled SASLInitialResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalSASLResponse(t *testing.T) { + data := []byte(`{"Type":"SASLResponse","Message":"abc"}`) + want := SASLResponse{} + + var got SASLResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled SASLResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalStartupMessage(t *testing.T) { + data := []byte(`{"Type":"StartupMessage","ProtocolVersion":196608,"Parameters":{"database":"testing","user":"postgres"}}`) + want := StartupMessage{ + ProtocolVersion: 196608, + Parameters: map[string]string{ + "database": "testing", + "user": "postgres", + }, + } + + var got StartupMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled StartupMessage struct doesn't match expected value") + } +} + +func TestAuthenticationOK(t *testing.T) { + data := []byte(`{"Type":"AuthenticationOK"}`) + want := AuthenticationOk{} + + var got AuthenticationOk + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationOK struct doesn't match expected value") + } +} + +func TestAuthenticationCleartextPassword(t *testing.T) { + data := []byte(`{"Type":"AuthenticationCleartextPassword"}`) + want := AuthenticationCleartextPassword{} + + var got AuthenticationCleartextPassword + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationCleartextPassword struct doesn't match expected value") + } +} + +func TestAuthenticationMD5Password(t *testing.T) { + data := []byte(`{"Type":"AuthenticationMD5Password","Salt":[1,2,3,4]}`) + want := AuthenticationMD5Password{ + Salt: [4]byte{1, 2, 3, 4}, + } + + var got AuthenticationMD5Password + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationMD5Password struct doesn't match expected value") + } +} + +func TestJSONUnmarshalGSSResponse(t *testing.T) { + data := []byte(`{"Type":"GSSResponse","Data":[10,20,30,40]}`) + want := GSSResponse{Data: []byte{10, 20, 30, 40}} + + var got GSSResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled GSSResponse struct doesn't match expected value") + } +} + +func TestErrorResponse(t *testing.T) { + data := []byte(`{"Type":"ErrorResponse","UnknownFields":{"112":"foo"},"Code": "Fail","Position":1,"Message":"this is an error"}`) + want := ErrorResponse{ + UnknownFields: map[byte]string{ + 'p': "foo", + }, + Code: "Fail", + Position: 1, + Message: "this is an error", + } + + var got ErrorResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled ErrorResponse struct doesn't match expected value") + } +} diff --git a/pgproto3/no_data.go b/pgproto3/no_data.go index 1fb47c2a4..cbcaad40c 100644 --- a/pgproto3/no_data.go +++ b/pgproto3/no_data.go @@ -6,8 +6,11 @@ import ( type NoData struct{} +// Backend identifies this message as sendable by the PostgreSQL backend. func (*NoData) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *NoData) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)} @@ -16,11 +19,13 @@ func (dst *NoData) Decode(src []byte) error { return nil } -func (src *NoData) Encode(dst []byte) []byte { - return append(dst, 'n', 0, 0, 0, 4) +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *NoData) Encode(dst []byte) ([]byte, error) { + return append(dst, 'n', 0, 0, 0, 4), nil } -func (src *NoData) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src NoData) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string }{ diff --git a/pgproto3/notice_response.go b/pgproto3/notice_response.go index e4595aa5a..497aba6dd 100644 --- a/pgproto3/notice_response.go +++ b/pgproto3/notice_response.go @@ -2,12 +2,18 @@ package pgproto3 type NoticeResponse ErrorResponse +// Backend identifies this message as sendable by the PostgreSQL backend. func (*NoticeResponse) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *NoticeResponse) Decode(src []byte) error { return (*ErrorResponse)(dst).Decode(src) } -func (src *NoticeResponse) Encode(dst []byte) []byte { - return append(dst, (*ErrorResponse)(src).marshalBinary('N')...) +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *NoticeResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'N') + dst = (*ErrorResponse)(src).appendFields(dst) + return finishMessage(dst, sp) } diff --git a/pgproto3/notification_response.go b/pgproto3/notification_response.go index b14007b48..243b6bf7c 100644 --- a/pgproto3/notification_response.go +++ b/pgproto3/notification_response.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "encoding/json" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type NotificationResponse struct { @@ -14,11 +14,18 @@ type NotificationResponse struct { Payload string } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*NotificationResponse) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *NotificationResponse) Decode(src []byte) error { buf := bytes.NewBuffer(src) + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "NotificationResponse", details: "too short"} + } + pid := binary.BigEndian.Uint32(buf.Next(4)) b, err := buf.ReadBytes(0) @@ -37,22 +44,19 @@ func (dst *NotificationResponse) Decode(src []byte) error { return nil } -func (src *NotificationResponse) Encode(dst []byte) []byte { - dst = append(dst, 'A') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *NotificationResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'A') + dst = pgio.AppendUint32(dst, src.PID) dst = append(dst, src.Channel...) dst = append(dst, 0) dst = append(dst, src.Payload...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } -func (src *NotificationResponse) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src NotificationResponse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string PID uint32 diff --git a/pgproto3/parameter_description.go b/pgproto3/parameter_description.go index 1fa3c927e..1ef27b75f 100644 --- a/pgproto3/parameter_description.go +++ b/pgproto3/parameter_description.go @@ -4,16 +4,21 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type ParameterDescription struct { ParameterOIDs []uint32 } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*ParameterDescription) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *ParameterDescription) Decode(src []byte) error { buf := bytes.NewBuffer(src) @@ -35,22 +40,23 @@ func (dst *ParameterDescription) Decode(src []byte) error { return nil } -func (src *ParameterDescription) Encode(dst []byte) []byte { - dst = append(dst, 't') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *ParameterDescription) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 't') + if len(src.ParameterOIDs) > math.MaxUint16 { + return nil, errors.New("too many parameter oids") + } dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) for _, oid := range src.ParameterOIDs { dst = pgio.AppendUint32(dst, oid) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } -func (src *ParameterDescription) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src ParameterDescription) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ParameterOIDs []uint32 diff --git a/pgproto3/parameter_status.go b/pgproto3/parameter_status.go index b3bac33f3..9ee0720b5 100644 --- a/pgproto3/parameter_status.go +++ b/pgproto3/parameter_status.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgx/pgio" ) type ParameterStatus struct { @@ -12,8 +10,11 @@ type ParameterStatus struct { Value string } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*ParameterStatus) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *ParameterStatus) Decode(src []byte) error { buf := bytes.NewBuffer(src) @@ -33,22 +34,18 @@ func (dst *ParameterStatus) Decode(src []byte) error { return nil } -func (src *ParameterStatus) Encode(dst []byte) []byte { - dst = append(dst, 'S') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *ParameterStatus) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'S') dst = append(dst, src.Name...) dst = append(dst, 0) dst = append(dst, src.Value...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } -func (ps *ParameterStatus) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (ps ParameterStatus) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string Name string diff --git a/pgproto3/parse.go b/pgproto3/parse.go index ca4834c68..6ba3486cf 100644 --- a/pgproto3/parse.go +++ b/pgproto3/parse.go @@ -4,8 +4,10 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type Parse struct { @@ -14,8 +16,11 @@ type Parse struct { ParameterOIDs []uint32 } +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*Parse) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *Parse) Decode(src []byte) error { *dst = Parse{} @@ -48,27 +53,28 @@ func (dst *Parse) Decode(src []byte) error { return nil } -func (src *Parse) Encode(dst []byte) []byte { - dst = append(dst, 'P') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Parse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'P') dst = append(dst, src.Name...) dst = append(dst, 0) dst = append(dst, src.Query...) dst = append(dst, 0) + if len(src.ParameterOIDs) > math.MaxUint16 { + return nil, errors.New("too many parameter oids") + } dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) for _, oid := range src.ParameterOIDs { dst = pgio.AppendUint32(dst, oid) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } -func (src *Parse) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src Parse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string Name string diff --git a/pgproto3/parse_complete.go b/pgproto3/parse_complete.go index 462a89ba0..cff9e27d0 100644 --- a/pgproto3/parse_complete.go +++ b/pgproto3/parse_complete.go @@ -6,8 +6,11 @@ import ( type ParseComplete struct{} +// Backend identifies this message as sendable by the PostgreSQL backend. func (*ParseComplete) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *ParseComplete) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)} @@ -16,11 +19,13 @@ func (dst *ParseComplete) Decode(src []byte) error { return nil } -func (src *ParseComplete) Encode(dst []byte) []byte { - return append(dst, '1', 0, 0, 0, 4) +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *ParseComplete) Encode(dst []byte) ([]byte, error) { + return append(dst, '1', 0, 0, 0, 4), nil } -func (src *ParseComplete) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src ParseComplete) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string }{ diff --git a/pgproto3/password_message.go b/pgproto3/password_message.go index 2ad3fe4a7..67b78515d 100644 --- a/pgproto3/password_message.go +++ b/pgproto3/password_message.go @@ -3,16 +3,20 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgx/pgio" ) type PasswordMessage struct { Password string } +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*PasswordMessage) Frontend() {} +// InitialResponse identifies this message as an authentication response. +func (*PasswordMessage) InitialResponse() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *PasswordMessage) Decode(src []byte) error { buf := bytes.NewBuffer(src) @@ -25,17 +29,16 @@ func (dst *PasswordMessage) Decode(src []byte) error { return nil } -func (src *PasswordMessage) Encode(dst []byte) []byte { - dst = append(dst, 'p') - dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1)) - +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *PasswordMessage) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') dst = append(dst, src.Password...) dst = append(dst, 0) - - return dst + return finishMessage(dst, sp) } -func (src *PasswordMessage) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src PasswordMessage) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string Password string diff --git a/pgproto3/pgproto3.go b/pgproto3/pgproto3.go index fe7b085bc..128f97f87 100644 --- a/pgproto3/pgproto3.go +++ b/pgproto3/pgproto3.go @@ -1,6 +1,16 @@ package pgproto3 -import "fmt" +import ( + "encoding/hex" + "errors" + "fmt" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// maxMessageBodyLen is the maximum length of a message body in bytes. See PG_LARGE_MESSAGE_LIMIT in the PostgreSQL +// source. It is defined as (MaxAllocSize - 1). MaxAllocSize is defined as 0x3fffffff. +const maxMessageBodyLen = (0x3fffffff - 1) // Message is the interface implemented by an object that can decode and encode // a particular PostgreSQL message. @@ -10,19 +20,26 @@ type Message interface { Decode(data []byte) error // Encode appends itself to dst and returns the new buffer. - Encode(dst []byte) []byte + Encode(dst []byte) ([]byte, error) } +// FrontendMessage is a message sent by the frontend (i.e. the client). type FrontendMessage interface { Message Frontend() // no-op method to distinguish frontend from backend methods } +// BackendMessage is a message sent by the backend (i.e. the server). type BackendMessage interface { Message Backend() // no-op method to distinguish frontend from backend methods } +type AuthenticationResponseMessage interface { + BackendMessage + AuthenticationResponse() // no-op method to distinguish authentication responses +} + type invalidMessageLenErr struct { messageType string expectedLen int @@ -35,8 +52,69 @@ func (e *invalidMessageLenErr) Error() string { type invalidMessageFormatErr struct { messageType string + details string } func (e *invalidMessageFormatErr) Error() string { - return fmt.Sprintf("%s body is invalid", e.messageType) + return fmt.Sprintf("%s body is invalid %s", e.messageType, e.details) +} + +type writeError struct { + err error + safeToRetry bool +} + +func (e *writeError) Error() string { + return fmt.Sprintf("write failed: %s", e.err.Error()) +} + +func (e *writeError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *writeError) Unwrap() error { + return e.err +} + +type ExceededMaxBodyLenErr struct { + MaxExpectedBodyLen int + ActualBodyLen int +} + +func (e *ExceededMaxBodyLenErr) Error() string { + return fmt.Sprintf("invalid body length: expected at most %d, but got %d", e.MaxExpectedBodyLen, e.ActualBodyLen) +} + +// getValueFromJSON gets the value from a protocol message representation in JSON. +func getValueFromJSON(v map[string]string) ([]byte, error) { + if v == nil { + return nil, nil + } + if text, ok := v["text"]; ok { + return []byte(text), nil + } + if binary, ok := v["binary"]; ok { + return hex.DecodeString(binary) + } + return nil, errors.New("unknown protocol representation") +} + +// beginMessage begins a new message of type t. It appends the message type and a placeholder for the message length to +// dst. It returns the new buffer and the position of the message length placeholder. +func beginMessage(dst []byte, t byte) ([]byte, int) { + dst = append(dst, t) + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + return dst, sp +} + +// finishMessage finishes a message that was started with beginMessage. It computes the message length and writes it to +// dst[sp]. If the message length is too large it returns an error. Otherwise it returns the final message buffer. +func finishMessage(dst []byte, sp int) ([]byte, error) { + messageBodyLen := len(dst[sp:]) + if messageBodyLen > maxMessageBodyLen { + return nil, errors.New("message body too large") + } + pgio.SetInt32(dst[sp:], int32(messageBodyLen)) + return dst, nil } diff --git a/pgproto3/pgproto3_private_test.go b/pgproto3/pgproto3_private_test.go new file mode 100644 index 000000000..15da1eafb --- /dev/null +++ b/pgproto3/pgproto3_private_test.go @@ -0,0 +1,3 @@ +package pgproto3 + +const MaxMessageBodyLen = maxMessageBodyLen diff --git a/pgproto3/portal_suspended.go b/pgproto3/portal_suspended.go new file mode 100644 index 000000000..9e2f8cbc4 --- /dev/null +++ b/pgproto3/portal_suspended.go @@ -0,0 +1,34 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type PortalSuspended struct{} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*PortalSuspended) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *PortalSuspended) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "PortalSuspended", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *PortalSuspended) Encode(dst []byte) ([]byte, error) { + return append(dst, 's', 0, 0, 0, 4), nil +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src PortalSuspended) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "PortalSuspended", + }) +} diff --git a/pgproto3/query.go b/pgproto3/query.go index d80c0fb4d..aebdfde89 100644 --- a/pgproto3/query.go +++ b/pgproto3/query.go @@ -3,16 +3,17 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgx/pgio" ) type Query struct { String string } +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*Query) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *Query) Decode(src []byte) error { i := bytes.IndexByte(src, 0) if i != len(src)-1 { @@ -24,17 +25,16 @@ func (dst *Query) Decode(src []byte) error { return nil } -func (src *Query) Encode(dst []byte) []byte { - dst = append(dst, 'Q') - dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1)) - +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Query) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'Q') dst = append(dst, src.String...) dst = append(dst, 0) - - return dst + return finishMessage(dst, sp) } -func (src *Query) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src Query) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string String string diff --git a/pgproto3/query_test.go b/pgproto3/query_test.go new file mode 100644 index 000000000..9551fc14d --- /dev/null +++ b/pgproto3/query_test.go @@ -0,0 +1,20 @@ +package pgproto3_test + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/require" +) + +func TestQueryBiggerThanMaxMessageBodyLen(t *testing.T) { + t.Parallel() + + // Maximum allowed size. 4 bytes for size and 1 byte for 0 terminated string. + _, err := (&pgproto3.Query{String: string(make([]byte, pgproto3.MaxMessageBodyLen-5))}).Encode(nil) + require.NoError(t, err) + + // 1 byte too big + _, err = (&pgproto3.Query{String: string(make([]byte, pgproto3.MaxMessageBodyLen-4))}).Encode(nil) + require.Error(t, err) +} diff --git a/pgproto3/ready_for_query.go b/pgproto3/ready_for_query.go index 63b902bdf..a56af9fb2 100644 --- a/pgproto3/ready_for_query.go +++ b/pgproto3/ready_for_query.go @@ -2,14 +2,18 @@ package pgproto3 import ( "encoding/json" + "errors" ) type ReadyForQuery struct { TxStatus byte } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*ReadyForQuery) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *ReadyForQuery) Decode(src []byte) error { if len(src) != 1 { return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)} @@ -20,11 +24,13 @@ func (dst *ReadyForQuery) Decode(src []byte) error { return nil } -func (src *ReadyForQuery) Encode(dst []byte) []byte { - return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus) +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *ReadyForQuery) Encode(dst []byte) ([]byte, error) { + return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus), nil } -func (src *ReadyForQuery) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src ReadyForQuery) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string TxStatus string @@ -33,3 +39,23 @@ func (src *ReadyForQuery) MarshalJSON() ([]byte, error) { TxStatus: string(src.TxStatus), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *ReadyForQuery) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + TxStatus string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + if len(msg.TxStatus) != 1 { + return errors.New("invalid length for ReadyForQuery.TxStatus") + } + dst.TxStatus = msg.TxStatus[0] + return nil +} diff --git a/pgproto3/row_description.go b/pgproto3/row_description.go index d0df11b0a..c40a2261b 100644 --- a/pgproto3/row_description.go +++ b/pgproto3/row_description.go @@ -4,8 +4,10 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) const ( @@ -14,63 +16,96 @@ const ( ) type FieldDescription struct { - Name string + Name []byte TableOID uint32 TableAttributeNumber uint16 DataTypeOID uint32 DataTypeSize int16 - TypeModifier uint32 + TypeModifier int32 Format int16 } +// MarshalJSON implements encoding/json.Marshaler. +func (fd FieldDescription) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Name string + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier int32 + Format int16 + }{ + Name: string(fd.Name), + TableOID: fd.TableOID, + TableAttributeNumber: fd.TableAttributeNumber, + DataTypeOID: fd.DataTypeOID, + DataTypeSize: fd.DataTypeSize, + TypeModifier: fd.TypeModifier, + Format: fd.Format, + }) +} + type RowDescription struct { Fields []FieldDescription } +// Backend identifies this message as sendable by the PostgreSQL backend. func (*RowDescription) Backend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *RowDescription) Decode(src []byte) error { - buf := bytes.NewBuffer(src) - - if buf.Len() < 2 { + if len(src) < 2 { return &invalidMessageFormatErr{messageType: "RowDescription"} } - fieldCount := int(binary.BigEndian.Uint16(buf.Next(2))) + fieldCount := int(binary.BigEndian.Uint16(src)) + rp := 2 - *dst = RowDescription{Fields: make([]FieldDescription, fieldCount)} + dst.Fields = dst.Fields[0:0] for i := 0; i < fieldCount; i++ { var fd FieldDescription - bName, err := buf.ReadBytes(0) - if err != nil { - return err + + idx := bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "RowDescription"} } - fd.Name = string(bName[:len(bName)-1]) + fd.Name = src[rp : rp+idx] + rp += idx + 1 // Since buf.Next() doesn't return an error if we hit the end of the buffer // check Len ahead of time - if buf.Len() < 18 { + if len(src[rp:]) < 18 { return &invalidMessageFormatErr{messageType: "RowDescription"} } - fd.TableOID = binary.BigEndian.Uint32(buf.Next(4)) - fd.TableAttributeNumber = binary.BigEndian.Uint16(buf.Next(2)) - fd.DataTypeOID = binary.BigEndian.Uint32(buf.Next(4)) - fd.DataTypeSize = int16(binary.BigEndian.Uint16(buf.Next(2))) - fd.TypeModifier = binary.BigEndian.Uint32(buf.Next(4)) - fd.Format = int16(binary.BigEndian.Uint16(buf.Next(2))) - - dst.Fields[i] = fd + fd.TableOID = binary.BigEndian.Uint32(src[rp:]) + rp += 4 + fd.TableAttributeNumber = binary.BigEndian.Uint16(src[rp:]) + rp += 2 + fd.DataTypeOID = binary.BigEndian.Uint32(src[rp:]) + rp += 4 + fd.DataTypeSize = int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + fd.TypeModifier = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + fd.Format = int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + dst.Fields = append(dst.Fields, fd) } return nil } -func (src *RowDescription) Encode(dst []byte) []byte { - dst = append(dst, 'T') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *RowDescription) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'T') + if len(src.Fields) > math.MaxUint16 { + return nil, errors.New("too many fields") + } dst = pgio.AppendUint16(dst, uint16(len(src.Fields))) for _, fd := range src.Fields { dst = append(dst, fd.Name...) @@ -80,16 +115,15 @@ func (src *RowDescription) Encode(dst []byte) []byte { dst = pgio.AppendUint16(dst, fd.TableAttributeNumber) dst = pgio.AppendUint32(dst, fd.DataTypeOID) dst = pgio.AppendInt16(dst, fd.DataTypeSize) - dst = pgio.AppendUint32(dst, fd.TypeModifier) + dst = pgio.AppendInt32(dst, fd.TypeModifier) dst = pgio.AppendInt16(dst, fd.Format) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } -func (src *RowDescription) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src RowDescription) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string Fields []FieldDescription @@ -98,3 +132,34 @@ func (src *RowDescription) MarshalJSON() ([]byte, error) { Fields: src.Fields, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *RowDescription) UnmarshalJSON(data []byte) error { + var msg struct { + Fields []struct { + Name string + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier int32 + Format int16 + } + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + dst.Fields = make([]FieldDescription, len(msg.Fields)) + for n, field := range msg.Fields { + dst.Fields[n] = FieldDescription{ + Name: []byte(field.Name), + TableOID: field.TableOID, + TableAttributeNumber: field.TableAttributeNumber, + DataTypeOID: field.DataTypeOID, + DataTypeSize: field.DataTypeSize, + TypeModifier: field.TypeModifier, + Format: field.Format, + } + } + return nil +} diff --git a/pgproto3/sasl_initial_response.go b/pgproto3/sasl_initial_response.go new file mode 100644 index 000000000..9eb1b6a4b --- /dev/null +++ b/pgproto3/sasl_initial_response.go @@ -0,0 +1,90 @@ +package pgproto3 + +import ( + "bytes" + "encoding/hex" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type SASLInitialResponse struct { + AuthMechanism string + Data []byte +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*SASLInitialResponse) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *SASLInitialResponse) Decode(src []byte) error { + *dst = SASLInitialResponse{} + + rp := 0 + + idx := bytes.IndexByte(src, 0) + if idx < 0 { + return errors.New("invalid SASLInitialResponse") + } + + dst.AuthMechanism = string(src[rp:idx]) + rp = idx + 1 + + rp += 4 // The rest of the message is data so we can just skip the size + dst.Data = src[rp:] + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *SASLInitialResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') + + dst = append(dst, []byte(src.AuthMechanism)...) + dst = append(dst, 0) + + dst = pgio.AppendInt32(dst, int32(len(src.Data))) + dst = append(dst, src.Data...) + + return finishMessage(dst, sp) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src SASLInitialResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + AuthMechanism string + Data string + }{ + Type: "SASLInitialResponse", + AuthMechanism: src.AuthMechanism, + Data: string(src.Data), + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *SASLInitialResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + AuthMechanism string + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + dst.AuthMechanism = msg.AuthMechanism + if msg.Data != "" { + decoded, err := hex.DecodeString(msg.Data) + if err != nil { + return err + } + dst.Data = decoded + } + return nil +} diff --git a/pgproto3/sasl_response.go b/pgproto3/sasl_response.go new file mode 100644 index 000000000..1b604c254 --- /dev/null +++ b/pgproto3/sasl_response.go @@ -0,0 +1,56 @@ +package pgproto3 + +import ( + "encoding/hex" + "encoding/json" +) + +type SASLResponse struct { + Data []byte +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*SASLResponse) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *SASLResponse) Decode(src []byte) error { + *dst = SASLResponse{Data: src} + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *SASLResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') + dst = append(dst, src.Data...) + return finishMessage(dst, sp) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src SASLResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "SASLResponse", + Data: string(src.Data), + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *SASLResponse) UnmarshalJSON(data []byte) error { + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + if msg.Data != "" { + decoded, err := hex.DecodeString(msg.Data) + if err != nil { + return err + } + dst.Data = decoded + } + return nil +} diff --git a/pgproto3/ssl_request.go b/pgproto3/ssl_request.go new file mode 100644 index 000000000..bdfc7c427 --- /dev/null +++ b/pgproto3/ssl_request.go @@ -0,0 +1,48 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +const sslRequestNumber = 80877103 + +type SSLRequest struct{} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*SSLRequest) Frontend() {} + +func (dst *SSLRequest) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("ssl request too short") + } + + requestCode := binary.BigEndian.Uint32(src) + + if requestCode != sslRequestNumber { + return errors.New("bad ssl request code") + } + + return nil +} + +// Encode encodes src into dst. dst will include the 4 byte message length. +func (src *SSLRequest) Encode(dst []byte) ([]byte, error) { + dst = pgio.AppendInt32(dst, 8) + dst = pgio.AppendInt32(dst, sslRequestNumber) + return dst, nil +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src SSLRequest) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProtocolVersion uint32 + Parameters map[string]string + }{ + Type: "SSLRequest", + }) +} diff --git a/pgproto3/startup_message.go b/pgproto3/startup_message.go index 6c5d4f99e..3af4587d8 100644 --- a/pgproto3/startup_message.go +++ b/pgproto3/startup_message.go @@ -4,51 +4,48 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "fmt" - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" + "github.com/jackc/pgx/v5/internal/pgio" ) -const ( - ProtocolVersionNumber = 196608 // 3.0 - sslRequestNumber = 80877103 -) +const ProtocolVersionNumber = 196608 // 3.0 type StartupMessage struct { ProtocolVersion uint32 Parameters map[string]string } +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*StartupMessage) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *StartupMessage) Decode(src []byte) error { if len(src) < 4 { - return errors.Errorf("startup message too short") + return errors.New("startup message too short") } dst.ProtocolVersion = binary.BigEndian.Uint32(src) rp := 4 - if dst.ProtocolVersion == sslRequestNumber { - return errors.Errorf("can't handle ssl connection request") - } - if dst.ProtocolVersion != ProtocolVersionNumber { - return errors.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) + return fmt.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) } dst.Parameters = make(map[string]string) for { idx := bytes.IndexByte(src[rp:], 0) if idx < 0 { - return &invalidMessageFormatErr{messageType: "StartupMesage"} + return &invalidMessageFormatErr{messageType: "StartupMessage"} } key := string(src[rp : rp+idx]) rp += idx + 1 idx = bytes.IndexByte(src[rp:], 0) if idx < 0 { - return &invalidMessageFormatErr{messageType: "StartupMesage"} + return &invalidMessageFormatErr{messageType: "StartupMessage"} } value := string(src[rp : rp+idx]) rp += idx + 1 @@ -57,7 +54,7 @@ func (dst *StartupMessage) Decode(src []byte) error { if len(src[rp:]) == 1 { if src[rp] != 0 { - return errors.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp]) + return fmt.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp]) } break } @@ -66,7 +63,8 @@ func (dst *StartupMessage) Decode(src []byte) error { return nil } -func (src *StartupMessage) Encode(dst []byte) []byte { +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *StartupMessage) Encode(dst []byte) ([]byte, error) { sp := len(dst) dst = pgio.AppendInt32(dst, -1) @@ -79,12 +77,11 @@ func (src *StartupMessage) Encode(dst []byte) []byte { } dst = append(dst, 0) - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } -func (src *StartupMessage) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src StartupMessage) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ProtocolVersion uint32 diff --git a/pgproto3/sync.go b/pgproto3/sync.go index 85f4749a4..ea4fc9594 100644 --- a/pgproto3/sync.go +++ b/pgproto3/sync.go @@ -6,8 +6,11 @@ import ( type Sync struct{} +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*Sync) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *Sync) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "Sync", expectedLen: 0, actualLen: len(src)} @@ -16,11 +19,13 @@ func (dst *Sync) Decode(src []byte) error { return nil } -func (src *Sync) Encode(dst []byte) []byte { - return append(dst, 'S', 0, 0, 0, 4) +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Sync) Encode(dst []byte) ([]byte, error) { + return append(dst, 'S', 0, 0, 0, 4), nil } -func (src *Sync) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src Sync) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string }{ diff --git a/pgproto3/terminate.go b/pgproto3/terminate.go index 0a3310da7..35a9dc837 100644 --- a/pgproto3/terminate.go +++ b/pgproto3/terminate.go @@ -6,8 +6,11 @@ import ( type Terminate struct{} +// Frontend identifies this message as sendable by a PostgreSQL frontend. func (*Terminate) Frontend() {} +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. func (dst *Terminate) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "Terminate", expectedLen: 0, actualLen: len(src)} @@ -16,11 +19,13 @@ func (dst *Terminate) Decode(src []byte) error { return nil } -func (src *Terminate) Encode(dst []byte) []byte { - return append(dst, 'X', 0, 0, 0, 4) +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Terminate) Encode(dst []byte) ([]byte, error) { + return append(dst, 'X', 0, 0, 0, 4), nil } -func (src *Terminate) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler. +func (src Terminate) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string }{ diff --git a/pgproto3/testdata/fuzz/FuzzFrontend/39c5e864da4707fc15fea48f7062d6a07796fdc43b33e0ba9dbd7074a0211fa6 b/pgproto3/testdata/fuzz/FuzzFrontend/39c5e864da4707fc15fea48f7062d6a07796fdc43b33e0ba9dbd7074a0211fa6 new file mode 100644 index 000000000..d1c612d35 --- /dev/null +++ b/pgproto3/testdata/fuzz/FuzzFrontend/39c5e864da4707fc15fea48f7062d6a07796fdc43b33e0ba9dbd7074a0211fa6 @@ -0,0 +1,4 @@ +go test fuzz v1 +byte('A') +uint32(5) +[]byte("0") diff --git a/pgproto3/testdata/fuzz/FuzzFrontend/9b06792b1aaac8a907dbfa04d526ae14326c8573b7409032caac8461e83065f7 b/pgproto3/testdata/fuzz/FuzzFrontend/9b06792b1aaac8a907dbfa04d526ae14326c8573b7409032caac8461e83065f7 new file mode 100644 index 000000000..763b70ae4 --- /dev/null +++ b/pgproto3/testdata/fuzz/FuzzFrontend/9b06792b1aaac8a907dbfa04d526ae14326c8573b7409032caac8461e83065f7 @@ -0,0 +1,4 @@ +go test fuzz v1 +byte('D') +uint32(21) +[]byte("00\xb300000000000000") diff --git a/pgproto3/testdata/fuzz/FuzzFrontend/a661fb98e802839f0a7361160fbc6e28794612a411d00bde104364ee281c4214 b/pgproto3/testdata/fuzz/FuzzFrontend/a661fb98e802839f0a7361160fbc6e28794612a411d00bde104364ee281c4214 new file mode 100644 index 000000000..3d995c281 --- /dev/null +++ b/pgproto3/testdata/fuzz/FuzzFrontend/a661fb98e802839f0a7361160fbc6e28794612a411d00bde104364ee281c4214 @@ -0,0 +1,4 @@ +go test fuzz v1 +byte('C') +uint32(4) +[]byte("0") diff --git a/pgproto3/testdata/fuzz/FuzzFrontend/fc98dcd487a5173b38763a5f7dd023933f3a86ab566e3f2b091eb36248107eb4 b/pgproto3/testdata/fuzz/FuzzFrontend/fc98dcd487a5173b38763a5f7dd023933f3a86ab566e3f2b091eb36248107eb4 new file mode 100644 index 000000000..45f0ba817 --- /dev/null +++ b/pgproto3/testdata/fuzz/FuzzFrontend/fc98dcd487a5173b38763a5f7dd023933f3a86ab566e3f2b091eb36248107eb4 @@ -0,0 +1,4 @@ +go test fuzz v1 +byte('R') +uint32(13) +[]byte("\x00\x00\x00\n0\x12\xebG\x8dI']G\xdac\x95\xb7\x18\xb0\x02\xe8m\xc2\x00\xef\x03\x12\x1b\xbdj\x10\x9f\xf9\xeb\xb8") diff --git a/pgproto3/trace.go b/pgproto3/trace.go new file mode 100644 index 000000000..6cc7d3e36 --- /dev/null +++ b/pgproto3/trace.go @@ -0,0 +1,416 @@ +package pgproto3 + +import ( + "bytes" + "fmt" + "io" + "strconv" + "strings" + "sync" + "time" +) + +// tracer traces the messages send to and from a Backend or Frontend. The format it produces roughly mimics the +// format produced by the libpq C function PQtrace. +type tracer struct { + TracerOptions + + mux sync.Mutex + w io.Writer + buf *bytes.Buffer +} + +// TracerOptions controls tracing behavior. It is roughly equivalent to the libpq function PQsetTraceFlags. +type TracerOptions struct { + // SuppressTimestamps prevents printing of timestamps. + SuppressTimestamps bool + + // RegressMode redacts fields that may be vary between executions. + RegressMode bool +} + +func (t *tracer) traceMessage(sender byte, encodedLen int32, msg Message) { + switch msg := msg.(type) { + case *AuthenticationCleartextPassword: + t.traceAuthenticationCleartextPassword(sender, encodedLen, msg) + case *AuthenticationGSS: + t.traceAuthenticationGSS(sender, encodedLen, msg) + case *AuthenticationGSSContinue: + t.traceAuthenticationGSSContinue(sender, encodedLen, msg) + case *AuthenticationMD5Password: + t.traceAuthenticationMD5Password(sender, encodedLen, msg) + case *AuthenticationOk: + t.traceAuthenticationOk(sender, encodedLen, msg) + case *AuthenticationSASL: + t.traceAuthenticationSASL(sender, encodedLen, msg) + case *AuthenticationSASLContinue: + t.traceAuthenticationSASLContinue(sender, encodedLen, msg) + case *AuthenticationSASLFinal: + t.traceAuthenticationSASLFinal(sender, encodedLen, msg) + case *BackendKeyData: + t.traceBackendKeyData(sender, encodedLen, msg) + case *Bind: + t.traceBind(sender, encodedLen, msg) + case *BindComplete: + t.traceBindComplete(sender, encodedLen, msg) + case *CancelRequest: + t.traceCancelRequest(sender, encodedLen, msg) + case *Close: + t.traceClose(sender, encodedLen, msg) + case *CloseComplete: + t.traceCloseComplete(sender, encodedLen, msg) + case *CommandComplete: + t.traceCommandComplete(sender, encodedLen, msg) + case *CopyBothResponse: + t.traceCopyBothResponse(sender, encodedLen, msg) + case *CopyData: + t.traceCopyData(sender, encodedLen, msg) + case *CopyDone: + t.traceCopyDone(sender, encodedLen, msg) + case *CopyFail: + t.traceCopyFail(sender, encodedLen, msg) + case *CopyInResponse: + t.traceCopyInResponse(sender, encodedLen, msg) + case *CopyOutResponse: + t.traceCopyOutResponse(sender, encodedLen, msg) + case *DataRow: + t.traceDataRow(sender, encodedLen, msg) + case *Describe: + t.traceDescribe(sender, encodedLen, msg) + case *EmptyQueryResponse: + t.traceEmptyQueryResponse(sender, encodedLen, msg) + case *ErrorResponse: + t.traceErrorResponse(sender, encodedLen, msg) + case *Execute: + t.TraceQueryute(sender, encodedLen, msg) + case *Flush: + t.traceFlush(sender, encodedLen, msg) + case *FunctionCall: + t.traceFunctionCall(sender, encodedLen, msg) + case *FunctionCallResponse: + t.traceFunctionCallResponse(sender, encodedLen, msg) + case *GSSEncRequest: + t.traceGSSEncRequest(sender, encodedLen, msg) + case *NoData: + t.traceNoData(sender, encodedLen, msg) + case *NoticeResponse: + t.traceNoticeResponse(sender, encodedLen, msg) + case *NotificationResponse: + t.traceNotificationResponse(sender, encodedLen, msg) + case *ParameterDescription: + t.traceParameterDescription(sender, encodedLen, msg) + case *ParameterStatus: + t.traceParameterStatus(sender, encodedLen, msg) + case *Parse: + t.traceParse(sender, encodedLen, msg) + case *ParseComplete: + t.traceParseComplete(sender, encodedLen, msg) + case *PortalSuspended: + t.tracePortalSuspended(sender, encodedLen, msg) + case *Query: + t.traceQuery(sender, encodedLen, msg) + case *ReadyForQuery: + t.traceReadyForQuery(sender, encodedLen, msg) + case *RowDescription: + t.traceRowDescription(sender, encodedLen, msg) + case *SSLRequest: + t.traceSSLRequest(sender, encodedLen, msg) + case *StartupMessage: + t.traceStartupMessage(sender, encodedLen, msg) + case *Sync: + t.traceSync(sender, encodedLen, msg) + case *Terminate: + t.traceTerminate(sender, encodedLen, msg) + default: + t.writeTrace(sender, encodedLen, "Unknown", nil) + } +} + +func (t *tracer) traceAuthenticationCleartextPassword(sender byte, encodedLen int32, msg *AuthenticationCleartextPassword) { + t.writeTrace(sender, encodedLen, "AuthenticationCleartextPassword", nil) +} + +func (t *tracer) traceAuthenticationGSS(sender byte, encodedLen int32, msg *AuthenticationGSS) { + t.writeTrace(sender, encodedLen, "AuthenticationGSS", nil) +} + +func (t *tracer) traceAuthenticationGSSContinue(sender byte, encodedLen int32, msg *AuthenticationGSSContinue) { + t.writeTrace(sender, encodedLen, "AuthenticationGSSContinue", nil) +} + +func (t *tracer) traceAuthenticationMD5Password(sender byte, encodedLen int32, msg *AuthenticationMD5Password) { + t.writeTrace(sender, encodedLen, "AuthenticationMD5Password", nil) +} + +func (t *tracer) traceAuthenticationOk(sender byte, encodedLen int32, msg *AuthenticationOk) { + t.writeTrace(sender, encodedLen, "AuthenticationOk", nil) +} + +func (t *tracer) traceAuthenticationSASL(sender byte, encodedLen int32, msg *AuthenticationSASL) { + t.writeTrace(sender, encodedLen, "AuthenticationSASL", nil) +} + +func (t *tracer) traceAuthenticationSASLContinue(sender byte, encodedLen int32, msg *AuthenticationSASLContinue) { + t.writeTrace(sender, encodedLen, "AuthenticationSASLContinue", nil) +} + +func (t *tracer) traceAuthenticationSASLFinal(sender byte, encodedLen int32, msg *AuthenticationSASLFinal) { + t.writeTrace(sender, encodedLen, "AuthenticationSASLFinal", nil) +} + +func (t *tracer) traceBackendKeyData(sender byte, encodedLen int32, msg *BackendKeyData) { + t.writeTrace(sender, encodedLen, "BackendKeyData", func() { + if t.RegressMode { + t.buf.WriteString("\t NNNN NNNN") + } else { + fmt.Fprintf(t.buf, "\t %d %d", msg.ProcessID, msg.SecretKey) + } + }) +} + +func (t *tracer) traceBind(sender byte, encodedLen int32, msg *Bind) { + t.writeTrace(sender, encodedLen, "Bind", func() { + fmt.Fprintf(t.buf, "\t %s %s %d", traceDoubleQuotedString([]byte(msg.DestinationPortal)), traceDoubleQuotedString([]byte(msg.PreparedStatement)), len(msg.ParameterFormatCodes)) + for _, fc := range msg.ParameterFormatCodes { + fmt.Fprintf(t.buf, " %d", fc) + } + fmt.Fprintf(t.buf, " %d", len(msg.Parameters)) + for _, p := range msg.Parameters { + fmt.Fprintf(t.buf, " %s", traceSingleQuotedString(p)) + } + fmt.Fprintf(t.buf, " %d", len(msg.ResultFormatCodes)) + for _, fc := range msg.ResultFormatCodes { + fmt.Fprintf(t.buf, " %d", fc) + } + }) +} + +func (t *tracer) traceBindComplete(sender byte, encodedLen int32, msg *BindComplete) { + t.writeTrace(sender, encodedLen, "BindComplete", nil) +} + +func (t *tracer) traceCancelRequest(sender byte, encodedLen int32, msg *CancelRequest) { + t.writeTrace(sender, encodedLen, "CancelRequest", nil) +} + +func (t *tracer) traceClose(sender byte, encodedLen int32, msg *Close) { + t.writeTrace(sender, encodedLen, "Close", nil) +} + +func (t *tracer) traceCloseComplete(sender byte, encodedLen int32, msg *CloseComplete) { + t.writeTrace(sender, encodedLen, "CloseComplete", nil) +} + +func (t *tracer) traceCommandComplete(sender byte, encodedLen int32, msg *CommandComplete) { + t.writeTrace(sender, encodedLen, "CommandComplete", func() { + fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString(msg.CommandTag)) + }) +} + +func (t *tracer) traceCopyBothResponse(sender byte, encodedLen int32, msg *CopyBothResponse) { + t.writeTrace(sender, encodedLen, "CopyBothResponse", nil) +} + +func (t *tracer) traceCopyData(sender byte, encodedLen int32, msg *CopyData) { + t.writeTrace(sender, encodedLen, "CopyData", nil) +} + +func (t *tracer) traceCopyDone(sender byte, encodedLen int32, msg *CopyDone) { + t.writeTrace(sender, encodedLen, "CopyDone", nil) +} + +func (t *tracer) traceCopyFail(sender byte, encodedLen int32, msg *CopyFail) { + t.writeTrace(sender, encodedLen, "CopyFail", func() { + fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString([]byte(msg.Message))) + }) +} + +func (t *tracer) traceCopyInResponse(sender byte, encodedLen int32, msg *CopyInResponse) { + t.writeTrace(sender, encodedLen, "CopyInResponse", nil) +} + +func (t *tracer) traceCopyOutResponse(sender byte, encodedLen int32, msg *CopyOutResponse) { + t.writeTrace(sender, encodedLen, "CopyOutResponse", nil) +} + +func (t *tracer) traceDataRow(sender byte, encodedLen int32, msg *DataRow) { + t.writeTrace(sender, encodedLen, "DataRow", func() { + fmt.Fprintf(t.buf, "\t %d", len(msg.Values)) + for _, v := range msg.Values { + if v == nil { + t.buf.WriteString(" -1") + } else { + fmt.Fprintf(t.buf, " %d %s", len(v), traceSingleQuotedString(v)) + } + } + }) +} + +func (t *tracer) traceDescribe(sender byte, encodedLen int32, msg *Describe) { + t.writeTrace(sender, encodedLen, "Describe", func() { + fmt.Fprintf(t.buf, "\t %c %s", msg.ObjectType, traceDoubleQuotedString([]byte(msg.Name))) + }) +} + +func (t *tracer) traceEmptyQueryResponse(sender byte, encodedLen int32, msg *EmptyQueryResponse) { + t.writeTrace(sender, encodedLen, "EmptyQueryResponse", nil) +} + +func (t *tracer) traceErrorResponse(sender byte, encodedLen int32, msg *ErrorResponse) { + t.writeTrace(sender, encodedLen, "ErrorResponse", nil) +} + +func (t *tracer) TraceQueryute(sender byte, encodedLen int32, msg *Execute) { + t.writeTrace(sender, encodedLen, "Execute", func() { + fmt.Fprintf(t.buf, "\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows) + }) +} + +func (t *tracer) traceFlush(sender byte, encodedLen int32, msg *Flush) { + t.writeTrace(sender, encodedLen, "Flush", nil) +} + +func (t *tracer) traceFunctionCall(sender byte, encodedLen int32, msg *FunctionCall) { + t.writeTrace(sender, encodedLen, "FunctionCall", nil) +} + +func (t *tracer) traceFunctionCallResponse(sender byte, encodedLen int32, msg *FunctionCallResponse) { + t.writeTrace(sender, encodedLen, "FunctionCallResponse", nil) +} + +func (t *tracer) traceGSSEncRequest(sender byte, encodedLen int32, msg *GSSEncRequest) { + t.writeTrace(sender, encodedLen, "GSSEncRequest", nil) +} + +func (t *tracer) traceNoData(sender byte, encodedLen int32, msg *NoData) { + t.writeTrace(sender, encodedLen, "NoData", nil) +} + +func (t *tracer) traceNoticeResponse(sender byte, encodedLen int32, msg *NoticeResponse) { + t.writeTrace(sender, encodedLen, "NoticeResponse", nil) +} + +func (t *tracer) traceNotificationResponse(sender byte, encodedLen int32, msg *NotificationResponse) { + t.writeTrace(sender, encodedLen, "NotificationResponse", func() { + fmt.Fprintf(t.buf, "\t %d %s %s", msg.PID, traceDoubleQuotedString([]byte(msg.Channel)), traceDoubleQuotedString([]byte(msg.Payload))) + }) +} + +func (t *tracer) traceParameterDescription(sender byte, encodedLen int32, msg *ParameterDescription) { + t.writeTrace(sender, encodedLen, "ParameterDescription", nil) +} + +func (t *tracer) traceParameterStatus(sender byte, encodedLen int32, msg *ParameterStatus) { + t.writeTrace(sender, encodedLen, "ParameterStatus", func() { + fmt.Fprintf(t.buf, "\t %s %s", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Value))) + }) +} + +func (t *tracer) traceParse(sender byte, encodedLen int32, msg *Parse) { + t.writeTrace(sender, encodedLen, "Parse", func() { + fmt.Fprintf(t.buf, "\t %s %s %d", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Query)), len(msg.ParameterOIDs)) + for _, oid := range msg.ParameterOIDs { + fmt.Fprintf(t.buf, " %d", oid) + } + }) +} + +func (t *tracer) traceParseComplete(sender byte, encodedLen int32, msg *ParseComplete) { + t.writeTrace(sender, encodedLen, "ParseComplete", nil) +} + +func (t *tracer) tracePortalSuspended(sender byte, encodedLen int32, msg *PortalSuspended) { + t.writeTrace(sender, encodedLen, "PortalSuspended", nil) +} + +func (t *tracer) traceQuery(sender byte, encodedLen int32, msg *Query) { + t.writeTrace(sender, encodedLen, "Query", func() { + fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString([]byte(msg.String))) + }) +} + +func (t *tracer) traceReadyForQuery(sender byte, encodedLen int32, msg *ReadyForQuery) { + t.writeTrace(sender, encodedLen, "ReadyForQuery", func() { + fmt.Fprintf(t.buf, "\t %c", msg.TxStatus) + }) +} + +func (t *tracer) traceRowDescription(sender byte, encodedLen int32, msg *RowDescription) { + t.writeTrace(sender, encodedLen, "RowDescription", func() { + fmt.Fprintf(t.buf, "\t %d", len(msg.Fields)) + for _, fd := range msg.Fields { + fmt.Fprintf(t.buf, ` %s %d %d %d %d %d %d`, traceDoubleQuotedString(fd.Name), fd.TableOID, fd.TableAttributeNumber, fd.DataTypeOID, fd.DataTypeSize, fd.TypeModifier, fd.Format) + } + }) +} + +func (t *tracer) traceSSLRequest(sender byte, encodedLen int32, msg *SSLRequest) { + t.writeTrace(sender, encodedLen, "SSLRequest", nil) +} + +func (t *tracer) traceStartupMessage(sender byte, encodedLen int32, msg *StartupMessage) { + t.writeTrace(sender, encodedLen, "StartupMessage", nil) +} + +func (t *tracer) traceSync(sender byte, encodedLen int32, msg *Sync) { + t.writeTrace(sender, encodedLen, "Sync", nil) +} + +func (t *tracer) traceTerminate(sender byte, encodedLen int32, msg *Terminate) { + t.writeTrace(sender, encodedLen, "Terminate", nil) +} + +func (t *tracer) writeTrace(sender byte, encodedLen int32, msgType string, writeDetails func()) { + t.mux.Lock() + defer t.mux.Unlock() + defer func() { + if t.buf.Cap() > 1024 { + t.buf = &bytes.Buffer{} + } else { + t.buf.Reset() + } + }() + + if !t.SuppressTimestamps { + now := time.Now() + t.buf.WriteString(now.Format("2006-01-02 15:04:05.000000")) + t.buf.WriteByte('\t') + } + + t.buf.WriteByte(sender) + t.buf.WriteByte('\t') + t.buf.WriteString(msgType) + t.buf.WriteByte('\t') + t.buf.WriteString(strconv.FormatInt(int64(encodedLen), 10)) + + if writeDetails != nil { + writeDetails() + } + + t.buf.WriteByte('\n') + t.buf.WriteTo(t.w) +} + +// traceDoubleQuotedString returns t.buf as a double-quoted string without any escaping. It is roughly equivalent to +// pqTraceOutputString in libpq. +func traceDoubleQuotedString(buf []byte) string { + return `"` + string(buf) + `"` +} + +// traceSingleQuotedString returns buf as a single-quoted string with non-printable characters hex-escaped. It is +// roughly equivalent to pqTraceOutputNchar in libpq. +func traceSingleQuotedString(buf []byte) string { + sb := &strings.Builder{} + + sb.WriteByte('\'') + for _, b := range buf { + if b < 32 || b > 126 { + fmt.Fprintf(sb, `\x%x`, b) + } else { + sb.WriteByte(b) + } + } + sb.WriteByte('\'') + + return sb.String() +} diff --git a/pgproto3/trace_test.go b/pgproto3/trace_test.go new file mode 100644 index 000000000..c56a49912 --- /dev/null +++ b/pgproto3/trace_test.go @@ -0,0 +1,56 @@ +package pgproto3_test + +import ( + "bytes" + "context" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/require" +) + +func TestTrace(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer conn.Close(ctx) + + if conn.ParameterStatus("crdb_version") != "" { + t.Skip("Skipping message trace on CockroachDB as it varies slightly from PostgreSQL") + } + + traceOutput := &bytes.Buffer{} + conn.Frontend().Trace(traceOutput, pgproto3.TracerOptions{ + SuppressTimestamps: true, + RegressMode: true, + }) + + result := conn.ExecParams(ctx, "select n from generate_series(1,5) n", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + expected := `F Parse 45 "" "select n from generate_series(1,5) n" 0 +F Bind 13 "" "" 0 0 0 +F Describe 7 P "" +F Execute 10 "" 0 +F Sync 5 +B ParseComplete 5 +B BindComplete 5 +B RowDescription 27 1 "n" 0 0 23 4 -1 0 +B DataRow 12 1 1 '1' +B DataRow 12 1 1 '2' +B DataRow 12 1 1 '3' +B DataRow 12 1 1 '4' +B DataRow 12 1 1 '5' +B CommandComplete 14 "SELECT 5" +B ReadyForQuery 6 I +` + + require.Equal(t, expected, traceOutput.String()) +} diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go deleted file mode 100644 index 35269e91f..000000000 --- a/pgtype/aclitem.go +++ /dev/null @@ -1,126 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - - "github.com/pkg/errors" -) - -// ACLItem is used for PostgreSQL's aclitem data type. A sample aclitem -// might look like this: -// -// postgres=arwdDxt/postgres -// -// Note, however, that because the user/role name part of an aclitem is -// an identifier, it follows all the usual formatting rules for SQL -// identifiers: if it contains spaces and other special characters, -// it should appear in double-quotes: -// -// postgres=arwdDxt/"role with spaces" -// -type ACLItem struct { - String string - Status Status -} - -func (dst *ACLItem) Set(src interface{}) error { - switch value := src.(type) { - case string: - *dst = ACLItem{String: value, Status: Present} - case *string: - if value == nil { - *dst = ACLItem{Status: Null} - } else { - *dst = ACLItem{String: *value, Status: Present} - } - default: - if originalSrc, ok := underlyingStringType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to ACLItem", value) - } - - return nil -} - -func (dst *ACLItem) Get() interface{} { - switch dst.Status { - case Present: - return dst.String - case Null: - return nil - default: - return dst.Status - } -} - -func (src *ACLItem) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *string: - *v = src.String - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *ACLItem) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = ACLItem{Status: Null} - return nil - } - - *dst = ACLItem{String: string(src), Status: Present} - return nil -} - -func (src *ACLItem) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - return append(buf, src.String...), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *ACLItem) Scan(src interface{}) error { - if src == nil { - *dst = ACLItem{Status: Null} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *ACLItem) Value() (driver.Value, error) { - switch src.Status { - case Present: - return src.String, nil - case Null: - return nil, nil - default: - return nil, errUndefined - } -} diff --git a/pgtype/aclitem_array.go b/pgtype/aclitem_array.go deleted file mode 100644 index 0a829295d..000000000 --- a/pgtype/aclitem_array.go +++ /dev/null @@ -1,212 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - - "github.com/pkg/errors" -) - -type ACLItemArray struct { - Elements []ACLItem - Dimensions []ArrayDimension - Status Status -} - -func (dst *ACLItemArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = ACLItemArray{Status: Null} - return nil - } - - switch value := src.(type) { - - case []string: - if value == nil { - *dst = ACLItemArray{Status: Null} - } else if len(value) == 0 { - *dst = ACLItemArray{Status: Present} - } else { - elements := make([]ACLItem, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = ACLItemArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - default: - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to ACLItemArray", value) - } - - return nil -} - -func (dst *ACLItemArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *ACLItemArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *ACLItemArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = ACLItemArray{Status: Null} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []ACLItem - - if len(uta.Elements) > 0 { - elements = make([]ACLItem, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem ACLItem - var elemSrc []byte - if s != "NULL" { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = ACLItemArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} - - return nil -} - -func (src *ACLItemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *ACLItemArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *ACLItemArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/aclitem_array_test.go b/pgtype/aclitem_array_test.go deleted file mode 100644 index 4e60afcab..000000000 --- a/pgtype/aclitem_array_test.go +++ /dev/null @@ -1,152 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestACLItemArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "aclitem[]", []interface{}{ - &pgtype.ACLItemArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.ACLItemArray{Status: pgtype.Null}, - &pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - {String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, - {String: "=r/postgres", Status: pgtype.Present}, - {Status: pgtype.Null}, - {String: "=r/postgres", Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - {String: "=r/postgres", Status: pgtype.Present}, - {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestACLItemArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.ACLItemArray - }{ - { - source: []string{"=r/postgres"}, - result: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]string)(nil)), - result: pgtype.ACLItemArray{Status: pgtype.Null}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.ACLItemArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestACLItemArrayAssignTo(t *testing.T) { - var stringSlice []string - type _stringSlice []string - var namedStringSlice _stringSlice - - simpleTests := []struct { - src pgtype.ACLItemArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &stringSlice, - expected: []string{"=r/postgres"}, - }, - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &namedStringSlice, - expected: _stringSlice{"=r/postgres"}, - }, - { - src: pgtype.ACLItemArray{Status: pgtype.Null}, - dst: &stringSlice, - expected: (([]string)(nil)), - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.ACLItemArray - dst interface{} - }{ - { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &stringSlice, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/aclitem_test.go b/pgtype/aclitem_test.go deleted file mode 100644 index 65399a30c..000000000 --- a/pgtype/aclitem_test.go +++ /dev/null @@ -1,97 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestACLItemTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "aclitem", []interface{}{ - &pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - &pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, - &pgtype.ACLItem{Status: pgtype.Null}, - }) -} - -func TestACLItemSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.ACLItem - }{ - {source: "postgres=arwdDxt/postgres", result: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, - {source: (*string)(nil), result: pgtype.ACLItem{Status: pgtype.Null}}, - } - - for i, tt := range successfulTests { - var d pgtype.ACLItem - err := d.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if d != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) - } - } -} - -func TestACLItemAssignTo(t *testing.T) { - var s string - var ps *string - - simpleTests := []struct { - src pgtype.ACLItem - dst interface{} - expected interface{} - }{ - {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &s, expected: "postgres=arwdDxt/postgres"}, - {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.ACLItem - dst interface{} - expected interface{} - }{ - {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &ps, expected: "postgres=arwdDxt/postgres"}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.ACLItem - dst interface{} - }{ - {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &s}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/array.go b/pgtype/array.go index 5b852ed56..872a08891 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -3,22 +3,22 @@ package pgtype import ( "bytes" "encoding/binary" + "fmt" "io" "strconv" "strings" "unicode" - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" + "github.com/jackc/pgx/v5/internal/pgio" ) // Information on the internals of PostgreSQL arrays can be found in // src/include/utils/array.h and src/backend/utils/adt/arrayfuncs.c. Of // particular interest is the array_send function. -type ArrayHeader struct { +type arrayHeader struct { ContainsNull bool - ElementOID int32 + ElementOID uint32 Dimensions []ArrayDimension } @@ -27,9 +27,23 @@ type ArrayDimension struct { LowerBound int32 } -func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { +// cardinality returns the number of elements in an array of dimensions size. +func cardinality(dimensions []ArrayDimension) int { + if len(dimensions) == 0 { + return 0 + } + + elementCount := int(dimensions[0].Length) + for _, d := range dimensions[1:] { + elementCount *= int(d.Length) + } + + return elementCount +} + +func (dst *arrayHeader) DecodeBinary(m *Map, src []byte) (int, error) { if len(src) < 12 { - return 0, errors.Errorf("array header too short: %d", len(src)) + return 0, fmt.Errorf("array header too short: %d", len(src)) } rp := 0 @@ -40,14 +54,12 @@ func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { dst.ContainsNull = binary.BigEndian.Uint32(src[rp:]) == 1 rp += 4 - dst.ElementOID = int32(binary.BigEndian.Uint32(src[rp:])) + dst.ElementOID = binary.BigEndian.Uint32(src[rp:]) rp += 4 - if numDims > 0 { - dst.Dimensions = make([]ArrayDimension, numDims) - } + dst.Dimensions = make([]ArrayDimension, numDims) if len(src) < 12+numDims*8 { - return 0, errors.Errorf("array header too short for %d dimensions: %d", numDims, len(src)) + return 0, fmt.Errorf("array header too short for %d dimensions: %d", numDims, len(src)) } for i := range dst.Dimensions { dst.Dimensions[i].Length = int32(binary.BigEndian.Uint32(src[rp:])) @@ -60,7 +72,7 @@ func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { return rp, nil } -func (src *ArrayHeader) EncodeBinary(ci *ConnInfo, buf []byte) []byte { +func (src arrayHeader) EncodeBinary(buf []byte) []byte { buf = pgio.AppendInt32(buf, int32(len(src.Dimensions))) var containsNull int32 @@ -69,7 +81,7 @@ func (src *ArrayHeader) EncodeBinary(ci *ConnInfo, buf []byte) []byte { } buf = pgio.AppendInt32(buf, containsNull) - buf = pgio.AppendInt32(buf, src.ElementOID) + buf = pgio.AppendUint32(buf, src.ElementOID) for i := range src.Dimensions { buf = pgio.AppendInt32(buf, src.Dimensions[i].Length) @@ -79,13 +91,18 @@ func (src *ArrayHeader) EncodeBinary(ci *ConnInfo, buf []byte) []byte { return buf } -type UntypedTextArray struct { +type untypedTextArray struct { Elements []string + Quoted []bool Dimensions []ArrayDimension } -func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { - dst := &UntypedTextArray{} +func parseUntypedTextArray(src string) (*untypedTextArray, error) { + dst := &untypedTextArray{ + Elements: []string{}, + Quoted: []bool{}, + Dimensions: []ArrayDimension{}, + } buf := bytes.NewBufferString(src) @@ -93,7 +110,7 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { r, _, err := buf.ReadRune() if err != nil { - return nil, errors.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %w", err) } var explicitDimensions []ArrayDimension @@ -105,41 +122,41 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { for { r, _, err = buf.ReadRune() if err != nil { - return nil, errors.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %w", err) } if r == '=' { break } else if r != '[' { - return nil, errors.Errorf("invalid array, expected '[' or '=' got %v", r) + return nil, fmt.Errorf("invalid array, expected '[' or '=' got %v", r) } lower, err := arrayParseInteger(buf) if err != nil { - return nil, errors.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %w", err) } r, _, err = buf.ReadRune() if err != nil { - return nil, errors.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %w", err) } if r != ':' { - return nil, errors.Errorf("invalid array, expected ':' got %v", r) + return nil, fmt.Errorf("invalid array, expected ':' got %v", r) } upper, err := arrayParseInteger(buf) if err != nil { - return nil, errors.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %w", err) } r, _, err = buf.ReadRune() if err != nil { - return nil, errors.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %w", err) } if r != ']' { - return nil, errors.Errorf("invalid array, expected ']' got %v", r) + return nil, fmt.Errorf("invalid array, expected ']' got %v", r) } explicitDimensions = append(explicitDimensions, ArrayDimension{LowerBound: lower, Length: upper - lower + 1}) @@ -147,12 +164,12 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { r, _, err = buf.ReadRune() if err != nil { - return nil, errors.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %w", err) } } if r != '{' { - return nil, errors.Errorf("invalid array, expected '{': %v", err) + return nil, fmt.Errorf("invalid array, expected '{' got %v", r) } implicitDimensions := []ArrayDimension{{LowerBound: 1, Length: 0}} @@ -161,7 +178,7 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { for { r, _, err = buf.ReadRune() if err != nil { - return nil, errors.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %w", err) } if r == '{' { @@ -178,7 +195,7 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { for { r, _, err = buf.ReadRune() if err != nil { - return nil, errors.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %w", err) } switch r { @@ -195,13 +212,14 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { } default: buf.UnreadRune() - value, err := arrayParseValue(buf) + value, quoted, err := arrayParseValue(buf) if err != nil { - return nil, errors.Errorf("invalid array value: %v", err) + return nil, fmt.Errorf("invalid array value: %w", err) } if currentDim == counterDim { implicitDimensions[currentDim].Length++ } + dst.Quoted = append(dst.Quoted, quoted) dst.Elements = append(dst.Elements, value) } @@ -213,11 +231,10 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { skipWhitespace(buf) if buf.Len() > 0 { - return nil, errors.Errorf("unexpected trailing data: %v", buf.String()) + return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) } if len(dst.Elements) == 0 { - dst.Dimensions = nil } else if len(explicitDimensions) > 0 { dst.Dimensions = explicitDimensions } else { @@ -238,10 +255,10 @@ func skipWhitespace(buf *bytes.Buffer) { } } -func arrayParseValue(buf *bytes.Buffer) (string, error) { +func arrayParseValue(buf *bytes.Buffer) (string, bool, error) { r, _, err := buf.ReadRune() if err != nil { - return "", err + return "", false, err } if r == '"' { return arrayParseQuotedValue(buf) @@ -253,41 +270,41 @@ func arrayParseValue(buf *bytes.Buffer) (string, error) { for { r, _, err := buf.ReadRune() if err != nil { - return "", err + return "", false, err } switch r { case ',', '}': buf.UnreadRune() - return s.String(), nil + return s.String(), false, nil } s.WriteRune(r) } } -func arrayParseQuotedValue(buf *bytes.Buffer) (string, error) { +func arrayParseQuotedValue(buf *bytes.Buffer) (string, bool, error) { s := &bytes.Buffer{} for { r, _, err := buf.ReadRune() if err != nil { - return "", err + return "", false, err } switch r { case '\\': r, _, err = buf.ReadRune() if err != nil { - return "", err + return "", false, err } case '"': r, _, err = buf.ReadRune() if err != nil { - return "", err + return "", false, err } buf.UnreadRune() - return s.String(), nil + return s.String(), true, nil } s.WriteRune(r) } @@ -302,7 +319,7 @@ func arrayParseInteger(buf *bytes.Buffer) (int32, error) { return 0, err } - if '0' <= r && r <= '9' { + if ('0' <= r && r <= '9') || r == '-' { s.WriteRune(r) } else { buf.UnreadRune() @@ -315,7 +332,7 @@ func arrayParseInteger(buf *bytes.Buffer) (int32, error) { } } -func EncodeTextArrayDimensions(buf []byte, dimensions []ArrayDimension) []byte { +func encodeTextArrayDimensions(buf []byte, dimensions []ArrayDimension) []byte { var customDimensions bool for _, dim := range dimensions { if dim.LowerBound != 1 { @@ -344,9 +361,100 @@ func quoteArrayElement(src string) string { return `"` + quoteArrayReplacer.Replace(src) + `"` } -func QuoteArrayElementIfNeeded(src string) string { - if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `{},"\`) { +func isSpace(ch byte) bool { + // see array_isspace: + // https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/arrayfuncs.c + return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\v' || ch == '\f' +} + +func quoteArrayElementIfNeeded(src string) string { + if src == "" || (len(src) == 4 && strings.EqualFold(src, "null")) || isSpace(src[0]) || isSpace(src[len(src)-1]) || strings.ContainsAny(src, `{},"\`) { return quoteArrayElement(src) } return src } + +// Array represents a PostgreSQL array for T. It implements the [ArrayGetter] and [ArraySetter] interfaces. It preserves +// PostgreSQL dimensions and custom lower bounds. Use [FlatArray] if these are not needed. +type Array[T any] struct { + Elements []T + Dims []ArrayDimension + Valid bool +} + +func (a Array[T]) Dimensions() []ArrayDimension { + return a.Dims +} + +func (a Array[T]) Index(i int) any { + return a.Elements[i] +} + +func (a Array[T]) IndexType() any { + var el T + return el +} + +func (a *Array[T]) SetDimensions(dimensions []ArrayDimension) error { + if dimensions == nil { + *a = Array[T]{} + return nil + } + + elementCount := cardinality(dimensions) + *a = Array[T]{ + Elements: make([]T, elementCount), + Dims: dimensions, + Valid: true, + } + + return nil +} + +func (a Array[T]) ScanIndex(i int) any { + return &a.Elements[i] +} + +func (a Array[T]) ScanIndexType() any { + return new(T) +} + +// FlatArray implements the [ArrayGetter] and [ArraySetter] interfaces for any slice of T. It ignores PostgreSQL dimensions +// and custom lower bounds. Use [Array] to preserve these. +type FlatArray[T any] []T + +func (a FlatArray[T]) Dimensions() []ArrayDimension { + if a == nil { + return nil + } + + return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}} +} + +func (a FlatArray[T]) Index(i int) any { + return a[i] +} + +func (a FlatArray[T]) IndexType() any { + var el T + return el +} + +func (a *FlatArray[T]) SetDimensions(dimensions []ArrayDimension) error { + if dimensions == nil { + *a = nil + return nil + } + + elementCount := cardinality(dimensions) + *a = make(FlatArray[T], elementCount) + return nil +} + +func (a FlatArray[T]) ScanIndex(i int) any { + return &a[i] +} + +func (a FlatArray[T]) ScanIndexType() any { + return new(T) +} diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go new file mode 100644 index 000000000..bf5f6989a --- /dev/null +++ b/pgtype/array_codec.go @@ -0,0 +1,405 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// ArrayGetter is a type that can be converted into a PostgreSQL array. +type ArrayGetter interface { + // Dimensions returns the array dimensions. If array is nil then nil is returned. + Dimensions() []ArrayDimension + + // Index returns the element at i. + Index(i int) any + + // IndexType returns a non-nil scan target of the type Index will return. This is used by ArrayCodec.PlanEncode. + IndexType() any +} + +// ArraySetter is a type can be set from a PostgreSQL array. +type ArraySetter interface { + // SetDimensions prepares the value such that ScanIndex can be called for each element. This will remove any existing + // elements. dimensions may be nil to indicate a NULL array. If unable to exactly preserve dimensions SetDimensions + // may return an error or silently flatten the array dimensions. + SetDimensions(dimensions []ArrayDimension) error + + // ScanIndex returns a value usable as a scan target for i. SetDimensions must be called before ScanIndex. + ScanIndex(i int) any + + // ScanIndexType returns a non-nil scan target of the type ScanIndex will return. This is used by + // ArrayCodec.PlanScan. + ScanIndexType() any +} + +// ArrayCodec is a codec for any array type. +type ArrayCodec struct { + ElementType *Type +} + +func (c *ArrayCodec) FormatSupported(format int16) bool { + return c.ElementType.Codec.FormatSupported(format) +} + +func (c *ArrayCodec) PreferredFormat() int16 { + // The binary format should always be preferred for arrays if it is supported. Usually, this will happen automatically + // because most types that support binary prefer it. However, text, json, and jsonb support binary but prefer the text + // format. This is because it is simpler for jsonb and PostgreSQL can be significantly faster using the text format + // for text-like data types than binary. However, arrays appear to always be faster in binary. + // + // https://www.postgresql.org/message-id/CAMovtNoHFod2jMAKQjjxv209PCTJx5Kc66anwWvX0mEiaXwgmA%40mail.gmail.com + if c.ElementType.Codec.FormatSupported(BinaryFormatCode) { + return BinaryFormatCode + } + return TextFormatCode +} + +func (c *ArrayCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + arrayValuer, ok := value.(ArrayGetter) + if !ok { + return nil + } + + elementType := arrayValuer.IndexType() + + elementEncodePlan := m.PlanEncode(c.ElementType.OID, format, elementType) + if elementEncodePlan == nil { + if reflect.TypeOf(elementType) != nil { + return nil + } + } + + switch format { + case BinaryFormatCode: + return &encodePlanArrayCodecBinary{ac: c, m: m, oid: oid} + case TextFormatCode: + return &encodePlanArrayCodecText{ac: c, m: m, oid: oid} + } + + return nil +} + +type encodePlanArrayCodecText struct { + ac *ArrayCodec + m *Map + oid uint32 +} + +func (p *encodePlanArrayCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + array := value.(ArrayGetter) + + dimensions := array.Dimensions() + if dimensions == nil { + return nil, nil + } + + elementCount := cardinality(dimensions) + if elementCount == 0 { + return append(buf, '{', '}'), nil + } + + buf = encodeTextArrayDimensions(buf, dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(dimensions)) + dimElemCounts[len(dimensions)-1] = int(dimensions[len(dimensions)-1].Length) + for i := len(dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(dimensions[i].Length) * dimElemCounts[i+1] + } + + var encodePlan EncodePlan + var lastElemType reflect.Type + inElemBuf := make([]byte, 0, 32) + for i := 0; i < elementCount; i++ { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elem := array.Index(i) + var elemBuf []byte + if elem != nil { + elemType := reflect.TypeOf(elem) + if lastElemType != elemType { + lastElemType = elemType + encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, TextFormatCode, elem) + if encodePlan == nil { + return nil, fmt.Errorf("unable to encode %v", array.Index(i)) + } + } + elemBuf, err = encodePlan.Encode(elem, inElemBuf) + if err != nil { + return nil, err + } + } + + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, quoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +type encodePlanArrayCodecBinary struct { + ac *ArrayCodec + m *Map + oid uint32 +} + +func (p *encodePlanArrayCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + array := value.(ArrayGetter) + + dimensions := array.Dimensions() + if dimensions == nil { + return nil, nil + } + + arrayHeader := arrayHeader{ + Dimensions: dimensions, + ElementOID: p.ac.ElementType.OID, + } + + containsNullIndex := len(buf) + 4 + + buf = arrayHeader.EncodeBinary(buf) + + elementCount := cardinality(dimensions) + + var encodePlan EncodePlan + var lastElemType reflect.Type + for i := 0; i < elementCount; i++ { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elem := array.Index(i) + var elemBuf []byte + if elem != nil { + elemType := reflect.TypeOf(elem) + if lastElemType != elemType { + lastElemType = elemType + encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, BinaryFormatCode, elem) + if encodePlan == nil { + return nil, fmt.Errorf("unable to encode %v", array.Index(i)) + } + } + elemBuf, err = encodePlan.Encode(elem, buf) + if err != nil { + return nil, err + } + } + + if elemBuf == nil { + pgio.SetInt32(buf[containsNullIndex:], 1) + } else { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +func (c *ArrayCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + arrayScanner, ok := target.(ArraySetter) + if !ok { + return nil + } + + // target / arrayScanner might be a pointer to a nil. If it is create one so we can call ScanIndexType to plan the + // scan of the elements. + if isNil, _ := isNilDriverValuer(target); isNil { + arrayScanner = reflect.New(reflect.TypeOf(target).Elem()).Interface().(ArraySetter) + } + + elementType := arrayScanner.ScanIndexType() + + elementScanPlan := m.PlanScan(c.ElementType.OID, format, elementType) + if _, ok := elementScanPlan.(*scanPlanFail); ok { + return nil + } + + return &scanPlanArrayCodec{ + arrayCodec: c, + m: m, + oid: oid, + formatCode: format, + } +} + +func (c *ArrayCodec) decodeBinary(m *Map, arrayOID uint32, src []byte, array ArraySetter) error { + var arrayHeader arrayHeader + rp, err := arrayHeader.DecodeBinary(m, src) + if err != nil { + return err + } + + err = array.SetDimensions(arrayHeader.Dimensions) + if err != nil { + return err + } + + elementCount := cardinality(arrayHeader.Dimensions) + if elementCount == 0 { + return nil + } + + elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, BinaryFormatCode, array.ScanIndex(0)) + if elementScanPlan == nil { + elementScanPlan = m.PlanScan(c.ElementType.OID, BinaryFormatCode, array.ScanIndex(0)) + } + + for i := 0; i < elementCount; i++ { + elem := array.ScanIndex(i) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elementScanPlan.Scan(elemSrc, elem) + if err != nil { + return fmt.Errorf("failed to scan array element %d: %w", i, err) + } + } + + return nil +} + +func (c *ArrayCodec) decodeText(m *Map, arrayOID uint32, src []byte, array ArraySetter) error { + uta, err := parseUntypedTextArray(string(src)) + if err != nil { + return err + } + + err = array.SetDimensions(uta.Dimensions) + if err != nil { + return err + } + + if len(uta.Elements) == 0 { + return nil + } + + elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, TextFormatCode, array.ScanIndex(0)) + if elementScanPlan == nil { + elementScanPlan = m.PlanScan(c.ElementType.OID, TextFormatCode, array.ScanIndex(0)) + } + + for i, s := range uta.Elements { + elem := array.ScanIndex(i) + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + + err = elementScanPlan.Scan(elemSrc, elem) + if err != nil { + return err + } + } + + return nil +} + +type scanPlanArrayCodec struct { + arrayCodec *ArrayCodec + m *Map + oid uint32 + formatCode int16 + elementScanPlan ScanPlan +} + +func (spac *scanPlanArrayCodec) Scan(src []byte, dst any) error { + c := spac.arrayCodec + m := spac.m + oid := spac.oid + formatCode := spac.formatCode + + array := dst.(ArraySetter) + + if src == nil { + return array.SetDimensions(nil) + } + + switch formatCode { + case BinaryFormatCode: + return c.decodeBinary(m, oid, src, array) + case TextFormatCode: + return c.decodeText(m, oid, src, array) + default: + return fmt.Errorf("unknown format code %d", formatCode) + } +} + +func (c *ArrayCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + switch format { + case TextFormatCode: + return string(src), nil + case BinaryFormatCode: + buf := make([]byte, len(src)) + copy(buf, src) + return buf, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } +} + +func (c *ArrayCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var slice []any + err := m.PlanScan(oid, format, &slice).Scan(src, &slice) + return slice, err +} + +func isRagged(slice reflect.Value) bool { + if slice.Type().Elem().Kind() != reflect.Slice { + return false + } + + sliceLen := slice.Len() + innerLen := 0 + for i := 0; i < sliceLen; i++ { + if i == 0 { + innerLen = slice.Index(i).Len() + } else { + if slice.Index(i).Len() != innerLen { + return true + } + } + if isRagged(slice.Index(i)) { + return true + } + } + + return false +} diff --git a/pgtype/array_codec_test.go b/pgtype/array_codec_test.go new file mode 100644 index 000000000..eafa94d43 --- /dev/null +++ b/pgtype/array_codec_test.go @@ -0,0 +1,364 @@ +package pgtype_test + +import ( + "context" + "encoding/hex" + "reflect" + "strings" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestArrayCodec(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + for i, tt := range []struct { + expected any + }{ + {[]int16(nil)}, + {[]int16{}}, + {[]int16{1, 2, 3}}, + } { + var actual []int16 + err := conn.QueryRow( + ctx, + "select $1::smallint[]", + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + + newInt16 := func(n int16) *int16 { return &n } + + for i, tt := range []struct { + expected any + }{ + {[]*int16{newInt16(1), nil, newInt16(3), nil, newInt16(5)}}, + } { + var actual []*int16 + err := conn.QueryRow( + ctx, + "select $1::smallint[]", + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + }) +} + +func TestArrayCodecFlatArrayString(t *testing.T) { + testCases := []struct { + input []string + }{ + {nil}, + {[]string{}}, + {[]string{"a"}}, + {[]string{"a", "b"}}, + // previously had a bug with whitespace handling + {[]string{"\v", "\t", "\n", "\r", "\f", " "}}, + {[]string{"a\vb", "a\tb", "a\nb", "a\rb", "a\fb", "a b"}}, + } + + queryModes := []pgx.QueryExecMode{pgx.QueryExecModeSimpleProtocol, pgx.QueryExecModeDescribeExec} + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + for i, testCase := range testCases { + for _, queryMode := range queryModes { + var out []string + err := conn.QueryRow(ctx, "select $1::text[]", queryMode, testCase.input).Scan(&out) + if err != nil { + t.Fatalf("i=%d input=%#v queryMode=%s: Scan failed: %s", + i, testCase.input, queryMode, err) + } + if !reflect.DeepEqual(out, testCase.input) { + t.Errorf("i=%d input=%#v queryMode=%s: not equal output=%#v", + i, testCase.input, queryMode, out) + } + } + } + }) +} + +func TestArrayCodecArray(t *testing.T) { + ctr := defaultConnTestRunner + ctr.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does not support multi-dimensional arrays") + } + + ctr.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + for i, tt := range []struct { + expected any + }{ + {pgtype.Array[int32]{ + Elements: []int32{1, 2, 3, 4}, + Dims: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 2}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }}, + } { + var actual pgtype.Array[int32] + err := conn.QueryRow( + ctx, + "select $1::int[]", + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + }) +} + +func TestArrayCodecNamedSliceType(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + type _int16Slice []int16 + + for i, tt := range []struct { + expected any + }{ + {_int16Slice(nil)}, + {_int16Slice{}}, + {_int16Slice{1, 2, 3}}, + } { + var actual _int16Slice + err := conn.QueryRow( + ctx, + "select $1::smallint[]", + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + }) +} + +// https://github.com/jackc/pgx/issues/1488 +func TestArrayCodecAnySliceArgument(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + type _int16Slice []int16 + + for i, tt := range []struct { + arg any + expected []int16 + }{ + {[]any{1, 2, 3}, []int16{1, 2, 3}}, + } { + var actual []int16 + err := conn.QueryRow( + ctx, + "select $1::smallint[]", + tt.arg, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + }) +} + +// https://github.com/jackc/pgx/issues/1442 +func TestArrayCodecAnyArray(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + type _point3 [3]float32 + + for i, tt := range []struct { + expected any + }{ + {_point3{0, 0, 0}}, + {_point3{1, 2, 3}}, + } { + var actual _point3 + err := conn.QueryRow( + ctx, + "select $1::float4[]", + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + }) +} + +// https://github.com/jackc/pgx/issues/1273#issuecomment-1218262703 +func TestArrayCodecSliceArgConversion(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + arg := []string{ + "3ad95bfd-ecea-4032-83c3-0c823cafb372", + "951baf11-c0cc-4afc-a779-abff0611dbf1", + "8327f244-7e2f-45e7-a10b-fbdc9d6f3378", + } + + var expected []pgtype.UUID + + for _, s := range arg { + buf, err := hex.DecodeString(strings.ReplaceAll(s, "-", "")) + require.NoError(t, err) + var u pgtype.UUID + copy(u.Bytes[:], buf) + u.Valid = true + expected = append(expected, u) + } + + var actual []pgtype.UUID + err := conn.QueryRow( + ctx, + "select $1::uuid[]", + arg, + ).Scan(&actual) + require.NoError(t, err) + require.Equal(t, expected, actual) + }) +} + +func TestArrayCodecDecodeValue(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + for _, tt := range []struct { + sql string + expected any + }{ + { + sql: `select '{}'::int4[]`, + expected: []any{}, + }, + { + sql: `select '{1,2}'::int8[]`, + expected: []any{int64(1), int64(2)}, + }, + { + sql: `select '{foo,bar}'::text[]`, + expected: []any{"foo", "bar"}, + }, + } { + t.Run(tt.sql, func(t *testing.T) { + rows, err := conn.Query(ctx, tt.sql) + require.NoError(t, err) + + for rows.Next() { + values, err := rows.Values() + require.NoError(t, err) + require.Len(t, values, 1) + require.Equal(t, tt.expected, values[0]) + } + + require.NoError(t, rows.Err()) + }) + } + }) +} + +func TestArrayCodecScanMultipleDimensions(t *testing.T) { + skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, err := conn.Query(ctx, `select '{{1,2,3,4}, {5,6,7,8}, {9,10,11,12}}'::int4[]`) + require.NoError(t, err) + + for rows.Next() { + var ss [][]int32 + err := rows.Scan(&ss) + require.NoError(t, err) + require.Equal(t, [][]int32{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, ss) + } + + require.NoError(t, rows.Err()) + }) +} + +func TestArrayCodecScanMultipleDimensionsEmpty(t *testing.T) { + skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, err := conn.Query(ctx, `select '{}'::int4[]`) + require.NoError(t, err) + + for rows.Next() { + var ss [][]int32 + err := rows.Scan(&ss) + require.NoError(t, err) + require.Equal(t, [][]int32{}, ss) + } + + require.NoError(t, rows.Err()) + }) +} + +func TestArrayCodecScanWrongMultipleDimensions(t *testing.T) { + skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, err := conn.Query(ctx, `select '{{1,2,3,4}, {5,6,7,8}, {9,10,11,12}}'::int4[]`) + require.NoError(t, err) + + for rows.Next() { + var ss [][][]int32 + err := rows.Scan(&ss) + require.Error(t, err, "can't scan into dest[0]: PostgreSQL array has 2 dimensions but slice has 3 dimensions") + } + }) +} + +func TestArrayCodecEncodeMultipleDimensions(t *testing.T) { + skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, err := conn.Query(ctx, `select $1::int4[]`, [][]int32{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}) + require.NoError(t, err) + + for rows.Next() { + var ss [][]int32 + err := rows.Scan(&ss) + require.NoError(t, err) + require.Equal(t, [][]int32{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, ss) + } + + require.NoError(t, rows.Err()) + }) +} + +func TestArrayCodecEncodeMultipleDimensionsRagged(t *testing.T) { + skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, err := conn.Query(ctx, `select $1::int4[]`, [][]int32{{1, 2, 3, 4}, {5}, {9, 10, 11, 12}}) + require.Error(t, err, "cannot convert [][]int32 to ArrayGetter because it is a ragged multi-dimensional") + defer rows.Close() + }) +} + +// https://github.com/jackc/pgx/issues/1494 +func TestArrayCodecDecodeTextArrayWithTextOfNULL(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + { + var actual []string + err := conn.QueryRow(ctx, `select '{"foo", "NULL", " NULL "}'::text[]`).Scan(&actual) + require.NoError(t, err) + require.Equal(t, []string{"foo", "NULL", " NULL "}, actual) + } + + { + var actual []pgtype.Text + err := conn.QueryRow(ctx, `select '{"foo", "NULL", NULL, " NULL "}'::text[]`).Scan(&actual) + require.NoError(t, err) + require.Equal(t, []pgtype.Text{ + {String: "foo", Valid: true}, + {String: "NULL", Valid: true}, + {}, + {String: " NULL ", Valid: true}, + }, actual) + } + }) +} + +func TestArrayCodecDecodeTextArrayPrefersBinaryFormat(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + sd, err := conn.Prepare(ctx, "", `select '{"foo", "NULL", " NULL "}'::text[]`) + require.NoError(t, err) + require.Equal(t, int16(1), conn.TypeMap().FormatCodeForOID(sd.Fields[0].DataTypeOID)) + }) +} diff --git a/pgtype/array_test.go b/pgtype/array_test.go index d1cdb4c56..f246b346f 100644 --- a/pgtype/array_test.go +++ b/pgtype/array_test.go @@ -1,71 +1,77 @@ -package pgtype_test +package pgtype import ( "reflect" "testing" - - "github.com/jackc/pgx/pgtype" ) func TestParseUntypedTextArray(t *testing.T) { tests := []struct { source string - result pgtype.UntypedTextArray + result untypedTextArray }{ { source: "{}", - result: pgtype.UntypedTextArray{ - Elements: nil, - Dimensions: nil, + result: untypedTextArray{ + Elements: []string{}, + Quoted: []bool{}, + Dimensions: []ArrayDimension{}, }, }, { source: "{1}", - result: pgtype.UntypedTextArray{ + result: untypedTextArray{ Elements: []string{"1"}, - Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + Quoted: []bool{false}, + Dimensions: []ArrayDimension{{Length: 1, LowerBound: 1}}, }, }, { source: "{a,b}", - result: pgtype.UntypedTextArray{ + result: untypedTextArray{ Elements: []string{"a", "b"}, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Quoted: []bool{false, false}, + Dimensions: []ArrayDimension{{Length: 2, LowerBound: 1}}, }, }, { source: `{"NULL"}`, - result: pgtype.UntypedTextArray{ + result: untypedTextArray{ Elements: []string{"NULL"}, - Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + Quoted: []bool{true}, + Dimensions: []ArrayDimension{{Length: 1, LowerBound: 1}}, }, }, { source: `{""}`, - result: pgtype.UntypedTextArray{ + result: untypedTextArray{ Elements: []string{""}, - Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + Quoted: []bool{true}, + Dimensions: []ArrayDimension{{Length: 1, LowerBound: 1}}, }, }, { source: `{"He said, \"Hello.\""}`, - result: pgtype.UntypedTextArray{ + result: untypedTextArray{ Elements: []string{`He said, "Hello."`}, - Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + Quoted: []bool{true}, + Dimensions: []ArrayDimension{{Length: 1, LowerBound: 1}}, }, }, { source: "{{a,b},{c,d},{e,f}}", - result: pgtype.UntypedTextArray{ + result: untypedTextArray{ Elements: []string{"a", "b", "c", "d", "e", "f"}, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Quoted: []bool{false, false, false, false, false, false}, + Dimensions: []ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, }, }, { source: "{{{a,b},{c,d},{e,f}},{{a,b},{c,d},{e,f}}}", - result: pgtype.UntypedTextArray{ + result: untypedTextArray{ Elements: []string{"a", "b", "c", "d", "e", "f", "a", "b", "c", "d", "e", "f"}, - Dimensions: []pgtype.ArrayDimension{ + Quoted: []bool{false, false, false, false, false, false, false, false, false, false, false, false}, + Dimensions: []ArrayDimension{ {Length: 2, LowerBound: 1}, {Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}, @@ -74,25 +80,35 @@ func TestParseUntypedTextArray(t *testing.T) { }, { source: "[4:4]={1}", - result: pgtype.UntypedTextArray{ + result: untypedTextArray{ Elements: []string{"1"}, - Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 4}}, + Quoted: []bool{false}, + Dimensions: []ArrayDimension{{Length: 1, LowerBound: 4}}, }, }, { source: "[4:5][2:3]={{a,b},{c,d}}", - result: pgtype.UntypedTextArray{ + result: untypedTextArray{ Elements: []string{"a", "b", "c", "d"}, - Dimensions: []pgtype.ArrayDimension{ + Quoted: []bool{false, false, false, false}, + Dimensions: []ArrayDimension{ {Length: 2, LowerBound: 4}, {Length: 2, LowerBound: 2}, }, }, }, + { + source: "[-4:-2]={1,2,3}", + result: untypedTextArray{ + Elements: []string{"1", "2", "3"}, + Quoted: []bool{false, false, false}, + Dimensions: []ArrayDimension{{Length: 3, LowerBound: -4}}, + }, + }, } for i, tt := range tests { - r, err := pgtype.ParseUntypedTextArray(tt.source) + r, err := parseUntypedTextArray(tt.source) if err != nil { t.Errorf("%d: %v", i, err) continue diff --git a/pgtype/bit.go b/pgtype/bit.go deleted file mode 100644 index f892cee5c..000000000 --- a/pgtype/bit.go +++ /dev/null @@ -1,37 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" -) - -type Bit Varbit - -func (dst *Bit) Set(src interface{}) error { - return (*Varbit)(dst).Set(src) -} - -func (dst *Bit) Get() interface{} { - return (*Varbit)(dst).Get() -} - -func (src *Bit) AssignTo(dst interface{}) error { - return (*Varbit)(src).AssignTo(dst) -} - -func (dst *Bit) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*Varbit)(dst).DecodeBinary(ci, src) -} - -func (src *Bit) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Varbit)(src).EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *Bit) Scan(src interface{}) error { - return (*Varbit)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *Bit) Value() (driver.Value, error) { - return (*Varbit)(src).Value() -} diff --git a/pgtype/bit_test.go b/pgtype/bit_test.go deleted file mode 100644 index 19492bc9e..000000000 --- a/pgtype/bit_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestBitTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "bit(40)", []interface{}{ - &pgtype.Varbit{Bytes: []byte{0, 0, 0, 0, 0}, Len: 40, Status: pgtype.Present}, - &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Status: pgtype.Present}, - &pgtype.Varbit{Status: pgtype.Null}, - }) -} - -func TestBitNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select B'111111111'", - Value: &pgtype.Bit{Bytes: []byte{255, 128}, Len: 9, Status: pgtype.Present}, - }, - }) -} diff --git a/pgtype/bits.go b/pgtype/bits.go new file mode 100644 index 000000000..2a48e3549 --- /dev/null +++ b/pgtype/bits.go @@ -0,0 +1,211 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type BitsScanner interface { + ScanBits(v Bits) error +} + +type BitsValuer interface { + BitsValue() (Bits, error) +} + +// Bits represents the PostgreSQL bit and varbit types. +type Bits struct { + Bytes []byte + Len int32 // Number of bits + Valid bool +} + +// ScanBits implements the [BitsScanner] interface. +func (b *Bits) ScanBits(v Bits) error { + *b = v + return nil +} + +// BitsValue implements the [BitsValuer] interface. +func (b Bits) BitsValue() (Bits, error) { + return b, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (dst *Bits) Scan(src any) error { + if src == nil { + *dst = Bits{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToBitsScanner{}.Scan([]byte(src), dst) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (src Bits) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + buf, err := BitsCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type BitsCodec struct{} + +func (BitsCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (BitsCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (BitsCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(BitsValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanBitsCodecBinary{} + case TextFormatCode: + return encodePlanBitsCodecText{} + } + + return nil +} + +type encodePlanBitsCodecBinary struct{} + +func (encodePlanBitsCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + bits, err := value.(BitsValuer).BitsValue() + if err != nil { + return nil, err + } + + if !bits.Valid { + return nil, nil + } + + buf = pgio.AppendInt32(buf, bits.Len) + return append(buf, bits.Bytes...), nil +} + +type encodePlanBitsCodecText struct{} + +func (encodePlanBitsCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + bits, err := value.(BitsValuer).BitsValue() + if err != nil { + return nil, err + } + + if !bits.Valid { + return nil, nil + } + + for i := int32(0); i < bits.Len; i++ { + byteIdx := i / 8 + bitMask := byte(128 >> byte(i%8)) + char := byte('0') + if bits.Bytes[byteIdx]&bitMask > 0 { + char = '1' + } + buf = append(buf, char) + } + + return buf, nil +} + +func (BitsCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case BitsScanner: + return scanPlanBinaryBitsToBitsScanner{} + } + case TextFormatCode: + switch target.(type) { + case BitsScanner: + return scanPlanTextAnyToBitsScanner{} + } + } + + return nil +} + +func (c BitsCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c BitsCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var box Bits + err := codecScan(c, m, oid, format, src, &box) + if err != nil { + return nil, err + } + return box, nil +} + +type scanPlanBinaryBitsToBitsScanner struct{} + +func (scanPlanBinaryBitsToBitsScanner) Scan(src []byte, dst any) error { + scanner := (dst).(BitsScanner) + + if src == nil { + return scanner.ScanBits(Bits{}) + } + + if len(src) < 4 { + return fmt.Errorf("invalid length for bit/varbit: %v", len(src)) + } + + bitLen := int32(binary.BigEndian.Uint32(src)) + rp := 4 + buf := make([]byte, len(src[rp:])) + copy(buf, src[rp:]) + + return scanner.ScanBits(Bits{Bytes: buf, Len: bitLen, Valid: true}) +} + +type scanPlanTextAnyToBitsScanner struct{} + +func (scanPlanTextAnyToBitsScanner) Scan(src []byte, dst any) error { + scanner := (dst).(BitsScanner) + + if src == nil { + return scanner.ScanBits(Bits{}) + } + + bitLen := len(src) + byteLen := bitLen / 8 + if bitLen%8 > 0 { + byteLen++ + } + buf := make([]byte, byteLen) + + for i, b := range src { + if b == '1' { + byteIdx := i / 8 + bitIdx := uint(i % 8) + buf[byteIdx] = buf[byteIdx] | (128 >> bitIdx) + } + } + + return scanner.ScanBits(Bits{Bytes: buf, Len: int32(bitLen), Valid: true}) +} diff --git a/pgtype/bits_test.go b/pgtype/bits_test.go new file mode 100644 index 000000000..d517df2b8 --- /dev/null +++ b/pgtype/bits_test.go @@ -0,0 +1,57 @@ +package pgtype_test + +import ( + "bytes" + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func isExpectedEqBits(a any) func(any) bool { + return func(v any) bool { + ab := a.(pgtype.Bits) + vb := v.(pgtype.Bits) + return bytes.Equal(ab.Bytes, vb.Bytes) && ab.Len == vb.Len && ab.Valid == vb.Valid + } +} + +func TestBitsCodecBit(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "bit(40)", []pgxtest.ValueRoundTripTest{ + { + pgtype.Bits{Bytes: []byte{0, 0, 0, 0, 0}, Len: 40, Valid: true}, + new(pgtype.Bits), + isExpectedEqBits(pgtype.Bits{Bytes: []byte{0, 0, 0, 0, 0}, Len: 40, Valid: true}), + }, + { + pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}, + new(pgtype.Bits), + isExpectedEqBits(pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}), + }, + {pgtype.Bits{}, new(pgtype.Bits), isExpectedEqBits(pgtype.Bits{})}, + {nil, new(pgtype.Bits), isExpectedEqBits(pgtype.Bits{})}, + }) +} + +func TestBitsCodecVarbit(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "varbit", []pgxtest.ValueRoundTripTest{ + { + pgtype.Bits{Bytes: []byte{}, Len: 0, Valid: true}, + new(pgtype.Bits), + isExpectedEqBits(pgtype.Bits{Bytes: []byte{}, Len: 0, Valid: true}), + }, + { + pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}, + new(pgtype.Bits), + isExpectedEqBits(pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}), + }, + { + pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Valid: true}, + new(pgtype.Bits), + isExpectedEqBits(pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Valid: true}), + }, + {pgtype.Bits{}, new(pgtype.Bits), isExpectedEqBits(pgtype.Bits{})}, + {nil, new(pgtype.Bits), isExpectedEqBits(pgtype.Bits{})}, + }) +} diff --git a/pgtype/bool.go b/pgtype/bool.go index 3a3eef48f..955f01fe8 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -1,104 +1,165 @@ package pgtype import ( + "bytes" "database/sql/driver" + "encoding/json" + "fmt" "strconv" - - "github.com/pkg/errors" + "strings" ) +type BoolScanner interface { + ScanBool(v Bool) error +} + +type BoolValuer interface { + BoolValue() (Bool, error) +} + type Bool struct { - Bool bool - Status Status + Bool bool + Valid bool +} + +// ScanBool implements the [BoolScanner] interface. +func (b *Bool) ScanBool(v Bool) error { + *b = v + return nil +} + +// BoolValue implements the [BoolValuer] interface. +func (b Bool) BoolValue() (Bool, error) { + return b, nil } -func (dst *Bool) Set(src interface{}) error { - switch value := src.(type) { +// Scan implements the [database/sql.Scanner] interface. +func (dst *Bool) Scan(src any) error { + if src == nil { + *dst = Bool{} + return nil + } + + switch src := src.(type) { case bool: - *dst = Bool{Bool: value, Status: Present} + *dst = Bool{Bool: src, Valid: true} + return nil case string: - bb, err := strconv.ParseBool(value) + b, err := strconv.ParseBool(src) if err != nil { return err } - *dst = Bool{Bool: bb, Status: Present} - default: - if originalSrc, ok := underlyingBoolType(src); ok { - return dst.Set(originalSrc) + *dst = Bool{Bool: b, Valid: true} + return nil + case []byte: + b, err := strconv.ParseBool(string(src)) + if err != nil { + return err } - return errors.Errorf("cannot convert %v to Bool", value) + *dst = Bool{Bool: b, Valid: true} + return nil } - return nil + return fmt.Errorf("cannot scan %T", src) } -func (dst *Bool) Get() interface{} { - switch dst.Status { - case Present: - return dst.Bool - case Null: - return nil - default: - return dst.Status +// Value implements the [database/sql/driver.Valuer] interface. +func (src Bool) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil } + + return src.Bool, nil } -func (src *Bool) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *bool: - *v = src.Bool - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) +// MarshalJSON implements the [encoding/json.Marshaler] interface. +func (src Bool) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil } - return errors.Errorf("cannot decode %v into %T", src, dst) + if src.Bool { + return []byte("true"), nil + } else { + return []byte("false"), nil + } } -func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Bool{Status: Null} - return nil +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. +func (dst *Bool) UnmarshalJSON(b []byte) error { + var v *bool + err := json.Unmarshal(b, &v) + if err != nil { + return err } - if len(src) != 1 { - return errors.Errorf("invalid length for bool: %v", len(src)) + if v == nil { + *dst = Bool{} + } else { + *dst = Bool{Bool: *v, Valid: true} } - *dst = Bool{Bool: src[0] == 't', Status: Present} return nil } -func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Bool{Status: Null} - return nil - } +type BoolCodec struct{} - if len(src) != 1 { - return errors.Errorf("invalid length for bool: %v", len(src)) +func (BoolCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (BoolCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (BoolCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case bool: + return encodePlanBoolCodecBinaryBool{} + case BoolValuer: + return encodePlanBoolCodecBinaryBoolValuer{} + } + case TextFormatCode: + switch value.(type) { + case bool: + return encodePlanBoolCodecTextBool{} + case BoolValuer: + return encodePlanBoolCodecTextBoolValuer{} + } } - *dst = Bool{Bool: src[0] == 1, Status: Present} return nil } -func (src *Bool) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: +type encodePlanBoolCodecBinaryBool struct{} + +func (encodePlanBoolCodecBinaryBool) Encode(value any, buf []byte) (newBuf []byte, err error) { + v := value.(bool) + + if v { + buf = append(buf, 1) + } else { + buf = append(buf, 0) + } + + return buf, nil +} + +type encodePlanBoolCodecTextBoolValuer struct{} + +func (encodePlanBoolCodecTextBoolValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + b, err := value.(BoolValuer).BoolValue() + if err != nil { + return nil, err + } + + if !b.Valid { return nil, nil - case Undefined: - return nil, errUndefined } - if src.Bool { + if b.Bool { buf = append(buf, 't') } else { buf = append(buf, 'f') @@ -107,15 +168,19 @@ func (src *Bool) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Bool) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: +type encodePlanBoolCodecBinaryBoolValuer struct{} + +func (encodePlanBoolCodecBinaryBoolValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + b, err := value.(BoolValuer).BoolValue() + if err != nil { + return nil, err + } + + if !b.Valid { return nil, nil - case Undefined: - return nil, errUndefined } - if src.Bool { + if b.Bool { buf = append(buf, 1) } else { buf = append(buf, 0) @@ -124,36 +189,158 @@ func (src *Bool) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -// Scan implements the database/sql Scanner interface. -func (dst *Bool) Scan(src interface{}) error { - if src == nil { - *dst = Bool{Status: Null} - return nil +type encodePlanBoolCodecTextBool struct{} + +func (encodePlanBoolCodecTextBool) Encode(value any, buf []byte) (newBuf []byte, err error) { + v := value.(bool) + + if v { + buf = append(buf, 't') + } else { + buf = append(buf, 'f') } - switch src := src.(type) { - case bool: - *dst = Bool{Bool: src, Status: Present} - return nil - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + return buf, nil +} + +func (BoolCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case *bool: + return scanPlanBinaryBoolToBool{} + case BoolScanner: + return scanPlanBinaryBoolToBoolScanner{} + } + case TextFormatCode: + switch target.(type) { + case *bool: + return scanPlanTextAnyToBool{} + case BoolScanner: + return scanPlanTextAnyToBoolScanner{} + } } - return errors.Errorf("cannot scan %T", src) + return nil +} + +func (c BoolCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return c.DecodeValue(m, oid, format, src) } -// Value implements the database/sql/driver Valuer interface. -func (src *Bool) Value() (driver.Value, error) { - switch src.Status { - case Present: - return src.Bool, nil - case Null: +func (c BoolCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { return nil, nil + } + + var b bool + err := codecScan(c, m, oid, format, src, &b) + if err != nil { + return nil, err + } + return b, nil +} + +type scanPlanBinaryBoolToBool struct{} + +func (scanPlanBinaryBoolToBool) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) + } + + p, ok := (dst).(*bool) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = src[0] == 1 + + return nil +} + +type scanPlanTextAnyToBool struct{} + +func (scanPlanTextAnyToBool) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) == 0 { + return fmt.Errorf("cannot scan empty string into %T", dst) + } + + p, ok := (dst).(*bool) + if !ok { + return ErrScanTargetTypeChanged + } + + v, err := planTextToBool(src) + if err != nil { + return err + } + + *p = v + + return nil +} + +type scanPlanBinaryBoolToBoolScanner struct{} + +func (scanPlanBinaryBoolToBoolScanner) Scan(src []byte, dst any) error { + s, ok := (dst).(BoolScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanBool(Bool{}) + } + + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) + } + + return s.ScanBool(Bool{Bool: src[0] == 1, Valid: true}) +} + +type scanPlanTextAnyToBoolScanner struct{} + +func (scanPlanTextAnyToBoolScanner) Scan(src []byte, dst any) error { + s, ok := (dst).(BoolScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanBool(Bool{}) + } + + if len(src) == 0 { + return fmt.Errorf("cannot scan empty string into %T", dst) + } + + v, err := planTextToBool(src) + if err != nil { + return err + } + + return s.ScanBool(Bool{Bool: v, Valid: true}) +} + +// https://www.postgresql.org/docs/current/datatype-boolean.html +func planTextToBool(src []byte) (bool, error) { + s := string(bytes.ToLower(bytes.TrimSpace(src))) + + switch { + case strings.HasPrefix("true", s), strings.HasPrefix("yes", s), s == "on", s == "1": + return true, nil + case strings.HasPrefix("false", s), strings.HasPrefix("no", s), strings.HasPrefix("off", s), s == "0": + return false, nil default: - return nil, errUndefined + return false, fmt.Errorf("unknown boolean string representation %q", src) } } diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go deleted file mode 100644 index 67dd92a74..000000000 --- a/pgtype/bool_array.go +++ /dev/null @@ -1,300 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type BoolArray struct { - Elements []Bool - Dimensions []ArrayDimension - Status Status -} - -func (dst *BoolArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = BoolArray{Status: Null} - return nil - } - - switch value := src.(type) { - - case []bool: - if value == nil { - *dst = BoolArray{Status: Null} - } else if len(value) == 0 { - *dst = BoolArray{Status: Present} - } else { - elements := make([]Bool, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = BoolArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - default: - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to BoolArray", value) - } - - return nil -} - -func (dst *BoolArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *BoolArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - - case *[]bool: - *v = make([]bool, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = BoolArray{Status: Null} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Bool - - if len(uta.Elements) > 0 { - elements = make([]Bool, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Bool - var elemSrc []byte - if s != "NULL" { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = BoolArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} - - return nil -} - -func (dst *BoolArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = BoolArray{Status: Null} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = BoolArray{Dimensions: arrayHeader.Dimensions, Status: Present} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Bool, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = BoolArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} - return nil -} - -func (src *BoolArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src *BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("bool"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, errors.Errorf("unable to find oid for type name %v", "bool") - } - - for i := range src.Elements { - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *BoolArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *BoolArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/bool_array_test.go b/pgtype/bool_array_test.go deleted file mode 100644 index b529555e7..000000000 --- a/pgtype/bool_array_test.go +++ /dev/null @@ -1,153 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestBoolArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "bool[]", []interface{}{ - &pgtype.BoolArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.BoolArray{Status: pgtype.Null}, - &pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Bool: false, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - {Bool: true, Status: pgtype.Present}, - {Bool: false, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestBoolArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.BoolArray - }{ - { - source: []bool{true}, - result: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]bool)(nil)), - result: pgtype.BoolArray{Status: pgtype.Null}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.BoolArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestBoolArrayAssignTo(t *testing.T) { - var boolSlice []bool - type _boolSlice []bool - var namedBoolSlice _boolSlice - - simpleTests := []struct { - src pgtype.BoolArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &boolSlice, - expected: []bool{true}, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &namedBoolSlice, - expected: _boolSlice{true}, - }, - { - src: pgtype.BoolArray{Status: pgtype.Null}, - dst: &boolSlice, - expected: (([]bool)(nil)), - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.BoolArray - dst interface{} - }{ - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &boolSlice, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/pgtype/bool_test.go b/pgtype/bool_test.go index 2712e3b06..7480471b9 100644 --- a/pgtype/bool_test.go +++ b/pgtype/bool_test.go @@ -1,96 +1,62 @@ package pgtype_test import ( - "reflect" + "context" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" ) -func TestBoolTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "bool", []interface{}{ - &pgtype.Bool{Bool: false, Status: pgtype.Present}, - &pgtype.Bool{Bool: true, Status: pgtype.Present}, - &pgtype.Bool{Bool: false, Status: pgtype.Null}, +func TestBoolCodec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "bool", []pgxtest.ValueRoundTripTest{ + {true, new(bool), isExpectedEq(true)}, + {false, new(bool), isExpectedEq(false)}, + {true, new(pgtype.Bool), isExpectedEq(pgtype.Bool{Bool: true, Valid: true})}, + {pgtype.Bool{}, new(pgtype.Bool), isExpectedEq(pgtype.Bool{})}, + {nil, new(*bool), isExpectedEq((*bool)(nil))}, }) } -func TestBoolSet(t *testing.T) { +func TestBoolMarshalJSON(t *testing.T) { successfulTests := []struct { - source interface{} - result pgtype.Bool + source pgtype.Bool + result string }{ - {source: true, result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, - {source: false, result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, - {source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, - {source: "false", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, - {source: "t", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, - {source: "f", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, - {source: _bool(true), result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, - {source: _bool(false), result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + {source: pgtype.Bool{}, result: "null"}, + {source: pgtype.Bool{Bool: true, Valid: true}, result: "true"}, + {source: pgtype.Bool{Bool: false, Valid: true}, result: "false"}, } - for i, tt := range successfulTests { - var r pgtype.Bool - err := r.Set(tt.source) + r, err := tt.source.MarshalJSON() if err != nil { t.Errorf("%d: %v", i, err) } - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) } } } -func TestBoolAssignTo(t *testing.T) { - var b bool - var _b _bool - var pb *bool - var _pb *_bool - - simpleTests := []struct { - src pgtype.Bool - dst interface{} - expected interface{} - }{ - {src: pgtype.Bool{Bool: false, Status: pgtype.Present}, dst: &b, expected: false}, - {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &b, expected: true}, - {src: pgtype.Bool{Bool: false, Status: pgtype.Present}, dst: &_b, expected: _bool(false)}, - {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &_b, expected: _bool(true)}, - {src: pgtype.Bool{Bool: false, Status: pgtype.Null}, dst: &pb, expected: ((*bool)(nil))}, - {src: pgtype.Bool{Bool: false, Status: pgtype.Null}, dst: &_pb, expected: ((*_bool)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Bool - dst interface{} - expected interface{} +func TestBoolUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Bool }{ - {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &pb, expected: true}, - {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &_pb, expected: _bool(true)}, + {source: "null", result: pgtype.Bool{}}, + {source: "true", result: pgtype.Bool{Bool: true, Valid: true}}, + {source: "false", result: pgtype.Bool{Bool: false, Valid: true}}, } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) + for i, tt := range successfulTests { + var r pgtype.Bool + err := r.UnmarshalJSON([]byte(tt.source)) if err != nil { t.Errorf("%d: %v", i, err) } - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) } } } diff --git a/pgtype/box.go b/pgtype/box.go index 83df04992..d243f58e3 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -8,89 +8,154 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" + "github.com/jackc/pgx/v5/internal/pgio" ) -type Box struct { - P [2]Vec2 - Status Status +type BoxScanner interface { + ScanBox(v Box) error } -func (dst *Box) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Box", src) +type BoxValuer interface { + BoxValue() (Box, error) } -func (dst *Box) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } +type Box struct { + P [2]Vec2 + Valid bool +} + +// ScanBox implements the [BoxScanner] interface. +func (b *Box) ScanBox(v Box) error { + *b = v + return nil } -func (src *Box) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) +// BoxValue implements the [BoxValuer] interface. +func (b Box) BoxValue() (Box, error) { + return b, nil } -func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error { +// Scan implements the [database/sql.Scanner] interface. +func (dst *Box) Scan(src any) error { if src == nil { - *dst = Box{Status: Null} + *dst = Box{} return nil } - if len(src) < 11 { - return errors.Errorf("invalid length for Box: %v", len(src)) + switch src := src.(type) { + case string: + return scanPlanTextAnyToBoxScanner{}.Scan([]byte(src), dst) } - str := string(src[1:]) + return fmt.Errorf("cannot scan %T", src) +} - var end int - end = strings.IndexByte(str, ',') +// Value implements the [database/sql/driver.Valuer] interface. +func (src Box) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } - x1, err := strconv.ParseFloat(str[:end], 64) + buf, err := BoxCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) if err != nil { - return err + return nil, err } + return string(buf), err +} - str = str[end+1:] - end = strings.IndexByte(str, ')') +type BoxCodec struct{} - y1, err := strconv.ParseFloat(str[:end], 64) - if err != nil { - return err +func (BoxCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (BoxCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (BoxCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(BoxValuer); !ok { + return nil } - str = str[end+3:] - end = strings.IndexByte(str, ',') + switch format { + case BinaryFormatCode: + return encodePlanBoxCodecBinary{} + case TextFormatCode: + return encodePlanBoxCodecText{} + } - x2, err := strconv.ParseFloat(str[:end], 64) + return nil +} + +type encodePlanBoxCodecBinary struct{} + +func (encodePlanBoxCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + box, err := value.(BoxValuer).BoxValue() if err != nil { - return err + return nil, err } - str = str[end+1 : len(str)-1] + if !box.Valid { + return nil, nil + } - y2, err := strconv.ParseFloat(str, 64) + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[0].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[0].Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[1].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[1].Y)) + return buf, nil +} + +type encodePlanBoxCodecText struct{} + +func (encodePlanBoxCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + box, err := value.(BoxValuer).BoxValue() if err != nil { - return err + return nil, err + } + + if !box.Valid { + return nil, nil + } + + buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, + strconv.FormatFloat(box.P[0].X, 'f', -1, 64), + strconv.FormatFloat(box.P[0].Y, 'f', -1, 64), + strconv.FormatFloat(box.P[1].X, 'f', -1, 64), + strconv.FormatFloat(box.P[1].Y, 'f', -1, 64), + )...) + return buf, nil +} + +func (BoxCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case BoxScanner: + return scanPlanBinaryBoxToBoxScanner{} + } + case TextFormatCode: + switch target.(type) { + case BoxScanner: + return scanPlanTextAnyToBoxScanner{} + } } - *dst = Box{P: [2]Vec2{{x1, y1}, {x2, y2}}, Status: Present} return nil } -func (dst *Box) DecodeBinary(ci *ConnInfo, src []byte) error { +type scanPlanBinaryBoxToBoxScanner struct{} + +func (scanPlanBinaryBoxToBoxScanner) Scan(src []byte, dst any) error { + scanner := (dst).(BoxScanner) + if src == nil { - *dst = Box{Status: Null} - return nil + return scanner.ScanBox(Box{}) } if len(src) != 32 { - return errors.Errorf("invalid length for Box: %v", len(src)) + return fmt.Errorf("invalid length for Box: %v", len(src)) } x1 := binary.BigEndian.Uint64(src) @@ -98,65 +163,77 @@ func (dst *Box) DecodeBinary(ci *ConnInfo, src []byte) error { x2 := binary.BigEndian.Uint64(src[16:]) y2 := binary.BigEndian.Uint64(src[24:]) - *dst = Box{ + return scanner.ScanBox(Box{ P: [2]Vec2{ {math.Float64frombits(x1), math.Float64frombits(y1)}, {math.Float64frombits(x2), math.Float64frombits(y2)}, }, - Status: Present, - } - return nil + Valid: true, + }) } -func (src *Box) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined +type scanPlanTextAnyToBoxScanner struct{} + +func (scanPlanTextAnyToBoxScanner) Scan(src []byte, dst any) error { + scanner := (dst).(BoxScanner) + + if src == nil { + return scanner.ScanBox(Box{}) } - buf = append(buf, fmt.Sprintf(`(%f,%f),(%f,%f)`, - src.P[0].X, src.P[0].Y, src.P[1].X, src.P[1].Y)...) - return buf, nil -} + if len(src) < 11 { + return fmt.Errorf("invalid length for Box: %v", len(src)) + } -func (src *Box) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined + str := string(src[1:]) + + var end int + end = strings.IndexByte(str, ',') + + x1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err } - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].X)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].Y)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].X)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].Y)) + str = str[end+1:] + end = strings.IndexByte(str, ')') - return buf, nil -} + y1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } -// Scan implements the database/sql Scanner interface. -func (dst *Box) Scan(src interface{}) error { - if src == nil { - *dst = Box{Status: Null} - return nil + str = str[end+3:] + end = strings.IndexByte(str, ',') + + x2, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err } - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + str = str[end+1 : len(str)-1] + + y2, err := strconv.ParseFloat(str, 64) + if err != nil { + return err } - return errors.Errorf("cannot scan %T", src) + return scanner.ScanBox(Box{P: [2]Vec2{{x1, y1}, {x2, y2}}, Valid: true}) +} + +func (c BoxCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) } -// Value implements the database/sql/driver Valuer interface. -func (src *Box) Value() (driver.Value, error) { - return EncodeValueText(src) +func (c BoxCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var box Box + err := codecScan(c, m, oid, format, src, &box) + if err != nil { + return nil, err + } + return box, nil } diff --git a/pgtype/box_test.go b/pgtype/box_test.go index f26cda68b..3b54c1f83 100644 --- a/pgtype/box_test.go +++ b/pgtype/box_test.go @@ -1,34 +1,40 @@ package pgtype_test import ( + "context" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" ) -func TestBoxTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "box", []interface{}{ - &pgtype.Box{ - P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, - Status: pgtype.Present, - }, - &pgtype.Box{ - P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, - Status: pgtype.Present, - }, - &pgtype.Box{Status: pgtype.Null}, - }) -} +func TestBoxCodec(t *testing.T) { + skipCockroachDB(t, "Server does not support box type") -func TestBoxNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "box", []pgxtest.ValueRoundTripTest{ + { + pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 5.2345678}, {3.14, 1.678}}, + Valid: true, + }, + new(pgtype.Box), + isExpectedEq(pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 5.2345678}, {3.14, 1.678}}, + Valid: true, + }), + }, { - SQL: "select '3.14, 1.678, 7.1, 5.234'::box", - Value: &pgtype.Box{ - P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, - Status: pgtype.Present, + pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 5.2345678}, {-13.14, -5.234}}, + Valid: true, }, + new(pgtype.Box), + isExpectedEq(pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 5.2345678}, {-13.14, -5.234}}, + Valid: true, + }), }, + {pgtype.Box{}, new(pgtype.Box), isExpectedEq(pgtype.Box{})}, + {nil, new(pgtype.Box), isExpectedEq(pgtype.Box{})}, }) } diff --git a/pgtype/bpchar.go b/pgtype/bpchar.go deleted file mode 100644 index 212631841..000000000 --- a/pgtype/bpchar.go +++ /dev/null @@ -1,68 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" -) - -// BPChar is fixed-length, blank padded char type -// character(n), char(n) -type BPChar Text - -// Set converts from src to dst. -func (dst *BPChar) Set(src interface{}) error { - return (*Text)(dst).Set(src) -} - -// Get returns underlying value -func (dst *BPChar) Get() interface{} { - return (*Text)(dst).Get() -} - -// AssignTo assigns from src to dst. -func (src *BPChar) AssignTo(dst interface{}) error { - if src.Status == Present { - switch v := dst.(type) { - case *rune: - runes := []rune(src.String) - if len(runes) == 1 { - *v = runes[0] - return nil - } - } - } - return (*Text)(src).AssignTo(dst) -} - -func (dst *BPChar) DecodeText(ci *ConnInfo, src []byte) error { - return (*Text)(dst).DecodeText(ci, src) -} - -func (dst *BPChar) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*Text)(dst).DecodeBinary(ci, src) -} - -func (src *BPChar) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Text)(src).EncodeText(ci, buf) -} - -func (src *BPChar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Text)(src).EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *BPChar) Scan(src interface{}) error { - return (*Text)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *BPChar) Value() (driver.Value, error) { - return (*Text)(src).Value() -} - -func (src *BPChar) MarshalJSON() ([]byte, error) { - return (*Text)(src).MarshalJSON() -} - -func (dst *BPChar) UnmarshalJSON(b []byte) error { - return (*Text)(dst).UnmarshalJSON(b) -} diff --git a/pgtype/bpchar_array.go b/pgtype/bpchar_array.go deleted file mode 100644 index 1e6220f79..000000000 --- a/pgtype/bpchar_array.go +++ /dev/null @@ -1,300 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type BPCharArray struct { - Elements []BPChar - Dimensions []ArrayDimension - Status Status -} - -func (dst *BPCharArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = BPCharArray{Status: Null} - return nil - } - - switch value := src.(type) { - - case []string: - if value == nil { - *dst = BPCharArray{Status: Null} - } else if len(value) == 0 { - *dst = BPCharArray{Status: Present} - } else { - elements := make([]BPChar, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = BPCharArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - default: - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to BPCharArray", value) - } - - return nil -} - -func (dst *BPCharArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *BPCharArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *BPCharArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = BPCharArray{Status: Null} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []BPChar - - if len(uta.Elements) > 0 { - elements = make([]BPChar, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem BPChar - var elemSrc []byte - if s != "NULL" { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = BPCharArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} - - return nil -} - -func (dst *BPCharArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = BPCharArray{Status: Null} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = BPCharArray{Dimensions: arrayHeader.Dimensions, Status: Present} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]BPChar, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = BPCharArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} - return nil -} - -func (src *BPCharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src *BPCharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("bpchar"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, errors.Errorf("unable to find oid for type name %v", "bpchar") - } - - for i := range src.Elements { - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *BPCharArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *BPCharArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/bpchar_array_test.go b/pgtype/bpchar_array_test.go deleted file mode 100644 index e4f2e7ebf..000000000 --- a/pgtype/bpchar_array_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestBPCharArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "char(8)[]", []interface{}{ - &pgtype.BPCharArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.BPCharArray{ - Elements: []pgtype.BPChar{ - pgtype.BPChar{String: "foo ", Status: pgtype.Present}, - pgtype.BPChar{Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.BPCharArray{Status: pgtype.Null}, - &pgtype.BPCharArray{ - Elements: []pgtype.BPChar{ - pgtype.BPChar{String: "bar ", Status: pgtype.Present}, - pgtype.BPChar{String: "NuLL ", Status: pgtype.Present}, - pgtype.BPChar{String: `wow"quz\`, Status: pgtype.Present}, - pgtype.BPChar{String: "1 ", Status: pgtype.Present}, - pgtype.BPChar{String: "1 ", Status: pgtype.Present}, - pgtype.BPChar{String: "null ", Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 3, LowerBound: 1}, - {Length: 2, LowerBound: 1}, - }, - Status: pgtype.Present, - }, - &pgtype.BPCharArray{ - Elements: []pgtype.BPChar{ - pgtype.BPChar{String: " bar ", Status: pgtype.Present}, - pgtype.BPChar{String: " baz ", Status: pgtype.Present}, - pgtype.BPChar{String: " quz ", Status: pgtype.Present}, - pgtype.BPChar{String: "foo ", Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} diff --git a/pgtype/bpchar_test.go b/pgtype/bpchar_test.go deleted file mode 100644 index c076ca1b8..000000000 --- a/pgtype/bpchar_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestChar3Transcode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "char(3)", []interface{}{ - &pgtype.BPChar{String: "a ", Status: pgtype.Present}, - &pgtype.BPChar{String: " a ", Status: pgtype.Present}, - &pgtype.BPChar{String: "嗨 ", Status: pgtype.Present}, - &pgtype.BPChar{String: " ", Status: pgtype.Present}, - &pgtype.BPChar{Status: pgtype.Null}, - }, func(aa, bb interface{}) bool { - a := aa.(pgtype.BPChar) - b := bb.(pgtype.BPChar) - - return a.Status == b.Status && a.String == b.String - }) -} - -func TestBPCharAssignTo(t *testing.T) { - var ( - str string - run rune - ) - simpleTests := []struct { - src pgtype.BPChar - dst interface{} - expected interface{} - }{ - {src: pgtype.BPChar{String: "simple", Status: pgtype.Present}, dst: &str, expected: "simple"}, - {src: pgtype.BPChar{String: "嗨", Status: pgtype.Present}, dst: &run, expected: '嗨'}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - -} diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go new file mode 100644 index 000000000..84964425b --- /dev/null +++ b/pgtype/builtin_wrappers.go @@ -0,0 +1,952 @@ +package pgtype + +import ( + "errors" + "fmt" + "math" + "math/big" + "net" + "net/netip" + "reflect" + "time" +) + +type int8Wrapper int8 + +func (w int8Wrapper) SkipUnderlyingTypePlan() {} + +func (w *int8Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *int8") + } + + if v.Int64 < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", v.Int64) + } + if v.Int64 > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", v.Int64) + } + *w = int8Wrapper(v.Int64) + + return nil +} + +func (w int8Wrapper) Int64Value() (Int8, error) { + return Int8{Int64: int64(w), Valid: true}, nil +} + +type int16Wrapper int16 + +func (w int16Wrapper) SkipUnderlyingTypePlan() {} + +func (w *int16Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *int16") + } + + if v.Int64 < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for int16", v.Int64) + } + if v.Int64 > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for int16", v.Int64) + } + *w = int16Wrapper(v.Int64) + + return nil +} + +func (w int16Wrapper) Int64Value() (Int8, error) { + return Int8{Int64: int64(w), Valid: true}, nil +} + +type int32Wrapper int32 + +func (w int32Wrapper) SkipUnderlyingTypePlan() {} + +func (w *int32Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *int32") + } + + if v.Int64 < math.MinInt32 { + return fmt.Errorf("%d is less than minimum value for int32", v.Int64) + } + if v.Int64 > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for int32", v.Int64) + } + *w = int32Wrapper(v.Int64) + + return nil +} + +func (w int32Wrapper) Int64Value() (Int8, error) { + return Int8{Int64: int64(w), Valid: true}, nil +} + +type int64Wrapper int64 + +func (w int64Wrapper) SkipUnderlyingTypePlan() {} + +func (w *int64Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *int64") + } + + *w = int64Wrapper(v.Int64) + + return nil +} + +func (w int64Wrapper) Int64Value() (Int8, error) { + return Int8{Int64: int64(w), Valid: true}, nil +} + +type intWrapper int + +func (w intWrapper) SkipUnderlyingTypePlan() {} + +func (w *intWrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *int") + } + + if v.Int64 < math.MinInt { + return fmt.Errorf("%d is less than minimum value for int", v.Int64) + } + if v.Int64 > math.MaxInt { + return fmt.Errorf("%d is greater than maximum value for int", v.Int64) + } + + *w = intWrapper(v.Int64) + + return nil +} + +func (w intWrapper) Int64Value() (Int8, error) { + return Int8{Int64: int64(w), Valid: true}, nil +} + +type uint8Wrapper uint8 + +func (w uint8Wrapper) SkipUnderlyingTypePlan() {} + +func (w *uint8Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint8") + } + + if v.Int64 < 0 { + return fmt.Errorf("%d is less than minimum value for uint8", v.Int64) + } + if v.Int64 > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", v.Int64) + } + *w = uint8Wrapper(v.Int64) + + return nil +} + +func (w uint8Wrapper) Int64Value() (Int8, error) { + return Int8{Int64: int64(w), Valid: true}, nil +} + +type uint16Wrapper uint16 + +func (w uint16Wrapper) SkipUnderlyingTypePlan() {} + +func (w *uint16Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint16") + } + + if v.Int64 < 0 { + return fmt.Errorf("%d is less than minimum value for uint16", v.Int64) + } + if v.Int64 > math.MaxUint16 { + return fmt.Errorf("%d is greater than maximum value for uint16", v.Int64) + } + *w = uint16Wrapper(v.Int64) + + return nil +} + +func (w uint16Wrapper) Int64Value() (Int8, error) { + return Int8{Int64: int64(w), Valid: true}, nil +} + +type uint32Wrapper uint32 + +func (w uint32Wrapper) SkipUnderlyingTypePlan() {} + +func (w *uint32Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint32") + } + + if v.Int64 < 0 { + return fmt.Errorf("%d is less than minimum value for uint32", v.Int64) + } + if v.Int64 > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for uint32", v.Int64) + } + *w = uint32Wrapper(v.Int64) + + return nil +} + +func (w uint32Wrapper) Int64Value() (Int8, error) { + return Int8{Int64: int64(w), Valid: true}, nil +} + +type uint64Wrapper uint64 + +func (w uint64Wrapper) SkipUnderlyingTypePlan() {} + +func (w *uint64Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint64") + } + + if v.Int64 < 0 { + return fmt.Errorf("%d is less than minimum value for uint64", v.Int64) + } + + *w = uint64Wrapper(v.Int64) + + return nil +} + +func (w uint64Wrapper) Int64Value() (Int8, error) { + if uint64(w) > uint64(math.MaxInt64) { + return Int8{}, fmt.Errorf("%d is greater than maximum value for int64", w) + } + + return Int8{Int64: int64(w), Valid: true}, nil +} + +func (w *uint64Wrapper) ScanNumeric(v Numeric) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint64") + } + + bi, err := v.toBigInt() + if err != nil { + return fmt.Errorf("cannot scan into *uint64: %w", err) + } + + if !bi.IsUint64() { + return fmt.Errorf("cannot scan %v into *uint64", bi.String()) + } + + *w = uint64Wrapper(bi.Uint64()) + + return nil +} + +func (w uint64Wrapper) NumericValue() (Numeric, error) { + return Numeric{Int: new(big.Int).SetUint64(uint64(w)), Valid: true}, nil +} + +type uintWrapper uint + +func (w uintWrapper) SkipUnderlyingTypePlan() {} + +func (w *uintWrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint64") + } + + if v.Int64 < 0 { + return fmt.Errorf("%d is less than minimum value for uint64", v.Int64) + } + + if uint64(v.Int64) > math.MaxUint { + return fmt.Errorf("%d is greater than maximum value for uint", v.Int64) + } + + *w = uintWrapper(v.Int64) + + return nil +} + +func (w uintWrapper) Int64Value() (Int8, error) { + if uint64(w) > uint64(math.MaxInt64) { + return Int8{}, fmt.Errorf("%d is greater than maximum value for int64", w) + } + + return Int8{Int64: int64(w), Valid: true}, nil +} + +func (w *uintWrapper) ScanNumeric(v Numeric) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint") + } + + bi, err := v.toBigInt() + if err != nil { + return fmt.Errorf("cannot scan into *uint: %w", err) + } + + if !bi.IsUint64() { + return fmt.Errorf("cannot scan %v into *uint", bi.String()) + } + + ui := bi.Uint64() + + if math.MaxUint < ui { + return fmt.Errorf("cannot scan %v into *uint", ui) + } + + *w = uintWrapper(ui) + + return nil +} + +func (w uintWrapper) NumericValue() (Numeric, error) { + return Numeric{Int: new(big.Int).SetUint64(uint64(w)), Valid: true}, nil +} + +type float32Wrapper float32 + +func (w float32Wrapper) SkipUnderlyingTypePlan() {} + +func (w *float32Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *float32") + } + + *w = float32Wrapper(v.Int64) + + return nil +} + +func (w float32Wrapper) Int64Value() (Int8, error) { + if w > math.MaxInt64 { + return Int8{}, fmt.Errorf("%f is greater than maximum value for int64", w) + } + + return Int8{Int64: int64(w), Valid: true}, nil +} + +func (w *float32Wrapper) ScanFloat64(v Float8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *float32") + } + + *w = float32Wrapper(v.Float64) + + return nil +} + +func (w float32Wrapper) Float64Value() (Float8, error) { + return Float8{Float64: float64(w), Valid: true}, nil +} + +type float64Wrapper float64 + +func (w float64Wrapper) SkipUnderlyingTypePlan() {} + +func (w *float64Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *float64") + } + + *w = float64Wrapper(v.Int64) + + return nil +} + +func (w float64Wrapper) Int64Value() (Int8, error) { + if w > math.MaxInt64 { + return Int8{}, fmt.Errorf("%f is greater than maximum value for int64", w) + } + + return Int8{Int64: int64(w), Valid: true}, nil +} + +func (w *float64Wrapper) ScanFloat64(v Float8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *float64") + } + + *w = float64Wrapper(v.Float64) + + return nil +} + +func (w float64Wrapper) Float64Value() (Float8, error) { + return Float8{Float64: float64(w), Valid: true}, nil +} + +type stringWrapper string + +func (w stringWrapper) SkipUnderlyingTypePlan() {} + +func (w *stringWrapper) ScanText(v Text) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *string") + } + + *w = stringWrapper(v.String) + return nil +} + +func (w stringWrapper) TextValue() (Text, error) { + return Text{String: string(w), Valid: true}, nil +} + +type timeWrapper time.Time + +func (w *timeWrapper) ScanDate(v Date) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *time.Time") + } + + switch v.InfinityModifier { + case Finite: + *w = timeWrapper(v.Time) + return nil + case Infinity: + return fmt.Errorf("cannot scan Infinity into *time.Time") + case NegativeInfinity: + return fmt.Errorf("cannot scan -Infinity into *time.Time") + default: + return fmt.Errorf("invalid InfinityModifier: %v", v.InfinityModifier) + } +} + +func (w timeWrapper) DateValue() (Date, error) { + return Date{Time: time.Time(w), Valid: true}, nil +} + +func (w *timeWrapper) ScanTimestamp(v Timestamp) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *time.Time") + } + + switch v.InfinityModifier { + case Finite: + *w = timeWrapper(v.Time) + return nil + case Infinity: + return fmt.Errorf("cannot scan Infinity into *time.Time") + case NegativeInfinity: + return fmt.Errorf("cannot scan -Infinity into *time.Time") + default: + return fmt.Errorf("invalid InfinityModifier: %v", v.InfinityModifier) + } +} + +func (w timeWrapper) TimestampValue() (Timestamp, error) { + return Timestamp{Time: time.Time(w), Valid: true}, nil +} + +func (w *timeWrapper) ScanTimestamptz(v Timestamptz) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *time.Time") + } + + switch v.InfinityModifier { + case Finite: + *w = timeWrapper(v.Time) + return nil + case Infinity: + return fmt.Errorf("cannot scan Infinity into *time.Time") + case NegativeInfinity: + return fmt.Errorf("cannot scan -Infinity into *time.Time") + default: + return fmt.Errorf("invalid InfinityModifier: %v", v.InfinityModifier) + } +} + +func (w timeWrapper) TimestamptzValue() (Timestamptz, error) { + return Timestamptz{Time: time.Time(w), Valid: true}, nil +} + +func (w *timeWrapper) ScanTime(v Time) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *time.Time") + } + + // 24:00:00 is max allowed time in PostgreSQL, but time.Time will normalize that to 00:00:00 the next day. + var maxRepresentableByTime int64 = 24*60*60*1000000 - 1 + if v.Microseconds > maxRepresentableByTime { + return fmt.Errorf("%d microseconds cannot be represented as time.Time", v.Microseconds) + } + + usec := v.Microseconds + hours := usec / microsecondsPerHour + usec -= hours * microsecondsPerHour + minutes := usec / microsecondsPerMinute + usec -= minutes * microsecondsPerMinute + seconds := usec / microsecondsPerSecond + usec -= seconds * microsecondsPerSecond + ns := usec * 1000 + *w = timeWrapper(time.Date(2000, 1, 1, int(hours), int(minutes), int(seconds), int(ns), time.UTC)) + return nil +} + +func (w timeWrapper) TimeValue() (Time, error) { + t := time.Time(w) + usec := int64(t.Hour())*microsecondsPerHour + + int64(t.Minute())*microsecondsPerMinute + + int64(t.Second())*microsecondsPerSecond + + int64(t.Nanosecond())/1000 + return Time{Microseconds: usec, Valid: true}, nil +} + +type durationWrapper time.Duration + +func (w durationWrapper) SkipUnderlyingTypePlan() {} + +func (w *durationWrapper) ScanInterval(v Interval) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *time.Interval") + } + + us := int64(v.Months)*microsecondsPerMonth + int64(v.Days)*microsecondsPerDay + v.Microseconds + *w = durationWrapper(time.Duration(us) * time.Microsecond) + return nil +} + +func (w durationWrapper) IntervalValue() (Interval, error) { + return Interval{Microseconds: int64(w) / 1000, Valid: true}, nil +} + +type netIPNetWrapper net.IPNet + +func (w *netIPNetWrapper) ScanNetipPrefix(v netip.Prefix) error { + if !v.IsValid() { + return fmt.Errorf("cannot scan NULL into *net.IPNet") + } + + *w = netIPNetWrapper{ + IP: v.Addr().AsSlice(), + Mask: net.CIDRMask(v.Bits(), v.Addr().BitLen()), + } + + return nil +} + +func (w netIPNetWrapper) NetipPrefixValue() (netip.Prefix, error) { + ip, ok := netip.AddrFromSlice(w.IP) + if !ok { + return netip.Prefix{}, errors.New("invalid net.IPNet") + } + + ones, _ := w.Mask.Size() + + return netip.PrefixFrom(ip, ones), nil +} + +type netIPWrapper net.IP + +func (w netIPWrapper) SkipUnderlyingTypePlan() {} + +func (w *netIPWrapper) ScanNetipPrefix(v netip.Prefix) error { + if !v.IsValid() { + *w = nil + return nil + } + + if v.Addr().BitLen() != v.Bits() { + return fmt.Errorf("cannot scan %v to *net.IP", v) + } + + *w = netIPWrapper(v.Addr().AsSlice()) + return nil +} + +func (w netIPWrapper) NetipPrefixValue() (netip.Prefix, error) { + if w == nil { + return netip.Prefix{}, nil + } + + addr, ok := netip.AddrFromSlice([]byte(w)) + if !ok { + return netip.Prefix{}, errors.New("invalid net.IP") + } + + return netip.PrefixFrom(addr, addr.BitLen()), nil +} + +type netipPrefixWrapper netip.Prefix + +func (w *netipPrefixWrapper) ScanNetipPrefix(v netip.Prefix) error { + *w = netipPrefixWrapper(v) + return nil +} + +func (w netipPrefixWrapper) NetipPrefixValue() (netip.Prefix, error) { + return netip.Prefix(w), nil +} + +type netipAddrWrapper netip.Addr + +func (w *netipAddrWrapper) ScanNetipPrefix(v netip.Prefix) error { + if !v.IsValid() { + *w = netipAddrWrapper(netip.Addr{}) + return nil + } + + if v.Addr().BitLen() != v.Bits() { + return fmt.Errorf("cannot scan %v to netip.Addr", v) + } + + *w = netipAddrWrapper(v.Addr()) + + return nil +} + +func (w netipAddrWrapper) NetipPrefixValue() (netip.Prefix, error) { + addr := (netip.Addr)(w) + if !addr.IsValid() { + return netip.Prefix{}, nil + } + + return netip.PrefixFrom(addr, addr.BitLen()), nil +} + +type mapStringToPointerStringWrapper map[string]*string + +func (w *mapStringToPointerStringWrapper) ScanHstore(v Hstore) error { + *w = mapStringToPointerStringWrapper(v) + return nil +} + +func (w mapStringToPointerStringWrapper) HstoreValue() (Hstore, error) { + return Hstore(w), nil +} + +type mapStringToStringWrapper map[string]string + +func (w *mapStringToStringWrapper) ScanHstore(v Hstore) error { + *w = make(mapStringToStringWrapper, len(v)) + for k, v := range v { + if v == nil { + return fmt.Errorf("cannot scan NULL to string") + } + (*w)[k] = *v + } + return nil +} + +func (w mapStringToStringWrapper) HstoreValue() (Hstore, error) { + if w == nil { + return nil, nil + } + + hstore := make(Hstore, len(w)) + for k, v := range w { + s := v + hstore[k] = &s + } + return hstore, nil +} + +type fmtStringerWrapper struct { + s fmt.Stringer +} + +func (w fmtStringerWrapper) TextValue() (Text, error) { + return Text{String: w.s.String(), Valid: true}, nil +} + +type byte16Wrapper [16]byte + +func (w *byte16Wrapper) ScanUUID(v UUID) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *[16]byte") + } + *w = byte16Wrapper(v.Bytes) + return nil +} + +func (w byte16Wrapper) UUIDValue() (UUID, error) { + return UUID{Bytes: [16]byte(w), Valid: true}, nil +} + +type byteSliceWrapper []byte + +func (w byteSliceWrapper) SkipUnderlyingTypePlan() {} + +func (w *byteSliceWrapper) ScanText(v Text) error { + if !v.Valid { + *w = nil + return nil + } + + *w = byteSliceWrapper(v.String) + return nil +} + +func (w byteSliceWrapper) TextValue() (Text, error) { + if w == nil { + return Text{}, nil + } + + return Text{String: string(w), Valid: true}, nil +} + +func (w *byteSliceWrapper) ScanUUID(v UUID) error { + if !v.Valid { + *w = nil + return nil + } + *w = make(byteSliceWrapper, 16) + copy(*w, v.Bytes[:]) + return nil +} + +func (w byteSliceWrapper) UUIDValue() (UUID, error) { + if w == nil { + return UUID{}, nil + } + + uuid := UUID{Valid: true} + copy(uuid.Bytes[:], w) + return uuid, nil +} + +// structWrapper implements CompositeIndexGetter for a struct. +type structWrapper struct { + s any + exportedFields []reflect.Value +} + +func (w structWrapper) IsNull() bool { + return w.s == nil +} + +func (w structWrapper) Index(i int) any { + if i >= len(w.exportedFields) { + return fmt.Errorf("%#v only has %d public fields - %d is out of bounds", w.s, len(w.exportedFields), i) + } + + return w.exportedFields[i].Interface() +} + +// ptrStructWrapper implements CompositeIndexScanner for a pointer to a struct. +type ptrStructWrapper struct { + s any + exportedFields []reflect.Value +} + +func (w *ptrStructWrapper) ScanNull() error { + return fmt.Errorf("cannot scan NULL into %#v", w.s) +} + +func (w *ptrStructWrapper) ScanIndex(i int) any { + if i >= len(w.exportedFields) { + return fmt.Errorf("%#v only has %d public fields - %d is out of bounds", w.s, len(w.exportedFields), i) + } + + return w.exportedFields[i].Addr().Interface() +} + +type anySliceArrayReflect struct { + slice reflect.Value +} + +func (a anySliceArrayReflect) Dimensions() []ArrayDimension { + if a.slice.IsNil() { + return nil + } + + return []ArrayDimension{{Length: int32(a.slice.Len()), LowerBound: 1}} +} + +func (a anySliceArrayReflect) Index(i int) any { + return a.slice.Index(i).Interface() +} + +func (a anySliceArrayReflect) IndexType() any { + return reflect.New(a.slice.Type().Elem()).Elem().Interface() +} + +func (a *anySliceArrayReflect) SetDimensions(dimensions []ArrayDimension) error { + sliceType := a.slice.Type() + + if dimensions == nil { + a.slice.Set(reflect.Zero(sliceType)) + return nil + } + + elementCount := cardinality(dimensions) + slice := reflect.MakeSlice(sliceType, elementCount, elementCount) + a.slice.Set(slice) + return nil +} + +func (a *anySliceArrayReflect) ScanIndex(i int) any { + return a.slice.Index(i).Addr().Interface() +} + +func (a *anySliceArrayReflect) ScanIndexType() any { + return reflect.New(a.slice.Type().Elem()).Interface() +} + +type anyMultiDimSliceArray struct { + slice reflect.Value + dims []ArrayDimension +} + +func (a *anyMultiDimSliceArray) Dimensions() []ArrayDimension { + if a.slice.IsNil() { + return nil + } + + s := a.slice + for { + a.dims = append(a.dims, ArrayDimension{Length: int32(s.Len()), LowerBound: 1}) + if s.Len() > 0 { + s = s.Index(0) + } else { + break + } + if s.Type().Kind() == reflect.Slice { + } else { + break + } + } + + return a.dims +} + +func (a *anyMultiDimSliceArray) Index(i int) any { + if len(a.dims) == 1 { + return a.slice.Index(i).Interface() + } + + indexes := make([]int, len(a.dims)) + for j := len(a.dims) - 1; j >= 0; j-- { + dimLen := int(a.dims[j].Length) + indexes[j] = i % dimLen + i = i / dimLen + } + + v := a.slice + for _, si := range indexes { + v = v.Index(si) + } + + return v.Interface() +} + +func (a *anyMultiDimSliceArray) IndexType() any { + lowestSliceType := a.slice.Type() + for ; lowestSliceType.Elem().Kind() == reflect.Slice; lowestSliceType = lowestSliceType.Elem() { + } + return reflect.New(lowestSliceType.Elem()).Elem().Interface() +} + +func (a *anyMultiDimSliceArray) SetDimensions(dimensions []ArrayDimension) error { + sliceType := a.slice.Type() + + if dimensions == nil { + a.slice.Set(reflect.Zero(sliceType)) + return nil + } + + switch len(dimensions) { + case 0: + // Empty, but non-nil array + slice := reflect.MakeSlice(sliceType, 0, 0) + a.slice.Set(slice) + return nil + case 1: + elementCount := cardinality(dimensions) + slice := reflect.MakeSlice(sliceType, elementCount, elementCount) + a.slice.Set(slice) + return nil + default: + sliceDimensionCount := 1 + lowestSliceType := sliceType + for ; lowestSliceType.Elem().Kind() == reflect.Slice; lowestSliceType = lowestSliceType.Elem() { + sliceDimensionCount++ + } + + if sliceDimensionCount != len(dimensions) { + return fmt.Errorf("PostgreSQL array has %d dimensions but slice has %d dimensions", len(dimensions), sliceDimensionCount) + } + + elementCount := cardinality(dimensions) + flatSlice := reflect.MakeSlice(lowestSliceType, elementCount, elementCount) + + multiDimSlice := a.makeMultidimensionalSlice(sliceType, dimensions, flatSlice, 0) + a.slice.Set(multiDimSlice) + + // Now that a.slice is a multi-dimensional slice with the underlying data pointed at flatSlice change a.slice to + // flatSlice so ScanIndex only has to handle simple one dimensional slices. + a.slice = flatSlice + + return nil + } +} + +func (a *anyMultiDimSliceArray) makeMultidimensionalSlice(sliceType reflect.Type, dimensions []ArrayDimension, flatSlice reflect.Value, flatSliceIdx int) reflect.Value { + if len(dimensions) == 1 { + endIdx := flatSliceIdx + int(dimensions[0].Length) + return flatSlice.Slice3(flatSliceIdx, endIdx, endIdx) + } + + sliceLen := int(dimensions[0].Length) + slice := reflect.MakeSlice(sliceType, sliceLen, sliceLen) + for i := 0; i < sliceLen; i++ { + subSlice := a.makeMultidimensionalSlice(sliceType.Elem(), dimensions[1:], flatSlice, flatSliceIdx+(i*int(dimensions[1].Length))) + slice.Index(i).Set(subSlice) + } + + return slice +} + +func (a *anyMultiDimSliceArray) ScanIndex(i int) any { + return a.slice.Index(i).Addr().Interface() +} + +func (a *anyMultiDimSliceArray) ScanIndexType() any { + lowestSliceType := a.slice.Type() + for ; lowestSliceType.Elem().Kind() == reflect.Slice; lowestSliceType = lowestSliceType.Elem() { + } + return reflect.New(lowestSliceType.Elem()).Interface() +} + +type anyArrayArrayReflect struct { + array reflect.Value +} + +func (a anyArrayArrayReflect) Dimensions() []ArrayDimension { + return []ArrayDimension{{Length: int32(a.array.Len()), LowerBound: 1}} +} + +func (a anyArrayArrayReflect) Index(i int) any { + return a.array.Index(i).Interface() +} + +func (a anyArrayArrayReflect) IndexType() any { + return reflect.New(a.array.Type().Elem()).Elem().Interface() +} + +func (a *anyArrayArrayReflect) SetDimensions(dimensions []ArrayDimension) error { + if dimensions == nil { + return fmt.Errorf("anyArrayArrayReflect: cannot scan NULL into %v", a.array.Type().String()) + } + + if len(dimensions) != 1 { + return fmt.Errorf("anyArrayArrayReflect: cannot scan multi-dimensional array into %v", a.array.Type().String()) + } + + if int(dimensions[0].Length) != a.array.Len() { + return fmt.Errorf("anyArrayArrayReflect: cannot scan array with length %v into %v", dimensions[0].Length, a.array.Type().String()) + } + + return nil +} + +func (a *anyArrayArrayReflect) ScanIndex(i int) any { + return a.array.Index(i).Addr().Interface() +} + +func (a *anyArrayArrayReflect) ScanIndexType() any { + return reflect.New(a.array.Type().Elem()).Interface() +} diff --git a/pgtype/bytea.go b/pgtype/bytea.go index c7117f485..6c4f0c5ea 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -3,154 +3,252 @@ package pgtype import ( "database/sql/driver" "encoding/hex" - - "github.com/pkg/errors" + "fmt" ) -type Bytea struct { - Bytes []byte - Status Status +type BytesScanner interface { + // ScanBytes receives a byte slice of driver memory that is only valid until the next database method call. + ScanBytes(v []byte) error } -func (dst *Bytea) Set(src interface{}) error { - if src == nil { - *dst = Bytea{Status: Null} +type BytesValuer interface { + // BytesValue returns a byte slice of the byte data. The caller must not change the returned slice. + BytesValue() ([]byte, error) +} + +// DriverBytes is a byte slice that holds a reference to memory owned by the driver. It is only valid from the time it +// is scanned until Rows.Next or Rows.Close is called. It is never safe to use DriverBytes with QueryRow as Row.Scan +// internally calls Rows.Close before returning. +type DriverBytes []byte + +func (b *DriverBytes) ScanBytes(v []byte) error { + *b = v + return nil +} + +// PreallocBytes is a byte slice of preallocated memory that scanned bytes will be copied to. If it is too small a new +// slice will be allocated. +type PreallocBytes []byte + +func (b *PreallocBytes) ScanBytes(v []byte) error { + if v == nil { + *b = nil return nil } - switch value := src.(type) { - case []byte: - if value != nil { - *dst = Bytea{Bytes: value, Status: Present} - } else { - *dst = Bytea{Status: Null} - } - default: - if originalSrc, ok := underlyingBytesType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to Bytea", value) + if len(v) <= len(*b) { + *b = (*b)[:len(v)] + } else { + *b = make(PreallocBytes, len(v)) } - + copy(*b, v) return nil } -func (dst *Bytea) Get() interface{} { - switch dst.Status { - case Present: - return dst.Bytes - case Null: +// UndecodedBytes can be used as a scan target to get the raw bytes from PostgreSQL without any decoding. +type UndecodedBytes []byte + +type scanPlanAnyToUndecodedBytes struct{} + +func (scanPlanAnyToUndecodedBytes) Scan(src []byte, dst any) error { + dstBuf := dst.(*UndecodedBytes) + if src == nil { + *dstBuf = nil return nil - default: - return dst.Status } + + *dstBuf = make([]byte, len(src)) + copy(*dstBuf, src) + return nil } -func (src *Bytea) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *[]byte: - buf := make([]byte, len(src.Bytes)) - copy(buf, src.Bytes) - *v = buf - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } +type ByteaCodec struct{} + +func (ByteaCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (ByteaCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (ByteaCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case []byte: + return encodePlanBytesCodecBinaryBytes{} + case BytesValuer: + return encodePlanBytesCodecBinaryBytesValuer{} + } + case TextFormatCode: + switch value.(type) { + case []byte: + return encodePlanBytesCodecTextBytes{} + case BytesValuer: + return encodePlanBytesCodecTextBytesValuer{} } - case Null: - return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return nil } -// DecodeText only supports the hex format. This has been the default since -// PostgreSQL 9.0. -func (dst *Bytea) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Bytea{Status: Null} - return nil - } +type encodePlanBytesCodecBinaryBytes struct{} - if len(src) < 2 || src[0] != '\\' || src[1] != 'x' { - return errors.Errorf("invalid hex format") +func (encodePlanBytesCodecBinaryBytes) Encode(value any, buf []byte) (newBuf []byte, err error) { + b := value.([]byte) + if b == nil { + return nil, nil } - buf := make([]byte, (len(src)-2)/2) - _, err := hex.Decode(buf, src[2:]) + return append(buf, b...), nil +} + +type encodePlanBytesCodecBinaryBytesValuer struct{} + +func (encodePlanBytesCodecBinaryBytesValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + b, err := value.(BytesValuer).BytesValue() if err != nil { - return err + return nil, err + } + if b == nil { + return nil, nil } - *dst = Bytea{Bytes: buf, Status: Present} - return nil + return append(buf, b...), nil } -func (dst *Bytea) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Bytea{Status: Null} - return nil +type encodePlanBytesCodecTextBytes struct{} + +func (encodePlanBytesCodecTextBytes) Encode(value any, buf []byte) (newBuf []byte, err error) { + b := value.([]byte) + if b == nil { + return nil, nil } - *dst = Bytea{Bytes: src, Status: Present} - return nil + buf = append(buf, `\x`...) + buf = append(buf, hex.EncodeToString(b)...) + return buf, nil } -func (src *Bytea) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: +type encodePlanBytesCodecTextBytesValuer struct{} + +func (encodePlanBytesCodecTextBytesValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + b, err := value.(BytesValuer).BytesValue() + if err != nil { + return nil, err + } + if b == nil { return nil, nil - case Undefined: - return nil, errUndefined } buf = append(buf, `\x`...) - buf = append(buf, hex.EncodeToString(src.Bytes)...) + buf = append(buf, hex.EncodeToString(b)...) return buf, nil } -func (src *Bytea) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined +func (ByteaCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case *[]byte: + return scanPlanBinaryBytesToBytes{} + case BytesScanner: + return scanPlanBinaryBytesToBytesScanner{} + } + case TextFormatCode: + switch target.(type) { + case *[]byte: + return scanPlanTextByteaToBytes{} + case BytesScanner: + return scanPlanTextByteaToBytesScanner{} + } } - return append(buf, src.Bytes...), nil + return nil } -// Scan implements the database/sql Scanner interface. -func (dst *Bytea) Scan(src interface{}) error { +type scanPlanBinaryBytesToBytes struct{} + +func (scanPlanBinaryBytesToBytes) Scan(src []byte, dst any) error { + dstBuf := dst.(*[]byte) if src == nil { - *dst = Bytea{Status: Null} + *dstBuf = nil return nil } - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - buf := make([]byte, len(src)) - copy(buf, src) - *dst = Bytea{Bytes: buf, Status: Present} + *dstBuf = make([]byte, len(src)) + copy(*dstBuf, src) + return nil +} + +type scanPlanBinaryBytesToBytesScanner struct{} + +func (scanPlanBinaryBytesToBytesScanner) Scan(src []byte, dst any) error { + scanner := (dst).(BytesScanner) + return scanner.ScanBytes(src) +} + +type scanPlanTextByteaToBytes struct{} + +func (scanPlanTextByteaToBytes) Scan(src []byte, dst any) error { + dstBuf := dst.(*[]byte) + if src == nil { + *dstBuf = nil return nil } - return errors.Errorf("cannot scan %T", src) + buf, err := decodeHexBytea(src) + if err != nil { + return err + } + *dstBuf = buf + + return nil +} + +type scanPlanTextByteaToBytesScanner struct{} + +func (scanPlanTextByteaToBytesScanner) Scan(src []byte, dst any) error { + scanner := (dst).(BytesScanner) + buf, err := decodeHexBytea(src) + if err != nil { + return err + } + return scanner.ScanBytes(buf) +} + +func decodeHexBytea(src []byte) ([]byte, error) { + if src == nil { + return nil, nil + } + + if len(src) < 2 || src[0] != '\\' || src[1] != 'x' { + return nil, fmt.Errorf("invalid hex format") + } + + buf := make([]byte, (len(src)-2)/2) + _, err := hex.Decode(buf, src[2:]) + if err != nil { + return nil, err + } + + return buf, nil +} + +func (c ByteaCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return c.DecodeValue(m, oid, format, src) } -// Value implements the database/sql/driver Valuer interface. -func (src *Bytea) Value() (driver.Value, error) { - switch src.Status { - case Present: - return src.Bytes, nil - case Null: +func (c ByteaCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { return nil, nil - default: - return nil, errUndefined } + + var buf []byte + err := codecScan(c, m, oid, format, src, &buf) + if err != nil { + return nil, err + } + return buf, nil } diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go deleted file mode 100644 index c8eb56698..000000000 --- a/pgtype/bytea_array.go +++ /dev/null @@ -1,300 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type ByteaArray struct { - Elements []Bytea - Dimensions []ArrayDimension - Status Status -} - -func (dst *ByteaArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = ByteaArray{Status: Null} - return nil - } - - switch value := src.(type) { - - case [][]byte: - if value == nil { - *dst = ByteaArray{Status: Null} - } else if len(value) == 0 { - *dst = ByteaArray{Status: Present} - } else { - elements := make([]Bytea, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = ByteaArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - default: - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to ByteaArray", value) - } - - return nil -} - -func (dst *ByteaArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *ByteaArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - - case *[][]byte: - *v = make([][]byte, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = ByteaArray{Status: Null} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Bytea - - if len(uta.Elements) > 0 { - elements = make([]Bytea, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Bytea - var elemSrc []byte - if s != "NULL" { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = ByteaArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} - - return nil -} - -func (dst *ByteaArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = ByteaArray{Status: Null} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = ByteaArray{Dimensions: arrayHeader.Dimensions, Status: Present} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Bytea, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = ByteaArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} - return nil -} - -func (src *ByteaArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src *ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("bytea"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, errors.Errorf("unable to find oid for type name %v", "bytea") - } - - for i := range src.Elements { - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *ByteaArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *ByteaArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/bytea_array_test.go b/pgtype/bytea_array_test.go deleted file mode 100644 index 8450b71bc..000000000 --- a/pgtype/bytea_array_test.go +++ /dev/null @@ -1,120 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestByteaArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "bytea[]", []interface{}{ - &pgtype.ByteaArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.ByteaArray{ - Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.ByteaArray{Status: pgtype.Null}, - &pgtype.ByteaArray{ - Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Bytes: []byte{}, Status: pgtype.Present}, - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Bytes: []byte{1}, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.ByteaArray{ - Elements: []pgtype.Bytea{ - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Bytes: []byte{}, Status: pgtype.Present}, - {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - {Bytes: []byte{1}, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestByteaArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.ByteaArray - }{ - { - source: [][]byte{{1, 2, 3}}, - result: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([][]byte)(nil)), - result: pgtype.ByteaArray{Status: pgtype.Null}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.ByteaArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestByteaArrayAssignTo(t *testing.T) { - var byteByteSlice [][]byte - - simpleTests := []struct { - src pgtype.ByteaArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.ByteaArray{ - Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &byteByteSlice, - expected: [][]byte{{1, 2, 3}}, - }, - { - src: pgtype.ByteaArray{Status: pgtype.Null}, - dst: &byteByteSlice, - expected: (([][]byte)(nil)), - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} diff --git a/pgtype/bytea_test.go b/pgtype/bytea_test.go index fd5a0dec9..ccd147f6f 100644 --- a/pgtype/bytea_test.go +++ b/pgtype/bytea_test.go @@ -1,73 +1,137 @@ package pgtype_test import ( - "reflect" + "bytes" + "context" + "fmt" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" ) -func TestByteaTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "bytea", []interface{}{ - &pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - &pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, - &pgtype.Bytea{Bytes: nil, Status: pgtype.Null}, - }) -} +func isExpectedEqBytes(a any) func(any) bool { + return func(v any) bool { + ab := a.([]byte) + vb := v.([]byte) -func TestByteaSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Bytea - }{ - {source: []byte{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, - {source: []byte{}, result: pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}}, - {source: []byte(nil), result: pgtype.Bytea{Status: pgtype.Null}}, - {source: _byteSlice{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, - {source: _byteSlice(nil), result: pgtype.Bytea{Status: pgtype.Null}}, - } - - for i, tt := range successfulTests { - var r pgtype.Bytea - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) + if (ab == nil) != (vb == nil) { + return false } - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + if ab == nil { + return true } + + return bytes.Equal(ab, vb) } } -func TestByteaAssignTo(t *testing.T) { - var buf []byte - var _buf _byteSlice - var pbuf *[]byte - var _pbuf *_byteSlice - - simpleTests := []struct { - src pgtype.Bytea - dst interface{} - expected interface{} - }{ - {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &buf, expected: []byte{1, 2, 3}}, - {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &_buf, expected: _byteSlice{1, 2, 3}}, - {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &pbuf, expected: &[]byte{1, 2, 3}}, - {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &_pbuf, expected: &_byteSlice{1, 2, 3}}, - {src: pgtype.Bytea{Status: pgtype.Null}, dst: &pbuf, expected: ((*[]byte)(nil))}, - {src: pgtype.Bytea{Status: pgtype.Null}, dst: &_pbuf, expected: ((*_byteSlice)(nil))}, - } +func TestByteaCodec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "bytea", []pgxtest.ValueRoundTripTest{ + {[]byte{1, 2, 3}, new([]byte), isExpectedEqBytes([]byte{1, 2, 3})}, + {[]byte{}, new([]byte), isExpectedEqBytes([]byte{})}, + {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, + }) +} - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) +func TestDriverBytesQueryRow(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var buf []byte + err := conn.QueryRow(ctx, `select $1::bytea`, []byte{1, 2}).Scan((*pgtype.DriverBytes)(&buf)) + require.EqualError(t, err, "cannot scan into *pgtype.DriverBytes from QueryRow") + }) +} + +func TestDriverBytes(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + argBuf := make([]byte, 128) + for i := range argBuf { + argBuf[i] = byte(i) } - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + rows, err := conn.Query(ctx, `select $1::bytea from generate_series(1, 1000)`, argBuf) + require.NoError(t, err) + defer rows.Close() + + rowCount := 0 + resultBuf := argBuf + detectedResultMutation := false + for rows.Next() { + rowCount++ + + // At some point the buffer should be reused and change. + if !bytes.Equal(argBuf, resultBuf) { + detectedResultMutation = true + } + + err = rows.Scan((*pgtype.DriverBytes)(&resultBuf)) + require.NoError(t, err) + + require.Len(t, resultBuf, len(argBuf)) + require.Equal(t, resultBuf, argBuf) + require.Equalf(t, cap(resultBuf), len(resultBuf), "cap(resultBuf) is larger than len(resultBuf)") } - } + + require.True(t, detectedResultMutation) + + err = rows.Err() + require.NoError(t, err) + }) +} + +func TestPreallocBytes(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + origBuf := []byte{5, 6, 7, 8} + buf := origBuf + err := conn.QueryRow(ctx, `select $1::bytea`, []byte{1, 2}).Scan((*pgtype.PreallocBytes)(&buf)) + require.NoError(t, err) + + require.Len(t, buf, 2) + require.Equal(t, 4, cap(buf)) + require.Equal(t, []byte{1, 2}, buf) + + require.Equal(t, []byte{1, 2, 7, 8}, origBuf) + + err = conn.QueryRow(ctx, `select $1::bytea`, []byte{3, 4, 5, 6, 7}).Scan((*pgtype.PreallocBytes)(&buf)) + require.NoError(t, err) + require.Len(t, buf, 5) + require.Equal(t, 5, cap(buf)) + + require.Equal(t, []byte{1, 2, 7, 8}, origBuf) + }) +} + +func TestUndecodedBytes(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var buf []byte + err := conn.QueryRow(ctx, `select 1::int4`).Scan((*pgtype.UndecodedBytes)(&buf)) + require.NoError(t, err) + + require.Len(t, buf, 4) + require.Equal(t, []byte{0, 0, 0, 1}, buf) + }) +} + +func TestByteaCodecDecodeDatabaseSQLValue(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var buf []byte + err := conn.QueryRow(ctx, `select '\xa1b2c3d4'::bytea`).Scan(sqlScannerFunc(func(src any) error { + switch src := src.(type) { + case []byte: + buf = make([]byte, len(src)) + copy(buf, src) + return nil + default: + return fmt.Errorf("expected []byte, got %T", src) + } + })) + require.NoError(t, err) + + require.Len(t, buf, 4) + require.Equal(t, []byte{0xa1, 0xb2, 0xc3, 0xd4}, buf) + }) } diff --git a/pgtype/cid.go b/pgtype/cid.go deleted file mode 100644 index 0ed54f44e..000000000 --- a/pgtype/cid.go +++ /dev/null @@ -1,61 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" -) - -// CID is PostgreSQL's Command Identifier type. -// -// When one does -// -// select cmin, cmax, * from some_table; -// -// it is the data type of the cmin and cmax hidden system columns. -// -// It is currently implemented as an unsigned four byte integer. -// Its definition can be found in src/include/c.h as CommandId -// in the PostgreSQL sources. -type CID pguint32 - -// Set converts from src to dst. Note that as CID is not a general -// number type Set does not do automatic type conversion as other number -// types do. -func (dst *CID) Set(src interface{}) error { - return (*pguint32)(dst).Set(src) -} - -func (dst *CID) Get() interface{} { - return (*pguint32)(dst).Get() -} - -// AssignTo assigns from src to dst. Note that as CID is not a general number -// type AssignTo does not do automatic type conversion as other number types do. -func (src *CID) AssignTo(dst interface{}) error { - return (*pguint32)(src).AssignTo(dst) -} - -func (dst *CID) DecodeText(ci *ConnInfo, src []byte) error { - return (*pguint32)(dst).DecodeText(ci, src) -} - -func (dst *CID) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*pguint32)(dst).DecodeBinary(ci, src) -} - -func (src *CID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*pguint32)(src).EncodeText(ci, buf) -} - -func (src *CID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*pguint32)(src).EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *CID) Scan(src interface{}) error { - return (*pguint32)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *CID) Value() (driver.Value, error) { - return (*pguint32)(src).Value() -} diff --git a/pgtype/cid_test.go b/pgtype/cid_test.go deleted file mode 100644 index 0dfc56d4d..000000000 --- a/pgtype/cid_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestCIDTranscode(t *testing.T) { - pgTypeName := "cid" - values := []interface{}{ - &pgtype.CID{Uint: 42, Status: pgtype.Present}, - &pgtype.CID{Status: pgtype.Null}, - } - eqFunc := func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - } - - testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - - // No direct conversion from int to cid, convert through text - testutil.TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc) - - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) - } -} - -func TestCIDSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.CID - }{ - {source: uint32(1), result: pgtype.CID{Uint: 1, Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var r pgtype.CID - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestCIDAssignTo(t *testing.T) { - var ui32 uint32 - var pui32 *uint32 - - simpleTests := []struct { - src pgtype.CID - dst interface{} - expected interface{} - }{ - {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.CID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.CID - dst interface{} - expected interface{} - }{ - {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.CID - dst interface{} - }{ - {src: pgtype.CID{Status: pgtype.Null}, dst: &ui32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/cidr.go b/pgtype/cidr.go deleted file mode 100644 index 519b9caef..000000000 --- a/pgtype/cidr.go +++ /dev/null @@ -1,31 +0,0 @@ -package pgtype - -type CIDR Inet - -func (dst *CIDR) Set(src interface{}) error { - return (*Inet)(dst).Set(src) -} - -func (dst *CIDR) Get() interface{} { - return (*Inet)(dst).Get() -} - -func (src *CIDR) AssignTo(dst interface{}) error { - return (*Inet)(src).AssignTo(dst) -} - -func (dst *CIDR) DecodeText(ci *ConnInfo, src []byte) error { - return (*Inet)(dst).DecodeText(ci, src) -} - -func (dst *CIDR) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*Inet)(dst).DecodeBinary(ci, src) -} - -func (src *CIDR) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Inet)(src).EncodeText(ci, buf) -} - -func (src *CIDR) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Inet)(src).EncodeBinary(ci, buf) -} diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go deleted file mode 100644 index e4bb76142..000000000 --- a/pgtype/cidr_array.go +++ /dev/null @@ -1,329 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "net" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type CIDRArray struct { - Elements []CIDR - Dimensions []ArrayDimension - Status Status -} - -func (dst *CIDRArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = CIDRArray{Status: Null} - return nil - } - - switch value := src.(type) { - - case []*net.IPNet: - if value == nil { - *dst = CIDRArray{Status: Null} - } else if len(value) == 0 { - *dst = CIDRArray{Status: Present} - } else { - elements := make([]CIDR, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = CIDRArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []net.IP: - if value == nil { - *dst = CIDRArray{Status: Null} - } else if len(value) == 0 { - *dst = CIDRArray{Status: Present} - } else { - elements := make([]CIDR, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = CIDRArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - default: - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to CIDRArray", value) - } - - return nil -} - -func (dst *CIDRArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *CIDRArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - - case *[]*net.IPNet: - *v = make([]*net.IPNet, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]net.IP: - *v = make([]net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *CIDRArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = CIDRArray{Status: Null} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []CIDR - - if len(uta.Elements) > 0 { - elements = make([]CIDR, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem CIDR - var elemSrc []byte - if s != "NULL" { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = CIDRArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} - - return nil -} - -func (dst *CIDRArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = CIDRArray{Status: Null} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = CIDRArray{Dimensions: arrayHeader.Dimensions, Status: Present} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]CIDR, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = CIDRArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} - return nil -} - -func (src *CIDRArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src *CIDRArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("cidr"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, errors.Errorf("unable to find oid for type name %v", "cidr") - } - - for i := range src.Elements { - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *CIDRArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *CIDRArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/cidr_array_test.go b/pgtype/cidr_array_test.go deleted file mode 100644 index 206a590fa..000000000 --- a/pgtype/cidr_array_test.go +++ /dev/null @@ -1,165 +0,0 @@ -package pgtype_test - -import ( - "net" - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestCIDRArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "cidr[]", []interface{}{ - &pgtype.CIDRArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.CIDRArray{Status: pgtype.Null}, - &pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, - {Status: pgtype.Null}, - {IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestCIDRArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.CIDRArray - }{ - { - source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, - result: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]*net.IPNet)(nil)), - result: pgtype.CIDRArray{Status: pgtype.Null}, - }, - { - source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, - result: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]net.IP)(nil)), - result: pgtype.CIDRArray{Status: pgtype.Null}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.CIDRArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestCIDRArrayAssignTo(t *testing.T) { - var ipnetSlice []*net.IPNet - var ipSlice []net.IP - - simpleTests := []struct { - src pgtype.CIDRArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &ipnetSlice, - expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &ipnetSlice, - expected: []*net.IPNet{nil}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &ipSlice, - expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &ipSlice, - expected: []net.IP{nil}, - }, - { - src: pgtype.CIDRArray{Status: pgtype.Null}, - dst: &ipnetSlice, - expected: (([]*net.IPNet)(nil)), - }, - { - src: pgtype.CIDRArray{Status: pgtype.Null}, - dst: &ipSlice, - expected: (([]net.IP)(nil)), - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} diff --git a/pgtype/circle.go b/pgtype/circle.go index 97ecbf318..fb9b4c11d 100644 --- a/pgtype/circle.go +++ b/pgtype/circle.go @@ -8,139 +8,217 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" + "github.com/jackc/pgx/v5/internal/pgio" ) +type CircleScanner interface { + ScanCircle(v Circle) error +} + +type CircleValuer interface { + CircleValue() (Circle, error) +} + type Circle struct { - P Vec2 - R float64 - Status Status + P Vec2 + R float64 + Valid bool +} + +// ScanCircle implements the [CircleScanner] interface. +func (c *Circle) ScanCircle(v Circle) error { + *c = v + return nil } -func (dst *Circle) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Circle", src) +// CircleValue implements the [CircleValuer] interface. +func (c Circle) CircleValue() (Circle, error) { + return c, nil } -func (dst *Circle) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: +// Scan implements the [database/sql.Scanner] interface. +func (dst *Circle) Scan(src any) error { + if src == nil { + *dst = Circle{} return nil - default: - return dst.Status } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToCircleScanner{}.Scan([]byte(src), dst) + } + + return fmt.Errorf("cannot scan %T", src) } -func (src *Circle) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) +// Value implements the [database/sql/driver.Valuer] interface. +func (src Circle) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + buf, err := CircleCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) + if err != nil { + return nil, err + } + return string(buf), err } -func (dst *Circle) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Circle{Status: Null} +type CircleCodec struct{} + +func (CircleCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (CircleCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (CircleCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(CircleValuer); !ok { return nil } - if len(src) < 9 { - return errors.Errorf("invalid length for Circle: %v", len(src)) + switch format { + case BinaryFormatCode: + return encodePlanCircleCodecBinary{} + case TextFormatCode: + return encodePlanCircleCodecText{} } - str := string(src[2:]) - end := strings.IndexByte(str, ',') - x, err := strconv.ParseFloat(str[:end], 64) + return nil +} + +type encodePlanCircleCodecBinary struct{} + +func (encodePlanCircleCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + circle, err := value.(CircleValuer).CircleValue() if err != nil { - return err + return nil, err } - str = str[end+1:] - end = strings.IndexByte(str, ')') + if !circle.Valid { + return nil, nil + } - y, err := strconv.ParseFloat(str[:end], 64) + buf = pgio.AppendUint64(buf, math.Float64bits(circle.P.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(circle.P.Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(circle.R)) + return buf, nil +} + +type encodePlanCircleCodecText struct{} + +func (encodePlanCircleCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + circle, err := value.(CircleValuer).CircleValue() if err != nil { - return err + return nil, err } - str = str[end+2 : len(str)-1] + if !circle.Valid { + return nil, nil + } - r, err := strconv.ParseFloat(str, 64) - if err != nil { - return err + buf = append(buf, fmt.Sprintf(`<(%s,%s),%s>`, + strconv.FormatFloat(circle.P.X, 'f', -1, 64), + strconv.FormatFloat(circle.P.Y, 'f', -1, 64), + strconv.FormatFloat(circle.R, 'f', -1, 64), + )...) + return buf, nil +} + +func (CircleCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case CircleScanner: + return scanPlanBinaryCircleToCircleScanner{} + } + case TextFormatCode: + switch target.(type) { + case CircleScanner: + return scanPlanTextAnyToCircleScanner{} + } } - *dst = Circle{P: Vec2{x, y}, R: r, Status: Present} return nil } -func (dst *Circle) DecodeBinary(ci *ConnInfo, src []byte) error { +func (c CircleCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c CircleCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { - *dst = Circle{Status: Null} - return nil + return nil, nil + } + + var circle Circle + err := codecScan(c, m, oid, format, src, &circle) + if err != nil { + return nil, err + } + return circle, nil +} + +type scanPlanBinaryCircleToCircleScanner struct{} + +func (scanPlanBinaryCircleToCircleScanner) Scan(src []byte, dst any) error { + scanner := (dst).(CircleScanner) + + if src == nil { + return scanner.ScanCircle(Circle{}) } if len(src) != 24 { - return errors.Errorf("invalid length for Circle: %v", len(src)) + return fmt.Errorf("invalid length for Circle: %v", len(src)) } x := binary.BigEndian.Uint64(src) y := binary.BigEndian.Uint64(src[8:]) r := binary.BigEndian.Uint64(src[16:]) - *dst = Circle{ - P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, - R: math.Float64frombits(r), - Status: Present, - } - return nil + return scanner.ScanCircle(Circle{ + P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, + R: math.Float64frombits(r), + Valid: true, + }) } -func (src *Circle) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } +type scanPlanTextAnyToCircleScanner struct{} - buf = append(buf, fmt.Sprintf(`<(%f,%f),%f>`, src.P.X, src.P.Y, src.R)...) - return buf, nil -} +func (scanPlanTextAnyToCircleScanner) Scan(src []byte, dst any) error { + scanner := (dst).(CircleScanner) -func (src *Circle) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined + if src == nil { + return scanner.ScanCircle(Circle{}) } - buf = pgio.AppendUint64(buf, math.Float64bits(src.P.X)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P.Y)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.R)) - return buf, nil -} + if len(src) < 9 { + return fmt.Errorf("invalid length for Circle: %v", len(src)) + } -// Scan implements the database/sql Scanner interface. -func (dst *Circle) Scan(src interface{}) error { - if src == nil { - *dst = Circle{Status: Null} - return nil + str := string(src[2:]) + end := strings.IndexByte(str, ',') + x, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err } - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err } - return errors.Errorf("cannot scan %T", src) -} + str = str[end+2 : len(str)-1] + + r, err := strconv.ParseFloat(str, 64) + if err != nil { + return err + } -// Value implements the database/sql/driver Valuer interface. -func (src *Circle) Value() (driver.Value, error) { - return EncodeValueText(src) + return scanner.ScanCircle(Circle{P: Vec2{x, y}, R: r, Valid: true}) } diff --git a/pgtype/circle_test.go b/pgtype/circle_test.go index 2747d4f58..7b6db7774 100644 --- a/pgtype/circle_test.go +++ b/pgtype/circle_test.go @@ -1,16 +1,28 @@ package pgtype_test import ( + "context" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" ) func TestCircleTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "circle", []interface{}{ - &pgtype.Circle{P: pgtype.Vec2{1.234, 5.6789}, R: 3.5, Status: pgtype.Present}, - &pgtype.Circle{P: pgtype.Vec2{-1.234, -5.6789}, R: 12.9, Status: pgtype.Present}, - &pgtype.Circle{Status: pgtype.Null}, + skipCockroachDB(t, "Server does not support box type") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "circle", []pgxtest.ValueRoundTripTest{ + { + pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}, + new(pgtype.Circle), + isExpectedEq(pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}), + }, + { + pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}, + new(pgtype.Circle), + isExpectedEq(pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}), + }, + {pgtype.Circle{}, new(pgtype.Circle), isExpectedEq(pgtype.Circle{})}, + {nil, new(pgtype.Circle), isExpectedEq(pgtype.Circle{})}, }) } diff --git a/pgtype/composite.go b/pgtype/composite.go new file mode 100644 index 000000000..598cf7af9 --- /dev/null +++ b/pgtype/composite.go @@ -0,0 +1,601 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "errors" + "fmt" + "strings" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// CompositeIndexGetter is a type accessed by index that can be converted into a PostgreSQL composite. +type CompositeIndexGetter interface { + // IsNull returns true if the value is SQL NULL. + IsNull() bool + + // Index returns the element at i. + Index(i int) any +} + +// CompositeIndexScanner is a type accessed by index that can be scanned from a PostgreSQL composite. +type CompositeIndexScanner interface { + // ScanNull sets the value to SQL NULL. + ScanNull() error + + // ScanIndex returns a value usable as a scan target for i. + ScanIndex(i int) any +} + +type CompositeCodecField struct { + Name string + Type *Type +} + +type CompositeCodec struct { + Fields []CompositeCodecField +} + +func (c *CompositeCodec) FormatSupported(format int16) bool { + for _, f := range c.Fields { + if !f.Type.Codec.FormatSupported(format) { + return false + } + } + + return true +} + +func (c *CompositeCodec) PreferredFormat() int16 { + if c.FormatSupported(BinaryFormatCode) { + return BinaryFormatCode + } + return TextFormatCode +} + +func (c *CompositeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(CompositeIndexGetter); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return &encodePlanCompositeCodecCompositeIndexGetterToBinary{cc: c, m: m} + case TextFormatCode: + return &encodePlanCompositeCodecCompositeIndexGetterToText{cc: c, m: m} + } + + return nil +} + +type encodePlanCompositeCodecCompositeIndexGetterToBinary struct { + cc *CompositeCodec + m *Map +} + +func (plan *encodePlanCompositeCodecCompositeIndexGetterToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + getter := value.(CompositeIndexGetter) + + if getter.IsNull() { + return nil, nil + } + + builder := NewCompositeBinaryBuilder(plan.m, buf) + for i, field := range plan.cc.Fields { + builder.AppendValue(field.Type.OID, getter.Index(i)) + } + + return builder.Finish() +} + +type encodePlanCompositeCodecCompositeIndexGetterToText struct { + cc *CompositeCodec + m *Map +} + +func (plan *encodePlanCompositeCodecCompositeIndexGetterToText) Encode(value any, buf []byte) (newBuf []byte, err error) { + getter := value.(CompositeIndexGetter) + + if getter.IsNull() { + return nil, nil + } + + b := NewCompositeTextBuilder(plan.m, buf) + for i, field := range plan.cc.Fields { + b.AppendValue(field.Type.OID, getter.Index(i)) + } + + return b.Finish() +} + +func (c *CompositeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case CompositeIndexScanner: + return &scanPlanBinaryCompositeToCompositeIndexScanner{cc: c, m: m} + } + case TextFormatCode: + switch target.(type) { + case CompositeIndexScanner: + return &scanPlanTextCompositeToCompositeIndexScanner{cc: c, m: m} + } + } + + return nil +} + +type scanPlanBinaryCompositeToCompositeIndexScanner struct { + cc *CompositeCodec + m *Map +} + +func (plan *scanPlanBinaryCompositeToCompositeIndexScanner) Scan(src []byte, target any) error { + targetScanner := (target).(CompositeIndexScanner) + + if src == nil { + return targetScanner.ScanNull() + } + + scanner := NewCompositeBinaryScanner(plan.m, src) + for i, field := range plan.cc.Fields { + if scanner.Next() { + fieldTarget := targetScanner.ScanIndex(i) + if fieldTarget != nil { + fieldPlan := plan.m.PlanScan(field.Type.OID, BinaryFormatCode, fieldTarget) + if fieldPlan == nil { + return fmt.Errorf("unable to encode %v into OID %d in binary format", field, field.Type.OID) + } + + err := fieldPlan.Scan(scanner.Bytes(), fieldTarget) + if err != nil { + return err + } + } + } else { + return errors.New("read past end of composite") + } + } + + if err := scanner.Err(); err != nil { + return err + } + + return nil +} + +type scanPlanTextCompositeToCompositeIndexScanner struct { + cc *CompositeCodec + m *Map +} + +func (plan *scanPlanTextCompositeToCompositeIndexScanner) Scan(src []byte, target any) error { + targetScanner := (target).(CompositeIndexScanner) + + if src == nil { + return targetScanner.ScanNull() + } + + scanner := NewCompositeTextScanner(plan.m, src) + for i, field := range plan.cc.Fields { + if scanner.Next() { + fieldTarget := targetScanner.ScanIndex(i) + if fieldTarget != nil { + fieldPlan := plan.m.PlanScan(field.Type.OID, TextFormatCode, fieldTarget) + if fieldPlan == nil { + return fmt.Errorf("unable to encode %v into OID %d in text format", field, field.Type.OID) + } + + err := fieldPlan.Scan(scanner.Bytes(), fieldTarget) + if err != nil { + return err + } + } + } else { + return errors.New("read past end of composite") + } + } + + if err := scanner.Err(); err != nil { + return err + } + + return nil +} + +func (c *CompositeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + switch format { + case TextFormatCode: + return string(src), nil + case BinaryFormatCode: + buf := make([]byte, len(src)) + copy(buf, src) + return buf, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } +} + +func (c *CompositeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + switch format { + case TextFormatCode: + scanner := NewCompositeTextScanner(m, src) + values := make(map[string]any, len(c.Fields)) + for i := 0; scanner.Next() && i < len(c.Fields); i++ { + var v any + fieldPlan := m.PlanScan(c.Fields[i].Type.OID, TextFormatCode, &v) + if fieldPlan == nil { + return nil, fmt.Errorf("unable to scan OID %d in text format into %v", c.Fields[i].Type.OID, v) + } + + err := fieldPlan.Scan(scanner.Bytes(), &v) + if err != nil { + return nil, err + } + + values[c.Fields[i].Name] = v + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + return values, nil + case BinaryFormatCode: + scanner := NewCompositeBinaryScanner(m, src) + values := make(map[string]any, len(c.Fields)) + for i := 0; scanner.Next() && i < len(c.Fields); i++ { + var v any + fieldPlan := m.PlanScan(scanner.OID(), BinaryFormatCode, &v) + if fieldPlan == nil { + return nil, fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), v) + } + + err := fieldPlan.Scan(scanner.Bytes(), &v) + if err != nil { + return nil, err + } + + values[c.Fields[i].Name] = v + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + return values, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } +} + +type CompositeBinaryScanner struct { + m *Map + rp int + src []byte + + fieldCount int32 + fieldBytes []byte + fieldOID uint32 + err error +} + +// NewCompositeBinaryScanner a scanner over a binary encoded composite balue. +func NewCompositeBinaryScanner(m *Map, src []byte) *CompositeBinaryScanner { + rp := 0 + if len(src[rp:]) < 4 { + return &CompositeBinaryScanner{err: fmt.Errorf("Record incomplete %v", src)} + } + + fieldCount := int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + return &CompositeBinaryScanner{ + m: m, + rp: rp, + src: src, + fieldCount: fieldCount, + } +} + +// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After +// Next returns false, the Err method can be called to check if any errors occurred. +func (cfs *CompositeBinaryScanner) Next() bool { + if cfs.err != nil { + return false + } + + if cfs.rp == len(cfs.src) { + return false + } + + if len(cfs.src[cfs.rp:]) < 8 { + cfs.err = fmt.Errorf("Record incomplete %v", cfs.src) + return false + } + cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:]) + cfs.rp += 4 + + fieldLen := int(int32(binary.BigEndian.Uint32(cfs.src[cfs.rp:]))) + cfs.rp += 4 + + if fieldLen >= 0 { + if len(cfs.src[cfs.rp:]) < fieldLen { + cfs.err = fmt.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src) + return false + } + cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen] + cfs.rp += fieldLen + } else { + cfs.fieldBytes = nil + } + + return true +} + +func (cfs *CompositeBinaryScanner) FieldCount() int { + return int(cfs.fieldCount) +} + +// Bytes returns the bytes of the field most recently read by Scan(). +func (cfs *CompositeBinaryScanner) Bytes() []byte { + return cfs.fieldBytes +} + +// OID returns the OID of the field most recently read by Scan(). +func (cfs *CompositeBinaryScanner) OID() uint32 { + return cfs.fieldOID +} + +// Err returns any error encountered by the scanner. +func (cfs *CompositeBinaryScanner) Err() error { + return cfs.err +} + +type CompositeTextScanner struct { + m *Map + rp int + src []byte + + fieldBytes []byte + err error +} + +// NewCompositeTextScanner a scanner over a text encoded composite value. +func NewCompositeTextScanner(m *Map, src []byte) *CompositeTextScanner { + if len(src) < 2 { + return &CompositeTextScanner{err: fmt.Errorf("Record incomplete %v", src)} + } + + if src[0] != '(' { + return &CompositeTextScanner{err: fmt.Errorf("composite text format must start with '('")} + } + + if src[len(src)-1] != ')' { + return &CompositeTextScanner{err: fmt.Errorf("composite text format must end with ')'")} + } + + return &CompositeTextScanner{ + m: m, + rp: 1, + src: src, + } +} + +// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After +// Next returns false, the Err method can be called to check if any errors occurred. +func (cfs *CompositeTextScanner) Next() bool { + if cfs.err != nil { + return false + } + + if cfs.rp == len(cfs.src) { + return false + } + + switch cfs.src[cfs.rp] { + case ',', ')': // null + cfs.rp++ + cfs.fieldBytes = nil + return true + case '"': // quoted value + cfs.rp++ + cfs.fieldBytes = make([]byte, 0, 16) + for { + ch := cfs.src[cfs.rp] + + if ch == '"' { + cfs.rp++ + if cfs.src[cfs.rp] == '"' { + cfs.fieldBytes = append(cfs.fieldBytes, '"') + cfs.rp++ + } else { + break + } + } else if ch == '\\' { + cfs.rp++ + cfs.fieldBytes = append(cfs.fieldBytes, cfs.src[cfs.rp]) + cfs.rp++ + } else { + cfs.fieldBytes = append(cfs.fieldBytes, ch) + cfs.rp++ + } + } + cfs.rp++ + return true + default: // unquoted value + start := cfs.rp + for { + ch := cfs.src[cfs.rp] + if ch == ',' || ch == ')' { + break + } + cfs.rp++ + } + cfs.fieldBytes = cfs.src[start:cfs.rp] + cfs.rp++ + return true + } +} + +// Bytes returns the bytes of the field most recently read by Scan(). +func (cfs *CompositeTextScanner) Bytes() []byte { + return cfs.fieldBytes +} + +// Err returns any error encountered by the scanner. +func (cfs *CompositeTextScanner) Err() error { + return cfs.err +} + +type CompositeBinaryBuilder struct { + m *Map + buf []byte + startIdx int + fieldCount uint32 + err error +} + +func NewCompositeBinaryBuilder(m *Map, buf []byte) *CompositeBinaryBuilder { + startIdx := len(buf) + buf = append(buf, 0, 0, 0, 0) // allocate room for number of fields + return &CompositeBinaryBuilder{m: m, buf: buf, startIdx: startIdx} +} + +func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field any) { + if b.err != nil { + return + } + + if field == nil { + b.buf = pgio.AppendUint32(b.buf, oid) + b.buf = pgio.AppendInt32(b.buf, -1) + b.fieldCount++ + return + } + + plan := b.m.PlanEncode(oid, BinaryFormatCode, field) + if plan == nil { + b.err = fmt.Errorf("unable to encode %v into OID %d in binary format", field, oid) + return + } + + b.buf = pgio.AppendUint32(b.buf, oid) + lengthPos := len(b.buf) + b.buf = pgio.AppendInt32(b.buf, -1) + fieldBuf, err := plan.Encode(field, b.buf) + if err != nil { + b.err = err + return + } + if fieldBuf != nil { + binary.BigEndian.PutUint32(fieldBuf[lengthPos:], uint32(len(fieldBuf)-len(b.buf))) + b.buf = fieldBuf + } + + b.fieldCount++ +} + +func (b *CompositeBinaryBuilder) Finish() ([]byte, error) { + if b.err != nil { + return nil, b.err + } + + binary.BigEndian.PutUint32(b.buf[b.startIdx:], b.fieldCount) + return b.buf, nil +} + +type CompositeTextBuilder struct { + m *Map + buf []byte + startIdx int + fieldCount uint32 + err error + fieldBuf [32]byte +} + +func NewCompositeTextBuilder(m *Map, buf []byte) *CompositeTextBuilder { + buf = append(buf, '(') // allocate room for number of fields + return &CompositeTextBuilder{m: m, buf: buf} +} + +func (b *CompositeTextBuilder) AppendValue(oid uint32, field any) { + if b.err != nil { + return + } + + if field == nil { + b.buf = append(b.buf, ',') + return + } + + plan := b.m.PlanEncode(oid, TextFormatCode, field) + if plan == nil { + b.err = fmt.Errorf("unable to encode %v into OID %d in text format", field, oid) + return + } + + fieldBuf, err := plan.Encode(field, b.fieldBuf[0:0]) + if err != nil { + b.err = err + return + } + if fieldBuf != nil { + b.buf = append(b.buf, quoteCompositeFieldIfNeeded(string(fieldBuf))...) + } + + b.buf = append(b.buf, ',') +} + +func (b *CompositeTextBuilder) Finish() ([]byte, error) { + if b.err != nil { + return nil, b.err + } + + b.buf[len(b.buf)-1] = ')' + return b.buf, nil +} + +var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +func quoteCompositeField(src string) string { + return `"` + quoteCompositeReplacer.Replace(src) + `"` +} + +func quoteCompositeFieldIfNeeded(src string) string { + if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) { + return quoteCompositeField(src) + } + return src +} + +// CompositeFields represents the values of a composite value. It can be used as an encoding source or as a scan target. +// It cannot scan a NULL, but the composite fields can be NULL. +type CompositeFields []any + +func (cf CompositeFields) SkipUnderlyingTypePlan() {} + +func (cf CompositeFields) IsNull() bool { + return cf == nil +} + +func (cf CompositeFields) Index(i int) any { + return cf[i] +} + +func (cf CompositeFields) ScanNull() error { + return fmt.Errorf("cannot scan NULL into CompositeFields") +} + +func (cf CompositeFields) ScanIndex(i int) any { + return cf[i] +} diff --git a/pgtype/composite_test.go b/pgtype/composite_test.go new file mode 100644 index 000000000..7bae67677 --- /dev/null +++ b/pgtype/composite_test.go @@ -0,0 +1,242 @@ +package pgtype_test + +import ( + "context" + "fmt" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/require" +) + +func TestCompositeCodecTranscode(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(ctx, `drop type if exists ct_test; + +create type ct_test as ( + a text, + b int4 +);`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop type ct_test") + + dt, err := conn.LoadType(ctx, "ct_test") + require.NoError(t, err) + conn.TypeMap().RegisterType(dt) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + for _, format := range formats { + var a string + var b int32 + + err := conn.QueryRow(ctx, "select $1::ct_test", pgx.QueryResultFormats{format.code}, + pgtype.CompositeFields{"hi", int32(42)}, + ).Scan( + pgtype.CompositeFields{&a, &b}, + ) + require.NoErrorf(t, err, "%v", format.name) + require.EqualValuesf(t, "hi", a, "%v", format.name) + require.EqualValuesf(t, 42, b, "%v", format.name) + } + }) +} + +type point3d struct { + X, Y, Z float64 +} + +func (p point3d) IsNull() bool { + return false +} + +func (p point3d) Index(i int) any { + switch i { + case 0: + return p.X + case 1: + return p.Y + case 2: + return p.Z + default: + panic("invalid index") + } +} + +func (p *point3d) ScanNull() error { + return fmt.Errorf("cannot scan NULL into point3d") +} + +func (p *point3d) ScanIndex(i int) any { + switch i { + case 0: + return &p.X + case 1: + return &p.Y + case 2: + return &p.Z + default: + panic("invalid index") + } +} + +func TestCompositeCodecTranscodeStruct(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(ctx, `drop type if exists point3d; + +create type point3d as ( + x float8, + y float8, + z float8 +);`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop type point3d") + + dt, err := conn.LoadType(ctx, "point3d") + require.NoError(t, err) + conn.TypeMap().RegisterType(dt) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + for _, format := range formats { + input := point3d{X: 1, Y: 2, Z: 3} + var output point3d + err := conn.QueryRow(ctx, "select $1::point3d", pgx.QueryResultFormats{format.code}, input).Scan(&output) + require.NoErrorf(t, err, "%v", format.name) + require.Equalf(t, input, output, "%v", format.name) + } + }) +} + +func TestCompositeCodecTranscodeStructWrapper(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(ctx, `drop type if exists point3d; + +create type point3d as ( + x float8, + y float8, + z float8 +);`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop type point3d") + + dt, err := conn.LoadType(ctx, "point3d") + require.NoError(t, err) + conn.TypeMap().RegisterType(dt) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + type anotherPoint struct { + X, Y, Z float64 + } + + for _, format := range formats { + input := anotherPoint{X: 1, Y: 2, Z: 3} + var output anotherPoint + err := conn.QueryRow(ctx, "select $1::point3d", pgx.QueryResultFormats{format.code}, input).Scan(&output) + require.NoErrorf(t, err, "%v", format.name) + require.Equalf(t, input, output, "%v", format.name) + } + }) +} + +func TestCompositeCodecDecodeValue(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(ctx, `drop type if exists point3d; + +create type point3d as ( + x float8, + y float8, + z float8 +);`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop type point3d") + + dt, err := conn.LoadType(ctx, "point3d") + require.NoError(t, err) + conn.TypeMap().RegisterType(dt) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + for _, format := range formats { + rows, err := conn.Query(ctx, "select '(1,2,3)'::point3d", pgx.QueryResultFormats{format.code}) + require.NoErrorf(t, err, "%v", format.name) + require.True(t, rows.Next()) + values, err := rows.Values() + require.NoErrorf(t, err, "%v", format.name) + require.Lenf(t, values, 1, "%v", format.name) + require.Equalf(t, map[string]any{"x": 1.0, "y": 2.0, "z": 3.0}, values[0], "%v", format.name) + require.False(t, rows.Next()) + require.NoErrorf(t, rows.Err(), "%v", format.name) + } + }) +} + +// Test for composite type from table instead of create type. Table types have system / hidden columns like tableoid, +// cmax, xmax, etc. These are not included when sending or receiving composite types. +// +// https://github.com/jackc/pgx/issues/1576 +func TestCompositeCodecTranscodeStructWrapperForTable(t *testing.T) { + skipCockroachDB(t, "Server does not support composite types from table definitions") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(ctx, `drop table if exists point3d; + +create table point3d ( + x float8, + y float8, + z float8 +);`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop table point3d") + + dt, err := conn.LoadType(ctx, "point3d") + require.NoError(t, err) + conn.TypeMap().RegisterType(dt) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + type anotherPoint struct { + X, Y, Z float64 + } + + for _, format := range formats { + input := anotherPoint{X: 1, Y: 2, Z: 3} + var output anotherPoint + err := conn.QueryRow(ctx, "select $1::point3d", pgx.QueryResultFormats{format.code}, input).Scan(&output) + require.NoErrorf(t, err, "%v", format.name) + require.Equalf(t, input, output, "%v", format.name) + } + }) +} diff --git a/pgtype/convert.go b/pgtype/convert.go index 5dfb738e9..8a9cee9c3 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -1,354 +1,15 @@ package pgtype import ( - "math" "reflect" - "time" - - "github.com/pkg/errors" ) -const maxUint = ^uint(0) -const maxInt = int(maxUint >> 1) -const minInt = -maxInt - 1 - -// underlyingNumberType gets the underlying type that can be converted to Int2, Int4, Int8, Float4, or Float8 -func underlyingNumberType(val interface{}) (interface{}, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return nil, false - } - convVal := refVal.Elem().Interface() - return convVal, true - case reflect.Int: - convVal := int(refVal.Int()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Int8: - convVal := int8(refVal.Int()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Int16: - convVal := int16(refVal.Int()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Int32: - convVal := int32(refVal.Int()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Int64: - convVal := int64(refVal.Int()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Uint: - convVal := uint(refVal.Uint()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Uint8: - convVal := uint8(refVal.Uint()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Uint16: - convVal := uint16(refVal.Uint()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Uint32: - convVal := uint32(refVal.Uint()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Uint64: - convVal := uint64(refVal.Uint()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Float32: - convVal := float32(refVal.Float()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Float64: - convVal := refVal.Float() - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.String: - convVal := refVal.String() - return convVal, reflect.TypeOf(convVal) != refVal.Type() - } - - return nil, false -} - -// underlyingBoolType gets the underlying type that can be converted to Bool -func underlyingBoolType(val interface{}) (interface{}, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return nil, false - } - convVal := refVal.Elem().Interface() - return convVal, true - case reflect.Bool: - convVal := refVal.Bool() - return convVal, reflect.TypeOf(convVal) != refVal.Type() - } - - return nil, false -} - -// underlyingBytesType gets the underlying type that can be converted to []byte -func underlyingBytesType(val interface{}) (interface{}, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return nil, false - } - convVal := refVal.Elem().Interface() - return convVal, true - case reflect.Slice: - if refVal.Type().Elem().Kind() == reflect.Uint8 { - convVal := refVal.Bytes() - return convVal, reflect.TypeOf(convVal) != refVal.Type() - } - } - - return nil, false -} - -// underlyingStringType gets the underlying type that can be converted to String -func underlyingStringType(val interface{}) (interface{}, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return nil, false - } - convVal := refVal.Elem().Interface() - return convVal, true - case reflect.String: - convVal := refVal.String() - return convVal, reflect.TypeOf(convVal) != refVal.Type() - } - - return nil, false -} - -// underlyingPtrType dereferences a pointer -func underlyingPtrType(val interface{}) (interface{}, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return nil, false - } - convVal := refVal.Elem().Interface() - return convVal, true - } - - return nil, false -} - -// underlyingTimeType gets the underlying type that can be converted to time.Time -func underlyingTimeType(val interface{}) (interface{}, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return time.Time{}, false - } - convVal := refVal.Elem().Interface() - return convVal, true - } - - timeType := reflect.TypeOf(time.Time{}) - if refVal.Type().ConvertibleTo(timeType) { - return refVal.Convert(timeType).Interface(), true - } - - return time.Time{}, false -} - -// underlyingSliceType gets the underlying slice type -func underlyingSliceType(val interface{}) (interface{}, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return nil, false - } - convVal := refVal.Elem().Interface() - return convVal, true - case reflect.Slice: - baseSliceType := reflect.SliceOf(refVal.Type().Elem()) - if refVal.Type().ConvertibleTo(baseSliceType) { - convVal := refVal.Convert(baseSliceType) - return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type() - } - } - - return nil, false -} - -func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { - if srcStatus == Present { - switch v := dst.(type) { - case *int: - if srcVal < int64(minInt) { - return errors.Errorf("%d is less than minimum value for int", srcVal) - } else if srcVal > int64(maxInt) { - return errors.Errorf("%d is greater than maximum value for int", srcVal) - } - *v = int(srcVal) - case *int8: - if srcVal < math.MinInt8 { - return errors.Errorf("%d is less than minimum value for int8", srcVal) - } else if srcVal > math.MaxInt8 { - return errors.Errorf("%d is greater than maximum value for int8", srcVal) - } - *v = int8(srcVal) - case *int16: - if srcVal < math.MinInt16 { - return errors.Errorf("%d is less than minimum value for int16", srcVal) - } else if srcVal > math.MaxInt16 { - return errors.Errorf("%d is greater than maximum value for int16", srcVal) - } - *v = int16(srcVal) - case *int32: - if srcVal < math.MinInt32 { - return errors.Errorf("%d is less than minimum value for int32", srcVal) - } else if srcVal > math.MaxInt32 { - return errors.Errorf("%d is greater than maximum value for int32", srcVal) - } - *v = int32(srcVal) - case *int64: - if srcVal < math.MinInt64 { - return errors.Errorf("%d is less than minimum value for int64", srcVal) - } else if srcVal > math.MaxInt64 { - return errors.Errorf("%d is greater than maximum value for int64", srcVal) - } - *v = int64(srcVal) - case *uint: - if srcVal < 0 { - return errors.Errorf("%d is less than zero for uint", srcVal) - } else if uint64(srcVal) > uint64(maxUint) { - return errors.Errorf("%d is greater than maximum value for uint", srcVal) - } - *v = uint(srcVal) - case *uint8: - if srcVal < 0 { - return errors.Errorf("%d is less than zero for uint8", srcVal) - } else if srcVal > math.MaxUint8 { - return errors.Errorf("%d is greater than maximum value for uint8", srcVal) - } - *v = uint8(srcVal) - case *uint16: - if srcVal < 0 { - return errors.Errorf("%d is less than zero for uint32", srcVal) - } else if srcVal > math.MaxUint16 { - return errors.Errorf("%d is greater than maximum value for uint16", srcVal) - } - *v = uint16(srcVal) - case *uint32: - if srcVal < 0 { - return errors.Errorf("%d is less than zero for uint32", srcVal) - } else if srcVal > math.MaxUint32 { - return errors.Errorf("%d is greater than maximum value for uint32", srcVal) - } - *v = uint32(srcVal) - case *uint64: - if srcVal < 0 { - return errors.Errorf("%d is less than zero for uint64", srcVal) - } - *v = uint64(srcVal) - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return int64AssignTo(srcVal, srcStatus, el.Interface()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if el.OverflowInt(int64(srcVal)) { - return errors.Errorf("cannot put %d into %T", srcVal, dst) - } - el.SetInt(int64(srcVal)) - return nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - if srcVal < 0 { - return errors.Errorf("%d is less than zero for %T", srcVal, dst) - } - if el.OverflowUint(uint64(srcVal)) { - return errors.Errorf("cannot put %d into %T", srcVal, dst) - } - el.SetUint(uint64(srcVal)) - return nil - } - } - return errors.Errorf("cannot assign %v into %T", srcVal, dst) - } - return nil - } - - // if dst is a pointer to pointer and srcStatus is not Present, nil it out - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - if el.Kind() == reflect.Ptr { - el.Set(reflect.Zero(el.Type())) - return nil - } - } - - return errors.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) -} - -func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { - if srcStatus == Present { - switch v := dst.(type) { - case *float32: - *v = float32(srcVal) - case *float64: - *v = srcVal - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return float64AssignTo(srcVal, srcStatus, el.Interface()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - i64 := int64(srcVal) - if float64(i64) == srcVal { - return int64AssignTo(i64, srcStatus, dst) - } - } - } - return errors.Errorf("cannot assign %v into %T", srcVal, dst) - } - return nil - } - - // if dst is a pointer to pointer and srcStatus is not Present, nil it out - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - if el.Kind() == reflect.Ptr { - el.Set(reflect.Zero(el.Type())) - return nil - } - } - - return errors.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) -} - -func NullAssignTo(dst interface{}) error { +func NullAssignTo(dst any) error { dstPtr := reflect.ValueOf(dst) // AssignTo dst must always be a pointer if dstPtr.Kind() != reflect.Ptr { - return errors.Errorf("cannot assign NULL to %T", dst) + return &nullAssignmentError{dst: dst} } dstVal := dstPtr.Elem() @@ -359,11 +20,16 @@ func NullAssignTo(dst interface{}) error { return nil } - return errors.Errorf("cannot assign NULL to %T", dst) + return &nullAssignmentError{dst: dst} } var kindTypes map[reflect.Kind]reflect.Type +func toInterface(dst reflect.Value, t reflect.Type) (any, bool) { + nextDst := dst.Convert(t) + return nextDst.Interface(), dst.Type() != nextDst.Type() +} + // GetAssignToDstType attempts to convert dst to something AssignTo can assign // to. If dst is a pointer to pointer it allocates a value and returns the // dereferences pointer. If dst is a named type such as *Foo where Foo is type @@ -371,7 +37,7 @@ var kindTypes map[reflect.Kind]reflect.Type // // GetAssignToDstType returns the converted dst and a bool representing if any // change was made. -func GetAssignToDstType(dst interface{}) (interface{}, bool) { +func GetAssignToDstType(dst any) (any, bool) { dstPtr := reflect.ValueOf(dst) // AssignTo dst must always be a pointer @@ -389,15 +55,33 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { // if dst is pointer to a base type that has been renamed if baseValType, ok := kindTypes[dstVal.Kind()]; ok { - nextDst := dstPtr.Convert(reflect.PtrTo(baseValType)) - return nextDst.Interface(), dstPtr.Type() != nextDst.Type() + return toInterface(dstPtr, reflect.PtrTo(baseValType)) } if dstVal.Kind() == reflect.Slice { if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { - baseSliceType := reflect.PtrTo(reflect.SliceOf(baseElemType)) - nextDst := dstPtr.Convert(baseSliceType) - return nextDst.Interface(), dstPtr.Type() != nextDst.Type() + return toInterface(dstPtr, reflect.PtrTo(reflect.SliceOf(baseElemType))) + } + } + + if dstVal.Kind() == reflect.Array { + if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { + return toInterface(dstPtr, reflect.PtrTo(reflect.ArrayOf(dstVal.Len(), baseElemType))) + } + } + + if dstVal.Kind() == reflect.Struct { + if dstVal.Type().NumField() == 1 && dstVal.Type().Field(0).Anonymous { + dstPtr = dstVal.Field(0).Addr() + nested := dstVal.Type().Field(0).Type + if nested.Kind() == reflect.Array { + if baseElemType, ok := kindTypes[nested.Elem().Kind()]; ok { + return toInterface(dstPtr, reflect.PtrTo(reflect.ArrayOf(nested.Len(), baseElemType))) + } + } + if _, ok := kindTypes[nested.Kind()]; ok && dstPtr.CanInterface() { + return dstPtr.Interface(), true + } } } diff --git a/pgtype/database_sql.go b/pgtype/database_sql.go deleted file mode 100644 index 969536ddb..000000000 --- a/pgtype/database_sql.go +++ /dev/null @@ -1,42 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - - "github.com/pkg/errors" -) - -func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { - if valuer, ok := src.(driver.Valuer); ok { - return valuer.Value() - } - - if textEncoder, ok := src.(TextEncoder); ok { - buf, err := textEncoder.EncodeText(ci, nil) - if err != nil { - return nil, err - } - return string(buf), nil - } - - if binaryEncoder, ok := src.(BinaryEncoder); ok { - buf, err := binaryEncoder.EncodeBinary(ci, nil) - if err != nil { - return nil, err - } - return buf, nil - } - - return nil, errors.New("cannot convert to database/sql compatible value") -} - -func EncodeValueText(src TextEncoder) (interface{}, error) { - buf, err := src.EncodeText(nil, make([]byte, 0, 32)) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - return string(buf), err -} diff --git a/pgtype/date.go b/pgtype/date.go index f1c0d8bd5..447056860 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -3,16 +3,38 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" + "fmt" + "regexp" + "strconv" "time" - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" + "github.com/jackc/pgx/v5/internal/pgio" ) +type DateScanner interface { + ScanDate(v Date) error +} + +type DateValuer interface { + DateValue() (Date, error) +} + type Date struct { Time time.Time - Status Status InfinityModifier InfinityModifier + Valid bool +} + +// ScanDate implements the [DateScanner] interface. +func (d *Date) ScanDate(v Date) error { + *d = v + return nil +} + +// DateValue implements the [DateValuer] interface. +func (d Date) DateValue() (Date, error) { + return d, nil } const ( @@ -20,144 +42,127 @@ const ( infinityDayOffset = 2147483647 ) -func (dst *Date) Set(src interface{}) error { +// Scan implements the [database/sql.Scanner] interface. +func (dst *Date) Scan(src any) error { if src == nil { - *dst = Date{Status: Null} + *dst = Date{} return nil } - switch value := src.(type) { + switch src := src.(type) { + case string: + return scanPlanTextAnyToDateScanner{}.Scan([]byte(src), dst) case time.Time: - *dst = Date{Time: value, Status: Present} - default: - if originalSrc, ok := underlyingTimeType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to Date", value) + *dst = Date{Time: src, Valid: true} + return nil } - return nil + return fmt.Errorf("cannot scan %T", src) } -func (dst *Date) Get() interface{} { - switch dst.Status { - case Present: - if dst.InfinityModifier != None { - return dst.InfinityModifier - } - return dst.Time - case Null: - return nil - default: - return dst.Status - } -} - -func (src *Date) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *time.Time: - if src.InfinityModifier != None { - return errors.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.Time - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) +// Value implements the [database/sql/driver.Valuer] interface. +func (src Date) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil } - return errors.Errorf("cannot decode %v into %T", src, dst) + if src.InfinityModifier != Finite { + return src.InfinityModifier.String(), nil + } + return src.Time, nil } -func (dst *Date) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Date{Status: Null} +// MarshalJSON implements the [encoding/json.Marshaler] interface. +func (src Date) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + + var s string + + switch src.InfinityModifier { + case Finite: + s = src.Time.Format("2006-01-02") + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return json.Marshal(s) +} + +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. +func (dst *Date) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *dst = Date{} return nil } - sbuf := string(src) - switch sbuf { + switch *s { case "infinity": - *dst = Date{Status: Present, InfinityModifier: Infinity} + *dst = Date{Valid: true, InfinityModifier: Infinity} case "-infinity": - *dst = Date{Status: Present, InfinityModifier: -Infinity} + *dst = Date{Valid: true, InfinityModifier: -Infinity} default: - t, err := time.ParseInLocation("2006-01-02", sbuf, time.UTC) + t, err := time.ParseInLocation("2006-01-02", *s, time.UTC) if err != nil { return err } - *dst = Date{Time: t, Status: Present} + *dst = Date{Time: t, Valid: true} } return nil } -func (dst *Date) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Date{Status: Null} - return nil - } +type DateCodec struct{} - if len(src) != 4 { - return errors.Errorf("invalid length for date: %v", len(src)) - } +func (DateCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} - dayOffset := int32(binary.BigEndian.Uint32(src)) +func (DateCodec) PreferredFormat() int16 { + return BinaryFormatCode +} - switch dayOffset { - case infinityDayOffset: - *dst = Date{Status: Present, InfinityModifier: Infinity} - case negativeInfinityDayOffset: - *dst = Date{Status: Present, InfinityModifier: -Infinity} - default: - t := time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.UTC) - *dst = Date{Time: t, Status: Present} +func (DateCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(DateValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanDateCodecBinary{} + case TextFormatCode: + return encodePlanDateCodecText{} } return nil } -func (src *Date) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - var s string +type encodePlanDateCodecBinary struct{} - switch src.InfinityModifier { - case None: - s = src.Time.Format("2006-01-02") - case Infinity: - s = "infinity" - case NegativeInfinity: - s = "-infinity" +func (encodePlanDateCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + date, err := value.(DateValuer).DateValue() + if err != nil { + return nil, err } - return append(buf, s...), nil -} - -func (src *Date) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !date.Valid { return nil, nil - case Undefined: - return nil, errUndefined } var daysSinceDateEpoch int32 - switch src.InfinityModifier { - case None: - tUnix := time.Date(src.Time.Year(), src.Time.Month(), src.Time.Day(), 0, 0, 0, 0, time.UTC).Unix() + switch date.InfinityModifier { + case Finite: + tUnix := time.Date(date.Time.Year(), date.Time.Month(), date.Time.Day(), 0, 0, 0, 0, time.UTC).Unix() dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() secSinceDateEpoch := tUnix - dateEpoch @@ -171,39 +176,179 @@ func (src *Date) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return pgio.AppendInt32(buf, daysSinceDateEpoch), nil } -// Scan implements the database/sql Scanner interface. -func (dst *Date) Scan(src interface{}) error { +type encodePlanDateCodecText struct{} + +func (encodePlanDateCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + date, err := value.(DateValuer).DateValue() + if err != nil { + return nil, err + } + + if !date.Valid { + return nil, nil + } + + switch date.InfinityModifier { + case Finite: + // Year 0000 is 1 BC + bc := false + year := date.Time.Year() + if year <= 0 { + year = -year + 1 + bc = true + } + + yearBytes := strconv.AppendInt(make([]byte, 0, 6), int64(year), 10) + for i := len(yearBytes); i < 4; i++ { + buf = append(buf, '0') + } + buf = append(buf, yearBytes...) + buf = append(buf, '-') + if date.Time.Month() < 10 { + buf = append(buf, '0') + } + buf = strconv.AppendInt(buf, int64(date.Time.Month()), 10) + buf = append(buf, '-') + if date.Time.Day() < 10 { + buf = append(buf, '0') + } + buf = strconv.AppendInt(buf, int64(date.Time.Day()), 10) + + if bc { + buf = append(buf, " BC"...) + } + case Infinity: + buf = append(buf, "infinity"...) + case NegativeInfinity: + buf = append(buf, "-infinity"...) + } + + return buf, nil +} + +func (DateCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case DateScanner: + return scanPlanBinaryDateToDateScanner{} + } + case TextFormatCode: + switch target.(type) { + case DateScanner: + return scanPlanTextAnyToDateScanner{} + } + } + + return nil +} + +type scanPlanBinaryDateToDateScanner struct{} + +func (scanPlanBinaryDateToDateScanner) Scan(src []byte, dst any) error { + scanner := (dst).(DateScanner) + if src == nil { - *dst = Date{Status: Null} - return nil + return scanner.ScanDate(Date{}) } - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - case time.Time: - *dst = Date{Time: src, Status: Present} - return nil + if len(src) != 4 { + return fmt.Errorf("invalid length for date: %v", len(src)) } - return errors.Errorf("cannot scan %T", src) + dayOffset := int32(binary.BigEndian.Uint32(src)) + + switch dayOffset { + case infinityDayOffset: + return scanner.ScanDate(Date{InfinityModifier: Infinity, Valid: true}) + case negativeInfinityDayOffset: + return scanner.ScanDate(Date{InfinityModifier: -Infinity, Valid: true}) + default: + t := time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.UTC) + return scanner.ScanDate(Date{Time: t, Valid: true}) + } } -// Value implements the database/sql/driver Valuer interface. -func (src *Date) Value() (driver.Value, error) { - switch src.Status { - case Present: - if src.InfinityModifier != None { - return src.InfinityModifier.String(), nil +type scanPlanTextAnyToDateScanner struct{} + +var dateRegexp = regexp.MustCompile(`^(\d{4,})-(\d\d)-(\d\d)( BC)?$`) + +func (scanPlanTextAnyToDateScanner) Scan(src []byte, dst any) error { + scanner := (dst).(DateScanner) + + if src == nil { + return scanner.ScanDate(Date{}) + } + + sbuf := string(src) + match := dateRegexp.FindStringSubmatch(sbuf) + if match != nil { + year, err := strconv.ParseInt(match[1], 10, 32) + if err != nil { + return fmt.Errorf("BUG: cannot parse date that regexp matched (year): %w", err) } - return src.Time, nil - case Null: - return nil, nil + + month, err := strconv.ParseInt(match[2], 10, 32) + if err != nil { + return fmt.Errorf("BUG: cannot parse date that regexp matched (month): %w", err) + } + + day, err := strconv.ParseInt(match[3], 10, 32) + if err != nil { + return fmt.Errorf("BUG: cannot parse date that regexp matched (month): %w", err) + } + + // BC matched + if len(match[4]) > 0 { + year = -year + 1 + } + + t := time.Date(int(year), time.Month(month), int(day), 0, 0, 0, 0, time.UTC) + return scanner.ScanDate(Date{Time: t, Valid: true}) + } + + switch sbuf { + case "infinity": + return scanner.ScanDate(Date{InfinityModifier: Infinity, Valid: true}) + case "-infinity": + return scanner.ScanDate(Date{InfinityModifier: -Infinity, Valid: true}) default: - return nil, errUndefined + return fmt.Errorf("invalid date format") + } +} + +func (c DateCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var date Date + err := codecScan(c, m, oid, format, src, &date) + if err != nil { + return nil, err + } + + if date.InfinityModifier != Finite { + return date.InfinityModifier.String(), nil } + + return date.Time, nil +} + +func (c DateCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var date Date + err := codecScan(c, m, oid, format, src, &date) + if err != nil { + return nil, err + } + + if date.InfinityModifier != Finite { + return date.InfinityModifier, nil + } + + return date.Time, nil } diff --git a/pgtype/date_array.go b/pgtype/date_array.go deleted file mode 100644 index 0cb64581a..000000000 --- a/pgtype/date_array.go +++ /dev/null @@ -1,301 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "time" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type DateArray struct { - Elements []Date - Dimensions []ArrayDimension - Status Status -} - -func (dst *DateArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = DateArray{Status: Null} - return nil - } - - switch value := src.(type) { - - case []time.Time: - if value == nil { - *dst = DateArray{Status: Null} - } else if len(value) == 0 { - *dst = DateArray{Status: Present} - } else { - elements := make([]Date, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = DateArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - default: - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to DateArray", value) - } - - return nil -} - -func (dst *DateArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *DateArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - - case *[]time.Time: - *v = make([]time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *DateArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = DateArray{Status: Null} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Date - - if len(uta.Elements) > 0 { - elements = make([]Date, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Date - var elemSrc []byte - if s != "NULL" { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = DateArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} - - return nil -} - -func (dst *DateArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = DateArray{Status: Null} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = DateArray{Dimensions: arrayHeader.Dimensions, Status: Present} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Date, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = DateArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} - return nil -} - -func (src *DateArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src *DateArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("date"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, errors.Errorf("unable to find oid for type name %v", "date") - } - - for i := range src.Elements { - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *DateArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *DateArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/date_array_test.go b/pgtype/date_array_test.go deleted file mode 100644 index 2ba19d1ad..000000000 --- a/pgtype/date_array_test.go +++ /dev/null @@ -1,143 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - "time" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestDateArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "date[]", []interface{}{ - &pgtype.DateArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.DateArray{Status: pgtype.Null}, - &pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Status: pgtype.Null}, - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.DateArray{ - Elements: []pgtype.Date{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestDateArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.DateArray - }{ - { - source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - result: pgtype.DateArray{ - Elements: []pgtype.Date{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]time.Time)(nil)), - result: pgtype.DateArray{Status: pgtype.Null}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.DateArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestDateArrayAssignTo(t *testing.T) { - var timeSlice []time.Time - - simpleTests := []struct { - src pgtype.DateArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.DateArray{ - Elements: []pgtype.Date{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &timeSlice, - expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - }, - { - src: pgtype.DateArray{Status: pgtype.Null}, - dst: &timeSlice, - expected: (([]time.Time)(nil)), - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.DateArray - dst interface{} - }{ - { - src: pgtype.DateArray{ - Elements: []pgtype.Date{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &timeSlice, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/pgtype/date_test.go b/pgtype/date_test.go index d98e16520..c7620fcf6 100644 --- a/pgtype/date_test.go +++ b/pgtype/date_test.go @@ -1,118 +1,113 @@ package pgtype_test import ( - "reflect" + "context" "testing" "time" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/assert" ) -func TestDateTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "date", []interface{}{ - &pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Date{Status: pgtype.Null}, - &pgtype.Date{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, - &pgtype.Date{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, - }, func(a, b interface{}) bool { - at := a.(pgtype.Date) - bt := b.(pgtype.Date) +func isExpectedEqTime(a any) func(any) bool { + return func(v any) bool { + at := a.(time.Time) + vt := v.(time.Time) - return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier + return at.Equal(vt) + } +} + +func TestDateCodec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "date", []pgxtest.ValueRoundTripTest{ + {time.Date(-100, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(-100, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(-1, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(-1, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC))}, + {time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC))}, + {time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(12200, 1, 2, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(12200, 1, 2, 0, 0, 0, 0, time.UTC))}, + {pgtype.Date{InfinityModifier: pgtype.Infinity, Valid: true}, new(pgtype.Date), isExpectedEq(pgtype.Date{InfinityModifier: pgtype.Infinity, Valid: true})}, + {pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, new(pgtype.Date), isExpectedEq(pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Valid: true})}, + {pgtype.Date{}, new(pgtype.Date), isExpectedEq(pgtype.Date{})}, + {nil, new(*time.Time), isExpectedEq((*time.Time)(nil))}, }) } -func TestDateSet(t *testing.T) { - type _time time.Time +func TestDateCodecTextEncode(t *testing.T) { + m := pgtype.NewMap() successfulTests := []struct { - source interface{} - result pgtype.Date + source pgtype.Date + result string }{ - {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 0, 0, 0, 0, time.UTC), Valid: true}, result: "2012-03-29"}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Valid: true}, result: "2012-03-29"}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Valid: true}, result: "2012-03-29"}, + {source: pgtype.Date{Time: time.Date(789, 1, 2, 0, 0, 0, 0, time.UTC), Valid: true}, result: "0789-01-02"}, + {source: pgtype.Date{Time: time.Date(89, 1, 2, 0, 0, 0, 0, time.UTC), Valid: true}, result: "0089-01-02"}, + {source: pgtype.Date{Time: time.Date(9, 1, 2, 0, 0, 0, 0, time.UTC), Valid: true}, result: "0009-01-02"}, + {source: pgtype.Date{Time: time.Date(12200, 1, 2, 0, 0, 0, 0, time.UTC), Valid: true}, result: "12200-01-02"}, + {source: pgtype.Date{InfinityModifier: pgtype.Infinity, Valid: true}, result: "infinity"}, + {source: pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, result: "-infinity"}, } - for i, tt := range successfulTests { - var d pgtype.Date - err := d.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if d != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) - } + buf, err := m.Encode(pgtype.DateOID, pgtype.TextFormatCode, tt.source, nil) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.result, string(buf), "%d", i) } } -func TestDateAssignTo(t *testing.T) { - var tim time.Time - var ptim *time.Time - - simpleTests := []struct { - src pgtype.Date - dst interface{} - expected interface{} +func TestDateMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Date + result string }{ - {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, - {src: pgtype.Date{Time: time.Time{}, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, + {source: pgtype.Date{}, result: "null"}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 0, 0, 0, 0, time.UTC), Valid: true}, result: "\"2012-03-29\""}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Valid: true}, result: "\"2012-03-29\""}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Valid: true}, result: "\"2012-03-29\""}, + {source: pgtype.Date{InfinityModifier: pgtype.Infinity, Valid: true}, result: "\"infinity\""}, + {source: pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, result: "\"-infinity\""}, } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() if err != nil { t.Errorf("%d: %v", i, err) } - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) } } +} - pointerAllocTests := []struct { - src pgtype.Date - dst interface{} - expected interface{} +func TestDateUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Date }{ - {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + {source: "null", result: pgtype.Date{}}, + {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Valid: true}}, + {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Valid: true}}, + {source: "\"infinity\"", result: pgtype.Date{InfinityModifier: pgtype.Infinity, Valid: true}}, + {source: "\"-infinity\"", result: pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Valid: true}}, } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) + for i, tt := range successfulTests { + var r pgtype.Date + err := r.UnmarshalJSON([]byte(tt.source)) if err != nil { t.Errorf("%d: %v", i, err) } - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Date - dst interface{} - }{ - {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, dst: &tim}, - {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, dst: &tim}, - {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Null}, dst: &tim}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + if r.Time.Year() != tt.result.Time.Year() || r.Time.Month() != tt.result.Time.Month() || r.Time.Day() != tt.result.Time.Day() || r.Valid != tt.result.Valid || r.InfinityModifier != tt.result.InfinityModifier { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) } } } diff --git a/pgtype/daterange.go b/pgtype/daterange.go deleted file mode 100644 index 47cd7e460..000000000 --- a/pgtype/daterange.go +++ /dev/null @@ -1,250 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type Daterange struct { - Lower Date - Upper Date - LowerType BoundType - UpperType BoundType - Status Status -} - -func (dst *Daterange) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Daterange", src) -} - -func (dst *Daterange) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *Daterange) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Daterange) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Daterange{Status: Null} - return nil - } - - utr, err := ParseUntypedTextRange(string(src)) - if err != nil { - return err - } - - *dst = Daterange{Status: Present} - - dst.LowerType = utr.LowerType - dst.UpperType = utr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { - return err - } - } - - return nil -} - -func (dst *Daterange) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Daterange{Status: Null} - return nil - } - - ubr, err := ParseUntypedBinaryRange(src) - if err != nil { - return err - } - - *dst = Daterange{Status: Present} - - dst.LowerType = ubr.LowerType - dst.UpperType = ubr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { - return err - } - } - - return nil -} - -func (src Daterange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - switch src.LowerType { - case Exclusive, Unbounded: - buf = append(buf, '(') - case Inclusive: - buf = append(buf, '[') - case Empty: - return append(buf, "empty"...), nil - default: - return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) - } - - var err error - - if src.LowerType != Unbounded { - buf, err = src.Lower.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - } - - buf = append(buf, ',') - - if src.UpperType != Unbounded { - buf, err = src.Upper.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - } - - switch src.UpperType { - case Exclusive, Unbounded: - buf = append(buf, ')') - case Inclusive: - buf = append(buf, ']') - default: - return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) - } - - return buf, nil -} - -func (src Daterange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - var rangeType byte - switch src.LowerType { - case Inclusive: - rangeType |= lowerInclusiveMask - case Unbounded: - rangeType |= lowerUnboundedMask - case Exclusive: - case Empty: - return append(buf, emptyMask), nil - default: - return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) - } - - switch src.UpperType { - case Inclusive: - rangeType |= upperInclusiveMask - case Unbounded: - rangeType |= upperUnboundedMask - case Exclusive: - default: - return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) - } - - buf = append(buf, rangeType) - - var err error - - if src.LowerType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Lower.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - if src.UpperType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Upper.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Daterange) Scan(src interface{}) error { - if src == nil { - *dst = Daterange{Status: Null} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Daterange) Value() (driver.Value, error) { - return EncodeValueText(src) -} diff --git a/pgtype/daterange_test.go b/pgtype/daterange_test.go deleted file mode 100644 index d2af5986b..000000000 --- a/pgtype/daterange_test.go +++ /dev/null @@ -1,67 +0,0 @@ -package pgtype_test - -import ( - "testing" - "time" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestDaterangeTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "daterange", []interface{}{ - &pgtype.Daterange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, - &pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - &pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - &pgtype.Daterange{Status: pgtype.Null}, - }, func(aa, bb interface{}) bool { - a := aa.(pgtype.Daterange) - b := bb.(pgtype.Daterange) - - return a.Status == b.Status && - a.Lower.Time.Equal(b.Lower.Time) && - a.Lower.Status == b.Lower.Status && - a.Lower.InfinityModifier == b.Lower.InfinityModifier && - a.Upper.Time.Equal(b.Upper.Time) && - a.Upper.Status == b.Upper.Status && - a.Upper.InfinityModifier == b.Upper.InfinityModifier - }) -} - -func TestDaterangeNormalize(t *testing.T) { - testutil.TestSuccessfulNormalizeEqFunc(t, []testutil.NormalizeTest{ - { - SQL: "select daterange('2010-01-01', '2010-01-11', '(]')", - Value: pgtype.Daterange{ - Lower: pgtype.Date{Time: time.Date(2010, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Date{Time: time.Date(2010, 1, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - }, - }, func(aa, bb interface{}) bool { - a := aa.(pgtype.Daterange) - b := bb.(pgtype.Daterange) - - return a.Status == b.Status && - a.Lower.Time.Equal(b.Lower.Time) && - a.Lower.Status == b.Lower.Status && - a.Lower.InfinityModifier == b.Lower.InfinityModifier && - a.Upper.Time.Equal(b.Upper.Time) && - a.Upper.Status == b.Upper.Status && - a.Upper.InfinityModifier == b.Upper.InfinityModifier - }) -} diff --git a/pgtype/decimal.go b/pgtype/decimal.go deleted file mode 100644 index 79653cf38..000000000 --- a/pgtype/decimal.go +++ /dev/null @@ -1,31 +0,0 @@ -package pgtype - -type Decimal Numeric - -func (dst *Decimal) Set(src interface{}) error { - return (*Numeric)(dst).Set(src) -} - -func (dst *Decimal) Get() interface{} { - return (*Numeric)(dst).Get() -} - -func (src *Decimal) AssignTo(dst interface{}) error { - return (*Numeric)(src).AssignTo(dst) -} - -func (dst *Decimal) DecodeText(ci *ConnInfo, src []byte) error { - return (*Numeric)(dst).DecodeText(ci, src) -} - -func (dst *Decimal) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*Numeric)(dst).DecodeBinary(ci, src) -} - -func (src *Decimal) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Numeric)(src).EncodeText(ci, buf) -} - -func (src *Decimal) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Numeric)(src).EncodeBinary(ci, buf) -} diff --git a/pgtype/derived_types_test.go b/pgtype/derived_types_test.go new file mode 100644 index 000000000..05d109bda --- /dev/null +++ b/pgtype/derived_types_test.go @@ -0,0 +1,61 @@ +package pgtype_test + +import ( + "context" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/require" +) + +func TestDerivedTypes(t *testing.T) { + skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(ctx, ` +drop type if exists dt_test; +drop domain if exists dt_uint64; + +create domain dt_uint64 as numeric(20,0); +create type dt_test as ( + a text, + b dt_uint64, + c dt_uint64[] +);`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop domain dt_uint64") + defer conn.Exec(ctx, "drop type dt_test") + + dtypes, err := conn.LoadTypes(ctx, []string{"dt_test"}) + require.Len(t, dtypes, 6) + require.Equal(t, dtypes[0].Name, "public.dt_uint64") + require.Equal(t, dtypes[1].Name, "dt_uint64") + require.Equal(t, dtypes[2].Name, "public._dt_uint64") + require.Equal(t, dtypes[3].Name, "_dt_uint64") + require.Equal(t, dtypes[4].Name, "public.dt_test") + require.Equal(t, dtypes[5].Name, "dt_test") + require.NoError(t, err) + conn.TypeMap().RegisterTypes(dtypes) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + for _, format := range formats { + var a string + var b uint64 + var c *[]uint64 + + row := conn.QueryRow(ctx, "select $1::dt_test", pgx.QueryResultFormats{format.code}, pgtype.CompositeFields{"hi", uint64(42), []uint64{10, 20, 30}}) + err := row.Scan(pgtype.CompositeFields{&a, &b, &c}) + require.NoError(t, err) + require.EqualValuesf(t, "hi", a, "%v", format.name) + require.EqualValuesf(t, 42, b, "%v", format.name) + } + }) +} diff --git a/pgtype/doc.go b/pgtype/doc.go new file mode 100644 index 000000000..83dfc5de5 --- /dev/null +++ b/pgtype/doc.go @@ -0,0 +1,196 @@ +// Package pgtype converts between Go and PostgreSQL values. +/* +The primary type is the Map type. It is a map of PostgreSQL types identified by OID (object ID) to a Codec. A Codec is +responsible for converting between Go and PostgreSQL values. NewMap creates a Map with all supported standard PostgreSQL +types already registered. Additional types can be registered with Map.RegisterType. + +Use Map.Scan and Map.Encode to decode PostgreSQL values to Go and encode Go values to PostgreSQL respectively. + +Base Type Mapping + +pgtype maps between all common base types directly between Go and PostgreSQL. In particular: + + Go PostgreSQL + ----------------------- + string varchar + text + + // Integers are automatically be converted to any other integer type if + // it can be done without overflow or underflow. + int8 + int16 smallint + int32 int + int64 bigint + int + uint8 + uint16 + uint32 + uint64 + uint + + // Floats are strict and do not automatically convert like integers. + float32 float4 + float64 float8 + + time.Time date + timestamp + timestamptz + + netip.Addr inet + netip.Prefix cidr + + []byte bytea + +Null Values + +pgtype can map NULLs in two ways. The first is types that can directly represent NULL such as Int4. They work in a +similar fashion to database/sql. The second is to use a pointer to a pointer. + + var foo pgtype.Text + var bar *string + err := conn.QueryRow("select foo, bar from widgets where id=$1", 42).Scan(&foo, &bar) + if err != nil { + return err + } + +When using nullable pgtype types as parameters for queries, one has to remember to explicitly set their Valid field to +true, otherwise the parameter's value will be NULL. + +JSON Support + +pgtype automatically marshals and unmarshals data from json and jsonb PostgreSQL types. + +Extending Existing PostgreSQL Type Support + +Generally, all Codecs will support interfaces that can be implemented to enable scanning and encoding. For example, +PointCodec can use any Go type that implements the PointScanner and PointValuer interfaces. So rather than use +pgtype.Point and application can directly use its own point type with pgtype as long as it implements those interfaces. + +See example_custom_type_test.go for an example of a custom type for the PostgreSQL point type. + +Sometimes pgx supports a PostgreSQL type such as numeric but the Go type is in an external package that does not have +pgx support such as github.com/shopspring/decimal. These types can be registered with pgtype with custom conversion +logic. See https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid for example +integrations. + +New PostgreSQL Type Support + +pgtype uses the PostgreSQL OID to determine how to encode or decode a value. pgtype supports array, composite, domain, +and enum types. However, any type created in PostgreSQL with CREATE TYPE will receive a new OID. This means that the OID +of each new PostgreSQL type must be registered for pgtype to handle values of that type with the correct Codec. + +The pgx.Conn LoadType method can return a *Type for array, composite, domain, and enum types by inspecting the database +metadata. This *Type can then be registered with Map.RegisterType. + +For example, the following function could be called after a connection is established: + + func RegisterDataTypes(ctx context.Context, conn *pgx.Conn) error { + dataTypeNames := []string{ + "foo", + "_foo", + "bar", + "_bar", + } + + for _, typeName := range dataTypeNames { + dataType, err := conn.LoadType(ctx, typeName) + if err != nil { + return err + } + conn.TypeMap().RegisterType(dataType) + } + + return nil + } + +A type cannot be registered unless all types it depends on are already registered. e.g. An array type cannot be +registered until its element type is registered. + +ArrayCodec implements support for arrays. If pgtype supports type T then it can easily support []T by registering an +ArrayCodec for the appropriate PostgreSQL OID. In addition, Array[T] type can support multi-dimensional arrays. + +CompositeCodec implements support for PostgreSQL composite types. Go structs can be scanned into if the public fields of +the struct are in the exact order and type of the PostgreSQL type or by implementing CompositeIndexScanner and +CompositeIndexGetter. + +Domain types are treated as their underlying type if the underlying type and the domain type are registered. + +PostgreSQL enums can usually be treated as text. However, EnumCodec implements support for interning strings which can +reduce memory usage. + +While pgtype will often still work with unregistered types it is highly recommended that all types be registered due to +an improvement in performance and the elimination of certain edge cases. + +If an entirely new PostgreSQL type (e.g. PostGIS types) is used then the application or a library can create a new +Codec. Then the OID / Codec mapping can be registered with Map.RegisterType. There is no difference between a Codec +defined and registered by the application and a Codec built in to pgtype. See any of the Codecs in pgtype for Codec +examples and for examples of type registration. + +Encoding Unknown Types + +pgtype works best when the OID of the PostgreSQL type is known. But in some cases such as using the simple protocol the +OID is unknown. In this case Map.RegisterDefaultPgType can be used to register an assumed OID for a particular Go type. + +Renamed Types + +If pgtype does not recognize a type and that type is a renamed simple type simple (e.g. type MyInt32 int32) pgtype acts +as if it is the underlying type. It currently cannot automatically detect the underlying type of renamed structs (eg.g. +type MyTime time.Time). + +Compatibility with database/sql + +pgtype also includes support for custom types implementing the database/sql.Scanner and database/sql/driver.Valuer +interfaces. + +Encoding Typed Nils + +pgtype encodes untyped and typed nils (e.g. nil and []byte(nil)) to the SQL NULL value without going through the Codec +system. This means that Codecs and other encoding logic do not have to handle nil or *T(nil). + +However, database/sql compatibility requires Value to be called on T(nil) when T implements driver.Valuer. Therefore, +driver.Valuer values are only considered NULL when *T(nil) where driver.Valuer is implemented on T not on *T. See +https://github.com/golang/go/issues/8415 and +https://github.com/golang/go/commit/0ce1d79a6a771f7449ec493b993ed2a720917870. + +Child Records + +pgtype's support for arrays and composite records can be used to load records and their children in a single query. See +example_child_records_test.go for an example. + +Overview of Scanning Implementation + +The first step is to use the OID to lookup the correct Codec. The Map will call the Codec's PlanScan method to get a +plan for scanning into the Go value. A Codec will support scanning into one or more Go types. Oftentime these Go types +are interfaces rather than explicit types. For example, PointCodec can use any Go type that implements the PointScanner +and PointValuer interfaces. + +If a Go value is not supported directly by a Codec then Map will try see if it is a sql.Scanner. If is then that +interface will be used to scan the value. Most sql.Scanners require the input to be in the text format (e.g. UUIDs and +numeric). However, pgx will typically have received the value in the binary format. In this case the binary value will be +parsed, reencoded as text, and then passed to the sql.Scanner. This may incur additional overhead for query results with +a large number of affected values. + +If a Go value is not supported directly by a Codec then Map will try wrapping it with additional logic and try again. +For example, Int8Codec does not support scanning into a renamed type (e.g. type myInt64 int64). But Map will detect that +myInt64 is a renamed type and create a plan that converts the value to the underlying int64 type and then passes that to +the Codec (see TryFindUnderlyingTypeScanPlan). + +These plan wrappers are contained in Map.TryWrapScanPlanFuncs. By default these contain shared logic to handle renamed +types, pointers to pointers, slices, composite types, etc. Additional plan wrappers can be added to seamlessly integrate +types that do not support pgx directly. For example, the before mentioned +https://github.com/jackc/pgx-shopspring-decimal package detects decimal.Decimal values, wraps them in something +implementing NumericScanner and passes that to the Codec. + +Map.Scan and Map.Encode are convenience methods that wrap Map.PlanScan and Map.PlanEncode. Determining how to scan or +encode a particular type may be a time consuming operation. Hence the planning and execution steps of a conversion are +internally separated. + +Reducing Compiled Binary Size + +pgx.QueryExecModeExec and pgx.QueryExecModeSimpleProtocol require the default PostgreSQL type to be registered for each +Go type used as a query parameter. By default pgx does this for all supported types and their array variants. If an +application does not use those query execution modes or manually registers the default PostgreSQL type for the types it +uses as query parameters it can use the build tag nopgxregisterdefaulttypes. This omits the default type registration +and reduces the compiled binary size by ~2MB. +*/ +package pgtype diff --git a/pgtype/enum_array.go b/pgtype/enum_array.go deleted file mode 100644 index 3a9480159..000000000 --- a/pgtype/enum_array.go +++ /dev/null @@ -1,212 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - - "github.com/pkg/errors" -) - -type EnumArray struct { - Elements []GenericText - Dimensions []ArrayDimension - Status Status -} - -func (dst *EnumArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = EnumArray{Status: Null} - return nil - } - - switch value := src.(type) { - - case []string: - if value == nil { - *dst = EnumArray{Status: Null} - } else if len(value) == 0 { - *dst = EnumArray{Status: Present} - } else { - elements := make([]GenericText, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = EnumArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - default: - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to EnumArray", value) - } - - return nil -} - -func (dst *EnumArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *EnumArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *EnumArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = EnumArray{Status: Null} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []GenericText - - if len(uta.Elements) > 0 { - elements = make([]GenericText, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem GenericText - var elemSrc []byte - if s != "NULL" { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = EnumArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} - - return nil -} - -func (src *EnumArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *EnumArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *EnumArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/enum_array_test.go b/pgtype/enum_array_test.go deleted file mode 100644 index 9cc950af3..000000000 --- a/pgtype/enum_array_test.go +++ /dev/null @@ -1,150 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestEnumArrayTranscode(t *testing.T) { - setupConn := testutil.MustConnectPgx(t) - defer testutil.MustClose(t, setupConn) - - if _, err := setupConn.Exec("drop type if exists color"); err != nil { - t.Fatal(err) - } - if _, err := setupConn.Exec("create type color as enum ('red', 'green', 'blue')"); err != nil { - t.Fatal(err) - } - - testutil.TestSuccessfulTranscode(t, "color[]", []interface{}{ - &pgtype.EnumArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.EnumArray{ - Elements: []pgtype.GenericText{ - {String: "red", Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.EnumArray{Status: pgtype.Null}, - &pgtype.EnumArray{ - Elements: []pgtype.GenericText{ - {String: "red", Status: pgtype.Present}, - {String: "green", Status: pgtype.Present}, - {String: "blue", Status: pgtype.Present}, - {String: "red", Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestEnumArrayArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.EnumArray - }{ - { - source: []string{"foo"}, - result: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]string)(nil)), - result: pgtype.EnumArray{Status: pgtype.Null}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.EnumArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestEnumArrayArrayAssignTo(t *testing.T) { - var stringSlice []string - type _stringSlice []string - var namedStringSlice _stringSlice - - simpleTests := []struct { - src pgtype.EnumArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &stringSlice, - expected: []string{"foo"}, - }, - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &namedStringSlice, - expected: _stringSlice{"bar"}, - }, - { - src: pgtype.EnumArray{Status: pgtype.Null}, - dst: &stringSlice, - expected: (([]string)(nil)), - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.EnumArray - dst interface{} - }{ - { - src: pgtype.EnumArray{ - Elements: []pgtype.GenericText{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &stringSlice, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/enum_codec.go b/pgtype/enum_codec.go new file mode 100644 index 000000000..5e787c1e2 --- /dev/null +++ b/pgtype/enum_codec.go @@ -0,0 +1,109 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" +) + +// EnumCodec is a codec that caches the strings it decodes. If the same string is read multiple times only one copy is +// allocated. These strings are only garbage collected when the EnumCodec is garbage collected. EnumCodec can be used +// for any text type not only enums, but it should only be used when there are a small number of possible values. +type EnumCodec struct { + membersMap map[string]string // map to quickly lookup member and reuse string instead of allocating +} + +func (EnumCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (EnumCodec) PreferredFormat() int16 { + return TextFormatCode +} + +func (EnumCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case TextFormatCode, BinaryFormatCode: + switch value.(type) { + case string: + return encodePlanTextCodecString{} + case []byte: + return encodePlanTextCodecByteSlice{} + case TextValuer: + return encodePlanTextCodecTextValuer{} + } + } + + return nil +} + +func (c *EnumCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case TextFormatCode, BinaryFormatCode: + switch target.(type) { + case *string: + return &scanPlanTextAnyToEnumString{codec: c} + case *[]byte: + return scanPlanAnyToNewByteSlice{} + case TextScanner: + return &scanPlanTextAnyToEnumTextScanner{codec: c} + } + } + + return nil +} + +func (c *EnumCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return c.DecodeValue(m, oid, format, src) +} + +func (c *EnumCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + return c.lookupAndCacheString(src), nil +} + +// lookupAndCacheString looks for src in the members map. If it is not found it is added to the map. +func (c *EnumCodec) lookupAndCacheString(src []byte) string { + if c.membersMap == nil { + c.membersMap = make(map[string]string) + } + + if s, found := c.membersMap[string(src)]; found { + return s + } + + s := string(src) + c.membersMap[s] = s + return s +} + +type scanPlanTextAnyToEnumString struct { + codec *EnumCodec +} + +func (plan *scanPlanTextAnyToEnumString) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p := (dst).(*string) + *p = plan.codec.lookupAndCacheString(src) + + return nil +} + +type scanPlanTextAnyToEnumTextScanner struct { + codec *EnumCodec +} + +func (plan *scanPlanTextAnyToEnumTextScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TextScanner) + + if src == nil { + return scanner.ScanText(Text{}) + } + + return scanner.ScanText(Text{String: plan.codec.lookupAndCacheString(src), Valid: true}) +} diff --git a/pgtype/enum_codec_test.go b/pgtype/enum_codec_test.go new file mode 100644 index 000000000..96746438c --- /dev/null +++ b/pgtype/enum_codec_test.go @@ -0,0 +1,67 @@ +package pgtype_test + +import ( + "context" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" +) + +func TestEnumCodec(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(ctx, `drop type if exists enum_test; + +create type enum_test as enum ('foo', 'bar', 'baz');`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop type enum_test") + + dt, err := conn.LoadType(ctx, "enum_test") + require.NoError(t, err) + + conn.TypeMap().RegisterType(dt) + + var s string + err = conn.QueryRow(ctx, `select 'foo'::enum_test`).Scan(&s) + require.NoError(t, err) + require.Equal(t, "foo", s) + + err = conn.QueryRow(ctx, `select $1::enum_test`, "bar").Scan(&s) + require.NoError(t, err) + require.Equal(t, "bar", s) + + err = conn.QueryRow(ctx, `select 'foo'::enum_test`).Scan(&s) + require.NoError(t, err) + require.Equal(t, "foo", s) + + err = conn.QueryRow(ctx, `select $1::enum_test`, "bar").Scan(&s) + require.NoError(t, err) + require.Equal(t, "bar", s) + + err = conn.QueryRow(ctx, `select 'baz'::enum_test`).Scan(&s) + require.NoError(t, err) + require.Equal(t, "baz", s) + }) +} + +func TestEnumCodecValues(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(ctx, `drop type if exists enum_test; + +create type enum_test as enum ('foo', 'bar', 'baz');`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop type enum_test") + + dt, err := conn.LoadType(ctx, "enum_test") + require.NoError(t, err) + + conn.TypeMap().RegisterType(dt) + + rows, err := conn.Query(ctx, `select 'foo'::enum_test`) + require.NoError(t, err) + require.True(t, rows.Next()) + values, err := rows.Values() + require.NoError(t, err) + require.Equal(t, []any{"foo"}, values) + }) +} diff --git a/pgtype/example_child_records_test.go b/pgtype/example_child_records_test.go new file mode 100644 index 000000000..08d3c9e78 --- /dev/null +++ b/pgtype/example_child_records_test.go @@ -0,0 +1,103 @@ +package pgtype_test + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/jackc/pgx/v5" +) + +type Player struct { + Name string + Position string +} + +type Team struct { + Name string + Players []Player +} + +// This example uses a single query to return parent and child records. +func Example_childRecords() { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + if conn.PgConn().ParameterStatus("crdb_version") != "" { + // Skip test / example when running on CockroachDB. Since an example can't be skipped fake success instead. + fmt.Println(`Alpha + Adam: wing + Bill: halfback + Charlie: fullback +Beta + Don: halfback + Edgar: halfback + Frank: fullback`) + return + } + + // Setup example schema and data. + _, err = conn.Exec(ctx, ` +create temporary table teams ( + name text primary key +); + +create temporary table players ( + name text primary key, + team_name text, + position text +); + +insert into teams (name) values + ('Alpha'), + ('Beta'); + +insert into players (name, team_name, position) values + ('Adam', 'Alpha', 'wing'), + ('Bill', 'Alpha', 'halfback'), + ('Charlie', 'Alpha', 'fullback'), + ('Don', 'Beta', 'halfback'), + ('Edgar', 'Beta', 'halfback'), + ('Frank', 'Beta', 'fullback') +`) + if err != nil { + fmt.Printf("Unable to setup example schema and data: %v", err) + return + } + + rows, _ := conn.Query(ctx, ` +select t.name, + (select array_agg(row(p.name, position) order by p.name) from players p where p.team_name = t.name) +from teams t +order by t.name +`) + teams, err := pgx.CollectRows(rows, pgx.RowToStructByPos[Team]) + if err != nil { + fmt.Printf("CollectRows error: %v", err) + return + } + + for _, team := range teams { + fmt.Println(team.Name) + for _, player := range team.Players { + fmt.Printf(" %s: %s\n", player.Name, player.Position) + } + } + + // Output: + // Alpha + // Adam: wing + // Bill: halfback + // Charlie: fullback + // Beta + // Don: halfback + // Edgar: halfback + // Frank: fullback +} diff --git a/pgtype/example_custom_type_test.go b/pgtype/example_custom_type_test.go new file mode 100644 index 000000000..ceb9a0aab --- /dev/null +++ b/pgtype/example_custom_type_test.go @@ -0,0 +1,75 @@ +package pgtype_test + +import ( + "context" + "fmt" + "os" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" +) + +// Point represents a point that may be null. +type Point struct { + X, Y float32 // Coordinates of point + Valid bool +} + +func (p *Point) ScanPoint(v pgtype.Point) error { + *p = Point{ + X: float32(v.P.X), + Y: float32(v.P.Y), + Valid: v.Valid, + } + return nil +} + +func (p Point) PointValue() (pgtype.Point, error) { + return pgtype.Point{ + P: pgtype.Vec2{X: float64(p.X), Y: float64(p.Y)}, + Valid: true, + }, nil +} + +func (src *Point) String() string { + if !src.Valid { + return "null point" + } + + return fmt.Sprintf("%.1f, %.1f", src.X, src.Y) +} + +func Example_customType() { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + defer conn.Close(context.Background()) + + if conn.PgConn().ParameterStatus("crdb_version") != "" { + // Skip test / example when running on CockroachDB which doesn't support the point type. Since an example can't be + // skipped fake success instead. + fmt.Println("null point") + fmt.Println("1.5, 2.5") + return + } + + p := &Point{} + err = conn.QueryRow(context.Background(), "select null::point").Scan(p) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(p) + + err = conn.QueryRow(context.Background(), "select point(1.5,2.5)").Scan(p) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(p) + // Output: + // null point + // 1.5, 2.5 +} diff --git a/example_json_test.go b/pgtype/example_json_test.go similarity index 59% rename from example_json_test.go rename to pgtype/example_json_test.go index 09e27cff4..98fb675aa 100644 --- a/example_json_test.go +++ b/pgtype/example_json_test.go @@ -1,13 +1,15 @@ -package pgx_test +package pgtype_test import ( + "context" "fmt" + "os" - "github.com/jackc/pgx" + "github.com/jackc/pgx/v5" ) -func Example_JSON() { - conn, err := pgx.Connect(*defaultConnConfig) +func Example_json() { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { fmt.Printf("Unable to establish connection: %v", err) return @@ -25,7 +27,7 @@ func Example_JSON() { var output person - err = conn.QueryRow("select $1::json", input).Scan(&output) + err = conn.QueryRow(context.Background(), "select $1::json", input).Scan(&output) if err != nil { fmt.Println(err) return diff --git a/pgtype/ext/satori-uuid/uuid.go b/pgtype/ext/satori-uuid/uuid.go deleted file mode 100644 index 78a900354..000000000 --- a/pgtype/ext/satori-uuid/uuid.go +++ /dev/null @@ -1,161 +0,0 @@ -package uuid - -import ( - "database/sql/driver" - - "github.com/pkg/errors" - - "github.com/jackc/pgx/pgtype" - uuid "github.com/satori/go.uuid" -) - -var errUndefined = errors.New("cannot encode status undefined") - -type UUID struct { - UUID uuid.UUID - Status pgtype.Status -} - -func (dst *UUID) Set(src interface{}) error { - switch value := src.(type) { - case uuid.UUID: - *dst = UUID{UUID: value, Status: pgtype.Present} - case [16]byte: - *dst = UUID{UUID: uuid.UUID(value), Status: pgtype.Present} - case []byte: - if len(value) != 16 { - return errors.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) - } - *dst = UUID{Status: pgtype.Present} - copy(dst.UUID[:], value) - case string: - uuid, err := uuid.FromString(value) - if err != nil { - return err - } - *dst = UUID{UUID: uuid, Status: pgtype.Present} - default: - // If all else fails see if pgtype.UUID can handle it. If so, translate through that. - pgUUID := &pgtype.UUID{} - if err := pgUUID.Set(value); err != nil { - return errors.Errorf("cannot convert %v to UUID", value) - } - - *dst = UUID{UUID: uuid.UUID(pgUUID.Bytes), Status: pgUUID.Status} - } - - return nil -} - -func (dst *UUID) Get() interface{} { - switch dst.Status { - case pgtype.Present: - return dst.UUID - case pgtype.Null: - return nil - default: - return dst.Status - } -} - -func (src *UUID) AssignTo(dst interface{}) error { - switch src.Status { - case pgtype.Present: - switch v := dst.(type) { - case *uuid.UUID: - *v = src.UUID - case *[16]byte: - *v = [16]byte(src.UUID) - return nil - case *[]byte: - *v = make([]byte, 16) - copy(*v, src.UUID[:]) - return nil - case *string: - *v = src.UUID.String() - return nil - default: - if nextDst, retry := pgtype.GetAssignToDstType(v); retry { - return src.AssignTo(nextDst) - } - } - case pgtype.Null: - return pgtype.NullAssignTo(dst) - } - - return errors.Errorf("cannot assign %v into %T", src, dst) -} - -func (dst *UUID) DecodeText(ci *pgtype.ConnInfo, src []byte) error { - if src == nil { - *dst = UUID{Status: pgtype.Null} - return nil - } - - u, err := uuid.FromString(string(src)) - if err != nil { - return err - } - - *dst = UUID{UUID: u, Status: pgtype.Present} - return nil -} - -func (dst *UUID) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - if src == nil { - *dst = UUID{Status: pgtype.Null} - return nil - } - - if len(src) != 16 { - return errors.Errorf("invalid length for UUID: %v", len(src)) - } - - *dst = UUID{Status: pgtype.Present} - copy(dst.UUID[:], src) - return nil -} - -func (src *UUID) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case pgtype.Null: - return nil, nil - case pgtype.Undefined: - return nil, errUndefined - } - - return append(buf, src.UUID.String()...), nil -} - -func (src *UUID) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case pgtype.Null: - return nil, nil - case pgtype.Undefined: - return nil, errUndefined - } - - return append(buf, src.UUID[:]...), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *UUID) Scan(src interface{}) error { - if src == nil { - *dst = UUID{Status: pgtype.Null} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - return dst.DecodeText(nil, src) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *UUID) Value() (driver.Value, error) { - return pgtype.EncodeValueText(src) -} diff --git a/pgtype/ext/satori-uuid/uuid_test.go b/pgtype/ext/satori-uuid/uuid_test.go deleted file mode 100644 index 02ebb770e..000000000 --- a/pgtype/ext/satori-uuid/uuid_test.go +++ /dev/null @@ -1,97 +0,0 @@ -package uuid_test - -import ( - "bytes" - "testing" - - "github.com/jackc/pgx/pgtype" - satori "github.com/jackc/pgx/pgtype/ext/satori-uuid" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestUUIDTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ - &satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - &satori.UUID{Status: pgtype.Null}, - }) -} - -func TestUUIDSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result satori.UUID - }{ - { - source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - }, - { - source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - }, - { - source: "00010203-0405-0607-0809-0a0b0c0d0e0f", - result: satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - }, - } - - for i, tt := range successfulTests { - var r satori.UUID - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestUUIDAssignTo(t *testing.T) { - { - src := satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} - var dst [16]byte - expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if dst != expected { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } - - { - src := satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} - var dst []byte - expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if bytes.Compare(dst, expected) != 0 { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } - - { - src := satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} - var dst string - expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if dst != expected { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } - -} diff --git a/pgtype/ext/shopspring-numeric/decimal.go b/pgtype/ext/shopspring-numeric/decimal.go deleted file mode 100644 index 507a93dc7..000000000 --- a/pgtype/ext/shopspring-numeric/decimal.go +++ /dev/null @@ -1,317 +0,0 @@ -package numeric - -import ( - "database/sql/driver" - "strconv" - - "github.com/pkg/errors" - - "github.com/jackc/pgx/pgtype" - "github.com/shopspring/decimal" -) - -var errUndefined = errors.New("cannot encode status undefined") - -type Numeric struct { - Decimal decimal.Decimal - Status pgtype.Status -} - -func (dst *Numeric) Set(src interface{}) error { - if src == nil { - *dst = Numeric{Status: pgtype.Null} - return nil - } - - switch value := src.(type) { - case decimal.Decimal: - *dst = Numeric{Decimal: value, Status: pgtype.Present} - case float32: - *dst = Numeric{Decimal: decimal.NewFromFloat(float64(value)), Status: pgtype.Present} - case float64: - *dst = Numeric{Decimal: decimal.NewFromFloat(value), Status: pgtype.Present} - case int8: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} - case uint8: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} - case int16: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} - case uint16: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} - case int32: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} - case uint32: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} - case int64: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} - case uint64: - // uint64 could be greater than int64 so convert to string then to decimal - dec, err := decimal.NewFromString(strconv.FormatUint(value, 10)) - if err != nil { - return err - } - *dst = Numeric{Decimal: dec, Status: pgtype.Present} - case int: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} - case uint: - // uint could be greater than int64 so convert to string then to decimal - dec, err := decimal.NewFromString(strconv.FormatUint(uint64(value), 10)) - if err != nil { - return err - } - *dst = Numeric{Decimal: dec, Status: pgtype.Present} - case string: - dec, err := decimal.NewFromString(value) - if err != nil { - return err - } - *dst = Numeric{Decimal: dec, Status: pgtype.Present} - default: - // If all else fails see if pgtype.Numeric can handle it. If so, translate through that. - num := &pgtype.Numeric{} - if err := num.Set(value); err != nil { - return errors.Errorf("cannot convert %v to Numeric", value) - } - - buf, err := num.EncodeText(nil, nil) - if err != nil { - return errors.Errorf("cannot convert %v to Numeric", value) - } - - dec, err := decimal.NewFromString(string(buf)) - if err != nil { - return errors.Errorf("cannot convert %v to Numeric", value) - } - *dst = Numeric{Decimal: dec, Status: pgtype.Present} - } - - return nil -} - -func (dst *Numeric) Get() interface{} { - switch dst.Status { - case pgtype.Present: - return dst.Decimal - case pgtype.Null: - return nil - default: - return dst.Status - } -} - -func (src *Numeric) AssignTo(dst interface{}) error { - switch src.Status { - case pgtype.Present: - switch v := dst.(type) { - case *decimal.Decimal: - *v = src.Decimal - case *float32: - f, _ := src.Decimal.Float64() - *v = float32(f) - case *float64: - f, _ := src.Decimal.Float64() - *v = f - case *int: - if src.Decimal.Exponent() < 0 { - return errors.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseInt(src.Decimal.String(), 10, strconv.IntSize) - if err != nil { - return errors.Errorf("cannot convert %v to %T", dst, *v) - } - *v = int(n) - case *int8: - if src.Decimal.Exponent() < 0 { - return errors.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseInt(src.Decimal.String(), 10, 8) - if err != nil { - return errors.Errorf("cannot convert %v to %T", dst, *v) - } - *v = int8(n) - case *int16: - if src.Decimal.Exponent() < 0 { - return errors.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseInt(src.Decimal.String(), 10, 16) - if err != nil { - return errors.Errorf("cannot convert %v to %T", dst, *v) - } - *v = int16(n) - case *int32: - if src.Decimal.Exponent() < 0 { - return errors.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseInt(src.Decimal.String(), 10, 32) - if err != nil { - return errors.Errorf("cannot convert %v to %T", dst, *v) - } - *v = int32(n) - case *int64: - if src.Decimal.Exponent() < 0 { - return errors.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseInt(src.Decimal.String(), 10, 64) - if err != nil { - return errors.Errorf("cannot convert %v to %T", dst, *v) - } - *v = int64(n) - case *uint: - if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return errors.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseUint(src.Decimal.String(), 10, strconv.IntSize) - if err != nil { - return errors.Errorf("cannot convert %v to %T", dst, *v) - } - *v = uint(n) - case *uint8: - if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return errors.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseUint(src.Decimal.String(), 10, 8) - if err != nil { - return errors.Errorf("cannot convert %v to %T", dst, *v) - } - *v = uint8(n) - case *uint16: - if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return errors.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseUint(src.Decimal.String(), 10, 16) - if err != nil { - return errors.Errorf("cannot convert %v to %T", dst, *v) - } - *v = uint16(n) - case *uint32: - if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return errors.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseUint(src.Decimal.String(), 10, 32) - if err != nil { - return errors.Errorf("cannot convert %v to %T", dst, *v) - } - *v = uint32(n) - case *uint64: - if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return errors.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseUint(src.Decimal.String(), 10, 64) - if err != nil { - return errors.Errorf("cannot convert %v to %T", dst, *v) - } - *v = uint64(n) - default: - if nextDst, retry := pgtype.GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case pgtype.Null: - return pgtype.NullAssignTo(dst) - } - - return nil -} - -func (dst *Numeric) DecodeText(ci *pgtype.ConnInfo, src []byte) error { - if src == nil { - *dst = Numeric{Status: pgtype.Null} - return nil - } - - dec, err := decimal.NewFromString(string(src)) - if err != nil { - return err - } - - *dst = Numeric{Decimal: dec, Status: pgtype.Present} - return nil -} - -func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - if src == nil { - *dst = Numeric{Status: pgtype.Null} - return nil - } - - // For now at least, implement this in terms of pgtype.Numeric - - num := &pgtype.Numeric{} - if err := num.DecodeBinary(ci, src); err != nil { - return err - } - - buf, err := num.EncodeText(ci, nil) - if err != nil { - return err - } - - dec, err := decimal.NewFromString(string(buf)) - if err != nil { - return err - } - - *dst = Numeric{Decimal: dec, Status: pgtype.Present} - - return nil -} - -func (src *Numeric) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case pgtype.Null: - return nil, nil - case pgtype.Undefined: - return nil, errUndefined - } - - return append(buf, src.Decimal.String()...), nil -} - -func (src *Numeric) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case pgtype.Null: - return nil, nil - case pgtype.Undefined: - return nil, errUndefined - } - - // For now at least, implement this in terms of pgtype.Numeric - num := &pgtype.Numeric{} - if err := num.DecodeText(ci, []byte(src.Decimal.String())); err != nil { - return nil, err - } - - return num.EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *Numeric) Scan(src interface{}) error { - if src == nil { - *dst = Numeric{Status: pgtype.Null} - return nil - } - - switch src := src.(type) { - case float64: - *dst = Numeric{Decimal: decimal.NewFromFloat(src), Status: pgtype.Present} - return nil - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - return dst.DecodeText(nil, src) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *Numeric) Value() (driver.Value, error) { - switch src.Status { - case pgtype.Present: - return src.Decimal.Value() - case pgtype.Null: - return nil, nil - default: - return nil, errUndefined - } -} diff --git a/pgtype/ext/shopspring-numeric/decimal_test.go b/pgtype/ext/shopspring-numeric/decimal_test.go deleted file mode 100644 index 79121ef3b..000000000 --- a/pgtype/ext/shopspring-numeric/decimal_test.go +++ /dev/null @@ -1,286 +0,0 @@ -package numeric_test - -import ( - "fmt" - "math/big" - "math/rand" - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - shopspring "github.com/jackc/pgx/pgtype/ext/shopspring-numeric" - "github.com/jackc/pgx/pgtype/testutil" - "github.com/shopspring/decimal" -) - -func mustParseDecimal(t *testing.T, src string) decimal.Decimal { - dec, err := decimal.NewFromString(src) - if err != nil { - t.Fatal(err) - } - return dec -} - -func TestNumericNormalize(t *testing.T) { - testutil.TestSuccessfulNormalizeEqFunc(t, []testutil.NormalizeTest{ - { - SQL: "select '0'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, - }, - { - SQL: "select '1'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}, - }, - { - SQL: "select '10.00'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "10.00"), Status: pgtype.Present}, - }, - { - SQL: "select '1e-3'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present}, - }, - { - SQL: "select '-1'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, - }, - { - SQL: "select '10000'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "10000"), Status: pgtype.Present}, - }, - { - SQL: "select '3.14'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present}, - }, - { - SQL: "select '1.1'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1.1"), Status: pgtype.Present}, - }, - { - SQL: "select '100010001'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001"), Status: pgtype.Present}, - }, - { - SQL: "select '100010001.0001'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001.0001"), Status: pgtype.Present}, - }, - { - SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", - Value: &shopspring.Numeric{ - Decimal: mustParseDecimal(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"), - Status: pgtype.Present, - }, - }, - { - SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", - Value: &shopspring.Numeric{ - Decimal: mustParseDecimal(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), - Status: pgtype.Present, - }, - }, - { - SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", - Value: &shopspring.Numeric{ - Decimal: mustParseDecimal(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"), - Status: pgtype.Present, - }, - }, - }, func(aa, bb interface{}) bool { - a := aa.(shopspring.Numeric) - b := bb.(shopspring.Numeric) - - return a.Status == b.Status && a.Decimal.Equal(b.Decimal) - }) -} - -func TestNumericTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "100000"), Status: pgtype.Present}, - - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.1"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.01"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0001"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00001"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000001"), Status: pgtype.Present}, - - &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000123"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000000123"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0000000123"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000123"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001234567890123456789"), Status: pgtype.Present}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "4309132809320932980457137401234890237489238912983572189348951289375283573984571892758234678903467889512893489128589347891272139.8489235871258912789347891235879148795891238915678189467128957812395781238579189025891238901583915890128973578957912385798125789012378905238905471598123758923478294374327894237892234"), Status: pgtype.Present}, - &shopspring.Numeric{Status: pgtype.Null}, - }, func(aa, bb interface{}) bool { - a := aa.(shopspring.Numeric) - b := bb.(shopspring.Numeric) - - return a.Status == b.Status && a.Decimal.Equal(b.Decimal) - }) - -} - -func TestNumericTranscodeFuzz(t *testing.T) { - r := rand.New(rand.NewSource(0)) - max := &big.Int{} - max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) - - values := make([]interface{}, 0, 2000) - for i := 0; i < 500; i++ { - num := fmt.Sprintf("%s.%s", (&big.Int{}).Rand(r, max).String(), (&big.Int{}).Rand(r, max).String()) - negNum := "-" + num - values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, num), Status: pgtype.Present}) - values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, negNum), Status: pgtype.Present}) - } - - testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", values, - func(aa, bb interface{}) bool { - a := aa.(shopspring.Numeric) - b := bb.(shopspring.Numeric) - - return a.Status == b.Status && a.Decimal.Equal(b.Decimal) - }) -} - -func TestNumericSet(t *testing.T) { - type _int8 int8 - - successfulTests := []struct { - source interface{} - result *shopspring.Numeric - }{ - {source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: int16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: int32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: int64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: int8(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, - {source: int16(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, - {source: int32(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, - {source: int64(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, - {source: uint8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: uint16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: uint32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: uint64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: "1", result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: _int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, - {source: float64(1000), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1000"), Status: pgtype.Present}}, - {source: float64(1234), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1234"), Status: pgtype.Present}}, - {source: float64(12345678900), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345678900"), Status: pgtype.Present}}, - {source: float64(12345.678901), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345.678901"), Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - r := &shopspring.Numeric{} - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !(r.Status == tt.result.Status && r.Decimal.Equal(tt.result.Decimal)) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestNumericAssignTo(t *testing.T) { - type _int8 int8 - - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - var f32 float32 - var f64 float64 - var pf32 *float32 - var pf64 *float64 - - simpleTests := []struct { - src *shopspring.Numeric - dst interface{} - expected interface{} - }{ - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &f32, expected: float32(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &f64, expected: float64(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Status: pgtype.Present}, dst: &f32, expected: float32(4.2)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Status: pgtype.Present}, dst: &f64, expected: float64(4.2)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Status: pgtype.Present}, dst: &i64, expected: int64(42000)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src *shopspring.Numeric - dst interface{} - expected interface{} - }{ - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src *shopspring.Numeric - dst interface{} - }{ - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "150"), Status: pgtype.Present}, dst: &i8}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "40000"), Status: pgtype.Present}, dst: &i16}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui8}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui16}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui32}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui64}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui}, - {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &i32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/float4.go b/pgtype/float4.go index 2207594ac..241a25add 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -3,195 +3,321 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" + "fmt" "math" "strconv" - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" + "github.com/jackc/pgx/v5/internal/pgio" ) type Float4 struct { - Float float32 - Status Status + Float32 float32 + Valid bool } -func (dst *Float4) Set(src interface{}) error { +// ScanFloat64 implements the [Float64Scanner] interface. +func (f *Float4) ScanFloat64(n Float8) error { + *f = Float4{Float32: float32(n.Float64), Valid: n.Valid} + return nil +} + +// Float64Value implements the [Float64Valuer] interface. +func (f Float4) Float64Value() (Float8, error) { + return Float8{Float64: float64(f.Float32), Valid: f.Valid}, nil +} + +// ScanInt64 implements the [Int64Scanner] interface. +func (f *Float4) ScanInt64(n Int8) error { + *f = Float4{Float32: float32(n.Int64), Valid: n.Valid} + return nil +} + +// Int64Value implements the [Int64Valuer] interface. +func (f Float4) Int64Value() (Int8, error) { + return Int8{Int64: int64(f.Float32), Valid: f.Valid}, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (f *Float4) Scan(src any) error { if src == nil { - *dst = Float4{Status: Null} + *f = Float4{} return nil } - switch value := src.(type) { - case float32: - *dst = Float4{Float: value, Status: Present} + switch src := src.(type) { case float64: - *dst = Float4{Float: float32(value), Status: Present} - case int8: - *dst = Float4{Float: float32(value), Status: Present} - case uint8: - *dst = Float4{Float: float32(value), Status: Present} - case int16: - *dst = Float4{Float: float32(value), Status: Present} - case uint16: - *dst = Float4{Float: float32(value), Status: Present} - case int32: - f32 := float32(value) - if int32(f32) == value { - *dst = Float4{Float: f32, Status: Present} - } else { - return errors.Errorf("%v cannot be exactly represented as float32", value) - } - case uint32: - f32 := float32(value) - if uint32(f32) == value { - *dst = Float4{Float: f32, Status: Present} - } else { - return errors.Errorf("%v cannot be exactly represented as float32", value) - } - case int64: - f32 := float32(value) - if int64(f32) == value { - *dst = Float4{Float: f32, Status: Present} - } else { - return errors.Errorf("%v cannot be exactly represented as float32", value) - } - case uint64: - f32 := float32(value) - if uint64(f32) == value { - *dst = Float4{Float: f32, Status: Present} - } else { - return errors.Errorf("%v cannot be exactly represented as float32", value) - } - case int: - f32 := float32(value) - if int(f32) == value { - *dst = Float4{Float: f32, Status: Present} - } else { - return errors.Errorf("%v cannot be exactly represented as float32", value) - } - case uint: - f32 := float32(value) - if uint(f32) == value { - *dst = Float4{Float: f32, Status: Present} - } else { - return errors.Errorf("%v cannot be exactly represented as float32", value) - } + *f = Float4{Float32: float32(src), Valid: true} + return nil case string: - num, err := strconv.ParseFloat(value, 32) + n, err := strconv.ParseFloat(string(src), 32) if err != nil { return err } - *dst = Float4{Float: float32(num), Status: Present} - default: - if originalSrc, ok := underlyingNumberType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to Float8", value) + *f = Float4{Float32: float32(n), Valid: true} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (f Float4) Value() (driver.Value, error) { + if !f.Valid { + return nil, nil + } + return float64(f.Float32), nil +} + +// MarshalJSON implements the [encoding/json.Marshaler] interface. +func (f Float4) MarshalJSON() ([]byte, error) { + if !f.Valid { + return []byte("null"), nil + } + return json.Marshal(f.Float32) +} + +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. +func (f *Float4) UnmarshalJSON(b []byte) error { + var n *float32 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *f = Float4{} + } else { + *f = Float4{Float32: *n, Valid: true} } return nil } -func (dst *Float4) Get() interface{} { - switch dst.Status { - case Present: - return dst.Float - case Null: - return nil - default: - return dst.Status +type Float4Codec struct{} + +func (Float4Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Float4Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Float4Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case float32: + return encodePlanFloat4CodecBinaryFloat32{} + case Float64Valuer: + return encodePlanFloat4CodecBinaryFloat64Valuer{} + case Int64Valuer: + return encodePlanFloat4CodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case float32: + return encodePlanTextFloat32{} + case Float64Valuer: + return encodePlanTextFloat64Valuer{} + case Int64Valuer: + return encodePlanTextInt64Valuer{} + } } + + return nil } -func (src *Float4) AssignTo(dst interface{}) error { - return float64AssignTo(float64(src.Float), src.Status, dst) +type encodePlanFloat4CodecBinaryFloat32 struct{} + +func (encodePlanFloat4CodecBinaryFloat32) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(float32) + return pgio.AppendUint32(buf, math.Float32bits(n)), nil } -func (dst *Float4) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Float4{Status: Null} - return nil +type encodePlanTextFloat32 struct{} + +func (encodePlanTextFloat32) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(float32) + return append(buf, strconv.FormatFloat(float64(n), 'f', -1, 32)...), nil +} + +type encodePlanFloat4CodecBinaryFloat64Valuer struct{} + +func (encodePlanFloat4CodecBinaryFloat64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Float64Valuer).Float64Value() + if err != nil { + return nil, err } - n, err := strconv.ParseFloat(string(src), 32) + if !n.Valid { + return nil, nil + } + + return pgio.AppendUint32(buf, math.Float32bits(float32(n.Float64))), nil +} + +type encodePlanFloat4CodecBinaryInt64Valuer struct{} + +func (encodePlanFloat4CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() if err != nil { - return err + return nil, err + } + + if !n.Valid { + return nil, nil + } + + f := float32(n.Int64) + return pgio.AppendUint32(buf, math.Float32bits(f)), nil +} + +func (Float4Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case *float32: + return scanPlanBinaryFloat4ToFloat32{} + case Float64Scanner: + return scanPlanBinaryFloat4ToFloat64Scanner{} + case Int64Scanner: + return scanPlanBinaryFloat4ToInt64Scanner{} + case TextScanner: + return scanPlanBinaryFloat4ToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case *float32: + return scanPlanTextAnyToFloat32{} + case Float64Scanner: + return scanPlanTextAnyToFloat64Scanner{} + case Int64Scanner: + return scanPlanTextAnyToInt64Scanner{} + } } - *dst = Float4{Float: float32(n), Status: Present} return nil } -func (dst *Float4) DecodeBinary(ci *ConnInfo, src []byte) error { +type scanPlanBinaryFloat4ToFloat32 struct{} + +func (scanPlanBinaryFloat4ToFloat32) Scan(src []byte, dst any) error { if src == nil { - *dst = Float4{Status: Null} - return nil + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 4 { - return errors.Errorf("invalid length for float4: %v", len(src)) + return fmt.Errorf("invalid length for float4: %v", len(src)) } n := int32(binary.BigEndian.Uint32(src)) + f := (dst).(*float32) + *f = math.Float32frombits(uint32(n)) - *dst = Float4{Float: math.Float32frombits(uint32(n)), Status: Present} return nil } -func (src *Float4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined +type scanPlanBinaryFloat4ToFloat64Scanner struct{} + +func (scanPlanBinaryFloat4ToFloat64Scanner) Scan(src []byte, dst any) error { + s := (dst).(Float64Scanner) + + if src == nil { + return s.ScanFloat64(Float8{}) } - buf = append(buf, strconv.FormatFloat(float64(src.Float), 'f', -1, 32)...) - return buf, nil + if len(src) != 4 { + return fmt.Errorf("invalid length for float4: %v", len(src)) + } + + n := int32(binary.BigEndian.Uint32(src)) + return s.ScanFloat64(Float8{Float64: float64(math.Float32frombits(uint32(n))), Valid: true}) } -func (src *Float4) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined +type scanPlanBinaryFloat4ToInt64Scanner struct{} + +func (scanPlanBinaryFloat4ToInt64Scanner) Scan(src []byte, dst any) error { + s := (dst).(Int64Scanner) + + if src == nil { + return s.ScanInt64(Int8{}) } - buf = pgio.AppendUint32(buf, math.Float32bits(src.Float)) - return buf, nil + if len(src) != 4 { + return fmt.Errorf("invalid length for float4: %v", len(src)) + } + + ui32 := int32(binary.BigEndian.Uint32(src)) + f32 := math.Float32frombits(uint32(ui32)) + i64 := int64(f32) + if f32 != float32(i64) { + return fmt.Errorf("cannot losslessly convert %v to int64", f32) + } + + return s.ScanInt64(Int8{Int64: i64, Valid: true}) } -// Scan implements the database/sql Scanner interface. -func (dst *Float4) Scan(src interface{}) error { +type scanPlanBinaryFloat4ToTextScanner struct{} + +func (scanPlanBinaryFloat4ToTextScanner) Scan(src []byte, dst any) error { + s := (dst).(TextScanner) + if src == nil { - *dst = Float4{Status: Null} - return nil + return s.ScanText(Text{}) } - switch src := src.(type) { - case float64: - *dst = Float4{Float: float32(src), Status: Present} - return nil - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + if len(src) != 4 { + return fmt.Errorf("invalid length for float4: %v", len(src)) + } + + ui32 := int32(binary.BigEndian.Uint32(src)) + f32 := math.Float32frombits(uint32(ui32)) + + return s.ScanText(Text{String: strconv.FormatFloat(float64(f32), 'f', -1, 32), Valid: true}) +} + +type scanPlanTextAnyToFloat32 struct{} + +func (scanPlanTextAnyToFloat32) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + n, err := strconv.ParseFloat(string(src), 32) + if err != nil { + return err } - return errors.Errorf("cannot scan %T", src) + f := (dst).(*float32) + *f = float32(n) + + return nil } -// Value implements the database/sql/driver Valuer interface. -func (src *Float4) Value() (driver.Value, error) { - switch src.Status { - case Present: - return float64(src.Float), nil - case Null: +func (c Float4Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { return nil, nil - default: - return nil, errUndefined } + + var n float32 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return float64(n), nil +} + +func (c Float4Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var n float32 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil } diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go deleted file mode 100644 index 02c28caab..000000000 --- a/pgtype/float4_array.go +++ /dev/null @@ -1,300 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type Float4Array struct { - Elements []Float4 - Dimensions []ArrayDimension - Status Status -} - -func (dst *Float4Array) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = Float4Array{Status: Null} - return nil - } - - switch value := src.(type) { - - case []float32: - if value == nil { - *dst = Float4Array{Status: Null} - } else if len(value) == 0 { - *dst = Float4Array{Status: Present} - } else { - elements := make([]Float4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Float4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - default: - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to Float4Array", value) - } - - return nil -} - -func (dst *Float4Array) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *Float4Array) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - - case *[]float32: - *v = make([]float32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *Float4Array) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Float4Array{Status: Null} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Float4 - - if len(uta.Elements) > 0 { - elements = make([]Float4, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Float4 - var elemSrc []byte - if s != "NULL" { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = Float4Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} - - return nil -} - -func (dst *Float4Array) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Float4Array{Status: Null} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = Float4Array{Dimensions: arrayHeader.Dimensions, Status: Present} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Float4, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = Float4Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} - return nil -} - -func (src *Float4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src *Float4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("float4"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, errors.Errorf("unable to find oid for type name %v", "float4") - } - - for i := range src.Elements { - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Float4Array) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *Float4Array) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/float4_array_test.go b/pgtype/float4_array_test.go deleted file mode 100644 index 4d6511b4e..000000000 --- a/pgtype/float4_array_test.go +++ /dev/null @@ -1,152 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestFloat4ArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "float4[]", []interface{}{ - &pgtype.Float4Array{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.Float4Array{ - Elements: []pgtype.Float4{ - {Float: 1, Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Float4Array{Status: pgtype.Null}, - &pgtype.Float4Array{ - Elements: []pgtype.Float4{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Float: 6, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Float4Array{ - Elements: []pgtype.Float4{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestFloat4ArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Float4Array - }{ - { - source: []float32{1}, - result: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]float32)(nil)), - result: pgtype.Float4Array{Status: pgtype.Null}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.Float4Array - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestFloat4ArrayAssignTo(t *testing.T) { - var float32Slice []float32 - var namedFloat32Slice _float32Slice - - simpleTests := []struct { - src pgtype.Float4Array - dst interface{} - expected interface{} - }{ - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1.23, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &float32Slice, - expected: []float32{1.23}, - }, - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Float: 1.23, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &namedFloat32Slice, - expected: _float32Slice{1.23}, - }, - { - src: pgtype.Float4Array{Status: pgtype.Null}, - dst: &float32Slice, - expected: (([]float32)(nil)), - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Float4Array - dst interface{} - }{ - { - src: pgtype.Float4Array{ - Elements: []pgtype.Float4{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &float32Slice, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/pgtype/float4_test.go b/pgtype/float4_test.go index 2ed8d05d2..bc74921cf 100644 --- a/pgtype/float4_test.go +++ b/pgtype/float4_test.go @@ -1,149 +1,64 @@ package pgtype_test import ( - "reflect" + "context" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" ) -func TestFloat4Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "float4", []interface{}{ - &pgtype.Float4{Float: -1, Status: pgtype.Present}, - &pgtype.Float4{Float: 0, Status: pgtype.Present}, - &pgtype.Float4{Float: 0.00001, Status: pgtype.Present}, - &pgtype.Float4{Float: 1, Status: pgtype.Present}, - &pgtype.Float4{Float: 9999.99, Status: pgtype.Present}, - &pgtype.Float4{Float: 0, Status: pgtype.Null}, +func TestFloat4Codec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "float4", []pgxtest.ValueRoundTripTest{ + {pgtype.Float4{Float32: -1, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float32: -1, Valid: true})}, + {pgtype.Float4{Float32: 0, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float32: 0, Valid: true})}, + {pgtype.Float4{Float32: 1, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float32: 1, Valid: true})}, + {float32(0.00001), new(float32), isExpectedEq(float32(0.00001))}, + {float32(9999.99), new(float32), isExpectedEq(float32(9999.99))}, + {pgtype.Float4{}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{})}, + {int64(1), new(int64), isExpectedEq(int64(1))}, + {"1.23", new(string), isExpectedEq("1.23")}, + {nil, new(*float32), isExpectedEq((*float32)(nil))}, }) } -func TestFloat4Set(t *testing.T) { +func TestFloat4MarshalJSON(t *testing.T) { successfulTests := []struct { - source interface{} - result pgtype.Float4 + source pgtype.Float4 + result string }{ - {source: float32(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: float64(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: int8(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: int16(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: int32(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: int64(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: int8(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, - {source: int16(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, - {source: int32(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, - {source: int64(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, - {source: uint8(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: uint16(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: uint32(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: uint64(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: "1", result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, - {source: _int8(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: pgtype.Float4{Float32: 0}, result: "null"}, + {source: pgtype.Float4{Float32: 1.23, Valid: true}, result: "1.23"}, } - for i, tt := range successfulTests { - var r pgtype.Float4 - err := r.Set(tt.source) + r, err := tt.source.MarshalJSON() if err != nil { t.Errorf("%d: %v", i, err) } - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) } } } -func TestFloat4AssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - var f32 float32 - var f64 float64 - var pf32 *float32 - var pf64 *float64 - - simpleTests := []struct { - src pgtype.Float4 - dst interface{} - expected interface{} - }{ - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &f32, expected: float32(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &f64, expected: float64(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Float4 - dst interface{} - expected interface{} +func TestFloat4UnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Float4 }{ - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, - {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + {source: "null", result: pgtype.Float4{Float32: 0}}, + {source: "1.23", result: pgtype.Float4{Float32: 1.23, Valid: true}}, } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) + for i, tt := range successfulTests { + var r pgtype.Float4 + err := r.UnmarshalJSON([]byte(tt.source)) if err != nil { t.Errorf("%d: %v", i, err) } - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Float4 - dst interface{} - }{ - {src: pgtype.Float4{Float: 150, Status: pgtype.Present}, dst: &i8}, - {src: pgtype.Float4{Float: 40000, Status: pgtype.Present}, dst: &i16}, - {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui8}, - {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui16}, - {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui32}, - {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui64}, - {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui}, - {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &i32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) } } } diff --git a/pgtype/float8.go b/pgtype/float8.go index dd34f541f..54d6781ec 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -3,185 +3,367 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" + "fmt" "math" "strconv" - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" + "github.com/jackc/pgx/v5/internal/pgio" ) +type Float64Scanner interface { + ScanFloat64(Float8) error +} + +type Float64Valuer interface { + Float64Value() (Float8, error) +} + type Float8 struct { - Float float64 - Status Status + Float64 float64 + Valid bool +} + +// ScanFloat64 implements the [Float64Scanner] interface. +func (f *Float8) ScanFloat64(n Float8) error { + *f = n + return nil } -func (dst *Float8) Set(src interface{}) error { +// Float64Value implements the [Float64Valuer] interface. +func (f Float8) Float64Value() (Float8, error) { + return f, nil +} + +// ScanInt64 implements the [Int64Scanner] interface. +func (f *Float8) ScanInt64(n Int8) error { + *f = Float8{Float64: float64(n.Int64), Valid: n.Valid} + return nil +} + +// Int64Value implements the [Int64Valuer] interface. +func (f Float8) Int64Value() (Int8, error) { + return Int8{Int64: int64(f.Float64), Valid: f.Valid}, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (f *Float8) Scan(src any) error { if src == nil { - *dst = Float8{Status: Null} + *f = Float8{} return nil } - switch value := src.(type) { - case float32: - *dst = Float8{Float: float64(value), Status: Present} + switch src := src.(type) { case float64: - *dst = Float8{Float: value, Status: Present} - case int8: - *dst = Float8{Float: float64(value), Status: Present} - case uint8: - *dst = Float8{Float: float64(value), Status: Present} - case int16: - *dst = Float8{Float: float64(value), Status: Present} - case uint16: - *dst = Float8{Float: float64(value), Status: Present} - case int32: - *dst = Float8{Float: float64(value), Status: Present} - case uint32: - *dst = Float8{Float: float64(value), Status: Present} - case int64: - f64 := float64(value) - if int64(f64) == value { - *dst = Float8{Float: f64, Status: Present} - } else { - return errors.Errorf("%v cannot be exactly represented as float64", value) - } - case uint64: - f64 := float64(value) - if uint64(f64) == value { - *dst = Float8{Float: f64, Status: Present} - } else { - return errors.Errorf("%v cannot be exactly represented as float64", value) - } - case int: - f64 := float64(value) - if int(f64) == value { - *dst = Float8{Float: f64, Status: Present} - } else { - return errors.Errorf("%v cannot be exactly represented as float64", value) - } - case uint: - f64 := float64(value) - if uint(f64) == value { - *dst = Float8{Float: f64, Status: Present} - } else { - return errors.Errorf("%v cannot be exactly represented as float64", value) - } + *f = Float8{Float64: src, Valid: true} + return nil case string: - num, err := strconv.ParseFloat(value, 64) + n, err := strconv.ParseFloat(string(src), 64) if err != nil { return err } - *dst = Float8{Float: float64(num), Status: Present} - default: - if originalSrc, ok := underlyingNumberType(src); ok { - return dst.Set(originalSrc) + *f = Float8{Float64: n, Valid: true} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (f Float8) Value() (driver.Value, error) { + if !f.Valid { + return nil, nil + } + return f.Float64, nil +} + +// MarshalJSON implements the [encoding/json.Marshaler] interface. +func (f Float8) MarshalJSON() ([]byte, error) { + if !f.Valid { + return []byte("null"), nil + } + return json.Marshal(f.Float64) +} + +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. +func (f *Float8) UnmarshalJSON(b []byte) error { + var n *float64 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *f = Float8{} + } else { + *f = Float8{Float64: *n, Valid: true} + } + + return nil +} + +type Float8Codec struct{} + +func (Float8Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Float8Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Float8Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case float64: + return encodePlanFloat8CodecBinaryFloat64{} + case Float64Valuer: + return encodePlanFloat8CodecBinaryFloat64Valuer{} + case Int64Valuer: + return encodePlanFloat8CodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case float64: + return encodePlanTextFloat64{} + case Float64Valuer: + return encodePlanTextFloat64Valuer{} + case Int64Valuer: + return encodePlanTextInt64Valuer{} } - return errors.Errorf("cannot convert %v to Float8", value) } return nil } -func (dst *Float8) Get() interface{} { - switch dst.Status { - case Present: - return dst.Float - case Null: - return nil - default: - return dst.Status +type encodePlanFloat8CodecBinaryFloat64 struct{} + +func (encodePlanFloat8CodecBinaryFloat64) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(float64) + return pgio.AppendUint64(buf, math.Float64bits(n)), nil +} + +type encodePlanTextFloat64 struct{} + +func (encodePlanTextFloat64) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(float64) + return append(buf, strconv.FormatFloat(n, 'f', -1, 64)...), nil +} + +type encodePlanFloat8CodecBinaryFloat64Valuer struct{} + +func (encodePlanFloat8CodecBinaryFloat64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Float64Valuer).Float64Value() + if err != nil { + return nil, err } + + if !n.Valid { + return nil, nil + } + + return pgio.AppendUint64(buf, math.Float64bits(n.Float64)), nil } -func (src *Float8) AssignTo(dst interface{}) error { - return float64AssignTo(src.Float, src.Status, dst) +type encodePlanTextFloat64Valuer struct{} + +func (encodePlanTextFloat64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Float64Valuer).Float64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + return append(buf, strconv.FormatFloat(n.Float64, 'f', -1, 64)...), nil } -func (dst *Float8) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Float8{Status: Null} - return nil +type encodePlanFloat8CodecBinaryInt64Valuer struct{} + +func (encodePlanFloat8CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err } - n, err := strconv.ParseFloat(string(src), 64) + if !n.Valid { + return nil, nil + } + + f := float64(n.Int64) + return pgio.AppendUint64(buf, math.Float64bits(f)), nil +} + +type encodePlanTextInt64Valuer struct{} + +func (encodePlanTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() if err != nil { - return err + return nil, err + } + + if !n.Valid { + return nil, nil + } + + return append(buf, strconv.FormatInt(n.Int64, 10)...), nil +} + +func (Float8Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case *float64: + return scanPlanBinaryFloat8ToFloat64{} + case Float64Scanner: + return scanPlanBinaryFloat8ToFloat64Scanner{} + case Int64Scanner: + return scanPlanBinaryFloat8ToInt64Scanner{} + case TextScanner: + return scanPlanBinaryFloat8ToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case *float64: + return scanPlanTextAnyToFloat64{} + case Float64Scanner: + return scanPlanTextAnyToFloat64Scanner{} + case Int64Scanner: + return scanPlanTextAnyToInt64Scanner{} + } } - *dst = Float8{Float: n, Status: Present} return nil } -func (dst *Float8) DecodeBinary(ci *ConnInfo, src []byte) error { +type scanPlanBinaryFloat8ToFloat64 struct{} + +func (scanPlanBinaryFloat8ToFloat64) Scan(src []byte, dst any) error { if src == nil { - *dst = Float8{Status: Null} - return nil + return fmt.Errorf("cannot scan NULL into %T", dst) } if len(src) != 8 { - return errors.Errorf("invalid length for float4: %v", len(src)) + return fmt.Errorf("invalid length for float8: %v", len(src)) } n := int64(binary.BigEndian.Uint64(src)) + f := (dst).(*float64) + *f = math.Float64frombits(uint64(n)) - *dst = Float8{Float: math.Float64frombits(uint64(n)), Status: Present} return nil } -func (src *Float8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined +type scanPlanBinaryFloat8ToFloat64Scanner struct{} + +func (scanPlanBinaryFloat8ToFloat64Scanner) Scan(src []byte, dst any) error { + s := (dst).(Float64Scanner) + + if src == nil { + return s.ScanFloat64(Float8{}) } - buf = append(buf, strconv.FormatFloat(float64(src.Float), 'f', -1, 64)...) - return buf, nil + if len(src) != 8 { + return fmt.Errorf("invalid length for float8: %v", len(src)) + } + + n := int64(binary.BigEndian.Uint64(src)) + return s.ScanFloat64(Float8{Float64: math.Float64frombits(uint64(n)), Valid: true}) } -func (src *Float8) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined +type scanPlanBinaryFloat8ToInt64Scanner struct{} + +func (scanPlanBinaryFloat8ToInt64Scanner) Scan(src []byte, dst any) error { + s := (dst).(Int64Scanner) + + if src == nil { + return s.ScanInt64(Int8{}) } - buf = pgio.AppendUint64(buf, math.Float64bits(src.Float)) - return buf, nil + if len(src) != 8 { + return fmt.Errorf("invalid length for float8: %v", len(src)) + } + + ui64 := int64(binary.BigEndian.Uint64(src)) + f64 := math.Float64frombits(uint64(ui64)) + i64 := int64(f64) + if f64 != float64(i64) { + return fmt.Errorf("cannot losslessly convert %v to int64", f64) + } + + return s.ScanInt64(Int8{Int64: i64, Valid: true}) } -// Scan implements the database/sql Scanner interface. -func (dst *Float8) Scan(src interface{}) error { +type scanPlanBinaryFloat8ToTextScanner struct{} + +func (scanPlanBinaryFloat8ToTextScanner) Scan(src []byte, dst any) error { + s := (dst).(TextScanner) + if src == nil { - *dst = Float8{Status: Null} - return nil + return s.ScanText(Text{}) } - switch src := src.(type) { - case float64: - *dst = Float8{Float: src, Status: Present} - return nil - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + if len(src) != 8 { + return fmt.Errorf("invalid length for float8: %v", len(src)) + } + + ui64 := int64(binary.BigEndian.Uint64(src)) + f64 := math.Float64frombits(uint64(ui64)) + + return s.ScanText(Text{String: strconv.FormatFloat(f64, 'f', -1, 64), Valid: true}) +} + +type scanPlanTextAnyToFloat64 struct{} + +func (scanPlanTextAnyToFloat64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + n, err := strconv.ParseFloat(string(src), 64) + if err != nil { + return err + } + + f := (dst).(*float64) + *f = n + + return nil +} + +type scanPlanTextAnyToFloat64Scanner struct{} + +func (scanPlanTextAnyToFloat64Scanner) Scan(src []byte, dst any) error { + s := (dst).(Float64Scanner) + + if src == nil { + return s.ScanFloat64(Float8{}) + } + + n, err := strconv.ParseFloat(string(src), 64) + if err != nil { + return err } - return errors.Errorf("cannot scan %T", src) + return s.ScanFloat64(Float8{Float64: n, Valid: true}) +} + +func (c Float8Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return c.DecodeValue(m, oid, format, src) } -// Value implements the database/sql/driver Valuer interface. -func (src *Float8) Value() (driver.Value, error) { - switch src.Status { - case Present: - return src.Float, nil - case Null: +func (c Float8Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { return nil, nil - default: - return nil, errUndefined } + + var n float64 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil } diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go deleted file mode 100644 index b92a8205a..000000000 --- a/pgtype/float8_array.go +++ /dev/null @@ -1,300 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type Float8Array struct { - Elements []Float8 - Dimensions []ArrayDimension - Status Status -} - -func (dst *Float8Array) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = Float8Array{Status: Null} - return nil - } - - switch value := src.(type) { - - case []float64: - if value == nil { - *dst = Float8Array{Status: Null} - } else if len(value) == 0 { - *dst = Float8Array{Status: Present} - } else { - elements := make([]Float8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Float8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - default: - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to Float8Array", value) - } - - return nil -} - -func (dst *Float8Array) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *Float8Array) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - - case *[]float64: - *v = make([]float64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *Float8Array) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Float8Array{Status: Null} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Float8 - - if len(uta.Elements) > 0 { - elements = make([]Float8, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Float8 - var elemSrc []byte - if s != "NULL" { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = Float8Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} - - return nil -} - -func (dst *Float8Array) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Float8Array{Status: Null} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = Float8Array{Dimensions: arrayHeader.Dimensions, Status: Present} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Float8, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = Float8Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} - return nil -} - -func (src *Float8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src *Float8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("float8"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, errors.Errorf("unable to find oid for type name %v", "float8") - } - - for i := range src.Elements { - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Float8Array) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *Float8Array) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/float8_array_test.go b/pgtype/float8_array_test.go deleted file mode 100644 index ff8e3b261..000000000 --- a/pgtype/float8_array_test.go +++ /dev/null @@ -1,152 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestFloat8ArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "float8[]", []interface{}{ - &pgtype.Float8Array{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.Float8Array{ - Elements: []pgtype.Float8{ - {Float: 1, Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Float8Array{Status: pgtype.Null}, - &pgtype.Float8Array{ - Elements: []pgtype.Float8{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Float: 6, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Float8Array{ - Elements: []pgtype.Float8{ - {Float: 1, Status: pgtype.Present}, - {Float: 2, Status: pgtype.Present}, - {Float: 3, Status: pgtype.Present}, - {Float: 4, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestFloat8ArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Float8Array - }{ - { - source: []float64{1}, - result: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]float64)(nil)), - result: pgtype.Float8Array{Status: pgtype.Null}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.Float8Array - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestFloat8ArrayAssignTo(t *testing.T) { - var float64Slice []float64 - var namedFloat64Slice _float64Slice - - simpleTests := []struct { - src pgtype.Float8Array - dst interface{} - expected interface{} - }{ - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1.23, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &float64Slice, - expected: []float64{1.23}, - }, - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Float: 1.23, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &namedFloat64Slice, - expected: _float64Slice{1.23}, - }, - { - src: pgtype.Float8Array{Status: pgtype.Null}, - dst: &float64Slice, - expected: (([]float64)(nil)), - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Float8Array - dst interface{} - }{ - { - src: pgtype.Float8Array{ - Elements: []pgtype.Float8{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &float64Slice, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/pgtype/float8_test.go b/pgtype/float8_test.go index 46fc8d5d8..64593d97c 100644 --- a/pgtype/float8_test.go +++ b/pgtype/float8_test.go @@ -1,149 +1,64 @@ package pgtype_test import ( - "reflect" + "context" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" ) -func TestFloat8Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "float8", []interface{}{ - &pgtype.Float8{Float: -1, Status: pgtype.Present}, - &pgtype.Float8{Float: 0, Status: pgtype.Present}, - &pgtype.Float8{Float: 0.00001, Status: pgtype.Present}, - &pgtype.Float8{Float: 1, Status: pgtype.Present}, - &pgtype.Float8{Float: 9999.99, Status: pgtype.Present}, - &pgtype.Float8{Float: 0, Status: pgtype.Null}, +func TestFloat8Codec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "float8", []pgxtest.ValueRoundTripTest{ + {pgtype.Float8{Float64: -1, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float64: -1, Valid: true})}, + {pgtype.Float8{Float64: 0, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float64: 0, Valid: true})}, + {pgtype.Float8{Float64: 1, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float64: 1, Valid: true})}, + {float64(0.00001), new(float64), isExpectedEq(float64(0.00001))}, + {float64(9999.99), new(float64), isExpectedEq(float64(9999.99))}, + {pgtype.Float8{}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{})}, + {int64(1), new(int64), isExpectedEq(int64(1))}, + {"1.23", new(string), isExpectedEq("1.23")}, + {nil, new(*float64), isExpectedEq((*float64)(nil))}, }) } -func TestFloat8Set(t *testing.T) { +func TestFloat8MarshalJSON(t *testing.T) { successfulTests := []struct { - source interface{} - result pgtype.Float8 + source pgtype.Float8 + result string }{ - {source: float32(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: float64(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: int8(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: int16(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: int32(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: int64(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: int8(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, - {source: int16(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, - {source: int32(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, - {source: int64(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, - {source: uint8(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: uint16(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: uint32(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: uint64(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: "1", result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, - {source: _int8(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: pgtype.Float8{Float64: 0}, result: "null"}, + {source: pgtype.Float8{Float64: 1.23, Valid: true}, result: "1.23"}, } - for i, tt := range successfulTests { - var r pgtype.Float8 - err := r.Set(tt.source) + r, err := tt.source.MarshalJSON() if err != nil { t.Errorf("%d: %v", i, err) } - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) } } } -func TestFloat8AssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - var f32 float32 - var f64 float64 - var pf32 *float32 - var pf64 *float64 - - simpleTests := []struct { - src pgtype.Float8 - dst interface{} - expected interface{} - }{ - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &f32, expected: float32(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &f64, expected: float64(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Float8 - dst interface{} - expected interface{} +func TestFloat8UnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Float8 }{ - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, - {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + {source: "null", result: pgtype.Float8{Float64: 0}}, + {source: "1.23", result: pgtype.Float8{Float64: 1.23, Valid: true}}, } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) + for i, tt := range successfulTests { + var r pgtype.Float8 + err := r.UnmarshalJSON([]byte(tt.source)) if err != nil { t.Errorf("%d: %v", i, err) } - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Float8 - dst interface{} - }{ - {src: pgtype.Float8{Float: 150, Status: pgtype.Present}, dst: &i8}, - {src: pgtype.Float8{Float: 40000, Status: pgtype.Present}, dst: &i16}, - {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui8}, - {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui16}, - {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui32}, - {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui64}, - {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui}, - {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &i32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) } } } diff --git a/pgtype/generic_binary.go b/pgtype/generic_binary.go deleted file mode 100644 index 2596ecae1..000000000 --- a/pgtype/generic_binary.go +++ /dev/null @@ -1,39 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" -) - -// GenericBinary is a placeholder for binary format values that no other type exists -// to handle. -type GenericBinary Bytea - -func (dst *GenericBinary) Set(src interface{}) error { - return (*Bytea)(dst).Set(src) -} - -func (dst *GenericBinary) Get() interface{} { - return (*Bytea)(dst).Get() -} - -func (src *GenericBinary) AssignTo(dst interface{}) error { - return (*Bytea)(src).AssignTo(dst) -} - -func (dst *GenericBinary) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*Bytea)(dst).DecodeBinary(ci, src) -} - -func (src *GenericBinary) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Bytea)(src).EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *GenericBinary) Scan(src interface{}) error { - return (*Bytea)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *GenericBinary) Value() (driver.Value, error) { - return (*Bytea)(src).Value() -} diff --git a/pgtype/generic_text.go b/pgtype/generic_text.go deleted file mode 100644 index 0e3db9def..000000000 --- a/pgtype/generic_text.go +++ /dev/null @@ -1,39 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" -) - -// GenericText is a placeholder for text format values that no other type exists -// to handle. -type GenericText Text - -func (dst *GenericText) Set(src interface{}) error { - return (*Text)(dst).Set(src) -} - -func (dst *GenericText) Get() interface{} { - return (*Text)(dst).Get() -} - -func (src *GenericText) AssignTo(dst interface{}) error { - return (*Text)(src).AssignTo(dst) -} - -func (dst *GenericText) DecodeText(ci *ConnInfo, src []byte) error { - return (*Text)(dst).DecodeText(ci, src) -} - -func (src *GenericText) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Text)(src).EncodeText(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *GenericText) Scan(src interface{}) error { - return (*Text)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *GenericText) Value() (driver.Value, error) { - return (*Text)(src).Value() -} diff --git a/pgtype/hstore.go b/pgtype/hstore.go index 347446ae9..ef864928f 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -1,434 +1,487 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" + "errors" + "fmt" "strings" - "unicode" - "unicode/utf8" - "github.com/pkg/errors" - - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) +type HstoreScanner interface { + ScanHstore(v Hstore) error +} + +type HstoreValuer interface { + HstoreValue() (Hstore, error) +} + // Hstore represents an hstore column that can be null or have null values // associated with its keys. -type Hstore struct { - Map map[string]Text - Status Status +type Hstore map[string]*string + +// ScanHstore implements the [HstoreScanner] interface. +func (h *Hstore) ScanHstore(v Hstore) error { + *h = v + return nil +} + +// HstoreValue implements the [HstoreValuer] interface. +func (h Hstore) HstoreValue() (Hstore, error) { + return h, nil } -func (dst *Hstore) Set(src interface{}) error { +// Scan implements the [database/sql.Scanner] interface. +func (h *Hstore) Scan(src any) error { if src == nil { - *dst = Hstore{Status: Null} + *h = nil return nil } - switch value := src.(type) { - case map[string]string: - m := make(map[string]Text, len(value)) - for k, v := range value { - m[k] = Text{String: v, Status: Present} - } - *dst = Hstore{Map: m, Status: Present} - default: - return errors.Errorf("cannot convert %v to Hstore", src) + switch src := src.(type) { + case string: + return scanPlanTextAnyToHstoreScanner{}.scanString(src, h) } - return nil + return fmt.Errorf("cannot scan %T", src) } -func (dst *Hstore) Get() interface{} { - switch dst.Status { - case Present: - return dst.Map - case Null: - return nil - default: - return dst.Status +// Value implements the [database/sql/driver.Valuer] interface. +func (h Hstore) Value() (driver.Value, error) { + if h == nil { + return nil, nil } -} -func (src *Hstore) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *map[string]string: - *v = make(map[string]string, len(src.Map)) - for k, val := range src.Map { - if val.Status != Present { - return errors.Errorf("cannot decode %v into %T", src, dst) - } - (*v)[k] = val.String - } - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) + buf, err := HstoreCodec{}.PlanEncode(nil, 0, TextFormatCode, h).Encode(h, nil) + if err != nil { + return nil, err } + return string(buf), err +} - return errors.Errorf("cannot decode %v into %T", src, dst) +type HstoreCodec struct{} + +func (HstoreCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode } -func (dst *Hstore) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Hstore{Status: Null} - return nil - } +func (HstoreCodec) PreferredFormat() int16 { + return BinaryFormatCode +} - keys, values, err := parseHstore(string(src)) - if err != nil { - return err +func (HstoreCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(HstoreValuer); !ok { + return nil } - m := make(map[string]Text, len(keys)) - for i := range keys { - m[keys[i]] = values[i] + switch format { + case BinaryFormatCode: + return encodePlanHstoreCodecBinary{} + case TextFormatCode: + return encodePlanHstoreCodecText{} } - *dst = Hstore{Map: m, Status: Present} return nil } -func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Hstore{Status: Null} - return nil - } +type encodePlanHstoreCodecBinary struct{} - rp := 0 +func (encodePlanHstoreCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + hstore, err := value.(HstoreValuer).HstoreValue() + if err != nil { + return nil, err + } - if len(src[rp:]) < 4 { - return errors.Errorf("hstore incomplete %v", src) + if hstore == nil { + return nil, nil } - pairCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - m := make(map[string]Text, pairCount) + buf = pgio.AppendInt32(buf, int32(len(hstore))) - for i := 0; i < pairCount; i++ { - if len(src[rp:]) < 4 { - return errors.Errorf("hstore incomplete %v", src) - } - keyLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 + for k, v := range hstore { + buf = pgio.AppendInt32(buf, int32(len(k))) + buf = append(buf, k...) - if len(src[rp:]) < keyLen { - return errors.Errorf("hstore incomplete %v", src) + if v == nil { + buf = pgio.AppendInt32(buf, -1) + } else { + buf = pgio.AppendInt32(buf, int32(len(*v))) + buf = append(buf, (*v)...) } - key := string(src[rp : rp+keyLen]) - rp += keyLen + } - if len(src[rp:]) < 4 { - return errors.Errorf("hstore incomplete %v", src) - } - valueLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 + return buf, nil +} - var valueBuf []byte - if valueLen >= 0 { - valueBuf = src[rp : rp+valueLen] - } - rp += valueLen +type encodePlanHstoreCodecText struct{} - var value Text - err := value.DecodeBinary(ci, valueBuf) - if err != nil { - return err - } - m[key] = value +func (encodePlanHstoreCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + hstore, err := value.(HstoreValuer).HstoreValue() + if err != nil { + return nil, err } - *dst = Hstore{Map: m, Status: Present} - - return nil -} - -func (src *Hstore) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined + if len(hstore) == 0 { + // distinguish between empty and nil: Not strictly required by Postgres, since its protocol + // explicitly marks NULL column values separately. However, the Binary codec does this, and + // this means we can "round trip" Encode and Scan without data loss. + // nil: []byte(nil); empty: []byte{} + if hstore == nil { + return nil, nil + } + return []byte{}, nil } firstPair := true - for k, v := range src.Map { + for k, v := range hstore { if firstPair { firstPair = false } else { - buf = append(buf, ',') + buf = append(buf, ',', ' ') } - buf = append(buf, quoteHstoreElementIfNeeded(k)...) + // unconditionally quote hstore keys/values like Postgres does + // this avoids a Mac OS X Postgres hstore parsing bug: + // https://www.postgresql.org/message-id/CA%2BHWA9awUW0%2BRV_gO9r1ABZwGoZxPztcJxPy8vMFSTbTfi4jig%40mail.gmail.com + buf = append(buf, '"') + buf = append(buf, quoteArrayReplacer.Replace(k)...) + buf = append(buf, '"') buf = append(buf, "=>"...) - elemBuf, err := v.EncodeText(ci, nil) - if err != nil { - return nil, err - } - - if elemBuf == nil { + if v == nil { buf = append(buf, "NULL"...) } else { - buf = append(buf, quoteHstoreElementIfNeeded(string(elemBuf))...) + buf = append(buf, '"') + buf = append(buf, quoteArrayReplacer.Replace(*v)...) + buf = append(buf, '"') } } return buf, nil } -func (src *Hstore) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined +func (HstoreCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case HstoreScanner: + return scanPlanBinaryHstoreToHstoreScanner{} + } + case TextFormatCode: + switch target.(type) { + case HstoreScanner: + return scanPlanTextAnyToHstoreScanner{} + } } - buf = pgio.AppendInt32(buf, int32(len(src.Map))) + return nil +} - var err error - for k, v := range src.Map { - buf = pgio.AppendInt32(buf, int32(len(k))) - buf = append(buf, k...) +type scanPlanBinaryHstoreToHstoreScanner struct{} - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) +func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error { + scanner := (dst).(HstoreScanner) - elemBuf, err := v.EncodeText(ci, buf) - if err != nil { - return nil, err + if src == nil { + return scanner.ScanHstore(Hstore(nil)) + } + + rp := 0 + + const uint32Len = 4 + if len(src[rp:]) < uint32Len { + return fmt.Errorf("hstore incomplete %v", src) + } + pairCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += uint32Len + + hstore := make(Hstore, pairCount) + // one allocation for all *string, rather than one per string, just like text parsing + valueStrings := make([]string, pairCount) + + for i := 0; i < pairCount; i++ { + if len(src[rp:]) < uint32Len { + return fmt.Errorf("hstore incomplete %v", src) } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + keyLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += uint32Len + + if len(src[rp:]) < keyLen { + return fmt.Errorf("hstore incomplete %v", src) + } + key := string(src[rp : rp+keyLen]) + rp += keyLen + + if len(src[rp:]) < uint32Len { + return fmt.Errorf("hstore incomplete %v", src) + } + valueLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + if valueLen >= 0 { + valueStrings[i] = string(src[rp : rp+valueLen]) + rp += valueLen + + hstore[key] = &valueStrings[i] + } else { + hstore[key] = nil } } - return buf, err + return scanner.ScanHstore(hstore) } -var quoteHstoreReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) +type scanPlanTextAnyToHstoreScanner struct{} + +func (s scanPlanTextAnyToHstoreScanner) Scan(src []byte, dst any) error { + scanner := (dst).(HstoreScanner) -func quoteHstoreElement(src string) string { - return `"` + quoteArrayReplacer.Replace(src) + `"` + if src == nil { + return scanner.ScanHstore(Hstore(nil)) + } + return s.scanString(string(src), scanner) } -func quoteHstoreElementIfNeeded(src string) string { - if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || strings.ContainsAny(src, ` {},"\=>`) { - return quoteArrayElement(src) +// scanString does not return nil hstore values because string cannot be nil. +func (scanPlanTextAnyToHstoreScanner) scanString(src string, scanner HstoreScanner) error { + hstore, err := parseHstore(src) + if err != nil { + return err } - return src + return scanner.ScanHstore(hstore) } -const ( - hsPre = iota - hsKey - hsSep - hsVal - hsNul - hsNext -) +func (c HstoreCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c HstoreCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var hstore Hstore + err := codecScan(c, m, oid, format, src, &hstore) + if err != nil { + return nil, err + } + return hstore, nil +} type hstoreParser struct { - str string - pos int + str string + pos int + nextBackslash int } func newHSP(in string) *hstoreParser { return &hstoreParser{ - pos: 0, - str: in, + pos: 0, + str: in, + nextBackslash: strings.IndexByte(in, '\\'), } } -func (p *hstoreParser) Consume() (r rune, end bool) { +func (p *hstoreParser) atEnd() bool { + return p.pos >= len(p.str) +} + +// consume returns the next byte of the string, or end if the string is done. +func (p *hstoreParser) consume() (b byte, end bool) { if p.pos >= len(p.str) { - end = true - return + return 0, true } - r, w := utf8.DecodeRuneInString(p.str[p.pos:]) - p.pos += w - return + b = p.str[p.pos] + p.pos++ + return b, false } -func (p *hstoreParser) Peek() (r rune, end bool) { - if p.pos >= len(p.str) { - end = true - return +func unexpectedByteErr(actualB, expectedB byte) error { + return fmt.Errorf("expected '%c' ('%#v'); found '%c' ('%#v')", expectedB, expectedB, actualB, actualB) +} + +// consumeExpectedByte consumes expectedB from the string, or returns an error. +func (p *hstoreParser) consumeExpectedByte(expectedB byte) error { + nextB, end := p.consume() + if end { + return fmt.Errorf("expected '%c' ('%#v'); found end", expectedB, expectedB) + } + if nextB != expectedB { + return unexpectedByteErr(nextB, expectedB) } - r, _ = utf8.DecodeRuneInString(p.str[p.pos:]) - return + return nil } -// parseHstore parses the string representation of an hstore column (the same -// you would get from an ordinary SELECT) into two slices of keys and values. it -// is used internally in the default parsing of hstores. -func parseHstore(s string) (k []string, v []Text, err error) { - if s == "" { - return +// consumeExpected2 consumes two expected bytes or returns an error. +// This was a bit faster than using a string argument (better inlining? Not sure). +func (p *hstoreParser) consumeExpected2(one, two byte) error { + if p.pos+2 > len(p.str) { + return errors.New("unexpected end of string") + } + if p.str[p.pos] != one { + return unexpectedByteErr(p.str[p.pos], one) } + if p.str[p.pos+1] != two { + return unexpectedByteErr(p.str[p.pos+1], two) + } + p.pos += 2 + return nil +} - buf := bytes.Buffer{} - keys := []string{} - values := []Text{} - p := newHSP(s) +var errEOSInQuoted = errors.New(`found end before closing double-quote ('"')`) - r, end := p.Consume() - state := hsPre +// consumeDoubleQuoted consumes a double-quoted string from p. The double quote must have been +// parsed already. This copies the string from the backing string so it can be garbage collected. +func (p *hstoreParser) consumeDoubleQuoted() (string, error) { + // fast path: assume most keys/values do not contain escapes + nextDoubleQuote := strings.IndexByte(p.str[p.pos:], '"') + if nextDoubleQuote == -1 { + return "", errEOSInQuoted + } + nextDoubleQuote += p.pos + if p.nextBackslash == -1 || p.nextBackslash > nextDoubleQuote { + // clone the string from the source string to ensure it can be garbage collected separately + // TODO: use strings.Clone on Go 1.20; this could get optimized away + s := strings.Clone(p.str[p.pos:nextDoubleQuote]) + p.pos = nextDoubleQuote + 1 + return s, nil + } - for !end { - switch state { - case hsPre: - if r == '"' { - state = hsKey - } else { - err = errors.New("String does not begin with \"") - } - case hsKey: - switch r { - case '"': //End of the key - if buf.Len() == 0 { - err = errors.New("Empty Key is invalid") - } else { - keys = append(keys, buf.String()) - buf = bytes.Buffer{} - state = hsSep - } - case '\\': //Potential escaped character - n, end := p.Consume() - switch { - case end: - err = errors.New("Found EOS in key, expecting character or \"") - case n == '"', n == '\\': - buf.WriteRune(n) - default: - buf.WriteRune(r) - buf.WriteRune(n) - } - default: //Any other character - buf.WriteRune(r) - } - case hsSep: - if r == '=' { - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after '=', expecting '>'") - case r == '>': - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'") - case r == '"': - state = hsVal - case r == 'N': - state = hsNul - default: - err = errors.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) - } - default: - err = errors.Errorf("Invalid character after '=', expecting '>'") - } - } else { - err = errors.Errorf("Invalid character '%c' after value, expecting '='", r) - } - case hsVal: - switch r { - case '"': //End of the value - values = append(values, Text{String: buf.String(), Status: Present}) - buf = bytes.Buffer{} - state = hsNext - case '\\': //Potential escaped character - n, end := p.Consume() - switch { - case end: - err = errors.New("Found EOS in key, expecting character or \"") - case n == '"', n == '\\': - buf.WriteRune(n) - default: - buf.WriteRune(r) - buf.WriteRune(n) - } - default: //Any other character - buf.WriteRune(r) - } - case hsNul: - nulBuf := make([]rune, 3) - nulBuf[0] = r - for i := 1; i < 3; i++ { - r, end = p.Consume() - if end { - err = errors.New("Found EOS in NULL value") - return - } - nulBuf[i] = r - } - if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' { - values = append(values, Text{Status: Null}) - state = hsNext - } else { - err = errors.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) + // slow path: string contains escapes + s, err := p.consumeDoubleQuotedWithEscapes(p.nextBackslash) + p.nextBackslash = strings.IndexByte(p.str[p.pos:], '\\') + if p.nextBackslash != -1 { + p.nextBackslash += p.pos + } + return s, err +} + +// consumeDoubleQuotedWithEscapes consumes a double-quoted string containing escapes, starting +// at p.pos, and with the first backslash at firstBackslash. This copies the string so it can be +// garbage collected separately. +func (p *hstoreParser) consumeDoubleQuotedWithEscapes(firstBackslash int) (string, error) { + // copy the prefix that does not contain backslashes + var builder strings.Builder + builder.WriteString(p.str[p.pos:firstBackslash]) + + // skip to the backslash + p.pos = firstBackslash + + // copy bytes until the end, unescaping backslashes + for { + nextB, end := p.consume() + if end { + return "", errEOSInQuoted + } else if nextB == '"' { + break + } else if nextB == '\\' { + // escape: skip the backslash and copy the char + nextB, end = p.consume() + if end { + return "", errEOSInQuoted } - case hsNext: - if r == ',' { - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after ',', expcting space") - case (unicode.IsSpace(r)): - r, end = p.Consume() - state = hsKey - default: - err = errors.Errorf("Invalid character '%c' after ', ', expecting \"", r) - } - } else { - err = errors.Errorf("Invalid character '%c' after value, expecting ','", r) + if !(nextB == '\\' || nextB == '"') { + return "", fmt.Errorf("unexpected escape in quoted string: found '%#v'", nextB) } + builder.WriteByte(nextB) + } else { + // normal byte: copy it + builder.WriteByte(nextB) } + } + return builder.String(), nil +} +// consumePairSeparator consumes the Hstore pair separator ", " or returns an error. +func (p *hstoreParser) consumePairSeparator() error { + return p.consumeExpected2(',', ' ') +} + +// consumeKVSeparator consumes the Hstore key/value separator "=>" or returns an error. +func (p *hstoreParser) consumeKVSeparator() error { + return p.consumeExpected2('=', '>') +} + +// consumeDoubleQuotedOrNull consumes the Hstore key/value separator "=>" or returns an error. +func (p *hstoreParser) consumeDoubleQuotedOrNull() (Text, error) { + // peek at the next byte + if p.atEnd() { + return Text{}, errors.New("found end instead of value") + } + next := p.str[p.pos] + if next == 'N' { + // must be the exact string NULL: use consumeExpected2 twice + err := p.consumeExpected2('N', 'U') if err != nil { - return + return Text{}, err } - r, end = p.Consume() + err = p.consumeExpected2('L', 'L') + if err != nil { + return Text{}, err + } + return Text{String: "", Valid: false}, nil + } else if next != '"' { + return Text{}, unexpectedByteErr(next, '"') } - if state != hsNext { - err = errors.New("Improperly formatted hstore") - return + + // skip the double quote + p.pos += 1 + s, err := p.consumeDoubleQuoted() + if err != nil { + return Text{}, err } - k = keys - v = values - return + return Text{String: s, Valid: true}, nil } -// Scan implements the database/sql Scanner interface. -func (dst *Hstore) Scan(src interface{}) error { - if src == nil { - *dst = Hstore{Status: Null} - return nil - } +func parseHstore(s string) (Hstore, error) { + p := newHSP(s) - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } + // This is an over-estimate of the number of key/value pairs. Use '>' because I am guessing it + // is less likely to occur in keys/values than '=' or ','. + numPairsEstimate := strings.Count(s, ">") + // makes one allocation of strings for the entire Hstore, rather than one allocation per value. + valueStrings := make([]string, 0, numPairsEstimate) + result := make(Hstore, numPairsEstimate) + first := true + for !p.atEnd() { + if !first { + err := p.consumePairSeparator() + if err != nil { + return nil, err + } + } else { + first = false + } - return errors.Errorf("cannot scan %T", src) -} + err := p.consumeExpectedByte('"') + if err != nil { + return nil, err + } + + key, err := p.consumeDoubleQuoted() + if err != nil { + return nil, err + } + + err = p.consumeKVSeparator() + if err != nil { + return nil, err + } + + value, err := p.consumeDoubleQuotedOrNull() + if err != nil { + return nil, err + } + if value.Valid { + valueStrings = append(valueStrings, value.String) + result[key] = &valueStrings[len(valueStrings)-1] + } else { + result[key] = nil + } + } -// Value implements the database/sql/driver Valuer interface. -func (src *Hstore) Value() (driver.Value, error) { - return EncodeValueText(src) + return result, nil } diff --git a/pgtype/hstore_array.go b/pgtype/hstore_array.go deleted file mode 100644 index 80530c26f..000000000 --- a/pgtype/hstore_array.go +++ /dev/null @@ -1,300 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type HstoreArray struct { - Elements []Hstore - Dimensions []ArrayDimension - Status Status -} - -func (dst *HstoreArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = HstoreArray{Status: Null} - return nil - } - - switch value := src.(type) { - - case []map[string]string: - if value == nil { - *dst = HstoreArray{Status: Null} - } else if len(value) == 0 { - *dst = HstoreArray{Status: Present} - } else { - elements := make([]Hstore, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = HstoreArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - default: - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to HstoreArray", value) - } - - return nil -} - -func (dst *HstoreArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *HstoreArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - - case *[]map[string]string: - *v = make([]map[string]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *HstoreArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = HstoreArray{Status: Null} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Hstore - - if len(uta.Elements) > 0 { - elements = make([]Hstore, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Hstore - var elemSrc []byte - if s != "NULL" { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = HstoreArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} - - return nil -} - -func (dst *HstoreArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = HstoreArray{Status: Null} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = HstoreArray{Dimensions: arrayHeader.Dimensions, Status: Present} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Hstore, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = HstoreArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} - return nil -} - -func (src *HstoreArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src *HstoreArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("hstore"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, errors.Errorf("unable to find oid for type name %v", "hstore") - } - - for i := range src.Elements { - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *HstoreArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *HstoreArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/hstore_array_test.go b/pgtype/hstore_array_test.go deleted file mode 100644 index d629a04b6..000000000 --- a/pgtype/hstore_array_test.go +++ /dev/null @@ -1,184 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestHstoreArrayTranscode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustClose(t, conn) - - text := func(s string) pgtype.Text { - return pgtype.Text{String: s, Status: pgtype.Present} - } - - values := []pgtype.Hstore{ - {Map: map[string]pgtype.Text{}, Status: pgtype.Present}, - {Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, - {Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, - {Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, - {Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, - {Status: pgtype.Null}, - } - - specialStrings := []string{ - `"`, - `'`, - `\`, - `\\`, - `=>`, - ` `, - `\ / / \\ => " ' " '`, - } - for _, s := range specialStrings { - // Special key values - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key - - // Special value values - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key - } - - src := &pgtype.HstoreArray{ - Elements: values, - Dimensions: []pgtype.ArrayDimension{{Length: int32(len(values)), LowerBound: 1}}, - Status: pgtype.Present, - } - - ps, err := conn.Prepare("test", "select $1::hstore[]") - if err != nil { - t.Fatal(err) - } - - formats := []struct { - name string - formatCode int16 - }{ - {name: "TextFormat", formatCode: pgx.TextFormatCode}, - {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, - } - - for _, fc := range formats { - ps.FieldDescriptions[0].FormatCode = fc.formatCode - vEncoder := testutil.ForceEncoder(src, fc.formatCode) - if vEncoder == nil { - t.Logf("%#v does not implement %v", src, fc.name) - continue - } - - var result pgtype.HstoreArray - err := conn.QueryRow("test", vEncoder).Scan(&result) - if err != nil { - t.Errorf("%v: %v", fc.name, err) - continue - } - - if result.Status != src.Status { - t.Errorf("%v: expected Status %v, got %v", fc.formatCode, src.Status, result.Status) - continue - } - - if len(result.Elements) != len(src.Elements) { - t.Errorf("%v: expected %v elements, got %v", fc.formatCode, len(src.Elements), len(result.Elements)) - continue - } - - for i := range result.Elements { - a := src.Elements[i] - b := result.Elements[i] - - if a.Status != b.Status { - t.Errorf("%v element idx %d: expected status %v, got %v", fc.formatCode, i, a.Status, b.Status) - } - - if len(a.Map) != len(b.Map) { - t.Errorf("%v element idx %d: expected %v pairs, got %v", fc.formatCode, i, len(a.Map), len(b.Map)) - } - - for k := range a.Map { - if a.Map[k] != b.Map[k] { - t.Errorf("%v element idx %d: expected key %v to be %v, got %v", fc.formatCode, i, k, a.Map[k], b.Map[k]) - } - } - } - } -} - -func TestHstoreArraySet(t *testing.T) { - successfulTests := []struct { - src []map[string]string - result pgtype.HstoreArray - }{ - { - src: []map[string]string{{"foo": "bar"}}, - result: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - }, - } - - for i, tt := range successfulTests { - var dst pgtype.HstoreArray - err := dst.Set(tt.src) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(dst, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) - } - } -} - -func TestHstoreArrayAssignTo(t *testing.T) { - var m []map[string]string - - simpleTests := []struct { - src pgtype.HstoreArray - dst *[]map[string]string - expected []map[string]string - }{ - { - src: pgtype.HstoreArray{ - Elements: []pgtype.Hstore{ - { - Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, - Status: pgtype.Present, - }, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &m, - expected: []map[string]string{{"foo": "bar"}}}, - {src: pgtype.HstoreArray{Status: pgtype.Null}, dst: &m, expected: (([]map[string]string)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(*tt.dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } -} diff --git a/pgtype/hstore_test.go b/pgtype/hstore_test.go index d76c99420..dd064c60e 100644 --- a/pgtype/hstore_test.go +++ b/pgtype/hstore_test.go @@ -1,25 +1,120 @@ package pgtype_test import ( + "context" + "fmt" "reflect" "testing" + "time" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" ) -func TestHstoreTranscode(t *testing.T) { - text := func(s string) pgtype.Text { - return pgtype.Text{String: s, Status: pgtype.Present} +func isExpectedEqMapStringString(a any) func(any) bool { + return func(v any) bool { + am := a.(map[string]string) + vm := v.(map[string]string) + + if len(am) != len(vm) { + return false + } + + for k, v := range am { + if vm[k] != v { + return false + } + } + + return true + } +} + +func isExpectedEqMapStringPointerString(a any) func(any) bool { + return func(v any) bool { + am := a.(map[string]*string) + vm := v.(map[string]*string) + + if len(am) != len(vm) { + return false + } + + for k, v := range am { + if (vm[k] == nil) != (v == nil) { + return false + } + + if v != nil && *vm[k] != *v { + return false + } + } + + return true } +} + +// stringPtr returns a pointer to s. +func stringPtr(s string) *string { + return &s +} - values := []interface{}{ - &pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, - &pgtype.Hstore{Status: pgtype.Null}, +func TestHstoreCodec(t *testing.T) { + ctr := defaultConnTestRunner + ctr.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var hstoreOID uint32 + err := conn.QueryRow(context.Background(), `select oid from pg_type where typname = 'hstore'`).Scan(&hstoreOID) + if err != nil { + t.Skipf("Skipping: cannot find hstore OID") + } + + conn.TypeMap().RegisterType(&pgtype.Type{Name: "hstore", OID: hstoreOID, Codec: pgtype.HstoreCodec{}}) + } + + tests := []pgxtest.ValueRoundTripTest{ + { + map[string]string{}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{}), + }, + { + map[string]string{"foo": "", "bar": "", "baz": "123"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo": "", "bar": "", "baz": "123"}), + }, + { + map[string]string{"NULL": "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"NULL": "bar"}), + }, + { + map[string]string{"bar": "NULL"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"bar": "NULL"}), + }, + { + map[string]string{"": "foo"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"": "foo"}), + }, + { + map[string]*string{}, + new(map[string]*string), + isExpectedEqMapStringPointerString(map[string]*string{}), + }, + { + map[string]*string{"foo": stringPtr("bar"), "baq": stringPtr("quz")}, + new(map[string]*string), + isExpectedEqMapStringPointerString(map[string]*string{"foo": stringPtr("bar"), "baq": stringPtr("quz")}), + }, + { + map[string]*string{"foo": nil, "baq": stringPtr("quz")}, + new(map[string]*string), + isExpectedEqMapStringPointerString(map[string]*string{"foo": nil, "baq": stringPtr("quz")}), + }, + {nil, new(*map[string]string), isExpectedEq((*map[string]string)(nil))}, + {nil, new(*map[string]*string), isExpectedEq((*map[string]*string)(nil))}, + {nil, new(*pgtype.Hstore), isExpectedEq((*pgtype.Hstore)(nil))}, } specialStrings := []string{ @@ -30,80 +125,286 @@ func TestHstoreTranscode(t *testing.T) { `=>`, ` `, `\ / / \\ => " ' " '`, + "line1\nline2", + "tab\tafter", + "vtab\vafter", + "form\\ffeed", + "carriage\rreturn", + "curly{}braces", + // Postgres on Mac OS X hstore parsing bug: + // ą = "\xc4\x85" in UTF-8; isspace(0x85) on Mac OS X returns true instead of false + "mac_bugą", } for _, s := range specialStrings { // Special key values - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key + + // at beginning + tests = append(tests, pgxtest.ValueRoundTripTest{ + map[string]string{s + "foo": "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{s + "foo": "bar"}), + }) + // in middle + tests = append(tests, pgxtest.ValueRoundTripTest{ + map[string]string{"foo" + s + "bar": "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo" + s + "bar": "bar"}), + }) + // at end + tests = append(tests, pgxtest.ValueRoundTripTest{ + map[string]string{"foo" + s: "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo" + s: "bar"}), + }) + // is key + tests = append(tests, pgxtest.ValueRoundTripTest{ + map[string]string{s: "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{s: "bar"}), + }) // Special value values - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key + + // at beginning + tests = append(tests, pgxtest.ValueRoundTripTest{ + map[string]string{"foo": s + "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo": s + "bar"}), + }) + // in middle + tests = append(tests, pgxtest.ValueRoundTripTest{ + map[string]string{"foo": "foo" + s + "bar"}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo": "foo" + s + "bar"}), + }) + // at end + tests = append(tests, pgxtest.ValueRoundTripTest{ + map[string]string{"foo": "foo" + s}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo": "foo" + s}), + }) + // is key + tests = append(tests, pgxtest.ValueRoundTripTest{ + map[string]string{"foo": s}, + new(map[string]string), + isExpectedEqMapStringString(map[string]string{"foo": s}), + }) } - testutil.TestSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { - a := ai.(pgtype.Hstore) - b := bi.(pgtype.Hstore) + pgxtest.RunValueRoundTripTests(context.Background(), t, ctr, pgxtest.KnownOIDQueryExecModes, "hstore", tests) - if len(a.Map) != len(b.Map) || a.Status != b.Status { - return false + // run the tests using pgtype.Hstore as input and output types, and test all query modes + for i := range tests { + var h pgtype.Hstore + switch typedParam := tests[i].Param.(type) { + case map[string]*string: + h = pgtype.Hstore(typedParam) + case map[string]string: + if typedParam != nil { + h = pgtype.Hstore{} + for k, v := range typedParam { + h[k] = stringPtr(v) + } + } } - for k := range a.Map { - if a.Map[k] != b.Map[k] { - return false - } + tests[i].Param = h + tests[i].Result = &pgtype.Hstore{} + tests[i].Test = func(input any) bool { + return reflect.DeepEqual(input, h) } + } + pgxtest.RunValueRoundTripTests(context.Background(), t, ctr, pgxtest.AllQueryExecModes, "hstore", tests) - return true + // run the tests again without the codec registered: uses the text protocol + ctrWithoutCodec := defaultConnTestRunner + pgxtest.RunValueRoundTripTests(context.Background(), t, ctrWithoutCodec, pgxtest.AllQueryExecModes, "hstore", tests) + + // scan empty and NULL: should be different in all query modes + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, ctr, pgxtest.AllQueryExecModes, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + h := pgtype.Hstore{"should_be_erased": nil} + err := conn.QueryRow(ctx, `select cast(null as hstore)`).Scan(&h) + if err != nil { + t.Fatal(err) + } + expectedNil := pgtype.Hstore(nil) + if !reflect.DeepEqual(h, expectedNil) { + t.Errorf("plain conn.Scan failed expectedNil=%#v actual=%#v", expectedNil, h) + } + + err = conn.QueryRow(ctx, `select cast('' as hstore)`).Scan(&h) + if err != nil { + t.Fatal(err) + } + expectedEmpty := pgtype.Hstore{} + if !reflect.DeepEqual(h, expectedEmpty) { + t.Errorf("plain conn.Scan failed expectedEmpty=%#v actual=%#v", expectedEmpty, h) + } }) } -func TestHstoreSet(t *testing.T) { - successfulTests := []struct { - src map[string]string - result pgtype.Hstore - }{ - {src: map[string]string{"foo": "bar"}, result: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, Status: pgtype.Present}}, +func TestParseInvalidInputs(t *testing.T) { + // these inputs should be invalid, but previously were considered correct + invalidInputs := []string{ + // extra comma between values + `"a"=>"1", ,b"=>"2"`, + // missing doublequote before second value + `""=>"", 0"=>""`, } - - for i, tt := range successfulTests { - var dst pgtype.Hstore - err := dst.Set(tt.src) - if err != nil { - t.Errorf("%d: %v", i, err) + for i, input := range invalidInputs { + var hstore pgtype.Hstore + err := hstore.Scan(input) + if err == nil { + t.Errorf("test %d: input=%s (%#v) should fail; parsed correctly", i, input, input) } + } +} + +func TestRoundTrip(t *testing.T) { + codecs := []struct { + name string + encodePlan pgtype.EncodePlan + scanPlan pgtype.ScanPlan + }{ + { + "text", + pgtype.HstoreCodec{}.PlanEncode(nil, 0, pgtype.TextFormatCode, pgtype.Hstore(nil)), + pgtype.HstoreCodec{}.PlanScan(nil, 0, pgtype.TextFormatCode, (*pgtype.Hstore)(nil)), + }, + { + "binary", + pgtype.HstoreCodec{}.PlanEncode(nil, 0, pgtype.BinaryFormatCode, pgtype.Hstore(nil)), + pgtype.HstoreCodec{}.PlanScan(nil, 0, pgtype.BinaryFormatCode, (*pgtype.Hstore)(nil)), + }, + } - if !reflect.DeepEqual(dst, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) + inputs := []pgtype.Hstore{ + nil, + {}, + {"": stringPtr("")}, + {"k1": stringPtr("v1")}, + {"k1": stringPtr("v1"), "k2": stringPtr("v2")}, + } + for _, codec := range codecs { + for i, input := range inputs { + t.Run(fmt.Sprintf("%s/%d", codec.name, i), func(t *testing.T) { + serialized, err := codec.encodePlan.Encode(input, nil) + if err != nil { + t.Fatal(err) + } + var output pgtype.Hstore + err = codec.scanPlan.Scan(serialized, &output) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(output, input) { + t.Errorf("output=%#v does not match input=%#v", output, input) + } + }) } } } -func TestHstoreAssignTo(t *testing.T) { - var m map[string]string +func BenchmarkHstoreEncode(b *testing.B) { + h := pgtype.Hstore{ + "a x": stringPtr("100"), "b": stringPtr("200"), "c": stringPtr("300"), + "d": stringPtr("400"), "e": stringPtr("500"), + } - simpleTests := []struct { - src pgtype.Hstore - dst *map[string]string - expected map[string]string + serializeConfigs := []struct { + name string + encodePlan pgtype.EncodePlan }{ - {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, Status: pgtype.Present}, dst: &m, expected: map[string]string{"foo": "bar"}}, - {src: pgtype.Hstore{Status: pgtype.Null}, dst: &m, expected: ((map[string]string)(nil))}, + {"text", pgtype.HstoreCodec{}.PlanEncode(nil, 0, pgtype.TextFormatCode, h)}, + {"binary", pgtype.HstoreCodec{}.PlanEncode(nil, 0, pgtype.BinaryFormatCode, h)}, + } + + for _, serializeConfig := range serializeConfigs { + var buf []byte + b.Run(serializeConfig.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + var err error + buf, err = serializeConfig.encodePlan.Encode(h, buf) + if err != nil { + b.Fatal(err) + } + buf = buf[:0] + } + }) } +} + +func BenchmarkHstoreScan(b *testing.B) { + // empty, NULL, escapes, and based on some real data + benchStrings := []string{ + "", + `"a"=>"b"`, + `"a"=>"100", "b"=>"200", "c"=>"300", "d"=>"400", "e"=>"500"`, + `"a"=>"100", "b"=>NULL, "c"=>"300", "d"=>NULL, "e"=>"500"`, + `"pmd"=>"piokifjzxdy:mhvvmotns:sf1-dttudcp-orx-fuwzw-j8o-tl-jcg-1fb5d6dp50ke3l24", "ausz"=>"aorc-iosdby_tbxsjihj-kss64-32r128y-i2", "mgjo"=>"hxcp-ciag", "hkbee"=>"bokihheb", "gpcvhc"=>"ne-ywik-1", "olzjegk"=>"rxbkzba", "iy_quthhf"=>"sryizraxx", "bwpdpplfz"=>"gbdh-jikmnp_jwugdvjs-drh64-32k128h-p2", "njy_veipyyl"=>"727006795293", "vsgvqlrnqadzvk"=>"1_7_43", "mfdncuqvxp_gqlkytj"=>"fuyin", "cnuiswkwavoupqebov"=>"x32n128w", "mol_lcabioescln_ulstxauvi"=>"qm1-adbcand-tzi-fpnbv-s8j-vi-gqs-1om5b6lx50zk3u24", "arlyhgdxux.fc/bezucmz/mmfed"=>"vihsk", "jtkf.czddftrhr.ici/qbq_ftaz"=>"sse64", "notxkfqmpq.whxmykhtc.bcu/zmxz"=>"zauaklqp-uwo64-32q128a-g2", "ww_affdwqa_o8o_ilskcucq_urzltnf"=>"i6-9-0", "f8d.eq/bbqxwru-vsznvxerae/wsszbjw"=>"dgd", "ygpghkljze.dkrlrrieo.iur/xfqdqreft"=>"pfby-bhqlmm", "pmho-dqxuezyuu.ppslmznja.eam/ikehtxg"=>"wbku", "ckqeavtcqk.jiqdipgji.hjl/luzgqb-agm-wb"=>"ikpq", "akcn-yobdpxkyl.gktsjdo-xqwmivixku.p8y.vq/axqdw"=>"", "r8u.at/fbqrrss-ihxjmygoyc/ztqe-pqqqewnz/nepdj/njjv"=>"txtlffpp:ebwdksxkej", "q8x.wu/wenlhkz-govetdoibn/rcwg-ticalfjq/mgipy/awmjl"=>"dyzvbzvi", "p8l.wx/vadrnki-yfqhzlwcnt/hvun-geqhjsik/eqediipfr/vlc"=>"31900z", "t8z.be/qbtsmci-jqnqphssdg/ejma-slvywzry/txpnybwvn/kxdl"=>"210", "o8b.nb/bijgpwm-axvvqgujax/fjli-mxqwulfe/revyfoyty/oojpsd"=>"123421925786", "p8q.sk/ccpgzee-ufjempgvty/afwh-qvwzjvog/hsyhr/bklplujbfydtfw"=>"1_7_43", "k8y.jp/hqoymrw-flwqwvbntf/dlli-uggxkdqv/mtutu/qotjmacjitwtvcnblr"=>"m32x128f", "r8z.hj/eczodcw-lxzmeeqqii/fjba-psyoidht/gfjjcdbqs/apkqxiznu-muzubvl"=>"106068512341", "u8v.nf/ocnahkw-prhuwrrbjg/gxms-isohcouc/txfle/zfzw.neyygeeur.ejv/rnd_vdyo"=>"ibx64", "i8c.zz/dtiulqn-mmbskzjcib/fxuj-ejxbrnqi/optyp/wbbrancspv.pnkizgxcj.dbm/bldn"=>"znppnwzg-oxp64-32r128h-d2", "d8t.dg/jqtodoh-sokunyljow/svdf-ghplxxcx/wqkwl/dolljeqv.jcn.dxp.jmh.uyf/lyfv"=>"kc-lmpu-1i", "t8i.dy/imltbpr-atmthzarmk/fbbw-uaovyvdj/mmuwq/kseu-snmt.xtlgkstzph.mg/ehjdpgc"=>"", "o8c.yc/wximcpf-wmffadvnxx/tdim-szbqedqp/ztrui/puhx-kcwp.zziulqvvmb.ik/khfaxajj"=>"", "j8i.zc/sajavzi-kemnitliml/nloy-riqothpw/yxmnp/ttrnynffzy.lswpezbdq.wor/xkvqeexio"=>"ltmp-zajsxt", "a8f.xd/tfrrawy-ymihugugaa/ouzi-xdyecmqx/cwvgjvcrh/trgbxgbumo.uh/xmnqbds-nqxxeuqpq"=>"3123748065", "x8n.vx/juiqxkj-swvwogmncw/hvad-pojmevog/ytxit/auvo-duchssbth.uickilmnz.lja/hbeiakj"=>"hwhd", "z8j.bn/iplhrhv-wjdcwdclos/qndu-qvotchss/spvfx/brqotjnytw.aaemsoxor.ign/uwebjm-vzl-kb"=>"zwdg", "t8j.vx/iekvskm-xhikarvbty/czlm-xtipxwok/eeeow/uvtpuzmlqg.jgtpgiujc.wrs/mcofa-qxjjwak"=>"sovxb", "t8g.ab/wuncjdz-vsozsekgxz/aaea-hmgdjylm/qimwsoecgud-grgoowb/zveahbidvwcaebhlzigytiermehxy"=>"0.95", "n8k.ei/ohovibm-obkaatwlyw/bcow-gndyzpyt/aehyf/dpgifsorjx.ehsqntrka.jrr/meakdzy-ckxgnfavwm"=>"nlgw", "u8e.yi/qavbjew-qnmtzbeyce/rmwa-hcqlvadn/bhpml/taoj-wjnh.qqvkjmccfn.ja/nudbtwme-buc64-32j128i-k2"=>""`, + `"mbgs"=>"eqjillclhkxz", "bxci"=>"etksm.rudiu", "jijqqm"=>"kj-ryxhwqtco-2", "yivvcxy"=>"fwbujcu", "ybk_ztlajai"=>"601427279990"`, + `"wte"=>"nrhw", "lqjm"=>"ifsbygchn", "wbmf"=>"amjsoykkwq\\ghvwbsmz-qeiv-iekd-ukcwbipzy"`, + `"otx"=>"fcreomqbwtk:gqhxzhxuh:wrqo-rf1-avhdpfy-nqi-dldof-i8p-mw-jll-l5r9741753c3", "vbjy"=>"akzfspigip_muzyxzwuso-zvoifh-uw", "fmkb"=>"pkoe-lezf", "wfbq"=>"qoviagajeg", "zvxbiv"=>"db-bcngmoq-1", "olictqnpx"=>"taqcnrcwcj_ticfxydekq-fafbkg-ot", "wkt_jtzzqpt"=>"727006795293", "bsdncvmbvj_xivgkws"=>"zczag", "muzq.oyrphhtne.fqm/itc"=>"ihilzgx", "pfsd.xphmjdohu.hrm/yeimpfm"=>"lrrqxrwyud-uvcljo", "qukdxappwo.or/xgcsmdo/dodoj"=>"onflq", "ktqrsqtllo.xxxpkizlg.tnf/unrt"=>"jrveutvddu-loihei-ww", "tr_qmarsis_s8v_skzbuuvy_cnyuxyk"=>"g6-16-0", "z8q.yc/xistcyy-tftbikuuhg/zvhemmi"=>"knv", "zrgwpjnvzq.twkcxxuyk.qwc/nirbacaom"=>"okfdlcpbdg", "suvk-wwwjqdytq.wdjmzxl-nduettmnmf.e8e.ec/qhkan"=>"", "u8m.xa/uvbhlmw-rqrcyyaiju/otsg-bqjfitoq/zqfuq/fifo"=>"brarmrogdb", "b8o.ci/znwkyby-nzuxiguqus/nwou-cxxnqxrr/rtdsp/yawv"=>"juedpptnbt-khocdt-vg:vfxpdswxnc", "u8h.vl/kgmvysr-xhykrjcssj/jfjv-gzalgika/yhrjfytwz/kbm"=>"3900f", "y8b.cm/ttijscl-rznjossaqw/kvto-gvnavnep/bwdqyuzgo/ozoi"=>"40", "p8j.pd/bnucngv-vnqufgvfqw/qshw-obnkmlfx/obczheyis/zzbsos"=>"7009zf", "p8y.fc/ejbndrq-aariupaovi/mrah-hmrhjcsv/lvrmfwwiz/uskogxfuw-zamygae"=>"18747532246", "y8m.oh/xzuhilr-wqmqqzcznb/pcox-idpxmhfj/yzsoj/qebkjaeymc.abqznnelq.gyd/osvb"=>"hsgxlccalq-eeybug-mx", "p8f.ay/tyntrss-nljxedfihd/grvy-znfykhlf/fjsqd/ffxaixyv.jie.bkg.zpd.kim/mgtc"=>"or-vrkdcxm-1i", "i8m.ms/jtykfbi-jdrqsqjdwt/ibaq-zmeuyznf/uczny/ufmj-zklt.omodkgubqw.ip/xztdevd"=>"", "k8m.ui/ymxurqo-kuhofnewjj/twex-iuwljutj/warlx/zptkdgqdpr.uhvqtrclx.ohj/bdkgsozkk"=>"zlgisdikac", "g8b.wk/vecudfr-pljllpgzxi/lbwd-zsracrgq/fucssaowj/syizbmlfqt.si/swpbend-gxrhddxad"=>"156213905", "z8y.ah/azeasta-gffxfwklrn/hukw-hphwntwy/lfswv/tmaeaxekya.vgkxjhtvg.mht/bzolt-koioxpf"=>"wzkra", "f8l.sy/ouekhco-rlhsclfzwx/erfz-uuejogrs/bgvia/zpohrhmrmu.sbdxzlaxo.wii/jbnwfvz-shekbewool"=>"aiey", "j8w.pz/fjtkxhn-zxxizfldde/wsik-uiodldga/ljdtl/gswz-cjmt.ffkelhxcsd.lw/ftcqgdnnho-ibbfql-ww"=>""`, + `"uvd"=>"oneotg", "wsm"=>"djjgmwqyple:jtxtfvtjv:du1-nfxzmra-idl-ikxbx-t8n-id-nbo-6d08opx70381", "orq"=>"bkdvjw-xydgbd1", "gblm"=>"jtkcfd-unxbag1_xagyfw-nvachf1", "mfer"=>"jclz-yaim", "jvgvas"=>"jf-vhxh-1", "wwardeuqu"=>"ufimeb-bscfdy1_bfuagy-dhdqra1", "szs_rfgpqmc"=>"727006795293", "ckfxcgrnqc_rloxzxu"=>"qffbw", "yaigdvscju.ba/krpgzji/wvxyg"=>"srgtu", "gtxfjsigdv.pxujnffnp.aza/ycco"=>"ntranp-ahgeem1", "xj_lhdpvsl_i8i_qzrtlpjr_nroujqh"=>"q6-1-8", "czxy-sfym.enlohvvjmp.wb/huvcuhy"=>"", "x8a.of/sqpdqiq-vijrlgkkyl/oncckls"=>"mij", "oomgvfopmc.trnzktrtz.gza/rpeqqyqmm"=>"rgwnma-bwcbxe1", "gaud-giar.xuablvwkbo.wy/wvhmsk-uaycqn1"=>"", "oarbmcqzzw.qkfbtmltz.plh/aqssj-tlrhsof"=>"wxfd", "zepirccplb.qanvqnxlo.eld/emulnov-vgddsefeqv"=>"jnvh", "acby-kywxjuczc.suosfcy-drsgroeqvy.o8m.og/vyuxt"=>"", "q8j.by/lrwxbjt-yzrenlniog/gbmw-mnokcndu/etbcy/ibwr"=>"qpttug-jnxhwe1:grmslxhyky", "i8y.uy/awavkxk-nztmqujxys/pocu-sqjdqvzd/tfdjeflpn/xsj"=>"7900c", "z8g.ia/yzfdvta-ffkciorpfl/kmjc-fgcdomlv/snvhhbjil/nhvn"=>"45", "s8l.ky/dtvxoqu-lzfdnykmdh/wtdg-aktximmy/hofzkpzel/wtghso"=>"14837zg", "v8e.rq/uosznaz-drypoapgpe/vxss-mbxmvkjj/oglvxhxcz/whutvtjmr-tewtidr"=>"18747532246", "m8p.sz/hrgniti-aufhjdsdcc/whcp-cfuwjsnl/exugj/evphviokhl.ashpndixr.jvx/vgtt"=>"zdsacy-ppfuxf1", "w8t.fm/kljwjgc-fijbwsrvxa/dbzl-fhxvlrwk/yidyk/orrt-kgpr.wuzmpnxvtb.lc/dmbqfvt"=>"", "m8j.sv/takylmm-ywnolaflnl/ueih-fdcpfcpv/dslbc/dsspusnhtu.vgkihqtpb.fto/qmyksglfx"=>"wpwuih-deuiej1", "m8x.wi/jwobkio-mwupghbqbi/krqn-hqyfgwuw/mcbyi/yzkt-wtdy.pjxevrogab.tj/qlttbz-ppyzkd1"=>"", "c8j.tr/tzcbhid-lggaiypnny/wyms-zcjgxmwp/eaohd/bcwkheknsr.fqvtgecsf.qbf/uaqzj-jburpix"=>"ckkk", "w8h.wk/msbqvqy-nsmvbojwns/edpo-nsivbrmx/qifaf/sopuabsuvq.foyniwomd.zvj/lhvfwvv-zuufhhspso"=>"fghx"`, + `"xxtlvd"=>"ba-zrzy-1", "hlebkcl"=>"entrcad", "ytn_toivqso"=>"601427279990", "czdllqyvkcfemhubpwvxakepubup"=>"jzhpff-vn2-sgiupfiii-qmuuz-ndex-vin-kmfm", "mefjcnjmcspgviisjalxmwdbksmge"=>"2022-11-20"`, + `"ukq"=>"uhkbdj", "bmj"=>"mcoknsnhqcb:vmexvsccu:yt1-nscwdfr-zcp-ajfhr-z8i-ta-jhv-58yl03459t86", "cuq"=>"sqphbh-xkxbcgwdx", "dnac"=>"khzjpq-hljdvlbsw_azdisd-nshizhinc", "flgj"=>"zeem-pggu", "ksnn"=>"vpittgnl-xeojllby-toq", "wwxepg"=>"ki-cwee-1", "vigdnntxw"=>"sydsls-zidlsgugi_wviqvl-umwzyztab", "osz_utmlghi"=>"727006795293", "wacdaefqhc_buqmsci"=>"djtcv", "ljdbotgrsi.xn/gvtjfeg/iiyek"=>"lnfgg", "sohcclfodf.wkwiitult.ppm/hhsf"=>"ecpftm-ecmsibfjy", "dz_cgfnddq_o8j_cowdxlfz_rmjunpm"=>"v5-13-1", "niwk-fozq.tbamcxrhez.kl/zuxnisw"=>"", "h8k.xu/nbsezqz-fopcyqlnwt/lfcmgag"=>"dmm", "zebgpskksd.daigyeicb.dlj/dwmcpkohh"=>"hegecl-bnqmkunkl", "irjreiuove.qpmjixctw.mzv/xizjv-bpecdmy"=>"rkfl", "fupz-eiim.hwaqzvpzgv.yg/zhrqmr-qcydocyak"=>"", "djuscbflju.fmhnephvc.cmo/wzcisia-kqmrrhnkiv"=>"vchu", "hauo-olkeyvbrz.qzpaocu-wdbyfrzjkx.c8a.rn/bwhfe"=>"", "l8d.fj/jzojrmv-mbnxftbdzg/qvgo-oayrldze/tqmoa/oizo"=>"buwgyd-bjlrzrlci:ywosrfsnts", "q8l.sj/vifqvao-ynvfejvleb/ourc-jzridgtt/fgxnueuvm/wsg"=>"7900p", "d8b.mi/steijrv-bgajdbugff/kxkj-jhvctoxw/seyrafhni/xxrc"=>"45", "x8k.bn/dnnkttb-ywqrwwxirk/ngvt-eqyaeqsd/qesxmjfos/nlolbe"=>"14837xp", "v8o.az/vtbyyyo-rjuadsmwyb/gszv-ytnisfau/kfunvihsr/famkeacyo-skpueao"=>"18747532246", "f8w.ip/sjzrxbw-idgsgucprq/ster-zxiilwcf/luwzw/tavccuqfph.mcubdrtcr.ibw/dxnj"=>"ntyjnf-zwlyjqbfq", "y8f.mh/qykpkfr-fsnlckrhpe/hvyu-vstwrxkq/dmesn/kuor-acub.fqwqxcpiet.jf/zaxtdyb"=>"", "c8m.et/ekavnnp-gvpmldvoou/jzva-zzzpiecc/dvckb/qqxrfpoaiy.ssfqerwmb.cnz/odsfndorh"=>"liilkb-aekfuqzss", "e8n.gp/sybrxvz-mghjbpphpc/wcuo-naanbtcj/agtov/dztlgdacuz.fpbhhiybg.ncm/otgfu-hnezrwu"=>"ccez", "t8h.cy/bqsdiil-lxmioonwjt/drsw-qevzljvt/rvzjl/btbz-npvi.ypyxowgmfp.gf/jcfbyh-khpgbaayw"=>"", "y8b.df/anmudfn-gahfengbqw/fhdi-ozqtddmu/lvviu/kndwvowlby.jxkizwkac.hbq/fjkqyna-jijxahivma"=>"wxqg"`, + } + + // convert benchStrings into text and binary bytes + textBytes := make([][]byte, len(benchStrings)) + binaryBytes := make([][]byte, len(benchStrings)) + codec := pgtype.HstoreCodec{}.PlanEncode(nil, 0, pgtype.BinaryFormatCode, pgtype.Hstore(nil)) + for i, s := range benchStrings { + textBytes[i] = []byte(s) - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) + var tempH pgtype.Hstore + err := tempH.Scan(s) if err != nil { - t.Errorf("%d: %v", i, err) + b.Fatal(err) } + binaryBytes[i], err = codec.Encode(tempH, nil) + if err != nil { + b.Fatal(err) + } + } - if !reflect.DeepEqual(*tt.dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + // benchmark the database/sql.Scan API + var h pgtype.Hstore + b.Run("databasesql.Scan", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + for _, str := range benchStrings { + err := h.Scan(str) + if err != nil { + b.Fatal(err) + } + } } + }) + + // benchmark the []byte scan API used by pgconn + scanConfigs := []struct { + name string + scanPlan pgtype.ScanPlan + inputBytes [][]byte + }{ + {"text", pgtype.HstoreCodec{}.PlanScan(nil, 0, pgtype.TextFormatCode, &h), textBytes}, + {"binary", pgtype.HstoreCodec{}.PlanScan(nil, 0, pgtype.BinaryFormatCode, &h), binaryBytes}, + } + for _, scanConfig := range scanConfigs { + b.Run(scanConfig.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + for _, input := range scanConfig.inputBytes { + err := scanConfig.scanPlan.Scan(input, &h) + if err != nil { + b.Fatalf("input=%#v err=%s", string(input), err) + } + } + } + }) } } diff --git a/pgtype/inet.go b/pgtype/inet.go index 01fc0e5b8..b92edb239 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -1,10 +1,11 @@ package pgtype import ( + "bytes" "database/sql/driver" - "net" - - "github.com/pkg/errors" + "errors" + "fmt" + "net/netip" ) // Network address family is dependent on server socket.h value for AF_INET. @@ -15,201 +16,184 @@ const ( defaultAFInet6 = 3 ) -// Inet represents both inet and cidr PostgreSQL types. -type Inet struct { - IPNet *net.IPNet - Status Status +type NetipPrefixScanner interface { + ScanNetipPrefix(v netip.Prefix) error } -func (dst *Inet) Set(src interface{}) error { - if src == nil { - *dst = Inet{Status: Null} +type NetipPrefixValuer interface { + NetipPrefixValue() (netip.Prefix, error) +} + +// InetCodec handles both inet and cidr PostgreSQL types. The preferred Go types are [netip.Prefix] and [netip.Addr]. If +// IsValid() is false then they are treated as SQL NULL. +type InetCodec struct{} + +func (InetCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (InetCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (InetCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(NetipPrefixValuer); !ok { return nil } - switch value := src.(type) { - case net.IPNet: - *dst = Inet{IPNet: &value, Status: Present} - case *net.IPNet: - *dst = Inet{IPNet: value, Status: Present} - case net.IP: - bitCount := len(value) * 8 - mask := net.CIDRMask(bitCount, bitCount) - *dst = Inet{IPNet: &net.IPNet{Mask: mask, IP: value}, Status: Present} - case string: - _, ipnet, err := net.ParseCIDR(value) - if err != nil { - return err - } - *dst = Inet{IPNet: ipnet, Status: Present} - default: - if originalSrc, ok := underlyingPtrType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to Inet", value) + switch format { + case BinaryFormatCode: + return encodePlanInetCodecBinary{} + case TextFormatCode: + return encodePlanInetCodecText{} } return nil } -func (dst *Inet) Get() interface{} { - switch dst.Status { - case Present: - return dst.IPNet - case Null: - return nil - default: - return dst.Status - } -} +type encodePlanInetCodecBinary struct{} -func (src *Inet) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *net.IPNet: - *v = net.IPNet{ - IP: make(net.IP, len(src.IPNet.IP)), - Mask: make(net.IPMask, len(src.IPNet.Mask)), - } - copy(v.IP, src.IPNet.IP) - copy(v.Mask, src.IPNet.Mask) - return nil - case *net.IP: - if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount { - return errors.Errorf("cannot assign %v to %T", src, dst) - } - *v = make(net.IP, len(src.IPNet.IP)) - copy(*v, src.IPNet.IP) - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) +func (encodePlanInetCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + prefix, err := value.(NetipPrefixValuer).NetipPrefixValue() + if err != nil { + return nil, err } - return errors.Errorf("cannot decode %v into %T", src, dst) -} + if !prefix.IsValid() { + return nil, nil + } -func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Inet{Status: Null} - return nil + var family byte + if prefix.Addr().Is4() { + family = defaultAFInet + } else { + family = defaultAFInet6 } - var ipnet *net.IPNet - var err error + buf = append(buf, family) - if ip := net.ParseIP(string(src)); ip != nil { - ipv4 := ip.To4() - if ipv4 != nil { - ip = ipv4 - } - bitCount := len(ip) * 8 - mask := net.CIDRMask(bitCount, bitCount) - ipnet = &net.IPNet{Mask: mask, IP: ip} + ones := prefix.Bits() + buf = append(buf, byte(ones)) + + // is_cidr is ignored on server + buf = append(buf, 0) + + if family == defaultAFInet { + buf = append(buf, byte(4)) + b := prefix.Addr().As4() + buf = append(buf, b[:]...) } else { - _, ipnet, err = net.ParseCIDR(string(src)) - if err != nil { - return err - } + buf = append(buf, byte(16)) + b := prefix.Addr().As16() + buf = append(buf, b[:]...) } - *dst = Inet{IPNet: ipnet, Status: Present} - return nil + return buf, nil } -func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Inet{Status: Null} - return nil - } +type encodePlanInetCodecText struct{} - if len(src) != 8 && len(src) != 20 { - return errors.Errorf("Received an invalid size for a inet: %d", len(src)) +func (encodePlanInetCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + prefix, err := value.(NetipPrefixValuer).NetipPrefixValue() + if err != nil { + return nil, err } - // ignore family - bits := src[1] - // ignore is_cidr - addressLength := src[3] + if !prefix.IsValid() { + return nil, nil + } - var ipnet net.IPNet - ipnet.IP = make(net.IP, int(addressLength)) - copy(ipnet.IP, src[4:]) - ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8) + return append(buf, prefix.String()...), nil +} - *dst = Inet{IPNet: &ipnet, Status: Present} +func (InetCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case NetipPrefixScanner: + return scanPlanBinaryInetToNetipPrefixScanner{} + } + case TextFormatCode: + switch target.(type) { + case NetipPrefixScanner: + return scanPlanTextAnyToNetipPrefixScanner{} + } + } return nil } -func (src *Inet) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: +func (c InetCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c InetCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { return nil, nil - case Undefined: - return nil, errUndefined } - return append(buf, src.IPNet.String()...), nil -} + var prefix netip.Prefix + err := codecScan(c, m, oid, format, src, (*netipPrefixWrapper)(&prefix)) + if err != nil { + return nil, err + } -// EncodeBinary encodes src into w. -func (src *Inet) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: + if !prefix.IsValid() { return nil, nil - case Undefined: - return nil, errUndefined } - var family byte - switch len(src.IPNet.IP) { - case net.IPv4len: - family = defaultAFInet - case net.IPv6len: - family = defaultAFInet6 - default: - return nil, errors.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) - } + return prefix, nil +} - buf = append(buf, family) +type scanPlanBinaryInetToNetipPrefixScanner struct{} - ones, _ := src.IPNet.Mask.Size() - buf = append(buf, byte(ones)) +func (scanPlanBinaryInetToNetipPrefixScanner) Scan(src []byte, dst any) error { + scanner := (dst).(NetipPrefixScanner) - // is_cidr is ignored on server - buf = append(buf, 0) + if src == nil { + return scanner.ScanNetipPrefix(netip.Prefix{}) + } - buf = append(buf, byte(len(src.IPNet.IP))) + if len(src) != 8 && len(src) != 20 { + return fmt.Errorf("Received an invalid size for an inet: %d", len(src)) + } - return append(buf, src.IPNet.IP...), nil + // ignore family + bits := src[1] + // ignore is_cidr + // ignore addressLength - implicit in length of message + + addr, ok := netip.AddrFromSlice(src[4:]) + if !ok { + return errors.New("netip.AddrFromSlice failed") + } + + return scanner.ScanNetipPrefix(netip.PrefixFrom(addr, int(bits))) } -// Scan implements the database/sql Scanner interface. -func (dst *Inet) Scan(src interface{}) error { +type scanPlanTextAnyToNetipPrefixScanner struct{} + +func (scanPlanTextAnyToNetipPrefixScanner) Scan(src []byte, dst any) error { + scanner := (dst).(NetipPrefixScanner) + if src == nil { - *dst = Inet{Status: Null} - return nil + return scanner.ScanNetipPrefix(netip.Prefix{}) } - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + var prefix netip.Prefix + if bytes.IndexByte(src, '/') == -1 { + addr, err := netip.ParseAddr(string(src)) + if err != nil { + return err + } + prefix = netip.PrefixFrom(addr, addr.BitLen()) + } else { + var err error + prefix, err = netip.ParsePrefix(string(src)) + if err != nil { + return err + } } - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *Inet) Value() (driver.Value, error) { - return EncodeValueText(src) + return scanner.ScanNetipPrefix(prefix) } diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go deleted file mode 100644 index f3e4efbfb..000000000 --- a/pgtype/inet_array.go +++ /dev/null @@ -1,329 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "net" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type InetArray struct { - Elements []Inet - Dimensions []ArrayDimension - Status Status -} - -func (dst *InetArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = InetArray{Status: Null} - return nil - } - - switch value := src.(type) { - - case []*net.IPNet: - if value == nil { - *dst = InetArray{Status: Null} - } else if len(value) == 0 { - *dst = InetArray{Status: Present} - } else { - elements := make([]Inet, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = InetArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []net.IP: - if value == nil { - *dst = InetArray{Status: Null} - } else if len(value) == 0 { - *dst = InetArray{Status: Present} - } else { - elements := make([]Inet, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = InetArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - default: - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to InetArray", value) - } - - return nil -} - -func (dst *InetArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *InetArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - - case *[]*net.IPNet: - *v = make([]*net.IPNet, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]net.IP: - *v = make([]net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = InetArray{Status: Null} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Inet - - if len(uta.Elements) > 0 { - elements = make([]Inet, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Inet - var elemSrc []byte - if s != "NULL" { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = InetArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} - - return nil -} - -func (dst *InetArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = InetArray{Status: Null} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = InetArray{Dimensions: arrayHeader.Dimensions, Status: Present} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Inet, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = InetArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} - return nil -} - -func (src *InetArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src *InetArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("inet"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, errors.Errorf("unable to find oid for type name %v", "inet") - } - - for i := range src.Elements { - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *InetArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *InetArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/inet_array_test.go b/pgtype/inet_array_test.go deleted file mode 100644 index ca528ed35..000000000 --- a/pgtype/inet_array_test.go +++ /dev/null @@ -1,165 +0,0 @@ -package pgtype_test - -import ( - "net" - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestInetArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "inet[]", []interface{}{ - &pgtype.InetArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.InetArray{Status: pgtype.Null}, - &pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, - {Status: pgtype.Null}, - {IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, - {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestInetArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.InetArray - }{ - { - source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, - result: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]*net.IPNet)(nil)), - result: pgtype.InetArray{Status: pgtype.Null}, - }, - { - source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, - result: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]net.IP)(nil)), - result: pgtype.InetArray{Status: pgtype.Null}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.InetArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInetArrayAssignTo(t *testing.T) { - var ipnetSlice []*net.IPNet - var ipSlice []net.IP - - simpleTests := []struct { - src pgtype.InetArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &ipnetSlice, - expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &ipnetSlice, - expected: []*net.IPNet{nil}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &ipSlice, - expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &ipSlice, - expected: []net.IP{nil}, - }, - { - src: pgtype.InetArray{Status: pgtype.Null}, - dst: &ipnetSlice, - expected: (([]*net.IPNet)(nil)), - }, - { - src: pgtype.InetArray{Status: pgtype.Null}, - dst: &ipSlice, - expected: (([]net.IP)(nil)), - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} diff --git a/pgtype/inet_test.go b/pgtype/inet_test.go index 32d669996..f4b43dafe 100644 --- a/pgtype/inet_test.go +++ b/pgtype/inet_test.go @@ -1,115 +1,99 @@ package pgtype_test import ( + "context" "net" - "reflect" + "net/netip" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) -func TestInetTranscode(t *testing.T) { - for _, pgTypeName := range []string{"inet", "cidr"} { - testutil.TestSuccessfulTranscode(t, pgTypeName, []interface{}{ - &pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, - &pgtype.Inet{Status: pgtype.Null}, - }) - } -} +func isExpectedEqIPNet(a any) func(any) bool { + return func(v any) bool { + ap := a.(*net.IPNet) + vp := v.(net.IPNet) -func TestInetSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Inet - }{ - {source: mustParseCIDR(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var r pgtype.Inet - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } + return ap.IP.Equal(vp.IP) && ap.Mask.String() == vp.Mask.String() } } -func TestInetAssignTo(t *testing.T) { - var ipnet net.IPNet - var pipnet *net.IPNet - var ip net.IP - var pip *net.IP +func TestInetTranscode(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "inet", []pgxtest.ValueRoundTripTest{ + {mustParseInet(t, "0.0.0.0/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "0.0.0.0/32"))}, + {mustParseInet(t, "127.0.0.1/8"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "127.0.0.1/8"))}, + {mustParseInet(t, "12.34.56.65/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "12.34.56.65/32"))}, + {mustParseInet(t, "192.168.1.16/24"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "192.168.1.16/24"))}, + {mustParseInet(t, "255.0.0.0/8"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "255.0.0.0/8"))}, + {mustParseInet(t, "255.255.255.255/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "255.255.255.255/32"))}, + {mustParseInet(t, "2607:f8b0:4009:80b::200e"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "2607:f8b0:4009:80b::200e"))}, + {mustParseInet(t, "::1/64"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::1/64"))}, + {mustParseInet(t, "::/0"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::/0"))}, + {mustParseInet(t, "::1/128"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::1/128"))}, + {mustParseInet(t, "2607:f8b0:4009:80b::200e/64"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "2607:f8b0:4009:80b::200e/64"))}, - simpleTests := []struct { - src pgtype.Inet - dst interface{} - expected interface{} - }{ - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, - {src: pgtype.Inet{Status: pgtype.Null}, dst: &pipnet, expected: ((*net.IPNet)(nil))}, - {src: pgtype.Inet{Status: pgtype.Null}, dst: &pip, expected: ((*net.IP)(nil))}, - } + {mustParseInet(t, "0.0.0.0/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("0.0.0.0/32"))}, + {mustParseInet(t, "127.0.0.1/8"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("127.0.0.1/8"))}, + {mustParseInet(t, "12.34.56.65/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("12.34.56.65/32"))}, + {mustParseInet(t, "192.168.1.16/24"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("192.168.1.16/24"))}, + {mustParseInet(t, "255.0.0.0/8"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("255.0.0.0/8"))}, + {mustParseInet(t, "255.255.255.255/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("255.255.255.255/32"))}, + {mustParseInet(t, "2607:f8b0:4009:80b::200e"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("2607:f8b0:4009:80b::200e/128"))}, + {mustParseInet(t, "::1/64"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::1/64"))}, + {mustParseInet(t, "::/0"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::/0"))}, + {mustParseInet(t, "::1/128"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::1/128"))}, + {mustParseInet(t, "2607:f8b0:4009:80b::200e/64"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("2607:f8b0:4009:80b::200e/64"))}, - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } + {netip.MustParsePrefix("0.0.0.0/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("0.0.0.0/32"))}, + {netip.MustParsePrefix("127.0.0.1/8"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("127.0.0.1/8"))}, + {netip.MustParsePrefix("12.34.56.65/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("12.34.56.65/32"))}, + {netip.MustParsePrefix("192.168.1.16/24"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("192.168.1.16/24"))}, + {netip.MustParsePrefix("255.0.0.0/8"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("255.0.0.0/8"))}, + {netip.MustParsePrefix("255.255.255.255/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("255.255.255.255/32"))}, + {netip.MustParsePrefix("::1/64"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::1/64"))}, + {netip.MustParsePrefix("::/0"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::/0"))}, + {netip.MustParsePrefix("::1/128"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::1/128"))}, + {netip.MustParsePrefix("2607:f8b0:4009:80b::200e/64"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("2607:f8b0:4009:80b::200e/64"))}, - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %#v, but result was %#v", i, tt.src, tt.expected, dst) - } - } + {netip.MustParseAddr("0.0.0.0"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("0.0.0.0"))}, + {netip.MustParseAddr("127.0.0.1"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("127.0.0.1"))}, + {netip.MustParseAddr("12.34.56.65"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("12.34.56.65"))}, + {netip.MustParseAddr("192.168.1.16"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("192.168.1.16"))}, + {netip.MustParseAddr("255.0.0.0"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("255.0.0.0"))}, + {netip.MustParseAddr("255.255.255.255"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("255.255.255.255"))}, + {netip.MustParseAddr("2607:f8b0:4009:80b::200e"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("2607:f8b0:4009:80b::200e"))}, + {netip.MustParseAddr("::1"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("::1"))}, + {netip.MustParseAddr("::"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("::"))}, + {netip.MustParseAddr("2607:f8b0:4009:80b::200e"), new(netip.Addr), isExpectedEq(netip.MustParseAddr("2607:f8b0:4009:80b::200e"))}, - pointerAllocTests := []struct { - src pgtype.Inet - dst interface{} - expected interface{} - }{ - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, - } + {nil, new(netip.Prefix), isExpectedEq(netip.Prefix{})}, + }) +} - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } +func TestCidrTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support cidr type (see https://github.com/cockroachdb/cockroach/issues/18846)") - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "cidr", []pgxtest.ValueRoundTripTest{ + {mustParseInet(t, "0.0.0.0/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "0.0.0.0/32"))}, + {mustParseInet(t, "127.0.0.1/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "127.0.0.1/32"))}, + {mustParseInet(t, "12.34.56.0/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "12.34.56.0/32"))}, + {mustParseInet(t, "192.168.1.0/24"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "192.168.1.0/24"))}, + {mustParseInet(t, "255.0.0.0/8"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "255.0.0.0/8"))}, + {mustParseInet(t, "::/128"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::/128"))}, + {mustParseInet(t, "::/0"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::/0"))}, + {mustParseInet(t, "::1/128"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::1/128"))}, + {mustParseInet(t, "2607:f8b0:4009:80b::200e/128"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "2607:f8b0:4009:80b::200e/128"))}, - errorTests := []struct { - src pgtype.Inet - dst interface{} - }{ - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.0/16"), Status: pgtype.Present}, dst: &ip}, - {src: pgtype.Inet{Status: pgtype.Null}, dst: &ipnet}, - } + {netip.MustParsePrefix("0.0.0.0/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("0.0.0.0/32"))}, + {netip.MustParsePrefix("127.0.0.1/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("127.0.0.1/32"))}, + {netip.MustParsePrefix("12.34.56.0/32"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("12.34.56.0/32"))}, + {netip.MustParsePrefix("192.168.1.0/24"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("192.168.1.0/24"))}, + {netip.MustParsePrefix("255.0.0.0/8"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("255.0.0.0/8"))}, + {netip.MustParsePrefix("::/128"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::/128"))}, + {netip.MustParsePrefix("::/0"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::/0"))}, + {netip.MustParsePrefix("::1/128"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("::1/128"))}, + {netip.MustParsePrefix("2607:f8b0:4009:80b::200e/128"), new(netip.Prefix), isExpectedEq(netip.MustParsePrefix("2607:f8b0:4009:80b::200e/128"))}, - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } + {nil, new(netip.Prefix), isExpectedEq(netip.Prefix{})}, + }) } diff --git a/pgtype/int.go b/pgtype/int.go new file mode 100644 index 000000000..d1b8eb612 --- /dev/null +++ b/pgtype/int.go @@ -0,0 +1,1990 @@ +// Code generated from pgtype/int.go.erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "encoding/json" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type Int64Scanner interface { + ScanInt64(Int8) error +} + +type Int64Valuer interface { + Int64Value() (Int8, error) +} + +type Int2 struct { + Int16 int16 + Valid bool +} + +// ScanInt64 implements the [Int64Scanner] interface. +func (dst *Int2) ScanInt64(n Int8) error { + if !n.Valid { + *dst = Int2{} + return nil + } + + if n.Int64 < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for Int2", n.Int64) + } + if n.Int64 > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n.Int64) + } + *dst = Int2{Int16: int16(n.Int64), Valid: true} + + return nil +} + +// Int64Value implements the [Int64Valuer] interface. +func (n Int2) Int64Value() (Int8, error) { + return Int8{Int64: int64(n.Int16), Valid: n.Valid}, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (dst *Int2) Scan(src any) error { + if src == nil { + *dst = Int2{} + return nil + } + + var n int64 + + switch src := src.(type) { + case int64: + n = src + case string: + var err error + n, err = strconv.ParseInt(src, 10, 16) + if err != nil { + return err + } + case []byte: + var err error + n, err = strconv.ParseInt(string(src), 10, 16) + if err != nil { + return err + } + default: + return fmt.Errorf("cannot scan %T", src) + } + + if n < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n) + } + if n > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n) + } + *dst = Int2{Int16: int16(n), Valid: true} + + return nil +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (src Int2) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return int64(src.Int16), nil +} + +// MarshalJSON implements the [encoding/json.Marshaler] interface. +func (src Int2) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + return []byte(strconv.FormatInt(int64(src.Int16), 10)), nil +} + +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. +func (dst *Int2) UnmarshalJSON(b []byte) error { + var n *int16 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *dst = Int2{} + } else { + *dst = Int2{Int16: *n, Valid: true} + } + + return nil +} + +type Int2Codec struct{} + +func (Int2Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Int2Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Int2Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case int16: + return encodePlanInt2CodecBinaryInt16{} + case Int64Valuer: + return encodePlanInt2CodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case int16: + return encodePlanInt2CodecTextInt16{} + case Int64Valuer: + return encodePlanInt2CodecTextInt64Valuer{} + } + } + + return nil +} + +type encodePlanInt2CodecBinaryInt16 struct{} + +func (encodePlanInt2CodecBinaryInt16) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(int16) + return pgio.AppendInt16(buf, int16(n)), nil +} + +type encodePlanInt2CodecTextInt16 struct{} + +func (encodePlanInt2CodecTextInt16) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(int16) + return append(buf, strconv.FormatInt(int64(n), 10)...), nil +} + +type encodePlanInt2CodecBinaryInt64Valuer struct{} + +func (encodePlanInt2CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int64 > math.MaxInt16 { + return nil, fmt.Errorf("%d is greater than maximum value for int2", n.Int64) + } + if n.Int64 < math.MinInt16 { + return nil, fmt.Errorf("%d is less than minimum value for int2", n.Int64) + } + + return pgio.AppendInt16(buf, int16(n.Int64)), nil +} + +type encodePlanInt2CodecTextInt64Valuer struct{} + +func (encodePlanInt2CodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int64 > math.MaxInt16 { + return nil, fmt.Errorf("%d is greater than maximum value for int2", n.Int64) + } + if n.Int64 < math.MinInt16 { + return nil, fmt.Errorf("%d is less than minimum value for int2", n.Int64) + } + + return append(buf, strconv.FormatInt(n.Int64, 10)...), nil +} + +func (Int2Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *int8: + return scanPlanBinaryInt2ToInt8{} + case *int16: + return scanPlanBinaryInt2ToInt16{} + case *int32: + return scanPlanBinaryInt2ToInt32{} + case *int64: + return scanPlanBinaryInt2ToInt64{} + case *int: + return scanPlanBinaryInt2ToInt{} + case *uint8: + return scanPlanBinaryInt2ToUint8{} + case *uint16: + return scanPlanBinaryInt2ToUint16{} + case *uint32: + return scanPlanBinaryInt2ToUint32{} + case *uint64: + return scanPlanBinaryInt2ToUint64{} + case *uint: + return scanPlanBinaryInt2ToUint{} + case Int64Scanner: + return scanPlanBinaryInt2ToInt64Scanner{} + case TextScanner: + return scanPlanBinaryInt2ToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case *int8: + return scanPlanTextAnyToInt8{} + case *int16: + return scanPlanTextAnyToInt16{} + case *int32: + return scanPlanTextAnyToInt32{} + case *int64: + return scanPlanTextAnyToInt64{} + case *int: + return scanPlanTextAnyToInt{} + case *uint8: + return scanPlanTextAnyToUint8{} + case *uint16: + return scanPlanTextAnyToUint16{} + case *uint32: + return scanPlanTextAnyToUint32{} + case *uint64: + return scanPlanTextAnyToUint64{} + case *uint: + return scanPlanTextAnyToUint{} + case Int64Scanner: + return scanPlanTextAnyToInt64Scanner{} + } + } + + return nil +} + +func (c Int2Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var n int64 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +func (c Int2Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var n int16 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +type scanPlanBinaryInt2ToInt8 struct{} + +func (scanPlanBinaryInt2ToInt8) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + p, ok := (dst).(*int8) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int16(binary.BigEndian.Uint16(src)) + if n < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", n) + } else if n > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", n) + } + + *p = int8(n) + + return nil +} + +type scanPlanBinaryInt2ToUint8 struct{} + +func (scanPlanBinaryInt2ToUint8) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for uint2: %v", len(src)) + } + + p, ok := (dst).(*uint8) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int16(binary.BigEndian.Uint16(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint8", n) + } + + if n > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", n) + } + + *p = uint8(n) + + return nil +} + +type scanPlanBinaryInt2ToInt16 struct{} + +func (scanPlanBinaryInt2ToInt16) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + p, ok := (dst).(*int16) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int16(binary.BigEndian.Uint16(src)) + + return nil +} + +type scanPlanBinaryInt2ToUint16 struct{} + +func (scanPlanBinaryInt2ToUint16) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for uint2: %v", len(src)) + } + + p, ok := (dst).(*uint16) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int16(binary.BigEndian.Uint16(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint16", n) + } + + *p = uint16(n) + + return nil +} + +type scanPlanBinaryInt2ToInt32 struct{} + +func (scanPlanBinaryInt2ToInt32) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + p, ok := (dst).(*int32) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int32(int16(binary.BigEndian.Uint16(src))) + + return nil +} + +type scanPlanBinaryInt2ToUint32 struct{} + +func (scanPlanBinaryInt2ToUint32) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for uint2: %v", len(src)) + } + + p, ok := (dst).(*uint32) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int16(binary.BigEndian.Uint16(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint32", n) + } + + *p = uint32(n) + + return nil +} + +type scanPlanBinaryInt2ToInt64 struct{} + +func (scanPlanBinaryInt2ToInt64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + p, ok := (dst).(*int64) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int64(int16(binary.BigEndian.Uint16(src))) + + return nil +} + +type scanPlanBinaryInt2ToUint64 struct{} + +func (scanPlanBinaryInt2ToUint64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for uint2: %v", len(src)) + } + + p, ok := (dst).(*uint64) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int16(binary.BigEndian.Uint16(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint64", n) + } + + *p = uint64(n) + + return nil +} + +type scanPlanBinaryInt2ToInt struct{} + +func (scanPlanBinaryInt2ToInt) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + p, ok := (dst).(*int) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int(int16(binary.BigEndian.Uint16(src))) + + return nil +} + +type scanPlanBinaryInt2ToUint struct{} + +func (scanPlanBinaryInt2ToUint) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for uint2: %v", len(src)) + } + + p, ok := (dst).(*uint) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(int16(binary.BigEndian.Uint16(src))) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint", n) + } + + *p = uint(n) + + return nil +} + +type scanPlanBinaryInt2ToInt64Scanner struct{} + +func (scanPlanBinaryInt2ToInt64Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Int64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanInt64(Int8{}) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + n := int64(int16(binary.BigEndian.Uint16(src))) + + return s.ScanInt64(Int8{Int64: n, Valid: true}) +} + +type scanPlanBinaryInt2ToTextScanner struct{} + +func (scanPlanBinaryInt2ToTextScanner) Scan(src []byte, dst any) error { + s, ok := (dst).(TextScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + n := int64(int16(binary.BigEndian.Uint16(src))) + + return s.ScanText(Text{String: strconv.FormatInt(n, 10), Valid: true}) +} + +type Int4 struct { + Int32 int32 + Valid bool +} + +// ScanInt64 implements the [Int64Scanner] interface. +func (dst *Int4) ScanInt64(n Int8) error { + if !n.Valid { + *dst = Int4{} + return nil + } + + if n.Int64 < math.MinInt32 { + return fmt.Errorf("%d is less than minimum value for Int4", n.Int64) + } + if n.Int64 > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n.Int64) + } + *dst = Int4{Int32: int32(n.Int64), Valid: true} + + return nil +} + +// Int64Value implements the [Int64Valuer] interface. +func (n Int4) Int64Value() (Int8, error) { + return Int8{Int64: int64(n.Int32), Valid: n.Valid}, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (dst *Int4) Scan(src any) error { + if src == nil { + *dst = Int4{} + return nil + } + + var n int64 + + switch src := src.(type) { + case int64: + n = src + case string: + var err error + n, err = strconv.ParseInt(src, 10, 32) + if err != nil { + return err + } + case []byte: + var err error + n, err = strconv.ParseInt(string(src), 10, 32) + if err != nil { + return err + } + default: + return fmt.Errorf("cannot scan %T", src) + } + + if n < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n) + } + if n > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n) + } + *dst = Int4{Int32: int32(n), Valid: true} + + return nil +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (src Int4) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return int64(src.Int32), nil +} + +// MarshalJSON implements the [encoding/json.Marshaler] interface. +func (src Int4) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + return []byte(strconv.FormatInt(int64(src.Int32), 10)), nil +} + +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. +func (dst *Int4) UnmarshalJSON(b []byte) error { + var n *int32 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *dst = Int4{} + } else { + *dst = Int4{Int32: *n, Valid: true} + } + + return nil +} + +type Int4Codec struct{} + +func (Int4Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Int4Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Int4Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case int32: + return encodePlanInt4CodecBinaryInt32{} + case Int64Valuer: + return encodePlanInt4CodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case int32: + return encodePlanInt4CodecTextInt32{} + case Int64Valuer: + return encodePlanInt4CodecTextInt64Valuer{} + } + } + + return nil +} + +type encodePlanInt4CodecBinaryInt32 struct{} + +func (encodePlanInt4CodecBinaryInt32) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(int32) + return pgio.AppendInt32(buf, int32(n)), nil +} + +type encodePlanInt4CodecTextInt32 struct{} + +func (encodePlanInt4CodecTextInt32) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(int32) + return append(buf, strconv.FormatInt(int64(n), 10)...), nil +} + +type encodePlanInt4CodecBinaryInt64Valuer struct{} + +func (encodePlanInt4CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int64 > math.MaxInt32 { + return nil, fmt.Errorf("%d is greater than maximum value for int4", n.Int64) + } + if n.Int64 < math.MinInt32 { + return nil, fmt.Errorf("%d is less than minimum value for int4", n.Int64) + } + + return pgio.AppendInt32(buf, int32(n.Int64)), nil +} + +type encodePlanInt4CodecTextInt64Valuer struct{} + +func (encodePlanInt4CodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int64 > math.MaxInt32 { + return nil, fmt.Errorf("%d is greater than maximum value for int4", n.Int64) + } + if n.Int64 < math.MinInt32 { + return nil, fmt.Errorf("%d is less than minimum value for int4", n.Int64) + } + + return append(buf, strconv.FormatInt(n.Int64, 10)...), nil +} + +func (Int4Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *int8: + return scanPlanBinaryInt4ToInt8{} + case *int16: + return scanPlanBinaryInt4ToInt16{} + case *int32: + return scanPlanBinaryInt4ToInt32{} + case *int64: + return scanPlanBinaryInt4ToInt64{} + case *int: + return scanPlanBinaryInt4ToInt{} + case *uint8: + return scanPlanBinaryInt4ToUint8{} + case *uint16: + return scanPlanBinaryInt4ToUint16{} + case *uint32: + return scanPlanBinaryInt4ToUint32{} + case *uint64: + return scanPlanBinaryInt4ToUint64{} + case *uint: + return scanPlanBinaryInt4ToUint{} + case Int64Scanner: + return scanPlanBinaryInt4ToInt64Scanner{} + case TextScanner: + return scanPlanBinaryInt4ToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case *int8: + return scanPlanTextAnyToInt8{} + case *int16: + return scanPlanTextAnyToInt16{} + case *int32: + return scanPlanTextAnyToInt32{} + case *int64: + return scanPlanTextAnyToInt64{} + case *int: + return scanPlanTextAnyToInt{} + case *uint8: + return scanPlanTextAnyToUint8{} + case *uint16: + return scanPlanTextAnyToUint16{} + case *uint32: + return scanPlanTextAnyToUint32{} + case *uint64: + return scanPlanTextAnyToUint64{} + case *uint: + return scanPlanTextAnyToUint{} + case Int64Scanner: + return scanPlanTextAnyToInt64Scanner{} + } + } + + return nil +} + +func (c Int4Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var n int64 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +func (c Int4Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var n int32 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +type scanPlanBinaryInt4ToInt8 struct{} + +func (scanPlanBinaryInt4ToInt8) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + p, ok := (dst).(*int8) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int32(binary.BigEndian.Uint32(src)) + if n < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", n) + } else if n > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", n) + } + + *p = int8(n) + + return nil +} + +type scanPlanBinaryInt4ToUint8 struct{} + +func (scanPlanBinaryInt4ToUint8) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint4: %v", len(src)) + } + + p, ok := (dst).(*uint8) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int32(binary.BigEndian.Uint32(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint8", n) + } + + if n > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", n) + } + + *p = uint8(n) + + return nil +} + +type scanPlanBinaryInt4ToInt16 struct{} + +func (scanPlanBinaryInt4ToInt16) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + p, ok := (dst).(*int16) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int32(binary.BigEndian.Uint32(src)) + if n < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for int16", n) + } else if n > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for int16", n) + } + + *p = int16(n) + + return nil +} + +type scanPlanBinaryInt4ToUint16 struct{} + +func (scanPlanBinaryInt4ToUint16) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint4: %v", len(src)) + } + + p, ok := (dst).(*uint16) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int32(binary.BigEndian.Uint32(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint16", n) + } + + if n > math.MaxUint16 { + return fmt.Errorf("%d is greater than maximum value for uint16", n) + } + + *p = uint16(n) + + return nil +} + +type scanPlanBinaryInt4ToInt32 struct{} + +func (scanPlanBinaryInt4ToInt32) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + p, ok := (dst).(*int32) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int32(binary.BigEndian.Uint32(src)) + + return nil +} + +type scanPlanBinaryInt4ToUint32 struct{} + +func (scanPlanBinaryInt4ToUint32) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint4: %v", len(src)) + } + + p, ok := (dst).(*uint32) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int32(binary.BigEndian.Uint32(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint32", n) + } + + *p = uint32(n) + + return nil +} + +type scanPlanBinaryInt4ToInt64 struct{} + +func (scanPlanBinaryInt4ToInt64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + p, ok := (dst).(*int64) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int64(int32(binary.BigEndian.Uint32(src))) + + return nil +} + +type scanPlanBinaryInt4ToUint64 struct{} + +func (scanPlanBinaryInt4ToUint64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint4: %v", len(src)) + } + + p, ok := (dst).(*uint64) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int32(binary.BigEndian.Uint32(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint64", n) + } + + *p = uint64(n) + + return nil +} + +type scanPlanBinaryInt4ToInt struct{} + +func (scanPlanBinaryInt4ToInt) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + p, ok := (dst).(*int) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int(int32(binary.BigEndian.Uint32(src))) + + return nil +} + +type scanPlanBinaryInt4ToUint struct{} + +func (scanPlanBinaryInt4ToUint) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint4: %v", len(src)) + } + + p, ok := (dst).(*uint) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(int32(binary.BigEndian.Uint32(src))) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint", n) + } + + *p = uint(n) + + return nil +} + +type scanPlanBinaryInt4ToInt64Scanner struct{} + +func (scanPlanBinaryInt4ToInt64Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Int64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanInt64(Int8{}) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + n := int64(int32(binary.BigEndian.Uint32(src))) + + return s.ScanInt64(Int8{Int64: n, Valid: true}) +} + +type scanPlanBinaryInt4ToTextScanner struct{} + +func (scanPlanBinaryInt4ToTextScanner) Scan(src []byte, dst any) error { + s, ok := (dst).(TextScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + n := int64(int32(binary.BigEndian.Uint32(src))) + + return s.ScanText(Text{String: strconv.FormatInt(n, 10), Valid: true}) +} + +type Int8 struct { + Int64 int64 + Valid bool +} + +// ScanInt64 implements the [Int64Scanner] interface. +func (dst *Int8) ScanInt64(n Int8) error { + if !n.Valid { + *dst = Int8{} + return nil + } + + if n.Int64 < math.MinInt64 { + return fmt.Errorf("%d is less than minimum value for Int8", n.Int64) + } + if n.Int64 > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n.Int64) + } + *dst = Int8{Int64: int64(n.Int64), Valid: true} + + return nil +} + +// Int64Value implements the [Int64Valuer] interface. +func (n Int8) Int64Value() (Int8, error) { + return Int8{Int64: int64(n.Int64), Valid: n.Valid}, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (dst *Int8) Scan(src any) error { + if src == nil { + *dst = Int8{} + return nil + } + + var n int64 + + switch src := src.(type) { + case int64: + n = src + case string: + var err error + n, err = strconv.ParseInt(src, 10, 64) + if err != nil { + return err + } + case []byte: + var err error + n, err = strconv.ParseInt(string(src), 10, 64) + if err != nil { + return err + } + default: + return fmt.Errorf("cannot scan %T", src) + } + + if n < math.MinInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n) + } + if n > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n) + } + *dst = Int8{Int64: int64(n), Valid: true} + + return nil +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (src Int8) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return int64(src.Int64), nil +} + +// MarshalJSON implements the [encoding/json.Marshaler] interface. +func (src Int8) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + return []byte(strconv.FormatInt(int64(src.Int64), 10)), nil +} + +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. +func (dst *Int8) UnmarshalJSON(b []byte) error { + var n *int64 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *dst = Int8{} + } else { + *dst = Int8{Int64: *n, Valid: true} + } + + return nil +} + +type Int8Codec struct{} + +func (Int8Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Int8Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Int8Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case int64: + return encodePlanInt8CodecBinaryInt64{} + case Int64Valuer: + return encodePlanInt8CodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case int64: + return encodePlanInt8CodecTextInt64{} + case Int64Valuer: + return encodePlanInt8CodecTextInt64Valuer{} + } + } + + return nil +} + +type encodePlanInt8CodecBinaryInt64 struct{} + +func (encodePlanInt8CodecBinaryInt64) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(int64) + return pgio.AppendInt64(buf, int64(n)), nil +} + +type encodePlanInt8CodecTextInt64 struct{} + +func (encodePlanInt8CodecTextInt64) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(int64) + return append(buf, strconv.FormatInt(int64(n), 10)...), nil +} + +type encodePlanInt8CodecBinaryInt64Valuer struct{} + +func (encodePlanInt8CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int64 > math.MaxInt64 { + return nil, fmt.Errorf("%d is greater than maximum value for int8", n.Int64) + } + if n.Int64 < math.MinInt64 { + return nil, fmt.Errorf("%d is less than minimum value for int8", n.Int64) + } + + return pgio.AppendInt64(buf, int64(n.Int64)), nil +} + +type encodePlanInt8CodecTextInt64Valuer struct{} + +func (encodePlanInt8CodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int64 > math.MaxInt64 { + return nil, fmt.Errorf("%d is greater than maximum value for int8", n.Int64) + } + if n.Int64 < math.MinInt64 { + return nil, fmt.Errorf("%d is less than minimum value for int8", n.Int64) + } + + return append(buf, strconv.FormatInt(n.Int64, 10)...), nil +} + +func (Int8Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *int8: + return scanPlanBinaryInt8ToInt8{} + case *int16: + return scanPlanBinaryInt8ToInt16{} + case *int32: + return scanPlanBinaryInt8ToInt32{} + case *int64: + return scanPlanBinaryInt8ToInt64{} + case *int: + return scanPlanBinaryInt8ToInt{} + case *uint8: + return scanPlanBinaryInt8ToUint8{} + case *uint16: + return scanPlanBinaryInt8ToUint16{} + case *uint32: + return scanPlanBinaryInt8ToUint32{} + case *uint64: + return scanPlanBinaryInt8ToUint64{} + case *uint: + return scanPlanBinaryInt8ToUint{} + case Int64Scanner: + return scanPlanBinaryInt8ToInt64Scanner{} + case TextScanner: + return scanPlanBinaryInt8ToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case *int8: + return scanPlanTextAnyToInt8{} + case *int16: + return scanPlanTextAnyToInt16{} + case *int32: + return scanPlanTextAnyToInt32{} + case *int64: + return scanPlanTextAnyToInt64{} + case *int: + return scanPlanTextAnyToInt{} + case *uint8: + return scanPlanTextAnyToUint8{} + case *uint16: + return scanPlanTextAnyToUint16{} + case *uint32: + return scanPlanTextAnyToUint32{} + case *uint64: + return scanPlanTextAnyToUint64{} + case *uint: + return scanPlanTextAnyToUint{} + case Int64Scanner: + return scanPlanTextAnyToInt64Scanner{} + } + } + + return nil +} + +func (c Int8Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var n int64 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +func (c Int8Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var n int64 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +type scanPlanBinaryInt8ToInt8 struct{} + +func (scanPlanBinaryInt8ToInt8) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + p, ok := (dst).(*int8) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", n) + } else if n > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", n) + } + + *p = int8(n) + + return nil +} + +type scanPlanBinaryInt8ToUint8 struct{} + +func (scanPlanBinaryInt8ToUint8) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint8: %v", len(src)) + } + + p, ok := (dst).(*uint8) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint8", n) + } + + if n > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", n) + } + + *p = uint8(n) + + return nil +} + +type scanPlanBinaryInt8ToInt16 struct{} + +func (scanPlanBinaryInt8ToInt16) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + p, ok := (dst).(*int16) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for int16", n) + } else if n > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for int16", n) + } + + *p = int16(n) + + return nil +} + +type scanPlanBinaryInt8ToUint16 struct{} + +func (scanPlanBinaryInt8ToUint16) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint8: %v", len(src)) + } + + p, ok := (dst).(*uint16) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint16", n) + } + + if n > math.MaxUint16 { + return fmt.Errorf("%d is greater than maximum value for uint16", n) + } + + *p = uint16(n) + + return nil +} + +type scanPlanBinaryInt8ToInt32 struct{} + +func (scanPlanBinaryInt8ToInt32) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + p, ok := (dst).(*int32) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < math.MinInt32 { + return fmt.Errorf("%d is less than minimum value for int32", n) + } else if n > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for int32", n) + } + + *p = int32(n) + + return nil +} + +type scanPlanBinaryInt8ToUint32 struct{} + +func (scanPlanBinaryInt8ToUint32) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint8: %v", len(src)) + } + + p, ok := (dst).(*uint32) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint32", n) + } + + if n > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for uint32", n) + } + + *p = uint32(n) + + return nil +} + +type scanPlanBinaryInt8ToInt64 struct{} + +func (scanPlanBinaryInt8ToInt64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + p, ok := (dst).(*int64) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int64(binary.BigEndian.Uint64(src)) + + return nil +} + +type scanPlanBinaryInt8ToUint64 struct{} + +func (scanPlanBinaryInt8ToUint64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint8: %v", len(src)) + } + + p, ok := (dst).(*uint64) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint64", n) + } + + *p = uint64(n) + + return nil +} + +type scanPlanBinaryInt8ToInt struct{} + +func (scanPlanBinaryInt8ToInt) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + p, ok := (dst).(*int) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(binary.BigEndian.Uint64(src)) + if n < math.MinInt { + return fmt.Errorf("%d is less than minimum value for int", n) + } else if n > math.MaxInt { + return fmt.Errorf("%d is greater than maximum value for int", n) + } + + *p = int(n) + + return nil +} + +type scanPlanBinaryInt8ToUint struct{} + +func (scanPlanBinaryInt8ToUint) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint8: %v", len(src)) + } + + p, ok := (dst).(*uint) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(int64(binary.BigEndian.Uint64(src))) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint", n) + } + + if uint64(n) > math.MaxUint { + return fmt.Errorf("%d is greater than maximum value for uint", n) + } + + *p = uint(n) + + return nil +} + +type scanPlanBinaryInt8ToInt64Scanner struct{} + +func (scanPlanBinaryInt8ToInt64Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Int64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanInt64(Int8{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + n := int64(int64(binary.BigEndian.Uint64(src))) + + return s.ScanInt64(Int8{Int64: n, Valid: true}) +} + +type scanPlanBinaryInt8ToTextScanner struct{} + +func (scanPlanBinaryInt8ToTextScanner) Scan(src []byte, dst any) error { + s, ok := (dst).(TextScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + n := int64(int64(binary.BigEndian.Uint64(src))) + + return s.ScanText(Text{String: strconv.FormatInt(n, 10), Valid: true}) +} + +type scanPlanTextAnyToInt8 struct{} + +func (scanPlanTextAnyToInt8) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*int8) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 8) + if err != nil { + return err + } + + *p = int8(n) + return nil +} + +type scanPlanTextAnyToUint8 struct{} + +func (scanPlanTextAnyToUint8) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*uint8) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 8) + if err != nil { + return err + } + + *p = uint8(n) + return nil +} + +type scanPlanTextAnyToInt16 struct{} + +func (scanPlanTextAnyToInt16) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*int16) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 16) + if err != nil { + return err + } + + *p = int16(n) + return nil +} + +type scanPlanTextAnyToUint16 struct{} + +func (scanPlanTextAnyToUint16) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*uint16) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 16) + if err != nil { + return err + } + + *p = uint16(n) + return nil +} + +type scanPlanTextAnyToInt32 struct{} + +func (scanPlanTextAnyToInt32) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*int32) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 32) + if err != nil { + return err + } + + *p = int32(n) + return nil +} + +type scanPlanTextAnyToUint32 struct{} + +func (scanPlanTextAnyToUint32) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*uint32) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 32) + if err != nil { + return err + } + + *p = uint32(n) + return nil +} + +type scanPlanTextAnyToInt64 struct{} + +func (scanPlanTextAnyToInt64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*int64) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 64) + if err != nil { + return err + } + + *p = int64(n) + return nil +} + +type scanPlanTextAnyToUint64 struct{} + +func (scanPlanTextAnyToUint64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*uint64) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 64) + if err != nil { + return err + } + + *p = uint64(n) + return nil +} + +type scanPlanTextAnyToInt struct{} + +func (scanPlanTextAnyToInt) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*int) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 0) + if err != nil { + return err + } + + *p = int(n) + return nil +} + +type scanPlanTextAnyToUint struct{} + +func (scanPlanTextAnyToUint) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*uint) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 0) + if err != nil { + return err + } + + *p = uint(n) + return nil +} + +type scanPlanTextAnyToInt64Scanner struct{} + +func (scanPlanTextAnyToInt64Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Int64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanInt64(Int8{}) + } + + n, err := strconv.ParseInt(string(src), 10, 64) + if err != nil { + return err + } + + err = s.ScanInt64(Int8{Int64: n, Valid: true}) + if err != nil { + return err + } + + return nil +} diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb new file mode 100644 index 000000000..c2d40f60b --- /dev/null +++ b/pgtype/int.go.erb @@ -0,0 +1,551 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "encoding/json" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type Int64Scanner interface { + ScanInt64(Int8) error +} + +type Int64Valuer interface { + Int64Value() (Int8, error) +} + + +<% [2, 4, 8].each do |pg_byte_size| %> +<% pg_bit_size = pg_byte_size * 8 %> +type Int<%= pg_byte_size %> struct { + Int<%= pg_bit_size %> int<%= pg_bit_size %> + Valid bool +} + +// ScanInt64 implements the [Int64Scanner] interface. +func (dst *Int<%= pg_byte_size %>) ScanInt64(n Int8) error { + if !n.Valid { + *dst = Int<%= pg_byte_size %>{} + return nil + } + + if n.Int64 < math.MinInt<%= pg_bit_size %> { + return fmt.Errorf("%d is less than minimum value for Int<%= pg_byte_size %>", n.Int64) + } + if n.Int64 > math.MaxInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n.Int64) + } + *dst = Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: int<%= pg_bit_size %>(n.Int64), Valid: true} + + return nil +} + +// Int64Value implements the [Int64Valuer] interface. +func (n Int<%= pg_byte_size %>) Int64Value() (Int8, error) { + return Int8{Int64: int64(n.Int<%= pg_bit_size %>), Valid: n.Valid}, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (dst *Int<%= pg_byte_size %>) Scan(src any) error { + if src == nil { + *dst = Int<%= pg_byte_size %>{} + return nil + } + + var n int64 + + switch src := src.(type) { + case int64: + n = src + case string: + var err error + n, err = strconv.ParseInt(src, 10, <%= pg_bit_size %>) + if err != nil { + return err + } + case []byte: + var err error + n, err = strconv.ParseInt(string(src), 10, <%= pg_bit_size %>) + if err != nil { + return err + } + default: + return fmt.Errorf("cannot scan %T", src) + } + + if n < math.MinInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n) + } + if n > math.MaxInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n) + } + *dst = Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: int<%= pg_bit_size %>(n), Valid: true} + + return nil +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return int64(src.Int<%= pg_bit_size %>), nil +} + +// MarshalJSON implements the [encoding/json.Marshaler] interface. +func (src Int<%= pg_byte_size %>) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + return []byte(strconv.FormatInt(int64(src.Int<%= pg_bit_size %>), 10)), nil +} + +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. +func (dst *Int<%= pg_byte_size %>) UnmarshalJSON(b []byte) error { + var n *int<%= pg_bit_size %> + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *dst = Int<%= pg_byte_size %>{} + } else { + *dst = Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: *n, Valid: true} + } + + return nil +} + +type Int<%= pg_byte_size %>Codec struct{} + +func (Int<%= pg_byte_size %>Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Int<%= pg_byte_size %>Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Int<%= pg_byte_size %>Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case int<%= pg_bit_size %>: + return encodePlanInt<%= pg_byte_size %>CodecBinaryInt<%= pg_bit_size %>{} + case Int64Valuer: + return encodePlanInt<%= pg_byte_size %>CodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case int<%= pg_bit_size %>: + return encodePlanInt<%= pg_byte_size %>CodecTextInt<%= pg_bit_size %>{} + case Int64Valuer: + return encodePlanInt<%= pg_byte_size %>CodecTextInt64Valuer{} + } + } + + return nil +} + +type encodePlanInt<%= pg_byte_size %>CodecBinaryInt<%= pg_bit_size %> struct{} + +func (encodePlanInt<%= pg_byte_size %>CodecBinaryInt<%= pg_bit_size %>) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(int<%= pg_bit_size %>) + return pgio.AppendInt<%= pg_bit_size %>(buf, int<%= pg_bit_size %>(n)), nil +} + +type encodePlanInt<%= pg_byte_size %>CodecTextInt<%= pg_bit_size %> struct{} + +func (encodePlanInt<%= pg_byte_size %>CodecTextInt<%= pg_bit_size %>) Encode(value any, buf []byte) (newBuf []byte, err error) { + n := value.(int<%= pg_bit_size %>) + return append(buf, strconv.FormatInt(int64(n), 10)...), nil +} + +type encodePlanInt<%= pg_byte_size %>CodecBinaryInt64Valuer struct{} + +func (encodePlanInt<%= pg_byte_size %>CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int64 > math.MaxInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is greater than maximum value for int<%= pg_byte_size %>", n.Int64) + } + if n.Int64 < math.MinInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is less than minimum value for int<%= pg_byte_size %>", n.Int64) + } + + return pgio.AppendInt<%= pg_bit_size %>(buf, int<%= pg_bit_size %>(n.Int64)), nil +} + +type encodePlanInt<%= pg_byte_size %>CodecTextInt64Valuer struct{} + +func (encodePlanInt<%= pg_byte_size %>CodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int64 > math.MaxInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is greater than maximum value for int<%= pg_byte_size %>", n.Int64) + } + if n.Int64 < math.MinInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is less than minimum value for int<%= pg_byte_size %>", n.Int64) + } + + return append(buf, strconv.FormatInt(n.Int64, 10)...), nil +} + +func (Int<%= pg_byte_size %>Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *int8: + return scanPlanBinaryInt<%= pg_byte_size %>ToInt8{} + case *int16: + return scanPlanBinaryInt<%= pg_byte_size %>ToInt16{} + case *int32: + return scanPlanBinaryInt<%= pg_byte_size %>ToInt32{} + case *int64: + return scanPlanBinaryInt<%= pg_byte_size %>ToInt64{} + case *int: + return scanPlanBinaryInt<%= pg_byte_size %>ToInt{} + case *uint8: + return scanPlanBinaryInt<%= pg_byte_size %>ToUint8{} + case *uint16: + return scanPlanBinaryInt<%= pg_byte_size %>ToUint16{} + case *uint32: + return scanPlanBinaryInt<%= pg_byte_size %>ToUint32{} + case *uint64: + return scanPlanBinaryInt<%= pg_byte_size %>ToUint64{} + case *uint: + return scanPlanBinaryInt<%= pg_byte_size %>ToUint{} + case Int64Scanner: + return scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner{} + case TextScanner: + return scanPlanBinaryInt<%= pg_byte_size %>ToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case *int8: + return scanPlanTextAnyToInt8{} + case *int16: + return scanPlanTextAnyToInt16{} + case *int32: + return scanPlanTextAnyToInt32{} + case *int64: + return scanPlanTextAnyToInt64{} + case *int: + return scanPlanTextAnyToInt{} + case *uint8: + return scanPlanTextAnyToUint8{} + case *uint16: + return scanPlanTextAnyToUint16{} + case *uint32: + return scanPlanTextAnyToUint32{} + case *uint64: + return scanPlanTextAnyToUint64{} + case *uint: + return scanPlanTextAnyToUint{} + case Int64Scanner: + return scanPlanTextAnyToInt64Scanner{} + } + } + + return nil +} + +func (c Int<%= pg_byte_size %>Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var n int64 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +func (c Int<%= pg_byte_size %>Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var n int<%= pg_bit_size %> + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +<%# PostgreSQL binary format integer to fixed size Go integers %> +<% [8, 16, 32, 64].each do |dst_bit_size| %> +type scanPlanBinaryInt<%= pg_byte_size %>ToInt<%= dst_bit_size %> struct{} + +func (scanPlanBinaryInt<%= pg_byte_size %>ToInt<%= dst_bit_size %>) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != <%= pg_byte_size %> { + return fmt.Errorf("invalid length for int<%= pg_byte_size %>: %v", len(src)) + } + + p, ok := (dst).(*int<%= dst_bit_size %>) + if !ok { + return ErrScanTargetTypeChanged + } + + <% if dst_bit_size < pg_bit_size %> + n := int<%= pg_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src)) + if n < math.MinInt<%= dst_bit_size %> { + return fmt.Errorf("%d is less than minimum value for int<%= dst_bit_size %>", n) + } else if n > math.MaxInt<%= dst_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for int<%= dst_bit_size %>", n) + } + + *p = int<%= dst_bit_size %>(n) + <% elsif dst_bit_size == pg_bit_size %> + *p = int<%= dst_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src)) + <% else %> + *p = int<%= dst_bit_size %>(int<%= pg_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src))) + <% end %> + + return nil +} + +type scanPlanBinaryInt<%= pg_byte_size %>ToUint<%= dst_bit_size %> struct{} + +func (scanPlanBinaryInt<%= pg_byte_size %>ToUint<%= dst_bit_size %>) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != <%= pg_byte_size %> { + return fmt.Errorf("invalid length for uint<%= pg_byte_size %>: %v", len(src)) + } + + p, ok := (dst).(*uint<%= dst_bit_size %>) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int<%= pg_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src)) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint<%= dst_bit_size %>", n) + } + <% if dst_bit_size < pg_bit_size %> + if n > math.MaxUint<%= dst_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for uint<%= dst_bit_size %>", n) + } + <% end %> + *p = uint<%= dst_bit_size %>(n) + + return nil +} +<% end %> + +<%# PostgreSQL binary format integer to Go machine integers %> +type scanPlanBinaryInt<%= pg_byte_size %>ToInt struct{} + +func (scanPlanBinaryInt<%= pg_byte_size %>ToInt) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != <%= pg_byte_size %> { + return fmt.Errorf("invalid length for int<%= pg_byte_size %>: %v", len(src)) + } + + p, ok := (dst).(*int) + if !ok { + return ErrScanTargetTypeChanged + } + + <% if 32 < pg_bit_size %> + n := int64(binary.BigEndian.Uint<%= pg_bit_size %>(src)) + if n < math.MinInt { + return fmt.Errorf("%d is less than minimum value for int", n) + } else if n > math.MaxInt { + return fmt.Errorf("%d is greater than maximum value for int", n) + } + + *p = int(n) + <% else %> + *p = int(int<%= pg_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src))) + <% end %> + + return nil +} + +type scanPlanBinaryInt<%= pg_byte_size %>ToUint struct{} + +func (scanPlanBinaryInt<%= pg_byte_size %>ToUint) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != <%= pg_byte_size %> { + return fmt.Errorf("invalid length for uint<%= pg_byte_size %>: %v", len(src)) + } + + p, ok := (dst).(*uint) + if !ok { + return ErrScanTargetTypeChanged + } + + n := int64(int<%= pg_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src))) + if n < 0 { + return fmt.Errorf("%d is less than minimum value for uint", n) + } + <% if 32 < pg_bit_size %> + if uint64(n) > math.MaxUint { + return fmt.Errorf("%d is greater than maximum value for uint", n) + } + <% end %> + *p = uint(n) + + return nil +} + +<%# PostgreSQL binary format integer to Go Int64Scanner %> +type scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner struct{} + +func (scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Int64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanInt64(Int8{}) + } + + if len(src) != <%= pg_byte_size %> { + return fmt.Errorf("invalid length for int<%= pg_byte_size %>: %v", len(src)) + } + + + n := int64(int<%= pg_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src))) + + return s.ScanInt64(Int8{Int64: n, Valid: true}) +} + +<%# PostgreSQL binary format integer to Go TextScanner %> +type scanPlanBinaryInt<%= pg_byte_size %>ToTextScanner struct{} + +func (scanPlanBinaryInt<%= pg_byte_size %>ToTextScanner) Scan(src []byte, dst any) error { + s, ok := (dst).(TextScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != <%= pg_byte_size %> { + return fmt.Errorf("invalid length for int<%= pg_byte_size %>: %v", len(src)) + } + + + n := int64(int<%= pg_bit_size %>(binary.BigEndian.Uint<%= pg_bit_size %>(src))) + + return s.ScanText(Text{String: strconv.FormatInt(n, 10), Valid: true}) +} +<% end %> + +<%# Any text to all integer types %> +<% [ + ["8", 8], + ["16", 16], + ["32", 32], + ["64", 64], + ["", 0] +].each do |type_suffix, bit_size| %> +type scanPlanTextAnyToInt<%= type_suffix %> struct{} + +func (scanPlanTextAnyToInt<%= type_suffix %>) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*int<%= type_suffix %>) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, <%= bit_size %>) + if err != nil { + return err + } + + *p = int<%= type_suffix %>(n) + return nil +} + +type scanPlanTextAnyToUint<%= type_suffix %> struct{} + +func (scanPlanTextAnyToUint<%= type_suffix %>) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p, ok := (dst).(*uint<%= type_suffix %>) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, <%= bit_size %>) + if err != nil { + return err + } + + *p = uint<%= type_suffix %>(n) + return nil +} +<% end %> + +type scanPlanTextAnyToInt64Scanner struct{} + +func (scanPlanTextAnyToInt64Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Int64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanInt64(Int8{}) + } + + n, err := strconv.ParseInt(string(src), 10, 64) + if err != nil { + return err + } + + err = s.ScanInt64(Int8{Int64: n, Valid: true}) + if err != nil { + return err + } + + return nil +} diff --git a/pgtype/int2.go b/pgtype/int2.go deleted file mode 100644 index 6156ea772..000000000 --- a/pgtype/int2.go +++ /dev/null @@ -1,209 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "math" - "strconv" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type Int2 struct { - Int int16 - Status Status -} - -func (dst *Int2) Set(src interface{}) error { - if src == nil { - *dst = Int2{Status: Null} - return nil - } - - switch value := src.(type) { - case int8: - *dst = Int2{Int: int16(value), Status: Present} - case uint8: - *dst = Int2{Int: int16(value), Status: Present} - case int16: - *dst = Int2{Int: int16(value), Status: Present} - case uint16: - if value > math.MaxInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Status: Present} - case int32: - if value < math.MinInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", value) - } - if value > math.MaxInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Status: Present} - case uint32: - if value > math.MaxInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Status: Present} - case int64: - if value < math.MinInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", value) - } - if value > math.MaxInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Status: Present} - case uint64: - if value > math.MaxInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Status: Present} - case int: - if value < math.MinInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", value) - } - if value > math.MaxInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Status: Present} - case uint: - if value > math.MaxInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Status: Present} - case string: - num, err := strconv.ParseInt(value, 10, 16) - if err != nil { - return err - } - *dst = Int2{Int: int16(num), Status: Present} - default: - if originalSrc, ok := underlyingNumberType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to Int2", value) - } - - return nil -} - -func (dst *Int2) Get() interface{} { - switch dst.Status { - case Present: - return dst.Int - case Null: - return nil - default: - return dst.Status - } -} - -func (src *Int2) AssignTo(dst interface{}) error { - return int64AssignTo(int64(src.Int), src.Status, dst) -} - -func (dst *Int2) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int2{Status: Null} - return nil - } - - n, err := strconv.ParseInt(string(src), 10, 16) - if err != nil { - return err - } - - *dst = Int2{Int: int16(n), Status: Present} - return nil -} - -func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int2{Status: Null} - return nil - } - - if len(src) != 2 { - return errors.Errorf("invalid length for int2: %v", len(src)) - } - - n := int16(binary.BigEndian.Uint16(src)) - *dst = Int2{Int: n, Status: Present} - return nil -} - -func (src *Int2) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - return append(buf, strconv.FormatInt(int64(src.Int), 10)...), nil -} - -func (src *Int2) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - return pgio.AppendInt16(buf, src.Int), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Int2) Scan(src interface{}) error { - if src == nil { - *dst = Int2{Status: Null} - return nil - } - - switch src := src.(type) { - case int64: - if src < math.MinInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", src) - } - if src > math.MaxInt16 { - return errors.Errorf("%d is greater than maximum value for Int2", src) - } - *dst = Int2{Int: int16(src), Status: Present} - return nil - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *Int2) Value() (driver.Value, error) { - switch src.Status { - case Present: - return int64(src.Int), nil - case Null: - return nil, nil - default: - return nil, errUndefined - } -} - -func (src *Int2) MarshalJSON() ([]byte, error) { - switch src.Status { - case Present: - return []byte(strconv.FormatInt(int64(src.Int), 10)), nil - case Null: - return []byte("null"), nil - case Undefined: - return nil, errUndefined - } - - return nil, errBadStatus -} diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go deleted file mode 100644 index f50d92755..000000000 --- a/pgtype/int2_array.go +++ /dev/null @@ -1,328 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type Int2Array struct { - Elements []Int2 - Dimensions []ArrayDimension - Status Status -} - -func (dst *Int2Array) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = Int2Array{Status: Null} - return nil - } - - switch value := src.(type) { - - case []int16: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint16: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - default: - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to Int2Array", value) - } - - return nil -} - -func (dst *Int2Array) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *Int2Array) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - - case *[]int16: - *v = make([]int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint16: - *v = make([]uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int2Array{Status: Null} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Int2 - - if len(uta.Elements) > 0 { - elements = make([]Int2, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Int2 - var elemSrc []byte - if s != "NULL" { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = Int2Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} - - return nil -} - -func (dst *Int2Array) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int2Array{Status: Null} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = Int2Array{Dimensions: arrayHeader.Dimensions, Status: Present} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Int2, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = Int2Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} - return nil -} - -func (src *Int2Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src *Int2Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("int2"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, errors.Errorf("unable to find oid for type name %v", "int2") - } - - for i := range src.Elements { - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Int2Array) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *Int2Array) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/int2_array_test.go b/pgtype/int2_array_test.go deleted file mode 100644 index 0fe763c18..000000000 --- a/pgtype/int2_array_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestInt2ArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int2[]", []interface{}{ - &pgtype.Int2Array{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.Int2Array{ - Elements: []pgtype.Int2{ - {Int: 1, Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Int2Array{Status: pgtype.Null}, - &pgtype.Int2Array{ - Elements: []pgtype.Int2{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Int: 6, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Int2Array{ - Elements: []pgtype.Int2{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestInt2ArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Int2Array - }{ - { - source: []int16{1}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []uint16{1}, - result: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]int16)(nil)), - result: pgtype.Int2Array{Status: pgtype.Null}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.Int2Array - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInt2ArrayAssignTo(t *testing.T) { - var int16Slice []int16 - var uint16Slice []uint16 - var namedInt16Slice _int16Slice - - simpleTests := []struct { - src pgtype.Int2Array - dst interface{} - expected interface{} - }{ - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &int16Slice, - expected: []int16{1}, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &uint16Slice, - expected: []uint16{1}, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &namedInt16Slice, - expected: _int16Slice{1}, - }, - { - src: pgtype.Int2Array{Status: pgtype.Null}, - dst: &int16Slice, - expected: (([]int16)(nil)), - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Int2Array - dst interface{} - }{ - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &int16Slice, - }, - { - src: pgtype.Int2Array{ - Elements: []pgtype.Int2{{Int: -1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &uint16Slice, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/pgtype/int2_test.go b/pgtype/int2_test.go deleted file mode 100644 index d20bf0ed2..000000000 --- a/pgtype/int2_test.go +++ /dev/null @@ -1,142 +0,0 @@ -package pgtype_test - -import ( - "math" - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestInt2Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int2", []interface{}{ - &pgtype.Int2{Int: math.MinInt16, Status: pgtype.Present}, - &pgtype.Int2{Int: -1, Status: pgtype.Present}, - &pgtype.Int2{Int: 0, Status: pgtype.Present}, - &pgtype.Int2{Int: 1, Status: pgtype.Present}, - &pgtype.Int2{Int: math.MaxInt16, Status: pgtype.Present}, - &pgtype.Int2{Int: 0, Status: pgtype.Null}, - }) -} - -func TestInt2Set(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Int2 - }{ - {source: int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: int16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: int32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: int64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: int8(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, - {source: int16(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, - {source: int32(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, - {source: int64(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, - {source: uint8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: uint16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: uint32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: uint64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: "1", result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - {source: _int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var r pgtype.Int2 - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInt2AssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - - simpleTests := []struct { - src pgtype.Int2 - dst interface{} - expected interface{} - }{ - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.Int2{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.Int2{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Int2 - dst interface{} - expected interface{} - }{ - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, - {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Int2 - dst interface{} - }{ - {src: pgtype.Int2{Int: 150, Status: pgtype.Present}, dst: &i8}, - {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui8}, - {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui16}, - {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui32}, - {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui64}, - {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui}, - {src: pgtype.Int2{Int: 0, Status: pgtype.Null}, dst: &i16}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/int4.go b/pgtype/int4.go deleted file mode 100644 index 261c51189..000000000 --- a/pgtype/int4.go +++ /dev/null @@ -1,213 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "encoding/json" - "math" - "strconv" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type Int4 struct { - Int int32 - Status Status -} - -func (dst *Int4) Set(src interface{}) error { - if src == nil { - *dst = Int4{Status: Null} - return nil - } - - switch value := src.(type) { - case int8: - *dst = Int4{Int: int32(value), Status: Present} - case uint8: - *dst = Int4{Int: int32(value), Status: Present} - case int16: - *dst = Int4{Int: int32(value), Status: Present} - case uint16: - *dst = Int4{Int: int32(value), Status: Present} - case int32: - *dst = Int4{Int: int32(value), Status: Present} - case uint32: - if value > math.MaxInt32 { - return errors.Errorf("%d is greater than maximum value for Int4", value) - } - *dst = Int4{Int: int32(value), Status: Present} - case int64: - if value < math.MinInt32 { - return errors.Errorf("%d is greater than maximum value for Int4", value) - } - if value > math.MaxInt32 { - return errors.Errorf("%d is greater than maximum value for Int4", value) - } - *dst = Int4{Int: int32(value), Status: Present} - case uint64: - if value > math.MaxInt32 { - return errors.Errorf("%d is greater than maximum value for Int4", value) - } - *dst = Int4{Int: int32(value), Status: Present} - case int: - if value < math.MinInt32 { - return errors.Errorf("%d is greater than maximum value for Int4", value) - } - if value > math.MaxInt32 { - return errors.Errorf("%d is greater than maximum value for Int4", value) - } - *dst = Int4{Int: int32(value), Status: Present} - case uint: - if value > math.MaxInt32 { - return errors.Errorf("%d is greater than maximum value for Int4", value) - } - *dst = Int4{Int: int32(value), Status: Present} - case string: - num, err := strconv.ParseInt(value, 10, 32) - if err != nil { - return err - } - *dst = Int4{Int: int32(num), Status: Present} - default: - if originalSrc, ok := underlyingNumberType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to Int4", value) - } - - return nil -} - -func (dst *Int4) Get() interface{} { - switch dst.Status { - case Present: - return dst.Int - case Null: - return nil - default: - return dst.Status - } -} - -func (src *Int4) AssignTo(dst interface{}) error { - return int64AssignTo(int64(src.Int), src.Status, dst) -} - -func (dst *Int4) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int4{Status: Null} - return nil - } - - n, err := strconv.ParseInt(string(src), 10, 32) - if err != nil { - return err - } - - *dst = Int4{Int: int32(n), Status: Present} - return nil -} - -func (dst *Int4) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int4{Status: Null} - return nil - } - - if len(src) != 4 { - return errors.Errorf("invalid length for int4: %v", len(src)) - } - - n := int32(binary.BigEndian.Uint32(src)) - *dst = Int4{Int: n, Status: Present} - return nil -} - -func (src *Int4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - return append(buf, strconv.FormatInt(int64(src.Int), 10)...), nil -} - -func (src *Int4) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - return pgio.AppendInt32(buf, src.Int), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Int4) Scan(src interface{}) error { - if src == nil { - *dst = Int4{Status: Null} - return nil - } - - switch src := src.(type) { - case int64: - if src < math.MinInt32 { - return errors.Errorf("%d is greater than maximum value for Int4", src) - } - if src > math.MaxInt32 { - return errors.Errorf("%d is greater than maximum value for Int4", src) - } - *dst = Int4{Int: int32(src), Status: Present} - return nil - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *Int4) Value() (driver.Value, error) { - switch src.Status { - case Present: - return int64(src.Int), nil - case Null: - return nil, nil - default: - return nil, errUndefined - } -} - -func (src *Int4) MarshalJSON() ([]byte, error) { - switch src.Status { - case Present: - return []byte(strconv.FormatInt(int64(src.Int), 10)), nil - case Null: - return []byte("null"), nil - case Undefined: - return nil, errUndefined - } - - return nil, errBadStatus -} - -func (dst *Int4) UnmarshalJSON(b []byte) error { - var n int32 - err := json.Unmarshal(b, &n) - if err != nil { - return err - } - - *dst = Int4{Int: n, Status: Present} - - return nil -} diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go deleted file mode 100644 index 6c9418ba6..000000000 --- a/pgtype/int4_array.go +++ /dev/null @@ -1,328 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type Int4Array struct { - Elements []Int4 - Dimensions []ArrayDimension - Status Status -} - -func (dst *Int4Array) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = Int4Array{Status: Null} - return nil - } - - switch value := src.(type) { - - case []int32: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint32: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - default: - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to Int4Array", value) - } - - return nil -} - -func (dst *Int4Array) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *Int4Array) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - - case *[]int32: - *v = make([]int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint32: - *v = make([]uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *Int4Array) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int4Array{Status: Null} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Int4 - - if len(uta.Elements) > 0 { - elements = make([]Int4, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Int4 - var elemSrc []byte - if s != "NULL" { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = Int4Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} - - return nil -} - -func (dst *Int4Array) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int4Array{Status: Null} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = Int4Array{Dimensions: arrayHeader.Dimensions, Status: Present} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Int4, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = Int4Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} - return nil -} - -func (src *Int4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src *Int4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("int4"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, errors.Errorf("unable to find oid for type name %v", "int4") - } - - for i := range src.Elements { - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Int4Array) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *Int4Array) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/int4_array_test.go b/pgtype/int4_array_test.go deleted file mode 100644 index 602a36574..000000000 --- a/pgtype/int4_array_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestInt4ArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int4[]", []interface{}{ - &pgtype.Int4Array{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Int4Array{Status: pgtype.Null}, - &pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Int: 6, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestInt4ArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Int4Array - }{ - { - source: []int32{1}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []uint32{1}, - result: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]int32)(nil)), - result: pgtype.Int4Array{Status: pgtype.Null}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.Int4Array - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInt4ArrayAssignTo(t *testing.T) { - var int32Slice []int32 - var uint32Slice []uint32 - var namedInt32Slice _int32Slice - - simpleTests := []struct { - src pgtype.Int4Array - dst interface{} - expected interface{} - }{ - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &int32Slice, - expected: []int32{1}, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &uint32Slice, - expected: []uint32{1}, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &namedInt32Slice, - expected: _int32Slice{1}, - }, - { - src: pgtype.Int4Array{Status: pgtype.Null}, - dst: &int32Slice, - expected: (([]int32)(nil)), - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Int4Array - dst interface{} - }{ - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &int32Slice, - }, - { - src: pgtype.Int4Array{ - Elements: []pgtype.Int4{{Int: -1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &uint32Slice, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/pgtype/int4_test.go b/pgtype/int4_test.go deleted file mode 100644 index 02f5409fd..000000000 --- a/pgtype/int4_test.go +++ /dev/null @@ -1,143 +0,0 @@ -package pgtype_test - -import ( - "math" - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestInt4Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int4", []interface{}{ - &pgtype.Int4{Int: math.MinInt32, Status: pgtype.Present}, - &pgtype.Int4{Int: -1, Status: pgtype.Present}, - &pgtype.Int4{Int: 0, Status: pgtype.Present}, - &pgtype.Int4{Int: 1, Status: pgtype.Present}, - &pgtype.Int4{Int: math.MaxInt32, Status: pgtype.Present}, - &pgtype.Int4{Int: 0, Status: pgtype.Null}, - }) -} - -func TestInt4Set(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Int4 - }{ - {source: int8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: int16(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: int32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: int64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: int8(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, - {source: int16(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, - {source: int32(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, - {source: int64(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, - {source: uint8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: uint16(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: uint32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: uint64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: "1", result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - {source: _int8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var r pgtype.Int4 - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInt4AssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - - simpleTests := []struct { - src pgtype.Int4 - dst interface{} - expected interface{} - }{ - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.Int4{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.Int4{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Int4 - dst interface{} - expected interface{} - }{ - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, - {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Int4 - dst interface{} - }{ - {src: pgtype.Int4{Int: 150, Status: pgtype.Present}, dst: &i8}, - {src: pgtype.Int4{Int: 40000, Status: pgtype.Present}, dst: &i16}, - {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui8}, - {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui16}, - {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui32}, - {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui64}, - {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui}, - {src: pgtype.Int4{Int: 0, Status: pgtype.Null}, dst: &i32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/int4range.go b/pgtype/int4range.go deleted file mode 100644 index 95ad15218..000000000 --- a/pgtype/int4range.go +++ /dev/null @@ -1,250 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type Int4range struct { - Lower Int4 - Upper Int4 - LowerType BoundType - UpperType BoundType - Status Status -} - -func (dst *Int4range) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Int4range", src) -} - -func (dst *Int4range) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *Int4range) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Int4range) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int4range{Status: Null} - return nil - } - - utr, err := ParseUntypedTextRange(string(src)) - if err != nil { - return err - } - - *dst = Int4range{Status: Present} - - dst.LowerType = utr.LowerType - dst.UpperType = utr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { - return err - } - } - - return nil -} - -func (dst *Int4range) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int4range{Status: Null} - return nil - } - - ubr, err := ParseUntypedBinaryRange(src) - if err != nil { - return err - } - - *dst = Int4range{Status: Present} - - dst.LowerType = ubr.LowerType - dst.UpperType = ubr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { - return err - } - } - - return nil -} - -func (src Int4range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - switch src.LowerType { - case Exclusive, Unbounded: - buf = append(buf, '(') - case Inclusive: - buf = append(buf, '[') - case Empty: - return append(buf, "empty"...), nil - default: - return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) - } - - var err error - - if src.LowerType != Unbounded { - buf, err = src.Lower.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - } - - buf = append(buf, ',') - - if src.UpperType != Unbounded { - buf, err = src.Upper.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - } - - switch src.UpperType { - case Exclusive, Unbounded: - buf = append(buf, ')') - case Inclusive: - buf = append(buf, ']') - default: - return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) - } - - return buf, nil -} - -func (src Int4range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - var rangeType byte - switch src.LowerType { - case Inclusive: - rangeType |= lowerInclusiveMask - case Unbounded: - rangeType |= lowerUnboundedMask - case Exclusive: - case Empty: - return append(buf, emptyMask), nil - default: - return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) - } - - switch src.UpperType { - case Inclusive: - rangeType |= upperInclusiveMask - case Unbounded: - rangeType |= upperUnboundedMask - case Exclusive: - default: - return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) - } - - buf = append(buf, rangeType) - - var err error - - if src.LowerType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Lower.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - if src.UpperType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Upper.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Int4range) Scan(src interface{}) error { - if src == nil { - *dst = Int4range{Status: Null} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Int4range) Value() (driver.Value, error) { - return EncodeValueText(src) -} diff --git a/pgtype/int4range_test.go b/pgtype/int4range_test.go deleted file mode 100644 index 961678bbb..000000000 --- a/pgtype/int4range_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestInt4rangeTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int4range", []interface{}{ - &pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, - &pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - &pgtype.Int4range{Lower: pgtype.Int4{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int4{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - &pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Unbounded, Status: pgtype.Present}, - &pgtype.Int4range{Upper: pgtype.Int4{Int: 1, Status: pgtype.Present}, LowerType: pgtype.Unbounded, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - &pgtype.Int4range{Status: pgtype.Null}, - }) -} - -func TestInt4rangeNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select int4range(1, 10, '(]')", - Value: pgtype.Int4range{Lower: pgtype.Int4{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - }, - }) -} diff --git a/pgtype/int8.go b/pgtype/int8.go deleted file mode 100644 index 00a8cd006..000000000 --- a/pgtype/int8.go +++ /dev/null @@ -1,199 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "encoding/json" - "math" - "strconv" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type Int8 struct { - Int int64 - Status Status -} - -func (dst *Int8) Set(src interface{}) error { - if src == nil { - *dst = Int8{Status: Null} - return nil - } - - switch value := src.(type) { - case int8: - *dst = Int8{Int: int64(value), Status: Present} - case uint8: - *dst = Int8{Int: int64(value), Status: Present} - case int16: - *dst = Int8{Int: int64(value), Status: Present} - case uint16: - *dst = Int8{Int: int64(value), Status: Present} - case int32: - *dst = Int8{Int: int64(value), Status: Present} - case uint32: - *dst = Int8{Int: int64(value), Status: Present} - case int64: - *dst = Int8{Int: int64(value), Status: Present} - case uint64: - if value > math.MaxInt64 { - return errors.Errorf("%d is greater than maximum value for Int8", value) - } - *dst = Int8{Int: int64(value), Status: Present} - case int: - if int64(value) < math.MinInt64 { - return errors.Errorf("%d is greater than maximum value for Int8", value) - } - if int64(value) > math.MaxInt64 { - return errors.Errorf("%d is greater than maximum value for Int8", value) - } - *dst = Int8{Int: int64(value), Status: Present} - case uint: - if uint64(value) > math.MaxInt64 { - return errors.Errorf("%d is greater than maximum value for Int8", value) - } - *dst = Int8{Int: int64(value), Status: Present} - case string: - num, err := strconv.ParseInt(value, 10, 64) - if err != nil { - return err - } - *dst = Int8{Int: num, Status: Present} - default: - if originalSrc, ok := underlyingNumberType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to Int8", value) - } - - return nil -} - -func (dst *Int8) Get() interface{} { - switch dst.Status { - case Present: - return dst.Int - case Null: - return nil - default: - return dst.Status - } -} - -func (src *Int8) AssignTo(dst interface{}) error { - return int64AssignTo(int64(src.Int), src.Status, dst) -} - -func (dst *Int8) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int8{Status: Null} - return nil - } - - n, err := strconv.ParseInt(string(src), 10, 64) - if err != nil { - return err - } - - *dst = Int8{Int: n, Status: Present} - return nil -} - -func (dst *Int8) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int8{Status: Null} - return nil - } - - if len(src) != 8 { - return errors.Errorf("invalid length for int8: %v", len(src)) - } - - n := int64(binary.BigEndian.Uint64(src)) - - *dst = Int8{Int: n, Status: Present} - return nil -} - -func (src *Int8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - return append(buf, strconv.FormatInt(src.Int, 10)...), nil -} - -func (src *Int8) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - return pgio.AppendInt64(buf, src.Int), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Int8) Scan(src interface{}) error { - if src == nil { - *dst = Int8{Status: Null} - return nil - } - - switch src := src.(type) { - case int64: - *dst = Int8{Int: src, Status: Present} - return nil - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *Int8) Value() (driver.Value, error) { - switch src.Status { - case Present: - return int64(src.Int), nil - case Null: - return nil, nil - default: - return nil, errUndefined - } -} - -func (src *Int8) MarshalJSON() ([]byte, error) { - switch src.Status { - case Present: - return []byte(strconv.FormatInt(src.Int, 10)), nil - case Null: - return []byte("null"), nil - case Undefined: - return nil, errUndefined - } - - return nil, errBadStatus -} - -func (dst *Int8) UnmarshalJSON(b []byte) error { - var n int64 - err := json.Unmarshal(b, &n) - if err != nil { - return err - } - - *dst = Int8{Int: n, Status: Present} - - return nil -} diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go deleted file mode 100644 index bb6ce004b..000000000 --- a/pgtype/int8_array.go +++ /dev/null @@ -1,328 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type Int8Array struct { - Elements []Int8 - Dimensions []ArrayDimension - Status Status -} - -func (dst *Int8Array) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = Int8Array{Status: Null} - return nil - } - - switch value := src.(type) { - - case []int64: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint64: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - default: - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to Int8Array", value) - } - - return nil -} - -func (dst *Int8Array) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *Int8Array) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - - case *[]int64: - *v = make([]int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint64: - *v = make([]uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *Int8Array) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int8Array{Status: Null} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Int8 - - if len(uta.Elements) > 0 { - elements = make([]Int8, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Int8 - var elemSrc []byte - if s != "NULL" { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = Int8Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} - - return nil -} - -func (dst *Int8Array) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int8Array{Status: Null} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = Int8Array{Dimensions: arrayHeader.Dimensions, Status: Present} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Int8, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = Int8Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} - return nil -} - -func (src *Int8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src *Int8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("int8"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, errors.Errorf("unable to find oid for type name %v", "int8") - } - - for i := range src.Elements { - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Int8Array) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *Int8Array) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/int8_array_test.go b/pgtype/int8_array_test.go deleted file mode 100644 index 2ca651734..000000000 --- a/pgtype/int8_array_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestInt8ArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int8[]", []interface{}{ - &pgtype.Int8Array{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 1, Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Int8Array{Status: pgtype.Null}, - &pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Int: 6, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - {Int: 4, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestInt8ArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Int8Array - }{ - { - source: []int64{1}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []uint64{1}, - result: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]int64)(nil)), - result: pgtype.Int8Array{Status: pgtype.Null}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.Int8Array - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInt8ArrayAssignTo(t *testing.T) { - var int64Slice []int64 - var uint64Slice []uint64 - var namedInt64Slice _int64Slice - - simpleTests := []struct { - src pgtype.Int8Array - dst interface{} - expected interface{} - }{ - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &int64Slice, - expected: []int64{1}, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &uint64Slice, - expected: []uint64{1}, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &namedInt64Slice, - expected: _int64Slice{1}, - }, - { - src: pgtype.Int8Array{Status: pgtype.Null}, - dst: &int64Slice, - expected: (([]int64)(nil)), - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Int8Array - dst interface{} - }{ - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &int64Slice, - }, - { - src: pgtype.Int8Array{ - Elements: []pgtype.Int8{{Int: -1, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &uint64Slice, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/pgtype/int8_test.go b/pgtype/int8_test.go deleted file mode 100644 index 0b3bb3eb5..000000000 --- a/pgtype/int8_test.go +++ /dev/null @@ -1,144 +0,0 @@ -package pgtype_test - -import ( - "math" - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestInt8Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int8", []interface{}{ - &pgtype.Int8{Int: math.MinInt64, Status: pgtype.Present}, - &pgtype.Int8{Int: -1, Status: pgtype.Present}, - &pgtype.Int8{Int: 0, Status: pgtype.Present}, - &pgtype.Int8{Int: 1, Status: pgtype.Present}, - &pgtype.Int8{Int: math.MaxInt64, Status: pgtype.Present}, - &pgtype.Int8{Int: 0, Status: pgtype.Null}, - }) -} - -func TestInt8Set(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Int8 - }{ - {source: int8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: int16(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: int32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: int64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: int8(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, - {source: int16(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, - {source: int32(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, - {source: int64(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, - {source: uint8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: uint16(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: uint32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: uint64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: "1", result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - {source: _int8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var r pgtype.Int8 - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInt8AssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - - simpleTests := []struct { - src pgtype.Int8 - dst interface{} - expected interface{} - }{ - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.Int8{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.Int8{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Int8 - dst interface{} - expected interface{} - }{ - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, - {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Int8 - dst interface{} - }{ - {src: pgtype.Int8{Int: 150, Status: pgtype.Present}, dst: &i8}, - {src: pgtype.Int8{Int: 40000, Status: pgtype.Present}, dst: &i16}, - {src: pgtype.Int8{Int: 5000000000, Status: pgtype.Present}, dst: &i32}, - {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui8}, - {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui16}, - {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui32}, - {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui64}, - {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui}, - {src: pgtype.Int8{Int: 0, Status: pgtype.Null}, dst: &i64}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/int8range.go b/pgtype/int8range.go deleted file mode 100644 index 61d860d35..000000000 --- a/pgtype/int8range.go +++ /dev/null @@ -1,250 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type Int8range struct { - Lower Int8 - Upper Int8 - LowerType BoundType - UpperType BoundType - Status Status -} - -func (dst *Int8range) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Int8range", src) -} - -func (dst *Int8range) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *Int8range) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Int8range) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int8range{Status: Null} - return nil - } - - utr, err := ParseUntypedTextRange(string(src)) - if err != nil { - return err - } - - *dst = Int8range{Status: Present} - - dst.LowerType = utr.LowerType - dst.UpperType = utr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { - return err - } - } - - return nil -} - -func (dst *Int8range) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int8range{Status: Null} - return nil - } - - ubr, err := ParseUntypedBinaryRange(src) - if err != nil { - return err - } - - *dst = Int8range{Status: Present} - - dst.LowerType = ubr.LowerType - dst.UpperType = ubr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { - return err - } - } - - return nil -} - -func (src Int8range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - switch src.LowerType { - case Exclusive, Unbounded: - buf = append(buf, '(') - case Inclusive: - buf = append(buf, '[') - case Empty: - return append(buf, "empty"...), nil - default: - return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) - } - - var err error - - if src.LowerType != Unbounded { - buf, err = src.Lower.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - } - - buf = append(buf, ',') - - if src.UpperType != Unbounded { - buf, err = src.Upper.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - } - - switch src.UpperType { - case Exclusive, Unbounded: - buf = append(buf, ')') - case Inclusive: - buf = append(buf, ']') - default: - return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) - } - - return buf, nil -} - -func (src Int8range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - var rangeType byte - switch src.LowerType { - case Inclusive: - rangeType |= lowerInclusiveMask - case Unbounded: - rangeType |= lowerUnboundedMask - case Exclusive: - case Empty: - return append(buf, emptyMask), nil - default: - return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) - } - - switch src.UpperType { - case Inclusive: - rangeType |= upperInclusiveMask - case Unbounded: - rangeType |= upperUnboundedMask - case Exclusive: - default: - return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) - } - - buf = append(buf, rangeType) - - var err error - - if src.LowerType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Lower.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - if src.UpperType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Upper.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Int8range) Scan(src interface{}) error { - if src == nil { - *dst = Int8range{Status: Null} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Int8range) Value() (driver.Value, error) { - return EncodeValueText(src) -} diff --git a/pgtype/int8range_test.go b/pgtype/int8range_test.go deleted file mode 100644 index f33ae4d81..000000000 --- a/pgtype/int8range_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestInt8rangeTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "Int8range", []interface{}{ - &pgtype.Int8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, - &pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - &pgtype.Int8range{Lower: pgtype.Int8{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int8{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - &pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Unbounded, Status: pgtype.Present}, - &pgtype.Int8range{Upper: pgtype.Int8{Int: 1, Status: pgtype.Present}, LowerType: pgtype.Unbounded, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - &pgtype.Int8range{Status: pgtype.Null}, - }) -} - -func TestInt8rangeNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select Int8range(1, 10, '(]')", - Value: pgtype.Int8range{Lower: pgtype.Int8{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - }, - }) -} diff --git a/pgtype/int_test.go b/pgtype/int_test.go new file mode 100644 index 000000000..8c4987691 --- /dev/null +++ b/pgtype/int_test.go @@ -0,0 +1,258 @@ +// Code generated from pgtype/int_test.go.erb. DO NOT EDIT. + +package pgtype_test + +import ( + "context" + "math" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestInt2Codec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int2", []pgxtest.ValueRoundTripTest{ + {int8(1), new(int16), isExpectedEq(int16(1))}, + {int16(1), new(int16), isExpectedEq(int16(1))}, + {int32(1), new(int16), isExpectedEq(int16(1))}, + {int64(1), new(int16), isExpectedEq(int16(1))}, + {uint8(1), new(int16), isExpectedEq(int16(1))}, + {uint16(1), new(int16), isExpectedEq(int16(1))}, + {uint32(1), new(int16), isExpectedEq(int16(1))}, + {uint64(1), new(int16), isExpectedEq(int16(1))}, + {int(1), new(int16), isExpectedEq(int16(1))}, + {uint(1), new(int16), isExpectedEq(int16(1))}, + {pgtype.Int2{Int16: 1, Valid: true}, new(int16), isExpectedEq(int16(1))}, + {int32(-1), new(pgtype.Int2), isExpectedEq(pgtype.Int2{Int16: -1, Valid: true})}, + {1, new(int8), isExpectedEq(int8(1))}, + {1, new(int16), isExpectedEq(int16(1))}, + {1, new(int32), isExpectedEq(int32(1))}, + {1, new(int64), isExpectedEq(int64(1))}, + {1, new(uint8), isExpectedEq(uint8(1))}, + {1, new(uint16), isExpectedEq(uint16(1))}, + {1, new(uint32), isExpectedEq(uint32(1))}, + {1, new(uint64), isExpectedEq(uint64(1))}, + {1, new(int), isExpectedEq(int(1))}, + {1, new(uint), isExpectedEq(uint(1))}, + {-1, new(int8), isExpectedEq(int8(-1))}, + {-1, new(int16), isExpectedEq(int16(-1))}, + {-1, new(int32), isExpectedEq(int32(-1))}, + {-1, new(int64), isExpectedEq(int64(-1))}, + {-1, new(int), isExpectedEq(int(-1))}, + {math.MinInt16, new(int16), isExpectedEq(int16(math.MinInt16))}, + {-1, new(int16), isExpectedEq(int16(-1))}, + {0, new(int16), isExpectedEq(int16(0))}, + {1, new(int16), isExpectedEq(int16(1))}, + {math.MaxInt16, new(int16), isExpectedEq(int16(math.MaxInt16))}, + {1, new(pgtype.Int2), isExpectedEq(pgtype.Int2{Int16: 1, Valid: true})}, + {"1", new(string), isExpectedEq("1")}, + {pgtype.Int2{}, new(pgtype.Int2), isExpectedEq(pgtype.Int2{})}, + {nil, new(*int16), isExpectedEq((*int16)(nil))}, + }) +} + +func TestInt2MarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Int2 + result string + }{ + {source: pgtype.Int2{Int16: 0}, result: "null"}, + {source: pgtype.Int2{Int16: 1, Valid: true}, result: "1"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestInt2UnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Int2 + }{ + {source: "null", result: pgtype.Int2{Int16: 0}}, + {source: "1", result: pgtype.Int2{Int16: 1, Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Int2 + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt4Codec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int4", []pgxtest.ValueRoundTripTest{ + {int8(1), new(int32), isExpectedEq(int32(1))}, + {int16(1), new(int32), isExpectedEq(int32(1))}, + {int32(1), new(int32), isExpectedEq(int32(1))}, + {int64(1), new(int32), isExpectedEq(int32(1))}, + {uint8(1), new(int32), isExpectedEq(int32(1))}, + {uint16(1), new(int32), isExpectedEq(int32(1))}, + {uint32(1), new(int32), isExpectedEq(int32(1))}, + {uint64(1), new(int32), isExpectedEq(int32(1))}, + {int(1), new(int32), isExpectedEq(int32(1))}, + {uint(1), new(int32), isExpectedEq(int32(1))}, + {pgtype.Int4{Int32: 1, Valid: true}, new(int32), isExpectedEq(int32(1))}, + {int32(-1), new(pgtype.Int4), isExpectedEq(pgtype.Int4{Int32: -1, Valid: true})}, + {1, new(int8), isExpectedEq(int8(1))}, + {1, new(int16), isExpectedEq(int16(1))}, + {1, new(int32), isExpectedEq(int32(1))}, + {1, new(int64), isExpectedEq(int64(1))}, + {1, new(uint8), isExpectedEq(uint8(1))}, + {1, new(uint16), isExpectedEq(uint16(1))}, + {1, new(uint32), isExpectedEq(uint32(1))}, + {1, new(uint64), isExpectedEq(uint64(1))}, + {1, new(int), isExpectedEq(int(1))}, + {1, new(uint), isExpectedEq(uint(1))}, + {-1, new(int8), isExpectedEq(int8(-1))}, + {-1, new(int16), isExpectedEq(int16(-1))}, + {-1, new(int32), isExpectedEq(int32(-1))}, + {-1, new(int64), isExpectedEq(int64(-1))}, + {-1, new(int), isExpectedEq(int(-1))}, + {math.MinInt32, new(int32), isExpectedEq(int32(math.MinInt32))}, + {-1, new(int32), isExpectedEq(int32(-1))}, + {0, new(int32), isExpectedEq(int32(0))}, + {1, new(int32), isExpectedEq(int32(1))}, + {math.MaxInt32, new(int32), isExpectedEq(int32(math.MaxInt32))}, + {1, new(pgtype.Int4), isExpectedEq(pgtype.Int4{Int32: 1, Valid: true})}, + {"1", new(string), isExpectedEq("1")}, + {pgtype.Int4{}, new(pgtype.Int4), isExpectedEq(pgtype.Int4{})}, + {nil, new(*int32), isExpectedEq((*int32)(nil))}, + }) +} + +func TestInt4MarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Int4 + result string + }{ + {source: pgtype.Int4{Int32: 0}, result: "null"}, + {source: pgtype.Int4{Int32: 1, Valid: true}, result: "1"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestInt4UnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Int4 + }{ + {source: "null", result: pgtype.Int4{Int32: 0}}, + {source: "1", result: pgtype.Int4{Int32: 1, Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Int4 + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt8Codec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int8", []pgxtest.ValueRoundTripTest{ + {int8(1), new(int64), isExpectedEq(int64(1))}, + {int16(1), new(int64), isExpectedEq(int64(1))}, + {int32(1), new(int64), isExpectedEq(int64(1))}, + {int64(1), new(int64), isExpectedEq(int64(1))}, + {uint8(1), new(int64), isExpectedEq(int64(1))}, + {uint16(1), new(int64), isExpectedEq(int64(1))}, + {uint32(1), new(int64), isExpectedEq(int64(1))}, + {uint64(1), new(int64), isExpectedEq(int64(1))}, + {int(1), new(int64), isExpectedEq(int64(1))}, + {uint(1), new(int64), isExpectedEq(int64(1))}, + {pgtype.Int8{Int64: 1, Valid: true}, new(int64), isExpectedEq(int64(1))}, + {int32(-1), new(pgtype.Int8), isExpectedEq(pgtype.Int8{Int64: -1, Valid: true})}, + {1, new(int8), isExpectedEq(int8(1))}, + {1, new(int16), isExpectedEq(int16(1))}, + {1, new(int32), isExpectedEq(int32(1))}, + {1, new(int64), isExpectedEq(int64(1))}, + {1, new(uint8), isExpectedEq(uint8(1))}, + {1, new(uint16), isExpectedEq(uint16(1))}, + {1, new(uint32), isExpectedEq(uint32(1))}, + {1, new(uint64), isExpectedEq(uint64(1))}, + {1, new(int), isExpectedEq(int(1))}, + {1, new(uint), isExpectedEq(uint(1))}, + {-1, new(int8), isExpectedEq(int8(-1))}, + {-1, new(int16), isExpectedEq(int16(-1))}, + {-1, new(int32), isExpectedEq(int32(-1))}, + {-1, new(int64), isExpectedEq(int64(-1))}, + {-1, new(int), isExpectedEq(int(-1))}, + {math.MinInt64, new(int64), isExpectedEq(int64(math.MinInt64))}, + {-1, new(int64), isExpectedEq(int64(-1))}, + {0, new(int64), isExpectedEq(int64(0))}, + {1, new(int64), isExpectedEq(int64(1))}, + {math.MaxInt64, new(int64), isExpectedEq(int64(math.MaxInt64))}, + {1, new(pgtype.Int8), isExpectedEq(pgtype.Int8{Int64: 1, Valid: true})}, + {"1", new(string), isExpectedEq("1")}, + {pgtype.Int8{}, new(pgtype.Int8), isExpectedEq(pgtype.Int8{})}, + {nil, new(*int64), isExpectedEq((*int64)(nil))}, + }) +} + +func TestInt8MarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Int8 + result string + }{ + {source: pgtype.Int8{Int64: 0}, result: "null"}, + {source: pgtype.Int8{Int64: 1, Valid: true}, result: "1"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestInt8UnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Int8 + }{ + {source: "null", result: pgtype.Int8{Int64: 0}}, + {source: "1", result: pgtype.Int8{Int64: 1, Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Int8 + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/int_test.go.erb b/pgtype/int_test.go.erb new file mode 100644 index 000000000..ac9a3f143 --- /dev/null +++ b/pgtype/int_test.go.erb @@ -0,0 +1,93 @@ +package pgtype_test + +import ( + "math" + "testing" + + "github.com/jackc/pgx/v5/pgtype" +) + +<% [2, 4, 8].each do |pg_byte_size| %> +<% pg_bit_size = pg_byte_size * 8 %> +func TestInt<%= pg_byte_size %>Codec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int<%= pg_byte_size %>", []pgxtest.ValueRoundTripTest{ + {int8(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {int16(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {int32(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {int64(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {uint8(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {uint16(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {uint32(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {uint64(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {int(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {uint(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {pgtype.Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: 1, Valid: true}, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {int32(-1), new(pgtype.Int<%= pg_byte_size %>), isExpectedEq(pgtype.Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: -1, Valid: true})}, + {1, new(int8), isExpectedEq(int8(1))}, + {1, new(int16), isExpectedEq(int16(1))}, + {1, new(int32), isExpectedEq(int32(1))}, + {1, new(int64), isExpectedEq(int64(1))}, + {1, new(uint8), isExpectedEq(uint8(1))}, + {1, new(uint16), isExpectedEq(uint16(1))}, + {1, new(uint32), isExpectedEq(uint32(1))}, + {1, new(uint64), isExpectedEq(uint64(1))}, + {1, new(int), isExpectedEq(int(1))}, + {1, new(uint), isExpectedEq(uint(1))}, + {-1, new(int8), isExpectedEq(int8(-1))}, + {-1, new(int16), isExpectedEq(int16(-1))}, + {-1, new(int32), isExpectedEq(int32(-1))}, + {-1, new(int64), isExpectedEq(int64(-1))}, + {-1, new(int), isExpectedEq(int(-1))}, + {math.MinInt<%= pg_bit_size %>, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(math.MinInt<%= pg_bit_size %>))}, + {-1, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(-1))}, + {0, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(0))}, + {1, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {math.MaxInt<%= pg_bit_size %>, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(math.MaxInt<%= pg_bit_size %>))}, + {1, new(pgtype.Int<%= pg_byte_size %>), isExpectedEq(pgtype.Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: 1, Valid: true})}, + {"1", new(string), isExpectedEq("1")}, + {pgtype.Int<%= pg_byte_size %>{}, new(pgtype.Int<%= pg_byte_size %>), isExpectedEq(pgtype.Int<%= pg_byte_size %>{})}, + {nil, new(*int<%= pg_bit_size %>), isExpectedEq((*int<%= pg_bit_size %>)(nil))}, + }) +} + +func TestInt<%= pg_byte_size %>MarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Int<%= pg_byte_size %> + result string + }{ + {source: pgtype.Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: 0}, result: "null"}, + {source: pgtype.Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: 1, Valid: true}, result: "1"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestInt<%= pg_byte_size %>UnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Int<%= pg_byte_size %> + }{ + {source: "null", result: pgtype.Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: 0}}, + {source: "1", result: pgtype.Int<%= pg_byte_size %>{Int<%= pg_bit_size %>: 1, Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Int<%= pg_byte_size %> + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} +<% end %> diff --git a/pgtype/integration_benchmark_test.go b/pgtype/integration_benchmark_test.go new file mode 100644 index 000000000..88516a9fb --- /dev/null +++ b/pgtype/integration_benchmark_test.go @@ -0,0 +1,1271 @@ +// Code generated from pgtype/integration_benchmark_test.go.erb. DO NOT EDIT. + +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" +) + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int16 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int16 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int16 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int16 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int16 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int16 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int16 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int16 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]uint64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]uint64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]uint64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]uint64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]uint64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]uint64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]uint64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]uint64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Int4 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Int4 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Int4 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Int4 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Int4 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Int4 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0 from generate_series(1, 10) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Int4 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Int4 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 10) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 10) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]float64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]float64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]float64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]float64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]float64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 10) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]float64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 10) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]float64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]float64 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Numeric + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Numeric + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Numeric + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Numeric + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Numeric + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 10) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_10_rows_1_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [1]pgtype.Numeric + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0 from generate_series(1, 10) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Numeric + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_100_rows_10_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [10]pgtype.Numeric + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_10(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select array_agg(n) from generate_series(1, 10) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_10(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select array_agg(n) from generate_series(1, 10) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_100(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select array_agg(n) from generate_series(1, 100) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_100(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select array_agg(n) from generate_series(1, 100) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryTextFormatDecode_PG_Int4Array_With_Go_Int4Array_1000(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select array_agg(n) from generate_series(1, 1000) n`, + pgx.QueryResultFormats{pgx.TextFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkQueryBinaryFormatDecode_PG_Int4Array_With_Go_Int4Array_1000(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select array_agg(n) from generate_series(1, 1000) n`, + pgx.QueryResultFormats{pgx.BinaryFormatCode}, + ) + _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} diff --git a/pgtype/integration_benchmark_test.go.erb b/pgtype/integration_benchmark_test.go.erb new file mode 100644 index 000000000..6f4011534 --- /dev/null +++ b/pgtype/integration_benchmark_test.go.erb @@ -0,0 +1,62 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5" +) + +<% + [ + ["int4", ["int16", "int32", "int64", "uint64", "pgtype.Int4"], [[1, 1], [1, 10], [10, 1], [100, 10]]], + ["numeric", ["int64", "float64", "pgtype.Numeric"], [[1, 1], [1, 10], [10, 1], [100, 10]]], + ].each do |pg_type, go_types, rows_columns| +%> +<% go_types.each do |go_type| %> +<% rows_columns.each do |rows, columns| %> +<% [["Text", "pgx.TextFormatCode"], ["Binary", "pgx.BinaryFormatCode"]].each do |format_name, format_code| %> +func BenchmarkQuery<%= format_name %>FormatDecode_PG_<%= pg_type %>_to_Go_<%= go_type.gsub(/\W/, "_") %>_<%= rows %>_rows_<%= columns %>_columns(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v [<%= columns %>]<%= go_type %> + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select <% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>n::<%= pg_type %> + <%= col_idx%><% end %> from generate_series(1, <%= rows %>) n`, + pgx.QueryResultFormats{<%= format_code %>}, + ) + _, err := pgx.ForEachRow(rows, []any{<% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>&v[<%= col_idx%>]<% end %>}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} +<% end %> +<% end %> +<% end %> +<% end %> + +<% [10, 100, 1000].each do |array_size| %> +<% [["Text", "pgx.TextFormatCode"], ["Binary", "pgx.BinaryFormatCode"]].each do |format_name, format_code| %> +func BenchmarkQuery<%= format_name %>FormatDecode_PG_Int4Array_With_Go_Int4Array_<%= array_size %>(b *testing.B) { + defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + b.ResetTimer() + var v []int32 + for i := 0; i < b.N; i++ { + rows, _ := conn.Query( + ctx, + `select array_agg(n) from generate_series(1, <%= array_size %>) n`, + pgx.QueryResultFormats{<%= format_code %>}, + ) + _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) + if err != nil { + b.Fatal(err) + } + } + }) +} +<% end %> +<% end %> diff --git a/pgtype/integration_benchmark_test_gen.sh b/pgtype/integration_benchmark_test_gen.sh new file mode 100755 index 000000000..22ac01aaf --- /dev/null +++ b/pgtype/integration_benchmark_test_gen.sh @@ -0,0 +1,2 @@ +erb integration_benchmark_test.go.erb > integration_benchmark_test.go +goimports -w integration_benchmark_test.go diff --git a/pgtype/interval.go b/pgtype/interval.go index 799ce53a4..ba5e818f0 100644 --- a/pgtype/interval.go +++ b/pgtype/interval.go @@ -6,81 +6,202 @@ import ( "fmt" "strconv" "strings" - "time" - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" + "github.com/jackc/pgx/v5/internal/pgio" ) const ( microsecondsPerSecond = 1000000 microsecondsPerMinute = 60 * microsecondsPerSecond microsecondsPerHour = 60 * microsecondsPerMinute + microsecondsPerDay = 24 * microsecondsPerHour + microsecondsPerMonth = 30 * microsecondsPerDay ) +type IntervalScanner interface { + ScanInterval(v Interval) error +} + +type IntervalValuer interface { + IntervalValue() (Interval, error) +} + type Interval struct { Microseconds int64 Days int32 Months int32 - Status Status + Valid bool +} + +// ScanInterval implements the [IntervalScanner] interface. +func (interval *Interval) ScanInterval(v Interval) error { + *interval = v + return nil +} + +// IntervalValue implements the [IntervalValuer] interface. +func (interval Interval) IntervalValue() (Interval, error) { + return interval, nil } -func (dst *Interval) Set(src interface{}) error { +// Scan implements the [database/sql.Scanner] interface. +func (interval *Interval) Scan(src any) error { if src == nil { - *dst = Interval{Status: Null} + *interval = Interval{} return nil } - switch value := src.(type) { - case time.Duration: - *dst = Interval{Microseconds: int64(value) / 1000, Status: Present} - default: - if originalSrc, ok := underlyingPtrType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to Interval", value) + switch src := src.(type) { + case string: + return scanPlanTextAnyToIntervalScanner{}.Scan([]byte(src), interval) } - return nil + return fmt.Errorf("cannot scan %T", src) } -func (dst *Interval) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: +// Value implements the [database/sql/driver.Valuer] interface. +func (interval Interval) Value() (driver.Value, error) { + if !interval.Valid { + return nil, nil + } + + buf, err := IntervalCodec{}.PlanEncode(nil, 0, TextFormatCode, interval).Encode(interval, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type IntervalCodec struct{} + +func (IntervalCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (IntervalCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (IntervalCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(IntervalValuer); !ok { return nil - default: - return dst.Status } + + switch format { + case BinaryFormatCode: + return encodePlanIntervalCodecBinary{} + case TextFormatCode: + return encodePlanIntervalCodecText{} + } + + return nil } -func (src *Interval) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *time.Duration: - if src.Days > 0 || src.Months > 0 { - return errors.Errorf("interval with months or days cannot be decoded into %T", dst) - } - *v = time.Duration(src.Microseconds) * time.Microsecond - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } +type encodePlanIntervalCodecBinary struct{} + +func (encodePlanIntervalCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + interval, err := value.(IntervalValuer).IntervalValue() + if err != nil { + return nil, err + } + + if !interval.Valid { + return nil, nil + } + + buf = pgio.AppendInt64(buf, interval.Microseconds) + buf = pgio.AppendInt32(buf, interval.Days) + buf = pgio.AppendInt32(buf, interval.Months) + return buf, nil +} + +type encodePlanIntervalCodecText struct{} + +func (encodePlanIntervalCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + interval, err := value.(IntervalValuer).IntervalValue() + if err != nil { + return nil, err + } + + if !interval.Valid { + return nil, nil + } + + if interval.Months != 0 { + buf = append(buf, strconv.FormatInt(int64(interval.Months), 10)...) + buf = append(buf, " mon "...) + } + + if interval.Days != 0 { + buf = append(buf, strconv.FormatInt(int64(interval.Days), 10)...) + buf = append(buf, " day "...) + } + + absMicroseconds := interval.Microseconds + if absMicroseconds < 0 { + absMicroseconds = -absMicroseconds + buf = append(buf, '-') + } + + hours := absMicroseconds / microsecondsPerHour + minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute + seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond + + timeStr := fmt.Sprintf("%02d:%02d:%02d", hours, minutes, seconds) + buf = append(buf, timeStr...) + + microseconds := absMicroseconds % microsecondsPerSecond + if microseconds != 0 { + buf = append(buf, fmt.Sprintf(".%06d", microseconds)...) + } + + return buf, nil +} + +func (IntervalCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case IntervalScanner: + return scanPlanBinaryIntervalToIntervalScanner{} + } + case TextFormatCode: + switch target.(type) { + case IntervalScanner: + return scanPlanTextAnyToIntervalScanner{} } - case Null: - return NullAssignTo(dst) } - return errors.Errorf("cannot decode %v into %T", src, dst) + return nil } -func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { +type scanPlanBinaryIntervalToIntervalScanner struct{} + +func (scanPlanBinaryIntervalToIntervalScanner) Scan(src []byte, dst any) error { + scanner := (dst).(IntervalScanner) + if src == nil { - *dst = Interval{Status: Null} - return nil + return scanner.ScanInterval(Interval{}) + } + + if len(src) != 16 { + return fmt.Errorf("Received an invalid size for an interval: %d", len(src)) + } + + microseconds := int64(binary.BigEndian.Uint64(src)) + days := int32(binary.BigEndian.Uint32(src[8:])) + months := int32(binary.BigEndian.Uint32(src[12:])) + + return scanner.ScanInterval(Interval{Microseconds: microseconds, Days: days, Months: months, Valid: true}) +} + +type scanPlanTextAnyToIntervalScanner struct{} + +func (scanPlanTextAnyToIntervalScanner) Scan(src []byte, dst any) error { + scanner := (dst).(IntervalScanner) + + if src == nil { + return scanner.ScanInterval(Interval{}) } var microseconds int64 @@ -92,7 +213,7 @@ func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { for i := 0; i < len(parts)-1; i += 2 { scalar, err := strconv.ParseInt(parts[i], 10, 64) if err != nil { - return errors.Errorf("bad interval format") + return fmt.Errorf("bad interval format") } switch parts[i+1] { @@ -108,7 +229,7 @@ func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { if len(parts)%2 == 1 { timeParts := strings.SplitN(parts[len(parts)-1], ":", 3) if len(timeParts) != 3 { - return errors.Errorf("bad interval format") + return fmt.Errorf("bad interval format") } var negative bool @@ -119,29 +240,29 @@ func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { hours, err := strconv.ParseInt(timeParts[0], 10, 64) if err != nil { - return errors.Errorf("bad interval hour format: %s", timeParts[0]) + return fmt.Errorf("bad interval hour format: %s", timeParts[0]) } minutes, err := strconv.ParseInt(timeParts[1], 10, 64) if err != nil { - return errors.Errorf("bad interval minute format: %s", timeParts[1]) + return fmt.Errorf("bad interval minute format: %s", timeParts[1]) } - secondParts := strings.SplitN(timeParts[2], ".", 2) + sec, secFrac, secFracFound := strings.Cut(timeParts[2], ".") - seconds, err := strconv.ParseInt(secondParts[0], 10, 64) + seconds, err := strconv.ParseInt(sec, 10, 64) if err != nil { - return errors.Errorf("bad interval second format: %s", secondParts[0]) + return fmt.Errorf("bad interval second format: %s", sec) } var uSeconds int64 - if len(secondParts) == 2 { - uSeconds, err = strconv.ParseInt(secondParts[1], 10, 64) + if secFracFound { + uSeconds, err = strconv.ParseInt(secFrac, 10, 64) if err != nil { - return errors.Errorf("bad interval decimal format: %s", secondParts[1]) + return fmt.Errorf("bad interval decimal format: %s", secFrac) } - for i := 0; i < 6-len(secondParts[1]); i++ { + for i := 0; i < 6-len(secFrac); i++ { uSeconds *= 10 } } @@ -156,95 +277,22 @@ func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = Interval{Months: months, Days: days, Microseconds: microseconds, Status: Present} - return nil + return scanner.ScanInterval(Interval{Months: months, Days: days, Microseconds: microseconds, Valid: true}) } -func (dst *Interval) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Interval{Status: Null} - return nil - } - - if len(src) != 16 { - return errors.Errorf("Received an invalid size for a interval: %d", len(src)) - } - - microseconds := int64(binary.BigEndian.Uint64(src)) - days := int32(binary.BigEndian.Uint32(src[8:])) - months := int32(binary.BigEndian.Uint32(src[12:])) - - *dst = Interval{Microseconds: microseconds, Days: days, Months: months, Status: Present} - return nil +func (c IntervalCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) } -func (src *Interval) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - if src.Months != 0 { - buf = append(buf, strconv.FormatInt(int64(src.Months), 10)...) - buf = append(buf, " mon "...) - } - - if src.Days != 0 { - buf = append(buf, strconv.FormatInt(int64(src.Days), 10)...) - buf = append(buf, " day "...) - } - - absMicroseconds := src.Microseconds - if absMicroseconds < 0 { - absMicroseconds = -absMicroseconds - buf = append(buf, '-') - } - - hours := absMicroseconds / microsecondsPerHour - minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute - seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond - microseconds := absMicroseconds % microsecondsPerSecond - - timeStr := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, microseconds) - return append(buf, timeStr...), nil -} - -// EncodeBinary encodes src into w. -func (src *Interval) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - buf = pgio.AppendInt64(buf, src.Microseconds) - buf = pgio.AppendInt32(buf, src.Days) - return pgio.AppendInt32(buf, src.Months), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Interval) Scan(src interface{}) error { +func (c IntervalCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { - *dst = Interval{Status: Null} - return nil + return nil, nil } - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + var interval Interval + err := codecScan(c, m, oid, format, src, &interval) + if err != nil { + return nil, err } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *Interval) Value() (driver.Value, error) { - return EncodeValueText(src) + return interval, nil } diff --git a/pgtype/interval_test.go b/pgtype/interval_test.go index 76ea3240b..c06c3b2df 100644 --- a/pgtype/interval_test.go +++ b/pgtype/interval_test.go @@ -1,63 +1,158 @@ package pgtype_test import ( + "context" "testing" + "time" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/assert" ) -func TestIntervalTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "interval", []interface{}{ - &pgtype.Interval{Microseconds: 1, Status: pgtype.Present}, - &pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, - &pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, - &pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, - &pgtype.Interval{Days: 1, Status: pgtype.Present}, - &pgtype.Interval{Months: 1, Status: pgtype.Present}, - &pgtype.Interval{Months: 12, Status: pgtype.Present}, - &pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Status: pgtype.Present}, - &pgtype.Interval{Microseconds: -1, Status: pgtype.Present}, - &pgtype.Interval{Microseconds: -1000000, Status: pgtype.Present}, - &pgtype.Interval{Microseconds: -1000001, Status: pgtype.Present}, - &pgtype.Interval{Microseconds: -123202800000000, Status: pgtype.Present}, - &pgtype.Interval{Days: -1, Status: pgtype.Present}, - &pgtype.Interval{Months: -1, Status: pgtype.Present}, - &pgtype.Interval{Months: -12, Status: pgtype.Present}, - &pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Status: pgtype.Present}, - &pgtype.Interval{Status: pgtype.Null}, - }) -} - -func TestIntervalNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ +func TestIntervalCodec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "interval", []pgxtest.ValueRoundTripTest{ + { + pgtype.Interval{Microseconds: 1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 1, Valid: true}), + }, + { + pgtype.Interval{Microseconds: 1000000, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 1000000, Valid: true}), + }, + { + pgtype.Interval{Microseconds: 1000001, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 1000001, Valid: true}), + }, + { + pgtype.Interval{Microseconds: 123202800000000, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 123202800000000, Valid: true}), + }, + { + pgtype.Interval{Days: 1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Days: 1, Valid: true}), + }, + { + pgtype.Interval{Months: 1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: 1, Valid: true}), + }, + { + pgtype.Interval{Months: 12, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: 12, Valid: true}), + }, + { + pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Valid: true}), + }, + { + pgtype.Interval{Microseconds: -1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: -1, Valid: true}), + }, + { + pgtype.Interval{Microseconds: -1000000, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: -1000000, Valid: true}), + }, + { + pgtype.Interval{Microseconds: -1000001, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: -1000001, Valid: true}), + }, + { + pgtype.Interval{Microseconds: -123202800000000, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: -123202800000000, Valid: true}), + }, + { + pgtype.Interval{Days: -1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Days: -1, Valid: true}), + }, + { + pgtype.Interval{Months: -1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: -1, Valid: true}), + }, + { + pgtype.Interval{Months: -12, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: -12, Valid: true}), + }, + { + pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Valid: true}), + }, { - SQL: "select '1 second'::interval", - Value: &pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, + "1 second", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 1000000, Valid: true}), }, { - SQL: "select '1.000001 second'::interval", - Value: &pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, + "1.000001 second", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 1000001, Valid: true}), }, { - SQL: "select '34223 hours'::interval", - Value: &pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, + "34223 hours", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 123202800000000, Valid: true}), }, { - SQL: "select '1 day'::interval", - Value: &pgtype.Interval{Days: 1, Status: pgtype.Present}, + "1 day", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Days: 1, Valid: true}), }, { - SQL: "select '1 month'::interval", - Value: &pgtype.Interval{Months: 1, Status: pgtype.Present}, + "1 month", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: 1, Valid: true}), }, { - SQL: "select '1 year'::interval", - Value: &pgtype.Interval{Months: 12, Status: pgtype.Present}, + "1 year", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: 12, Valid: true}), }, { - SQL: "select '-13 mon'::interval", - Value: &pgtype.Interval{Months: -13, Status: pgtype.Present}, + "-13 mon", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: -13, Valid: true}), }, + {time.Hour, new(time.Duration), isExpectedEq(time.Hour)}, + { + pgtype.Interval{Months: 1, Days: 1, Valid: true}, + new(time.Duration), + isExpectedEq(time.Duration(2678400000000000)), + }, + {pgtype.Interval{}, new(pgtype.Interval), isExpectedEq(pgtype.Interval{})}, + {nil, new(pgtype.Interval), isExpectedEq(pgtype.Interval{})}, }) } + +func TestIntervalTextEncode(t *testing.T) { + m := pgtype.NewMap() + + successfulTests := []struct { + source pgtype.Interval + result string + }{ + {source: pgtype.Interval{Months: 2, Days: 1, Microseconds: 0, Valid: true}, result: "2 mon 1 day 00:00:00"}, + {source: pgtype.Interval{Months: 0, Days: 0, Microseconds: 0, Valid: true}, result: "00:00:00"}, + {source: pgtype.Interval{Months: 0, Days: 0, Microseconds: 6 * 60 * 1000000, Valid: true}, result: "00:06:00"}, + {source: pgtype.Interval{Months: 0, Days: 1, Microseconds: 6*60*1000000 + 30, Valid: true}, result: "1 day 00:06:00.000030"}, + } + for i, tt := range successfulTests { + buf, err := m.Encode(pgtype.DateOID, pgtype.TextFormatCode, tt.source, nil) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.result, string(buf), "%d", i) + } +} diff --git a/pgtype/json.go b/pgtype/json.go index ef8231b18..60aa2b71d 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -1,161 +1,243 @@ package pgtype import ( + "database/sql" "database/sql/driver" "encoding/json" - - "github.com/pkg/errors" + "fmt" + "reflect" ) -type JSON struct { - Bytes []byte - Status Status +type JSONCodec struct { + Marshal func(v any) ([]byte, error) + Unmarshal func(data []byte, v any) error } -func (dst *JSON) Set(src interface{}) error { - if src == nil { - *dst = JSON{Status: Null} - return nil - } +func (*JSONCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} - switch value := src.(type) { +func (*JSONCodec) PreferredFormat() int16 { + return TextFormatCode +} + +func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch value.(type) { case string: - *dst = JSON{Bytes: []byte(value), Status: Present} - case *string: - if value == nil { - *dst = JSON{Status: Null} - } else { - *dst = JSON{Bytes: []byte(*value), Status: Present} - } + return encodePlanJSONCodecEitherFormatString{} case []byte: - if value == nil { - *dst = JSON{Status: Null} - } else { - *dst = JSON{Bytes: value, Status: Present} + return encodePlanJSONCodecEitherFormatByteSlice{} + + // Handle json.RawMessage specifically because if it is run through json.Marshal it may be mutated. + // e.g. `{"foo": "bar"}` -> `{"foo":"bar"}`. + case json.RawMessage: + return encodePlanJSONCodecEitherFormatJSONRawMessage{} + + // Cannot rely on driver.Valuer being handled later because anything can be marshalled. + // + // https://github.com/jackc/pgx/issues/1430 + // + // Check for driver.Valuer must come before json.Marshaler so that it is guaranteed to be used + // when both are implemented https://github.com/jackc/pgx/issues/1805 + case driver.Valuer: + return &encodePlanDriverValuer{m: m, oid: oid, formatCode: format} + + // Must come before trying wrap encode plans because a pointer to a struct may be unwrapped to a struct that can be + // marshalled. + // + // https://github.com/jackc/pgx/issues/1681 + case json.Marshaler: + return &encodePlanJSONCodecEitherFormatMarshal{ + marshal: c.Marshal, } - // Encode* methods are defined on *JSON. If JSON is passed directly then the - // struct itself would be encoded instead of Bytes. This is clearly a footgun - // so detect and return an error. See https://github.com/jackc/pgx/issues/350. - case JSON: - return errors.New("use pointer to pgtype.JSON instead of value") - // Same as above but for JSONB (because they share implementation) - case JSONB: - return errors.New("use pointer to pgtype.JSONB instead of value") - - default: - buf, err := json.Marshal(value) - if err != nil { - return err + } + + // Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the + // appropriate wrappers here. + for _, f := range []TryWrapEncodePlanFunc{ + TryWrapDerefPointerEncodePlan, + TryWrapFindUnderlyingTypeEncodePlan, + } { + if wrapperPlan, nextValue, ok := f(value); ok { + if nextPlan := c.PlanEncode(m, oid, format, nextValue); nextPlan != nil { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } } - *dst = JSON{Bytes: buf, Status: Present} } - return nil + return &encodePlanJSONCodecEitherFormatMarshal{ + marshal: c.Marshal, + } } -func (dst *JSON) Get() interface{} { - switch dst.Status { - case Present: - var i interface{} - err := json.Unmarshal(dst.Bytes, &i) - if err != nil { - return dst - } - return i - case Null: +// JSON needs its on scan plan for pointers to handle 'null'::json(b). +// Consider making pointerPointerScanPlan more flexible in the future. +type jsonPointerScanPlan struct { + next ScanPlan +} + +func (p jsonPointerScanPlan) Scan(src []byte, dst any) error { + el := reflect.ValueOf(dst).Elem() + if src == nil || string(src) == "null" { + el.SetZero() return nil - default: - return dst.Status } + + el.Set(reflect.New(el.Type().Elem())) + if p.next != nil { + return p.next.Scan(src, el.Interface()) + } + + return nil } -func (src *JSON) AssignTo(dst interface{}) error { - switch v := dst.(type) { +type encodePlanJSONCodecEitherFormatString struct{} + +func (encodePlanJSONCodecEitherFormatString) Encode(value any, buf []byte) (newBuf []byte, err error) { + jsonString := value.(string) + buf = append(buf, jsonString...) + return buf, nil +} + +type encodePlanJSONCodecEitherFormatByteSlice struct{} + +func (encodePlanJSONCodecEitherFormatByteSlice) Encode(value any, buf []byte) (newBuf []byte, err error) { + jsonBytes := value.([]byte) + if jsonBytes == nil { + return nil, nil + } + + buf = append(buf, jsonBytes...) + return buf, nil +} + +type encodePlanJSONCodecEitherFormatJSONRawMessage struct{} + +func (encodePlanJSONCodecEitherFormatJSONRawMessage) Encode(value any, buf []byte) (newBuf []byte, err error) { + jsonBytes := value.(json.RawMessage) + if jsonBytes == nil { + return nil, nil + } + + buf = append(buf, jsonBytes...) + return buf, nil +} + +type encodePlanJSONCodecEitherFormatMarshal struct { + marshal func(v any) ([]byte, error) +} + +func (e *encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) { + jsonBytes, err := e.marshal(value) + if err != nil { + return nil, err + } + + buf = append(buf, jsonBytes...) + return buf, nil +} + +func (c *JSONCodec) PlanScan(m *Map, oid uint32, formatCode int16, target any) ScanPlan { + return c.planScan(m, oid, formatCode, target, 0) +} + +// JSON cannot fallback to pointerPointerScanPlan because of 'null'::json(b), +// so we need to duplicate the logic here. +func (c *JSONCodec) planScan(m *Map, oid uint32, formatCode int16, target any, depth int) ScanPlan { + if depth > 8 { + return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} + } + + switch target.(type) { case *string: - if src.Status != Present { - v = nil - } else { - *v = string(src.Bytes) - } - case **string: - *v = new(string) - return src.AssignTo(*v) + return &scanPlanAnyToString{} case *[]byte: - if src.Status != Present { - *v = nil - } else { - buf := make([]byte, len(src.Bytes)) - copy(buf, src.Bytes) - *v = buf - } - default: - data := src.Bytes - if data == nil || src.Status != Present { - data = []byte("null") - } + return &scanPlanJSONToByteSlice{} + case BytesScanner: + return &scanPlanBinaryBytesToBytesScanner{} + case sql.Scanner: + return &scanPlanSQLScanner{formatCode: formatCode} + } - return json.Unmarshal(data, dst) + rv := reflect.ValueOf(target) + if rv.Kind() == reflect.Pointer && rv.Elem().Kind() == reflect.Pointer { + var plan jsonPointerScanPlan + plan.next = c.planScan(m, oid, formatCode, rv.Elem().Interface(), depth+1) + return plan + } else { + return &scanPlanJSONToJSONUnmarshal{unmarshal: c.Unmarshal} } +} + +type scanPlanAnyToString struct{} +func (scanPlanAnyToString) Scan(src []byte, dst any) error { + p := dst.(*string) + *p = string(src) return nil } -func (dst *JSON) DecodeText(ci *ConnInfo, src []byte) error { +type scanPlanJSONToByteSlice struct{} + +func (scanPlanJSONToByteSlice) Scan(src []byte, dst any) error { + dstBuf := dst.(*[]byte) if src == nil { - *dst = JSON{Status: Null} + *dstBuf = nil return nil } - *dst = JSON{Bytes: src, Status: Present} + *dstBuf = make([]byte, len(src)) + copy(*dstBuf, src) return nil } -func (dst *JSON) DecodeBinary(ci *ConnInfo, src []byte) error { - return dst.DecodeText(ci, src) +type scanPlanJSONToJSONUnmarshal struct { + unmarshal func(data []byte, v any) error } -func (src *JSON) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined +func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { + if src == nil { + dstValue := reflect.ValueOf(dst) + if dstValue.Kind() == reflect.Ptr { + el := dstValue.Elem() + switch el.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Interface: + el.Set(reflect.Zero(el.Type())) + return nil + } + } + + return fmt.Errorf("cannot scan NULL into %T", dst) } - return append(buf, src.Bytes...), nil -} + v := reflect.ValueOf(dst) + if v.Kind() != reflect.Pointer || v.IsNil() { + return fmt.Errorf("cannot scan into non-pointer or nil destinations %T", dst) + } + + elem := v.Elem() + elem.Set(reflect.Zero(elem.Type())) -func (src *JSON) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return src.EncodeText(ci, buf) + return s.unmarshal(src, dst) } -// Scan implements the database/sql Scanner interface. -func (dst *JSON) Scan(src interface{}) error { +func (c *JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { - *dst = JSON{Status: Null} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + return nil, nil } - return errors.Errorf("cannot scan %T", src) + dstBuf := make([]byte, len(src)) + copy(dstBuf, src) + return dstBuf, nil } -// Value implements the database/sql/driver Valuer interface. -func (src *JSON) Value() (driver.Value, error) { - switch src.Status { - case Present: - return string(src.Bytes), nil - case Null: +func (c *JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { return nil, nil - default: - return nil, errUndefined } + + var dst any + err := c.Unmarshal(src, &dst) + return dst, err } diff --git a/pgtype/json_test.go b/pgtype/json_test.go index 82c02539e..e1d654f26 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -1,136 +1,362 @@ package pgtype_test import ( - "bytes" + "context" + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" ) -func TestJSONTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "json", []interface{}{ - &pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}, - &pgtype.JSON{Bytes: []byte("null"), Status: pgtype.Present}, - &pgtype.JSON{Bytes: []byte("42"), Status: pgtype.Present}, - &pgtype.JSON{Bytes: []byte(`"hello"`), Status: pgtype.Present}, - &pgtype.JSON{Status: pgtype.Null}, - }) -} +func isExpectedEqMap(a any) func(any) bool { + return func(v any) bool { + aa := a.(map[string]any) + bb := v.(map[string]any) -func TestJSONSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.JSON - }{ - {source: "{}", result: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}}, - {source: []byte("{}"), result: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}}, - {source: ([]byte)(nil), result: pgtype.JSON{Status: pgtype.Null}}, - {source: (*string)(nil), result: pgtype.JSON{Status: pgtype.Null}}, - {source: []int{1, 2, 3}, result: pgtype.JSON{Bytes: []byte("[1,2,3]"), Status: pgtype.Present}}, - {source: map[string]interface{}{"foo": "bar"}, result: pgtype.JSON{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}}, - } + if (aa == nil) != (bb == nil) { + return false + } + + if aa == nil { + return true + } - for i, tt := range successfulTests { - var d pgtype.JSON - err := d.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) + if len(aa) != len(bb) { + return false } - if !reflect.DeepEqual(d, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + for k := range aa { + if aa[k] != bb[k] { + return false + } } + + return true } } -func TestJSONAssignTo(t *testing.T) { - var s string - var ps *string - var b []byte - - rawStringTests := []struct { - src pgtype.JSON - dst *string - expected string - }{ - {src: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &s, expected: "{}"}, +func TestJSONCodec(t *testing.T) { + type jsonStruct struct { + Name string `json:"name"` + Age int `json:"age"` } - for i, tt := range rawStringTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } + var str string + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "json", []pgxtest.ValueRoundTripTest{ + {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, + {map[string]any(nil), new(*string), isExpectedEq((*string)(nil))}, + {map[string]any(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, - if *tt.dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } + // Test sql.Scanner. (https://github.com/jackc/pgx/issues/1418) + {"42", new(sql.NullInt64), isExpectedEq(sql.NullInt64{Int64: 42, Valid: true})}, - rawBytesTests := []struct { - src pgtype.JSON - dst *[]byte - expected []byte - }{ - {src: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &b, expected: []byte("{}")}, - {src: pgtype.JSON{Status: pgtype.Null}, dst: &b, expected: (([]byte)(nil))}, - } + // Test driver.Valuer. (https://github.com/jackc/pgx/issues/1430) + {sql.NullInt64{Int64: 42, Valid: true}, new(sql.NullInt64), isExpectedEq(sql.NullInt64{Int64: 42, Valid: true})}, - for i, tt := range rawBytesTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } + // Test driver.Valuer is used before json.Marshaler (https://github.com/jackc/pgx/issues/1805) + {Issue1805(7), new(Issue1805), isExpectedEq(Issue1805(7))}, + // Test driver.Scanner is used before json.Unmarshaler (https://github.com/jackc/pgx/issues/2146) + {Issue2146(7), new(*Issue2146), isPtrExpectedEq(Issue2146(7))}, - if bytes.Compare(tt.expected, *tt.dst) != 0 { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } + // Test driver.Scanner without pointer receiver (https://github.com/jackc/pgx/issues/2204) + {NonPointerJSONScanner{V: stringPtr("{}")}, NonPointerJSONScanner{V: &str}, func(a any) bool { return str == "{}" }}, + }) - var mapDst map[string]interface{} - type structDst struct { - Name string `json:"name"` - Age int `json:"age"` + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{ + {[]byte("{}"), new([]byte), isExpectedEqBytes([]byte("{}"))}, + {[]byte("null"), new([]byte), isExpectedEqBytes([]byte("null"))}, + {[]byte("42"), new([]byte), isExpectedEqBytes([]byte("42"))}, + {[]byte(`"hello"`), new([]byte), isExpectedEqBytes([]byte(`"hello"`))}, + {[]byte(`"hello"`), new(string), isExpectedEq(`"hello"`)}, + {map[string]any{"foo": "bar"}, new(map[string]any), isExpectedEqMap(map[string]any{"foo": "bar"})}, + {jsonStruct{Name: "Adam", Age: 10}, new(jsonStruct), isExpectedEq(jsonStruct{Name: "Adam", Age: 10})}, + }) +} + +type Issue1805 int + +func (i *Issue1805) Scan(src any) error { + var source []byte + switch src.(type) { + case string: + source = []byte(src.(string)) + case []byte: + source = src.([]byte) + default: + return errors.New("unknown source type") } - var strDst structDst - - unmarshalTests := []struct { - src pgtype.JSON - dst interface{} - expected interface{} - }{ - {src: pgtype.JSON{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, - {src: pgtype.JSON{Bytes: []byte(`{"name":"John","age":42}`), Status: pgtype.Present}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, + var newI int + if err := json.Unmarshal(source, &newI); err != nil { + return err } - for i, tt := range unmarshalTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } + *i = Issue1805(newI) + return nil +} - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } +func (i Issue1805) Value() (driver.Value, error) { + b, err := json.Marshal(int(i)) + return string(b), err +} + +func (i Issue1805) UnmarshalJSON(bytes []byte) error { + return errors.New("UnmarshalJSON called") +} + +func (i Issue1805) MarshalJSON() ([]byte, error) { + return nil, errors.New("MarshalJSON called") +} + +type Issue2146 int + +func (i *Issue2146) Scan(src any) error { + var source []byte + switch src.(type) { + case string: + source = []byte(src.(string)) + case []byte: + source = src.([]byte) + default: + return errors.New("unknown source type") + } + var newI int + if err := json.Unmarshal(source, &newI); err != nil { + return err } + *i = Issue2146(newI + 1) + return nil +} - pointerAllocTests := []struct { - src pgtype.JSON - dst **string - expected *string - }{ - {src: pgtype.JSON{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, +func (i Issue2146) Value() (driver.Value, error) { + b, err := json.Marshal(int(i - 1)) + return string(b), err +} + +type NonPointerJSONScanner struct { + V *string +} + +func (i NonPointerJSONScanner) Scan(src any) error { + switch c := src.(type) { + case string: + *i.V = c + case []byte: + *i.V = string(c) + default: + return errors.New("unknown source type") } - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) + return nil +} + +func (i NonPointerJSONScanner) Value() (driver.Value, error) { + return i.V, nil +} + +// https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648 +func TestJSONCodecUnmarshalSQLNull(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + // Slices are nilified + slice := []string{"foo", "bar", "baz"} + err := conn.QueryRow(ctx, "select null::json").Scan(&slice) + require.NoError(t, err) + require.Nil(t, slice) + + // Maps are nilified + m := map[string]any{"foo": "bar"} + err = conn.QueryRow(ctx, "select null::json").Scan(&m) + require.NoError(t, err) + require.Nil(t, m) + + m = map[string]interface{}{"foo": "bar"} + err = conn.QueryRow(ctx, "select null::json").Scan(&m) + require.NoError(t, err) + require.Nil(t, m) + + // Pointer to pointer are nilified + n := 42 + p := &n + err = conn.QueryRow(ctx, "select null::json").Scan(&p) + require.NoError(t, err) + require.Nil(t, p) + + // A string cannot scan a NULL. + str := "foobar" + err = conn.QueryRow(ctx, "select null::json").Scan(&str) + fieldName := "json" + if conn.PgConn().ParameterStatus("crdb_version") != "" { + fieldName = "jsonb" // Seems like CockroachDB treats json as jsonb. } + require.EqualError(t, err, fmt.Sprintf("can't scan into dest[0] (col: %s): cannot scan NULL into *string", fieldName)) + + // A non-string cannot scan a NULL. + err = conn.QueryRow(ctx, "select null::json").Scan(&n) + require.EqualError(t, err, fmt.Sprintf("can't scan into dest[0] (col: %s): cannot scan NULL into *int", fieldName)) + }) +} + +// https://github.com/jackc/pgx/issues/1470 +func TestJSONCodecPointerToPointerToString(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var s *string + err := conn.QueryRow(ctx, "select '{}'::json").Scan(&s) + require.NoError(t, err) + require.NotNil(t, s) + require.Equal(t, "{}", *s) + + err = conn.QueryRow(ctx, "select null::json").Scan(&s) + require.NoError(t, err) + require.Nil(t, s) + }) +} + +// https://github.com/jackc/pgx/issues/1691 +func TestJSONCodecPointerToPointerToInt(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + n := 44 + p := &n + err := conn.QueryRow(ctx, "select 'null'::jsonb").Scan(&p) + require.NoError(t, err) + require.Nil(t, p) + }) +} - if *tt.dst == tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) +// https://github.com/jackc/pgx/issues/1691 +func TestJSONCodecPointerToPointerToStruct(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + type ImageSize struct { + Height int `json:"height"` + Width int `json:"width"` + Str string `json:"str"` } + is := &ImageSize{Height: 100, Width: 100, Str: "str"} + err := conn.QueryRow(ctx, `select 'null'::jsonb`).Scan(&is) + require.NoError(t, err) + require.Nil(t, is) + }) +} + +func TestJSONCodecClearExistingValueBeforeUnmarshal(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + m := map[string]any{} + err := conn.QueryRow(ctx, `select '{"foo": "bar"}'::json`).Scan(&m) + require.NoError(t, err) + require.Equal(t, map[string]any{"foo": "bar"}, m) + + err = conn.QueryRow(ctx, `select '{"baz": "quz"}'::json`).Scan(&m) + require.NoError(t, err) + require.Equal(t, map[string]any{"baz": "quz"}, m) + }) +} + +type ParentIssue1681 struct { + Child ChildIssue1681 +} + +func (t *ParentIssue1681) MarshalJSON() ([]byte, error) { + return []byte(`{"custom":"thing"}`), nil +} + +type ChildIssue1681 struct{} + +func (t ChildIssue1681) MarshalJSON() ([]byte, error) { + return []byte(`{"someVal": false}`), nil +} + +// https://github.com/jackc/pgx/issues/1681 +func TestJSONCodecEncodeJSONMarshalerThatCanBeWrapped(t *testing.T) { + skipCockroachDB(t, "CockroachDB treats json as jsonb. This causes it to format differently than PostgreSQL.") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var jsonStr string + err := conn.QueryRow(context.Background(), "select $1::json", &ParentIssue1681{}).Scan(&jsonStr) + require.NoError(t, err) + require.Equal(t, `{"custom":"thing"}`, jsonStr) + }) +} + +func TestJSONCodecCustomMarshal(t *testing.T) { + skipCockroachDB(t, "CockroachDB treats json as jsonb. This causes it to format differently than PostgreSQL.") + + connTestRunner := defaultConnTestRunner + connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "json", OID: pgtype.JSONOID, Codec: &pgtype.JSONCodec{ + Marshal: func(v any) ([]byte, error) { + return []byte(`{"custom":"value"}`), nil + }, + Unmarshal: func(data []byte, v any) error { + return json.Unmarshal([]byte(`{"custom":"value"}`), v) + }, + }, + }) } + + pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{ + // There is no space between "custom" and "value" in json type. + {map[string]any{"something": "else"}, new(string), isExpectedEq(`{"custom":"value"}`)}, + {[]byte(`{"something":"else"}`), new(map[string]any), func(v any) bool { + return reflect.DeepEqual(v, map[string]any{"custom": "value"}) + }}, + }) +} + +func TestJSONCodecScanToNonPointerValues(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + n := 44 + err := conn.QueryRow(ctx, "select '42'::jsonb").Scan(n) + require.Error(t, err) + + var i *int + err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(i) + require.Error(t, err) + + m := 0 + err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(&m) + require.NoError(t, err) + require.Equal(t, 42, m) + }) +} + +func TestJSONCodecScanNull(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var dest struct{} + err := conn.QueryRow(ctx, "select null::jsonb").Scan(&dest) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot scan NULL into *struct {}") + + err = conn.QueryRow(ctx, "select 'null'::jsonb").Scan(&dest) + require.NoError(t, err) + + var destPointer *struct{} + err = conn.QueryRow(ctx, "select null::jsonb").Scan(&destPointer) + require.NoError(t, err) + require.Nil(t, destPointer) + + err = conn.QueryRow(ctx, "select 'null'::jsonb").Scan(&destPointer) + require.NoError(t, err) + require.Nil(t, destPointer) + + var raw json.RawMessage + require.NoError(t, conn.QueryRow(ctx, "select 'null'::jsonb").Scan(&raw)) + require.Equal(t, json.RawMessage("null"), raw) + }) +} + +func TestJSONCodecScanNullToPointerToSQLScanner(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var dest *Issue2146 + err := conn.QueryRow(ctx, "select null::jsonb").Scan(&dest) + require.NoError(t, err) + require.Nil(t, dest) + }) } diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go index c315c5881..4d4eb58e5 100644 --- a/pgtype/jsonb.go +++ b/pgtype/jsonb.go @@ -2,69 +2,128 @@ package pgtype import ( "database/sql/driver" - - "github.com/pkg/errors" + "fmt" ) -type JSONB JSON +type JSONBCodec struct { + Marshal func(v any) ([]byte, error) + Unmarshal func(data []byte, v any) error +} + +func (*JSONBCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (*JSONBCodec) PreferredFormat() int16 { + return TextFormatCode +} + +func (c *JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, TextFormatCode, value) + if plan != nil { + return &encodePlanJSONBCodecBinaryWrapper{textPlan: plan} + } + case TextFormatCode: + return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, format, value) + } -func (dst *JSONB) Set(src interface{}) error { - return (*JSON)(dst).Set(src) + return nil } -func (dst *JSONB) Get() interface{} { - return (*JSON)(dst).Get() +type encodePlanJSONBCodecBinaryWrapper struct { + textPlan EncodePlan } -func (src *JSONB) AssignTo(dst interface{}) error { - return (*JSON)(src).AssignTo(dst) +func (plan *encodePlanJSONBCodecBinaryWrapper) Encode(value any, buf []byte) (newBuf []byte, err error) { + buf = append(buf, 1) + return plan.textPlan.Encode(value, buf) +} + +func (c *JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, TextFormatCode, target) + if plan != nil { + return &scanPlanJSONBCodecBinaryUnwrapper{textPlan: plan} + } + case TextFormatCode: + return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, format, target) + } + + return nil } -func (dst *JSONB) DecodeText(ci *ConnInfo, src []byte) error { - return (*JSON)(dst).DecodeText(ci, src) +type scanPlanJSONBCodecBinaryUnwrapper struct { + textPlan ScanPlan } -func (dst *JSONB) DecodeBinary(ci *ConnInfo, src []byte) error { +func (plan *scanPlanJSONBCodecBinaryUnwrapper) Scan(src []byte, dst any) error { if src == nil { - *dst = JSONB{Status: Null} - return nil + return plan.textPlan.Scan(src, dst) } if len(src) == 0 { - return errors.Errorf("jsonb too short") + return fmt.Errorf("jsonb too short") } if src[0] != 1 { - return errors.Errorf("unknown jsonb version number %d", src[0]) + return fmt.Errorf("unknown jsonb version number %d", src[0]) } - *dst = JSONB{Bytes: src[1:], Status: Present} - return nil - + return plan.textPlan.Scan(src[1:], dst) } -func (src *JSONB) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*JSON)(src).EncodeText(ci, buf) +func (c *JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + switch format { + case BinaryFormatCode: + if len(src) == 0 { + return nil, fmt.Errorf("jsonb too short") + } + + if src[0] != 1 { + return nil, fmt.Errorf("unknown jsonb version number %d", src[0]) + } + + dstBuf := make([]byte, len(src)-1) + copy(dstBuf, src[1:]) + return dstBuf, nil + case TextFormatCode: + dstBuf := make([]byte, len(src)) + copy(dstBuf, src) + return dstBuf, nil + default: + return nil, fmt.Errorf("unknown format code: %v", format) + } } -func (src *JSONB) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: +func (c *JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { return nil, nil - case Undefined: - return nil, errUndefined } - buf = append(buf, 1) - return append(buf, src.Bytes...), nil -} + switch format { + case BinaryFormatCode: + if len(src) == 0 { + return nil, fmt.Errorf("jsonb too short") + } -// Scan implements the database/sql Scanner interface. -func (dst *JSONB) Scan(src interface{}) error { - return (*JSON)(dst).Scan(src) -} + if src[0] != 1 { + return nil, fmt.Errorf("unknown jsonb version number %d", src[0]) + } + + src = src[1:] + case TextFormatCode: + default: + return nil, fmt.Errorf("unknown format code: %v", format) + } -// Value implements the database/sql/driver Valuer interface. -func (src *JSONB) Value() (driver.Value, error) { - return (*JSON)(src).Value() + var dst any + err := c.Unmarshal(src, &dst) + return dst, err } diff --git a/pgtype/jsonb_test.go b/pgtype/jsonb_test.go index 1a9a30562..0826f111f 100644 --- a/pgtype/jsonb_test.go +++ b/pgtype/jsonb_test.go @@ -1,142 +1,109 @@ package pgtype_test import ( - "bytes" + "context" + "encoding/json" "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" ) func TestJSONBTranscode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustClose(t, conn) - if _, ok := conn.ConnInfo.DataTypeForName("jsonb"); !ok { - t.Skip("Skipping due to no jsonb type") + type jsonStruct struct { + Name string `json:"name"` + Age int `json:"age"` } - testutil.TestSuccessfulTranscode(t, "jsonb", []interface{}{ - &pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}, - &pgtype.JSONB{Bytes: []byte("null"), Status: pgtype.Present}, - &pgtype.JSONB{Bytes: []byte("42"), Status: pgtype.Present}, - &pgtype.JSONB{Bytes: []byte(`"hello"`), Status: pgtype.Present}, - &pgtype.JSONB{Status: pgtype.Null}, + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "jsonb", []pgxtest.ValueRoundTripTest{ + {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, + {map[string]any(nil), new(*string), isExpectedEq((*string)(nil))}, + {map[string]any(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, }) -} - -func TestJSONBSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.JSONB - }{ - {source: "{}", result: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}}, - {source: []byte("{}"), result: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}}, - {source: ([]byte)(nil), result: pgtype.JSONB{Status: pgtype.Null}}, - {source: (*string)(nil), result: pgtype.JSONB{Status: pgtype.Null}}, - {source: []int{1, 2, 3}, result: pgtype.JSONB{Bytes: []byte("[1,2,3]"), Status: pgtype.Present}}, - {source: map[string]interface{}{"foo": "bar"}, result: pgtype.JSONB{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var d pgtype.JSONB - err := d.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - if !reflect.DeepEqual(d, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) - } - } + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "jsonb", []pgxtest.ValueRoundTripTest{ + {[]byte("{}"), new([]byte), isExpectedEqBytes([]byte("{}"))}, + {[]byte("null"), new([]byte), isExpectedEqBytes([]byte("null"))}, + {[]byte("42"), new([]byte), isExpectedEqBytes([]byte("42"))}, + {[]byte(`"hello"`), new([]byte), isExpectedEqBytes([]byte(`"hello"`))}, + {[]byte(`"hello"`), new(string), isExpectedEq(`"hello"`)}, + {map[string]any{"foo": "bar"}, new(map[string]any), isExpectedEqMap(map[string]any{"foo": "bar"})}, + {jsonStruct{Name: "Adam", Age: 10}, new(jsonStruct), isExpectedEq(jsonStruct{Name: "Adam", Age: 10})}, + }) } -func TestJSONBAssignTo(t *testing.T) { - var s string - var ps *string - var b []byte - - rawStringTests := []struct { - src pgtype.JSONB - dst *string - expected string - }{ - {src: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &s, expected: "{}"}, - } - - for i, tt := range rawStringTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if *tt.dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } - - rawBytesTests := []struct { - src pgtype.JSONB - dst *[]byte - expected []byte - }{ - {src: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &b, expected: []byte("{}")}, - {src: pgtype.JSONB{Status: pgtype.Null}, dst: &b, expected: (([]byte)(nil))}, - } - - for i, tt := range rawBytesTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if bytes.Compare(tt.expected, *tt.dst) != 0 { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } - - var mapDst map[string]interface{} - type structDst struct { - Name string `json:"name"` - Age int `json:"age"` - } - var strDst structDst - - unmarshalTests := []struct { - src pgtype.JSONB - dst interface{} - expected interface{} - }{ - {src: pgtype.JSONB{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, - {src: pgtype.JSONB{Bytes: []byte(`{"name":"John","age":42}`), Status: pgtype.Present}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, - } - for i, tt := range unmarshalTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } +func TestJSONBCodecUnmarshalSQLNull(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + // Slices are nilified + slice := []string{"foo", "bar", "baz"} + err := conn.QueryRow(ctx, "select null::jsonb").Scan(&slice) + require.NoError(t, err) + require.Nil(t, slice) + + // Maps are nilified + m := map[string]any{"foo": "bar"} + err = conn.QueryRow(ctx, "select null::jsonb").Scan(&m) + require.NoError(t, err) + require.Nil(t, m) + + m = map[string]interface{}{"foo": "bar"} + err = conn.QueryRow(ctx, "select null::jsonb").Scan(&m) + require.NoError(t, err) + require.Nil(t, m) + + // Pointer to pointer are nilified + n := 42 + p := &n + err = conn.QueryRow(ctx, "select null::jsonb").Scan(&p) + require.NoError(t, err) + require.Nil(t, p) + + // A string cannot scan a NULL. + str := "foobar" + err = conn.QueryRow(ctx, "select null::jsonb").Scan(&str) + require.EqualError(t, err, "can't scan into dest[0] (col: jsonb): cannot scan NULL into *string") + + // A non-string cannot scan a NULL. + err = conn.QueryRow(ctx, "select null::jsonb").Scan(&n) + require.EqualError(t, err, "can't scan into dest[0] (col: jsonb): cannot scan NULL into *int") + }) +} - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } +// https://github.com/jackc/pgx/issues/1681 +func TestJSONBCodecEncodeJSONMarshalerThatCanBeWrapped(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var jsonStr string + err := conn.QueryRow(context.Background(), "select $1::jsonb", &ParentIssue1681{}).Scan(&jsonStr) + require.NoError(t, err) + require.Equal(t, `{"custom": "thing"}`, jsonStr) // Note that unlike json, jsonb reformats the JSON string. + }) +} - pointerAllocTests := []struct { - src pgtype.JSONB - dst **string - expected *string - }{ - {src: pgtype.JSONB{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, +func TestJSONBCodecCustomMarshal(t *testing.T) { + connTestRunner := defaultConnTestRunner + connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "jsonb", OID: pgtype.JSONBOID, Codec: &pgtype.JSONBCodec{ + Marshal: func(v any) ([]byte, error) { + return []byte(`{"custom":"value"}`), nil + }, + Unmarshal: func(data []byte, v any) error { + return json.Unmarshal([]byte(`{"custom":"value"}`), v) + }, + }, + }) } - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if *tt.dst == tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) - } - } + pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "jsonb", []pgxtest.ValueRoundTripTest{ + // There is space between "custom" and "value" in jsonb type. + {map[string]any{"something": "else"}, new(string), isExpectedEq(`{"custom": "value"}`)}, + {[]byte(`{"something":"else"}`), new(map[string]any), func(v any) bool { + return reflect.DeepEqual(v, map[string]any{"custom": "value"}) + }}, + }) } diff --git a/pgtype/line.go b/pgtype/line.go index f6eadf0ef..10efc8ce7 100644 --- a/pgtype/line.go +++ b/pgtype/line.go @@ -8,136 +8,219 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" + "github.com/jackc/pgx/v5/internal/pgio" ) +type LineScanner interface { + ScanLine(v Line) error +} + +type LineValuer interface { + LineValue() (Line, error) +} + type Line struct { A, B, C float64 - Status Status + Valid bool } -func (dst *Line) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Line", src) +// ScanLine implements the [LineScanner] interface. +func (line *Line) ScanLine(v Line) error { + *line = v + return nil } -func (dst *Line) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } +// LineValue implements the [LineValuer] interface. +func (line Line) LineValue() (Line, error) { + return line, nil } -func (src *Line) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) +func (line *Line) Set(src any) error { + return fmt.Errorf("cannot convert %v to Line", src) } -func (dst *Line) DecodeText(ci *ConnInfo, src []byte) error { +// Scan implements the [database/sql.Scanner] interface. +func (line *Line) Scan(src any) error { if src == nil { - *dst = Line{Status: Null} + *line = Line{} return nil } - if len(src) < 7 { - return errors.Errorf("invalid length for Line: %v", len(src)) + switch src := src.(type) { + case string: + return scanPlanTextAnyToLineScanner{}.Scan([]byte(src), line) } - parts := strings.SplitN(string(src[1:len(src)-1]), ",", 3) - if len(parts) < 3 { - return errors.Errorf("invalid format for line") + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (line Line) Value() (driver.Value, error) { + if !line.Valid { + return nil, nil } - a, err := strconv.ParseFloat(parts[0], 64) + buf, err := LineCodec{}.PlanEncode(nil, 0, TextFormatCode, line).Encode(line, nil) if err != nil { - return err + return nil, err } + return string(buf), err +} - b, err := strconv.ParseFloat(parts[1], 64) +type LineCodec struct{} + +func (LineCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (LineCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (LineCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(LineValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanLineCodecBinary{} + case TextFormatCode: + return encodePlanLineCodecText{} + } + + return nil +} + +type encodePlanLineCodecBinary struct{} + +func (encodePlanLineCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + line, err := value.(LineValuer).LineValue() if err != nil { - return err + return nil, err } - c, err := strconv.ParseFloat(parts[2], 64) + if !line.Valid { + return nil, nil + } + + buf = pgio.AppendUint64(buf, math.Float64bits(line.A)) + buf = pgio.AppendUint64(buf, math.Float64bits(line.B)) + buf = pgio.AppendUint64(buf, math.Float64bits(line.C)) + return buf, nil +} + +type encodePlanLineCodecText struct{} + +func (encodePlanLineCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + line, err := value.(LineValuer).LineValue() if err != nil { - return err + return nil, err + } + + if !line.Valid { + return nil, nil + } + + buf = append(buf, fmt.Sprintf(`{%s,%s,%s}`, + strconv.FormatFloat(line.A, 'f', -1, 64), + strconv.FormatFloat(line.B, 'f', -1, 64), + strconv.FormatFloat(line.C, 'f', -1, 64), + )...) + return buf, nil +} + +func (LineCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case LineScanner: + return scanPlanBinaryLineToLineScanner{} + } + case TextFormatCode: + switch target.(type) { + case LineScanner: + return scanPlanTextAnyToLineScanner{} + } } - *dst = Line{A: a, B: b, C: c, Status: Present} return nil } -func (dst *Line) DecodeBinary(ci *ConnInfo, src []byte) error { +type scanPlanBinaryLineToLineScanner struct{} + +func (scanPlanBinaryLineToLineScanner) Scan(src []byte, dst any) error { + scanner := (dst).(LineScanner) + if src == nil { - *dst = Line{Status: Null} - return nil + return scanner.ScanLine(Line{}) } if len(src) != 24 { - return errors.Errorf("invalid length for Line: %v", len(src)) + return fmt.Errorf("invalid length for line: %v", len(src)) } a := binary.BigEndian.Uint64(src) b := binary.BigEndian.Uint64(src[8:]) c := binary.BigEndian.Uint64(src[16:]) - *dst = Line{ - A: math.Float64frombits(a), - B: math.Float64frombits(b), - C: math.Float64frombits(c), - Status: Present, - } - return nil + return scanner.ScanLine(Line{ + A: math.Float64frombits(a), + B: math.Float64frombits(b), + C: math.Float64frombits(c), + Valid: true, + }) } -func (src *Line) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined +type scanPlanTextAnyToLineScanner struct{} + +func (scanPlanTextAnyToLineScanner) Scan(src []byte, dst any) error { + scanner := (dst).(LineScanner) + + if src == nil { + return scanner.ScanLine(Line{}) } - return append(buf, fmt.Sprintf(`{%f,%f,%f}`, src.A, src.B, src.C)...), nil -} + if len(src) < 7 { + return fmt.Errorf("invalid length for line: %v", len(src)) + } -func (src *Line) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 3) + if len(parts) < 3 { + return fmt.Errorf("invalid format for line") } - buf = pgio.AppendUint64(buf, math.Float64bits(src.A)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.B)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.C)) - return buf, nil -} + a, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return err + } -// Scan implements the database/sql Scanner interface. -func (dst *Line) Scan(src interface{}) error { - if src == nil { - *dst = Line{Status: Null} - return nil + b, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + return err } - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + c, err := strconv.ParseFloat(parts[2], 64) + if err != nil { + return err } - return errors.Errorf("cannot scan %T", src) + return scanner.ScanLine(Line{A: a, B: b, C: c, Valid: true}) } -// Value implements the database/sql/driver Valuer interface. -func (src *Line) Value() (driver.Value, error) { - return EncodeValueText(src) +func (c LineCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c LineCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var line Line + err := codecScan(c, m, oid, format, src, &line) + if err != nil { + return nil, err + } + return line, nil } diff --git a/pgtype/line_test.go b/pgtype/line_test.go index 09e480198..dc980ce10 100644 --- a/pgtype/line_test.go +++ b/pgtype/line_test.go @@ -1,36 +1,58 @@ package pgtype_test import ( + "context" "testing" - version "github.com/hashicorp/go-version" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" ) func TestLineTranscode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - serverVersion, err := version.NewVersion(conn.RuntimeParams["server_version"]) - if err != nil { - t.Fatalf("cannot get server version: %v", err) - } - testutil.MustClose(t, conn) + ctr := defaultConnTestRunner + ctr.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does not support type line") - minVersion := version.Must(version.NewVersion("9.4")) + if _, ok := conn.TypeMap().TypeForName("line"); !ok { + t.Skip("Skipping due to no line type") + } - if serverVersion.LessThan(minVersion) { - t.Skipf("Skipping line test for server version %v", serverVersion) + // line may exist but not be usable on 9.3 :( + var isPG93 bool + err := conn.QueryRow(context.Background(), "select version() ~ '9.3'").Scan(&isPG93) + if err != nil { + t.Fatal(err) + } + if isPG93 { + t.Skip("Skipping due to unimplemented line type in PG 9.3") + } } - testutil.TestSuccessfulTranscode(t, "line", []interface{}{ - &pgtype.Line{ - A: 1.23, B: 4.56, C: 7.89, - Status: pgtype.Present, + pgxtest.RunValueRoundTripTests(context.Background(), t, ctr, nil, "line", []pgxtest.ValueRoundTripTest{ + { + pgtype.Line{ + A: 1.23, B: 4.56, C: 7.89012345, + Valid: true, + }, + new(pgtype.Line), + isExpectedEq(pgtype.Line{ + A: 1.23, B: 4.56, C: 7.89012345, + Valid: true, + }), }, - &pgtype.Line{ - A: -1.23, B: -4.56, C: -7.89, - Status: pgtype.Present, + { + pgtype.Line{ + A: -1.23, B: -4.56, C: -7.89, + Valid: true, + }, + new(pgtype.Line), + isExpectedEq(pgtype.Line{ + A: -1.23, B: -4.56, C: -7.89, + Valid: true, + }), }, - &pgtype.Line{Status: pgtype.Null}, + {pgtype.Line{}, new(pgtype.Line), isExpectedEq(pgtype.Line{})}, + {nil, new(pgtype.Line), isExpectedEq(pgtype.Line{})}, }) } diff --git a/pgtype/lseg.go b/pgtype/lseg.go index a9d740cf7..ed0d40d2a 100644 --- a/pgtype/lseg.go +++ b/pgtype/lseg.go @@ -8,89 +8,154 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" + "github.com/jackc/pgx/v5/internal/pgio" ) -type Lseg struct { - P [2]Vec2 - Status Status +type LsegScanner interface { + ScanLseg(v Lseg) error } -func (dst *Lseg) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Lseg", src) +type LsegValuer interface { + LsegValue() (Lseg, error) } -func (dst *Lseg) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } +type Lseg struct { + P [2]Vec2 + Valid bool +} + +// ScanLseg implements the [LsegScanner] interface. +func (lseg *Lseg) ScanLseg(v Lseg) error { + *lseg = v + return nil } -func (src *Lseg) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) +// LsegValue implements the [LsegValuer] interface. +func (lseg Lseg) LsegValue() (Lseg, error) { + return lseg, nil } -func (dst *Lseg) DecodeText(ci *ConnInfo, src []byte) error { +// Scan implements the [database/sql.Scanner] interface. +func (lseg *Lseg) Scan(src any) error { if src == nil { - *dst = Lseg{Status: Null} + *lseg = Lseg{} return nil } - if len(src) < 11 { - return errors.Errorf("invalid length for Lseg: %v", len(src)) + switch src := src.(type) { + case string: + return scanPlanTextAnyToLsegScanner{}.Scan([]byte(src), lseg) } - str := string(src[2:]) + return fmt.Errorf("cannot scan %T", src) +} - var end int - end = strings.IndexByte(str, ',') +// Value implements the [database/sql/driver.Valuer] interface. +func (lseg Lseg) Value() (driver.Value, error) { + if !lseg.Valid { + return nil, nil + } - x1, err := strconv.ParseFloat(str[:end], 64) + buf, err := LsegCodec{}.PlanEncode(nil, 0, TextFormatCode, lseg).Encode(lseg, nil) if err != nil { - return err + return nil, err } + return string(buf), err +} - str = str[end+1:] - end = strings.IndexByte(str, ')') +type LsegCodec struct{} - y1, err := strconv.ParseFloat(str[:end], 64) - if err != nil { - return err +func (LsegCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (LsegCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (LsegCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(LsegValuer); !ok { + return nil } - str = str[end+3:] - end = strings.IndexByte(str, ',') + switch format { + case BinaryFormatCode: + return encodePlanLsegCodecBinary{} + case TextFormatCode: + return encodePlanLsegCodecText{} + } - x2, err := strconv.ParseFloat(str[:end], 64) + return nil +} + +type encodePlanLsegCodecBinary struct{} + +func (encodePlanLsegCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + lseg, err := value.(LsegValuer).LsegValue() if err != nil { - return err + return nil, err } - str = str[end+1 : len(str)-2] + if !lseg.Valid { + return nil, nil + } - y2, err := strconv.ParseFloat(str, 64) + buf = pgio.AppendUint64(buf, math.Float64bits(lseg.P[0].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(lseg.P[0].Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(lseg.P[1].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(lseg.P[1].Y)) + return buf, nil +} + +type encodePlanLsegCodecText struct{} + +func (encodePlanLsegCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + lseg, err := value.(LsegValuer).LsegValue() if err != nil { - return err + return nil, err + } + + if !lseg.Valid { + return nil, nil + } + + buf = append(buf, fmt.Sprintf(`[(%s,%s),(%s,%s)]`, + strconv.FormatFloat(lseg.P[0].X, 'f', -1, 64), + strconv.FormatFloat(lseg.P[0].Y, 'f', -1, 64), + strconv.FormatFloat(lseg.P[1].X, 'f', -1, 64), + strconv.FormatFloat(lseg.P[1].Y, 'f', -1, 64), + )...) + return buf, nil +} + +func (LsegCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case LsegScanner: + return scanPlanBinaryLsegToLsegScanner{} + } + case TextFormatCode: + switch target.(type) { + case LsegScanner: + return scanPlanTextAnyToLsegScanner{} + } } - *dst = Lseg{P: [2]Vec2{{x1, y1}, {x2, y2}}, Status: Present} return nil } -func (dst *Lseg) DecodeBinary(ci *ConnInfo, src []byte) error { +type scanPlanBinaryLsegToLsegScanner struct{} + +func (scanPlanBinaryLsegToLsegScanner) Scan(src []byte, dst any) error { + scanner := (dst).(LsegScanner) + if src == nil { - *dst = Lseg{Status: Null} - return nil + return scanner.ScanLseg(Lseg{}) } if len(src) != 32 { - return errors.Errorf("invalid length for Lseg: %v", len(src)) + return fmt.Errorf("invalid length for lseg: %v", len(src)) } x1 := binary.BigEndian.Uint64(src) @@ -98,64 +163,77 @@ func (dst *Lseg) DecodeBinary(ci *ConnInfo, src []byte) error { x2 := binary.BigEndian.Uint64(src[16:]) y2 := binary.BigEndian.Uint64(src[24:]) - *dst = Lseg{ + return scanner.ScanLseg(Lseg{ P: [2]Vec2{ {math.Float64frombits(x1), math.Float64frombits(y1)}, {math.Float64frombits(x2), math.Float64frombits(y2)}, }, - Status: Present, - } - return nil + Valid: true, + }) } -func (src *Lseg) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined +type scanPlanTextAnyToLsegScanner struct{} + +func (scanPlanTextAnyToLsegScanner) Scan(src []byte, dst any) error { + scanner := (dst).(LsegScanner) + + if src == nil { + return scanner.ScanLseg(Lseg{}) } - buf = append(buf, fmt.Sprintf(`(%f,%f),(%f,%f)`, - src.P[0].X, src.P[0].Y, src.P[1].X, src.P[1].Y)...) - return buf, nil -} + if len(src) < 11 { + return fmt.Errorf("invalid length for lseg: %v", len(src)) + } -func (src *Lseg) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined + str := string(src[2:]) + + var end int + end = strings.IndexByte(str, ',') + + x1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err } - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].X)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].Y)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].X)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].Y)) - return buf, nil -} + str = str[end+1:] + end = strings.IndexByte(str, ')') -// Scan implements the database/sql Scanner interface. -func (dst *Lseg) Scan(src interface{}) error { - if src == nil { - *dst = Lseg{Status: Null} - return nil + y1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err } - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + str = str[end+3:] + end = strings.IndexByte(str, ',') + + x2, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err } - return errors.Errorf("cannot scan %T", src) + str = str[end+1 : len(str)-2] + + y2, err := strconv.ParseFloat(str, 64) + if err != nil { + return err + } + + return scanner.ScanLseg(Lseg{P: [2]Vec2{{x1, y1}, {x2, y2}}, Valid: true}) +} + +func (c LsegCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) } -// Value implements the database/sql/driver Valuer interface. -func (src *Lseg) Value() (driver.Value, error) { - return EncodeValueText(src) +func (c LsegCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var lseg Lseg + err := codecScan(c, m, oid, format, src, &lseg) + if err != nil { + return nil, err + } + return lseg, nil } diff --git a/pgtype/lseg_test.go b/pgtype/lseg_test.go index bd394e3c0..04fde0ebd 100644 --- a/pgtype/lseg_test.go +++ b/pgtype/lseg_test.go @@ -1,22 +1,40 @@ package pgtype_test import ( + "context" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" ) func TestLsegTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "lseg", []interface{}{ - &pgtype.Lseg{ - P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}}, - Status: pgtype.Present, + skipCockroachDB(t, "Server does not support type lseg") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "lseg", []pgxtest.ValueRoundTripTest{ + { + pgtype.Lseg{ + P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.2345678901}}, + Valid: true, + }, + new(pgtype.Lseg), + isExpectedEq(pgtype.Lseg{ + P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.2345678901}}, + Valid: true, + }), }, - &pgtype.Lseg{ - P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, - Status: pgtype.Present, + { + pgtype.Lseg{ + P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Valid: true, + }, + new(pgtype.Lseg), + isExpectedEq(pgtype.Lseg{ + P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Valid: true, + }), }, - &pgtype.Lseg{Status: pgtype.Null}, + {pgtype.Lseg{}, new(pgtype.Lseg), isExpectedEq(pgtype.Lseg{})}, + {nil, new(pgtype.Lseg), isExpectedEq(pgtype.Lseg{})}, }) } diff --git a/pgtype/ltree.go b/pgtype/ltree.go new file mode 100644 index 000000000..6af317794 --- /dev/null +++ b/pgtype/ltree.go @@ -0,0 +1,122 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" +) + +type LtreeCodec struct{} + +func (l LtreeCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +// PreferredFormat returns the preferred format. +func (l LtreeCodec) PreferredFormat() int16 { + return TextFormatCode +} + +// PlanEncode returns an EncodePlan for encoding value into PostgreSQL format for oid and format. If no plan can be +// found then nil is returned. +func (l LtreeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case TextFormatCode: + return (TextCodec)(l).PlanEncode(m, oid, format, value) + case BinaryFormatCode: + switch value.(type) { + case string: + return encodeLtreeCodecBinaryString{} + case []byte: + return encodeLtreeCodecBinaryByteSlice{} + case TextValuer: + return encodeLtreeCodecBinaryTextValuer{} + } + } + + return nil +} + +type encodeLtreeCodecBinaryString struct{} + +func (encodeLtreeCodecBinaryString) Encode(value any, buf []byte) (newBuf []byte, err error) { + ltree := value.(string) + buf = append(buf, 1) + return append(buf, ltree...), nil +} + +type encodeLtreeCodecBinaryByteSlice struct{} + +func (encodeLtreeCodecBinaryByteSlice) Encode(value any, buf []byte) (newBuf []byte, err error) { + ltree := value.([]byte) + buf = append(buf, 1) + return append(buf, ltree...), nil +} + +type encodeLtreeCodecBinaryTextValuer struct{} + +func (encodeLtreeCodecBinaryTextValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + t, err := value.(TextValuer).TextValue() + if err != nil { + return nil, err + } + if !t.Valid { + return nil, nil + } + + buf = append(buf, 1) + return append(buf, t.String...), nil +} + +// PlanScan returns a ScanPlan for scanning a PostgreSQL value into a destination with the same type as target. If +// no plan can be found then nil is returned. +func (l LtreeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case TextFormatCode: + return (TextCodec)(l).PlanScan(m, oid, format, target) + case BinaryFormatCode: + switch target.(type) { + case *string: + return scanPlanBinaryLtreeToString{} + case TextScanner: + return scanPlanBinaryLtreeToTextScanner{} + } + } + + return nil +} + +type scanPlanBinaryLtreeToString struct{} + +func (scanPlanBinaryLtreeToString) Scan(src []byte, target any) error { + version := src[0] + if version != 1 { + return fmt.Errorf("unsupported ltree version %d", version) + } + + p := (target).(*string) + *p = string(src[1:]) + + return nil +} + +type scanPlanBinaryLtreeToTextScanner struct{} + +func (scanPlanBinaryLtreeToTextScanner) Scan(src []byte, target any) error { + version := src[0] + if version != 1 { + return fmt.Errorf("unsupported ltree version %d", version) + } + + scanner := (target).(TextScanner) + return scanner.ScanText(Text{String: string(src[1:]), Valid: true}) +} + +// DecodeDatabaseSQLValue returns src decoded into a value compatible with the sql.Scanner interface. +func (l LtreeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return (TextCodec)(l).DecodeDatabaseSQLValue(m, oid, format, src) +} + +// DecodeValue returns src decoded into its default format. +func (l LtreeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + return (TextCodec)(l).DecodeValue(m, oid, format, src) +} diff --git a/pgtype/ltree_test.go b/pgtype/ltree_test.go new file mode 100644 index 000000000..2ec850f56 --- /dev/null +++ b/pgtype/ltree_test.go @@ -0,0 +1,26 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestLtreeCodec(t *testing.T) { + skipCockroachDB(t, "Server does not support type ltree") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "ltree", []pgxtest.ValueRoundTripTest{ + { + Param: "A.B.C", + Result: new(string), + Test: isExpectedEq("A.B.C"), + }, + { + Param: pgtype.Text{String: "", Valid: true}, + Result: new(pgtype.Text), + Test: isExpectedEq(pgtype.Text{String: "", Valid: true}), + }, + }) +} diff --git a/pgtype/macaddr.go b/pgtype/macaddr.go index 4c6e2212f..e913ec903 100644 --- a/pgtype/macaddr.go +++ b/pgtype/macaddr.go @@ -3,152 +3,160 @@ package pgtype import ( "database/sql/driver" "net" - - "github.com/pkg/errors" ) -type Macaddr struct { - Addr net.HardwareAddr - Status Status +type MacaddrCodec struct{} + +func (MacaddrCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode } -func (dst *Macaddr) Set(src interface{}) error { - if src == nil { - *dst = Macaddr{Status: Null} - return nil - } +func (MacaddrCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (MacaddrCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case net.HardwareAddr: + return encodePlanMacaddrCodecBinaryHardwareAddr{} + case TextValuer: + return encodePlanMacAddrCodecTextValuer{} - switch value := src.(type) { - case net.HardwareAddr: - addr := make(net.HardwareAddr, len(value)) - copy(addr, value) - *dst = Macaddr{Addr: addr, Status: Present} - case string: - addr, err := net.ParseMAC(value) - if err != nil { - return err } - *dst = Macaddr{Addr: addr, Status: Present} - default: - if originalSrc, ok := underlyingPtrType(src); ok { - return dst.Set(originalSrc) + case TextFormatCode: + switch value.(type) { + case net.HardwareAddr: + return encodePlanMacaddrCodecTextHardwareAddr{} + case TextValuer: + return encodePlanTextCodecTextValuer{} } - return errors.Errorf("cannot convert %v to Macaddr", value) } return nil } -func (dst *Macaddr) Get() interface{} { - switch dst.Status { - case Present: - return dst.Addr - case Null: - return nil - default: - return dst.Status - } -} +type encodePlanMacaddrCodecBinaryHardwareAddr struct{} -func (src *Macaddr) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *net.HardwareAddr: - *v = make(net.HardwareAddr, len(src.Addr)) - copy(*v, src.Addr) - return nil - case *string: - *v = src.Addr.String() - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) +func (encodePlanMacaddrCodecBinaryHardwareAddr) Encode(value any, buf []byte) (newBuf []byte, err error) { + addr := value.(net.HardwareAddr) + if addr == nil { + return nil, nil } - return errors.Errorf("cannot decode %v into %T", src, dst) + return append(buf, addr...), nil } -func (dst *Macaddr) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Macaddr{Status: Null} - return nil +type encodePlanMacAddrCodecTextValuer struct{} + +func (encodePlanMacAddrCodecTextValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + t, err := value.(TextValuer).TextValue() + if err != nil { + return nil, err + } + if !t.Valid { + return nil, nil } - addr, err := net.ParseMAC(string(src)) + addr, err := net.ParseMAC(t.String) if err != nil { - return err + return nil, err } - *dst = Macaddr{Addr: addr, Status: Present} - return nil + return append(buf, addr...), nil } -func (dst *Macaddr) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Macaddr{Status: Null} - return nil - } +type encodePlanMacaddrCodecTextHardwareAddr struct{} - if len(src) != 6 { - return errors.Errorf("Received an invalid size for a macaddr: %d", len(src)) +func (encodePlanMacaddrCodecTextHardwareAddr) Encode(value any, buf []byte) (newBuf []byte, err error) { + addr := value.(net.HardwareAddr) + if addr == nil { + return nil, nil } - addr := make(net.HardwareAddr, 6) - copy(addr, src) + return append(buf, addr.String()...), nil +} - *dst = Macaddr{Addr: addr, Status: Present} +func (MacaddrCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case *net.HardwareAddr: + return scanPlanBinaryMacaddrToHardwareAddr{} + case TextScanner: + return scanPlanBinaryMacaddrToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case *net.HardwareAddr: + return scanPlanTextMacaddrToHardwareAddr{} + case TextScanner: + return scanPlanTextAnyToTextScanner{} + } + } return nil } -func (src *Macaddr) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined +type scanPlanBinaryMacaddrToHardwareAddr struct{} + +func (scanPlanBinaryMacaddrToHardwareAddr) Scan(src []byte, dst any) error { + dstBuf := dst.(*net.HardwareAddr) + if src == nil { + *dstBuf = nil + return nil } - return append(buf, src.Addr.String()...), nil + *dstBuf = make([]byte, len(src)) + copy(*dstBuf, src) + return nil } -// EncodeBinary encodes src into w. -func (src *Macaddr) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined +type scanPlanBinaryMacaddrToTextScanner struct{} + +func (scanPlanBinaryMacaddrToTextScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TextScanner) + if src == nil { + return scanner.ScanText(Text{}) } - return append(buf, src.Addr...), nil + return scanner.ScanText(Text{String: net.HardwareAddr(src).String(), Valid: true}) } -// Scan implements the database/sql Scanner interface. -func (dst *Macaddr) Scan(src interface{}) error { +type scanPlanTextMacaddrToHardwareAddr struct{} + +func (scanPlanTextMacaddrToHardwareAddr) Scan(src []byte, dst any) error { + p := dst.(*net.HardwareAddr) + if src == nil { - *dst = Macaddr{Status: Null} + *p = nil return nil } - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + addr, err := net.ParseMAC(string(src)) + if err != nil { + return err } - return errors.Errorf("cannot scan %T", src) + *p = addr + + return nil +} + +func (c MacaddrCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) } -// Value implements the database/sql/driver Valuer interface. -func (src *Macaddr) Value() (driver.Value, error) { - return EncodeValueText(src) +func (c MacaddrCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var addr net.HardwareAddr + err := codecScan(c, m, oid, format, src, &addr) + if err != nil { + return nil, err + } + return addr, nil } diff --git a/pgtype/macaddr_test.go b/pgtype/macaddr_test.go index 5d3292491..58149c87e 100644 --- a/pgtype/macaddr_test.go +++ b/pgtype/macaddr_test.go @@ -2,77 +2,69 @@ package pgtype_test import ( "bytes" + "context" "net" - "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) -func TestMacaddrTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "macaddr", []interface{}{ - &pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - &pgtype.Macaddr{Status: pgtype.Null}, - }) -} - -func TestMacaddrSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Macaddr - }{ - { - source: mustParseMacaddr(t, "01:23:45:67:89:ab"), - result: pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - }, - { - source: "01:23:45:67:89:ab", - result: pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - }, - } +func isExpectedEqHardwareAddr(a any) func(any) bool { + return func(v any) bool { + aa := a.(net.HardwareAddr) + vv := v.(net.HardwareAddr) - for i, tt := range successfulTests { - var r pgtype.Macaddr - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) + if (aa == nil) != (vv == nil) { + return false } - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestMacaddrAssignTo(t *testing.T) { - { - src := pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present} - var dst net.HardwareAddr - expected := mustParseMacaddr(t, "01:23:45:67:89:ab") - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) + if aa == nil { + return true } - if bytes.Compare([]byte(dst), []byte(expected)) != 0 { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } + return bytes.Equal(aa, vv) } +} - { - src := pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present} - var dst string - expected := "01:23:45:67:89:ab" +func TestMacaddrCodec(t *testing.T) { + skipCockroachDB(t, "Server does not support type macaddr") - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } + // Only testing known OID query exec modes as net.HardwareAddr could map to macaddr or macaddr8. + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "macaddr", []pgxtest.ValueRoundTripTest{ + { + mustParseMacaddr(t, "01:23:45:67:89:ab"), + new(net.HardwareAddr), + isExpectedEqHardwareAddr(mustParseMacaddr(t, "01:23:45:67:89:ab")), + }, + { + "01:23:45:67:89:ab", + new(net.HardwareAddr), + isExpectedEqHardwareAddr(mustParseMacaddr(t, "01:23:45:67:89:ab")), + }, + { + mustParseMacaddr(t, "01:23:45:67:89:ab"), + new(string), + isExpectedEq("01:23:45:67:89:ab"), + }, + {nil, new(*net.HardwareAddr), isExpectedEq((*net.HardwareAddr)(nil))}, + }) - if dst != expected { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "macaddr8", []pgxtest.ValueRoundTripTest{ + { + mustParseMacaddr(t, "01:23:45:67:89:ab:01:08"), + new(net.HardwareAddr), + isExpectedEqHardwareAddr(mustParseMacaddr(t, "01:23:45:67:89:ab:01:08")), + }, + { + "01:23:45:67:89:ab:01:08", + new(net.HardwareAddr), + isExpectedEqHardwareAddr(mustParseMacaddr(t, "01:23:45:67:89:ab:01:08")), + }, + { + mustParseMacaddr(t, "01:23:45:67:89:ab:01:08"), + new(string), + isExpectedEq("01:23:45:67:89:ab:01:08"), + }, + {nil, new(*net.HardwareAddr), isExpectedEq((*net.HardwareAddr)(nil))}, + }) } diff --git a/pgtype/multirange.go b/pgtype/multirange.go new file mode 100644 index 000000000..4fe6dd40d --- /dev/null +++ b/pgtype/multirange.go @@ -0,0 +1,442 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// MultirangeGetter is a type that can be converted into a PostgreSQL multirange. +type MultirangeGetter interface { + // IsNull returns true if the value is SQL NULL. + IsNull() bool + + // Len returns the number of elements in the multirange. + Len() int + + // Index returns the element at i. + Index(i int) any + + // IndexType returns a non-nil scan target of the type Index will return. This is used by MultirangeCodec.PlanEncode. + IndexType() any +} + +// MultirangeSetter is a type can be set from a PostgreSQL multirange. +type MultirangeSetter interface { + // ScanNull sets the value to SQL NULL. + ScanNull() error + + // SetLen prepares the value such that ScanIndex can be called for each element. This will remove any existing + // elements. + SetLen(n int) error + + // ScanIndex returns a value usable as a scan target for i. SetLen must be called before ScanIndex. + ScanIndex(i int) any + + // ScanIndexType returns a non-nil scan target of the type ScanIndex will return. This is used by + // MultirangeCodec.PlanScan. + ScanIndexType() any +} + +// MultirangeCodec is a codec for any multirange type. +type MultirangeCodec struct { + ElementType *Type +} + +func (c *MultirangeCodec) FormatSupported(format int16) bool { + return c.ElementType.Codec.FormatSupported(format) +} + +func (c *MultirangeCodec) PreferredFormat() int16 { + return c.ElementType.Codec.PreferredFormat() +} + +func (c *MultirangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + multirangeValuer, ok := value.(MultirangeGetter) + if !ok { + return nil + } + + elementType := multirangeValuer.IndexType() + + elementEncodePlan := m.PlanEncode(c.ElementType.OID, format, elementType) + if elementEncodePlan == nil { + return nil + } + + switch format { + case BinaryFormatCode: + return &encodePlanMultirangeCodecBinary{ac: c, m: m, oid: oid} + case TextFormatCode: + return &encodePlanMultirangeCodecText{ac: c, m: m, oid: oid} + } + + return nil +} + +type encodePlanMultirangeCodecText struct { + ac *MultirangeCodec + m *Map + oid uint32 +} + +func (p *encodePlanMultirangeCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + multirange := value.(MultirangeGetter) + + if multirange.IsNull() { + return nil, nil + } + + elementCount := multirange.Len() + + buf = append(buf, '{') + + var encodePlan EncodePlan + var lastElemType reflect.Type + inElemBuf := make([]byte, 0, 32) + for i := 0; i < elementCount; i++ { + if i > 0 { + buf = append(buf, ',') + } + + elem := multirange.Index(i) + var elemBuf []byte + if elem != nil { + elemType := reflect.TypeOf(elem) + if lastElemType != elemType { + lastElemType = elemType + encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, TextFormatCode, elem) + if encodePlan == nil { + return nil, fmt.Errorf("unable to encode %v", multirange.Index(i)) + } + } + elemBuf, err = encodePlan.Encode(elem, inElemBuf) + if err != nil { + return nil, err + } + } + + if elemBuf == nil { + return nil, fmt.Errorf("multirange cannot contain NULL element") + } else { + buf = append(buf, elemBuf...) + } + } + + buf = append(buf, '}') + + return buf, nil +} + +type encodePlanMultirangeCodecBinary struct { + ac *MultirangeCodec + m *Map + oid uint32 +} + +func (p *encodePlanMultirangeCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + multirange := value.(MultirangeGetter) + + if multirange.IsNull() { + return nil, nil + } + + elementCount := multirange.Len() + + buf = pgio.AppendInt32(buf, int32(elementCount)) + + var encodePlan EncodePlan + var lastElemType reflect.Type + for i := 0; i < elementCount; i++ { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elem := multirange.Index(i) + var elemBuf []byte + if elem != nil { + elemType := reflect.TypeOf(elem) + if lastElemType != elemType { + lastElemType = elemType + encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, BinaryFormatCode, elem) + if encodePlan == nil { + return nil, fmt.Errorf("unable to encode %v", multirange.Index(i)) + } + } + elemBuf, err = encodePlan.Encode(elem, buf) + if err != nil { + return nil, err + } + } + + if elemBuf == nil { + return nil, fmt.Errorf("multirange cannot contain NULL element") + } else { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +func (c *MultirangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + multirangeScanner, ok := target.(MultirangeSetter) + if !ok { + return nil + } + + elementType := multirangeScanner.ScanIndexType() + + elementScanPlan := m.PlanScan(c.ElementType.OID, format, elementType) + if _, ok := elementScanPlan.(*scanPlanFail); ok { + return nil + } + + return &scanPlanMultirangeCodec{ + multirangeCodec: c, + m: m, + oid: oid, + formatCode: format, + } +} + +func (c *MultirangeCodec) decodeBinary(m *Map, multirangeOID uint32, src []byte, multirange MultirangeSetter) error { + rp := 0 + + elementCount := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + err := multirange.SetLen(elementCount) + if err != nil { + return err + } + + if elementCount == 0 { + return nil + } + + elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0)) + if elementScanPlan == nil { + elementScanPlan = m.PlanScan(c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0)) + } + + for i := 0; i < elementCount; i++ { + elem := multirange.ScanIndex(i) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elementScanPlan.Scan(elemSrc, elem) + if err != nil { + return fmt.Errorf("failed to scan multirange element %d: %w", i, err) + } + } + + return nil +} + +func (c *MultirangeCodec) decodeText(m *Map, multirangeOID uint32, src []byte, multirange MultirangeSetter) error { + elements, err := parseUntypedTextMultirange(src) + if err != nil { + return err + } + + err = multirange.SetLen(len(elements)) + if err != nil { + return err + } + + if len(elements) == 0 { + return nil + } + + elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, TextFormatCode, multirange.ScanIndex(0)) + if elementScanPlan == nil { + elementScanPlan = m.PlanScan(c.ElementType.OID, TextFormatCode, multirange.ScanIndex(0)) + } + + for i, s := range elements { + elem := multirange.ScanIndex(i) + err = elementScanPlan.Scan([]byte(s), elem) + if err != nil { + return err + } + } + + return nil +} + +type scanPlanMultirangeCodec struct { + multirangeCodec *MultirangeCodec + m *Map + oid uint32 + formatCode int16 + elementScanPlan ScanPlan +} + +func (spac *scanPlanMultirangeCodec) Scan(src []byte, dst any) error { + c := spac.multirangeCodec + m := spac.m + oid := spac.oid + formatCode := spac.formatCode + + multirange := dst.(MultirangeSetter) + + if src == nil { + return multirange.ScanNull() + } + + switch formatCode { + case BinaryFormatCode: + return c.decodeBinary(m, oid, src, multirange) + case TextFormatCode: + return c.decodeText(m, oid, src, multirange) + default: + return fmt.Errorf("unknown format code %d", formatCode) + } +} + +func (c *MultirangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + switch format { + case TextFormatCode: + return string(src), nil + case BinaryFormatCode: + buf := make([]byte, len(src)) + copy(buf, src) + return buf, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } +} + +func (c *MultirangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var multirange Multirange[Range[any]] + err := m.PlanScan(oid, format, &multirange).Scan(src, &multirange) + return multirange, err +} + +func parseUntypedTextMultirange(src []byte) ([]string, error) { + elements := make([]string, 0) + + buf := bytes.NewBuffer(src) + + skipWhitespace(buf) + + r, _, err := buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %w", err) + } + + if r != '{' { + return nil, fmt.Errorf("invalid multirange, expected '{' got %v", r) + } + +parseValueLoop: + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid multirange: %w", err) + } + + switch r { + case ',': // skip range separator + case '}': + break parseValueLoop + default: + buf.UnreadRune() + value, err := parseRange(buf) + if err != nil { + return nil, fmt.Errorf("invalid multirange value: %w", err) + } + elements = append(elements, value) + } + } + + skipWhitespace(buf) + + if buf.Len() > 0 { + return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) + } + + return elements, nil +} + +func parseRange(buf *bytes.Buffer) (string, error) { + s := &bytes.Buffer{} + + boundSepRead := false + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case ',', '}': + if r == ',' && !boundSepRead { + boundSepRead = true + break + } + buf.UnreadRune() + return s.String(), nil + } + + s.WriteRune(r) + } +} + +// Multirange is a generic multirange type. +// +// T should implement [RangeValuer] and *T should implement [RangeScanner]. However, there does not appear to be a way to +// enforce the [RangeScanner] constraint. +type Multirange[T RangeValuer] []T + +func (r Multirange[T]) IsNull() bool { + return r == nil +} + +func (r Multirange[T]) Len() int { + return len(r) +} + +func (r Multirange[T]) Index(i int) any { + return r[i] +} + +func (r Multirange[T]) IndexType() any { + var zero T + return zero +} + +func (r *Multirange[T]) ScanNull() error { + *r = nil + return nil +} + +func (r *Multirange[T]) SetLen(n int) error { + *r = make([]T, n) + return nil +} + +func (r Multirange[T]) ScanIndex(i int) any { + return &r[i] +} + +func (r Multirange[T]) ScanIndexType() any { + return new(T) +} diff --git a/pgtype/multirange_test.go b/pgtype/multirange_test.go new file mode 100644 index 000000000..fe53083b8 --- /dev/null +++ b/pgtype/multirange_test.go @@ -0,0 +1,113 @@ +package pgtype_test + +import ( + "context" + "reflect" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" +) + +func TestMultirangeCodecTranscode(t *testing.T) { + skipPostgreSQLVersionLessThan(t, 14) + skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int4multirange", []pgxtest.ValueRoundTripTest{ + { + pgtype.Multirange[pgtype.Range[pgtype.Int4]](nil), + new(pgtype.Multirange[pgtype.Range[pgtype.Int4]]), + func(a any) bool { return reflect.DeepEqual(pgtype.Multirange[pgtype.Range[pgtype.Int4]](nil), a) }, + }, + { + pgtype.Multirange[pgtype.Range[pgtype.Int4]]{}, + new(pgtype.Multirange[pgtype.Range[pgtype.Int4]]), + func(a any) bool { return reflect.DeepEqual(pgtype.Multirange[pgtype.Range[pgtype.Int4]]{}, a) }, + }, + { + pgtype.Multirange[pgtype.Range[pgtype.Int4]]{ + { + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + { + Lower: pgtype.Int4{Int32: 7, Valid: true}, + Upper: pgtype.Int4{Int32: 9, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + }, + new(pgtype.Multirange[pgtype.Range[pgtype.Int4]]), + func(a any) bool { + return reflect.DeepEqual(pgtype.Multirange[pgtype.Range[pgtype.Int4]]{ + { + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + { + Lower: pgtype.Int4{Int32: 7, Valid: true}, + Upper: pgtype.Int4{Int32: 9, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + }, a) + }, + }, + }) +} + +func TestMultirangeCodecDecodeValue(t *testing.T) { + skipPostgreSQLVersionLessThan(t, 14) + skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + for _, tt := range []struct { + sql string + expected any + }{ + { + sql: `select int4multirange(int4range(1, 5), int4range(7,9))`, + expected: pgtype.Multirange[pgtype.Range[any]]{ + { + Lower: int32(1), + Upper: int32(5), + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + { + Lower: int32(7), + Upper: int32(9), + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + }, + }, + } { + t.Run(tt.sql, func(t *testing.T) { + rows, err := conn.Query(ctx, tt.sql) + require.NoError(t, err) + + for rows.Next() { + values, err := rows.Values() + require.NoError(t, err) + require.Len(t, values, 1) + require.Equal(t, tt.expected, values[0]) + } + + require.NoError(t, rows.Err()) + }) + } + }) +} diff --git a/pgtype/name.go b/pgtype/name.go deleted file mode 100644 index af064a82f..000000000 --- a/pgtype/name.go +++ /dev/null @@ -1,58 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" -) - -// Name is a type used for PostgreSQL's special 63-byte -// name data type, used for identifiers like table names. -// The pg_class.relname column is a good example of where the -// name data type is used. -// -// Note that the underlying Go data type of pgx.Name is string, -// so there is no way to enforce the 63-byte length. Inputting -// a longer name into PostgreSQL will result in silent truncation -// to 63 bytes. -// -// Also, if you have custom-compiled PostgreSQL and set -// NAMEDATALEN to a different value, obviously that number of -// bytes applies, rather than the default 63. -type Name Text - -func (dst *Name) Set(src interface{}) error { - return (*Text)(dst).Set(src) -} - -func (dst *Name) Get() interface{} { - return (*Text)(dst).Get() -} - -func (src *Name) AssignTo(dst interface{}) error { - return (*Text)(src).AssignTo(dst) -} - -func (dst *Name) DecodeText(ci *ConnInfo, src []byte) error { - return (*Text)(dst).DecodeText(ci, src) -} - -func (dst *Name) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*Text)(dst).DecodeBinary(ci, src) -} - -func (src *Name) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Text)(src).EncodeText(ci, buf) -} - -func (src *Name) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Text)(src).EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *Name) Scan(src interface{}) error { - return (*Text)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *Name) Value() (driver.Value, error) { - return (*Text)(src).Value() -} diff --git a/pgtype/name_test.go b/pgtype/name_test.go deleted file mode 100644 index ec0820c4b..000000000 --- a/pgtype/name_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestNameTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "name", []interface{}{ - &pgtype.Name{String: "", Status: pgtype.Present}, - &pgtype.Name{String: "foo", Status: pgtype.Present}, - &pgtype.Name{Status: pgtype.Null}, - }) -} - -func TestNameSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Name - }{ - {source: "foo", result: pgtype.Name{String: "foo", Status: pgtype.Present}}, - {source: _string("bar"), result: pgtype.Name{String: "bar", Status: pgtype.Present}}, - {source: (*string)(nil), result: pgtype.Name{Status: pgtype.Null}}, - } - - for i, tt := range successfulTests { - var d pgtype.Name - err := d.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if d != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) - } - } -} - -func TestNameAssignTo(t *testing.T) { - var s string - var ps *string - - simpleTests := []struct { - src pgtype.Name - dst interface{} - expected interface{} - }{ - {src: pgtype.Name{String: "foo", Status: pgtype.Present}, dst: &s, expected: "foo"}, - {src: pgtype.Name{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Name - dst interface{} - expected interface{} - }{ - {src: pgtype.Name{String: "foo", Status: pgtype.Present}, dst: &ps, expected: "foo"}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Name - dst interface{} - }{ - {src: pgtype.Name{Status: pgtype.Null}, dst: &s}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/numeric.go b/pgtype/numeric.go index fb63df756..7d236902d 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -1,368 +1,632 @@ package pgtype import ( + "bytes" "database/sql/driver" "encoding/binary" + "fmt" "math" "math/big" "strconv" "strings" - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" + "github.com/jackc/pgx/v5/internal/pgio" ) // PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000 const nbase = 10000 -var big0 *big.Int = big.NewInt(0) -var big1 *big.Int = big.NewInt(1) -var big10 *big.Int = big.NewInt(10) -var big100 *big.Int = big.NewInt(100) -var big1000 *big.Int = big.NewInt(1000) - -var bigMaxInt8 *big.Int = big.NewInt(math.MaxInt8) -var bigMinInt8 *big.Int = big.NewInt(math.MinInt8) -var bigMaxInt16 *big.Int = big.NewInt(math.MaxInt16) -var bigMinInt16 *big.Int = big.NewInt(math.MinInt16) -var bigMaxInt32 *big.Int = big.NewInt(math.MaxInt32) -var bigMinInt32 *big.Int = big.NewInt(math.MinInt32) -var bigMaxInt64 *big.Int = big.NewInt(math.MaxInt64) -var bigMinInt64 *big.Int = big.NewInt(math.MinInt64) -var bigMaxInt *big.Int = big.NewInt(int64(maxInt)) -var bigMinInt *big.Int = big.NewInt(int64(minInt)) - -var bigMaxUint8 *big.Int = big.NewInt(math.MaxUint8) -var bigMaxUint16 *big.Int = big.NewInt(math.MaxUint16) -var bigMaxUint32 *big.Int = big.NewInt(math.MaxUint32) -var bigMaxUint64 *big.Int = (&big.Int{}).SetUint64(uint64(math.MaxUint64)) -var bigMaxUint *big.Int = (&big.Int{}).SetUint64(uint64(maxUint)) - -var bigNBase *big.Int = big.NewInt(nbase) -var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) -var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) -var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) +const ( + pgNumericNaN = 0x00000000c0000000 + pgNumericNaNSign = 0xc000 + + pgNumericPosInf = 0x00000000d0000000 + pgNumericPosInfSign = 0xd000 + + pgNumericNegInf = 0x00000000f0000000 + pgNumericNegInfSign = 0xf000 +) + +var ( + big0 *big.Int = big.NewInt(0) + big1 *big.Int = big.NewInt(1) + big10 *big.Int = big.NewInt(10) + big100 *big.Int = big.NewInt(100) + big1000 *big.Int = big.NewInt(1000) +) + +var ( + bigNBase *big.Int = big.NewInt(nbase) + bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) + bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) + bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) +) + +type NumericScanner interface { + ScanNumeric(v Numeric) error +} + +type NumericValuer interface { + NumericValue() (Numeric, error) +} type Numeric struct { - Int *big.Int - Exp int32 - Status Status + Int *big.Int + Exp int32 + NaN bool + InfinityModifier InfinityModifier + Valid bool } -func (dst *Numeric) Set(src interface{}) error { - if src == nil { - *dst = Numeric{Status: Null} - return nil +// ScanNumeric implements the [NumericScanner] interface. +func (n *Numeric) ScanNumeric(v Numeric) error { + *n = v + return nil +} + +// NumericValue implements the [NumericValuer] interface. +func (n Numeric) NumericValue() (Numeric, error) { + return n, nil +} + +// Float64Value implements the [Float64Valuer] interface. +func (n Numeric) Float64Value() (Float8, error) { + if !n.Valid { + return Float8{}, nil + } else if n.NaN { + return Float8{Float64: math.NaN(), Valid: true}, nil + } else if n.InfinityModifier == Infinity { + return Float8{Float64: math.Inf(1), Valid: true}, nil + } else if n.InfinityModifier == NegativeInfinity { + return Float8{Float64: math.Inf(-1), Valid: true}, nil } - switch value := src.(type) { - case float32: - num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64)) - if err != nil { - return err - } - *dst = Numeric{Int: num, Exp: exp, Status: Present} - case float64: - num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64)) - if err != nil { - return err - } - *dst = Numeric{Int: num, Exp: exp, Status: Present} - case int8: - *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} - case uint8: - *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} - case int16: - *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} - case uint16: - *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} - case int32: - *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} - case uint32: - *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} - case int64: - *dst = Numeric{Int: big.NewInt(value), Status: Present} - case uint64: - *dst = Numeric{Int: (&big.Int{}).SetUint64(value), Status: Present} - case int: - *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} - case uint: - *dst = Numeric{Int: (&big.Int{}).SetUint64(uint64(value)), Status: Present} - case string: - num, exp, err := parseNumericString(value) - if err != nil { - return err - } - *dst = Numeric{Int: num, Exp: exp, Status: Present} - default: - if originalSrc, ok := underlyingNumberType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to Numeric", value) + buf := make([]byte, 0, 32) + + if n.Int == nil { + buf = append(buf, '0') + } else { + buf = append(buf, n.Int.String()...) + } + buf = append(buf, 'e') + buf = append(buf, strconv.FormatInt(int64(n.Exp), 10)...) + + f, err := strconv.ParseFloat(string(buf), 64) + if err != nil { + return Float8{}, err } - return nil + return Float8{Float64: f, Valid: true}, nil } -func (dst *Numeric) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: +// ScanInt64 implements the [Int64Scanner] interface. +func (n *Numeric) ScanInt64(v Int8) error { + if !v.Valid { + *n = Numeric{} return nil - default: - return dst.Status } + + *n = Numeric{Int: big.NewInt(v.Int64), Valid: true} + return nil } -func (src *Numeric) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *float32: - f, err := src.toFloat64() - if err != nil { - return err - } - return float64AssignTo(f, src.Status, dst) - case *float64: - f, err := src.toFloat64() - if err != nil { - return err - } - return float64AssignTo(f, src.Status, dst) - case *int: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(bigMaxInt) > 0 { - return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) - } - if normalizedInt.Cmp(bigMinInt) < 0 { - return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) - } - *v = int(normalizedInt.Int64()) - case *int8: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(bigMaxInt8) > 0 { - return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) - } - if normalizedInt.Cmp(bigMinInt8) < 0 { - return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) - } - *v = int8(normalizedInt.Int64()) - case *int16: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(bigMaxInt16) > 0 { - return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) - } - if normalizedInt.Cmp(bigMinInt16) < 0 { - return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) - } - *v = int16(normalizedInt.Int64()) - case *int32: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(bigMaxInt32) > 0 { - return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) - } - if normalizedInt.Cmp(bigMinInt32) < 0 { - return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) - } - *v = int32(normalizedInt.Int64()) - case *int64: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(bigMaxInt64) > 0 { - return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) - } - if normalizedInt.Cmp(bigMinInt64) < 0 { - return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) - } - *v = normalizedInt.Int64() - case *uint: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(big0) < 0 { - return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) - } else if normalizedInt.Cmp(bigMaxUint) > 0 { - return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) - } - *v = uint(normalizedInt.Uint64()) - case *uint8: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(big0) < 0 { - return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) - } else if normalizedInt.Cmp(bigMaxUint8) > 0 { - return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) - } - *v = uint8(normalizedInt.Uint64()) - case *uint16: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(big0) < 0 { - return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) - } else if normalizedInt.Cmp(bigMaxUint16) > 0 { - return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) - } - *v = uint16(normalizedInt.Uint64()) - case *uint32: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(big0) < 0 { - return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) - } else if normalizedInt.Cmp(bigMaxUint32) > 0 { - return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) - } - *v = uint32(normalizedInt.Uint64()) - case *uint64: - normalizedInt, err := src.toBigInt() - if err != nil { - return err - } - if normalizedInt.Cmp(big0) < 0 { - return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) - } else if normalizedInt.Cmp(bigMaxUint64) > 0 { - return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) - } - *v = normalizedInt.Uint64() - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) +// Int64Value implements the [Int64Valuer] interface. +func (n Numeric) Int64Value() (Int8, error) { + if !n.Valid { + return Int8{}, nil + } + + bi, err := n.toBigInt() + if err != nil { + return Int8{}, err + } + + if !bi.IsInt64() { + return Int8{}, fmt.Errorf("cannot convert %v to int64", n) } + return Int8{Int64: bi.Int64(), Valid: true}, nil +} + +func (n *Numeric) ScanScientific(src string) error { + if !strings.ContainsAny("eE", src) { + return scanPlanTextAnyToNumericScanner{}.Scan([]byte(src), n) + } + + if bigF, ok := new(big.Float).SetString(string(src)); ok { + smallF, _ := bigF.Float64() + src = strconv.FormatFloat(smallF, 'f', -1, 64) + } + + num, exp, err := parseNumericString(src) + if err != nil { + return err + } + + *n = Numeric{Int: num, Exp: exp, Valid: true} + return nil } -func (dst *Numeric) toBigInt() (*big.Int, error) { - if dst.Exp == 0 { - return dst.Int, nil +func (n *Numeric) toBigInt() (*big.Int, error) { + if n.Exp == 0 { + return n.Int, nil } num := &big.Int{} - num.Set(dst.Int) - if dst.Exp > 0 { + num.Set(n.Int) + if n.Exp > 0 { mul := &big.Int{} - mul.Exp(big10, big.NewInt(int64(dst.Exp)), nil) + mul.Exp(big10, big.NewInt(int64(n.Exp)), nil) num.Mul(num, mul) return num, nil } div := &big.Int{} - div.Exp(big10, big.NewInt(int64(-dst.Exp)), nil) + div.Exp(big10, big.NewInt(int64(-n.Exp)), nil) remainder := &big.Int{} num.DivMod(num, div, remainder) if remainder.Cmp(big0) != 0 { - return nil, errors.Errorf("cannot convert %v to integer", dst) + return nil, fmt.Errorf("cannot convert %v to integer", n) } return num, nil } -func (src *Numeric) toFloat64() (float64, error) { - f, err := strconv.ParseFloat(src.Int.String(), 64) - if err != nil { - return 0, err - } - if src.Exp > 0 { - for i := 0; i < int(src.Exp); i++ { - f *= 10 +func parseNumericString(str string) (n *big.Int, exp int32, err error) { + idx := strings.IndexByte(str, '.') + + if idx == -1 { + for len(str) > 1 && str[len(str)-1] == '0' && str[len(str)-2] != '-' { + str = str[:len(str)-1] + exp++ } - } else if src.Exp < 0 { - for i := 0; i > int(src.Exp); i-- { - f /= 10 + } else { + exp = int32(-(len(str) - idx - 1)) + str = str[:idx] + str[idx+1:] + } + + accum := &big.Int{} + if _, ok := accum.SetString(str, 10); !ok { + return nil, 0, fmt.Errorf("%s is not a number", str) + } + + return accum, exp, nil +} + +func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { + digits := len(src) / 2 + if digits > 4 { + digits = 4 + } + + rp := 0 + + for i := 0; i < digits; i++ { + if i > 0 { + accum *= nbase } + accum += int64(binary.BigEndian.Uint16(src[rp:])) + rp += 2 } - return f, nil + + return accum, rp, digits } -func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { +// Scan implements the [database/sql.Scanner] interface. +func (n *Numeric) Scan(src any) error { if src == nil { - *dst = Numeric{Status: Null} + *n = Numeric{} return nil } - num, exp, err := parseNumericString(string(src)) + switch src := src.(type) { + case string: + return scanPlanTextAnyToNumericScanner{}.Scan([]byte(src), n) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (n Numeric) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + + buf, err := NumericCodec{}.PlanEncode(nil, 0, TextFormatCode, n).Encode(n, nil) if err != nil { - return err + return nil, err + } + return string(buf), err +} + +// MarshalJSON implements the [encoding/json.Marshaler] interface. +func (n Numeric) MarshalJSON() ([]byte, error) { + if !n.Valid { + return []byte("null"), nil + } + + if n.NaN { + return []byte(`"NaN"`), nil + } + + return n.numberTextBytes(), nil +} + +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. +func (n *Numeric) UnmarshalJSON(src []byte) error { + if bytes.Equal(src, []byte(`null`)) { + *n = Numeric{} + return nil + } + if bytes.Equal(src, []byte(`"NaN"`)) { + *n = Numeric{NaN: true, Valid: true} + return nil + } + return scanPlanTextAnyToNumericScanner{}.Scan(src, n) +} + +// numberString returns a string of the number. undefined if NaN, infinite, or NULL +func (n Numeric) numberTextBytes() []byte { + intStr := n.Int.String() + + buf := &bytes.Buffer{} + + if len(intStr) > 0 && intStr[:1] == "-" { + intStr = intStr[1:] + buf.WriteByte('-') + } + + exp := int(n.Exp) + if exp > 0 { + buf.WriteString(intStr) + for i := 0; i < exp; i++ { + buf.WriteByte('0') + } + } else if exp < 0 { + if len(intStr) <= -exp { + buf.WriteString("0.") + leadingZeros := -exp - len(intStr) + for i := 0; i < leadingZeros; i++ { + buf.WriteByte('0') + } + buf.WriteString(intStr) + } else if len(intStr) > -exp { + dpPos := len(intStr) + exp + buf.WriteString(intStr[:dpPos]) + buf.WriteByte('.') + buf.WriteString(intStr[dpPos:]) + } + } else { + buf.WriteString(intStr) + } + + return buf.Bytes() +} + +type NumericCodec struct{} + +func (NumericCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (NumericCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (NumericCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case NumericValuer: + return encodePlanNumericCodecBinaryNumericValuer{} + case Float64Valuer: + return encodePlanNumericCodecBinaryFloat64Valuer{} + case Int64Valuer: + return encodePlanNumericCodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case NumericValuer: + return encodePlanNumericCodecTextNumericValuer{} + case Float64Valuer: + return encodePlanNumericCodecTextFloat64Valuer{} + case Int64Valuer: + return encodePlanNumericCodecTextInt64Valuer{} + } } - *dst = Numeric{Int: num, Exp: exp, Status: Present} return nil } -func parseNumericString(str string) (n *big.Int, exp int32, err error) { - parts := strings.SplitN(str, ".", 2) - digits := strings.Join(parts, "") +type encodePlanNumericCodecBinaryNumericValuer struct{} - if len(parts) > 1 { - exp = int32(-len(parts[1])) +func (encodePlanNumericCodecBinaryNumericValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(NumericValuer).NumericValue() + if err != nil { + return nil, err + } + + return encodeNumericBinary(n, buf) +} + +type encodePlanNumericCodecBinaryFloat64Valuer struct{} + +func (encodePlanNumericCodecBinaryFloat64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Float64Valuer).Float64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if math.IsNaN(n.Float64) { + return encodeNumericBinary(Numeric{NaN: true, Valid: true}, buf) + } else if math.IsInf(n.Float64, 1) { + return encodeNumericBinary(Numeric{InfinityModifier: Infinity, Valid: true}, buf) + } else if math.IsInf(n.Float64, -1) { + return encodeNumericBinary(Numeric{InfinityModifier: NegativeInfinity, Valid: true}, buf) + } + num, exp, err := parseNumericString(strconv.FormatFloat(n.Float64, 'f', -1, 64)) + if err != nil { + return nil, err + } + + return encodeNumericBinary(Numeric{Int: num, Exp: exp, Valid: true}, buf) +} + +type encodePlanNumericCodecBinaryInt64Valuer struct{} + +func (encodePlanNumericCodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + return encodeNumericBinary(Numeric{Int: big.NewInt(n.Int64), Valid: true}, buf) +} + +func encodeNumericBinary(n Numeric, buf []byte) (newBuf []byte, err error) { + if !n.Valid { + return nil, nil + } + + if n.NaN { + buf = pgio.AppendUint64(buf, pgNumericNaN) + return buf, nil + } else if n.InfinityModifier == Infinity { + buf = pgio.AppendUint64(buf, pgNumericPosInf) + return buf, nil + } else if n.InfinityModifier == NegativeInfinity { + buf = pgio.AppendUint64(buf, pgNumericNegInf) + return buf, nil + } + + var sign int16 + if n.Int.Cmp(big0) < 0 { + sign = 16384 + } + + absInt := &big.Int{} + wholePart := &big.Int{} + fracPart := &big.Int{} + remainder := &big.Int{} + absInt.Abs(n.Int) + + // Normalize absInt and exp to where exp is always a multiple of 4. This makes + // converting to 16-bit base 10,000 digits easier. + var exp int32 + switch n.Exp % 4 { + case 1, -3: + exp = n.Exp - 1 + absInt.Mul(absInt, big10) + case 2, -2: + exp = n.Exp - 2 + absInt.Mul(absInt, big100) + case 3, -1: + exp = n.Exp - 3 + absInt.Mul(absInt, big1000) + default: + exp = n.Exp + } + + if exp < 0 { + divisor := &big.Int{} + divisor.Exp(big10, big.NewInt(int64(-exp)), nil) + wholePart.DivMod(absInt, divisor, fracPart) + fracPart.Add(fracPart, divisor) } else { - for len(digits) > 1 && digits[len(digits)-1] == '0' { - digits = digits[:len(digits)-1] - exp++ + wholePart = absInt + } + + var wholeDigits, fracDigits []int16 + + for wholePart.Cmp(big0) != 0 { + wholePart.DivMod(wholePart, bigNBase, remainder) + wholeDigits = append(wholeDigits, int16(remainder.Int64())) + } + + if fracPart.Cmp(big0) != 0 { + for fracPart.Cmp(big1) != 0 { + fracPart.DivMod(fracPart, bigNBase, remainder) + fracDigits = append(fracDigits, int16(remainder.Int64())) } } - accum := &big.Int{} - if _, ok := accum.SetString(digits, 10); !ok { - return nil, 0, errors.Errorf("%s is not a number", str) + buf = pgio.AppendInt16(buf, int16(len(wholeDigits)+len(fracDigits))) + + var weight int16 + if len(wholeDigits) > 0 { + weight = int16(len(wholeDigits) - 1) + if exp > 0 { + weight += int16(exp / 4) + } + } else { + weight = int16(exp/4) - 1 + int16(len(fracDigits)) } + buf = pgio.AppendInt16(buf, weight) - return accum, exp, nil + buf = pgio.AppendInt16(buf, sign) + + var dscale int16 + if n.Exp < 0 { + dscale = int16(-n.Exp) + } + buf = pgio.AppendInt16(buf, dscale) + + for i := len(wholeDigits) - 1; i >= 0; i-- { + buf = pgio.AppendInt16(buf, wholeDigits[i]) + } + + for i := len(fracDigits) - 1; i >= 0; i-- { + buf = pgio.AppendInt16(buf, fracDigits[i]) + } + + return buf, nil } -func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { +type encodePlanNumericCodecTextNumericValuer struct{} + +func (encodePlanNumericCodecTextNumericValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(NumericValuer).NumericValue() + if err != nil { + return nil, err + } + + return encodeNumericText(n, buf) +} + +type encodePlanNumericCodecTextFloat64Valuer struct{} + +func (encodePlanNumericCodecTextFloat64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Float64Valuer).Float64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if math.IsNaN(n.Float64) { + buf = append(buf, "NaN"...) + } else if math.IsInf(n.Float64, 1) { + buf = append(buf, "Infinity"...) + } else if math.IsInf(n.Float64, -1) { + buf = append(buf, "-Infinity"...) + } else { + buf = append(buf, strconv.FormatFloat(n.Float64, 'f', -1, 64)...) + } + return buf, nil +} + +type encodePlanNumericCodecTextInt64Valuer struct{} + +func (encodePlanNumericCodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + buf = append(buf, strconv.FormatInt(n.Int64, 10)...) + return buf, nil +} + +func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) { + if !n.Valid { + return nil, nil + } + + if n.NaN { + buf = append(buf, "NaN"...) + return buf, nil + } else if n.InfinityModifier == Infinity { + buf = append(buf, "Infinity"...) + return buf, nil + } else if n.InfinityModifier == NegativeInfinity { + buf = append(buf, "-Infinity"...) + return buf, nil + } + + buf = append(buf, n.numberTextBytes()...) + + return buf, nil +} + +func (NumericCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case NumericScanner: + return scanPlanBinaryNumericToNumericScanner{} + case Float64Scanner: + return scanPlanBinaryNumericToFloat64Scanner{} + case Int64Scanner: + return scanPlanBinaryNumericToInt64Scanner{} + case TextScanner: + return scanPlanBinaryNumericToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case NumericScanner: + return scanPlanTextAnyToNumericScanner{} + case Float64Scanner: + return scanPlanTextAnyToFloat64Scanner{} + case Int64Scanner: + return scanPlanTextAnyToInt64Scanner{} + } + } + + return nil +} + +type scanPlanBinaryNumericToNumericScanner struct{} + +func (scanPlanBinaryNumericToNumericScanner) Scan(src []byte, dst any) error { + scanner := (dst).(NumericScanner) + if src == nil { - *dst = Numeric{Status: Null} - return nil + return scanner.ScanNumeric(Numeric{}) } if len(src) < 8 { - return errors.Errorf("numeric incomplete %v", src) + return fmt.Errorf("numeric incomplete %v", src) } rp := 0 - ndigits := int16(binary.BigEndian.Uint16(src[rp:])) + ndigits := binary.BigEndian.Uint16(src[rp:]) rp += 2 - - if ndigits == 0 { - *dst = Numeric{Int: big.NewInt(0), Status: Present} - return nil - } - weight := int16(binary.BigEndian.Uint16(src[rp:])) rp += 2 - sign := int16(binary.BigEndian.Uint16(src[rp:])) + sign := binary.BigEndian.Uint16(src[rp:]) rp += 2 dscale := int16(binary.BigEndian.Uint16(src[rp:])) rp += 2 + if sign == pgNumericNaNSign { + return scanner.ScanNumeric(Numeric{NaN: true, Valid: true}) + } else if sign == pgNumericPosInfSign { + return scanner.ScanNumeric(Numeric{InfinityModifier: Infinity, Valid: true}) + } else if sign == pgNumericNegInfSign { + return scanner.ScanNumeric(Numeric{InfinityModifier: NegativeInfinity, Valid: true}) + } + + if ndigits == 0 { + return scanner.ScanNumeric(Numeric{Int: big.NewInt(0), Valid: true}) + } + if len(src[rp:]) < int(ndigits)*2 { - return errors.Errorf("numeric incomplete %v", src) + return fmt.Errorf("numeric incomplete %v", src) } accum := &big.Int{} @@ -383,7 +647,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { case 4: mul = bigNBaseX4 default: - return errors.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead) + return fmt.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead) } accum.Mul(accum, mul) } @@ -394,7 +658,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { exp := (int32(weight) - int32(ndigits) + 1) * 4 if dscale > 0 { - fracNBaseDigits := ndigits - weight - 1 + fracNBaseDigits := int16(int32(ndigits) - int32(weight) - 1) fracDecimalDigits := fracNBaseDigits * 4 if dscale > fracDecimalDigits { @@ -429,172 +693,141 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { accum.Neg(accum) } - *dst = Numeric{Int: accum, Exp: exp, Status: Present} + return scanner.ScanNumeric(Numeric{Int: accum, Exp: exp, Valid: true}) +} - return nil +type scanPlanBinaryNumericToFloat64Scanner struct{} -} +func (scanPlanBinaryNumericToFloat64Scanner) Scan(src []byte, dst any) error { + scanner := (dst).(Float64Scanner) -func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { - digits := len(src) / 2 - if digits > 4 { - digits = 4 + if src == nil { + return scanner.ScanFloat64(Float8{}) } - rp := 0 + var n Numeric - for i := 0; i < digits; i++ { - if i > 0 { - accum *= nbase - } - accum += int64(binary.BigEndian.Uint16(src[rp:])) - rp += 2 + err := scanPlanBinaryNumericToNumericScanner{}.Scan(src, &n) + if err != nil { + return err } - return accum, rp, digits -} - -func (src *Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined + f8, err := n.Float64Value() + if err != nil { + return err } - buf = append(buf, src.Int.String()...) - buf = append(buf, 'e') - buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...) - return buf, nil + return scanner.ScanFloat64(f8) } -func (src *Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined +type scanPlanBinaryNumericToInt64Scanner struct{} + +func (scanPlanBinaryNumericToInt64Scanner) Scan(src []byte, dst any) error { + scanner := (dst).(Int64Scanner) + + if src == nil { + return scanner.ScanInt64(Int8{}) } - var sign int16 - if src.Int.Cmp(big0) < 0 { - sign = 16384 + var n Numeric + + err := scanPlanBinaryNumericToNumericScanner{}.Scan(src, &n) + if err != nil { + return err } - absInt := &big.Int{} - wholePart := &big.Int{} - fracPart := &big.Int{} - remainder := &big.Int{} - absInt.Abs(src.Int) + bigInt, err := n.toBigInt() + if err != nil { + return err + } - // Normalize absInt and exp to where exp is always a multiple of 4. This makes - // converting to 16-bit base 10,000 digits easier. - var exp int32 - switch src.Exp % 4 { - case 1, -3: - exp = src.Exp - 1 - absInt.Mul(absInt, big10) - case 2, -2: - exp = src.Exp - 2 - absInt.Mul(absInt, big100) - case 3, -1: - exp = src.Exp - 3 - absInt.Mul(absInt, big1000) - default: - exp = src.Exp + if !bigInt.IsInt64() { + return fmt.Errorf("%v is out of range for int64", bigInt) } - if exp < 0 { - divisor := &big.Int{} - divisor.Exp(big10, big.NewInt(int64(-exp)), nil) - wholePart.DivMod(absInt, divisor, fracPart) - fracPart.Add(fracPart, divisor) - } else { - wholePart = absInt + return scanner.ScanInt64(Int8{Int64: bigInt.Int64(), Valid: true}) +} + +type scanPlanBinaryNumericToTextScanner struct{} + +func (scanPlanBinaryNumericToTextScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TextScanner) + + if src == nil { + return scanner.ScanText(Text{}) } - var wholeDigits, fracDigits []int16 + var n Numeric - for wholePart.Cmp(big0) != 0 { - wholePart.DivMod(wholePart, bigNBase, remainder) - wholeDigits = append(wholeDigits, int16(remainder.Int64())) + err := scanPlanBinaryNumericToNumericScanner{}.Scan(src, &n) + if err != nil { + return err } - if fracPart.Cmp(big0) != 0 { - for fracPart.Cmp(big1) != 0 { - fracPart.DivMod(fracPart, bigNBase, remainder) - fracDigits = append(fracDigits, int16(remainder.Int64())) - } + sbuf, err := encodeNumericText(n, nil) + if err != nil { + return err } - buf = pgio.AppendInt16(buf, int16(len(wholeDigits)+len(fracDigits))) + return scanner.ScanText(Text{String: string(sbuf), Valid: true}) +} - var weight int16 - if len(wholeDigits) > 0 { - weight = int16(len(wholeDigits) - 1) - if exp > 0 { - weight += int16(exp / 4) - } - } else { - weight = int16(exp/4) - 1 + int16(len(fracDigits)) - } - buf = pgio.AppendInt16(buf, weight) +type scanPlanTextAnyToNumericScanner struct{} - buf = pgio.AppendInt16(buf, sign) +func (scanPlanTextAnyToNumericScanner) Scan(src []byte, dst any) error { + scanner := (dst).(NumericScanner) - var dscale int16 - if src.Exp < 0 { - dscale = int16(-src.Exp) + if src == nil { + return scanner.ScanNumeric(Numeric{}) } - buf = pgio.AppendInt16(buf, dscale) - for i := len(wholeDigits) - 1; i >= 0; i-- { - buf = pgio.AppendInt16(buf, wholeDigits[i]) + if string(src) == "NaN" { + return scanner.ScanNumeric(Numeric{NaN: true, Valid: true}) + } else if string(src) == "Infinity" { + return scanner.ScanNumeric(Numeric{InfinityModifier: Infinity, Valid: true}) + } else if string(src) == "-Infinity" { + return scanner.ScanNumeric(Numeric{InfinityModifier: NegativeInfinity, Valid: true}) } - for i := len(fracDigits) - 1; i >= 0; i-- { - buf = pgio.AppendInt16(buf, fracDigits[i]) + num, exp, err := parseNumericString(string(src)) + if err != nil { + return err } - return buf, nil + return scanner.ScanNumeric(Numeric{Int: num, Exp: exp, Valid: true}) } -// Scan implements the database/sql Scanner interface. -func (dst *Numeric) Scan(src interface{}) error { +func (c NumericCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { - *dst = Numeric{Status: Null} - return nil + return nil, nil } - switch src := src.(type) { - case float64: - // TODO - // *dst = Numeric{Float: src, Status: Present} - return nil - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + if format == TextFormatCode { + return string(src), nil } - return errors.Errorf("cannot scan %T", src) -} + var n Numeric + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } -// Value implements the database/sql/driver Valuer interface. -func (src *Numeric) Value() (driver.Value, error) { - switch src.Status { - case Present: - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } + buf, err := m.Encode(oid, TextFormatCode, n, nil) + if err != nil { + return nil, err + } + return string(buf), nil +} - return string(buf), nil - case Null: +func (c NumericCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { return nil, nil - default: - return nil, errUndefined } + + var n Numeric + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil } diff --git a/pgtype/numeric_array.go b/pgtype/numeric_array.go deleted file mode 100644 index d991234ac..000000000 --- a/pgtype/numeric_array.go +++ /dev/null @@ -1,328 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type NumericArray struct { - Elements []Numeric - Dimensions []ArrayDimension - Status Status -} - -func (dst *NumericArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = NumericArray{Status: Null} - return nil - } - - switch value := src.(type) { - - case []float32: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []float64: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - default: - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to NumericArray", value) - } - - return nil -} - -func (dst *NumericArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *NumericArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - - case *[]float32: - *v = make([]float32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]float64: - *v = make([]float64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *NumericArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = NumericArray{Status: Null} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Numeric - - if len(uta.Elements) > 0 { - elements = make([]Numeric, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Numeric - var elemSrc []byte - if s != "NULL" { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = NumericArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} - - return nil -} - -func (dst *NumericArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = NumericArray{Status: Null} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = NumericArray{Dimensions: arrayHeader.Dimensions, Status: Present} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Numeric, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = NumericArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} - return nil -} - -func (src *NumericArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src *NumericArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("numeric"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, errors.Errorf("unable to find oid for type name %v", "numeric") - } - - for i := range src.Elements { - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *NumericArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *NumericArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/numeric_array_test.go b/pgtype/numeric_array_test.go deleted file mode 100644 index 22ee1bc4d..000000000 --- a/pgtype/numeric_array_test.go +++ /dev/null @@ -1,160 +0,0 @@ -package pgtype_test - -import ( - "math/big" - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestNumericArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "numeric[]", []interface{}{ - &pgtype.NumericArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.NumericArray{ - Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.NumericArray{Status: pgtype.Null}, - &pgtype.NumericArray{ - Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Status: pgtype.Present}, - {Int: big.NewInt(2), Status: pgtype.Present}, - {Int: big.NewInt(3), Status: pgtype.Present}, - {Int: big.NewInt(4), Status: pgtype.Present}, - {Status: pgtype.Null}, - {Int: big.NewInt(6), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.NumericArray{ - Elements: []pgtype.Numeric{ - {Int: big.NewInt(1), Status: pgtype.Present}, - {Int: big.NewInt(2), Status: pgtype.Present}, - {Int: big.NewInt(3), Status: pgtype.Present}, - {Int: big.NewInt(4), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestNumericArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.NumericArray - }{ - { - source: []float32{1}, - result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []float64{1}, - result: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]float32)(nil)), - result: pgtype.NumericArray{Status: pgtype.Null}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.NumericArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestNumericArrayAssignTo(t *testing.T) { - var float32Slice []float32 - var float64Slice []float64 - - simpleTests := []struct { - src pgtype.NumericArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &float32Slice, - expected: []float32{1}, - }, - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &float64Slice, - expected: []float64{1}, - }, - { - src: pgtype.NumericArray{Status: pgtype.Null}, - dst: &float32Slice, - expected: (([]float32)(nil)), - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.NumericArray - dst interface{} - }{ - { - src: pgtype.NumericArray{ - Elements: []pgtype.Numeric{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &float32Slice, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index 9d7d83d67..6cf951f4b 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -1,42 +1,22 @@ package pgtype_test import ( + "context" + "encoding/json" + "math" "math/big" "math/rand" "reflect" + "strconv" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -// For test purposes only. Note that it does not normalize values. e.g. (Int: 1, Exp: 3) will not equal (Int: 1000, Exp: 0) -func numericEqual(left, right *pgtype.Numeric) bool { - return left.Status == right.Status && - left.Exp == right.Exp && - ((left.Int == nil && right.Int == nil) || (left.Int != nil && right.Int != nil && left.Int.Cmp(right.Int) == 0)) -} - -// For test purposes only. -func numericNormalizedEqual(left, right *pgtype.Numeric) bool { - if left.Status != right.Status { - return false - } - - normLeft := &pgtype.Numeric{Int: (&big.Int{}).Set(left.Int), Status: left.Status} - normRight := &pgtype.Numeric{Int: (&big.Int{}).Set(right.Int), Status: right.Status} - - if left.Exp < right.Exp { - mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(right.Exp-left.Exp)), nil) - normRight.Int.Mul(normRight.Int, mul) - } else if left.Exp > right.Exp { - mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(left.Exp-right.Exp)), nil) - normLeft.Int.Mul(normLeft.Int, mul) - } - - return normLeft.Int.Cmp(normRight.Int) == 0 -} - func mustParseBigInt(t *testing.T, src string) *big.Int { i := &big.Int{} if _, ok := i.SetString(src, 10); !ok { @@ -45,311 +25,279 @@ func mustParseBigInt(t *testing.T, src string) *big.Int { return i } -func TestNumericNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select '0'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, - }, - { - SQL: "select '1'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, - }, - { - SQL: "select '10.00'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(1000), Exp: -2, Status: pgtype.Present}, - }, - { - SQL: "select '1e-3'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: -3, Status: pgtype.Present}, - }, - { - SQL: "select '-1'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, - }, - { - SQL: "select '10000'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: 4, Status: pgtype.Present}, - }, - { - SQL: "select '3.14'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, - }, - { - SQL: "select '1.1'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(11), Exp: -1, Status: pgtype.Present}, - }, - { - SQL: "select '100010001'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(100010001), Exp: 0, Status: pgtype.Present}, - }, - { - SQL: "select '100010001.0001'::numeric", - Value: &pgtype.Numeric{Int: big.NewInt(1000100010001), Exp: -4, Status: pgtype.Present}, - }, - { - SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", - Value: &pgtype.Numeric{ - Int: mustParseBigInt(t, "423723478923478928934789237432487213832189417894318904389012483210893443219085471578891547854892438945012347981"), - Exp: -41, - Status: pgtype.Present, - }, - }, - { - SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", - Value: &pgtype.Numeric{ - Int: mustParseBigInt(t, "8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), - Exp: -196, - Status: pgtype.Present, - }, - }, - { - SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", - Value: &pgtype.Numeric{ - Int: mustParseBigInt(t, "123"), - Exp: -186, - Status: pgtype.Present, - }, - }, - }) -} - -func TestNumericTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ - &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(1), Exp: 6, Status: pgtype.Present}, +func isExpectedEqNumeric(a any) func(any) bool { + return func(v any) bool { + aa := a.(pgtype.Numeric) + vv := v.(pgtype.Numeric) - // preserves significant zeroes - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -1, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -2, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -3, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -4, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -5, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -6, Status: pgtype.Present}, + if aa.Valid != vv.Valid { + return false + } - &pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(123), Exp: -7, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(123), Exp: -8, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(123), Exp: -9, Status: pgtype.Present}, - &pgtype.Numeric{Int: big.NewInt(123), Exp: -1500, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "2437"), Exp: 23790, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "243723409723490243842378942378901237502734019231380123"), Exp: 23790, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 80, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "3723409723490243842378942378901237502734019231380123"), Exp: 81, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "723409723490243842378942378901237502734019231380123"), Exp: 82, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "23409723490243842378942378901237502734019231380123"), Exp: 83, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "3409723490243842378942378901237502734019231380123"), Exp: 84, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "913423409823409243892349028349023482934092340892390101"), Exp: -14021, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -90, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "3423409823409243892349028349023482934092340892390101"), Exp: -91, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "423409823409243892349028349023482934092340892390101"), Exp: -92, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "23409823409243892349028349023482934092340892390101"), Exp: -93, Status: pgtype.Present}, - &pgtype.Numeric{Int: mustParseBigInt(t, "3409823409243892349028349023482934092340892390101"), Exp: -94, Status: pgtype.Present}, - &pgtype.Numeric{Status: pgtype.Null}, - }, func(aa, bb interface{}) bool { - a := aa.(pgtype.Numeric) - b := bb.(pgtype.Numeric) + // If NULL doesn't matter what the rest of the values are. + if !aa.Valid { + return true + } - return numericEqual(&a, &b) - }) + if !(aa.NaN == vv.NaN && aa.InfinityModifier == vv.InfinityModifier) { + return false + } -} + // If NaN or InfinityModifier are set then Int and Exp don't matter. + if aa.NaN || aa.InfinityModifier != pgtype.Finite { + return true + } -func TestNumericTranscodeFuzz(t *testing.T) { - r := rand.New(rand.NewSource(0)) - max := &big.Int{} - max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) + aaInt := (&big.Int{}).Set(aa.Int) + vvInt := (&big.Int{}).Set(vv.Int) - values := make([]interface{}, 0, 2000) - for i := 0; i < 10; i++ { - for j := -50; j < 50; j++ { - num := (&big.Int{}).Rand(r, max) - negNum := &big.Int{} - negNum.Neg(num) - values = append(values, &pgtype.Numeric{Int: num, Exp: int32(j), Status: pgtype.Present}) - values = append(values, &pgtype.Numeric{Int: negNum, Exp: int32(j), Status: pgtype.Present}) + if aa.Exp < vv.Exp { + mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(vv.Exp-aa.Exp)), nil) + vvInt.Mul(vvInt, mul) + } else if aa.Exp > vv.Exp { + mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(aa.Exp-vv.Exp)), nil) + aaInt.Mul(aaInt, mul) } - } - - testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", values, - func(aa, bb interface{}) bool { - a := aa.(pgtype.Numeric) - b := bb.(pgtype.Numeric) - return numericNormalizedEqual(&a, &b) - }) + return aaInt.Cmp(vvInt) == 0 + } } -func TestNumericSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result *pgtype.Numeric - }{ - {source: float32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: float64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: int16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: int32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: int64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: int8(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, - {source: int16(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, - {source: int32(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, - {source: int64(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, - {source: uint8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: uint16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: uint32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: uint64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: "1", result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: _int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, - {source: float64(1000), result: &pgtype.Numeric{Int: big.NewInt(1), Exp: 3, Status: pgtype.Present}}, - {source: float64(1234), result: &pgtype.Numeric{Int: big.NewInt(1234), Exp: 0, Status: pgtype.Present}}, - {source: float64(12345678900), result: &pgtype.Numeric{Int: big.NewInt(123456789), Exp: 2, Status: pgtype.Present}}, - {source: float64(12345.678901), result: &pgtype.Numeric{Int: big.NewInt(12345678901), Exp: -6, Status: pgtype.Present}}, - } +func mustParseNumeric(t *testing.T, src string) pgtype.Numeric { + var n pgtype.Numeric + plan := pgtype.NumericCodec{}.PlanScan(nil, pgtype.NumericOID, pgtype.TextFormatCode, &n) + require.NotNil(t, plan) + err := plan.Scan([]byte(src), &n) + require.NoError(t, err) + return n +} - for i, tt := range successfulTests { - r := &pgtype.Numeric{} - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } +func TestNumericCodec(t *testing.T) { + skipCockroachDB(t, "server formats numeric text format differently") + + max := new(big.Int).Exp(big.NewInt(10), big.NewInt(147454), nil) + max.Add(max, big.NewInt(1)) + longestNumeric := pgtype.Numeric{Int: max, Exp: -16383, Valid: true} + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "numeric", []pgxtest.ValueRoundTripTest{ + {mustParseNumeric(t, "1"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "1"))}, + {mustParseNumeric(t, "3.14159"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "3.14159"))}, + {mustParseNumeric(t, "100010001"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "100010001"))}, + {mustParseNumeric(t, "100010001.0001"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "100010001.0001"))}, + {mustParseNumeric(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"))}, + {mustParseNumeric(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"))}, + {mustParseNumeric(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"))}, + {pgtype.Numeric{Int: mustParseBigInt(t, "243723409723490243842378942378901237502734019231380123"), Exp: 23790, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "243723409723490243842378942378901237502734019231380123"), Exp: 23790, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "2437"), Exp: 23790, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "2437"), Exp: 23790, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 80, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 80, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 81, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 81, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 82, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 82, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 83, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 83, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 84, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 84, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "913423409823409243892349028349023482934092340892390101"), Exp: -14021, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "913423409823409243892349028349023482934092340892390101"), Exp: -14021, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -90, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -90, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -91, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -91, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -92, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -92, Valid: true})}, + {pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -93, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -93, Valid: true})}, + {pgtype.Numeric{NaN: true, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{NaN: true, Valid: true})}, + {longestNumeric, new(pgtype.Numeric), isExpectedEqNumeric(longestNumeric)}, + {mustParseNumeric(t, "1"), new(int64), isExpectedEq(int64(1))}, + {math.NaN(), new(float64), func(a any) bool { return math.IsNaN(a.(float64)) }}, + {float32(math.NaN()), new(float32), func(a any) bool { return math.IsNaN(float64(a.(float32))) }}, + {int64(-1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "-1"))}, + {int64(0), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "0"))}, + {int64(1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "1"))}, + {int64(math.MinInt64), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MinInt64, 10)))}, + {int64(math.MinInt64 + 1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MinInt64+1, 10)))}, + {int64(math.MaxInt64), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MaxInt64, 10)))}, + {int64(math.MaxInt64 - 1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MaxInt64-1, 10)))}, + {uint64(100), new(uint64), isExpectedEq(uint64(100))}, + {uint64(math.MaxUint64), new(uint64), isExpectedEq(uint64(math.MaxUint64))}, + {uint(math.MaxUint), new(uint), isExpectedEq(uint(math.MaxUint))}, + {uint(100), new(uint), isExpectedEq(uint(100))}, + {"1.23", new(string), isExpectedEq("1.23")}, + {pgtype.Numeric{}, new(pgtype.Numeric), isExpectedEq(pgtype.Numeric{})}, + {nil, new(pgtype.Numeric), isExpectedEq(pgtype.Numeric{})}, + {mustParseNumeric(t, "1"), new(string), isExpectedEq("1")}, + {pgtype.Numeric{NaN: true, Valid: true}, new(string), isExpectedEq("NaN")}, + }) - if !numericEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int8", []pgxtest.ValueRoundTripTest{ + {mustParseNumeric(t, "-1"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "-1"))}, + {mustParseNumeric(t, "0"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "0"))}, + {mustParseNumeric(t, "1"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "1"))}, + }) } -func TestNumericAssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - var f32 float32 - var f64 float64 - var pf32 *float32 - var pf64 *float64 +func TestNumericCodecInfinity(t *testing.T) { + skipCockroachDB(t, "server formats numeric text format differently") + skipPostgreSQLVersionLessThan(t, 14) + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "numeric", []pgxtest.ValueRoundTripTest{ + {math.Inf(1), new(float64), isExpectedEq(math.Inf(1))}, + {float32(math.Inf(1)), new(float32), isExpectedEq(float32(math.Inf(1)))}, + {math.Inf(-1), new(float64), isExpectedEq(math.Inf(-1))}, + {float32(math.Inf(-1)), new(float32), isExpectedEq(float32(math.Inf(-1)))}, + {pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true})}, + {pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true})}, + {pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, new(string), isExpectedEq("Infinity")}, + {pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, new(string), isExpectedEq("-Infinity")}, + }) +} - simpleTests := []struct { - src *pgtype.Numeric - dst interface{} - expected interface{} +func TestNumericFloat64Valuer(t *testing.T) { + for i, tt := range []struct { + n pgtype.Numeric + f pgtype.Float8 }{ - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f32, expected: float32(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f64, expected: float64(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Status: pgtype.Present}, dst: &f32, expected: float32(4.2)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Status: pgtype.Present}, dst: &f64, expected: float64(4.2)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: 3, Status: pgtype.Present}, dst: &i64, expected: int64(42000)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + {mustParseNumeric(t, "1"), pgtype.Float8{Float64: 1, Valid: true}}, + {mustParseNumeric(t, "0.0000000000000000001"), pgtype.Float8{Float64: 0.0000000000000000001, Valid: true}}, + {mustParseNumeric(t, "-99999999999"), pgtype.Float8{Float64: -99999999999, Valid: true}}, + {pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, pgtype.Float8{Float64: math.Inf(1), Valid: true}}, + {pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, pgtype.Float8{Float64: math.Inf(-1), Valid: true}}, + {pgtype.Numeric{Valid: true}, pgtype.Float8{Valid: true}}, + {pgtype.Numeric{}, pgtype.Float8{}}, + } { + f, err := tt.n.Float64Value() + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.f, f, "%d", i) } - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } + f, err := pgtype.Numeric{NaN: true, Valid: true}.Float64Value() + assert.NoError(t, err) + assert.True(t, math.IsNaN(f.Float64)) + assert.True(t, f.Valid) +} - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } +func TestNumericCodecFuzz(t *testing.T) { + skipCockroachDB(t, "server formats numeric text format differently") - pointerAllocTests := []struct { - src *pgtype.Numeric - dst interface{} - expected interface{} - }{ - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, - {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, - } + r := rand.New(rand.NewSource(0)) + max := &big.Int{} + max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } + tests := make([]pgxtest.ValueRoundTripTest, 0, 2000) + for i := 0; i < 10; i++ { + for j := -50; j < 50; j++ { + num := (&big.Int{}).Rand(r, max) + + n := pgtype.Numeric{Int: num, Exp: int32(j), Valid: true} + tests = append(tests, pgxtest.ValueRoundTripTest{n, new(pgtype.Numeric), isExpectedEqNumeric(n)}) - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + negNum := &big.Int{} + negNum.Neg(num) + n = pgtype.Numeric{Int: negNum, Exp: int32(j), Valid: true} + tests = append(tests, pgxtest.ValueRoundTripTest{n, new(pgtype.Numeric), isExpectedEqNumeric(n)}) } } - errorTests := []struct { - src *pgtype.Numeric - dst interface{} - }{ - {src: &pgtype.Numeric{Int: big.NewInt(150), Status: pgtype.Present}, dst: &i8}, - {src: &pgtype.Numeric{Int: big.NewInt(40000), Status: pgtype.Present}, dst: &i16}, - {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui8}, - {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui16}, - {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui32}, - {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui64}, - {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui}, - {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &i32}, - } + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "numeric", tests) +} - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) +func TestNumericMarshalJSON(t *testing.T) { + skipCockroachDB(t, "server formats numeric text format differently") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + for i, tt := range []struct { + decString string + }{ + {"NaN"}, + {"0"}, + {"1"}, + {"-1"}, + {"1000000000000000000"}, + {"1234.56789"}, + {"1.56789"}, + {"0.00000000000056789"}, + {"0.00123000"}, + {"123e-3"}, + {"243723409723490243842378942378901237502734019231380123e23790"}, + {"3409823409243892349028349023482934092340892390101e-14021"}, + {"-1.1"}, + {"-1.0231"}, + {"-10.0231"}, + {"-0.1"}, // failed with "invalid character '.' in numeric literal" + {"-0.01"}, // failed with "invalid character '-' after decimal point in numeric literal" + {"-0.001"}, // failed with "invalid character '-' after top-level value" + } { + var num pgtype.Numeric + var pgJSON string + err := conn.QueryRow(ctx, `select $1::numeric, to_json($1::numeric)`, tt.decString).Scan(&num, &pgJSON) + require.NoErrorf(t, err, "%d", i) + + goJSON, err := json.Marshal(num) + require.NoErrorf(t, err, "%d", i) + + require.Equal(t, pgJSON, string(goJSON)) } - } + }) } -func TestNumericEncodeDecodeBinary(t *testing.T) { - ci := pgtype.NewConnInfo() - tests := []interface{}{ - 123, - 0.000012345, - 1.00002345, +func TestNumericUnmarshalJSON(t *testing.T) { + tests := []struct { + name string + want *pgtype.Numeric + src []byte + wantErr bool + }{ + { + name: "null", + want: &pgtype.Numeric{}, + src: []byte(`null`), + wantErr: false, + }, + { + name: "NaN", + want: &pgtype.Numeric{Valid: true, NaN: true}, + src: []byte(`"NaN"`), + wantErr: false, + }, + { + name: "0", + want: &pgtype.Numeric{Valid: true, Int: big.NewInt(0)}, + src: []byte("0"), + wantErr: false, + }, + { + name: "1", + want: &pgtype.Numeric{Valid: true, Int: big.NewInt(1)}, + src: []byte("1"), + wantErr: false, + }, + { + name: "-1", + want: &pgtype.Numeric{Valid: true, Int: big.NewInt(-1)}, + src: []byte("-1"), + wantErr: false, + }, + { + name: "bigInt", + want: &pgtype.Numeric{Valid: true, Int: big.NewInt(1), Exp: 30}, + src: []byte("1000000000000000000000000000000"), + wantErr: false, + }, + { + name: "float: 1234.56789", + want: &pgtype.Numeric{Valid: true, Int: big.NewInt(123456789), Exp: -5}, + src: []byte("1234.56789"), + wantErr: false, + }, + { + name: "invalid value", + want: &pgtype.Numeric{}, + src: []byte("0xffff"), + wantErr: true, + }, } - - for i, tt := range tests { - toString := func(n *pgtype.Numeric) string { - ci := pgtype.NewConnInfo() - text, err := n.EncodeText(ci, nil) - if err != nil { - t.Errorf("%d: %v", i, err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := &pgtype.Numeric{} + if err := got.UnmarshalJSON(tt.src); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) } - return string(text) - } - numeric := &pgtype.Numeric{} - numeric.Set(tt) - - encoded, err := numeric.EncodeBinary(ci, nil) - if err != nil { - t.Errorf("%d: %v", i, err) - } - decoded := &pgtype.Numeric{} - decoded.DecodeBinary(ci, encoded) - - text0 := toString(numeric) - text1 := toString(decoded) - - if text0 != text1 { - t.Errorf("%d: expected %v to equal to %v, but doesn't", i, text0, text1) - } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("UnmarshalJSON() got = %v, want %v", got, tt.want) + } + }) } } diff --git a/pgtype/numrange.go b/pgtype/numrange.go deleted file mode 100644 index aaed62cee..000000000 --- a/pgtype/numrange.go +++ /dev/null @@ -1,250 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type Numrange struct { - Lower Numeric - Upper Numeric - LowerType BoundType - UpperType BoundType - Status Status -} - -func (dst *Numrange) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Numrange", src) -} - -func (dst *Numrange) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *Numrange) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Numrange) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Numrange{Status: Null} - return nil - } - - utr, err := ParseUntypedTextRange(string(src)) - if err != nil { - return err - } - - *dst = Numrange{Status: Present} - - dst.LowerType = utr.LowerType - dst.UpperType = utr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { - return err - } - } - - return nil -} - -func (dst *Numrange) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Numrange{Status: Null} - return nil - } - - ubr, err := ParseUntypedBinaryRange(src) - if err != nil { - return err - } - - *dst = Numrange{Status: Present} - - dst.LowerType = ubr.LowerType - dst.UpperType = ubr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { - return err - } - } - - return nil -} - -func (src Numrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - switch src.LowerType { - case Exclusive, Unbounded: - buf = append(buf, '(') - case Inclusive: - buf = append(buf, '[') - case Empty: - return append(buf, "empty"...), nil - default: - return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) - } - - var err error - - if src.LowerType != Unbounded { - buf, err = src.Lower.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - } - - buf = append(buf, ',') - - if src.UpperType != Unbounded { - buf, err = src.Upper.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - } - - switch src.UpperType { - case Exclusive, Unbounded: - buf = append(buf, ')') - case Inclusive: - buf = append(buf, ']') - default: - return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) - } - - return buf, nil -} - -func (src Numrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - var rangeType byte - switch src.LowerType { - case Inclusive: - rangeType |= lowerInclusiveMask - case Unbounded: - rangeType |= lowerUnboundedMask - case Exclusive: - case Empty: - return append(buf, emptyMask), nil - default: - return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) - } - - switch src.UpperType { - case Inclusive: - rangeType |= upperInclusiveMask - case Unbounded: - rangeType |= upperUnboundedMask - case Exclusive: - default: - return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) - } - - buf = append(buf, rangeType) - - var err error - - if src.LowerType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Lower.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - if src.UpperType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Upper.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Numrange) Scan(src interface{}) error { - if src == nil { - *dst = Numrange{Status: Null} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Numrange) Value() (driver.Value, error) { - return EncodeValueText(src) -} diff --git a/pgtype/numrange_test.go b/pgtype/numrange_test.go deleted file mode 100644 index ccc794d5c..000000000 --- a/pgtype/numrange_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package pgtype_test - -import ( - "math/big" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestNumrangeTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "numrange", []interface{}{ - &pgtype.Numrange{ - LowerType: pgtype.Empty, - UpperType: pgtype.Empty, - Status: pgtype.Present, - }, - &pgtype.Numrange{ - Lower: pgtype.Numeric{Int: big.NewInt(-543), Exp: 3, Status: pgtype.Present}, - Upper: pgtype.Numeric{Int: big.NewInt(342), Exp: 1, Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - &pgtype.Numrange{ - Lower: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Status: pgtype.Present}, - Upper: pgtype.Numeric{Int: big.NewInt(-5), Exp: 0, Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - &pgtype.Numrange{ - Lower: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Unbounded, - Status: pgtype.Present, - }, - &pgtype.Numrange{ - Upper: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Status: pgtype.Present}, - LowerType: pgtype.Unbounded, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - &pgtype.Numrange{Status: pgtype.Null}, - }) -} diff --git a/pgtype/oid.go b/pgtype/oid.go deleted file mode 100644 index 59370d66c..000000000 --- a/pgtype/oid.go +++ /dev/null @@ -1,81 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "strconv" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -// OID (Object Identifier Type) is, according to -// https://www.postgresql.org/docs/current/static/datatype-oid.html, used -// internally by PostgreSQL as a primary key for various system tables. It is -// currently implemented as an unsigned four-byte integer. Its definition can be -// found in src/include/postgres_ext.h in the PostgreSQL sources. Because it is -// so frequently required to be in a NOT NULL condition OID cannot be NULL. To -// allow for NULL OIDs use OIDValue. -type OID uint32 - -func (dst *OID) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - return errors.Errorf("cannot decode nil into OID") - } - - n, err := strconv.ParseUint(string(src), 10, 32) - if err != nil { - return err - } - - *dst = OID(n) - return nil -} - -func (dst *OID) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - return errors.Errorf("cannot decode nil into OID") - } - - if len(src) != 4 { - return errors.Errorf("invalid length: %v", len(src)) - } - - n := binary.BigEndian.Uint32(src) - *dst = OID(n) - return nil -} - -func (src OID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return append(buf, strconv.FormatUint(uint64(src), 10)...), nil -} - -func (src OID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return pgio.AppendUint32(buf, uint32(src)), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *OID) Scan(src interface{}) error { - if src == nil { - return errors.Errorf("cannot scan NULL into %T", src) - } - - switch src := src.(type) { - case int64: - *dst = OID(src) - return nil - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src OID) Value() (driver.Value, error) { - return int64(src), nil -} diff --git a/pgtype/oid_value.go b/pgtype/oid_value.go deleted file mode 100644 index 7eae4bf10..000000000 --- a/pgtype/oid_value.go +++ /dev/null @@ -1,55 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" -) - -// OIDValue (Object Identifier Type) is, according to -// https://www.postgresql.org/docs/current/static/datatype-OIDValue.html, used -// internally by PostgreSQL as a primary key for various system tables. It is -// currently implemented as an unsigned four-byte integer. Its definition can be -// found in src/include/postgres_ext.h in the PostgreSQL sources. -type OIDValue pguint32 - -// Set converts from src to dst. Note that as OIDValue is not a general -// number type Set does not do automatic type conversion as other number -// types do. -func (dst *OIDValue) Set(src interface{}) error { - return (*pguint32)(dst).Set(src) -} - -func (dst *OIDValue) Get() interface{} { - return (*pguint32)(dst).Get() -} - -// AssignTo assigns from src to dst. Note that as OIDValue is not a general number -// type AssignTo does not do automatic type conversion as other number types do. -func (src *OIDValue) AssignTo(dst interface{}) error { - return (*pguint32)(src).AssignTo(dst) -} - -func (dst *OIDValue) DecodeText(ci *ConnInfo, src []byte) error { - return (*pguint32)(dst).DecodeText(ci, src) -} - -func (dst *OIDValue) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*pguint32)(dst).DecodeBinary(ci, src) -} - -func (src *OIDValue) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*pguint32)(src).EncodeText(ci, buf) -} - -func (src *OIDValue) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*pguint32)(src).EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *OIDValue) Scan(src interface{}) error { - return (*pguint32)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *OIDValue) Value() (driver.Value, error) { - return (*pguint32)(src).Value() -} diff --git a/pgtype/oid_value_test.go b/pgtype/oid_value_test.go deleted file mode 100644 index f5ff16cf0..000000000 --- a/pgtype/oid_value_test.go +++ /dev/null @@ -1,95 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestOIDValueTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "oid", []interface{}{ - &pgtype.OIDValue{Uint: 42, Status: pgtype.Present}, - &pgtype.OIDValue{Status: pgtype.Null}, - }) -} - -func TestOIDValueSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.OIDValue - }{ - {source: uint32(1), result: pgtype.OIDValue{Uint: 1, Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var r pgtype.OIDValue - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestOIDValueAssignTo(t *testing.T) { - var ui32 uint32 - var pui32 *uint32 - - simpleTests := []struct { - src pgtype.OIDValue - dst interface{} - expected interface{} - }{ - {src: pgtype.OIDValue{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.OIDValue{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.OIDValue - dst interface{} - expected interface{} - }{ - {src: pgtype.OIDValue{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.OIDValue - dst interface{} - }{ - {src: pgtype.OIDValue{Status: pgtype.Null}, dst: &ui32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/path.go b/pgtype/path.go index aa0cee8e5..81dc1e5b5 100644 --- a/pgtype/path.go +++ b/pgtype/path.go @@ -8,86 +8,181 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" + "github.com/jackc/pgx/v5/internal/pgio" ) +type PathScanner interface { + ScanPath(v Path) error +} + +type PathValuer interface { + PathValue() (Path, error) +} + type Path struct { P []Vec2 Closed bool - Status Status + Valid bool +} + +// ScanPath implements the [PathScanner] interface. +func (path *Path) ScanPath(v Path) error { + *path = v + return nil } -func (dst *Path) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Path", src) +// PathValue implements the [PathValuer] interface. +func (path Path) PathValue() (Path, error) { + return path, nil } -func (dst *Path) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: +// Scan implements the [database/sql.Scanner] interface. +func (path *Path) Scan(src any) error { + if src == nil { + *path = Path{} return nil - default: - return dst.Status } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToPathScanner{}.Scan([]byte(src), path) + } + + return fmt.Errorf("cannot scan %T", src) } -func (src *Path) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) +// Value implements the [database/sql/driver.Valuer] interface. +func (path Path) Value() (driver.Value, error) { + if !path.Valid { + return nil, nil + } + + buf, err := PathCodec{}.PlanEncode(nil, 0, TextFormatCode, path).Encode(path, nil) + if err != nil { + return nil, err + } + + return string(buf), err } -func (dst *Path) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Path{Status: Null} +type PathCodec struct{} + +func (PathCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (PathCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (PathCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(PathValuer); !ok { return nil } - if len(src) < 7 { - return errors.Errorf("invalid length for Path: %v", len(src)) + switch format { + case BinaryFormatCode: + return encodePlanPathCodecBinary{} + case TextFormatCode: + return encodePlanPathCodecText{} } - closed := src[0] == '(' - points := make([]Vec2, 0) + return nil +} - str := string(src[2:]) +type encodePlanPathCodecBinary struct{} - for { - end := strings.IndexByte(str, ',') - x, err := strconv.ParseFloat(str[:end], 64) - if err != nil { - return err - } +func (encodePlanPathCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + path, err := value.(PathValuer).PathValue() + if err != nil { + return nil, err + } - str = str[end+1:] - end = strings.IndexByte(str, ')') + if !path.Valid { + return nil, nil + } - y, err := strconv.ParseFloat(str[:end], 64) - if err != nil { - return err + var closeByte byte + if path.Closed { + closeByte = 1 + } + buf = append(buf, closeByte) + + buf = pgio.AppendInt32(buf, int32(len(path.P))) + + for _, p := range path.P { + buf = pgio.AppendUint64(buf, math.Float64bits(p.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(p.Y)) + } + + return buf, nil +} + +type encodePlanPathCodecText struct{} + +func (encodePlanPathCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + path, err := value.(PathValuer).PathValue() + if err != nil { + return nil, err + } + + if !path.Valid { + return nil, nil + } + + var startByte, endByte byte + if path.Closed { + startByte = '(' + endByte = ')' + } else { + startByte = '[' + endByte = ']' + } + buf = append(buf, startByte) + + for i, p := range path.P { + if i > 0 { + buf = append(buf, ',') } + buf = append(buf, fmt.Sprintf(`(%s,%s)`, + strconv.FormatFloat(p.X, 'f', -1, 64), + strconv.FormatFloat(p.Y, 'f', -1, 64), + )...) + } - points = append(points, Vec2{x, y}) + buf = append(buf, endByte) - if end+3 < len(str) { - str = str[end+3:] - } else { - break + return buf, nil +} + +func (PathCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case PathScanner: + return scanPlanBinaryPathToPathScanner{} + } + case TextFormatCode: + switch target.(type) { + case PathScanner: + return scanPlanTextAnyToPathScanner{} } } - *dst = Path{P: points, Closed: closed, Status: Present} return nil } -func (dst *Path) DecodeBinary(ci *ConnInfo, src []byte) error { +type scanPlanBinaryPathToPathScanner struct{} + +func (scanPlanBinaryPathToPathScanner) Scan(src []byte, dst any) error { + scanner := (dst).(PathScanner) + if src == nil { - *dst = Path{Status: Null} - return nil + return scanner.ScanPath(Path{}) } if len(src) < 5 { - return errors.Errorf("invalid length for Path: %v", len(src)) + return fmt.Errorf("invalid length for Path: %v", len(src)) } closed := src[0] == 1 @@ -96,7 +191,7 @@ func (dst *Path) DecodeBinary(ci *ConnInfo, src []byte) error { rp := 5 if 5+pointCount*16 != len(src) { - return errors.Errorf("invalid length for Path with %d points: %v", pointCount, len(src)) + return fmt.Errorf("invalid length for Path with %d points: %v", pointCount, len(src)) } points := make([]Vec2, pointCount) @@ -108,86 +203,71 @@ func (dst *Path) DecodeBinary(ci *ConnInfo, src []byte) error { points[i] = Vec2{math.Float64frombits(x), math.Float64frombits(y)} } - *dst = Path{ + return scanner.ScanPath(Path{ P: points, Closed: closed, - Status: Present, - } - return nil + Valid: true, + }) } -func (src *Path) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } +type scanPlanTextAnyToPathScanner struct{} - var startByte, endByte byte - if src.Closed { - startByte = '(' - endByte = ')' - } else { - startByte = '[' - endByte = ']' +func (scanPlanTextAnyToPathScanner) Scan(src []byte, dst any) error { + scanner := (dst).(PathScanner) + + if src == nil { + return scanner.ScanPath(Path{}) } - buf = append(buf, startByte) - for i, p := range src.P { - if i > 0 { - buf = append(buf, ',') - } - buf = append(buf, fmt.Sprintf(`(%f,%f)`, p.X, p.Y)...) + if len(src) < 7 { + return fmt.Errorf("invalid length for Path: %v", len(src)) } - return append(buf, endByte), nil -} + closed := src[0] == '(' + points := make([]Vec2, 0) -func (src *Path) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } + str := string(src[2:]) - var closeByte byte - if src.Closed { - closeByte = 1 - } - buf = append(buf, closeByte) + for { + end := strings.IndexByte(str, ',') + x, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } - buf = pgio.AppendInt32(buf, int32(len(src.P))) + str = str[end+1:] + end = strings.IndexByte(str, ')') - for _, p := range src.P { - buf = pgio.AppendUint64(buf, math.Float64bits(p.X)) - buf = pgio.AppendUint64(buf, math.Float64bits(p.Y)) + y, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + points = append(points, Vec2{x, y}) + + if end+3 < len(str) { + str = str[end+3:] + } else { + break + } } - return buf, nil + return scanner.ScanPath(Path{P: points, Closed: closed, Valid: true}) } -// Scan implements the database/sql Scanner interface. -func (dst *Path) Scan(src interface{}) error { +func (c PathCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c PathCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { - *dst = Path{Status: Null} - return nil + return nil, nil } - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + var path Path + err := codecScan(c, m, oid, format, src, &path) + if err != nil { + return nil, err } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *Path) Value() (driver.Value, error) { - return EncodeValueText(src) + return path, nil } diff --git a/pgtype/path_test.go b/pgtype/path_test.go index d213a1b44..cfffd22a6 100644 --- a/pgtype/path_test.go +++ b/pgtype/path_test.go @@ -1,29 +1,76 @@ package pgtype_test import ( + "context" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" ) +func isExpectedEqPath(a any) func(any) bool { + return func(v any) bool { + ap := a.(pgtype.Path) + vp := v.(pgtype.Path) + + if !(ap.Valid == vp.Valid && ap.Closed == vp.Closed && len(ap.P) == len(vp.P)) { + return false + } + + for i := range ap.P { + if ap.P[i] != vp.P[i] { + return false + } + } + + return true + } +} + func TestPathTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "path", []interface{}{ - &pgtype.Path{ - P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}}, - Closed: false, - Status: pgtype.Present, + skipCockroachDB(t, "Server does not support type path") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "path", []pgxtest.ValueRoundTripTest{ + { + pgtype.Path{ + P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}}, + Closed: false, + Valid: true, + }, + new(pgtype.Path), + isExpectedEqPath(pgtype.Path{ + P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}}, + Closed: false, + Valid: true, + }), }, - &pgtype.Path{ - P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {23.1, 9.34}}, - Closed: true, - Status: pgtype.Present, + { + pgtype.Path{ + P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {23.1, 9.34}}, + Closed: true, + Valid: true, + }, + new(pgtype.Path), + isExpectedEqPath(pgtype.Path{ + P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {23.1, 9.34}}, + Closed: true, + Valid: true, + }), }, - &pgtype.Path{ - P: []pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, - Closed: true, - Status: pgtype.Present, + { + pgtype.Path{ + P: []pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Closed: true, + Valid: true, + }, + new(pgtype.Path), + isExpectedEqPath(pgtype.Path{ + P: []pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Closed: true, + Valid: true, + }), }, - &pgtype.Path{Status: pgtype.Null}, + {pgtype.Path{}, new(pgtype.Path), isExpectedEqPath(pgtype.Path{})}, + {nil, new(pgtype.Path), isExpectedEqPath(pgtype.Path{})}, }) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 2643314e1..942dddb8e 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -1,80 +1,144 @@ package pgtype import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "net" + "net/netip" "reflect" - - "github.com/pkg/errors" + "time" ) // PostgreSQL oids for common types const ( - BoolOID = 16 - ByteaOID = 17 - CharOID = 18 - NameOID = 19 - Int8OID = 20 - Int2OID = 21 - Int4OID = 23 - TextOID = 25 - OIDOID = 26 - TIDOID = 27 - XIDOID = 28 - CIDOID = 29 - JSONOID = 114 - CIDROID = 650 - CIDRArrayOID = 651 - Float4OID = 700 - Float8OID = 701 - UnknownOID = 705 - InetOID = 869 - BoolArrayOID = 1000 - Int2ArrayOID = 1005 - Int4ArrayOID = 1007 - TextArrayOID = 1009 - ByteaArrayOID = 1001 - BPCharArrayOID = 1014 - VarcharArrayOID = 1015 - Int8ArrayOID = 1016 - Float4ArrayOID = 1021 - Float8ArrayOID = 1022 - ACLItemOID = 1033 - ACLItemArrayOID = 1034 - InetArrayOID = 1041 - BPCharOID = 1042 - VarcharOID = 1043 - DateOID = 1082 - TimestampOID = 1114 - TimestampArrayOID = 1115 - DateArrayOID = 1182 - TimestamptzOID = 1184 - TimestamptzArrayOID = 1185 - NumericOID = 1700 - RecordOID = 2249 - UUIDOID = 2950 - UUIDArrayOID = 2951 - JSONBOID = 3802 -) - -type Status byte - -const ( - Undefined Status = iota - Null - Present + BoolOID = 16 + ByteaOID = 17 + QCharOID = 18 + NameOID = 19 + Int8OID = 20 + Int2OID = 21 + Int4OID = 23 + TextOID = 25 + OIDOID = 26 + TIDOID = 27 + XIDOID = 28 + CIDOID = 29 + JSONOID = 114 + XMLOID = 142 + XMLArrayOID = 143 + JSONArrayOID = 199 + XID8ArrayOID = 271 + PointOID = 600 + LsegOID = 601 + PathOID = 602 + BoxOID = 603 + PolygonOID = 604 + LineOID = 628 + LineArrayOID = 629 + CIDROID = 650 + CIDRArrayOID = 651 + Float4OID = 700 + Float8OID = 701 + CircleOID = 718 + CircleArrayOID = 719 + UnknownOID = 705 + Macaddr8OID = 774 + MacaddrOID = 829 + InetOID = 869 + BoolArrayOID = 1000 + QCharArrayOID = 1002 + NameArrayOID = 1003 + Int2ArrayOID = 1005 + Int4ArrayOID = 1007 + TextArrayOID = 1009 + TIDArrayOID = 1010 + ByteaArrayOID = 1001 + XIDArrayOID = 1011 + CIDArrayOID = 1012 + BPCharArrayOID = 1014 + VarcharArrayOID = 1015 + Int8ArrayOID = 1016 + PointArrayOID = 1017 + LsegArrayOID = 1018 + PathArrayOID = 1019 + BoxArrayOID = 1020 + Float4ArrayOID = 1021 + Float8ArrayOID = 1022 + PolygonArrayOID = 1027 + OIDArrayOID = 1028 + ACLItemOID = 1033 + ACLItemArrayOID = 1034 + MacaddrArrayOID = 1040 + InetArrayOID = 1041 + BPCharOID = 1042 + VarcharOID = 1043 + DateOID = 1082 + TimeOID = 1083 + TimestampOID = 1114 + TimestampArrayOID = 1115 + DateArrayOID = 1182 + TimeArrayOID = 1183 + TimestamptzOID = 1184 + TimestamptzArrayOID = 1185 + IntervalOID = 1186 + IntervalArrayOID = 1187 + NumericArrayOID = 1231 + TimetzOID = 1266 + TimetzArrayOID = 1270 + BitOID = 1560 + BitArrayOID = 1561 + VarbitOID = 1562 + VarbitArrayOID = 1563 + NumericOID = 1700 + RecordOID = 2249 + RecordArrayOID = 2287 + UUIDOID = 2950 + UUIDArrayOID = 2951 + JSONBOID = 3802 + JSONBArrayOID = 3807 + DaterangeOID = 3912 + DaterangeArrayOID = 3913 + Int4rangeOID = 3904 + Int4rangeArrayOID = 3905 + NumrangeOID = 3906 + NumrangeArrayOID = 3907 + TsrangeOID = 3908 + TsrangeArrayOID = 3909 + TstzrangeOID = 3910 + TstzrangeArrayOID = 3911 + Int8rangeOID = 3926 + Int8rangeArrayOID = 3927 + JSONPathOID = 4072 + JSONPathArrayOID = 4073 + Int4multirangeOID = 4451 + NummultirangeOID = 4532 + TsmultirangeOID = 4533 + TstzmultirangeOID = 4534 + DatemultirangeOID = 4535 + Int8multirangeOID = 4536 + XID8OID = 5069 + Int4multirangeArrayOID = 6150 + NummultirangeArrayOID = 6151 + TsmultirangeArrayOID = 6152 + TstzmultirangeArrayOID = 6153 + DatemultirangeArrayOID = 6155 + Int8multirangeArrayOID = 6157 ) type InfinityModifier int8 const ( Infinity InfinityModifier = 1 - None InfinityModifier = 0 + Finite InfinityModifier = 0 NegativeInfinity InfinityModifier = -Infinity ) func (im InfinityModifier) String() string { switch im { - case None: - return "none" + case Finite: + return "finite" case Infinity: return "infinity" case NegativeInfinity: @@ -84,197 +148,1903 @@ func (im InfinityModifier) String() string { } } -type Value interface { - // Set converts and assigns src to itself. - Set(src interface{}) error +// PostgreSQL format codes +const ( + TextFormatCode = 0 + BinaryFormatCode = 1 +) - // Get returns the simplest representation of Value. If the Value is Null or - // Undefined that is the return value. If no simpler representation is - // possible, then Get() returns Value. - Get() interface{} +// A Codec converts between Go and PostgreSQL values. A Codec must not be mutated after it is registered with a Map. +type Codec interface { + // FormatSupported returns true if the format is supported. + FormatSupported(int16) bool - // AssignTo converts and assigns the Value to dst. It MUST make a deep copy of - // any reference types. - AssignTo(dst interface{}) error -} + // PreferredFormat returns the preferred format. + PreferredFormat() int16 -type BinaryDecoder interface { - // DecodeBinary decodes src into BinaryDecoder. If src is nil then the - // original SQL value is NULL. BinaryDecoder takes ownership of src. The - // caller MUST not use it again. - DecodeBinary(ci *ConnInfo, src []byte) error + // PlanEncode returns an EncodePlan for encoding value into PostgreSQL format for oid and format. If no plan can be + // found then nil is returned. + PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan + + // PlanScan returns a ScanPlan for scanning a PostgreSQL value into a destination with the same type as target. If + // no plan can be found then nil is returned. + PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan + + // DecodeDatabaseSQLValue returns src decoded into a value compatible with the sql.Scanner interface. + DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) + + // DecodeValue returns src decoded into its default format. + DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) } -type TextDecoder interface { - // DecodeText decodes src into TextDecoder. If src is nil then the original - // SQL value is NULL. TextDecoder takes ownership of src. The caller MUST not - // use it again. - DecodeText(ci *ConnInfo, src []byte) error +type nullAssignmentError struct { + dst any } -// BinaryEncoder is implemented by types that can encode themselves into the -// PostgreSQL binary wire format. -type BinaryEncoder interface { - // EncodeBinary should append the binary format of self to buf. If self is the - // SQL value NULL then append nothing and return (nil, nil). The caller of - // EncodeBinary is responsible for writing the correct NULL value or the - // length of the data written. - EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) +func (e *nullAssignmentError) Error() string { + return fmt.Sprintf("cannot assign NULL to %T", e.dst) } -// TextEncoder is implemented by types that can encode themselves into the -// PostgreSQL text wire format. -type TextEncoder interface { - // EncodeText should append the text format of self to buf. If self is the - // SQL value NULL then append nothing and return (nil, nil). The caller of - // EncodeText is responsible for writing the correct NULL value or the - // length of the data written. - EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) +// Type represents a PostgreSQL data type. It must not be mutated after it is registered with a Map. +type Type struct { + Codec Codec + Name string + OID uint32 } -var errUndefined = errors.New("cannot encode status undefined") -var errBadStatus = errors.New("invalid status") +// Map is the mapping between PostgreSQL server types and Go type handling logic. It can encode values for +// transmission to a PostgreSQL server and scan received values. +type Map struct { + oidToType map[uint32]*Type + nameToType map[string]*Type + reflectTypeToName map[reflect.Type]string + oidToFormatCode map[uint32]int16 -type DataType struct { - Value Value - Name string - OID OID + reflectTypeToType map[reflect.Type]*Type + + memoizedEncodePlans map[uint32]map[reflect.Type][2]EncodePlan + + // TryWrapEncodePlanFuncs is a slice of functions that will wrap a value that cannot be encoded by the Codec. Every + // time a wrapper is found the PlanEncode method will be recursively called with the new value. This allows several layers of wrappers + // to be built up. There are default functions placed in this slice by NewMap(). In most cases these functions + // should run last. i.e. Additional functions should typically be prepended not appended. + TryWrapEncodePlanFuncs []TryWrapEncodePlanFunc + + // TryWrapScanPlanFuncs is a slice of functions that will wrap a target that cannot be scanned into by the Codec. Every + // time a wrapper is found the PlanScan method will be recursively called with the new target. This allows several layers of wrappers + // to be built up. There are default functions placed in this slice by NewMap(). In most cases these functions + // should run last. i.e. Additional functions should typically be prepended not appended. + TryWrapScanPlanFuncs []TryWrapScanPlanFunc } -type ConnInfo struct { - oidToDataType map[OID]*DataType - nameToDataType map[string]*DataType - reflectTypeToDataType map[reflect.Type]*DataType +// Copy returns a new Map containing the same registered types. +func (m *Map) Copy() *Map { + newMap := NewMap() + for _, type_ := range m.oidToType { + newMap.RegisterType(type_) + } + return newMap } -func NewConnInfo() *ConnInfo { - return &ConnInfo{ - oidToDataType: make(map[OID]*DataType, 256), - nameToDataType: make(map[string]*DataType, 256), - reflectTypeToDataType: make(map[reflect.Type]*DataType, 256), +func NewMap() *Map { + defaultMapInitOnce.Do(initDefaultMap) + + return &Map{ + oidToType: make(map[uint32]*Type), + nameToType: make(map[string]*Type), + reflectTypeToName: make(map[reflect.Type]string), + oidToFormatCode: make(map[uint32]int16), + + memoizedEncodePlans: make(map[uint32]map[reflect.Type][2]EncodePlan), + + TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{ + TryWrapDerefPointerEncodePlan, + TryWrapBuiltinTypeEncodePlan, + TryWrapFindUnderlyingTypeEncodePlan, + TryWrapStructEncodePlan, + TryWrapSliceEncodePlan, + TryWrapMultiDimSliceEncodePlan, + TryWrapArrayEncodePlan, + }, + + TryWrapScanPlanFuncs: []TryWrapScanPlanFunc{ + TryPointerPointerScanPlan, + TryWrapBuiltinTypeScanPlan, + TryFindUnderlyingTypeScanPlan, + TryWrapStructScanPlan, + TryWrapPtrSliceScanPlan, + TryWrapPtrMultiDimSliceScanPlan, + TryWrapPtrArrayScanPlan, + }, } } -func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]OID) { - for name, oid := range nameOIDs { - var value Value - if t, ok := nameValues[name]; ok { - value = reflect.New(reflect.ValueOf(t).Elem().Type()).Interface().(Value) - } else { - value = &GenericText{} - } - ci.RegisterDataType(DataType{Value: value, Name: name, OID: oid}) +// RegisterTypes registers multiple data types in the sequence they are provided. +func (m *Map) RegisterTypes(types []*Type) { + for _, t := range types { + m.RegisterType(t) + } +} + +// RegisterType registers a data type with the Map. t must not be mutated after it is registered. +func (m *Map) RegisterType(t *Type) { + m.oidToType[t.OID] = t + m.nameToType[t.Name] = t + m.oidToFormatCode[t.OID] = t.Codec.PreferredFormat() + + // Invalidated by type registration + m.reflectTypeToType = nil + for k := range m.memoizedEncodePlans { + delete(m.memoizedEncodePlans, k) } } -func (ci *ConnInfo) RegisterDataType(t DataType) { - ci.oidToDataType[t.OID] = &t - ci.nameToDataType[t.Name] = &t - ci.reflectTypeToDataType[reflect.ValueOf(t.Value).Type()] = &t +// RegisterDefaultPgType registers a mapping of a Go type to a PostgreSQL type name. Typically the data type to be +// encoded or decoded is determined by the PostgreSQL OID. But if the OID of a value to be encoded or decoded is +// unknown, this additional mapping will be used by TypeForValue to determine a suitable data type. +func (m *Map) RegisterDefaultPgType(value any, name string) { + m.reflectTypeToName[reflect.TypeOf(value)] = name + + // Invalidated by type registration + m.reflectTypeToType = nil + for k := range m.memoizedEncodePlans { + delete(m.memoizedEncodePlans, k) + } } -func (ci *ConnInfo) DataTypeForOID(oid OID) (*DataType, bool) { - dt, ok := ci.oidToDataType[oid] +// TypeForOID returns the Type registered for the given OID. The returned Type must not be mutated. +func (m *Map) TypeForOID(oid uint32) (*Type, bool) { + if dt, ok := m.oidToType[oid]; ok { + return dt, true + } + + dt, ok := defaultMap.oidToType[oid] return dt, ok } -func (ci *ConnInfo) DataTypeForName(name string) (*DataType, bool) { - dt, ok := ci.nameToDataType[name] +// TypeForName returns the Type registered for the given name. The returned Type must not be mutated. +func (m *Map) TypeForName(name string) (*Type, bool) { + if dt, ok := m.nameToType[name]; ok { + return dt, true + } + dt, ok := defaultMap.nameToType[name] return dt, ok } -func (ci *ConnInfo) DataTypeForValue(v Value) (*DataType, bool) { - dt, ok := ci.reflectTypeToDataType[reflect.ValueOf(v).Type()] +func (m *Map) buildReflectTypeToType() { + m.reflectTypeToType = make(map[reflect.Type]*Type) + + for reflectType, name := range m.reflectTypeToName { + if dt, ok := m.TypeForName(name); ok { + m.reflectTypeToType[reflectType] = dt + } + } +} + +// TypeForValue finds a data type suitable for v. Use RegisterType to register types that can encode and decode +// themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type. The returned Type +// must not be mutated. +func (m *Map) TypeForValue(v any) (*Type, bool) { + if m.reflectTypeToType == nil { + m.buildReflectTypeToType() + } + + if dt, ok := m.reflectTypeToType[reflect.TypeOf(v)]; ok { + return dt, true + } + + dt, ok := defaultMap.reflectTypeToType[reflect.TypeOf(v)] return dt, ok } -// DeepCopy makes a deep copy of the ConnInfo. -func (ci *ConnInfo) DeepCopy() *ConnInfo { - ci2 := &ConnInfo{ - oidToDataType: make(map[OID]*DataType, len(ci.oidToDataType)), - nameToDataType: make(map[string]*DataType, len(ci.nameToDataType)), - reflectTypeToDataType: make(map[reflect.Type]*DataType, len(ci.reflectTypeToDataType)), - } - - for _, dt := range ci.oidToDataType { - ci2.RegisterDataType(DataType{ - Value: reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value), - Name: dt.Name, - OID: dt.OID, - }) - } - - return ci2 -} - -var nameValues map[string]Value - -func init() { - nameValues = map[string]Value{ - "_aclitem": &ACLItemArray{}, - "_bool": &BoolArray{}, - "_bpchar": &BPCharArray{}, - "_bytea": &ByteaArray{}, - "_cidr": &CIDRArray{}, - "_date": &DateArray{}, - "_float4": &Float4Array{}, - "_float8": &Float8Array{}, - "_inet": &InetArray{}, - "_int2": &Int2Array{}, - "_int4": &Int4Array{}, - "_int8": &Int8Array{}, - "_numeric": &NumericArray{}, - "_text": &TextArray{}, - "_timestamp": &TimestampArray{}, - "_timestamptz": &TimestamptzArray{}, - "_uuid": &UUIDArray{}, - "_varchar": &VarcharArray{}, - "aclitem": &ACLItem{}, - "bit": &Bit{}, - "bool": &Bool{}, - "box": &Box{}, - "bpchar": &BPChar{}, - "bytea": &Bytea{}, - "char": &QChar{}, - "cid": &CID{}, - "cidr": &CIDR{}, - "circle": &Circle{}, - "date": &Date{}, - "daterange": &Daterange{}, - "decimal": &Decimal{}, - "float4": &Float4{}, - "float8": &Float8{}, - "hstore": &Hstore{}, - "inet": &Inet{}, - "int2": &Int2{}, - "int4": &Int4{}, - "int4range": &Int4range{}, - "int8": &Int8{}, - "int8range": &Int8range{}, - "interval": &Interval{}, - "json": &JSON{}, - "jsonb": &JSONB{}, - "line": &Line{}, - "lseg": &Lseg{}, - "macaddr": &Macaddr{}, - "name": &Name{}, - "numeric": &Numeric{}, - "numrange": &Numrange{}, - "oid": &OIDValue{}, - "path": &Path{}, - "point": &Point{}, - "polygon": &Polygon{}, - "record": &Record{}, - "text": &Text{}, - "tid": &TID{}, - "timestamp": &Timestamp{}, - "timestamptz": &Timestamptz{}, - "tsrange": &Tsrange{}, - "tstzrange": &Tstzrange{}, - "unknown": &Unknown{}, - "uuid": &UUID{}, - "varbit": &Varbit{}, - "varchar": &Varchar{}, - "xid": &XID{}, +// FormatCodeForOID returns the preferred format code for type oid. If the type is not registered it returns the text +// format code. +func (m *Map) FormatCodeForOID(oid uint32) int16 { + if fc, ok := m.oidToFormatCode[oid]; ok { + return fc + } + + if fc, ok := defaultMap.oidToFormatCode[oid]; ok { + return fc + } + + return TextFormatCode +} + +// EncodePlan is a precompiled plan to encode a particular type into a particular OID and format. +type EncodePlan interface { + // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return + // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data + // written. + Encode(value any, buf []byte) (newBuf []byte, err error) +} + +// ScanPlan is a precompiled plan to scan into a type of destination. +type ScanPlan interface { + // Scan scans src into target. src is only valid during the call to Scan. The ScanPlan must not retain a reference to + // src. + Scan(src []byte, target any) error +} + +type scanPlanCodecSQLScanner struct { + c Codec + m *Map + oid uint32 + formatCode int16 +} + +func (plan *scanPlanCodecSQLScanner) Scan(src []byte, dst any) error { + value, err := plan.c.DecodeDatabaseSQLValue(plan.m, plan.oid, plan.formatCode, src) + if err != nil { + return err + } + + scanner := dst.(sql.Scanner) + return scanner.Scan(value) +} + +type scanPlanSQLScanner struct { + formatCode int16 +} + +func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error { + scanner := dst.(sql.Scanner) + + if src == nil { + // This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the + // text format path would be converted to empty string. + return scanner.Scan(nil) + } else if plan.formatCode == BinaryFormatCode { + return scanner.Scan(src) + } else { + return scanner.Scan(string(src)) + } +} + +type scanPlanString struct{} + +func (scanPlanString) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + p := (dst).(*string) + *p = string(src) + return nil +} + +type scanPlanAnyTextToBytes struct{} + +func (scanPlanAnyTextToBytes) Scan(src []byte, dst any) error { + dstBuf := dst.(*[]byte) + if src == nil { + *dstBuf = nil + return nil + } + + *dstBuf = make([]byte, len(src)) + copy(*dstBuf, src) + return nil +} + +type scanPlanFail struct { + m *Map + oid uint32 + formatCode int16 +} + +func (plan *scanPlanFail) Scan(src []byte, dst any) error { + // If src is NULL it might be possible to scan into dst even though it is the types are not compatible. While this + // may seem to be a contrived case it can occur when selecting NULL directly. PostgreSQL assigns it the type of text. + // It would be surprising to the caller to have to cast the NULL (e.g. `select null::int`). So try to figure out a + // compatible data type for dst and scan with that. + // + // See https://github.com/jackc/pgx/issues/1326 + if src == nil { + // As a horrible hack try all types to find anything that can scan into dst. + for oid := range plan.m.oidToType { + // using planScan instead of Scan or PlanScan to avoid polluting the planned scan cache. + plan := plan.m.planScan(oid, plan.formatCode, dst, 0) + if _, ok := plan.(*scanPlanFail); !ok { + return plan.Scan(src, dst) + } + } + for oid := range defaultMap.oidToType { + if _, ok := plan.m.oidToType[oid]; !ok { + plan := plan.m.planScan(oid, plan.formatCode, dst, 0) + if _, ok := plan.(*scanPlanFail); !ok { + return plan.Scan(src, dst) + } + } + } + } + + var format string + switch plan.formatCode { + case TextFormatCode: + format = "text" + case BinaryFormatCode: + format = "binary" + default: + format = fmt.Sprintf("unknown %d", plan.formatCode) + } + + var dataTypeName string + if t, ok := plan.m.TypeForOID(plan.oid); ok { + dataTypeName = t.Name + } else { + dataTypeName = "unknown type" + } + + return fmt.Errorf("cannot scan %s (OID %d) in %v format into %T", dataTypeName, plan.oid, format, dst) +} + +// TryWrapScanPlanFunc is a function that tries to create a wrapper plan for target. If successful it returns a plan +// that will convert the target passed to Scan and then call the next plan. nextTarget is target as it will be converted +// by plan. It must be used to find another suitable ScanPlan. When it is found SetNext must be called on plan for it +// to be usabled. ok indicates if a suitable wrapper was found. +type TryWrapScanPlanFunc func(target any) (plan WrappedScanPlanNextSetter, nextTarget any, ok bool) + +type pointerPointerScanPlan struct { + dstType reflect.Type + next ScanPlan +} + +func (plan *pointerPointerScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *pointerPointerScanPlan) Scan(src []byte, dst any) error { + el := reflect.ValueOf(dst).Elem() + if src == nil { + el.Set(reflect.Zero(el.Type())) + return nil + } + + el.Set(reflect.New(el.Type().Elem())) + return plan.next.Scan(src, el.Interface()) +} + +// TryPointerPointerScanPlan handles a pointer to a pointer by setting the target to nil for SQL NULL and allocating and +// scanning for non-NULL. +func TryPointerPointerScanPlan(target any) (plan WrappedScanPlanNextSetter, nextTarget any, ok bool) { + if dstValue := reflect.ValueOf(target); dstValue.Kind() == reflect.Ptr { + elemValue := dstValue.Elem() + if elemValue.Kind() == reflect.Ptr { + plan = &pointerPointerScanPlan{dstType: dstValue.Type()} + return plan, reflect.Zero(elemValue.Type()).Interface(), true + } + } + + return nil, nil, false +} + +// SkipUnderlyingTypePlanner prevents PlanScan and PlanDecode from trying to use the underlying type. +type SkipUnderlyingTypePlanner interface { + SkipUnderlyingTypePlan() +} + +var elemKindToPointerTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{ + reflect.Int: reflect.TypeOf(new(int)), + reflect.Int8: reflect.TypeOf(new(int8)), + reflect.Int16: reflect.TypeOf(new(int16)), + reflect.Int32: reflect.TypeOf(new(int32)), + reflect.Int64: reflect.TypeOf(new(int64)), + reflect.Uint: reflect.TypeOf(new(uint)), + reflect.Uint8: reflect.TypeOf(new(uint8)), + reflect.Uint16: reflect.TypeOf(new(uint16)), + reflect.Uint32: reflect.TypeOf(new(uint32)), + reflect.Uint64: reflect.TypeOf(new(uint64)), + reflect.Float32: reflect.TypeOf(new(float32)), + reflect.Float64: reflect.TypeOf(new(float64)), + reflect.String: reflect.TypeOf(new(string)), + reflect.Bool: reflect.TypeOf(new(bool)), +} + +type underlyingTypeScanPlan struct { + dstType reflect.Type + nextDstType reflect.Type + next ScanPlan +} + +func (plan *underlyingTypeScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *underlyingTypeScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, reflect.ValueOf(dst).Convert(plan.nextDstType).Interface()) +} + +// TryFindUnderlyingTypeScanPlan tries to convert to a Go builtin type. e.g. If value was of type MyString and +// MyString was defined as a string then a wrapper plan would be returned that converts MyString to string. +func TryFindUnderlyingTypeScanPlan(dst any) (plan WrappedScanPlanNextSetter, nextDst any, ok bool) { + if _, ok := dst.(SkipUnderlyingTypePlanner); ok { + return nil, nil, false + } + + dstValue := reflect.ValueOf(dst) + + if dstValue.Kind() == reflect.Ptr { + var elemValue reflect.Value + if dstValue.IsNil() { + elemValue = reflect.New(dstValue.Type().Elem()).Elem() + } else { + elemValue = dstValue.Elem() + } + nextDstType := elemKindToPointerTypes[elemValue.Kind()] + if nextDstType == nil { + if elemValue.Kind() == reflect.Slice { + if elemValue.Type().Elem().Kind() == reflect.Uint8 { + var v *[]byte + nextDstType = reflect.TypeOf(v) + } + } + + // Get underlying type of any array. + // https://github.com/jackc/pgx/issues/2107 + if elemValue.Kind() == reflect.Array { + nextDstType = reflect.PointerTo(reflect.ArrayOf(elemValue.Len(), elemValue.Type().Elem())) + } + } + + if nextDstType != nil && dstValue.Type() != nextDstType && dstValue.CanConvert(nextDstType) { + return &underlyingTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true + } + } + + return nil, nil, false +} + +type WrappedScanPlanNextSetter interface { + SetNext(ScanPlan) + ScanPlan +} + +// TryWrapBuiltinTypeScanPlan tries to wrap a builtin type with a wrapper that provides additional methods. e.g. If +// value was of type int32 then a wrapper plan would be returned that converts target to a value that implements +// Int64Scanner. +func TryWrapBuiltinTypeScanPlan(target any) (plan WrappedScanPlanNextSetter, nextDst any, ok bool) { + switch target := target.(type) { + case *int8: + return &wrapInt8ScanPlan{}, (*int8Wrapper)(target), true + case *int16: + return &wrapInt16ScanPlan{}, (*int16Wrapper)(target), true + case *int32: + return &wrapInt32ScanPlan{}, (*int32Wrapper)(target), true + case *int64: + return &wrapInt64ScanPlan{}, (*int64Wrapper)(target), true + case *int: + return &wrapIntScanPlan{}, (*intWrapper)(target), true + case *uint8: + return &wrapUint8ScanPlan{}, (*uint8Wrapper)(target), true + case *uint16: + return &wrapUint16ScanPlan{}, (*uint16Wrapper)(target), true + case *uint32: + return &wrapUint32ScanPlan{}, (*uint32Wrapper)(target), true + case *uint64: + return &wrapUint64ScanPlan{}, (*uint64Wrapper)(target), true + case *uint: + return &wrapUintScanPlan{}, (*uintWrapper)(target), true + case *float32: + return &wrapFloat32ScanPlan{}, (*float32Wrapper)(target), true + case *float64: + return &wrapFloat64ScanPlan{}, (*float64Wrapper)(target), true + case *string: + return &wrapStringScanPlan{}, (*stringWrapper)(target), true + case *time.Time: + return &wrapTimeScanPlan{}, (*timeWrapper)(target), true + case *time.Duration: + return &wrapDurationScanPlan{}, (*durationWrapper)(target), true + case *net.IPNet: + return &wrapNetIPNetScanPlan{}, (*netIPNetWrapper)(target), true + case *net.IP: + return &wrapNetIPScanPlan{}, (*netIPWrapper)(target), true + case *netip.Prefix: + return &wrapNetipPrefixScanPlan{}, (*netipPrefixWrapper)(target), true + case *netip.Addr: + return &wrapNetipAddrScanPlan{}, (*netipAddrWrapper)(target), true + case *map[string]*string: + return &wrapMapStringToPointerStringScanPlan{}, (*mapStringToPointerStringWrapper)(target), true + case *map[string]string: + return &wrapMapStringToStringScanPlan{}, (*mapStringToStringWrapper)(target), true + case *[16]byte: + return &wrapByte16ScanPlan{}, (*byte16Wrapper)(target), true + case *[]byte: + return &wrapByteSliceScanPlan{}, (*byteSliceWrapper)(target), true + } + + return nil, nil, false +} + +type wrapInt8ScanPlan struct { + next ScanPlan +} + +func (plan *wrapInt8ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapInt8ScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*int8Wrapper)(dst.(*int8))) +} + +type wrapInt16ScanPlan struct { + next ScanPlan +} + +func (plan *wrapInt16ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapInt16ScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*int16Wrapper)(dst.(*int16))) +} + +type wrapInt32ScanPlan struct { + next ScanPlan +} + +func (plan *wrapInt32ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapInt32ScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*int32Wrapper)(dst.(*int32))) +} + +type wrapInt64ScanPlan struct { + next ScanPlan +} + +func (plan *wrapInt64ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapInt64ScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*int64Wrapper)(dst.(*int64))) +} + +type wrapIntScanPlan struct { + next ScanPlan +} + +func (plan *wrapIntScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapIntScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*intWrapper)(dst.(*int))) +} + +type wrapUint8ScanPlan struct { + next ScanPlan +} + +func (plan *wrapUint8ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapUint8ScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*uint8Wrapper)(dst.(*uint8))) +} + +type wrapUint16ScanPlan struct { + next ScanPlan +} + +func (plan *wrapUint16ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapUint16ScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*uint16Wrapper)(dst.(*uint16))) +} + +type wrapUint32ScanPlan struct { + next ScanPlan +} + +func (plan *wrapUint32ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapUint32ScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*uint32Wrapper)(dst.(*uint32))) +} + +type wrapUint64ScanPlan struct { + next ScanPlan +} + +func (plan *wrapUint64ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapUint64ScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*uint64Wrapper)(dst.(*uint64))) +} + +type wrapUintScanPlan struct { + next ScanPlan +} + +func (plan *wrapUintScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapUintScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*uintWrapper)(dst.(*uint))) +} + +type wrapFloat32ScanPlan struct { + next ScanPlan +} + +func (plan *wrapFloat32ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapFloat32ScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*float32Wrapper)(dst.(*float32))) +} + +type wrapFloat64ScanPlan struct { + next ScanPlan +} + +func (plan *wrapFloat64ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapFloat64ScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*float64Wrapper)(dst.(*float64))) +} + +type wrapStringScanPlan struct { + next ScanPlan +} + +func (plan *wrapStringScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapStringScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*stringWrapper)(dst.(*string))) +} + +type wrapTimeScanPlan struct { + next ScanPlan +} + +func (plan *wrapTimeScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapTimeScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*timeWrapper)(dst.(*time.Time))) +} + +type wrapDurationScanPlan struct { + next ScanPlan +} + +func (plan *wrapDurationScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapDurationScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*durationWrapper)(dst.(*time.Duration))) +} + +type wrapNetIPNetScanPlan struct { + next ScanPlan +} + +func (plan *wrapNetIPNetScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapNetIPNetScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*netIPNetWrapper)(dst.(*net.IPNet))) +} + +type wrapNetIPScanPlan struct { + next ScanPlan +} + +func (plan *wrapNetIPScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapNetIPScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*netIPWrapper)(dst.(*net.IP))) +} + +type wrapNetipPrefixScanPlan struct { + next ScanPlan +} + +func (plan *wrapNetipPrefixScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapNetipPrefixScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*netipPrefixWrapper)(dst.(*netip.Prefix))) +} + +type wrapNetipAddrScanPlan struct { + next ScanPlan +} + +func (plan *wrapNetipAddrScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapNetipAddrScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*netipAddrWrapper)(dst.(*netip.Addr))) +} + +type wrapMapStringToPointerStringScanPlan struct { + next ScanPlan +} + +func (plan *wrapMapStringToPointerStringScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapMapStringToPointerStringScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*mapStringToPointerStringWrapper)(dst.(*map[string]*string))) +} + +type wrapMapStringToStringScanPlan struct { + next ScanPlan +} + +func (plan *wrapMapStringToStringScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapMapStringToStringScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*mapStringToStringWrapper)(dst.(*map[string]string))) +} + +type wrapByte16ScanPlan struct { + next ScanPlan +} + +func (plan *wrapByte16ScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapByte16ScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*byte16Wrapper)(dst.(*[16]byte))) +} + +type wrapByteSliceScanPlan struct { + next ScanPlan +} + +func (plan *wrapByteSliceScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapByteSliceScanPlan) Scan(src []byte, dst any) error { + return plan.next.Scan(src, (*byteSliceWrapper)(dst.(*[]byte))) +} + +type pointerEmptyInterfaceScanPlan struct { + codec Codec + m *Map + oid uint32 + formatCode int16 +} + +func (plan *pointerEmptyInterfaceScanPlan) Scan(src []byte, dst any) error { + value, err := plan.codec.DecodeValue(plan.m, plan.oid, plan.formatCode, src) + if err != nil { + return err + } + + ptrAny := dst.(*any) + *ptrAny = value + + return nil +} + +// TryWrapStructPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter. +func TryWrapStructScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValue any, ok bool) { + targetValue := reflect.ValueOf(target) + if targetValue.Kind() != reflect.Ptr { + return nil, nil, false + } + + var targetElemValue reflect.Value + if targetValue.IsNil() { + targetElemValue = reflect.Zero(targetValue.Type().Elem()) + } else { + targetElemValue = targetValue.Elem() + } + targetElemType := targetElemValue.Type() + + if targetElemType.Kind() == reflect.Struct { + exportedFields := getExportedFieldValues(targetElemValue) + if len(exportedFields) == 0 { + return nil, nil, false + } + + w := ptrStructWrapper{ + s: target, + exportedFields: exportedFields, + } + return &wrapAnyPtrStructScanPlan{}, &w, true + } + + return nil, nil, false +} + +type wrapAnyPtrStructScanPlan struct { + next ScanPlan +} + +func (plan *wrapAnyPtrStructScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapAnyPtrStructScanPlan) Scan(src []byte, target any) error { + w := ptrStructWrapper{ + s: target, + exportedFields: getExportedFieldValues(reflect.ValueOf(target).Elem()), + } + + return plan.next.Scan(src, &w) +} + +// TryWrapPtrSliceScanPlan tries to wrap a pointer to a single dimension slice. +func TryWrapPtrSliceScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValue any, ok bool) { + // Avoid using reflect path for common types. + switch target := target.(type) { + case *[]int16: + return &wrapPtrSliceScanPlan[int16]{}, (*FlatArray[int16])(target), true + case *[]int32: + return &wrapPtrSliceScanPlan[int32]{}, (*FlatArray[int32])(target), true + case *[]int64: + return &wrapPtrSliceScanPlan[int64]{}, (*FlatArray[int64])(target), true + case *[]float32: + return &wrapPtrSliceScanPlan[float32]{}, (*FlatArray[float32])(target), true + case *[]float64: + return &wrapPtrSliceScanPlan[float64]{}, (*FlatArray[float64])(target), true + case *[]string: + return &wrapPtrSliceScanPlan[string]{}, (*FlatArray[string])(target), true + case *[]time.Time: + return &wrapPtrSliceScanPlan[time.Time]{}, (*FlatArray[time.Time])(target), true + } + + targetType := reflect.TypeOf(target) + if targetType.Kind() != reflect.Ptr { + return nil, nil, false + } + + targetElemType := targetType.Elem() + + if targetElemType.Kind() == reflect.Slice { + slice := reflect.New(targetElemType).Elem() + return &wrapPtrSliceReflectScanPlan{}, &anySliceArrayReflect{slice: slice}, true + } + return nil, nil, false +} + +type wrapPtrSliceScanPlan[T any] struct { + next ScanPlan +} + +func (plan *wrapPtrSliceScanPlan[T]) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapPtrSliceScanPlan[T]) Scan(src []byte, target any) error { + return plan.next.Scan(src, (*FlatArray[T])(target.(*[]T))) +} + +type wrapPtrSliceReflectScanPlan struct { + next ScanPlan +} + +func (plan *wrapPtrSliceReflectScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapPtrSliceReflectScanPlan) Scan(src []byte, target any) error { + return plan.next.Scan(src, &anySliceArrayReflect{slice: reflect.ValueOf(target).Elem()}) +} + +// TryWrapPtrMultiDimSliceScanPlan tries to wrap a pointer to a multi-dimension slice. +func TryWrapPtrMultiDimSliceScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValue any, ok bool) { + targetValue := reflect.ValueOf(target) + if targetValue.Kind() != reflect.Ptr { + return nil, nil, false + } + + targetElemValue := targetValue.Elem() + + if targetElemValue.Kind() == reflect.Slice { + elemElemKind := targetElemValue.Type().Elem().Kind() + if elemElemKind == reflect.Slice { + if !isRagged(targetElemValue) { + return &wrapPtrMultiDimSliceScanPlan{}, &anyMultiDimSliceArray{slice: targetValue.Elem()}, true + } + } + } + + return nil, nil, false +} + +type wrapPtrMultiDimSliceScanPlan struct { + next ScanPlan +} + +func (plan *wrapPtrMultiDimSliceScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapPtrMultiDimSliceScanPlan) Scan(src []byte, target any) error { + return plan.next.Scan(src, &anyMultiDimSliceArray{slice: reflect.ValueOf(target).Elem()}) +} + +// TryWrapPtrArrayScanPlan tries to wrap a pointer to a single dimension array. +func TryWrapPtrArrayScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValue any, ok bool) { + targetValue := reflect.ValueOf(target) + if targetValue.Kind() != reflect.Ptr { + return nil, nil, false + } + + targetElemValue := targetValue.Elem() + + if targetElemValue.Kind() == reflect.Array { + return &wrapPtrArrayReflectScanPlan{}, &anyArrayArrayReflect{array: targetElemValue}, true + } + return nil, nil, false +} + +type wrapPtrArrayReflectScanPlan struct { + next ScanPlan +} + +func (plan *wrapPtrArrayReflectScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapPtrArrayReflectScanPlan) Scan(src []byte, target any) error { + return plan.next.Scan(src, &anyArrayArrayReflect{array: reflect.ValueOf(target).Elem()}) +} + +// PlanScan prepares a plan to scan a value into target. +func (m *Map) PlanScan(oid uint32, formatCode int16, target any) ScanPlan { + return m.planScan(oid, formatCode, target, 0) +} + +func (m *Map) planScan(oid uint32, formatCode int16, target any, depth int) ScanPlan { + if depth > 8 { + return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} + } + + if target == nil { + return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} + } + + if _, ok := target.(*UndecodedBytes); ok { + return scanPlanAnyToUndecodedBytes{} + } + + switch formatCode { + case BinaryFormatCode: + switch target.(type) { + case *string: + switch oid { + case TextOID, VarcharOID: + return scanPlanString{} + } + } + case TextFormatCode: + switch target.(type) { + case *string: + return scanPlanString{} + case *[]byte: + if oid != ByteaOID { + return scanPlanAnyTextToBytes{} + } + case TextScanner: + return scanPlanTextAnyToTextScanner{} + } + } + + var dt *Type + + if dataType, ok := m.TypeForOID(oid); ok { + dt = dataType + } else if dataType, ok := m.TypeForValue(target); ok { + dt = dataType + oid = dt.OID // Preserve assumed OID in case we are recursively called below. + } + + if dt != nil { + if plan := dt.Codec.PlanScan(m, oid, formatCode, target); plan != nil { + return plan + } + } + + // This needs to happen before trying m.TryWrapScanPlanFuncs. Otherwise, a sql.Scanner would not get called if it was + // defined on a type that could be unwrapped such as `type myString string`. + // + // https://github.com/jackc/pgtype/issues/197 + if _, ok := target.(sql.Scanner); ok { + if dt == nil { + return &scanPlanSQLScanner{formatCode: formatCode} + } else { + return &scanPlanCodecSQLScanner{c: dt.Codec, m: m, oid: oid, formatCode: formatCode} + } + } + + for _, f := range m.TryWrapScanPlanFuncs { + if wrapperPlan, nextDst, ok := f(target); ok { + if nextPlan := m.planScan(oid, formatCode, nextDst, depth+1); nextPlan != nil { + if _, failed := nextPlan.(*scanPlanFail); !failed { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } + } + } + } + + if _, ok := target.(*any); ok { + var codec Codec + if dt != nil { + codec = dt.Codec + } else { + if formatCode == TextFormatCode { + codec = TextCodec{} + } else { + codec = ByteaCodec{} + } + } + return &pointerEmptyInterfaceScanPlan{codec: codec, m: m, oid: oid, formatCode: formatCode} + } + + return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} +} + +func (m *Map) Scan(oid uint32, formatCode int16, src []byte, dst any) error { + if dst == nil { + return nil + } + + plan := m.PlanScan(oid, formatCode, dst) + return plan.Scan(src, dst) +} + +var ErrScanTargetTypeChanged = errors.New("scan target type changed") + +func codecScan(codec Codec, m *Map, oid uint32, format int16, src []byte, dst any) error { + scanPlan := codec.PlanScan(m, oid, format, dst) + if scanPlan == nil { + return fmt.Errorf("PlanScan did not find a plan") + } + return scanPlan.Scan(src, dst) +} + +func codecDecodeToTextFormat(codec Codec, m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + if format == TextFormatCode { + return string(src), nil + } else { + value, err := codec.DecodeValue(m, oid, format, src) + if err != nil { + return nil, err + } + buf, err := m.Encode(oid, TextFormatCode, value, nil) + if err != nil { + return nil, err + } + return string(buf), nil + } +} + +// PlanEncode returns an EncodePlan for encoding value into PostgreSQL format for oid and format. If no plan can be +// found then nil is returned. +func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan { + return m.planEncodeDepth(oid, format, value, 0) +} + +func (m *Map) planEncodeDepth(oid uint32, format int16, value any, depth int) EncodePlan { + // Guard against infinite recursion. + if depth > 8 { + return nil + } + + oidMemo := m.memoizedEncodePlans[oid] + if oidMemo == nil { + oidMemo = make(map[reflect.Type][2]EncodePlan) + m.memoizedEncodePlans[oid] = oidMemo + } + targetReflectType := reflect.TypeOf(value) + typeMemo := oidMemo[targetReflectType] + plan := typeMemo[format] + if plan == nil { + plan = m.planEncode(oid, format, value, depth) + typeMemo[format] = plan + oidMemo[targetReflectType] = typeMemo + } + + return plan +} + +func (m *Map) planEncode(oid uint32, format int16, value any, depth int) EncodePlan { + if format == TextFormatCode { + switch value.(type) { + case string: + return encodePlanStringToAnyTextFormat{} + case TextValuer: + return encodePlanTextValuerToAnyTextFormat{} + } + } + + var dt *Type + if dataType, ok := m.TypeForOID(oid); ok { + dt = dataType + } else { + // If no type for the OID was found, then either it is unknowable (e.g. the simple protocol) or it is an + // unregistered type. In either case try to find the type and OID that matches the value (e.g. a []byte would be + // registered to PostgreSQL bytea). + if dataType, ok := m.TypeForValue(value); ok { + dt = dataType + oid = dt.OID // Preserve assumed OID in case we are recursively called below. + } + } + + if dt != nil { + if plan := dt.Codec.PlanEncode(m, oid, format, value); plan != nil { + return plan + } + } + + for _, f := range m.TryWrapEncodePlanFuncs { + if wrapperPlan, nextValue, ok := f(value); ok { + if nextPlan := m.planEncodeDepth(oid, format, nextValue, depth+1); nextPlan != nil { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } + } + } + + if _, ok := value.(driver.Valuer); ok { + return &encodePlanDriverValuer{m: m, oid: oid, formatCode: format} + } + + return nil +} + +type encodePlanStringToAnyTextFormat struct{} + +func (encodePlanStringToAnyTextFormat) Encode(value any, buf []byte) (newBuf []byte, err error) { + s := value.(string) + return append(buf, s...), nil +} + +type encodePlanTextValuerToAnyTextFormat struct{} + +func (encodePlanTextValuerToAnyTextFormat) Encode(value any, buf []byte) (newBuf []byte, err error) { + t, err := value.(TextValuer).TextValue() + if err != nil { + return nil, err + } + if !t.Valid { + return nil, nil + } + + return append(buf, t.String...), nil +} + +type encodePlanDriverValuer struct { + m *Map + oid uint32 + formatCode int16 +} + +func (plan *encodePlanDriverValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + dv := value.(driver.Valuer) + if dv == nil { + return nil, nil + } + v, err := dv.Value() + if err != nil { + return nil, err + } + if v == nil { + return nil, nil + } + + newBuf, err = plan.m.Encode(plan.oid, plan.formatCode, v, buf) + if err == nil { + return newBuf, nil + } + + s, ok := v.(string) + if !ok { + return nil, err + } + + var scannedValue any + scanErr := plan.m.Scan(plan.oid, TextFormatCode, []byte(s), &scannedValue) + if scanErr != nil { + return nil, err + } + + // Prevent infinite loop. We can't encode this. See https://github.com/jackc/pgx/issues/1331. + if reflect.TypeOf(value) == reflect.TypeOf(scannedValue) { + return nil, fmt.Errorf("tried to encode %v via encoding to text and scanning but failed due to receiving same type back", value) + } + + var err2 error + newBuf, err2 = plan.m.Encode(plan.oid, BinaryFormatCode, scannedValue, buf) + if err2 != nil { + return nil, err + } + + return newBuf, nil +} + +// TryWrapEncodePlanFunc is a function that tries to create a wrapper plan for value. If successful it returns a plan +// that will convert the value passed to Encode and then call the next plan. nextValue is value as it will be converted +// by plan. It must be used to find another suitable EncodePlan. When it is found SetNext must be called on plan for it +// to be usabled. ok indicates if a suitable wrapper was found. +type TryWrapEncodePlanFunc func(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) + +type derefPointerEncodePlan struct { + next EncodePlan +} + +func (plan *derefPointerEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *derefPointerEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + ptr := reflect.ValueOf(value) + + if ptr.IsNil() { + return nil, nil + } + + return plan.next.Encode(ptr.Elem().Interface(), buf) +} + +// TryWrapDerefPointerEncodePlan tries to dereference a pointer. e.g. If value was of type *string then a wrapper plan +// would be returned that dereferences the value. +func TryWrapDerefPointerEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { + if _, ok := value.(driver.Valuer); ok { + return nil, nil, false + } + + if valueType := reflect.TypeOf(value); valueType != nil && valueType.Kind() == reflect.Ptr { + return &derefPointerEncodePlan{}, reflect.New(valueType.Elem()).Elem().Interface(), true + } + + return nil, nil, false +} + +var kindToTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{ + reflect.Int: reflect.TypeOf(int(0)), + reflect.Int8: reflect.TypeOf(int8(0)), + reflect.Int16: reflect.TypeOf(int16(0)), + reflect.Int32: reflect.TypeOf(int32(0)), + reflect.Int64: reflect.TypeOf(int64(0)), + reflect.Uint: reflect.TypeOf(uint(0)), + reflect.Uint8: reflect.TypeOf(uint8(0)), + reflect.Uint16: reflect.TypeOf(uint16(0)), + reflect.Uint32: reflect.TypeOf(uint32(0)), + reflect.Uint64: reflect.TypeOf(uint64(0)), + reflect.Float32: reflect.TypeOf(float32(0)), + reflect.Float64: reflect.TypeOf(float64(0)), + reflect.String: reflect.TypeOf(""), + reflect.Bool: reflect.TypeOf(false), +} + +var byteSliceType = reflect.TypeOf([]byte{}) + +type underlyingTypeEncodePlan struct { + nextValueType reflect.Type + next EncodePlan +} + +func (plan *underlyingTypeEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *underlyingTypeEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(reflect.ValueOf(value).Convert(plan.nextValueType).Interface(), buf) +} + +// TryWrapFindUnderlyingTypeEncodePlan tries to convert to a Go builtin type. e.g. If value was of type MyString and +// MyString was defined as a string then a wrapper plan would be returned that converts MyString to string. +func TryWrapFindUnderlyingTypeEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { + if value == nil { + return nil, nil, false + } + + if _, ok := value.(driver.Valuer); ok { + return nil, nil, false + } + + if _, ok := value.(SkipUnderlyingTypePlanner); ok { + return nil, nil, false + } + + refValue := reflect.ValueOf(value) + + nextValueType := kindToTypes[refValue.Kind()] + if nextValueType != nil && refValue.Type() != nextValueType { + return &underlyingTypeEncodePlan{nextValueType: nextValueType}, refValue.Convert(nextValueType).Interface(), true + } + + // []byte is a special case. It is a slice but we treat it as a scalar type. In the case of a named type like + // json.RawMessage which is defined as []byte the underlying type should be considered as []byte. But any other slice + // does not have a special underlying type. + // + // https://github.com/jackc/pgx/issues/1763 + if refValue.Type() != byteSliceType && refValue.Type().AssignableTo(byteSliceType) { + return &underlyingTypeEncodePlan{nextValueType: byteSliceType}, refValue.Convert(byteSliceType).Interface(), true + } + + // Get underlying type of any array. + // https://github.com/jackc/pgx/issues/2107 + if refValue.Kind() == reflect.Array { + underlyingArrayType := reflect.ArrayOf(refValue.Len(), refValue.Type().Elem()) + if refValue.Type() != underlyingArrayType { + return &underlyingTypeEncodePlan{nextValueType: underlyingArrayType}, refValue.Convert(underlyingArrayType).Interface(), true + } + } + + return nil, nil, false +} + +type WrappedEncodePlanNextSetter interface { + SetNext(EncodePlan) + EncodePlan +} + +// TryWrapBuiltinTypeEncodePlan tries to wrap a builtin type with a wrapper that provides additional methods. e.g. If +// value was of type int32 then a wrapper plan would be returned that converts value to a type that implements +// Int64Valuer. +func TryWrapBuiltinTypeEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { + if _, ok := value.(driver.Valuer); ok { + return nil, nil, false + } + + switch value := value.(type) { + case int8: + return &wrapInt8EncodePlan{}, int8Wrapper(value), true + case int16: + return &wrapInt16EncodePlan{}, int16Wrapper(value), true + case int32: + return &wrapInt32EncodePlan{}, int32Wrapper(value), true + case int64: + return &wrapInt64EncodePlan{}, int64Wrapper(value), true + case int: + return &wrapIntEncodePlan{}, intWrapper(value), true + case uint8: + return &wrapUint8EncodePlan{}, uint8Wrapper(value), true + case uint16: + return &wrapUint16EncodePlan{}, uint16Wrapper(value), true + case uint32: + return &wrapUint32EncodePlan{}, uint32Wrapper(value), true + case uint64: + return &wrapUint64EncodePlan{}, uint64Wrapper(value), true + case uint: + return &wrapUintEncodePlan{}, uintWrapper(value), true + case float32: + return &wrapFloat32EncodePlan{}, float32Wrapper(value), true + case float64: + return &wrapFloat64EncodePlan{}, float64Wrapper(value), true + case string: + return &wrapStringEncodePlan{}, stringWrapper(value), true + case time.Time: + return &wrapTimeEncodePlan{}, timeWrapper(value), true + case time.Duration: + return &wrapDurationEncodePlan{}, durationWrapper(value), true + case net.IPNet: + return &wrapNetIPNetEncodePlan{}, netIPNetWrapper(value), true + case net.IP: + return &wrapNetIPEncodePlan{}, netIPWrapper(value), true + case netip.Prefix: + return &wrapNetipPrefixEncodePlan{}, netipPrefixWrapper(value), true + case netip.Addr: + return &wrapNetipAddrEncodePlan{}, netipAddrWrapper(value), true + case map[string]*string: + return &wrapMapStringToPointerStringEncodePlan{}, mapStringToPointerStringWrapper(value), true + case map[string]string: + return &wrapMapStringToStringEncodePlan{}, mapStringToStringWrapper(value), true + case [16]byte: + return &wrapByte16EncodePlan{}, byte16Wrapper(value), true + case []byte: + return &wrapByteSliceEncodePlan{}, byteSliceWrapper(value), true + case fmt.Stringer: + return &wrapFmtStringerEncodePlan{}, fmtStringerWrapper{value}, true + } + + return nil, nil, false +} + +type wrapInt8EncodePlan struct { + next EncodePlan +} + +func (plan *wrapInt8EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapInt8EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(int8Wrapper(value.(int8)), buf) +} + +type wrapInt16EncodePlan struct { + next EncodePlan +} + +func (plan *wrapInt16EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapInt16EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(int16Wrapper(value.(int16)), buf) +} + +type wrapInt32EncodePlan struct { + next EncodePlan +} + +func (plan *wrapInt32EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapInt32EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(int32Wrapper(value.(int32)), buf) +} + +type wrapInt64EncodePlan struct { + next EncodePlan +} + +func (plan *wrapInt64EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapInt64EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(int64Wrapper(value.(int64)), buf) +} + +type wrapIntEncodePlan struct { + next EncodePlan +} + +func (plan *wrapIntEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapIntEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(intWrapper(value.(int)), buf) +} + +type wrapUint8EncodePlan struct { + next EncodePlan +} + +func (plan *wrapUint8EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapUint8EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(uint8Wrapper(value.(uint8)), buf) +} + +type wrapUint16EncodePlan struct { + next EncodePlan +} + +func (plan *wrapUint16EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapUint16EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(uint16Wrapper(value.(uint16)), buf) +} + +type wrapUint32EncodePlan struct { + next EncodePlan +} + +func (plan *wrapUint32EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapUint32EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(uint32Wrapper(value.(uint32)), buf) +} + +type wrapUint64EncodePlan struct { + next EncodePlan +} + +func (plan *wrapUint64EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapUint64EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(uint64Wrapper(value.(uint64)), buf) +} + +type wrapUintEncodePlan struct { + next EncodePlan +} + +func (plan *wrapUintEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapUintEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(uintWrapper(value.(uint)), buf) +} + +type wrapFloat32EncodePlan struct { + next EncodePlan +} + +func (plan *wrapFloat32EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapFloat32EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(float32Wrapper(value.(float32)), buf) +} + +type wrapFloat64EncodePlan struct { + next EncodePlan +} + +func (plan *wrapFloat64EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapFloat64EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(float64Wrapper(value.(float64)), buf) +} + +type wrapStringEncodePlan struct { + next EncodePlan +} + +func (plan *wrapStringEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapStringEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(stringWrapper(value.(string)), buf) +} + +type wrapTimeEncodePlan struct { + next EncodePlan +} + +func (plan *wrapTimeEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapTimeEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(timeWrapper(value.(time.Time)), buf) +} + +type wrapDurationEncodePlan struct { + next EncodePlan +} + +func (plan *wrapDurationEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapDurationEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(durationWrapper(value.(time.Duration)), buf) +} + +type wrapNetIPNetEncodePlan struct { + next EncodePlan +} + +func (plan *wrapNetIPNetEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapNetIPNetEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(netIPNetWrapper(value.(net.IPNet)), buf) +} + +type wrapNetIPEncodePlan struct { + next EncodePlan +} + +func (plan *wrapNetIPEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapNetIPEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(netIPWrapper(value.(net.IP)), buf) +} + +type wrapNetipPrefixEncodePlan struct { + next EncodePlan +} + +func (plan *wrapNetipPrefixEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapNetipPrefixEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(netipPrefixWrapper(value.(netip.Prefix)), buf) +} + +type wrapNetipAddrEncodePlan struct { + next EncodePlan +} + +func (plan *wrapNetipAddrEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapNetipAddrEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(netipAddrWrapper(value.(netip.Addr)), buf) +} + +type wrapMapStringToPointerStringEncodePlan struct { + next EncodePlan +} + +func (plan *wrapMapStringToPointerStringEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapMapStringToPointerStringEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(mapStringToPointerStringWrapper(value.(map[string]*string)), buf) +} + +type wrapMapStringToStringEncodePlan struct { + next EncodePlan +} + +func (plan *wrapMapStringToStringEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapMapStringToStringEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(mapStringToStringWrapper(value.(map[string]string)), buf) +} + +type wrapByte16EncodePlan struct { + next EncodePlan +} + +func (plan *wrapByte16EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapByte16EncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(byte16Wrapper(value.([16]byte)), buf) +} + +type wrapByteSliceEncodePlan struct { + next EncodePlan +} + +func (plan *wrapByteSliceEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapByteSliceEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(byteSliceWrapper(value.([]byte)), buf) +} + +type wrapFmtStringerEncodePlan struct { + next EncodePlan +} + +func (plan *wrapFmtStringerEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapFmtStringerEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(fmtStringerWrapper{value.(fmt.Stringer)}, buf) +} + +// TryWrapStructPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter. +func TryWrapStructEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { + if _, ok := value.(driver.Valuer); ok { + return nil, nil, false + } + + if valueType := reflect.TypeOf(value); valueType != nil && valueType.Kind() == reflect.Struct { + exportedFields := getExportedFieldValues(reflect.ValueOf(value)) + if len(exportedFields) == 0 { + return nil, nil, false + } + + w := structWrapper{ + s: value, + exportedFields: exportedFields, + } + return &wrapAnyStructEncodePlan{}, w, true + } + + return nil, nil, false +} + +type wrapAnyStructEncodePlan struct { + next EncodePlan +} + +func (plan *wrapAnyStructEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapAnyStructEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + w := structWrapper{ + s: value, + exportedFields: getExportedFieldValues(reflect.ValueOf(value)), + } + + return plan.next.Encode(w, buf) +} + +func getExportedFieldValues(structValue reflect.Value) []reflect.Value { + structType := structValue.Type() + exportedFields := make([]reflect.Value, 0, structValue.NumField()) + for i := 0; i < structType.NumField(); i++ { + sf := structType.Field(i) + if sf.IsExported() { + exportedFields = append(exportedFields, structValue.Field(i)) + } + } + + return exportedFields +} + +func TryWrapSliceEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { + if _, ok := value.(driver.Valuer); ok { + return nil, nil, false + } + + // Avoid using reflect path for common types. + switch value := value.(type) { + case []int16: + return &wrapSliceEncodePlan[int16]{}, (FlatArray[int16])(value), true + case []int32: + return &wrapSliceEncodePlan[int32]{}, (FlatArray[int32])(value), true + case []int64: + return &wrapSliceEncodePlan[int64]{}, (FlatArray[int64])(value), true + case []float32: + return &wrapSliceEncodePlan[float32]{}, (FlatArray[float32])(value), true + case []float64: + return &wrapSliceEncodePlan[float64]{}, (FlatArray[float64])(value), true + case []string: + return &wrapSliceEncodePlan[string]{}, (FlatArray[string])(value), true + case []time.Time: + return &wrapSliceEncodePlan[time.Time]{}, (FlatArray[time.Time])(value), true + } + + if valueType := reflect.TypeOf(value); valueType != nil && valueType.Kind() == reflect.Slice { + w := anySliceArrayReflect{ + slice: reflect.ValueOf(value), + } + return &wrapSliceEncodeReflectPlan{}, w, true + } + + return nil, nil, false +} + +type wrapSliceEncodePlan[T any] struct { + next EncodePlan +} + +func (plan *wrapSliceEncodePlan[T]) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapSliceEncodePlan[T]) Encode(value any, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode((FlatArray[T])(value.([]T)), buf) +} + +type wrapSliceEncodeReflectPlan struct { + next EncodePlan +} + +func (plan *wrapSliceEncodeReflectPlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapSliceEncodeReflectPlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + w := anySliceArrayReflect{ + slice: reflect.ValueOf(value), + } + + return plan.next.Encode(w, buf) +} + +func TryWrapMultiDimSliceEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { + if _, ok := value.(driver.Valuer); ok { + return nil, nil, false + } + + sliceValue := reflect.ValueOf(value) + if sliceValue.Kind() == reflect.Slice { + valueElemType := sliceValue.Type().Elem() + + if valueElemType.Kind() == reflect.Slice { + if !isRagged(sliceValue) { + w := anyMultiDimSliceArray{ + slice: reflect.ValueOf(value), + } + return &wrapMultiDimSliceEncodePlan{}, &w, true + } + } + } + + return nil, nil, false +} + +type wrapMultiDimSliceEncodePlan struct { + next EncodePlan +} + +func (plan *wrapMultiDimSliceEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapMultiDimSliceEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + w := anyMultiDimSliceArray{ + slice: reflect.ValueOf(value), + } + + return plan.next.Encode(&w, buf) +} + +func TryWrapArrayEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { + if _, ok := value.(driver.Valuer); ok { + return nil, nil, false + } + + if valueType := reflect.TypeOf(value); valueType != nil && valueType.Kind() == reflect.Array { + w := anyArrayArrayReflect{ + array: reflect.ValueOf(value), + } + return &wrapArrayEncodeReflectPlan{}, w, true + } + + return nil, nil, false +} + +type wrapArrayEncodeReflectPlan struct { + next EncodePlan +} + +func (plan *wrapArrayEncodeReflectPlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapArrayEncodeReflectPlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + w := anyArrayArrayReflect{ + array: reflect.ValueOf(value), + } + + return plan.next.Encode(w, buf) +} + +func newEncodeError(value any, m *Map, oid uint32, formatCode int16, err error) error { + var format string + switch formatCode { + case TextFormatCode: + format = "text" + case BinaryFormatCode: + format = "binary" + default: + format = fmt.Sprintf("unknown (%d)", formatCode) + } + + var dataTypeName string + if t, ok := m.TypeForOID(oid); ok { + dataTypeName = t.Name + } else { + dataTypeName = "unknown type" + } + + return fmt.Errorf("unable to encode %#v into %s format for %s (OID %d): %w", value, format, dataTypeName, oid, err) +} + +// Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return +// (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data +// written. +func (m *Map) Encode(oid uint32, formatCode int16, value any, buf []byte) (newBuf []byte, err error) { + if isNil, callNilDriverValuer := isNilDriverValuer(value); isNil { + if callNilDriverValuer { + newBuf, err = (&encodePlanDriverValuer{m: m, oid: oid, formatCode: formatCode}).Encode(value, buf) + if err != nil { + return nil, newEncodeError(value, m, oid, formatCode, err) + } + + return newBuf, nil + } else { + return nil, nil + } + } + + plan := m.PlanEncode(oid, formatCode, value) + if plan == nil { + return nil, newEncodeError(value, m, oid, formatCode, errors.New("cannot find encode plan")) + } + + newBuf, err = plan.Encode(value, buf) + if err != nil { + return nil, newEncodeError(value, m, oid, formatCode, err) + } + + return newBuf, nil +} + +// SQLScanner returns a database/sql.Scanner for v. This is necessary for types like Array[T] and Range[T] where the +// type needs assistance from Map to implement the sql.Scanner interface. It is not necessary for types like Box that +// implement sql.Scanner directly. +// +// This uses the type of v to look up the PostgreSQL OID that v presumably came from. This means v must be registered +// with m by calling RegisterDefaultPgType. +// +// As of Go 1.26, this should be unnecessary. +func (m *Map) SQLScanner(v any) sql.Scanner { + if s, ok := v.(sql.Scanner); ok { + return s + } + + return &sqlScannerWrapper{m: m, v: v} +} + +type sqlScannerWrapper struct { + m *Map + v any +} + +func (w *sqlScannerWrapper) Scan(src any) error { + t, ok := w.m.TypeForValue(w.v) + if !ok { + return fmt.Errorf("cannot convert to sql.Scanner: cannot find registered type for %T", w.v) + } + + var bufSrc []byte + if src != nil { + switch src := src.(type) { + case string: + bufSrc = []byte(src) + case []byte: + bufSrc = src + default: + bufSrc = []byte(fmt.Sprint(bufSrc)) + } + } + + return w.m.Scan(t.OID, TextFormatCode, bufSrc, w.v) +} + +var valuerReflectType = reflect.TypeFor[driver.Valuer]() + +// isNilDriverValuer returns true if value is any type of nil unless it implements driver.Valuer. *T is not considered to implement +// driver.Valuer if it is only implemented by T. +func isNilDriverValuer(value any) (isNil, callNilDriverValuer bool) { + if value == nil { + return true, false + } + + refVal := reflect.ValueOf(value) + kind := refVal.Kind() + switch kind { + case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice: + if !refVal.IsNil() { + return false, false + } + + if _, ok := value.(driver.Valuer); ok { + if kind == reflect.Ptr { + // The type assertion will succeed if driver.Valuer is implemented on T or *T. Check if it is implemented on *T + // by checking if it is not implemented on *T. + return true, !refVal.Type().Elem().Implements(valuerReflectType) + } else { + return true, true + } + } + + return true, false + default: + return false, false } } diff --git a/pgtype/pgtype_default.go b/pgtype/pgtype_default.go new file mode 100644 index 000000000..5648d89bf --- /dev/null +++ b/pgtype/pgtype_default.go @@ -0,0 +1,248 @@ +package pgtype + +import ( + "encoding/json" + "encoding/xml" + "net" + "net/netip" + "reflect" + "sync" + "time" +) + +var ( + // defaultMap contains default mappings between PostgreSQL server types and Go type handling logic. + defaultMap *Map + defaultMapInitOnce = sync.Once{} +) + +func initDefaultMap() { + defaultMap = &Map{ + oidToType: make(map[uint32]*Type), + nameToType: make(map[string]*Type), + reflectTypeToName: make(map[reflect.Type]string), + oidToFormatCode: make(map[uint32]int16), + + memoizedEncodePlans: make(map[uint32]map[reflect.Type][2]EncodePlan), + + TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{ + TryWrapDerefPointerEncodePlan, + TryWrapBuiltinTypeEncodePlan, + TryWrapFindUnderlyingTypeEncodePlan, + TryWrapStructEncodePlan, + TryWrapSliceEncodePlan, + TryWrapMultiDimSliceEncodePlan, + TryWrapArrayEncodePlan, + }, + + TryWrapScanPlanFuncs: []TryWrapScanPlanFunc{ + TryPointerPointerScanPlan, + TryWrapBuiltinTypeScanPlan, + TryFindUnderlyingTypeScanPlan, + TryWrapStructScanPlan, + TryWrapPtrSliceScanPlan, + TryWrapPtrMultiDimSliceScanPlan, + TryWrapPtrArrayScanPlan, + }, + } + + // Base types + defaultMap.RegisterType(&Type{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) + defaultMap.RegisterType(&Type{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) + defaultMap.RegisterType(&Type{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) + defaultMap.RegisterType(&Type{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) + defaultMap.RegisterType(&Type{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) + defaultMap.RegisterType(&Type{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) + defaultMap.RegisterType(&Type{Name: "char", OID: QCharOID, Codec: QCharCodec{}}) + defaultMap.RegisterType(&Type{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) + defaultMap.RegisterType(&Type{Name: "cidr", OID: CIDROID, Codec: InetCodec{}}) + defaultMap.RegisterType(&Type{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) + defaultMap.RegisterType(&Type{Name: "date", OID: DateOID, Codec: DateCodec{}}) + defaultMap.RegisterType(&Type{Name: "float4", OID: Float4OID, Codec: Float4Codec{}}) + defaultMap.RegisterType(&Type{Name: "float8", OID: Float8OID, Codec: Float8Codec{}}) + defaultMap.RegisterType(&Type{Name: "inet", OID: InetOID, Codec: InetCodec{}}) + defaultMap.RegisterType(&Type{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) + defaultMap.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) + defaultMap.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) + defaultMap.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) + defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: &JSONCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}}) + defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: &JSONBCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}}) + defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) + defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}}) + defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) + defaultMap.RegisterType(&Type{Name: "macaddr8", OID: Macaddr8OID, Codec: MacaddrCodec{}}) + defaultMap.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) + defaultMap.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}}) + defaultMap.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}}) + defaultMap.RegisterType(&Type{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}}) + defaultMap.RegisterType(&Type{Name: "path", OID: PathOID, Codec: PathCodec{}}) + defaultMap.RegisterType(&Type{Name: "point", OID: PointOID, Codec: PointCodec{}}) + defaultMap.RegisterType(&Type{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}}) + defaultMap.RegisterType(&Type{Name: "record", OID: RecordOID, Codec: RecordCodec{}}) + defaultMap.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}}) + defaultMap.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) + defaultMap.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) + defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: &TimestampCodec{}}) + defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: &TimestamptzCodec{}}) + defaultMap.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) + defaultMap.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) + defaultMap.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) + defaultMap.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) + defaultMap.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) + defaultMap.RegisterType(&Type{Name: "xid8", OID: XID8OID, Codec: Uint64Codec{}}) + defaultMap.RegisterType(&Type{Name: "xml", OID: XMLOID, Codec: &XMLCodec{ + Marshal: xml.Marshal, + // xml.Unmarshal does not support unmarshalling into *any. However, XMLCodec.DecodeValue calls Unmarshal with a + // *any. Wrap xml.Marshal with a function that copies the data into a new byte slice in this case. Not implementing + // directly in XMLCodec.DecodeValue to allow for the unlikely possibility that someone uses an alternative XML + // unmarshaler that does support unmarshalling into *any. + // + // https://github.com/jackc/pgx/issues/2227 + // https://github.com/jackc/pgx/pull/2228 + Unmarshal: func(data []byte, v any) error { + if v, ok := v.(*any); ok { + dstBuf := make([]byte, len(data)) + copy(dstBuf, data) + *v = dstBuf + return nil + } + return xml.Unmarshal(data, v) + }, + }}) + + // Range types + defaultMap.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[DateOID]}}) + defaultMap.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[Int4OID]}}) + defaultMap.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[Int8OID]}}) + defaultMap.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[NumericOID]}}) + defaultMap.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[TimestampOID]}}) + defaultMap.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[TimestamptzOID]}}) + + // Multirange types + defaultMap.RegisterType(&Type{Name: "datemultirange", OID: DatemultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[DaterangeOID]}}) + defaultMap.RegisterType(&Type{Name: "int4multirange", OID: Int4multirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[Int4rangeOID]}}) + defaultMap.RegisterType(&Type{Name: "int8multirange", OID: Int8multirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[Int8rangeOID]}}) + defaultMap.RegisterType(&Type{Name: "nummultirange", OID: NummultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[NumrangeOID]}}) + defaultMap.RegisterType(&Type{Name: "tsmultirange", OID: TsmultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[TsrangeOID]}}) + defaultMap.RegisterType(&Type{Name: "tstzmultirange", OID: TstzmultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[TstzrangeOID]}}) + + // Array types + defaultMap.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[ACLItemOID]}}) + defaultMap.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BitOID]}}) + defaultMap.RegisterType(&Type{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BoolOID]}}) + defaultMap.RegisterType(&Type{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BoxOID]}}) + defaultMap.RegisterType(&Type{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BPCharOID]}}) + defaultMap.RegisterType(&Type{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[ByteaOID]}}) + defaultMap.RegisterType(&Type{Name: "_char", OID: QCharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[QCharOID]}}) + defaultMap.RegisterType(&Type{Name: "_cid", OID: CIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[CIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[CIDROID]}}) + defaultMap.RegisterType(&Type{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[CircleOID]}}) + defaultMap.RegisterType(&Type{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[DateOID]}}) + defaultMap.RegisterType(&Type{Name: "_daterange", OID: DaterangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[DaterangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_float4", OID: Float4ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Float4OID]}}) + defaultMap.RegisterType(&Type{Name: "_float8", OID: Float8ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Float8OID]}}) + defaultMap.RegisterType(&Type{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[InetOID]}}) + defaultMap.RegisterType(&Type{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int2OID]}}) + defaultMap.RegisterType(&Type{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int4OID]}}) + defaultMap.RegisterType(&Type{Name: "_int4range", OID: Int4rangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int4rangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int8OID]}}) + defaultMap.RegisterType(&Type{Name: "_int8range", OID: Int8rangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int8rangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[IntervalOID]}}) + defaultMap.RegisterType(&Type{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[JSONOID]}}) + defaultMap.RegisterType(&Type{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[JSONBOID]}}) + defaultMap.RegisterType(&Type{Name: "_jsonpath", OID: JSONPathArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[JSONPathOID]}}) + defaultMap.RegisterType(&Type{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[LineOID]}}) + defaultMap.RegisterType(&Type{Name: "_lseg", OID: LsegArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[LsegOID]}}) + defaultMap.RegisterType(&Type{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[MacaddrOID]}}) + defaultMap.RegisterType(&Type{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[NameOID]}}) + defaultMap.RegisterType(&Type{Name: "_numeric", OID: NumericArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[NumericOID]}}) + defaultMap.RegisterType(&Type{Name: "_numrange", OID: NumrangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[NumrangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[OIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[PathOID]}}) + defaultMap.RegisterType(&Type{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[PointOID]}}) + defaultMap.RegisterType(&Type{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[PolygonOID]}}) + defaultMap.RegisterType(&Type{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[RecordOID]}}) + defaultMap.RegisterType(&Type{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TextOID]}}) + defaultMap.RegisterType(&Type{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimeOID]}}) + defaultMap.RegisterType(&Type{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimestampOID]}}) + defaultMap.RegisterType(&Type{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimestamptzOID]}}) + defaultMap.RegisterType(&Type{Name: "_tsrange", OID: TsrangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TsrangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_tstzrange", OID: TstzrangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TstzrangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_uuid", OID: UUIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[UUIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarbitOID]}}) + defaultMap.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarcharOID]}}) + defaultMap.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[XIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_xid8", OID: XID8ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[XID8OID]}}) + defaultMap.RegisterType(&Type{Name: "_xml", OID: XMLArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[XMLOID]}}) + + // Integer types that directly map to a PostgreSQL type + registerDefaultPgTypeVariants[int16](defaultMap, "int2") + registerDefaultPgTypeVariants[int32](defaultMap, "int4") + registerDefaultPgTypeVariants[int64](defaultMap, "int8") + + // Integer types that do not have a direct match to a PostgreSQL type + registerDefaultPgTypeVariants[int8](defaultMap, "int8") + registerDefaultPgTypeVariants[int](defaultMap, "int8") + registerDefaultPgTypeVariants[uint8](defaultMap, "int8") + registerDefaultPgTypeVariants[uint16](defaultMap, "int8") + registerDefaultPgTypeVariants[uint32](defaultMap, "int8") + registerDefaultPgTypeVariants[uint64](defaultMap, "numeric") + registerDefaultPgTypeVariants[uint](defaultMap, "numeric") + + registerDefaultPgTypeVariants[float32](defaultMap, "float4") + registerDefaultPgTypeVariants[float64](defaultMap, "float8") + + registerDefaultPgTypeVariants[bool](defaultMap, "bool") + registerDefaultPgTypeVariants[time.Time](defaultMap, "timestamptz") + registerDefaultPgTypeVariants[time.Duration](defaultMap, "interval") + registerDefaultPgTypeVariants[string](defaultMap, "text") + registerDefaultPgTypeVariants[json.RawMessage](defaultMap, "json") + registerDefaultPgTypeVariants[[]byte](defaultMap, "bytea") + + registerDefaultPgTypeVariants[net.IP](defaultMap, "inet") + registerDefaultPgTypeVariants[net.IPNet](defaultMap, "cidr") + registerDefaultPgTypeVariants[netip.Addr](defaultMap, "inet") + registerDefaultPgTypeVariants[netip.Prefix](defaultMap, "cidr") + + // pgtype provided structs + registerDefaultPgTypeVariants[Bits](defaultMap, "varbit") + registerDefaultPgTypeVariants[Bool](defaultMap, "bool") + registerDefaultPgTypeVariants[Box](defaultMap, "box") + registerDefaultPgTypeVariants[Circle](defaultMap, "circle") + registerDefaultPgTypeVariants[Date](defaultMap, "date") + registerDefaultPgTypeVariants[Range[Date]](defaultMap, "daterange") + registerDefaultPgTypeVariants[Multirange[Range[Date]]](defaultMap, "datemultirange") + registerDefaultPgTypeVariants[Float4](defaultMap, "float4") + registerDefaultPgTypeVariants[Float8](defaultMap, "float8") + registerDefaultPgTypeVariants[Range[Float8]](defaultMap, "numrange") // There is no PostgreSQL builtin float8range so map it to numrange. + registerDefaultPgTypeVariants[Multirange[Range[Float8]]](defaultMap, "nummultirange") // There is no PostgreSQL builtin float8multirange so map it to nummultirange. + registerDefaultPgTypeVariants[Int2](defaultMap, "int2") + registerDefaultPgTypeVariants[Int4](defaultMap, "int4") + registerDefaultPgTypeVariants[Range[Int4]](defaultMap, "int4range") + registerDefaultPgTypeVariants[Multirange[Range[Int4]]](defaultMap, "int4multirange") + registerDefaultPgTypeVariants[Int8](defaultMap, "int8") + registerDefaultPgTypeVariants[Range[Int8]](defaultMap, "int8range") + registerDefaultPgTypeVariants[Multirange[Range[Int8]]](defaultMap, "int8multirange") + registerDefaultPgTypeVariants[Interval](defaultMap, "interval") + registerDefaultPgTypeVariants[Line](defaultMap, "line") + registerDefaultPgTypeVariants[Lseg](defaultMap, "lseg") + registerDefaultPgTypeVariants[Numeric](defaultMap, "numeric") + registerDefaultPgTypeVariants[Range[Numeric]](defaultMap, "numrange") + registerDefaultPgTypeVariants[Multirange[Range[Numeric]]](defaultMap, "nummultirange") + registerDefaultPgTypeVariants[Path](defaultMap, "path") + registerDefaultPgTypeVariants[Point](defaultMap, "point") + registerDefaultPgTypeVariants[Polygon](defaultMap, "polygon") + registerDefaultPgTypeVariants[TID](defaultMap, "tid") + registerDefaultPgTypeVariants[Text](defaultMap, "text") + registerDefaultPgTypeVariants[Time](defaultMap, "time") + registerDefaultPgTypeVariants[Timestamp](defaultMap, "timestamp") + registerDefaultPgTypeVariants[Timestamptz](defaultMap, "timestamptz") + registerDefaultPgTypeVariants[Range[Timestamp]](defaultMap, "tsrange") + registerDefaultPgTypeVariants[Multirange[Range[Timestamp]]](defaultMap, "tsmultirange") + registerDefaultPgTypeVariants[Range[Timestamptz]](defaultMap, "tstzrange") + registerDefaultPgTypeVariants[Multirange[Range[Timestamptz]]](defaultMap, "tstzmultirange") + registerDefaultPgTypeVariants[UUID](defaultMap, "uuid") + + defaultMap.buildReflectTypeToType() +} diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index f7e743b29..510b0c62f 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -1,31 +1,79 @@ package pgtype_test import ( + "context" + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" "net" + "os" + "reflect" + "regexp" + "strconv" "testing" - _ "github.com/jackc/pgx/stdlib" - _ "github.com/lib/pq" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + _ "github.com/jackc/pgx/v5/stdlib" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +var defaultConnTestRunner pgxtest.ConnTestRunner + +func init() { + defaultConnTestRunner = pgxtest.DefaultConnTestRunner() + defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + return config + } +} + // Test for renamed types -type _string string -type _bool bool -type _int8 int8 -type _int16 int16 -type _int16Slice []int16 -type _int32Slice []int32 -type _int64Slice []int64 -type _float32Slice []float32 -type _float64Slice []float64 -type _byteSlice []byte - -func mustParseCIDR(t testing.TB, s string) *net.IPNet { - _, ipnet, err := net.ParseCIDR(s) - if err != nil { - t.Fatal(err) +type ( + _string string + _bool bool + _uint8 uint8 + _int8 int8 + _int16 int16 + _int16Slice []int16 + _int32Slice []int32 + _int64Slice []int64 + _float32Slice []float32 + _float64Slice []float64 + _byteSlice []byte +) + +// unregisteredOID represents an actual type that is not registered. Cannot use 0 because that represents that the type +// is not known (e.g. when using the simple protocol). +const unregisteredOID = uint32(1) + +func mustParseInet(t testing.TB, s string) *net.IPNet { + ip, ipnet, err := net.ParseCIDR(s) + if err == nil { + if ipv4 := ip.To4(); ipv4 != nil { + ipnet.IP = ipv4 + } else { + ipnet.IP = ip + } + return ipnet } + // May be bare IP address. + // + ip = net.ParseIP(s) + if ip == nil { + t.Fatal(errors.New("unable to parse inet address")) + } + ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} + if ipv4 := ip.To4(); ipv4 != nil { + ipnet.IP = ipv4 + ipnet.Mask = net.CIDRMask(32, 32) + } return ipnet } @@ -37,3 +85,590 @@ func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { return addr } + +func skipCockroachDB(t testing.TB, msg string) { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + t.Fatal(err) + } + defer conn.Close(context.Background()) + + if conn.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip(msg) + } +} + +func skipPostgreSQLVersionLessThan(t testing.TB, minVersion int64) { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + t.Fatal(err) + } + defer conn.Close(context.Background()) + + serverVersionStr := conn.PgConn().ParameterStatus("server_version") + serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr) + // if not PostgreSQL do nothing + if serverVersionStr == "" { + return + } + + serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64) + require.NoError(t, err) + + if serverVersion < minVersion { + t.Skipf("Test requires PostgreSQL v%d+", minVersion) + } +} + +// sqlScannerFunc lets an arbitrary function be used as a sql.Scanner. +type sqlScannerFunc func(src any) error + +func (f sqlScannerFunc) Scan(src any) error { + return f(src) +} + +// driverValuerFunc lets an arbitrary function be used as a driver.Valuer. +type driverValuerFunc func() (driver.Value, error) + +func (f driverValuerFunc) Value() (driver.Value, error) { + return f() +} + +func TestMapScanNilIsNoOp(t *testing.T) { + m := pgtype.NewMap() + + err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), nil) + assert.NoError(t, err) +} + +func TestMapScanTextFormatInterfacePtr(t *testing.T) { + m := pgtype.NewMap() + var got any + err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), &got) + require.NoError(t, err) + assert.Equal(t, "foo", got) +} + +func TestMapScanTextFormatNonByteaIntoByteSlice(t *testing.T) { + m := pgtype.NewMap() + var got []byte + err := m.Scan(pgtype.JSONBOID, pgx.TextFormatCode, []byte("{}"), &got) + require.NoError(t, err) + assert.Equal(t, []byte("{}"), got) +} + +func TestMapScanBinaryFormatInterfacePtr(t *testing.T) { + m := pgtype.NewMap() + var got any + err := m.Scan(pgtype.TextOID, pgx.BinaryFormatCode, []byte("foo"), &got) + require.NoError(t, err) + assert.Equal(t, "foo", got) +} + +func TestMapScanUnknownOIDToPtrToAny(t *testing.T) { + unknownOID := uint32(999999) + srcBuf := []byte("foo") + m := pgtype.NewMap() + + var a any + err := m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &a) + assert.NoError(t, err) + assert.Equal(t, "foo", a) + + err = m.Scan(unknownOID, pgx.BinaryFormatCode, srcBuf, &a) + assert.NoError(t, err) + assert.Equal(t, []byte("foo"), a) +} + +func TestMapScanUnknownOIDToStringsAndBytes(t *testing.T) { + unknownOID := uint32(999999) + srcBuf := []byte("foo") + m := pgtype.NewMap() + + var s string + err := m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &s) + assert.NoError(t, err) + assert.Equal(t, "foo", s) + + var rs _string + err = m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rs) + assert.NoError(t, err) + assert.Equal(t, "foo", string(rs)) + + var b []byte + err = m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &b) + assert.NoError(t, err) + assert.Equal(t, []byte("foo"), b) + + var rb _byteSlice + err = m.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rb) + assert.NoError(t, err) + assert.Equal(t, []byte("foo"), []byte(rb)) +} + +func TestMapScanPointerToNilStructDoesNotCrash(t *testing.T) { + m := pgtype.NewMap() + + type myStruct struct{} + var p *myStruct + err := m.Scan(0, pgx.TextFormatCode, []byte("(foo,bar)"), &p) + require.NotNil(t, err) +} + +func TestMapScanUnknownOIDTextFormat(t *testing.T) { + m := pgtype.NewMap() + + var n int32 + err := m.Scan(0, pgx.TextFormatCode, []byte("123"), &n) + assert.NoError(t, err) + assert.EqualValues(t, 123, n) +} + +func TestMapScanUnknownOIDIntoSQLScanner(t *testing.T) { + m := pgtype.NewMap() + + var s sql.NullString + err := m.Scan(0, pgx.TextFormatCode, []byte(nil), &s) + assert.NoError(t, err) + assert.Equal(t, "", s.String) + assert.False(t, s.Valid) +} + +type scannerString string + +func (ss *scannerString) Scan(v any) error { + *ss = scannerString("scanned") + return nil +} + +// https://github.com/jackc/pgtype/issues/197 +func TestMapScanUnregisteredOIDIntoRenamedStringSQLScanner(t *testing.T) { + m := pgtype.NewMap() + + var s scannerString + err := m.Scan(unregisteredOID, pgx.TextFormatCode, []byte(nil), &s) + assert.NoError(t, err) + assert.Equal(t, "scanned", string(s)) +} + +type pgCustomInt int64 + +func (ci *pgCustomInt) Scan(src interface{}) error { + *ci = pgCustomInt(src.(int64)) + return nil +} + +func TestScanPlanBinaryInt32ScanScanner(t *testing.T) { + m := pgtype.NewMap() + src := []byte{0, 42} + var v pgCustomInt + + plan := m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &v) + err := plan.Scan(src, &v) + require.NoError(t, err) + require.EqualValues(t, 42, v) + + ptr := new(pgCustomInt) + plan = m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) + err = plan.Scan(src, &ptr) + require.NoError(t, err) + require.EqualValues(t, 42, *ptr) + + ptr = new(pgCustomInt) + err = plan.Scan(nil, &ptr) + require.NoError(t, err) + assert.Nil(t, ptr) + + ptr = nil + plan = m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) + err = plan.Scan(src, &ptr) + require.NoError(t, err) + require.EqualValues(t, 42, *ptr) + + ptr = nil + plan = m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) + err = plan.Scan(nil, &ptr) + require.NoError(t, err) + assert.Nil(t, ptr) +} + +// Test for https://github.com/jackc/pgtype/issues/164 +func TestScanPlanInterface(t *testing.T) { + m := pgtype.NewMap() + src := []byte{0, 42} + var v interface{} + plan := m.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, v) + err := plan.Scan(src, v) + assert.Error(t, err) +} + +func TestPointerPointerStructScan(t *testing.T) { + m := pgtype.NewMap() + type composite struct { + ID int + } + + int4Type, _ := m.TypeForOID(pgtype.Int4OID) + pgt := &pgtype.Type{ + Codec: &pgtype.CompositeCodec{ + Fields: []pgtype.CompositeCodecField{ + { + Name: "id", + Type: int4Type, + }, + }, + }, + Name: "composite", + OID: 215333, + } + m.RegisterType(pgt) + + var c *composite + plan := m.PlanScan(pgt.OID, pgtype.TextFormatCode, &c) + err := plan.Scan([]byte("(1)"), &c) + require.NoError(t, err) + require.Equal(t, 1, c.ID) +} + +// https://github.com/jackc/pgx/issues/1263 +func TestMapScanPtrToPtrToSlice(t *testing.T) { + m := pgtype.NewMap() + src := []byte("{foo,bar}") + var v *[]string + plan := m.PlanScan(pgtype.TextArrayOID, pgtype.TextFormatCode, &v) + err := plan.Scan(src, &v) + require.NoError(t, err) + require.Equal(t, []string{"foo", "bar"}, *v) +} + +func TestMapScanPtrToPtrToSliceOfStruct(t *testing.T) { + type Team struct { + TeamID int + Name string + } + + // Have to use binary format because text format doesn't include type information. + m := pgtype.NewMap() + src := []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x8, 0xc9, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x1e, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x17, 0x0, 0x0, 0x0, 0x4, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x19, 0x0, 0x0, 0x0, 0x6, 0x74, 0x65, 0x61, 0x6d, 0x20, 0x31, 0x0, 0x0, 0x0, 0x1e, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x17, 0x0, 0x0, 0x0, 0x4, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x19, 0x0, 0x0, 0x0, 0x6, 0x74, 0x65, 0x61, 0x6d, 0x20, 0x32} + var v *[]Team + plan := m.PlanScan(pgtype.RecordArrayOID, pgtype.BinaryFormatCode, &v) + err := plan.Scan(src, &v) + require.NoError(t, err) + require.Equal(t, []Team{{1, "team 1"}, {2, "team 2"}}, *v) +} + +type databaseValuerString string + +func (s databaseValuerString) Value() (driver.Value, error) { + return fmt.Sprintf("%d", len(s)), nil +} + +// https://github.com/jackc/pgx/issues/1319 +func TestMapEncodeTextFormatDatabaseValuerThatIsRenamedSimpleType(t *testing.T) { + m := pgtype.NewMap() + src := databaseValuerString("foo") + buf, err := m.Encode(pgtype.TextOID, pgtype.TextFormatCode, src, nil) + require.NoError(t, err) + require.Equal(t, "3", string(buf)) +} + +type databaseValuerFmtStringer string + +func (s databaseValuerFmtStringer) Value() (driver.Value, error) { + return nil, nil +} + +func (s databaseValuerFmtStringer) String() string { + return "foobar" +} + +// https://github.com/jackc/pgx/issues/1311 +func TestMapEncodeTextFormatDatabaseValuerThatIsFmtStringer(t *testing.T) { + m := pgtype.NewMap() + src := databaseValuerFmtStringer("") + buf, err := m.Encode(pgtype.TextOID, pgtype.TextFormatCode, src, nil) + require.NoError(t, err) + require.Nil(t, buf) +} + +type databaseValuerStringFormat struct { + n int32 +} + +func (v databaseValuerStringFormat) Value() (driver.Value, error) { + return fmt.Sprint(v.n), nil +} + +func TestMapEncodeBinaryFormatDatabaseValuerThatReturnsString(t *testing.T) { + m := pgtype.NewMap() + src := databaseValuerStringFormat{n: 42} + buf, err := m.Encode(pgtype.Int4OID, pgtype.BinaryFormatCode, src, nil) + require.NoError(t, err) + require.Equal(t, []byte{0, 0, 0, 42}, buf) +} + +// https://github.com/jackc/pgx/issues/1445 +func TestMapEncodeDatabaseValuerThatReturnsStringIntoUnregisteredTypeTextFormat(t *testing.T) { + m := pgtype.NewMap() + buf, err := m.Encode(unregisteredOID, pgtype.TextFormatCode, driverValuerFunc(func() (driver.Value, error) { return "foo", nil }), nil) + require.NoError(t, err) + require.Equal(t, []byte("foo"), buf) +} + +// https://github.com/jackc/pgx/issues/1445 +func TestMapEncodeDatabaseValuerThatReturnsByteSliceIntoUnregisteredTypeTextFormat(t *testing.T) { + m := pgtype.NewMap() + buf, err := m.Encode(unregisteredOID, pgtype.TextFormatCode, driverValuerFunc(func() (driver.Value, error) { return []byte{0, 1, 2, 3}, nil }), nil) + require.NoError(t, err) + require.Equal(t, []byte(`\x00010203`), buf) +} + +func TestMapEncodeStringIntoUnregisteredTypeTextFormat(t *testing.T) { + m := pgtype.NewMap() + buf, err := m.Encode(unregisteredOID, pgtype.TextFormatCode, "foo", nil) + require.NoError(t, err) + require.Equal(t, []byte("foo"), buf) +} + +func TestMapEncodeByteSliceIntoUnregisteredTypeTextFormat(t *testing.T) { + m := pgtype.NewMap() + buf, err := m.Encode(unregisteredOID, pgtype.TextFormatCode, []byte{0, 1, 2, 3}, nil) + require.NoError(t, err) + require.Equal(t, []byte(`\x00010203`), buf) +} + +// https://github.com/jackc/pgx/issues/1763 +func TestMapEncodeNamedTypeOfByteSliceIntoTextTextFormat(t *testing.T) { + m := pgtype.NewMap() + buf, err := m.Encode(pgtype.TextOID, pgtype.TextFormatCode, json.RawMessage(`{"foo": "bar"}`), nil) + require.NoError(t, err) + require.Equal(t, []byte(`{"foo": "bar"}`), buf) +} + +// https://github.com/jackc/pgx/issues/1326 +func TestMapScanPointerToRenamedType(t *testing.T) { + srcBuf := []byte("foo") + m := pgtype.NewMap() + + var rs *_string + err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, srcBuf, &rs) + assert.NoError(t, err) + require.NotNil(t, rs) + assert.Equal(t, "foo", string(*rs)) +} + +// https://github.com/jackc/pgx/issues/1326 +func TestMapScanNullToWrongType(t *testing.T) { + m := pgtype.NewMap() + + var n *int32 + err := m.Scan(pgtype.TextOID, pgx.TextFormatCode, nil, &n) + assert.NoError(t, err) + assert.Nil(t, n) + + var pn pgtype.Int4 + err = m.Scan(pgtype.TextOID, pgx.TextFormatCode, nil, &pn) + assert.NoError(t, err) + assert.False(t, pn.Valid) +} + +func TestScanToSliceOfRenamedUint8(t *testing.T) { + m := pgtype.NewMap() + var ruint8 []_uint8 + err := m.Scan(pgtype.Int2ArrayOID, pgx.TextFormatCode, []byte("{2,4}"), &ruint8) + assert.NoError(t, err) + assert.Equal(t, []_uint8{2, 4}, ruint8) +} + +func TestMapScanTextToBool(t *testing.T) { + tests := []struct { + name string + src []byte + want bool + }{ + {"t", []byte("t"), true}, + {"f", []byte("f"), false}, + {"y", []byte("y"), true}, + {"n", []byte("n"), false}, + {"1", []byte("1"), true}, + {"0", []byte("0"), false}, + {"true", []byte("true"), true}, + {"false", []byte("false"), false}, + {"yes", []byte("yes"), true}, + {"no", []byte("no"), false}, + {"on", []byte("on"), true}, + {"off", []byte("off"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := pgtype.NewMap() + + var v bool + err := m.Scan(pgtype.BoolOID, pgx.TextFormatCode, tt.src, &v) + require.NoError(t, err) + assert.Equal(t, tt.want, v) + }) + } +} + +func TestMapScanTextToBoolError(t *testing.T) { + tests := []struct { + name string + src []byte + want string + }{ + {"nil", nil, "cannot scan NULL into *bool"}, + {"empty", []byte{}, "cannot scan empty string into *bool"}, + {"foo", []byte("foo"), "unknown boolean string representation \"foo\""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := pgtype.NewMap() + + var v bool + err := m.Scan(pgtype.BoolOID, pgx.TextFormatCode, tt.src, &v) + require.ErrorContains(t, err, tt.want) + }) + } +} + +type databaseValuerUUID [16]byte + +func (v databaseValuerUUID) Value() (driver.Value, error) { + return fmt.Sprintf("%x", v), nil +} + +// https://github.com/jackc/pgx/issues/1502 +func TestMapEncodePlanCacheUUIDTypeConfusion(t *testing.T) { + expected := []byte{ + 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0xb, 0x86, 0, 0, 0, 2, 0, 0, 0, 1, + 0, 0, 0, 16, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 0, 0, 16, + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, + } + + m := pgtype.NewMap() + buf, err := m.Encode(pgtype.UUIDArrayOID, pgtype.BinaryFormatCode, + []databaseValuerUUID{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, {15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}}, + nil) + require.NoError(t, err) + require.Equal(t, expected, buf) + + // This actually *should* fail. In the actual query path this error is detected and the encoding falls back to the + // text format. In the bug this test is guarding against regression this would panic. + _, err = m.Encode(pgtype.UUIDArrayOID, pgtype.BinaryFormatCode, + []string{"00010203-0405-0607-0809-0a0b0c0d0e0f", "0f0e0d0c-0b0a-0908-0706-0504-03020100"}, + nil) + require.Error(t, err) +} + +// https://github.com/jackc/pgx/issues/1763 +func TestMapEncodeRawJSONIntoUnknownOID(t *testing.T) { + m := pgtype.NewMap() + buf, err := m.Encode(0, pgtype.TextFormatCode, json.RawMessage(`{"foo": "bar"}`), nil) + require.NoError(t, err) + require.Equal(t, []byte(`{"foo": "bar"}`), buf) +} + +// PlanScan previously used a cache to improve performance. However, the cache could get confused in certain cases. The +// example below was one such failure case. +func TestCachedPlanScanConfusion(t *testing.T) { + m := pgtype.NewMap() + var err error + + var tags any + err = m.Scan(pgtype.TextArrayOID, pgx.TextFormatCode, []byte("{foo,bar,baz}"), &tags) + require.NoError(t, err) + + var cells [][]string + err = m.Scan(pgtype.TextArrayOID, pgx.TextFormatCode, []byte("{{foo,bar},{baz,quz}}"), &cells) + require.NoError(t, err) +} + +func BenchmarkMapScanInt4IntoBinaryDecoder(b *testing.B) { + m := pgtype.NewMap() + src := []byte{0, 0, 0, 42} + var v pgtype.Int4 + + for i := 0; i < b.N; i++ { + v = pgtype.Int4{} + err := m.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + if err != nil { + b.Fatal(err) + } + if v != (pgtype.Int4{Int32: 42, Valid: true}) { + b.Fatal("scan failed due to bad value") + } + } +} + +func BenchmarkMapScanInt4IntoGoInt32(b *testing.B) { + m := pgtype.NewMap() + src := []byte{0, 0, 0, 42} + var v int32 + + for i := 0; i < b.N; i++ { + v = 0 + err := m.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + if err != nil { + b.Fatal(err) + } + if v != 42 { + b.Fatal("scan failed due to bad value") + } + } +} + +func BenchmarkScanPlanScanInt4IntoBinaryDecoder(b *testing.B) { + m := pgtype.NewMap() + src := []byte{0, 0, 0, 42} + var v pgtype.Int4 + + plan := m.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v) + + for i := 0; i < b.N; i++ { + v = pgtype.Int4{} + err := plan.Scan(src, &v) + if err != nil { + b.Fatal(err) + } + if v != (pgtype.Int4{Int32: 42, Valid: true}) { + b.Fatal("scan failed due to bad value") + } + } +} + +func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) { + m := pgtype.NewMap() + src := []byte{0, 0, 0, 42} + var v int32 + + plan := m.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v) + + for i := 0; i < b.N; i++ { + v = 0 + err := plan.Scan(src, &v) + if err != nil { + b.Fatal(err) + } + if v != 42 { + b.Fatal("scan failed due to bad value") + } + } +} + +func isExpectedEq(a any) func(any) bool { + return func(v any) bool { + return a == v + } +} + +func isPtrExpectedEq(a any) func(any) bool { + return func(v any) bool { + val := reflect.ValueOf(v) + return a == val.Elem().Interface() + } +} diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go deleted file mode 100644 index e441a6901..000000000 --- a/pgtype/pguint32.go +++ /dev/null @@ -1,162 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "math" - "strconv" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -// pguint32 is the core type that is used to implement PostgreSQL types such as -// CID and XID. -type pguint32 struct { - Uint uint32 - Status Status -} - -// Set converts from src to dst. Note that as pguint32 is not a general -// number type Set does not do automatic type conversion as other number -// types do. -func (dst *pguint32) Set(src interface{}) error { - switch value := src.(type) { - case int64: - if value < 0 { - return errors.Errorf("%d is less than minimum value for pguint32", value) - } - if value > math.MaxUint32 { - return errors.Errorf("%d is greater than maximum value for pguint32", value) - } - *dst = pguint32{Uint: uint32(value), Status: Present} - case uint32: - *dst = pguint32{Uint: value, Status: Present} - default: - return errors.Errorf("cannot convert %v to pguint32", value) - } - - return nil -} - -func (dst *pguint32) Get() interface{} { - switch dst.Status { - case Present: - return dst.Uint - case Null: - return nil - default: - return dst.Status - } -} - -// AssignTo assigns from src to dst. Note that as pguint32 is not a general number -// type AssignTo does not do automatic type conversion as other number types do. -func (src *pguint32) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *uint32: - if src.Status == Present { - *v = src.Uint - } else { - return errors.Errorf("cannot assign %v into %T", src, dst) - } - case **uint32: - if src.Status == Present { - n := src.Uint - *v = &n - } else { - *v = nil - } - } - - return nil -} - -func (dst *pguint32) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = pguint32{Status: Null} - return nil - } - - n, err := strconv.ParseUint(string(src), 10, 32) - if err != nil { - return err - } - - *dst = pguint32{Uint: uint32(n), Status: Present} - return nil -} - -func (dst *pguint32) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = pguint32{Status: Null} - return nil - } - - if len(src) != 4 { - return errors.Errorf("invalid length: %v", len(src)) - } - - n := binary.BigEndian.Uint32(src) - *dst = pguint32{Uint: n, Status: Present} - return nil -} - -func (src *pguint32) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - return append(buf, strconv.FormatUint(uint64(src.Uint), 10)...), nil -} - -func (src *pguint32) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - return pgio.AppendUint32(buf, src.Uint), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *pguint32) Scan(src interface{}) error { - if src == nil { - *dst = pguint32{Status: Null} - return nil - } - - switch src := src.(type) { - case uint32: - *dst = pguint32{Uint: src, Status: Present} - return nil - case int64: - *dst = pguint32{Uint: uint32(src), Status: Present} - return nil - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *pguint32) Value() (driver.Value, error) { - switch src.Status { - case Present: - return int64(src.Uint), nil - case Null: - return nil, nil - default: - return nil, errUndefined - } -} diff --git a/pgtype/point.go b/pgtype/point.go index 3132a9395..b701513dc 100644 --- a/pgtype/point.go +++ b/pgtype/point.go @@ -1,6 +1,7 @@ package pgtype import ( + "bytes" "database/sql/driver" "encoding/binary" "fmt" @@ -8,8 +9,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" + "github.com/jackc/pgx/v5/internal/pgio" ) type Vec2 struct { @@ -17,123 +17,253 @@ type Vec2 struct { Y float64 } +type PointScanner interface { + ScanPoint(v Point) error +} + +type PointValuer interface { + PointValue() (Point, error) +} + type Point struct { - P Vec2 - Status Status + P Vec2 + Valid bool } -func (dst *Point) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Point", src) +// ScanPoint implements the [PointScanner] interface. +func (p *Point) ScanPoint(v Point) error { + *p = v + return nil } -func (dst *Point) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } +// PointValue implements the [PointValuer] interface. +func (p Point) PointValue() (Point, error) { + return p, nil } -func (src *Point) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) +func parsePoint(src []byte) (*Point, error) { + if src == nil || bytes.Equal(src, []byte("null")) { + return &Point{}, nil + } + + if len(src) < 5 { + return nil, fmt.Errorf("invalid length for point: %v", len(src)) + } + if src[0] == '"' && src[len(src)-1] == '"' { + src = src[1 : len(src)-1] + } + sx, sy, found := strings.Cut(string(src[1:len(src)-1]), ",") + if !found { + return nil, fmt.Errorf("invalid format for point") + } + + x, err := strconv.ParseFloat(sx, 64) + if err != nil { + return nil, err + } + + y, err := strconv.ParseFloat(sy, 64) + if err != nil { + return nil, err + } + + return &Point{P: Vec2{x, y}, Valid: true}, nil } -func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error { +// Scan implements the [database/sql.Scanner] interface. +func (dst *Point) Scan(src any) error { if src == nil { - *dst = Point{Status: Null} + *dst = Point{} return nil } - if len(src) < 5 { - return errors.Errorf("invalid length for point: %v", len(src)) + switch src := src.(type) { + case string: + return scanPlanTextAnyToPointScanner{}.Scan([]byte(src), dst) } - parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) - if len(parts) < 2 { - return errors.Errorf("invalid format for point") + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (src Point) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil } - x, err := strconv.ParseFloat(parts[0], 64) + buf, err := PointCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) if err != nil { - return err + return nil, err + } + return string(buf), err +} + +// MarshalJSON implements the [encoding/json.Marshaler] interface. +func (src Point) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil } - y, err := strconv.ParseFloat(parts[1], 64) + var buff bytes.Buffer + buff.WriteByte('"') + buff.WriteString(fmt.Sprintf("(%g,%g)", src.P.X, src.P.Y)) + buff.WriteByte('"') + return buff.Bytes(), nil +} + +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. +func (dst *Point) UnmarshalJSON(point []byte) error { + p, err := parsePoint(point) if err != nil { return err } - - *dst = Point{P: Vec2{x, y}, Status: Present} + *dst = *p return nil } -func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Point{Status: Null} +type PointCodec struct{} + +func (PointCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (PointCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (PointCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(PointValuer); !ok { return nil } - if len(src) != 16 { - return errors.Errorf("invalid length for point: %v", len(src)) + switch format { + case BinaryFormatCode: + return encodePlanPointCodecBinary{} + case TextFormatCode: + return encodePlanPointCodecText{} } - x := binary.BigEndian.Uint64(src) - y := binary.BigEndian.Uint64(src[8:]) + return nil +} + +type encodePlanPointCodecBinary struct{} - *dst = Point{ - P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, - Status: Present, +func (encodePlanPointCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + point, err := value.(PointValuer).PointValue() + if err != nil { + return nil, err } - return nil + + if !point.Valid { + return nil, nil + } + + buf = pgio.AppendUint64(buf, math.Float64bits(point.P.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(point.P.Y)) + return buf, nil } -func (src *Point) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: +type encodePlanPointCodecText struct{} + +func (encodePlanPointCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + point, err := value.(PointValuer).PointValue() + if err != nil { + return nil, err + } + + if !point.Valid { return nil, nil - case Undefined: - return nil, errUndefined } - return append(buf, fmt.Sprintf(`(%f,%f)`, src.P.X, src.P.Y)...), nil + return append(buf, fmt.Sprintf(`(%s,%s)`, + strconv.FormatFloat(point.P.X, 'f', -1, 64), + strconv.FormatFloat(point.P.Y, 'f', -1, 64), + )...), nil +} + +func (PointCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case PointScanner: + return scanPlanBinaryPointToPointScanner{} + } + case TextFormatCode: + switch target.(type) { + case PointScanner: + return scanPlanTextAnyToPointScanner{} + } + } + + return nil +} + +func (c PointCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) } -func (src *Point) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: +func (c PointCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { return nil, nil - case Undefined: - return nil, errUndefined } - buf = pgio.AppendUint64(buf, math.Float64bits(src.P.X)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P.Y)) - return buf, nil + var point Point + err := codecScan(c, m, oid, format, src, &point) + if err != nil { + return nil, err + } + return point, nil } -// Scan implements the database/sql Scanner interface. -func (dst *Point) Scan(src interface{}) error { +type scanPlanBinaryPointToPointScanner struct{} + +func (scanPlanBinaryPointToPointScanner) Scan(src []byte, dst any) error { + scanner := (dst).(PointScanner) + if src == nil { - *dst = Point{Status: Null} - return nil + return scanner.ScanPoint(Point{}) } - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + if len(src) != 16 { + return fmt.Errorf("invalid length for point: %v", len(src)) } - return errors.Errorf("cannot scan %T", src) + x := binary.BigEndian.Uint64(src) + y := binary.BigEndian.Uint64(src[8:]) + + return scanner.ScanPoint(Point{ + P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, + Valid: true, + }) } -// Value implements the database/sql/driver Valuer interface. -func (src *Point) Value() (driver.Value, error) { - return EncodeValueText(src) +type scanPlanTextAnyToPointScanner struct{} + +func (scanPlanTextAnyToPointScanner) Scan(src []byte, dst any) error { + scanner := (dst).(PointScanner) + + if src == nil { + return scanner.ScanPoint(Point{}) + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for point: %v", len(src)) + } + + sx, sy, found := strings.Cut(string(src[1:len(src)-1]), ",") + if !found { + return fmt.Errorf("invalid format for point") + } + + x, err := strconv.ParseFloat(sx, 64) + if err != nil { + return err + } + + y, err := strconv.ParseFloat(sy, 64) + if err != nil { + return err + } + + return scanner.ScanPoint(Point{P: Vec2{x, y}, Valid: true}) } diff --git a/pgtype/point_test.go b/pgtype/point_test.go index f46b342d0..336f1a470 100644 --- a/pgtype/point_test.go +++ b/pgtype/point_test.go @@ -1,16 +1,102 @@ package pgtype_test import ( + "context" + "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" ) -func TestPointTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "point", []interface{}{ - &pgtype.Point{P: pgtype.Vec2{1.234, 5.6789}, Status: pgtype.Present}, - &pgtype.Point{P: pgtype.Vec2{-1.234, -5.6789}, Status: pgtype.Present}, - &pgtype.Point{Status: pgtype.Null}, +func TestPointCodec(t *testing.T) { + skipCockroachDB(t, "Server does not support type point") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "point", []pgxtest.ValueRoundTripTest{ + { + pgtype.Point{P: pgtype.Vec2{1.234, 5.6789012345}, Valid: true}, + new(pgtype.Point), + isExpectedEq(pgtype.Point{P: pgtype.Vec2{1.234, 5.6789012345}, Valid: true}), + }, + { + pgtype.Point{P: pgtype.Vec2{-1.234, -5.6789}, Valid: true}, + new(pgtype.Point), + isExpectedEq(pgtype.Point{P: pgtype.Vec2{-1.234, -5.6789}, Valid: true}), + }, + {pgtype.Point{}, new(pgtype.Point), isExpectedEq(pgtype.Point{})}, + {nil, new(pgtype.Point), isExpectedEq(pgtype.Point{})}, }) } + +func TestPoint_MarshalJSON(t *testing.T) { + tests := []struct { + name string + point pgtype.Point + want []byte + }{ + { + name: "second", + point: pgtype.Point{ + P: pgtype.Vec2{X: 12.245, Y: 432.12}, + Valid: true, + }, + want: []byte(`"(12.245,432.12)"`), + }, + { + name: "third", + point: pgtype.Point{ + P: pgtype.Vec2{}, + }, + want: []byte("null"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.point.MarshalJSON() + require.NoError(t, err) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestPoint_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + valid bool + arg []byte + wantErr bool + }{ + { + name: "first", + valid: true, + arg: []byte(`"(123.123,54.12)"`), + wantErr: false, + }, + { + name: "second", + valid: false, + arg: []byte(`"(123.123,54.1sad2)"`), + wantErr: true, + }, + { + name: "third", + valid: false, + arg: []byte("null"), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dst := &pgtype.Point{} + if err := dst.UnmarshalJSON(tt.arg); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if dst.Valid != tt.valid { + t.Errorf("Valid mismatch: %v != %v", dst.Valid, tt.valid) + } + }) + } +} diff --git a/pgtype/polygon.go b/pgtype/polygon.go index 3f3d9f537..a84b25fe3 100644 --- a/pgtype/polygon.go +++ b/pgtype/polygon.go @@ -8,91 +8,173 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" + "github.com/jackc/pgx/v5/internal/pgio" ) +type PolygonScanner interface { + ScanPolygon(v Polygon) error +} + +type PolygonValuer interface { + PolygonValue() (Polygon, error) +} + type Polygon struct { - P []Vec2 - Status Status + P []Vec2 + Valid bool } -func (dst *Polygon) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Polygon", src) +// ScanPolygon implements the [PolygonScanner] interface. +func (p *Polygon) ScanPolygon(v Polygon) error { + *p = v + return nil +} + +// PolygonValue implements the [PolygonValuer] interface. +func (p Polygon) PolygonValue() (Polygon, error) { + return p, nil } -func (dst *Polygon) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: +// Scan implements the [database/sql.Scanner] interface. +func (p *Polygon) Scan(src any) error { + if src == nil { + *p = Polygon{} return nil - default: - return dst.Status } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToPolygonScanner{}.Scan([]byte(src), p) + } + + return fmt.Errorf("cannot scan %T", src) } -func (src *Polygon) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) +// Value implements the [database/sql/driver.Valuer] interface. +func (p Polygon) Value() (driver.Value, error) { + if !p.Valid { + return nil, nil + } + + buf, err := PolygonCodec{}.PlanEncode(nil, 0, TextFormatCode, p).Encode(p, nil) + if err != nil { + return nil, err + } + + return string(buf), err } -func (dst *Polygon) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Polygon{Status: Null} +type PolygonCodec struct{} + +func (PolygonCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (PolygonCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (PolygonCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(PolygonValuer); !ok { return nil } - if len(src) < 7 { - return errors.Errorf("invalid length for Polygon: %v", len(src)) + switch format { + case BinaryFormatCode: + return encodePlanPolygonCodecBinary{} + case TextFormatCode: + return encodePlanPolygonCodecText{} } - points := make([]Vec2, 0) + return nil +} - str := string(src[2:]) +type encodePlanPolygonCodecBinary struct{} - for { - end := strings.IndexByte(str, ',') - x, err := strconv.ParseFloat(str[:end], 64) - if err != nil { - return err - } +func (encodePlanPolygonCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + polygon, err := value.(PolygonValuer).PolygonValue() + if err != nil { + return nil, err + } - str = str[end+1:] - end = strings.IndexByte(str, ')') + if !polygon.Valid { + return nil, nil + } - y, err := strconv.ParseFloat(str[:end], 64) - if err != nil { - return err + buf = pgio.AppendInt32(buf, int32(len(polygon.P))) + + for _, p := range polygon.P { + buf = pgio.AppendUint64(buf, math.Float64bits(p.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(p.Y)) + } + + return buf, nil +} + +type encodePlanPolygonCodecText struct{} + +func (encodePlanPolygonCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + polygon, err := value.(PolygonValuer).PolygonValue() + if err != nil { + return nil, err + } + + if !polygon.Valid { + return nil, nil + } + + buf = append(buf, '(') + + for i, p := range polygon.P { + if i > 0 { + buf = append(buf, ',') } + buf = append(buf, fmt.Sprintf(`(%s,%s)`, + strconv.FormatFloat(p.X, 'f', -1, 64), + strconv.FormatFloat(p.Y, 'f', -1, 64), + )...) + } - points = append(points, Vec2{x, y}) + buf = append(buf, ')') - if end+3 < len(str) { - str = str[end+3:] - } else { - break + return buf, nil +} + +func (PolygonCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case PolygonScanner: + return scanPlanBinaryPolygonToPolygonScanner{} + } + case TextFormatCode: + switch target.(type) { + case PolygonScanner: + return scanPlanTextAnyToPolygonScanner{} } } - *dst = Polygon{P: points, Status: Present} return nil } -func (dst *Polygon) DecodeBinary(ci *ConnInfo, src []byte) error { +type scanPlanBinaryPolygonToPolygonScanner struct{} + +func (scanPlanBinaryPolygonToPolygonScanner) Scan(src []byte, dst any) error { + scanner := (dst).(PolygonScanner) + if src == nil { - *dst = Polygon{Status: Null} - return nil + return scanner.ScanPolygon(Polygon{}) } if len(src) < 5 { - return errors.Errorf("invalid length for Polygon: %v", len(src)) + return fmt.Errorf("invalid length for polygon: %v", len(src)) } pointCount := int(binary.BigEndian.Uint32(src)) rp := 4 if 4+pointCount*16 != len(src) { - return errors.Errorf("invalid length for Polygon with %d points: %v", pointCount, len(src)) + return fmt.Errorf("invalid length for Polygon with %d points: %v", pointCount, len(src)) } points := make([]Vec2, pointCount) @@ -104,71 +186,69 @@ func (dst *Polygon) DecodeBinary(ci *ConnInfo, src []byte) error { points[i] = Vec2{math.Float64frombits(x), math.Float64frombits(y)} } - *dst = Polygon{ - P: points, - Status: Present, - } - return nil + return scanner.ScanPolygon(Polygon{ + P: points, + Valid: true, + }) } -func (src *Polygon) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined +type scanPlanTextAnyToPolygonScanner struct{} + +func (scanPlanTextAnyToPolygonScanner) Scan(src []byte, dst any) error { + scanner := (dst).(PolygonScanner) + + if src == nil { + return scanner.ScanPolygon(Polygon{}) } - buf = append(buf, '(') + if len(src) < 7 { + return fmt.Errorf("invalid length for Polygon: %v", len(src)) + } - for i, p := range src.P { - if i > 0 { - buf = append(buf, ',') + points := make([]Vec2, 0) + + str := string(src[2:]) + + for { + end := strings.IndexByte(str, ',') + x, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err } - buf = append(buf, fmt.Sprintf(`(%f,%f)`, p.X, p.Y)...) - } - return append(buf, ')'), nil -} + str = str[end+1:] + end = strings.IndexByte(str, ')') -func (src *Polygon) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } + y, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } - buf = pgio.AppendInt32(buf, int32(len(src.P))) + points = append(points, Vec2{x, y}) - for _, p := range src.P { - buf = pgio.AppendUint64(buf, math.Float64bits(p.X)) - buf = pgio.AppendUint64(buf, math.Float64bits(p.Y)) + if end+3 < len(str) { + str = str[end+3:] + } else { + break + } } - return buf, nil + return scanner.ScanPolygon(Polygon{P: points, Valid: true}) +} + +func (c PolygonCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) } -// Scan implements the database/sql Scanner interface. -func (dst *Polygon) Scan(src interface{}) error { +func (c PolygonCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { - *dst = Polygon{Status: Null} - return nil + return nil, nil } - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + var polygon Polygon + err := codecScan(c, m, oid, format, src, &polygon) + if err != nil { + return nil, err } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *Polygon) Value() (driver.Value, error) { - return EncodeValueText(src) + return polygon, nil } diff --git a/pgtype/polygon_test.go b/pgtype/polygon_test.go index 48481dc5b..5ddbc1669 100644 --- a/pgtype/polygon_test.go +++ b/pgtype/polygon_test.go @@ -1,22 +1,59 @@ package pgtype_test import ( + "context" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" ) +func isExpectedEqPolygon(a any) func(any) bool { + return func(v any) bool { + ap := a.(pgtype.Polygon) + vp := v.(pgtype.Polygon) + + if !(ap.Valid == vp.Valid && len(ap.P) == len(vp.P)) { + return false + } + + for i := range ap.P { + if ap.P[i] != vp.P[i] { + return false + } + } + + return true + } +} + func TestPolygonTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "polygon", []interface{}{ - &pgtype.Polygon{ - P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {5.0, 3.234}}, - Status: pgtype.Present, + skipCockroachDB(t, "Server does not support type polygon") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "polygon", []pgxtest.ValueRoundTripTest{ + { + pgtype.Polygon{ + P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}, {5.0, 3.234}}, + Valid: true, + }, + new(pgtype.Polygon), + isExpectedEqPolygon(pgtype.Polygon{ + P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}, {5.0, 3.234}}, + Valid: true, + }), }, - &pgtype.Polygon{ - P: []pgtype.Vec2{{3.14, -1.678}, {7.1, -5.234}, {23.1, 9.34}}, - Status: pgtype.Present, + { + pgtype.Polygon{ + P: []pgtype.Vec2{{3.14, -1.678}, {7.1, -5.234}, {23.1, 9.34}}, + Valid: true, + }, + new(pgtype.Polygon), + isExpectedEqPolygon(pgtype.Polygon{ + P: []pgtype.Vec2{{3.14, -1.678}, {7.1, -5.234}, {23.1, 9.34}}, + Valid: true, + }), }, - &pgtype.Polygon{Status: pgtype.Null}, + {pgtype.Polygon{}, new(pgtype.Polygon), isExpectedEqPolygon(pgtype.Polygon{})}, + {nil, new(pgtype.Polygon), isExpectedEqPolygon(pgtype.Polygon{})}, }) } diff --git a/pgtype/qchar.go b/pgtype/qchar.go index 064dab1e9..fc40a5b2c 100644 --- a/pgtype/qchar.go +++ b/pgtype/qchar.go @@ -1,146 +1,141 @@ package pgtype import ( + "database/sql/driver" + "fmt" "math" - "strconv" - - "github.com/pkg/errors" ) -// QChar is for PostgreSQL's special 8-bit-only "char" type more akin to the C +// QCharCodec is for PostgreSQL's special 8-bit-only "char" type more akin to the C // language's char type, or Go's byte type. (Note that the name in PostgreSQL // itself is "char", in double-quotes, and not char.) It gets used a lot in // PostgreSQL's system tables to hold a single ASCII character value (eg // pg_class.relkind). It is named Qchar for quoted char to disambiguate from SQL // standard type char. -// -// Not all possible values of QChar are representable in the text format. -// Therefore, QChar does not implement TextEncoder and TextDecoder. In -// addition, database/sql Scanner and database/sql/driver Value are not -// implemented. -type QChar struct { - Int int8 - Status Status +type QCharCodec struct{} + +func (QCharCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode } -func (dst *QChar) Set(src interface{}) error { - if src == nil { - *dst = QChar{Status: Null} - return nil - } +func (QCharCodec) PreferredFormat() int16 { + return BinaryFormatCode +} - switch value := src.(type) { - case int8: - *dst = QChar{Int: value, Status: Present} - case uint8: - if value > math.MaxInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Status: Present} - case int16: - if value < math.MinInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) - } - if value > math.MaxInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Status: Present} - case uint16: - if value > math.MaxInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Status: Present} - case int32: - if value < math.MinInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) - } - if value > math.MaxInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Status: Present} - case uint32: - if value > math.MaxInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Status: Present} - case int64: - if value < math.MinInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) - } - if value > math.MaxInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) +func (QCharCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case TextFormatCode, BinaryFormatCode: + switch value.(type) { + case byte: + return encodePlanQcharCodecByte{} + case rune: + return encodePlanQcharCodecRune{} } - *dst = QChar{Int: int8(value), Status: Present} - case uint64: - if value > math.MaxInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Status: Present} - case int: - if value < math.MinInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) - } - if value > math.MaxInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Status: Present} - case uint: - if value > math.MaxInt8 { - return errors.Errorf("%d is greater than maximum value for QChar", value) - } - *dst = QChar{Int: int8(value), Status: Present} - case string: - num, err := strconv.ParseInt(value, 10, 8) - if err != nil { - return err - } - *dst = QChar{Int: int8(num), Status: Present} - default: - if originalSrc, ok := underlyingNumberType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to QChar", value) } return nil } -func (dst *QChar) Get() interface{} { - switch dst.Status { - case Present: - return dst.Int - case Null: - return nil - default: - return dst.Status +type encodePlanQcharCodecByte struct{} + +func (encodePlanQcharCodecByte) Encode(value any, buf []byte) (newBuf []byte, err error) { + b := value.(byte) + buf = append(buf, b) + return buf, nil +} + +type encodePlanQcharCodecRune struct{} + +func (encodePlanQcharCodecRune) Encode(value any, buf []byte) (newBuf []byte, err error) { + r := value.(rune) + if r > math.MaxUint8 { + return nil, fmt.Errorf(`%v cannot be encoded to "char"`, r) + } + b := byte(r) + buf = append(buf, b) + return buf, nil +} + +func (QCharCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case TextFormatCode, BinaryFormatCode: + switch target.(type) { + case *byte: + return scanPlanQcharCodecByte{} + case *rune: + return scanPlanQcharCodecRune{} + } } + + return nil } -func (src *QChar) AssignTo(dst interface{}) error { - return int64AssignTo(int64(src.Int), src.Status, dst) +type scanPlanQcharCodecByte struct{} + +func (scanPlanQcharCodecByte) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) > 1 { + return fmt.Errorf(`invalid length for "char": %v`, len(src)) + } + + b := dst.(*byte) + // In the text format the zero value is returned as a zero byte value instead of 0 + if len(src) == 0 { + *b = 0 + } else { + *b = src[0] + } + + return nil } -func (dst *QChar) DecodeBinary(ci *ConnInfo, src []byte) error { +type scanPlanQcharCodecRune struct{} + +func (scanPlanQcharCodecRune) Scan(src []byte, dst any) error { if src == nil { - *dst = QChar{Status: Null} - return nil + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) > 1 { + return fmt.Errorf(`invalid length for "char": %v`, len(src)) } - if len(src) != 1 { - return errors.Errorf(`invalid length for "char": %v`, len(src)) + r := dst.(*rune) + // In the text format the zero value is returned as a zero byte value instead of 0 + if len(src) == 0 { + *r = 0 + } else { + *r = rune(src[0]) } - *dst = QChar{Int: int8(src[0]), Status: Present} return nil } -func (src *QChar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: +func (c QCharCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var r rune + err := codecScan(c, m, oid, format, src, &r) + if err != nil { + return nil, err + } + return string(r), nil +} + +func (c QCharCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { return nil, nil - case Undefined: - return nil, errUndefined } - return append(buf, byte(src.Int)), nil + var r rune + err := codecScan(c, m, oid, format, src, &r) + if err != nil { + return nil, err + } + return r, nil } diff --git a/pgtype/qchar_test.go b/pgtype/qchar_test.go index 057a557ff..da00b89e4 100644 --- a/pgtype/qchar_test.go +++ b/pgtype/qchar_test.go @@ -1,143 +1,24 @@ package pgtype_test import ( + "context" "math" - "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v5/pgxtest" ) -func TestQCharTranscode(t *testing.T) { - testutil.TestPgxSuccessfulTranscodeEqFunc(t, `"char"`, []interface{}{ - &pgtype.QChar{Int: math.MinInt8, Status: pgtype.Present}, - &pgtype.QChar{Int: -1, Status: pgtype.Present}, - &pgtype.QChar{Int: 0, Status: pgtype.Present}, - &pgtype.QChar{Int: 1, Status: pgtype.Present}, - &pgtype.QChar{Int: math.MaxInt8, Status: pgtype.Present}, - &pgtype.QChar{Int: 0, Status: pgtype.Null}, - }, func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - }) -} - -func TestQCharSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.QChar - }{ - {source: int8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: int16(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: int32(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: int64(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: int8(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, - {source: int16(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, - {source: int32(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, - {source: int64(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, - {source: uint8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: uint16(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: uint32(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: uint64(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: "1", result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - {source: _int8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var r pgtype.QChar - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestQCharAssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 +func TestQcharTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support qchar") - simpleTests := []struct { - src pgtype.QChar - dst interface{} - expected interface{} - }{ - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + var tests []pgxtest.ValueRoundTripTest + for i := 0; i <= math.MaxUint8; i++ { + tests = append(tests, pgxtest.ValueRoundTripTest{rune(i), new(rune), isExpectedEq(rune(i))}) + tests = append(tests, pgxtest.ValueRoundTripTest{byte(i), new(byte), isExpectedEq(byte(i))}) } + tests = append(tests, pgxtest.ValueRoundTripTest{nil, new(*rune), isExpectedEq((*rune)(nil))}) + tests = append(tests, pgxtest.ValueRoundTripTest{nil, new(*byte), isExpectedEq((*byte)(nil))}) - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.QChar - dst interface{} - expected interface{} - }{ - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, - {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.QChar - dst interface{} - }{ - {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui8}, - {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui16}, - {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui32}, - {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui64}, - {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui}, - {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &i16}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } + // Can only test with known OIDs as rune and byte would be considered numbers. + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, `"char"`, tests) } diff --git a/pgtype/range.go b/pgtype/range.go index 54fc6ca07..62d699905 100644 --- a/pgtype/range.go +++ b/pgtype/range.go @@ -3,8 +3,7 @@ package pgtype import ( "bytes" "encoding/binary" - - "github.com/pkg/errors" + "fmt" ) type BoundType byte @@ -20,15 +19,15 @@ func (bt BoundType) String() string { return string(bt) } -type UntypedTextRange struct { +type untypedTextRange struct { Lower string Upper string LowerType BoundType UpperType BoundType } -func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { - utr := &UntypedTextRange{} +func parseUntypedTextRange(src string) (*untypedTextRange, error) { + utr := &untypedTextRange{} if src == "empty" { utr.LowerType = Empty utr.UpperType = Empty @@ -41,7 +40,7 @@ func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { r, _, err := buf.ReadRune() if err != nil { - return nil, errors.Errorf("invalid lower bound: %v", err) + return nil, fmt.Errorf("invalid lower bound: %w", err) } switch r { case '(': @@ -49,12 +48,12 @@ func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { case '[': utr.LowerType = Inclusive default: - return nil, errors.Errorf("missing lower bound, instead got: %v", string(r)) + return nil, fmt.Errorf("missing lower bound, instead got: %v", string(r)) } r, _, err = buf.ReadRune() if err != nil { - return nil, errors.Errorf("invalid lower value: %v", err) + return nil, fmt.Errorf("invalid lower value: %w", err) } buf.UnreadRune() @@ -63,21 +62,21 @@ func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { } else { utr.Lower, err = rangeParseValue(buf) if err != nil { - return nil, errors.Errorf("invalid lower value: %v", err) + return nil, fmt.Errorf("invalid lower value: %w", err) } } r, _, err = buf.ReadRune() if err != nil { - return nil, errors.Errorf("missing range separator: %v", err) + return nil, fmt.Errorf("missing range separator: %w", err) } if r != ',' { - return nil, errors.Errorf("missing range separator: %v", r) + return nil, fmt.Errorf("missing range separator: %v", r) } r, _, err = buf.ReadRune() if err != nil { - return nil, errors.Errorf("invalid upper value: %v", err) + return nil, fmt.Errorf("invalid upper value: %w", err) } if r == ')' || r == ']' { @@ -86,12 +85,12 @@ func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { buf.UnreadRune() utr.Upper, err = rangeParseValue(buf) if err != nil { - return nil, errors.Errorf("invalid upper value: %v", err) + return nil, fmt.Errorf("invalid upper value: %w", err) } r, _, err = buf.ReadRune() if err != nil { - return nil, errors.Errorf("missing upper bound: %v", err) + return nil, fmt.Errorf("missing upper bound: %w", err) } switch r { case ')': @@ -99,14 +98,14 @@ func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { case ']': utr.UpperType = Inclusive default: - return nil, errors.Errorf("missing upper bound, instead got: %v", string(r)) + return nil, fmt.Errorf("missing upper bound, instead got: %v", string(r)) } } skipWhitespace(buf) if buf.Len() > 0 { - return nil, errors.Errorf("unexpected trailing data: %v", buf.String()) + return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) } return utr, nil @@ -174,7 +173,7 @@ func rangeParseQuotedValue(buf *bytes.Buffer) (string, error) { } } -type UntypedBinaryRange struct { +type untypedBinaryRange struct { Lower []byte Upper []byte LowerType BoundType @@ -192,17 +191,19 @@ type UntypedBinaryRange struct { // 18 = [ = 10010 // 24 = = 11000 -const emptyMask = 1 -const lowerInclusiveMask = 2 -const upperInclusiveMask = 4 -const lowerUnboundedMask = 8 -const upperUnboundedMask = 16 +const ( + emptyMask = 1 + lowerInclusiveMask = 2 + upperInclusiveMask = 4 + lowerUnboundedMask = 8 + upperUnboundedMask = 16 +) -func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { - ubr := &UntypedBinaryRange{} +func parseUntypedBinaryRange(src []byte) (*untypedBinaryRange, error) { + ubr := &untypedBinaryRange{} if len(src) == 0 { - return nil, errors.Errorf("range too short: %v", len(src)) + return nil, fmt.Errorf("range too short: %v", len(src)) } rangeType := src[0] @@ -210,7 +211,7 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { if rangeType&emptyMask > 0 { if len(src[rp:]) > 0 { - return nil, errors.Errorf("unexpected trailing bytes parsing empty range: %v", len(src[rp:])) + return nil, fmt.Errorf("unexpected trailing bytes parsing empty range: %v", len(src[rp:])) } ubr.LowerType = Empty ubr.UpperType = Empty @@ -235,13 +236,13 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { if ubr.LowerType == Unbounded && ubr.UpperType == Unbounded { if len(src[rp:]) > 0 { - return nil, errors.Errorf("unexpected trailing bytes parsing unbounded range: %v", len(src[rp:])) + return nil, fmt.Errorf("unexpected trailing bytes parsing unbounded range: %v", len(src[rp:])) } return ubr, nil } if len(src[rp:]) < 4 { - return nil, errors.Errorf("too few bytes for size: %v", src[rp:]) + return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) } valueLen := int(binary.BigEndian.Uint32(src[rp:])) rp += 4 @@ -254,14 +255,14 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { } else { ubr.Upper = val if len(src[rp:]) > 0 { - return nil, errors.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) } return ubr, nil } if ubr.UpperType != Unbounded { if len(src[rp:]) < 4 { - return nil, errors.Errorf("too few bytes for size: %v", src[rp:]) + return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) } valueLen := int(binary.BigEndian.Uint32(src[rp:])) rp += 4 @@ -270,9 +271,53 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { } if len(src[rp:]) > 0 { - return nil, errors.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) } return ubr, nil +} + +// Range is a generic range type. +type Range[T any] struct { + Lower T + Upper T + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (r Range[T]) IsNull() bool { + return !r.Valid +} +func (r Range[T]) BoundTypes() (lower, upper BoundType) { + return r.LowerType, r.UpperType +} + +func (r Range[T]) Bounds() (lower, upper any) { + return &r.Lower, &r.Upper +} + +func (r *Range[T]) ScanNull() error { + *r = Range[T]{} + return nil +} + +func (r *Range[T]) ScanBounds() (lowerTarget, upperTarget any) { + return &r.Lower, &r.Upper +} + +func (r *Range[T]) SetBoundTypes(lower, upper BoundType) error { + if lower == Unbounded || lower == Empty { + var zero T + r.Lower = zero + } + if upper == Unbounded || upper == Empty { + var zero T + r.Upper = zero + } + r.LowerType = lower + r.UpperType = upper + r.Valid = true + return nil } diff --git a/pgtype/range_codec.go b/pgtype/range_codec.go new file mode 100644 index 000000000..684f1bf73 --- /dev/null +++ b/pgtype/range_codec.go @@ -0,0 +1,379 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +// RangeValuer is a type that can be converted into a PostgreSQL range. +type RangeValuer interface { + // IsNull returns true if the value is SQL NULL. + IsNull() bool + + // BoundTypes returns the lower and upper bound types. + BoundTypes() (lower, upper BoundType) + + // Bounds returns the lower and upper range values. + Bounds() (lower, upper any) +} + +// RangeScanner is a type can be scanned from a PostgreSQL range. +type RangeScanner interface { + // ScanNull sets the value to SQL NULL. + ScanNull() error + + // ScanBounds returns values usable as a scan target. The returned values may not be scanned if the range is empty or + // the bound type is unbounded. + ScanBounds() (lowerTarget, upperTarget any) + + // SetBoundTypes sets the lower and upper bound types. ScanBounds will be called and the returned values scanned + // (if appropriate) before SetBoundTypes is called. If the bound types are unbounded or empty this method must + // also set the bound values. + SetBoundTypes(lower, upper BoundType) error +} + +// RangeCodec is a codec for any range type. +type RangeCodec struct { + ElementType *Type +} + +func (c *RangeCodec) FormatSupported(format int16) bool { + return c.ElementType.Codec.FormatSupported(format) +} + +func (c *RangeCodec) PreferredFormat() int16 { + if c.FormatSupported(BinaryFormatCode) { + return BinaryFormatCode + } + return TextFormatCode +} + +func (c *RangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(RangeValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return &encodePlanRangeCodecRangeValuerToBinary{rc: c, m: m} + case TextFormatCode: + return &encodePlanRangeCodecRangeValuerToText{rc: c, m: m} + } + + return nil +} + +type encodePlanRangeCodecRangeValuerToBinary struct { + rc *RangeCodec + m *Map +} + +func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + getter := value.(RangeValuer) + + if getter.IsNull() { + return nil, nil + } + + lowerType, upperType := getter.BoundTypes() + lower, upper := getter.Bounds() + + var rangeType byte + switch lowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, fmt.Errorf("unknown LowerType: %v", lowerType) + } + + switch upperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, fmt.Errorf("unknown UpperType: %v", upperType) + } + + buf = append(buf, rangeType) + + if lowerType != Unbounded { + if lower == nil { + return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded") + } + + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + lowerPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, BinaryFormatCode, lower) + if lowerPlan == nil { + return nil, fmt.Errorf("cannot encode %v as element of range", lower) + } + + buf, err = lowerPlan.Encode(lower, buf) + if err != nil { + return nil, fmt.Errorf("failed to encode %v as element of range: %w", lower, err) + } + if buf == nil { + return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if upperType != Unbounded { + if upper == nil { + return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded") + } + + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + upperPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, BinaryFormatCode, upper) + if upperPlan == nil { + return nil, fmt.Errorf("cannot encode %v as element of range", upper) + } + + buf, err = upperPlan.Encode(upper, buf) + if err != nil { + return nil, fmt.Errorf("failed to encode %v as element of range: %w", upper, err) + } + if buf == nil { + return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +type encodePlanRangeCodecRangeValuerToText struct { + rc *RangeCodec + m *Map +} + +func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) (newBuf []byte, err error) { + getter := value.(RangeValuer) + + if getter.IsNull() { + return nil, nil + } + + lowerType, upperType := getter.BoundTypes() + lower, upper := getter.Bounds() + + switch lowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, fmt.Errorf("unknown lower bound type %v", lowerType) + } + + if lowerType != Unbounded { + if lower == nil { + return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded") + } + + lowerPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, TextFormatCode, lower) + if lowerPlan == nil { + return nil, fmt.Errorf("cannot encode %v as element of range", lower) + } + + buf, err = lowerPlan.Encode(lower, buf) + if err != nil { + return nil, fmt.Errorf("failed to encode %v as element of range: %w", lower, err) + } + if buf == nil { + return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if upperType != Unbounded { + if upper == nil { + return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded") + } + + upperPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, TextFormatCode, upper) + if upperPlan == nil { + return nil, fmt.Errorf("cannot encode %v as element of range", upper) + } + + buf, err = upperPlan.Encode(upper, buf) + if err != nil { + return nil, fmt.Errorf("failed to encode %v as element of range: %w", upper, err) + } + if buf == nil { + return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded") + } + } + + switch upperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, fmt.Errorf("unknown upper bound type %v", upperType) + } + + return buf, nil +} + +func (c *RangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case RangeScanner: + return &scanPlanBinaryRangeToRangeScanner{rc: c, m: m} + } + case TextFormatCode: + switch target.(type) { + case RangeScanner: + return &scanPlanTextRangeToRangeScanner{rc: c, m: m} + } + } + + return nil +} + +type scanPlanBinaryRangeToRangeScanner struct { + rc *RangeCodec + m *Map +} + +func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) error { + rangeScanner := (target).(RangeScanner) + + if src == nil { + return rangeScanner.ScanNull() + } + + ubr, err := parseUntypedBinaryRange(src) + if err != nil { + return err + } + + if ubr.LowerType == Empty { + return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType) + } + + lowerTarget, upperTarget := rangeScanner.ScanBounds() + + if ubr.LowerType == Inclusive || ubr.LowerType == Exclusive { + lowerPlan := plan.m.PlanScan(plan.rc.ElementType.OID, BinaryFormatCode, lowerTarget) + if lowerPlan == nil { + return fmt.Errorf("cannot scan into %v from range element", lowerTarget) + } + + err = lowerPlan.Scan(ubr.Lower, lowerTarget) + if err != nil { + return fmt.Errorf("cannot scan into %v from range element: %w", lowerTarget, err) + } + } + + if ubr.UpperType == Inclusive || ubr.UpperType == Exclusive { + upperPlan := plan.m.PlanScan(plan.rc.ElementType.OID, BinaryFormatCode, upperTarget) + if upperPlan == nil { + return fmt.Errorf("cannot scan into %v from range element", upperTarget) + } + + err = upperPlan.Scan(ubr.Upper, upperTarget) + if err != nil { + return fmt.Errorf("cannot scan into %v from range element: %w", upperTarget, err) + } + } + + return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType) +} + +type scanPlanTextRangeToRangeScanner struct { + rc *RangeCodec + m *Map +} + +func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error { + rangeScanner := (target).(RangeScanner) + + if src == nil { + return rangeScanner.ScanNull() + } + + utr, err := parseUntypedTextRange(string(src)) + if err != nil { + return err + } + + if utr.LowerType == Empty { + return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType) + } + + lowerTarget, upperTarget := rangeScanner.ScanBounds() + + if utr.LowerType == Inclusive || utr.LowerType == Exclusive { + lowerPlan := plan.m.PlanScan(plan.rc.ElementType.OID, TextFormatCode, lowerTarget) + if lowerPlan == nil { + return fmt.Errorf("cannot scan into %v from range element", lowerTarget) + } + + err = lowerPlan.Scan([]byte(utr.Lower), lowerTarget) + if err != nil { + return fmt.Errorf("cannot scan into %v from range element: %w", lowerTarget, err) + } + } + + if utr.UpperType == Inclusive || utr.UpperType == Exclusive { + upperPlan := plan.m.PlanScan(plan.rc.ElementType.OID, TextFormatCode, upperTarget) + if upperPlan == nil { + return fmt.Errorf("cannot scan into %v from range element", upperTarget) + } + + err = upperPlan.Scan([]byte(utr.Upper), upperTarget) + if err != nil { + return fmt.Errorf("cannot scan into %v from range element: %w", upperTarget, err) + } + } + + return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType) +} + +func (c *RangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + switch format { + case TextFormatCode: + return string(src), nil + case BinaryFormatCode: + buf := make([]byte, len(src)) + copy(buf, src) + return buf, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } +} + +func (c *RangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var r Range[any] + err := c.PlanScan(m, oid, format, &r).Scan(src, &r) + return r, err +} diff --git a/pgtype/range_codec_test.go b/pgtype/range_codec_test.go new file mode 100644 index 000000000..f70b7a590 --- /dev/null +++ b/pgtype/range_codec_test.go @@ -0,0 +1,161 @@ +package pgtype_test + +import ( + "context" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" +) + +func TestRangeCodecTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int4range", []pgxtest.ValueRoundTripTest{ + { + pgtype.Range[pgtype.Int4]{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, + new(pgtype.Range[pgtype.Int4]), + isExpectedEq(pgtype.Range[pgtype.Int4]{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}), + }, + { + pgtype.Range[pgtype.Int4]{ + LowerType: pgtype.Inclusive, + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, + UpperType: pgtype.Exclusive, Valid: true, + }, + new(pgtype.Range[pgtype.Int4]), + isExpectedEq(pgtype.Range[pgtype.Int4]{ + LowerType: pgtype.Inclusive, + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, + UpperType: pgtype.Exclusive, Valid: true, + }), + }, + {pgtype.Range[pgtype.Int4]{}, new(pgtype.Range[pgtype.Int4]), isExpectedEq(pgtype.Range[pgtype.Int4]{})}, + {nil, new(pgtype.Range[pgtype.Int4]), isExpectedEq(pgtype.Range[pgtype.Int4]{})}, + }) +} + +func TestRangeCodecTranscodeCompatibleRangeElementTypes(t *testing.T) { + ctr := defaultConnTestRunner + ctr.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, ctr, nil, "numrange", []pgxtest.ValueRoundTripTest{ + { + pgtype.Range[pgtype.Float8]{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, + new(pgtype.Range[pgtype.Float8]), + isExpectedEq(pgtype.Range[pgtype.Float8]{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}), + }, + { + pgtype.Range[pgtype.Float8]{ + LowerType: pgtype.Inclusive, + Lower: pgtype.Float8{Float64: 1, Valid: true}, + Upper: pgtype.Float8{Float64: 5, Valid: true}, + UpperType: pgtype.Exclusive, Valid: true, + }, + new(pgtype.Range[pgtype.Float8]), + isExpectedEq(pgtype.Range[pgtype.Float8]{ + LowerType: pgtype.Inclusive, + Lower: pgtype.Float8{Float64: 1, Valid: true}, + Upper: pgtype.Float8{Float64: 5, Valid: true}, + UpperType: pgtype.Exclusive, Valid: true, + }), + }, + {pgtype.Range[pgtype.Float8]{}, new(pgtype.Range[pgtype.Float8]), isExpectedEq(pgtype.Range[pgtype.Float8]{})}, + {nil, new(pgtype.Range[pgtype.Float8]), isExpectedEq(pgtype.Range[pgtype.Float8]{})}, + }) +} + +func TestRangeCodecScanRangeTwiceWithUnbounded(t *testing.T) { + skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var r pgtype.Range[pgtype.Int4] + + err := conn.QueryRow(context.Background(), `select '[1,5)'::int4range`).Scan(&r) + require.NoError(t, err) + + require.Equal( + t, + pgtype.Range[pgtype.Int4]{ + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + r, + ) + + err = conn.QueryRow(ctx, `select '[1,)'::int4range`).Scan(&r) + require.NoError(t, err) + + require.Equal( + t, + pgtype.Range[pgtype.Int4]{ + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Unbounded, + Valid: true, + }, + r, + ) + + err = conn.QueryRow(ctx, `select 'empty'::int4range`).Scan(&r) + require.NoError(t, err) + + require.Equal( + t, + pgtype.Range[pgtype.Int4]{ + Lower: pgtype.Int4{}, + Upper: pgtype.Int4{}, + LowerType: pgtype.Empty, + UpperType: pgtype.Empty, + Valid: true, + }, + r, + ) + }) +} + +func TestRangeCodecDecodeValue(t *testing.T) { + skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + for _, tt := range []struct { + sql string + expected any + }{ + { + sql: `select '[1,5)'::int4range`, + expected: pgtype.Range[any]{ + Lower: int32(1), + Upper: int32(5), + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + }, + } { + t.Run(tt.sql, func(t *testing.T) { + rows, err := conn.Query(ctx, tt.sql) + require.NoError(t, err) + + for rows.Next() { + values, err := rows.Values() + require.NoError(t, err) + require.Len(t, values, 1) + require.Equal(t, tt.expected, values[0]) + } + + require.NoError(t, rows.Err()) + }) + } + }) +} diff --git a/pgtype/range_test.go b/pgtype/range_test.go index 9e16df59c..eb9486b08 100644 --- a/pgtype/range_test.go +++ b/pgtype/range_test.go @@ -8,68 +8,68 @@ import ( func TestParseUntypedTextRange(t *testing.T) { tests := []struct { src string - result UntypedTextRange + result untypedTextRange err error }{ { src: `[1,2)`, - result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, + result: untypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, err: nil, }, { src: `[1,2]`, - result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Inclusive}, + result: untypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Inclusive}, err: nil, }, { src: `(1,3)`, - result: UntypedTextRange{Lower: "1", Upper: "3", LowerType: Exclusive, UpperType: Exclusive}, + result: untypedTextRange{Lower: "1", Upper: "3", LowerType: Exclusive, UpperType: Exclusive}, err: nil, }, { src: ` [1,2) `, - result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, + result: untypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, err: nil, }, { src: `[ foo , bar )`, - result: UntypedTextRange{Lower: " foo ", Upper: " bar ", LowerType: Inclusive, UpperType: Exclusive}, + result: untypedTextRange{Lower: " foo ", Upper: " bar ", LowerType: Inclusive, UpperType: Exclusive}, err: nil, }, { src: `["foo","bar")`, - result: UntypedTextRange{Lower: "foo", Upper: "bar", LowerType: Inclusive, UpperType: Exclusive}, + result: untypedTextRange{Lower: "foo", Upper: "bar", LowerType: Inclusive, UpperType: Exclusive}, err: nil, }, { src: `["f""oo","b""ar")`, - result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, + result: untypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, err: nil, }, { src: `["f""oo","b""ar")`, - result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, + result: untypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, err: nil, }, { src: `["","bar")`, - result: UntypedTextRange{Lower: ``, Upper: `bar`, LowerType: Inclusive, UpperType: Exclusive}, + result: untypedTextRange{Lower: ``, Upper: `bar`, LowerType: Inclusive, UpperType: Exclusive}, err: nil, }, { src: `[f\"oo\,,b\\ar\))`, - result: UntypedTextRange{Lower: `f"oo,`, Upper: `b\ar)`, LowerType: Inclusive, UpperType: Exclusive}, + result: untypedTextRange{Lower: `f"oo,`, Upper: `b\ar)`, LowerType: Inclusive, UpperType: Exclusive}, err: nil, }, { src: `empty`, - result: UntypedTextRange{Lower: "", Upper: "", LowerType: Empty, UpperType: Empty}, + result: untypedTextRange{Lower: "", Upper: "", LowerType: Empty, UpperType: Empty}, err: nil, }, } for i, tt := range tests { - r, err := ParseUntypedTextRange(tt.src) + r, err := parseUntypedTextRange(tt.src) if err != tt.err { t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) continue @@ -96,63 +96,63 @@ func TestParseUntypedTextRange(t *testing.T) { func TestParseUntypedBinaryRange(t *testing.T) { tests := []struct { src []byte - result UntypedBinaryRange + result untypedBinaryRange err error }{ { src: []byte{0, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Exclusive}, + result: untypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Exclusive}, err: nil, }, { src: []byte{1}, - result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Empty, UpperType: Empty}, + result: untypedBinaryRange{Lower: nil, Upper: nil, LowerType: Empty, UpperType: Empty}, err: nil, }, { src: []byte{2, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Exclusive}, + result: untypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Exclusive}, err: nil, }, { src: []byte{4, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Inclusive}, + result: untypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Inclusive}, err: nil, }, { src: []byte{6, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Inclusive}, + result: untypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Inclusive}, err: nil, }, { src: []byte{8, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Exclusive}, + result: untypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Exclusive}, err: nil, }, { src: []byte{12, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Inclusive}, + result: untypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Inclusive}, err: nil, }, { src: []byte{16, 0, 0, 0, 2, 0, 4}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Exclusive, UpperType: Unbounded}, + result: untypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Exclusive, UpperType: Unbounded}, err: nil, }, { src: []byte{18, 0, 0, 0, 2, 0, 4}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Inclusive, UpperType: Unbounded}, + result: untypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Inclusive, UpperType: Unbounded}, err: nil, }, { src: []byte{24}, - result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Unbounded, UpperType: Unbounded}, + result: untypedBinaryRange{Lower: nil, Upper: nil, LowerType: Unbounded, UpperType: Unbounded}, err: nil, }, } for i, tt := range tests { - r, err := ParseUntypedBinaryRange(tt.src) + r, err := parseUntypedBinaryRange(tt.src) if err != tt.err { t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) continue @@ -166,11 +166,11 @@ func TestParseUntypedBinaryRange(t *testing.T) { t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) } - if bytes.Compare(r.Lower, tt.result.Lower) != 0 { + if !bytes.Equal(r.Lower, tt.result.Lower) { t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) } - if bytes.Compare(r.Upper, tt.result.Upper) != 0 { + if !bytes.Equal(r.Upper, tt.result.Upper) { t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) } } diff --git a/pgtype/record.go b/pgtype/record.go deleted file mode 100644 index aeca1c546..000000000 --- a/pgtype/record.go +++ /dev/null @@ -1,129 +0,0 @@ -package pgtype - -import ( - "encoding/binary" - "reflect" - - "github.com/pkg/errors" -) - -// Record is the generic PostgreSQL record type such as is created with the -// "row" function. Record only implements BinaryEncoder and Value. The text -// format output format from PostgreSQL does not include type information and is -// therefore impossible to decode. No encoders are implemented because -// PostgreSQL does not support input of generic records. -type Record struct { - Fields []Value - Status Status -} - -func (dst *Record) Set(src interface{}) error { - if src == nil { - *dst = Record{Status: Null} - return nil - } - - switch value := src.(type) { - case []Value: - *dst = Record{Fields: value, Status: Present} - default: - return errors.Errorf("cannot convert %v to Record", src) - } - - return nil -} - -func (dst *Record) Get() interface{} { - switch dst.Status { - case Present: - return dst.Fields - case Null: - return nil - default: - return dst.Status - } -} - -func (src *Record) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *[]Value: - *v = make([]Value, len(src.Fields)) - copy(*v, src.Fields) - return nil - case *[]interface{}: - *v = make([]interface{}, len(src.Fields)) - for i := range *v { - (*v)[i] = src.Fields[i].Get() - } - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Record{Status: Null} - return nil - } - - rp := 0 - - if len(src[rp:]) < 4 { - return errors.Errorf("Record incomplete %v", src) - } - fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - - fields := make([]Value, fieldCount) - - for i := 0; i < fieldCount; i++ { - if len(src[rp:]) < 8 { - return errors.Errorf("Record incomplete %v", src) - } - fieldOID := OID(binary.BigEndian.Uint32(src[rp:])) - rp += 4 - - fieldLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - - var binaryDecoder BinaryDecoder - if dt, ok := ci.DataTypeForOID(fieldOID); ok { - binaryDecoder, _ = dt.Value.(BinaryDecoder) - } - if binaryDecoder == nil { - return errors.Errorf("unknown oid while decoding record: %v", fieldOID) - } - - var fieldBytes []byte - if fieldLen >= 0 { - if len(src[rp:]) < fieldLen { - return errors.Errorf("Record incomplete %v", src) - } - fieldBytes = src[rp : rp+fieldLen] - rp += fieldLen - } - - // Duplicate struct to scan into - binaryDecoder = reflect.New(reflect.ValueOf(binaryDecoder).Elem().Type()).Interface().(BinaryDecoder) - - if err := binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { - return err - } - - fields[i] = binaryDecoder.(Value) - } - - *dst = Record{Fields: fields, Status: Present} - - return nil -} diff --git a/pgtype/record_codec.go b/pgtype/record_codec.go new file mode 100644 index 000000000..90b9bd4bb --- /dev/null +++ b/pgtype/record_codec.go @@ -0,0 +1,124 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" +) + +// ArrayGetter is a type that can be converted into a PostgreSQL array. + +// RecordCodec is a codec for the generic PostgreSQL record type such as is created with the "row" function. Record can +// only decode the binary format. The text format output format from PostgreSQL does not include type information and +// is therefore impossible to decode. Encoding is impossible because PostgreSQL does not support input of generic +// records. +type RecordCodec struct{} + +func (RecordCodec) FormatSupported(format int16) bool { + return format == BinaryFormatCode +} + +func (RecordCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (RecordCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + return nil +} + +func (RecordCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + if format == BinaryFormatCode { + switch target.(type) { + case CompositeIndexScanner: + return &scanPlanBinaryRecordToCompositeIndexScanner{m: m} + } + } + + return nil +} + +type scanPlanBinaryRecordToCompositeIndexScanner struct { + m *Map +} + +func (plan *scanPlanBinaryRecordToCompositeIndexScanner) Scan(src []byte, target any) error { + targetScanner := (target).(CompositeIndexScanner) + + if src == nil { + return targetScanner.ScanNull() + } + + scanner := NewCompositeBinaryScanner(plan.m, src) + for i := 0; scanner.Next(); i++ { + fieldTarget := targetScanner.ScanIndex(i) + if fieldTarget != nil { + fieldPlan := plan.m.PlanScan(scanner.OID(), BinaryFormatCode, fieldTarget) + if fieldPlan == nil { + return fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), fieldTarget) + } + + err := fieldPlan.Scan(scanner.Bytes(), fieldTarget) + if err != nil { + return err + } + } + } + + if err := scanner.Err(); err != nil { + return err + } + + return nil +} + +func (RecordCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + switch format { + case TextFormatCode: + return string(src), nil + case BinaryFormatCode: + buf := make([]byte, len(src)) + copy(buf, src) + return buf, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } +} + +func (RecordCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + switch format { + case TextFormatCode: + return string(src), nil + case BinaryFormatCode: + scanner := NewCompositeBinaryScanner(m, src) + values := make([]any, scanner.FieldCount()) + for i := 0; scanner.Next(); i++ { + var v any + fieldPlan := m.PlanScan(scanner.OID(), BinaryFormatCode, &v) + if fieldPlan == nil { + return nil, fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), v) + } + + err := fieldPlan.Scan(scanner.Bytes(), &v) + if err != nil { + return nil, err + } + + values[i] = v + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + return values, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } +} diff --git a/pgtype/record_codec_test.go b/pgtype/record_codec_test.go new file mode 100644 index 000000000..2189f99c1 --- /dev/null +++ b/pgtype/record_codec_test.go @@ -0,0 +1,73 @@ +package pgtype_test + +import ( + "context" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/require" +) + +func TestRecordCodec(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var a string + var b int32 + err := conn.QueryRow(ctx, `select row('foo'::text, 42::int4)`).Scan(pgtype.CompositeFields{&a, &b}) + require.NoError(t, err) + + require.Equal(t, "foo", a) + require.Equal(t, int32(42), b) + }) +} + +func TestRecordCodecDecodeValue(t *testing.T) { + skipCockroachDB(t, "Server converts row int4 to int8") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + for _, tt := range []struct { + sql string + expected any + }{ + { + sql: `select row()`, + expected: []any{}, + }, + { + sql: `select row('foo'::text, 42::int4)`, + expected: []any{"foo", int32(42)}, + }, + { + sql: `select row(100.0::float4, 1.09::float4)`, + expected: []any{float32(100), float32(1.09)}, + }, + { + sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, + expected: []any{"foo", []any{int32(1), int32(2), nil, int32(4)}, int32(42)}, + }, + { + sql: `select row(null)`, + expected: []any{nil}, + }, + { + sql: `select null::record`, + expected: nil, + }, + } { + t.Run(tt.sql, func(t *testing.T) { + rows, err := conn.Query(context.Background(), tt.sql) + require.NoError(t, err) + defer rows.Close() + + for rows.Next() { + values, err := rows.Values() + require.NoError(t, err) + require.Len(t, values, 1) + require.Equal(t, tt.expected, values[0]) + } + + require.NoError(t, rows.Err()) + }) + } + }) +} diff --git a/pgtype/record_test.go b/pgtype/record_test.go deleted file mode 100644 index 23ec2cd33..000000000 --- a/pgtype/record_test.go +++ /dev/null @@ -1,183 +0,0 @@ -package pgtype_test - -import ( - "fmt" - "reflect" - "testing" - - "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestRecordTranscode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustClose(t, conn) - - tests := []struct { - sql string - expected pgtype.Record - }{ - { - sql: `select row()`, - expected: pgtype.Record{ - Fields: []pgtype.Value{}, - Status: pgtype.Present, - }, - }, - { - sql: `select row('foo'::text, 42::int4)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Status: pgtype.Present}, - &pgtype.Int4{Int: 42, Status: pgtype.Present}, - }, - Status: pgtype.Present, - }, - }, - { - sql: `select row(100.0::float4, 1.09::float4)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Float4{Float: 100, Status: pgtype.Present}, - &pgtype.Float4{Float: 1.09, Status: pgtype.Present}, - }, - Status: pgtype.Present, - }, - }, - { - sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Status: pgtype.Present}, - &pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Int: 4, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Int4{Int: 42, Status: pgtype.Present}, - }, - Status: pgtype.Present, - }, - }, - { - sql: `select row(null)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Unknown{Status: pgtype.Null}, - }, - Status: pgtype.Present, - }, - }, - { - sql: `select null::record`, - expected: pgtype.Record{ - Status: pgtype.Null, - }, - }, - } - - for i, tt := range tests { - psName := fmt.Sprintf("test%d", i) - ps, err := conn.Prepare(psName, tt.sql) - if err != nil { - t.Fatal(err) - } - ps.FieldDescriptions[0].FormatCode = pgx.BinaryFormatCode - - var result pgtype.Record - if err := conn.QueryRow(psName).Scan(&result); err != nil { - t.Errorf("%d: %v", i, err) - continue - } - - if !reflect.DeepEqual(tt.expected, result) { - t.Errorf("%d: expected %#v, got %#v", i, tt.expected, result) - } - } -} - -func TestRecordWithUnknownOID(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustClose(t, conn) - - _, err := conn.Exec(`drop type if exists floatrange; - -create type floatrange as range ( - subtype = float8, - subtype_diff = float8mi -);`) - if err != nil { - t.Fatal(err) - } - defer conn.Exec("drop type floatrange") - - var result pgtype.Record - err = conn.QueryRow("select row('foo'::text, floatrange(1, 10), 'bar'::text)").Scan(&result) - if err == nil { - t.Errorf("expected error but none") - } -} - -func TestRecordAssignTo(t *testing.T) { - var valueSlice []pgtype.Value - var interfaceSlice []interface{} - - simpleTests := []struct { - src pgtype.Record - dst interface{} - expected interface{} - }{ - { - src: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Status: pgtype.Present}, - &pgtype.Int4{Int: 42, Status: pgtype.Present}, - }, - Status: pgtype.Present, - }, - dst: &valueSlice, - expected: []pgtype.Value{ - &pgtype.Text{String: "foo", Status: pgtype.Present}, - &pgtype.Int4{Int: 42, Status: pgtype.Present}, - }, - }, - { - src: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Status: pgtype.Present}, - &pgtype.Int4{Int: 42, Status: pgtype.Present}, - }, - Status: pgtype.Present, - }, - dst: &interfaceSlice, - expected: []interface{}{"foo", int32(42)}, - }, - { - src: pgtype.Record{Status: pgtype.Null}, - dst: &valueSlice, - expected: (([]pgtype.Value)(nil)), - }, - { - src: pgtype.Record{Status: pgtype.Null}, - dst: &interfaceSlice, - expected: (([]interface{})(nil)), - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} diff --git a/pgtype/register_default_pg_types.go b/pgtype/register_default_pg_types.go new file mode 100644 index 000000000..be1ca4a18 --- /dev/null +++ b/pgtype/register_default_pg_types.go @@ -0,0 +1,35 @@ +//go:build !nopgxregisterdefaulttypes + +package pgtype + +func registerDefaultPgTypeVariants[T any](m *Map, name string) { + arrayName := "_" + name + + var value T + m.RegisterDefaultPgType(value, name) // T + m.RegisterDefaultPgType(&value, name) // *T + + var sliceT []T + m.RegisterDefaultPgType(sliceT, arrayName) // []T + m.RegisterDefaultPgType(&sliceT, arrayName) // *[]T + + var slicePtrT []*T + m.RegisterDefaultPgType(slicePtrT, arrayName) // []*T + m.RegisterDefaultPgType(&slicePtrT, arrayName) // *[]*T + + var arrayOfT Array[T] + m.RegisterDefaultPgType(arrayOfT, arrayName) // Array[T] + m.RegisterDefaultPgType(&arrayOfT, arrayName) // *Array[T] + + var arrayOfPtrT Array[*T] + m.RegisterDefaultPgType(arrayOfPtrT, arrayName) // Array[*T] + m.RegisterDefaultPgType(&arrayOfPtrT, arrayName) // *Array[*T] + + var flatArrayOfT FlatArray[T] + m.RegisterDefaultPgType(flatArrayOfT, arrayName) // FlatArray[T] + m.RegisterDefaultPgType(&flatArrayOfT, arrayName) // *FlatArray[T] + + var flatArrayOfPtrT FlatArray[*T] + m.RegisterDefaultPgType(flatArrayOfPtrT, arrayName) // FlatArray[*T] + m.RegisterDefaultPgType(&flatArrayOfPtrT, arrayName) // *FlatArray[*T] +} diff --git a/pgtype/register_default_pg_types_disabled.go b/pgtype/register_default_pg_types_disabled.go new file mode 100644 index 000000000..56fe7c226 --- /dev/null +++ b/pgtype/register_default_pg_types_disabled.go @@ -0,0 +1,6 @@ +//go:build nopgxregisterdefaulttypes + +package pgtype + +func registerDefaultPgTypeVariants[T any](m *Map, name string) { +} diff --git a/pgtype/testutil/testutil.go b/pgtype/testutil/testutil.go deleted file mode 100644 index 0effb42d9..000000000 --- a/pgtype/testutil/testutil.go +++ /dev/null @@ -1,297 +0,0 @@ -package testutil - -import ( - "context" - "database/sql" - "fmt" - "os" - "reflect" - "testing" - - "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" - _ "github.com/jackc/pgx/stdlib" - _ "github.com/lib/pq" -) - -func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { - var sqlDriverName string - switch driverName { - case "github.com/lib/pq": - sqlDriverName = "postgres" - case "github.com/jackc/pgx/stdlib": - sqlDriverName = "pgx" - default: - t.Fatalf("Unknown driver %v", driverName) - } - - db, err := sql.Open(sqlDriverName, os.Getenv("PGX_TEST_DATABASE")) - if err != nil { - t.Fatal(err) - } - - return db -} - -func MustConnectPgx(t testing.TB) *pgx.Conn { - config, err := pgx.ParseConnectionString(os.Getenv("PGX_TEST_DATABASE")) - if err != nil { - t.Fatal(err) - } - - conn, err := pgx.Connect(config) - if err != nil { - t.Fatal(err) - } - - return conn -} - -func MustClose(t testing.TB, conn interface { - Close() error -}) { - err := conn.Close() - if err != nil { - t.Fatal(err) - } -} - -type forceTextEncoder struct { - e pgtype.TextEncoder -} - -func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - return f.e.EncodeText(ci, buf) -} - -type forceBinaryEncoder struct { - e pgtype.BinaryEncoder -} - -func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - return f.e.EncodeBinary(ci, buf) -} - -func ForceEncoder(e interface{}, formatCode int16) interface{} { - switch formatCode { - case pgx.TextFormatCode: - if e, ok := e.(pgtype.TextEncoder); ok { - return forceTextEncoder{e: e} - } - case pgx.BinaryFormatCode: - if e, ok := e.(pgtype.BinaryEncoder); ok { - return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)} - } - } - return nil -} - -func TestSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) { - TestSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - }) -} - -func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) - } -} - -func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - conn := MustConnectPgx(t) - defer MustClose(t, conn) - - ps, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName)) - if err != nil { - t.Fatal(err) - } - - formats := []struct { - name string - formatCode int16 - }{ - {name: "TextFormat", formatCode: pgx.TextFormatCode}, - {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, - } - - for i, v := range values { - for _, fc := range formats { - ps.FieldDescriptions[0].FormatCode = fc.formatCode - vEncoder := ForceEncoder(v, fc.formatCode) - if vEncoder == nil { - t.Logf("Skipping: %#v does not implement %v", v, fc.name) - continue - } - // Derefence value if it is a pointer - derefV := v - refVal := reflect.ValueOf(v) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - err := conn.QueryRow("test", ForceEncoder(v, fc.formatCode)).Scan(result.Interface()) - if err != nil { - t.Errorf("%v %d: %v", fc.name, i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) - } - } - } -} - -func TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - conn := MustConnectPgx(t) - defer MustClose(t, conn) - - for i, v := range values { - // Derefence value if it is a pointer - derefV := v - refVal := reflect.ValueOf(v) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - err := conn.QueryRowEx( - context.Background(), - fmt.Sprintf("select ($1)::%s", pgTypeName), - &pgx.QueryExOptions{SimpleProtocol: true}, - v, - ).Scan(result.Interface()) - if err != nil { - t.Errorf("Simple protocol %d: %v", i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("Simple protocol %d: expected %v, got %v", i, derefV, result.Elem().Interface()) - } - } -} - -func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - conn := MustConnectDatabaseSQL(t, driverName) - defer MustClose(t, conn) - - ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName)) - if err != nil { - t.Fatal(err) - } - - for i, v := range values { - // Derefence value if it is a pointer - derefV := v - refVal := reflect.ValueOf(v) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - err := ps.QueryRow(v).Scan(result.Interface()) - if err != nil { - t.Errorf("%v %d: %v", driverName, i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) - } - } -} - -type NormalizeTest struct { - SQL string - Value interface{} -} - -func TestSuccessfulNormalize(t testing.TB, tests []NormalizeTest) { - TestSuccessfulNormalizeEqFunc(t, tests, func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - }) -} - -func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { - TestPgxSuccessfulNormalizeEqFunc(t, tests, eqFunc) - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - TestDatabaseSQLSuccessfulNormalizeEqFunc(t, driverName, tests, eqFunc) - } -} - -func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { - conn := MustConnectPgx(t) - defer MustClose(t, conn) - - formats := []struct { - name string - formatCode int16 - }{ - {name: "TextFormat", formatCode: pgx.TextFormatCode}, - {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, - } - - for i, tt := range tests { - for _, fc := range formats { - psName := fmt.Sprintf("test%d", i) - ps, err := conn.Prepare(psName, tt.SQL) - if err != nil { - t.Fatal(err) - } - - ps.FieldDescriptions[0].FormatCode = fc.formatCode - if ForceEncoder(tt.Value, fc.formatCode) == nil { - t.Logf("Skipping: %#v does not implement %v", tt.Value, fc.name) - continue - } - // Derefence value if it is a pointer - derefV := tt.Value - refVal := reflect.ValueOf(tt.Value) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - err = conn.QueryRow(psName).Scan(result.Interface()) - if err != nil { - t.Errorf("%v %d: %v", fc.name, i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) - } - } - } -} - -func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { - conn := MustConnectDatabaseSQL(t, driverName) - defer MustClose(t, conn) - - for i, tt := range tests { - ps, err := conn.Prepare(tt.SQL) - if err != nil { - t.Errorf("%d. %v", i, err) - continue - } - - // Derefence value if it is a pointer - derefV := tt.Value - refVal := reflect.ValueOf(tt.Value) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - err = ps.QueryRow().Scan(result.Interface()) - if err != nil { - t.Errorf("%v %d: %v", driverName, i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) - } - } -} diff --git a/pgtype/text.go b/pgtype/text.go index bceeffd40..e08b12549 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -3,161 +3,224 @@ package pgtype import ( "database/sql/driver" "encoding/json" - - "github.com/pkg/errors" + "fmt" ) +type TextScanner interface { + ScanText(v Text) error +} + +type TextValuer interface { + TextValue() (Text, error) +} + type Text struct { String string - Status Status + Valid bool } -func (dst *Text) Set(src interface{}) error { +// ScanText implements the [TextScanner] interface. +func (t *Text) ScanText(v Text) error { + *t = v + return nil +} + +// TextValue implements the [TextValuer] interface. +func (t Text) TextValue() (Text, error) { + return t, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (dst *Text) Scan(src any) error { if src == nil { - *dst = Text{Status: Null} + *dst = Text{} return nil } - switch value := src.(type) { + switch src := src.(type) { case string: - *dst = Text{String: value, Status: Present} - case *string: - if value == nil { - *dst = Text{Status: Null} - } else { - *dst = Text{String: *value, Status: Present} - } + *dst = Text{String: src, Valid: true} + return nil case []byte: - if value == nil { - *dst = Text{Status: Null} - } else { - *dst = Text{String: string(value), Status: Present} - } - default: - if originalSrc, ok := underlyingStringType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to Text", value) + *dst = Text{String: string(src), Valid: true} + return nil } - return nil + return fmt.Errorf("cannot scan %T", src) } -func (dst *Text) Get() interface{} { - switch dst.Status { - case Present: - return dst.String - case Null: - return nil - default: - return dst.Status +// Value implements the [database/sql/driver.Valuer] interface. +func (src Text) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil } + return src.String, nil } -func (src *Text) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *string: - *v = src.String - return nil - case *[]byte: - *v = make([]byte, len(src.String)) - copy(*v, src.String) - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) +// MarshalJSON implements the [encoding/json.Marshaler] interface. +func (src Text) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil } - return errors.Errorf("cannot decode %v into %T", src, dst) + return json.Marshal(src.String) } -func (dst *Text) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Text{Status: Null} - return nil +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. +func (dst *Text) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *dst = Text{} + } else { + *dst = Text{String: *s, Valid: true} } - *dst = Text{String: string(src), Status: Present} return nil } -func (dst *Text) DecodeBinary(ci *ConnInfo, src []byte) error { - return dst.DecodeText(ci, src) +type TextCodec struct{} + +func (TextCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode } -func (src *Text) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined +func (TextCodec) PreferredFormat() int16 { + return TextFormatCode +} + +func (TextCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case TextFormatCode, BinaryFormatCode: + switch value.(type) { + case string: + return encodePlanTextCodecString{} + case []byte: + return encodePlanTextCodecByteSlice{} + case TextValuer: + return encodePlanTextCodecTextValuer{} + } } - return append(buf, src.String...), nil + return nil +} + +type encodePlanTextCodecString struct{} + +func (encodePlanTextCodecString) Encode(value any, buf []byte) (newBuf []byte, err error) { + s := value.(string) + buf = append(buf, s...) + return buf, nil } -func (src *Text) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return src.EncodeText(ci, buf) +type encodePlanTextCodecByteSlice struct{} + +func (encodePlanTextCodecByteSlice) Encode(value any, buf []byte) (newBuf []byte, err error) { + s := value.([]byte) + buf = append(buf, s...) + return buf, nil } -// Scan implements the database/sql Scanner interface. -func (dst *Text) Scan(src interface{}) error { - if src == nil { - *dst = Text{Status: Null} - return nil +type encodePlanTextCodecStringer struct{} + +func (encodePlanTextCodecStringer) Encode(value any, buf []byte) (newBuf []byte, err error) { + s := value.(fmt.Stringer) + buf = append(buf, s.String()...) + return buf, nil +} + +type encodePlanTextCodecTextValuer struct{} + +func (encodePlanTextCodecTextValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + text, err := value.(TextValuer).TextValue() + if err != nil { + return nil, err } - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + if !text.Valid { + return nil, nil } - return errors.Errorf("cannot scan %T", src) + buf = append(buf, text.String...) + return buf, nil } -// Value implements the database/sql/driver Valuer interface. -func (src *Text) Value() (driver.Value, error) { - switch src.Status { - case Present: - return src.String, nil - case Null: - return nil, nil - default: - return nil, errUndefined +func (TextCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case TextFormatCode, BinaryFormatCode: + switch target.(type) { + case *string: + return scanPlanTextAnyToString{} + case *[]byte: + return scanPlanAnyToNewByteSlice{} + case BytesScanner: + return scanPlanAnyToByteScanner{} + case TextScanner: + return scanPlanTextAnyToTextScanner{} + } } + + return nil } -func (src *Text) MarshalJSON() ([]byte, error) { - switch src.Status { - case Present: - return json.Marshal(src.String) - case Null: - return []byte("null"), nil - case Undefined: - return nil, errUndefined +func (c TextCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return c.DecodeValue(m, oid, format, src) +} + +func (c TextCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil } - return nil, errBadStatus + return string(src), nil } -func (dst *Text) UnmarshalJSON(b []byte) error { - var s string - err := json.Unmarshal(b, &s) - if err != nil { - return err +type scanPlanTextAnyToString struct{} + +func (scanPlanTextAnyToString) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) } - *dst = Text{String: s, Status: Present} + p := (dst).(*string) + *p = string(src) + + return nil +} + +type scanPlanAnyToNewByteSlice struct{} + +func (scanPlanAnyToNewByteSlice) Scan(src []byte, dst any) error { + p := (dst).(*[]byte) + if src == nil { + *p = nil + } else { + *p = make([]byte, len(src)) + copy(*p, src) + } return nil } + +type scanPlanAnyToByteScanner struct{} + +func (scanPlanAnyToByteScanner) Scan(src []byte, dst any) error { + p := (dst).(BytesScanner) + return p.ScanBytes(src) +} + +type scanPlanTextAnyToTextScanner struct{} + +func (scanPlanTextAnyToTextScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TextScanner) + + if src == nil { + return scanner.ScanText(Text{}) + } + + return scanner.ScanText(Text{String: string(src), Valid: true}) +} diff --git a/pgtype/text_array.go b/pgtype/text_array.go deleted file mode 100644 index e40f4b863..000000000 --- a/pgtype/text_array.go +++ /dev/null @@ -1,300 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type TextArray struct { - Elements []Text - Dimensions []ArrayDimension - Status Status -} - -func (dst *TextArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = TextArray{Status: Null} - return nil - } - - switch value := src.(type) { - - case []string: - if value == nil { - *dst = TextArray{Status: Null} - } else if len(value) == 0 { - *dst = TextArray{Status: Present} - } else { - elements := make([]Text, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TextArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - default: - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to TextArray", value) - } - - return nil -} - -func (dst *TextArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *TextArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *TextArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = TextArray{Status: Null} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Text - - if len(uta.Elements) > 0 { - elements = make([]Text, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Text - var elemSrc []byte - if s != "NULL" { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = TextArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} - - return nil -} - -func (dst *TextArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = TextArray{Status: Null} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = TextArray{Dimensions: arrayHeader.Dimensions, Status: Present} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Text, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = TextArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} - return nil -} - -func (src *TextArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `"NULL"`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src *TextArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("text"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, errors.Errorf("unable to find oid for type name %v", "text") - } - - for i := range src.Elements { - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *TextArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *TextArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/text_array_test.go b/pgtype/text_array_test.go deleted file mode 100644 index 105d93534..000000000 --- a/pgtype/text_array_test.go +++ /dev/null @@ -1,152 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestTextArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "text[]", []interface{}{ - &pgtype.TextArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.TextArray{ - Elements: []pgtype.Text{ - {String: "foo", Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.TextArray{Status: pgtype.Null}, - &pgtype.TextArray{ - Elements: []pgtype.Text{ - {String: "bar ", Status: pgtype.Present}, - {String: "NuLL", Status: pgtype.Present}, - {String: `wow"quz\`, Status: pgtype.Present}, - {String: "", Status: pgtype.Present}, - {Status: pgtype.Null}, - {String: "null", Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.TextArray{ - Elements: []pgtype.Text{ - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "quz", Status: pgtype.Present}, - {String: "foo", Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestTextArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.TextArray - }{ - { - source: []string{"foo"}, - result: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]string)(nil)), - result: pgtype.TextArray{Status: pgtype.Null}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.TextArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestTextArrayAssignTo(t *testing.T) { - var stringSlice []string - type _stringSlice []string - var namedStringSlice _stringSlice - - simpleTests := []struct { - src pgtype.TextArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &stringSlice, - expected: []string{"foo"}, - }, - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{{String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &namedStringSlice, - expected: _stringSlice{"bar"}, - }, - { - src: pgtype.TextArray{Status: pgtype.Null}, - dst: &stringSlice, - expected: (([]string)(nil)), - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.TextArray - dst interface{} - }{ - { - src: pgtype.TextArray{ - Elements: []pgtype.Text{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &stringSlice, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/text_format_only_codec.go b/pgtype/text_format_only_codec.go new file mode 100644 index 000000000..d5e4cdb38 --- /dev/null +++ b/pgtype/text_format_only_codec.go @@ -0,0 +1,13 @@ +package pgtype + +type TextFormatOnlyCodec struct { + Codec +} + +func (c *TextFormatOnlyCodec) FormatSupported(format int16) bool { + return format == TextFormatCode && c.Codec.FormatSupported(format) +} + +func (TextFormatOnlyCodec) PreferredFormat() int16 { + return TextFormatCode +} diff --git a/pgtype/text_test.go b/pgtype/text_test.go index bd9718071..eb5d005ec 100644 --- a/pgtype/text_test.go +++ b/pgtype/text_test.go @@ -1,123 +1,178 @@ package pgtype_test import ( - "bytes" - "reflect" + "context" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" ) -func TestTextTranscode(t *testing.T) { +type someFmtStringer struct{} + +func (someFmtStringer) String() string { + return "some fmt.Stringer" +} + +func TestTextCodec(t *testing.T) { for _, pgTypeName := range []string{"text", "varchar"} { - testutil.TestSuccessfulTranscode(t, pgTypeName, []interface{}{ - &pgtype.Text{String: "", Status: pgtype.Present}, - &pgtype.Text{String: "foo", Status: pgtype.Present}, - &pgtype.Text{Status: pgtype.Null}, + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, pgTypeName, []pgxtest.ValueRoundTripTest{ + { + pgtype.Text{String: "", Valid: true}, + new(pgtype.Text), + isExpectedEq(pgtype.Text{String: "", Valid: true}), + }, + { + pgtype.Text{String: "foo", Valid: true}, + new(pgtype.Text), + isExpectedEq(pgtype.Text{String: "foo", Valid: true}), + }, + {nil, new(pgtype.Text), isExpectedEq(pgtype.Text{})}, + {"foo", new(string), isExpectedEq("foo")}, + {someFmtStringer{}, new(string), isExpectedEq("some fmt.Stringer")}, }) } } -func TestTextSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Text - }{ - {source: "foo", result: pgtype.Text{String: "foo", Status: pgtype.Present}}, - {source: _string("bar"), result: pgtype.Text{String: "bar", Status: pgtype.Present}}, - {source: (*string)(nil), result: pgtype.Text{Status: pgtype.Null}}, - } +// name is PostgreSQL's special 63-byte data type, used for identifiers like table names. The pg_class.relname column +// is a good example of where the name data type is used. +// +// TextCodec does not do length checking. Inputting a longer name into PostgreSQL will result in silent truncation to +// 63 bytes. +// +// Length checking would be possible with a Codec specialized for "name" but it would be perfect because a +// custom-compiled PostgreSQL could have set NAMEDATALEN to a different value rather than the default 63. +// +// So this is simply a smoke test of the name type. +func TestTextCodecName(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "name", []pgxtest.ValueRoundTripTest{ + { + pgtype.Text{String: "", Valid: true}, + new(pgtype.Text), + isExpectedEq(pgtype.Text{String: "", Valid: true}), + }, + { + pgtype.Text{String: "foo", Valid: true}, + new(pgtype.Text), + isExpectedEq(pgtype.Text{String: "foo", Valid: true}), + }, + {nil, new(pgtype.Text), isExpectedEq(pgtype.Text{})}, + {"foo", new(string), isExpectedEq("foo")}, + }) +} - for i, tt := range successfulTests { - var d pgtype.Text - err := d.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } +// Test fixed length char types like char(3) +func TestTextCodecBPChar(t *testing.T) { + skipCockroachDB(t, "Server does not properly handle bpchar with multi-byte character") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "char(3)", []pgxtest.ValueRoundTripTest{ + { + pgtype.Text{String: "a ", Valid: true}, + new(pgtype.Text), + isExpectedEq(pgtype.Text{String: "a ", Valid: true}), + }, + {nil, new(pgtype.Text), isExpectedEq(pgtype.Text{})}, + {" ", new(string), isExpectedEq(" ")}, + {"", new(string), isExpectedEq(" ")}, + {" 嗨 ", new(string), isExpectedEq(" 嗨 ")}, + }) +} - if d != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) - } +// ACLItem is used for PostgreSQL's aclitem data type. A sample aclitem +// might look like this: +// +// postgres=arwdDxt/postgres +// +// Note, however, that because the user/role name part of an aclitem is +// an identifier, it follows all the usual formatting rules for SQL +// identifiers: if it contains spaces and other special characters, +// it should appear in double-quotes: +// +// postgres=arwdDxt/"role with spaces" +// +// It only supports the text format. +func TestTextCodecACLItem(t *testing.T) { + ctr := defaultConnTestRunner + ctr.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does not support type aclitem") } + + pgxtest.RunValueRoundTripTests(context.Background(), t, ctr, nil, "aclitem", []pgxtest.ValueRoundTripTest{ + { + pgtype.Text{String: "postgres=arwdDxt/postgres", Valid: true}, + new(pgtype.Text), + isExpectedEq(pgtype.Text{String: "postgres=arwdDxt/postgres", Valid: true}), + }, + {pgtype.Text{}, new(pgtype.Text), isExpectedEq(pgtype.Text{})}, + {nil, new(pgtype.Text), isExpectedEq(pgtype.Text{})}, + }) } -func TestTextAssignTo(t *testing.T) { - var s string - var ps *string +func TestTextCodecACLItemRoleWithSpecialCharacters(t *testing.T) { + ctr := defaultConnTestRunner + ctr.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does not support type aclitem") - stringTests := []struct { - src pgtype.Text - dst interface{} - expected interface{} - }{ - {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &s, expected: "foo"}, - {src: pgtype.Text{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, - } + // The tricky test user, below, has to actually exist so that it can be used in a test + // of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. + roleWithSpecialCharacters := ` tricky, ' } " \ test user ` - for i, tt := range stringTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } + commandTag, err := conn.Exec(ctx, `select * from pg_roles where rolname = $1`, roleWithSpecialCharacters) + require.NoError(t, err) - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + if commandTag.RowsAffected() == 0 { + t.Skipf("Role with special characters does not exist.") } } - var buf []byte + pgxtest.RunValueRoundTripTests(context.Background(), t, ctr, nil, "aclitem", []pgxtest.ValueRoundTripTest{ + { + pgtype.Text{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}, + new(pgtype.Text), + isExpectedEq(pgtype.Text{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}), + }, + }) +} - bytesTests := []struct { - src pgtype.Text - dst *[]byte - expected []byte +func TestTextMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Text + result string }{ - {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &buf, expected: []byte("foo")}, - {src: pgtype.Text{Status: pgtype.Null}, dst: &buf, expected: nil}, + {source: pgtype.Text{String: ""}, result: "null"}, + {source: pgtype.Text{String: "a", Valid: true}, result: "\"a\""}, } - - for i, tt := range bytesTests { - err := tt.src.AssignTo(tt.dst) + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() if err != nil { t.Errorf("%d: %v", i, err) } - if bytes.Compare(*tt.dst, tt.expected) != 0 { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, tt.dst) + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) } } +} - pointerAllocTests := []struct { - src pgtype.Text - dst interface{} - expected interface{} +func TestTextUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Text }{ - {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &ps, expected: "foo"}, + {source: "null", result: pgtype.Text{String: ""}}, + {source: "\"a\"", result: pgtype.Text{String: "a", Valid: true}}, } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) + for i, tt := range successfulTests { + var r pgtype.Text + err := r.UnmarshalJSON([]byte(tt.source)) if err != nil { t.Errorf("%d: %v", i, err) } - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Text - dst interface{} - }{ - {src: pgtype.Text{Status: pgtype.Null}, dst: &s}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) } } } diff --git a/pgtype/tid.go b/pgtype/tid.go index 21852a144..05c9e6d98 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -7,15 +7,22 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" + "github.com/jackc/pgx/v5/internal/pgio" ) +type TIDScanner interface { + ScanTID(v TID) error +} + +type TIDValuer interface { + TIDValue() (TID, error) +} + // TID is PostgreSQL's Tuple Identifier type. // // When one does // -// select ctid, * from some_table; +// select ctid, * from some_table; // // it is the data type of the ctid hidden system column. // @@ -25,120 +32,211 @@ import ( type TID struct { BlockNumber uint32 OffsetNumber uint16 - Status Status + Valid bool } -func (dst *TID) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to TID", src) +// ScanTID implements the [TIDScanner] interface. +func (b *TID) ScanTID(v TID) error { + *b = v + return nil } -func (dst *TID) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: +// TIDValue implements the [TIDValuer] interface. +func (b TID) TIDValue() (TID, error) { + return b, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (dst *TID) Scan(src any) error { + if src == nil { + *dst = TID{} return nil - default: - return dst.Status } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToTIDScanner{}.Scan([]byte(src), dst) + } + + return fmt.Errorf("cannot scan %T", src) } -func (src *TID) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) +// Value implements the [database/sql/driver.Valuer] interface. +func (src TID) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + buf, err := TIDCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) + if err != nil { + return nil, err + } + return string(buf), err } -func (dst *TID) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = TID{Status: Null} +type TIDCodec struct{} + +func (TIDCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (TIDCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (TIDCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(TIDValuer); !ok { return nil } - if len(src) < 5 { - return errors.Errorf("invalid length for tid: %v", len(src)) + switch format { + case BinaryFormatCode: + return encodePlanTIDCodecBinary{} + case TextFormatCode: + return encodePlanTIDCodecText{} } - parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) - if len(parts) < 2 { - return errors.Errorf("invalid format for tid") - } + return nil +} - blockNumber, err := strconv.ParseUint(parts[0], 10, 32) +type encodePlanTIDCodecBinary struct{} + +func (encodePlanTIDCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + tid, err := value.(TIDValuer).TIDValue() if err != nil { - return err + return nil, err + } + + if !tid.Valid { + return nil, nil } - offsetNumber, err := strconv.ParseUint(parts[1], 10, 16) + buf = pgio.AppendUint32(buf, tid.BlockNumber) + buf = pgio.AppendUint16(buf, tid.OffsetNumber) + return buf, nil +} + +type encodePlanTIDCodecText struct{} + +func (encodePlanTIDCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + tid, err := value.(TIDValuer).TIDValue() if err != nil { - return err + return nil, err + } + + if !tid.Valid { + return nil, nil + } + + buf = append(buf, fmt.Sprintf(`(%d,%d)`, tid.BlockNumber, tid.OffsetNumber)...) + return buf, nil +} + +func (TIDCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case TIDScanner: + return scanPlanBinaryTIDToTIDScanner{} + case TextScanner: + return scanPlanBinaryTIDToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case TIDScanner: + return scanPlanTextAnyToTIDScanner{} + } } - *dst = TID{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Status: Present} return nil } -func (dst *TID) DecodeBinary(ci *ConnInfo, src []byte) error { +type scanPlanBinaryTIDToTIDScanner struct{} + +func (scanPlanBinaryTIDToTIDScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TIDScanner) + if src == nil { - *dst = TID{Status: Null} - return nil + return scanner.ScanTID(TID{}) } if len(src) != 6 { - return errors.Errorf("invalid length for tid: %v", len(src)) + return fmt.Errorf("invalid length for tid: %v", len(src)) } - *dst = TID{ + return scanner.ScanTID(TID{ BlockNumber: binary.BigEndian.Uint32(src), OffsetNumber: binary.BigEndian.Uint16(src[4:]), - Status: Present, - } - return nil + Valid: true, + }) } -func (src *TID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } +type scanPlanBinaryTIDToTextScanner struct{} - buf = append(buf, fmt.Sprintf(`(%d,%d)`, src.BlockNumber, src.OffsetNumber)...) - return buf, nil -} +func (scanPlanBinaryTIDToTextScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TextScanner) -func (src *TID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined + if src == nil { + return scanner.ScanText(Text{}) } - buf = pgio.AppendUint32(buf, src.BlockNumber) - buf = pgio.AppendUint16(buf, src.OffsetNumber) - return buf, nil + if len(src) != 6 { + return fmt.Errorf("invalid length for tid: %v", len(src)) + } + + blockNumber := binary.BigEndian.Uint32(src) + offsetNumber := binary.BigEndian.Uint16(src[4:]) + + return scanner.ScanText(Text{ + String: fmt.Sprintf(`(%d,%d)`, blockNumber, offsetNumber), + Valid: true, + }) } -// Scan implements the database/sql Scanner interface. -func (dst *TID) Scan(src interface{}) error { +type scanPlanTextAnyToTIDScanner struct{} + +func (scanPlanTextAnyToTIDScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TIDScanner) + if src == nil { - *dst = TID{Status: Null} - return nil + return scanner.ScanTID(TID{}) } - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + if len(src) < 5 { + return fmt.Errorf("invalid length for tid: %v", len(src)) } - return errors.Errorf("cannot scan %T", src) + block, offset, found := strings.Cut(string(src[1:len(src)-1]), ",") + if !found { + return fmt.Errorf("invalid format for tid") + } + + blockNumber, err := strconv.ParseUint(block, 10, 32) + if err != nil { + return err + } + + offsetNumber, err := strconv.ParseUint(offset, 10, 16) + if err != nil { + return err + } + + return scanner.ScanTID(TID{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Valid: true}) } -// Value implements the database/sql/driver Valuer interface. -func (src *TID) Value() (driver.Value, error) { - return EncodeValueText(src) +func (c TIDCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c TIDCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var tid TID + err := codecScan(c, m, oid, format, src, &tid) + if err != nil { + return nil, err + } + return tid, nil } diff --git a/pgtype/tid_test.go b/pgtype/tid_test.go index 9185cb31c..3e7a1a50c 100644 --- a/pgtype/tid_test.go +++ b/pgtype/tid_test.go @@ -1,16 +1,38 @@ package pgtype_test import ( + "context" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" ) -func TestTIDTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "tid", []interface{}{ - &pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, - &pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, - &pgtype.TID{Status: pgtype.Null}, +func TestTIDCodec(t *testing.T) { + skipCockroachDB(t, "Server does not support type tid") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "tid", []pgxtest.ValueRoundTripTest{ + { + pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}, + new(pgtype.TID), + isExpectedEq(pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}), + }, + { + pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Valid: true}, + new(pgtype.TID), + isExpectedEq(pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Valid: true}), + }, + { + pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}, + new(string), + isExpectedEq("(42,43)"), + }, + { + pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Valid: true}, + new(string), + isExpectedEq("(4294967295,65535)"), + }, + {pgtype.TID{}, new(pgtype.TID), isExpectedEq(pgtype.TID{})}, + {nil, new(pgtype.TID), isExpectedEq(pgtype.TID{})}, }) } diff --git a/pgtype/time.go b/pgtype/time.go new file mode 100644 index 000000000..4b8f69083 --- /dev/null +++ b/pgtype/time.go @@ -0,0 +1,275 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "strconv" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type TimeScanner interface { + ScanTime(v Time) error +} + +type TimeValuer interface { + TimeValue() (Time, error) +} + +// Time represents the PostgreSQL time type. The PostgreSQL time is a time of day without time zone. +// +// Time is represented as the number of microseconds since midnight in the same way that PostgreSQL does. Other time and +// date types in pgtype can use time.Time as the underlying representation. However, pgtype.Time type cannot due to +// needing to handle 24:00:00. time.Time converts that to 00:00:00 on the following day. +// +// The time with time zone type is not supported. Use of time with time zone is discouraged by the PostgreSQL documentation. +type Time struct { + Microseconds int64 // Number of microseconds since midnight + Valid bool +} + +// ScanTime implements the [TimeScanner] interface. +func (t *Time) ScanTime(v Time) error { + *t = v + return nil +} + +// TimeValue implements the [TimeValuer] interface. +func (t Time) TimeValue() (Time, error) { + return t, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (t *Time) Scan(src any) error { + if src == nil { + *t = Time{} + return nil + } + + switch src := src.(type) { + case string: + err := scanPlanTextAnyToTimeScanner{}.Scan([]byte(src), t) + if err != nil { + t.Microseconds = 0 + t.Valid = false + } + return err + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (t Time) Value() (driver.Value, error) { + if !t.Valid { + return nil, nil + } + + buf, err := TimeCodec{}.PlanEncode(nil, 0, TextFormatCode, t).Encode(t, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type TimeCodec struct{} + +func (TimeCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (TimeCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (TimeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(TimeValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanTimeCodecBinary{} + case TextFormatCode: + return encodePlanTimeCodecText{} + } + + return nil +} + +type encodePlanTimeCodecBinary struct{} + +func (encodePlanTimeCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + t, err := value.(TimeValuer).TimeValue() + if err != nil { + return nil, err + } + + if !t.Valid { + return nil, nil + } + + return pgio.AppendInt64(buf, t.Microseconds), nil +} + +type encodePlanTimeCodecText struct{} + +func (encodePlanTimeCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + t, err := value.(TimeValuer).TimeValue() + if err != nil { + return nil, err + } + + if !t.Valid { + return nil, nil + } + + usec := t.Microseconds + hours := usec / microsecondsPerHour + usec -= hours * microsecondsPerHour + minutes := usec / microsecondsPerMinute + usec -= minutes * microsecondsPerMinute + seconds := usec / microsecondsPerSecond + usec -= seconds * microsecondsPerSecond + + s := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, usec) + + return append(buf, s...), nil +} + +func (TimeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case TimeScanner: + return scanPlanBinaryTimeToTimeScanner{} + case TextScanner: + return scanPlanBinaryTimeToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case TimeScanner: + return scanPlanTextAnyToTimeScanner{} + } + } + + return nil +} + +type scanPlanBinaryTimeToTimeScanner struct{} + +func (scanPlanBinaryTimeToTimeScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TimeScanner) + + if src == nil { + return scanner.ScanTime(Time{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for time: %v", len(src)) + } + + usec := int64(binary.BigEndian.Uint64(src)) + + return scanner.ScanTime(Time{Microseconds: usec, Valid: true}) +} + +type scanPlanBinaryTimeToTextScanner struct{} + +func (scanPlanBinaryTimeToTextScanner) Scan(src []byte, dst any) error { + ts, ok := (dst).(TextScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return ts.ScanText(Text{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for time: %v", len(src)) + } + + usec := int64(binary.BigEndian.Uint64(src)) + + tim := Time{Microseconds: usec, Valid: true} + + buf, err := TimeCodec{}.PlanEncode(nil, 0, TextFormatCode, tim).Encode(tim, nil) + if err != nil { + return err + } + + return ts.ScanText(Text{String: string(buf), Valid: true}) +} + +type scanPlanTextAnyToTimeScanner struct{} + +func (scanPlanTextAnyToTimeScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TimeScanner) + + if src == nil { + return scanner.ScanTime(Time{}) + } + + s := string(src) + + if len(s) < 8 || s[2] != ':' || s[5] != ':' { + return fmt.Errorf("cannot decode %v into Time", s) + } + + hours, err := strconv.ParseInt(s[0:2], 10, 64) + if err != nil { + return fmt.Errorf("cannot decode %v into Time", s) + } + usec := hours * microsecondsPerHour + + minutes, err := strconv.ParseInt(s[3:5], 10, 64) + if err != nil { + return fmt.Errorf("cannot decode %v into Time", s) + } + usec += minutes * microsecondsPerMinute + + seconds, err := strconv.ParseInt(s[6:8], 10, 64) + if err != nil { + return fmt.Errorf("cannot decode %v into Time", s) + } + usec += seconds * microsecondsPerSecond + + if len(s) > 9 { + if s[8] != '.' || len(s) > 15 { + return fmt.Errorf("cannot decode %v into Time", s) + } + + fraction := s[9:] + n, err := strconv.ParseInt(fraction, 10, 64) + if err != nil { + return fmt.Errorf("cannot decode %v into Time", s) + } + + for i := len(fraction); i < 6; i++ { + n *= 10 + } + + usec += n + } + + return scanner.ScanTime(Time{Microseconds: usec, Valid: true}) +} + +func (c TimeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c TimeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var t Time + err := codecScan(c, m, oid, format, src, &t) + if err != nil { + return nil, err + } + return t, nil +} diff --git a/pgtype/time_test.go b/pgtype/time_test.go new file mode 100644 index 000000000..06970bacd --- /dev/null +++ b/pgtype/time_test.go @@ -0,0 +1,115 @@ +package pgtype_test + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/assert" +) + +func TestTimeCodec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "time", []pgxtest.ValueRoundTripTest{ + { + pgtype.Time{Microseconds: 0, Valid: true}, + new(pgtype.Time), + isExpectedEq(pgtype.Time{Microseconds: 0, Valid: true}), + }, + { + pgtype.Time{Microseconds: 1, Valid: true}, + new(pgtype.Time), + isExpectedEq(pgtype.Time{Microseconds: 1, Valid: true}), + }, + { + pgtype.Time{Microseconds: 86399999999, Valid: true}, + new(pgtype.Time), + isExpectedEq(pgtype.Time{Microseconds: 86399999999, Valid: true}), + }, + { + pgtype.Time{Microseconds: 86400000000, Valid: true}, + new(pgtype.Time), + isExpectedEq(pgtype.Time{Microseconds: 86400000000, Valid: true}), + }, + { + time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), + new(pgtype.Time), + isExpectedEq(pgtype.Time{Microseconds: 0, Valid: true}), + }, + { + pgtype.Time{Microseconds: 0, Valid: true}, + new(time.Time), + isExpectedEq(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)), + }, + {pgtype.Time{}, new(pgtype.Time), isExpectedEq(pgtype.Time{})}, + {nil, new(pgtype.Time), isExpectedEq(pgtype.Time{})}, + }) +} + +func TestTimeTextScanner(t *testing.T) { + var pgTime pgtype.Time + + assert.NoError(t, pgTime.Scan("07:37:16")) + assert.Equal(t, true, pgTime.Valid) + assert.Equal(t, int64(7*time.Hour+37*time.Minute+16*time.Second), pgTime.Microseconds*int64(time.Microsecond)) + + assert.NoError(t, pgTime.Scan("15:04:05")) + assert.Equal(t, true, pgTime.Valid) + assert.Equal(t, int64(15*time.Hour+4*time.Minute+5*time.Second), pgTime.Microseconds*int64(time.Microsecond)) + + // parsing of fractional digits + assert.NoError(t, pgTime.Scan("15:04:05.00")) + assert.Equal(t, true, pgTime.Valid) + assert.Equal(t, int64(15*time.Hour+4*time.Minute+5*time.Second), pgTime.Microseconds*int64(time.Microsecond)) + + const mirco = "789123" + const woFraction = int64(4*time.Hour + 5*time.Minute + 6*time.Second) // time without fraction + for i := 0; i <= len(mirco); i++ { + assert.NoError(t, pgTime.Scan("04:05:06."+mirco[:i])) + assert.Equal(t, true, pgTime.Valid) + + frac, _ := strconv.ParseInt(mirco[:i], 10, 64) + for k := i; k < 6; k++ { + frac *= 10 + } + assert.Equal(t, woFraction+frac*int64(time.Microsecond), pgTime.Microseconds*int64(time.Microsecond)) + } + + // parsing of too long fraction errors + assert.Error(t, pgTime.Scan("04:05:06.7891234")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) + + // parsing of timetz errors + assert.Error(t, pgTime.Scan("04:05:06.789-08")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) + + assert.Error(t, pgTime.Scan("04:05:06-08:00")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) + + // parsing of date errors + assert.Error(t, pgTime.Scan("1997-12-17")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) + + // parsing of text errors + assert.Error(t, pgTime.Scan("12345678")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) + + assert.Error(t, pgTime.Scan("12-34-56")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) + + assert.Error(t, pgTime.Scan("12:34-56")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) + + assert.Error(t, pgTime.Scan("12-34:56")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) +} diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index d906f4679..861fa8838 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -3,223 +3,368 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" + "fmt" + "strings" "time" - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" + "github.com/jackc/pgx/v5/internal/pgio" ) -const pgTimestampFormat = "2006-01-02 15:04:05.999999999" +const ( + pgTimestampFormat = "2006-01-02 15:04:05.999999999" + jsonISO8601 = "2006-01-02T15:04:05.999999999" +) + +type TimestampScanner interface { + ScanTimestamp(v Timestamp) error +} + +type TimestampValuer interface { + TimestampValue() (Timestamp, error) +} -// Timestamp represents the PostgreSQL timestamp type. The PostgreSQL -// timestamp does not have a time zone. This presents a problem when -// translating to and from time.Time which requires a time zone. It is highly -// recommended to use timestamptz whenever possible. Timestamp methods either -// convert to UTC or return an error on non-UTC times. +// Timestamp represents the PostgreSQL timestamp type. type Timestamp struct { - Time time.Time // Time must always be in UTC. - Status Status + Time time.Time // Time zone will be ignored when encoding to PostgreSQL. InfinityModifier InfinityModifier + Valid bool +} + +// ScanTimestamp implements the [TimestampScanner] interface. +func (ts *Timestamp) ScanTimestamp(v Timestamp) error { + *ts = v + return nil +} + +// TimestampValue implements the [TimestampValuer] interface. +func (ts Timestamp) TimestampValue() (Timestamp, error) { + return ts, nil } -// Set converts src into a Timestamp and stores in dst. If src is a -// time.Time in a non-UTC time zone, the time zone is discarded. -func (dst *Timestamp) Set(src interface{}) error { +// Scan implements the [database/sql.Scanner] interface. +func (ts *Timestamp) Scan(src any) error { if src == nil { - *dst = Timestamp{Status: Null} + *ts = Timestamp{} return nil } - switch value := src.(type) { + switch src := src.(type) { + case string: + return (&scanPlanTextTimestampToTimestampScanner{}).Scan([]byte(src), ts) case time.Time: - *dst = Timestamp{Time: time.Date(value.Year(), value.Month(), value.Day(), value.Hour(), value.Minute(), value.Second(), value.Nanosecond(), time.UTC), Status: Present} - default: - if originalSrc, ok := underlyingTimeType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to Timestamp", value) + *ts = Timestamp{Time: src, Valid: true} + return nil } - return nil + return fmt.Errorf("cannot scan %T", src) } -func (dst *Timestamp) Get() interface{} { - switch dst.Status { - case Present: - if dst.InfinityModifier != None { - return dst.InfinityModifier - } - return dst.Time - case Null: - return nil - default: - return dst.Status +// Value implements the [database/sql/driver.Valuer] interface. +func (ts Timestamp) Value() (driver.Value, error) { + if !ts.Valid { + return nil, nil } + + if ts.InfinityModifier != Finite { + return ts.InfinityModifier.String(), nil + } + return ts.Time, nil } -func (src *Timestamp) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *time.Time: - if src.InfinityModifier != None { - return errors.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.Time - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) +// MarshalJSON implements the [encoding/json.Marshaler] interface. +func (ts Timestamp) MarshalJSON() ([]byte, error) { + if !ts.Valid { + return []byte("null"), nil } - return errors.Errorf("cannot decode %v into %T", src, dst) + var s string + + switch ts.InfinityModifier { + case Finite: + s = ts.Time.Format(jsonISO8601) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return json.Marshal(s) } -// DecodeText decodes from src into dst. The decoded time is considered to -// be in UTC. -func (dst *Timestamp) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Timestamp{Status: Null} +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. +func (ts *Timestamp) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *ts = Timestamp{} return nil } - sbuf := string(src) - switch sbuf { + switch *s { case "infinity": - *dst = Timestamp{Status: Present, InfinityModifier: Infinity} + *ts = Timestamp{Valid: true, InfinityModifier: Infinity} case "-infinity": - *dst = Timestamp{Status: Present, InfinityModifier: -Infinity} + *ts = Timestamp{Valid: true, InfinityModifier: -Infinity} default: - tim, err := time.Parse(pgTimestampFormat, sbuf) - if err != nil { - return err + // Parse time with or without timezonr + tss := *s + // PostgreSQL uses ISO 8601 without timezone for to_json function and casting from a string to timestampt + tim, err := time.Parse(time.RFC3339Nano, tss) + if err == nil { + *ts = Timestamp{Time: tim, Valid: true} + return nil } - - *dst = Timestamp{Time: tim, Status: Present} + tim, err = time.ParseInLocation(jsonISO8601, tss, time.UTC) + if err == nil { + *ts = Timestamp{Time: tim, Valid: true} + return nil + } + ts.Valid = false + return fmt.Errorf("cannot unmarshal %s to timestamp with layout %s or %s (%w)", + *s, time.RFC3339Nano, jsonISO8601, err) } - return nil } -// DecodeBinary decodes from src into dst. The decoded time is considered to -// be in UTC. -func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Timestamp{Status: Null} +type TimestampCodec struct { + // ScanLocation is the location that the time is assumed to be in for scanning. This is different from + // TimestamptzCodec.ScanLocation in that this setting does change the instant in time that the timestamp represents. + ScanLocation *time.Location +} + +func (*TimestampCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (*TimestampCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (*TimestampCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(TimestampValuer); !ok { return nil } - if len(src) != 8 { - return errors.Errorf("invalid length for timestamp: %v", len(src)) + switch format { + case BinaryFormatCode: + return encodePlanTimestampCodecBinary{} + case TextFormatCode: + return encodePlanTimestampCodecText{} } - microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) + return nil +} - switch microsecSinceY2K { - case infinityMicrosecondOffset: - *dst = Timestamp{Status: Present, InfinityModifier: Infinity} - case negativeInfinityMicrosecondOffset: - *dst = Timestamp{Status: Present, InfinityModifier: -Infinity} - default: - microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K - tim := time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000).UTC() - *dst = Timestamp{Time: tim, Status: Present} +type encodePlanTimestampCodecBinary struct{} + +func (encodePlanTimestampCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + ts, err := value.(TimestampValuer).TimestampValue() + if err != nil { + return nil, err } - return nil + if !ts.Valid { + return nil, nil + } + + var microsecSinceY2K int64 + switch ts.InfinityModifier { + case Finite: + t := discardTimeZone(ts.Time) + microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 + microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + case Infinity: + microsecSinceY2K = infinityMicrosecondOffset + case NegativeInfinity: + microsecSinceY2K = negativeInfinityMicrosecondOffset + } + + buf = pgio.AppendInt64(buf, microsecSinceY2K) + + return buf, nil } -// EncodeText writes the text encoding of src into w. If src.Time is not in -// the UTC time zone it returns an error. -func (src *Timestamp) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined +type encodePlanTimestampCodecText struct{} + +func (encodePlanTimestampCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + ts, err := value.(TimestampValuer).TimestampValue() + if err != nil { + return nil, err } - if src.Time.Location() != time.UTC { - return nil, errors.Errorf("cannot encode non-UTC time into timestamp") + + if !ts.Valid { + return nil, nil } var s string - switch src.InfinityModifier { - case None: - s = src.Time.Format(pgTimestampFormat) + switch ts.InfinityModifier { + case Finite: + t := discardTimeZone(ts.Time) + + // Year 0000 is 1 BC + bc := false + if year := t.Year(); year <= 0 { + year = -year + 1 + t = time.Date(year, t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.UTC) + bc = true + } + + s = t.Truncate(time.Microsecond).Format(pgTimestampFormat) + + if bc { + s = s + " BC" + } case Infinity: s = "infinity" case NegativeInfinity: s = "-infinity" } - return append(buf, s...), nil + buf = append(buf, s...) + + return buf, nil } -// EncodeBinary writes the binary encoding of src into w. If src.Time is not in -// the UTC time zone it returns an error. -func (src *Timestamp) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - if src.Time.Location() != time.UTC { - return nil, errors.Errorf("cannot encode non-UTC time into timestamp") +func discardTimeZone(t time.Time) time.Time { + if t.Location() != time.UTC { + return time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.UTC) } - var microsecSinceY2K int64 - switch src.InfinityModifier { - case None: - microsecSinceUnixEpoch := src.Time.Unix()*1000000 + int64(src.Time.Nanosecond())/1000 - microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K - case Infinity: - microsecSinceY2K = infinityMicrosecondOffset - case NegativeInfinity: - microsecSinceY2K = negativeInfinityMicrosecondOffset + return t +} + +func (c *TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case TimestampScanner: + return &scanPlanBinaryTimestampToTimestampScanner{location: c.ScanLocation} + } + case TextFormatCode: + switch target.(type) { + case TimestampScanner: + return &scanPlanTextTimestampToTimestampScanner{location: c.ScanLocation} + } } - return pgio.AppendInt64(buf, microsecSinceY2K), nil + return nil } -// Scan implements the database/sql Scanner interface. -func (dst *Timestamp) Scan(src interface{}) error { +type scanPlanBinaryTimestampToTimestampScanner struct{ location *time.Location } + +func (plan *scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TimestampScanner) + if src == nil { - *dst = Timestamp{Status: Null} - return nil + return scanner.ScanTimestamp(Timestamp{}) } - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - case time.Time: - *dst = Timestamp{Time: src, Status: Present} - return nil + if len(src) != 8 { + return fmt.Errorf("invalid length for timestamp: %v", len(src)) } - return errors.Errorf("cannot scan %T", src) + var ts Timestamp + microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) + + switch microsecSinceY2K { + case infinityMicrosecondOffset: + ts = Timestamp{Valid: true, InfinityModifier: Infinity} + case negativeInfinityMicrosecondOffset: + ts = Timestamp{Valid: true, InfinityModifier: -Infinity} + default: + tim := time.Unix( + microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, + (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), + ).UTC() + if plan.location != nil { + tim = time.Date(tim.Year(), tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), plan.location) + } + ts = Timestamp{Time: tim, Valid: true} + } + + return scanner.ScanTimestamp(ts) } -// Value implements the database/sql/driver Valuer interface. -func (src *Timestamp) Value() (driver.Value, error) { - switch src.Status { - case Present: - if src.InfinityModifier != None { - return src.InfinityModifier.String(), nil +type scanPlanTextTimestampToTimestampScanner struct{ location *time.Location } + +func (plan *scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TimestampScanner) + + if src == nil { + return scanner.ScanTimestamp(Timestamp{}) + } + + var ts Timestamp + sbuf := string(src) + switch sbuf { + case "infinity": + ts = Timestamp{Valid: true, InfinityModifier: Infinity} + case "-infinity": + ts = Timestamp{Valid: true, InfinityModifier: -Infinity} + default: + bc := false + if strings.HasSuffix(sbuf, " BC") { + sbuf = sbuf[:len(sbuf)-3] + bc = true } - return src.Time, nil - case Null: + tim, err := time.Parse(pgTimestampFormat, sbuf) + if err != nil { + return err + } + + if bc { + year := -tim.Year() + 1 + tim = time.Date(year, tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), tim.Location()) + } + + if plan.location != nil { + tim = time.Date(tim.Year(), tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), plan.location) + } + + ts = Timestamp{Time: tim, Valid: true} + } + + return scanner.ScanTimestamp(ts) +} + +func (c *TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { return nil, nil - default: - return nil, errUndefined } + + var ts Timestamp + err := codecScan(c, m, oid, format, src, &ts) + if err != nil { + return nil, err + } + + if ts.InfinityModifier != Finite { + return ts.InfinityModifier.String(), nil + } + + return ts.Time, nil +} + +func (c *TimestampCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var ts Timestamp + err := codecScan(c, m, oid, format, src, &ts) + if err != nil { + return nil, err + } + + if ts.InfinityModifier != Finite { + return ts.InfinityModifier, nil + } + + return ts.Time, nil } diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go deleted file mode 100644 index 546a3810b..000000000 --- a/pgtype/timestamp_array.go +++ /dev/null @@ -1,301 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "time" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type TimestampArray struct { - Elements []Timestamp - Dimensions []ArrayDimension - Status Status -} - -func (dst *TimestampArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = TimestampArray{Status: Null} - return nil - } - - switch value := src.(type) { - - case []time.Time: - if value == nil { - *dst = TimestampArray{Status: Null} - } else if len(value) == 0 { - *dst = TimestampArray{Status: Present} - } else { - elements := make([]Timestamp, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TimestampArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - default: - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to TimestampArray", value) - } - - return nil -} - -func (dst *TimestampArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *TimestampArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - - case *[]time.Time: - *v = make([]time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *TimestampArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = TimestampArray{Status: Null} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Timestamp - - if len(uta.Elements) > 0 { - elements = make([]Timestamp, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Timestamp - var elemSrc []byte - if s != "NULL" { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = TimestampArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} - - return nil -} - -func (dst *TimestampArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = TimestampArray{Status: Null} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = TimestampArray{Dimensions: arrayHeader.Dimensions, Status: Present} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Timestamp, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = TimestampArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} - return nil -} - -func (src *TimestampArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src *TimestampArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("timestamp"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, errors.Errorf("unable to find oid for type name %v", "timestamp") - } - - for i := range src.Elements { - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *TimestampArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *TimestampArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/timestamp_array_test.go b/pgtype/timestamp_array_test.go deleted file mode 100644 index 5821f43a7..000000000 --- a/pgtype/timestamp_array_test.go +++ /dev/null @@ -1,159 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - "time" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestTimestampArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp[]", []interface{}{ - &pgtype.TimestampArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.TimestampArray{Status: pgtype.Null}, - &pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Status: pgtype.Null}, - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }, func(a, b interface{}) bool { - ata := a.(pgtype.TimestampArray) - bta := b.(pgtype.TimestampArray) - - if len(ata.Elements) != len(bta.Elements) || ata.Status != bta.Status { - return false - } - - for i := range ata.Elements { - ae, be := ata.Elements[i], bta.Elements[i] - if !(ae.Time.Equal(be.Time) && ae.Status == be.Status && ae.InfinityModifier == be.InfinityModifier) { - return false - } - } - - return true - }) -} - -func TestTimestampArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.TimestampArray - }{ - { - source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - result: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]time.Time)(nil)), - result: pgtype.TimestampArray{Status: pgtype.Null}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.TimestampArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestTimestampArrayAssignTo(t *testing.T) { - var timeSlice []time.Time - - simpleTests := []struct { - src pgtype.TimestampArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &timeSlice, - expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - }, - { - src: pgtype.TimestampArray{Status: pgtype.Null}, - dst: &timeSlice, - expected: (([]time.Time)(nil)), - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.TimestampArray - dst interface{} - }{ - { - src: pgtype.TimestampArray{ - Elements: []pgtype.Timestamp{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &timeSlice, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go index 267f1a7ed..5e9022f42 100644 --- a/pgtype/timestamp_test.go +++ b/pgtype/timestamp_test.go @@ -1,123 +1,179 @@ package pgtype_test import ( - "reflect" + "context" + "encoding/json" "testing" "time" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func TestTimestampTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp", []interface{}{ - &pgtype.Timestamp{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - &pgtype.Timestamp{Status: pgtype.Null}, - &pgtype.Timestamp{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, - &pgtype.Timestamp{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, - }, func(a, b interface{}) bool { - at := a.(pgtype.Timestamp) - bt := b.(pgtype.Timestamp) - - return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier +func TestTimestampCodec(t *testing.T) { + skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "timestamp", []pgxtest.ValueRoundTripTest{ + {time.Date(-100, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(-100, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(-1, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(-1, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC))}, + + {time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC))}, + {time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC))}, + {time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC))}, + {time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC))}, + + // Nanosecond truncation + {time.Date(2020, 1, 1, 0, 0, 0, 999999999, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.UTC))}, + {time.Date(2020, 1, 1, 0, 0, 0, 999999001, time.UTC), new(time.Time), isExpectedEqTime(time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.UTC))}, + + {pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Valid: true}, new(pgtype.Timestamp), isExpectedEq(pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Valid: true})}, + {pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, new(pgtype.Timestamp), isExpectedEq(pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Valid: true})}, + {pgtype.Timestamp{}, new(pgtype.Timestamp), isExpectedEq(pgtype.Timestamp{})}, + {nil, new(*time.Time), isExpectedEq((*time.Time)(nil))}, }) } -func TestTimestampSet(t *testing.T) { - type _time time.Time +func TestTimestampCodecWithScanLocationUTC(t *testing.T) { + skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)") - successfulTests := []struct { - source interface{} - result pgtype.Timestamp - }{ - {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + connTestRunner := defaultConnTestRunner + connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "timestamp", + OID: pgtype.TimestampOID, + Codec: &pgtype.TimestampCodec{ScanLocation: time.UTC}, + }) } - for i, tt := range successfulTests { - var r pgtype.Timestamp - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } + pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, nil, "timestamp", []pgxtest.ValueRoundTripTest{ + // Have to use pgtype.Timestamp instead of time.Time as source because otherwise the simple and exec query exec + // modes will encode the time for timestamptz. That is, they will convert it from local time zone. + {pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, new(time.Time), isExpectedEq(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC))}, + }) } -func TestTimestampAssignTo(t *testing.T) { - var tim time.Time - var ptim *time.Time +func TestTimestampCodecWithScanLocationLocal(t *testing.T) { + skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)") - simpleTests := []struct { - src pgtype.Timestamp - dst interface{} - expected interface{} - }{ - {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC)}, - {src: pgtype.Timestamp{Time: time.Time{}, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, + connTestRunner := defaultConnTestRunner + connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "timestamp", + OID: pgtype.TimestampOID, + Codec: &pgtype.TimestampCodec{ScanLocation: time.Local}, + }) } - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) + pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, nil, "timestamp", []pgxtest.ValueRoundTripTest{ + {time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEq(time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local))}, + }) +} + +// https://github.com/jackc/pgx/v4/pgtype/pull/128 +func TestTimestampTranscodeBigTimeBinary(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + in := &pgtype.Timestamp{Time: time.Date(294276, 12, 31, 23, 59, 59, 999999000, time.UTC), Valid: true} + var out pgtype.Timestamp + + err := conn.QueryRow(ctx, "select $1::timestamp", in).Scan(&out) if err != nil { - t.Errorf("%d: %v", i, err) + t.Fatal(err) } - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } + require.Equal(t, in.Valid, out.Valid) + require.Truef(t, in.Time.Equal(out.Time), "expected %v got %v", in.Time, out.Time) + }) +} + +// https://github.com/jackc/pgtype/issues/74 +func TestTimestampCodecDecodeTextInvalid(t *testing.T) { + c := &pgtype.TimestampCodec{} + var ts pgtype.Timestamp + plan := c.PlanScan(nil, pgtype.TimestampOID, pgtype.TextFormatCode, &ts) + err := plan.Scan([]byte(`eeeee`), &ts) + require.Error(t, err) +} + +func TestTimestampMarshalJSON(t *testing.T) { + tsStruct := struct { + TS pgtype.Timestamp `json:"ts"` + }{} - pointerAllocTests := []struct { - src pgtype.Timestamp - dst interface{} - expected interface{} + tm := time.Date(2012, 3, 29, 10, 5, 45, 0, time.UTC) + tsString := "\"" + tm.Format("2006-01-02T15:04:05") + "\"" // `"2012-03-29T10:05:45"` + var pgt pgtype.Timestamp + _ = pgt.Scan(tm) + + successfulTests := []struct { + source pgtype.Timestamp + result string }{ - {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + {source: pgtype.Timestamp{}, result: "null"}, + {source: pgtype.Timestamp{Time: tm, Valid: true}, result: tsString}, + {source: pgt, result: tsString}, + {source: pgtype.Timestamp{Time: tm.Add(time.Second * 555 / 1000), Valid: true}, result: `"2012-03-29T10:05:45.555"`}, + {source: pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Valid: true}, result: "\"infinity\""}, + {source: pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, result: "\"-infinity\""}, } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() if err != nil { t.Errorf("%d: %v", i, err) } - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + if !assert.Equal(t, tt.result, string(r)) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) } + tsStruct.TS = tt.source + b, err := json.Marshal(tsStruct) + assert.NoErrorf(t, err, "failed to marshal %v %s", tt.source, err) + t2 := tsStruct + t2.TS = pgtype.Timestamp{} // Clear out the value so that we can compare after unmarshalling + err = json.Unmarshal(b, &t2) + assert.NoErrorf(t, err, "failed to unmarshal %v with %s", tt.source, err) + assert.True(t, tsStruct.TS.Time.Unix() == t2.TS.Time.Unix()) } +} - errorTests := []struct { - src pgtype.Timestamp - dst interface{} +func TestTimestampUnmarshalJSONErrors(t *testing.T) { + tsStruct := struct { + TS pgtype.Timestamp `json:"ts"` + }{} + goodJson1 := []byte(`{"ts":"2012-03-29T10:05:45"}`) + assert.NoError(t, json.Unmarshal(goodJson1, &tsStruct)) + goodJson2 := []byte(`{"ts":"2012-03-29T10:05:45Z"}`) + assert.NoError(t, json.Unmarshal(goodJson2, &tsStruct)) + badJson := []byte(`{"ts":"2012-03-29"}`) + assert.Error(t, json.Unmarshal(badJson, &tsStruct)) +} + +func TestTimestampUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Timestamp }{ - {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, dst: &tim}, - {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, dst: &tim}, - {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Null}, dst: &tim}, + {source: "null", result: pgtype.Timestamp{}}, + {source: "\"2012-03-29T10:05:45\"", result: pgtype.Timestamp{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.UTC), Valid: true}}, + {source: "\"2012-03-29T10:05:45.555\"", result: pgtype.Timestamp{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.UTC), Valid: true}}, + {source: "\"infinity\"", result: pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Valid: true}}, + {source: "\"-infinity\"", result: pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Valid: true}}, } + for i, tt := range successfulTests { + var r pgtype.Timestamp + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + if !r.Time.Equal(tt.result.Time) || r.Valid != tt.result.Valid || r.InfinityModifier != tt.result.InfinityModifier { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) } } } diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 74fe4954f..5d67e47f8 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -3,219 +3,369 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" + "fmt" + "strings" "time" - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" + "github.com/jackc/pgx/v5/internal/pgio" ) -const pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" -const pgTimestamptzMinuteFormat = "2006-01-02 15:04:05.999999999Z07:00" -const pgTimestamptzSecondFormat = "2006-01-02 15:04:05.999999999Z07:00:00" -const microsecFromUnixEpochToY2K = 946684800 * 1000000 +const ( + pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" + pgTimestamptzMinuteFormat = "2006-01-02 15:04:05.999999999Z07:00" + pgTimestamptzSecondFormat = "2006-01-02 15:04:05.999999999Z07:00:00" + microsecFromUnixEpochToY2K = 946684800 * 1000000 +) const ( negativeInfinityMicrosecondOffset = -9223372036854775808 infinityMicrosecondOffset = 9223372036854775807 ) +type TimestamptzScanner interface { + ScanTimestamptz(v Timestamptz) error +} + +type TimestamptzValuer interface { + TimestamptzValue() (Timestamptz, error) +} + +// Timestamptz represents the PostgreSQL timestamptz type. type Timestamptz struct { Time time.Time - Status Status InfinityModifier InfinityModifier + Valid bool +} + +// ScanTimestamptz implements the [TimestamptzScanner] interface. +func (tstz *Timestamptz) ScanTimestamptz(v Timestamptz) error { + *tstz = v + return nil +} + +// TimestamptzValue implements the [TimestamptzValuer] interface. +func (tstz Timestamptz) TimestamptzValue() (Timestamptz, error) { + return tstz, nil } -func (dst *Timestamptz) Set(src interface{}) error { +// Scan implements the [database/sql.Scanner] interface. +func (tstz *Timestamptz) Scan(src any) error { if src == nil { - *dst = Timestamptz{Status: Null} + *tstz = Timestamptz{} return nil } - switch value := src.(type) { + switch src := src.(type) { + case string: + return (&scanPlanTextTimestamptzToTimestamptzScanner{}).Scan([]byte(src), tstz) case time.Time: - *dst = Timestamptz{Time: value, Status: Present} - default: - if originalSrc, ok := underlyingTimeType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to Timestamptz", value) + *tstz = Timestamptz{Time: src, Valid: true} + return nil } - return nil + return fmt.Errorf("cannot scan %T", src) } -func (dst *Timestamptz) Get() interface{} { - switch dst.Status { - case Present: - if dst.InfinityModifier != None { - return dst.InfinityModifier - } - return dst.Time - case Null: - return nil - default: - return dst.Status - } -} - -func (src *Timestamptz) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *time.Time: - if src.InfinityModifier != None { - return errors.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.Time - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) +// Value implements the [database/sql/driver.Valuer] interface. +func (tstz Timestamptz) Value() (driver.Value, error) { + if !tstz.Valid { + return nil, nil } - return errors.Errorf("cannot decode %v into %T", src, dst) + if tstz.InfinityModifier != Finite { + return tstz.InfinityModifier.String(), nil + } + return tstz.Time, nil } -func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Timestamptz{Status: Null} +// MarshalJSON implements the [encoding/json.Marshaler] interface. +func (tstz Timestamptz) MarshalJSON() ([]byte, error) { + if !tstz.Valid { + return []byte("null"), nil + } + + var s string + + switch tstz.InfinityModifier { + case Finite: + s = tstz.Time.Format(time.RFC3339Nano) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return json.Marshal(s) +} + +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. +func (tstz *Timestamptz) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *tstz = Timestamptz{} return nil } - sbuf := string(src) - switch sbuf { + switch *s { case "infinity": - *dst = Timestamptz{Status: Present, InfinityModifier: Infinity} + *tstz = Timestamptz{Valid: true, InfinityModifier: Infinity} case "-infinity": - *dst = Timestamptz{Status: Present, InfinityModifier: -Infinity} + *tstz = Timestamptz{Valid: true, InfinityModifier: -Infinity} default: - var format string - if sbuf[len(sbuf)-9] == '-' || sbuf[len(sbuf)-9] == '+' { - format = pgTimestamptzSecondFormat - } else if sbuf[len(sbuf)-6] == '-' || sbuf[len(sbuf)-6] == '+' { - format = pgTimestamptzMinuteFormat - } else { - format = pgTimestamptzHourFormat - } - - tim, err := time.Parse(format, sbuf) + // PostgreSQL uses ISO 8601 for to_json function and casting from a string to timestamptz + tim, err := time.Parse(time.RFC3339Nano, *s) if err != nil { return err } - *dst = Timestamptz{Time: tim, Status: Present} + *tstz = Timestamptz{Time: tim, Valid: true} } return nil } -func (dst *Timestamptz) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Timestamptz{Status: Null} +type TimestamptzCodec struct { + // ScanLocation is the location to return scanned timestamptz values in. This does not change the instant in time that + // the timestamptz represents. + ScanLocation *time.Location +} + +func (*TimestamptzCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (*TimestamptzCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (*TimestamptzCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(TimestamptzValuer); !ok { return nil } - if len(src) != 8 { - return errors.Errorf("invalid length for timestamptz: %v", len(src)) + switch format { + case BinaryFormatCode: + return encodePlanTimestamptzCodecBinary{} + case TextFormatCode: + return encodePlanTimestamptzCodecText{} } - microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) + return nil +} - switch microsecSinceY2K { - case infinityMicrosecondOffset: - *dst = Timestamptz{Status: Present, InfinityModifier: Infinity} - case negativeInfinityMicrosecondOffset: - *dst = Timestamptz{Status: Present, InfinityModifier: -Infinity} - default: - microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K - tim := time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) - *dst = Timestamptz{Time: tim, Status: Present} +type encodePlanTimestamptzCodecBinary struct{} + +func (encodePlanTimestamptzCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + ts, err := value.(TimestamptzValuer).TimestamptzValue() + if err != nil { + return nil, err } - return nil + if !ts.Valid { + return nil, nil + } + + var microsecSinceY2K int64 + switch ts.InfinityModifier { + case Finite: + microsecSinceUnixEpoch := ts.Time.Unix()*1000000 + int64(ts.Time.Nanosecond())/1000 + microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + case Infinity: + microsecSinceY2K = infinityMicrosecondOffset + case NegativeInfinity: + microsecSinceY2K = negativeInfinityMicrosecondOffset + } + + buf = pgio.AppendInt64(buf, microsecSinceY2K) + + return buf, nil } -func (src *Timestamptz) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: +type encodePlanTimestamptzCodecText struct{} + +func (encodePlanTimestamptzCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + ts, err := value.(TimestamptzValuer).TimestamptzValue() + if err != nil { + return nil, err + } + + if !ts.Valid { return nil, nil - case Undefined: - return nil, errUndefined } var s string - switch src.InfinityModifier { - case None: - s = src.Time.UTC().Format(pgTimestamptzSecondFormat) + switch ts.InfinityModifier { + case Finite: + + t := ts.Time.UTC().Truncate(time.Microsecond) + + // Year 0000 is 1 BC + bc := false + if year := t.Year(); year <= 0 { + year = -year + 1 + t = time.Date(year, t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.UTC) + bc = true + } + + s = t.Format(pgTimestamptzSecondFormat) + + if bc { + s = s + " BC" + } case Infinity: s = "infinity" case NegativeInfinity: s = "-infinity" } - return append(buf, s...), nil + buf = append(buf, s...) + + return buf, nil } -func (src *Timestamptz) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined +func (c *TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case TimestamptzScanner: + return &scanPlanBinaryTimestamptzToTimestamptzScanner{location: c.ScanLocation} + } + case TextFormatCode: + switch target.(type) { + case TimestamptzScanner: + return &scanPlanTextTimestamptzToTimestamptzScanner{location: c.ScanLocation} + } } - var microsecSinceY2K int64 - switch src.InfinityModifier { - case None: - microsecSinceUnixEpoch := src.Time.Unix()*1000000 + int64(src.Time.Nanosecond())/1000 - microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K - case Infinity: - microsecSinceY2K = infinityMicrosecondOffset - case NegativeInfinity: - microsecSinceY2K = negativeInfinityMicrosecondOffset + return nil +} + +type scanPlanBinaryTimestamptzToTimestamptzScanner struct{ location *time.Location } + +func (plan *scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TimestamptzScanner) + + if src == nil { + return scanner.ScanTimestamptz(Timestamptz{}) } - return pgio.AppendInt64(buf, microsecSinceY2K), nil + if len(src) != 8 { + return fmt.Errorf("invalid length for timestamptz: %v", len(src)) + } + + var tstz Timestamptz + microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) + + switch microsecSinceY2K { + case infinityMicrosecondOffset: + tstz = Timestamptz{Valid: true, InfinityModifier: Infinity} + case negativeInfinityMicrosecondOffset: + tstz = Timestamptz{Valid: true, InfinityModifier: -Infinity} + default: + tim := time.Unix( + microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, + (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), + ) + if plan.location != nil { + tim = tim.In(plan.location) + } + tstz = Timestamptz{Time: tim, Valid: true} + } + + return scanner.ScanTimestamptz(tstz) } -// Scan implements the database/sql Scanner interface. -func (dst *Timestamptz) Scan(src interface{}) error { +type scanPlanTextTimestamptzToTimestamptzScanner struct{ location *time.Location } + +func (plan *scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TimestamptzScanner) + if src == nil { - *dst = Timestamptz{Status: Null} - return nil + return scanner.ScanTimestamptz(Timestamptz{}) } - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - case time.Time: - *dst = Timestamptz{Time: src, Status: Present} - return nil + var tstz Timestamptz + sbuf := string(src) + switch sbuf { + case "infinity": + tstz = Timestamptz{Valid: true, InfinityModifier: Infinity} + case "-infinity": + tstz = Timestamptz{Valid: true, InfinityModifier: -Infinity} + default: + bc := false + if strings.HasSuffix(sbuf, " BC") { + sbuf = sbuf[:len(sbuf)-3] + bc = true + } + + var format string + if len(sbuf) >= 9 && (sbuf[len(sbuf)-9] == '-' || sbuf[len(sbuf)-9] == '+') { + format = pgTimestamptzSecondFormat + } else if len(sbuf) >= 6 && (sbuf[len(sbuf)-6] == '-' || sbuf[len(sbuf)-6] == '+') { + format = pgTimestamptzMinuteFormat + } else { + format = pgTimestamptzHourFormat + } + + tim, err := time.Parse(format, sbuf) + if err != nil { + return err + } + + if bc { + year := -tim.Year() + 1 + tim = time.Date(year, tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), tim.Location()) + } + + if plan.location != nil { + tim = tim.In(plan.location) + } + + tstz = Timestamptz{Time: tim, Valid: true} } - return errors.Errorf("cannot scan %T", src) + return scanner.ScanTimestamptz(tstz) } -// Value implements the database/sql/driver Valuer interface. -func (src *Timestamptz) Value() (driver.Value, error) { - switch src.Status { - case Present: - if src.InfinityModifier != None { - return src.InfinityModifier.String(), nil - } - return src.Time, nil - case Null: +func (c *TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var tstz Timestamptz + err := codecScan(c, m, oid, format, src, &tstz) + if err != nil { + return nil, err + } + + if tstz.InfinityModifier != Finite { + return tstz.InfinityModifier.String(), nil + } + + return tstz.Time, nil +} + +func (c *TimestamptzCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { return nil, nil - default: - return nil, errUndefined } + + var tstz Timestamptz + err := codecScan(c, m, oid, format, src, &tstz) + if err != nil { + return nil, err + } + + if tstz.InfinityModifier != Finite { + return tstz.InfinityModifier, nil + } + + return tstz.Time, nil } diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go deleted file mode 100644 index 88b6cc5f1..000000000 --- a/pgtype/timestamptz_array.go +++ /dev/null @@ -1,301 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "time" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type TimestamptzArray struct { - Elements []Timestamptz - Dimensions []ArrayDimension - Status Status -} - -func (dst *TimestamptzArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = TimestamptzArray{Status: Null} - return nil - } - - switch value := src.(type) { - - case []time.Time: - if value == nil { - *dst = TimestamptzArray{Status: Null} - } else if len(value) == 0 { - *dst = TimestamptzArray{Status: Present} - } else { - elements := make([]Timestamptz, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TimestamptzArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - default: - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to TimestamptzArray", value) - } - - return nil -} - -func (dst *TimestamptzArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *TimestamptzArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - - case *[]time.Time: - *v = make([]time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *TimestamptzArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = TimestamptzArray{Status: Null} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Timestamptz - - if len(uta.Elements) > 0 { - elements = make([]Timestamptz, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Timestamptz - var elemSrc []byte - if s != "NULL" { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = TimestamptzArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} - - return nil -} - -func (dst *TimestamptzArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = TimestamptzArray{Status: Null} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = TimestamptzArray{Dimensions: arrayHeader.Dimensions, Status: Present} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Timestamptz, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = TimestamptzArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} - return nil -} - -func (src *TimestamptzArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("timestamptz"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, errors.Errorf("unable to find oid for type name %v", "timestamptz") - } - - for i := range src.Elements { - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *TimestamptzArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *TimestamptzArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/timestamptz_array_test.go b/pgtype/timestamptz_array_test.go deleted file mode 100644 index 8d7ea4c95..000000000 --- a/pgtype/timestamptz_array_test.go +++ /dev/null @@ -1,159 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - "time" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestTimestamptzArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz[]", []interface{}{ - &pgtype.TimestamptzArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.TimestamptzArray{Status: pgtype.Null}, - &pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Status: pgtype.Null}, - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{ - {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }, func(a, b interface{}) bool { - ata := a.(pgtype.TimestamptzArray) - bta := b.(pgtype.TimestamptzArray) - - if len(ata.Elements) != len(bta.Elements) || ata.Status != bta.Status { - return false - } - - for i := range ata.Elements { - ae, be := ata.Elements[i], bta.Elements[i] - if !(ae.Time.Equal(be.Time) && ae.Status == be.Status && ae.InfinityModifier == be.InfinityModifier) { - return false - } - } - - return true - }) -} - -func TestTimestamptzArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.TimestamptzArray - }{ - { - source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - result: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]time.Time)(nil)), - result: pgtype.TimestamptzArray{Status: pgtype.Null}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.TimestamptzArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestTimestamptzArrayAssignTo(t *testing.T) { - var timeSlice []time.Time - - simpleTests := []struct { - src pgtype.TimestamptzArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &timeSlice, - expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, - }, - { - src: pgtype.TimestamptzArray{Status: pgtype.Null}, - dst: &timeSlice, - expected: (([]time.Time)(nil)), - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.TimestamptzArray - dst interface{} - }{ - { - src: pgtype.TimestamptzArray{ - Elements: []pgtype.Timestamptz{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &timeSlice, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/pgtype/timestamptz_test.go b/pgtype/timestamptz_test.go index c326802d2..572481958 100644 --- a/pgtype/timestamptz_test.go +++ b/pgtype/timestamptz_test.go @@ -1,122 +1,145 @@ package pgtype_test import ( - "reflect" + "context" "testing" "time" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" ) -func TestTimestamptzTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ - &pgtype.Timestamptz{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - &pgtype.Timestamptz{Status: pgtype.Null}, - &pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, - &pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, - }, func(a, b interface{}) bool { - at := a.(pgtype.Timestamptz) - bt := b.(pgtype.Timestamptz) - - return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier +func TestTimestamptzCodec(t *testing.T) { + skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "timestamptz", []pgxtest.ValueRoundTripTest{ + {time.Date(-100, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(-100, 1, 1, 0, 0, 0, 0, time.Local))}, + {time.Date(-1, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(-1, 1, 1, 0, 0, 0, 0, time.Local))}, + {time.Date(0, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(0, 1, 1, 0, 0, 0, 0, time.Local))}, + {time.Date(1, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(1, 1, 1, 0, 0, 0, 0, time.Local))}, + + {time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local))}, + {time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local))}, + {time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local))}, + {time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local))}, + {time.Date(2000, 1, 2, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(2000, 1, 2, 0, 0, 0, 0, time.Local))}, + {time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEqTime(time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local))}, + + // Nanosecond truncation + {time.Date(2020, 1, 1, 0, 0, 0, 999999999, time.Local), new(time.Time), isExpectedEqTime(time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.Local))}, + {time.Date(2020, 1, 1, 0, 0, 0, 999999001, time.Local), new(time.Time), isExpectedEqTime(time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.Local))}, + + {pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Valid: true}, new(pgtype.Timestamptz), isExpectedEq(pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Valid: true})}, + {pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, new(pgtype.Timestamptz), isExpectedEq(pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Valid: true})}, + {pgtype.Timestamptz{}, new(pgtype.Timestamptz), isExpectedEq(pgtype.Timestamptz{})}, + {nil, new(*time.Time), isExpectedEq((*time.Time)(nil))}, }) } -func TestTimestamptzSet(t *testing.T) { - type _time time.Time +func TestTimestamptzCodecWithLocationUTC(t *testing.T) { + skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)") - successfulTests := []struct { - source interface{} - result pgtype.Timestamptz - }{ - {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, - {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, - {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), Status: pgtype.Present}}, - {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, - {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), Status: pgtype.Present}}, - {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, - {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local)), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + connTestRunner := defaultConnTestRunner + connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "timestamptz", + OID: pgtype.TimestamptzOID, + Codec: &pgtype.TimestamptzCodec{ScanLocation: time.UTC}, + }) } - for i, tt := range successfulTests { - var r pgtype.Timestamptz - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } + pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, nil, "timestamptz", []pgxtest.ValueRoundTripTest{ + {time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEq(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC))}, + }) } -func TestTimestamptzAssignTo(t *testing.T) { - var tim time.Time - var ptim *time.Time +func TestTimestamptzCodecWithLocationLocal(t *testing.T) { + skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)") - simpleTests := []struct { - src pgtype.Timestamptz - dst interface{} - expected interface{} - }{ - {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, - {src: pgtype.Timestamptz{Time: time.Time{}, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, + connTestRunner := defaultConnTestRunner + connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "timestamptz", + OID: pgtype.TimestamptzOID, + Codec: &pgtype.TimestamptzCodec{ScanLocation: time.Local}, + }) } - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) + pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, nil, "timestamptz", []pgxtest.ValueRoundTripTest{ + {time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEq(time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local))}, + }) +} + +// https://github.com/jackc/pgx/v4/pgtype/pull/128 +func TestTimestamptzTranscodeBigTimeBinary(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + in := &pgtype.Timestamptz{Time: time.Date(294276, 12, 31, 23, 59, 59, 999999000, time.UTC), Valid: true} + var out pgtype.Timestamptz + + err := conn.QueryRow(ctx, "select $1::timestamptz", in).Scan(&out) if err != nil { - t.Errorf("%d: %v", i, err) + t.Fatal(err) } - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } + require.Equal(t, in.Valid, out.Valid) + require.Truef(t, in.Time.Equal(out.Time), "expected %v got %v", in.Time, out.Time) + }) +} - pointerAllocTests := []struct { - src pgtype.Timestamptz - dst interface{} - expected interface{} +// https://github.com/jackc/pgtype/issues/74 +func TestTimestamptzDecodeTextInvalid(t *testing.T) { + c := &pgtype.TimestamptzCodec{} + var tstz pgtype.Timestamptz + plan := c.PlanScan(nil, pgtype.TimestamptzOID, pgtype.TextFormatCode, &tstz) + err := plan.Scan([]byte(`eeeee`), &tstz) + require.Error(t, err) +} + +func TestTimestamptzMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Timestamptz + result string }{ - {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + {source: pgtype.Timestamptz{}, result: "null"}, + {source: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Valid: true}, result: "\"2012-03-29T10:05:45-06:00\""}, + {source: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Valid: true}, result: "\"2012-03-29T10:05:45.555-06:00\""}, + {source: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Valid: true}, result: "\"infinity\""}, + {source: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, result: "\"-infinity\""}, } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() if err != nil { t.Errorf("%d: %v", i, err) } - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) } } +} - errorTests := []struct { - src pgtype.Timestamptz - dst interface{} +func TestTimestamptzUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Timestamptz }{ - {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, dst: &tim}, - {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, dst: &tim}, - {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Null}, dst: &tim}, + {source: "null", result: pgtype.Timestamptz{}}, + {source: "\"2012-03-29T10:05:45-06:00\"", result: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Valid: true}}, + {source: "\"2012-03-29T10:05:45.555-06:00\"", result: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Valid: true}}, + {source: "\"infinity\"", result: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Valid: true}}, + {source: "\"-infinity\"", result: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Valid: true}}, } + for i, tt := range successfulTests { + var r pgtype.Timestamptz + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + if !r.Time.Equal(tt.result.Time) || r.Valid != tt.result.Valid || r.InfinityModifier != tt.result.InfinityModifier { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) } } } diff --git a/pgtype/tsrange.go b/pgtype/tsrange.go deleted file mode 100644 index 8a67d65eb..000000000 --- a/pgtype/tsrange.go +++ /dev/null @@ -1,250 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type Tsrange struct { - Lower Timestamp - Upper Timestamp - LowerType BoundType - UpperType BoundType - Status Status -} - -func (dst *Tsrange) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Tsrange", src) -} - -func (dst *Tsrange) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *Tsrange) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Tsrange) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Tsrange{Status: Null} - return nil - } - - utr, err := ParseUntypedTextRange(string(src)) - if err != nil { - return err - } - - *dst = Tsrange{Status: Present} - - dst.LowerType = utr.LowerType - dst.UpperType = utr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { - return err - } - } - - return nil -} - -func (dst *Tsrange) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Tsrange{Status: Null} - return nil - } - - ubr, err := ParseUntypedBinaryRange(src) - if err != nil { - return err - } - - *dst = Tsrange{Status: Present} - - dst.LowerType = ubr.LowerType - dst.UpperType = ubr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { - return err - } - } - - return nil -} - -func (src Tsrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - switch src.LowerType { - case Exclusive, Unbounded: - buf = append(buf, '(') - case Inclusive: - buf = append(buf, '[') - case Empty: - return append(buf, "empty"...), nil - default: - return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) - } - - var err error - - if src.LowerType != Unbounded { - buf, err = src.Lower.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - } - - buf = append(buf, ',') - - if src.UpperType != Unbounded { - buf, err = src.Upper.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - } - - switch src.UpperType { - case Exclusive, Unbounded: - buf = append(buf, ')') - case Inclusive: - buf = append(buf, ']') - default: - return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) - } - - return buf, nil -} - -func (src Tsrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - var rangeType byte - switch src.LowerType { - case Inclusive: - rangeType |= lowerInclusiveMask - case Unbounded: - rangeType |= lowerUnboundedMask - case Exclusive: - case Empty: - return append(buf, emptyMask), nil - default: - return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) - } - - switch src.UpperType { - case Inclusive: - rangeType |= upperInclusiveMask - case Unbounded: - rangeType |= upperUnboundedMask - case Exclusive: - default: - return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) - } - - buf = append(buf, rangeType) - - var err error - - if src.LowerType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Lower.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - if src.UpperType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Upper.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Tsrange) Scan(src interface{}) error { - if src == nil { - *dst = Tsrange{Status: Null} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Tsrange) Value() (driver.Value, error) { - return EncodeValueText(src) -} diff --git a/pgtype/tsrange_test.go b/pgtype/tsrange_test.go deleted file mode 100644 index 78eb1cd36..000000000 --- a/pgtype/tsrange_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package pgtype_test - -import ( - "testing" - "time" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestTsrangeTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "tsrange", []interface{}{ - &pgtype.Tsrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, - &pgtype.Tsrange{ - Lower: pgtype.Timestamp{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Timestamp{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - &pgtype.Tsrange{ - Lower: pgtype.Timestamp{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - &pgtype.Tsrange{Status: pgtype.Null}, - }, func(aa, bb interface{}) bool { - a := aa.(pgtype.Tsrange) - b := bb.(pgtype.Tsrange) - - return a.Status == b.Status && - a.Lower.Time.Equal(b.Lower.Time) && - a.Lower.Status == b.Lower.Status && - a.Lower.InfinityModifier == b.Lower.InfinityModifier && - a.Upper.Time.Equal(b.Upper.Time) && - a.Upper.Status == b.Upper.Status && - a.Upper.InfinityModifier == b.Upper.InfinityModifier - }) -} diff --git a/pgtype/tstzrange.go b/pgtype/tstzrange.go deleted file mode 100644 index b5129093a..000000000 --- a/pgtype/tstzrange.go +++ /dev/null @@ -1,250 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type Tstzrange struct { - Lower Timestamptz - Upper Timestamptz - LowerType BoundType - UpperType BoundType - Status Status -} - -func (dst *Tstzrange) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Tstzrange", src) -} - -func (dst *Tstzrange) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *Tstzrange) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Tstzrange) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Tstzrange{Status: Null} - return nil - } - - utr, err := ParseUntypedTextRange(string(src)) - if err != nil { - return err - } - - *dst = Tstzrange{Status: Present} - - dst.LowerType = utr.LowerType - dst.UpperType = utr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { - return err - } - } - - return nil -} - -func (dst *Tstzrange) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Tstzrange{Status: Null} - return nil - } - - ubr, err := ParseUntypedBinaryRange(src) - if err != nil { - return err - } - - *dst = Tstzrange{Status: Present} - - dst.LowerType = ubr.LowerType - dst.UpperType = ubr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { - return err - } - } - - return nil -} - -func (src Tstzrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - switch src.LowerType { - case Exclusive, Unbounded: - buf = append(buf, '(') - case Inclusive: - buf = append(buf, '[') - case Empty: - return append(buf, "empty"...), nil - default: - return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) - } - - var err error - - if src.LowerType != Unbounded { - buf, err = src.Lower.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - } - - buf = append(buf, ',') - - if src.UpperType != Unbounded { - buf, err = src.Upper.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - } - - switch src.UpperType { - case Exclusive, Unbounded: - buf = append(buf, ')') - case Inclusive: - buf = append(buf, ']') - default: - return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) - } - - return buf, nil -} - -func (src Tstzrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - var rangeType byte - switch src.LowerType { - case Inclusive: - rangeType |= lowerInclusiveMask - case Unbounded: - rangeType |= lowerUnboundedMask - case Exclusive: - case Empty: - return append(buf, emptyMask), nil - default: - return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) - } - - switch src.UpperType { - case Inclusive: - rangeType |= upperInclusiveMask - case Unbounded: - rangeType |= upperUnboundedMask - case Exclusive: - default: - return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) - } - - buf = append(buf, rangeType) - - var err error - - if src.LowerType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Lower.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - if src.UpperType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Upper.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Tstzrange) Scan(src interface{}) error { - if src == nil { - *dst = Tstzrange{Status: Null} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Tstzrange) Value() (driver.Value, error) { - return EncodeValueText(src) -} diff --git a/pgtype/tstzrange_test.go b/pgtype/tstzrange_test.go deleted file mode 100644 index a27ddd3a2..000000000 --- a/pgtype/tstzrange_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package pgtype_test - -import ( - "testing" - "time" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestTstzrangeTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "tstzrange", []interface{}{ - &pgtype.Tstzrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, - &pgtype.Tstzrange{ - Lower: pgtype.Timestamptz{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Timestamptz{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - &pgtype.Tstzrange{ - Lower: pgtype.Timestamptz{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - Upper: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, - LowerType: pgtype.Inclusive, - UpperType: pgtype.Exclusive, - Status: pgtype.Present, - }, - &pgtype.Tstzrange{Status: pgtype.Null}, - }, func(aa, bb interface{}) bool { - a := aa.(pgtype.Tstzrange) - b := bb.(pgtype.Tstzrange) - - return a.Status == b.Status && - a.Lower.Time.Equal(b.Lower.Time) && - a.Lower.Status == b.Lower.Status && - a.Lower.InfinityModifier == b.Lower.InfinityModifier && - a.Upper.Time.Equal(b.Upper.Time) && - a.Upper.Status == b.Upper.Status && - a.Upper.InfinityModifier == b.Upper.InfinityModifier - }) -} diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb deleted file mode 100644 index 6fafc2dfc..000000000 --- a/pgtype/typed_array.go.erb +++ /dev/null @@ -1,304 +0,0 @@ -package pgtype - -import ( - "bytes" - "fmt" - "io" - - "github.com/jackc/pgx/pgio" -) - -type <%= pgtype_array_type %> struct { - Elements []<%= pgtype_element_type %> - Dimensions []ArrayDimension - Status Status -} - -func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = <%= pgtype_array_type %>{Status: Null} - return nil - } - - switch value := src.(type) { - <% go_array_types.split(",").each do |t| %> - case <%= t %>: - if value == nil { - *dst = <%= pgtype_array_type %>{Status: Null} - } else if len(value) == 0 { - *dst = <%= pgtype_array_type %>{Status: Present} - } else { - elements := make([]<%= pgtype_element_type %>, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = <%= pgtype_array_type %>{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - <% end %> - default: - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to <%= pgtype_array_type %>", value) - } - - return nil -} - -func (dst *<%= pgtype_array_type %>) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - <% go_array_types.split(",").each do |t| %> - case *<%= t %>: - *v = make(<%= t %>, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - <% end %> - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = <%= pgtype_array_type %>{Status: Null} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []<%= pgtype_element_type %> - - if len(uta.Elements) > 0 { - elements = make([]<%= pgtype_element_type %>, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem <%= pgtype_element_type %> - var elemSrc []byte - if s != "NULL" { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = <%= pgtype_array_type %>{Elements: elements, Dimensions: uta.Dimensions, Status: Present} - - return nil -} - -<% if binary_format == "true" %> -func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = <%= pgtype_array_type %>{Status: Null} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = <%= pgtype_array_type %>{Dimensions: arrayHeader.Dimensions, Status: Present} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]<%= pgtype_element_type %>, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp:rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = <%= pgtype_array_type %>{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} - return nil -} -<% end %> - -func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `<%= text_null %>`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -<% if binary_format == "true" %> - func (src *<%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("<%= element_type_name %>"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, errors.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") - } - - for i := range src.Elements { - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil - } -<% end %> - -// Scan implements the database/sql Scanner interface. -func (dst *<%= pgtype_array_type %>) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *<%= pgtype_array_type %>) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh deleted file mode 100644 index 4a8211bca..000000000 --- a/pgtype/typed_array_gen.sh +++ /dev/null @@ -1,24 +0,0 @@ -erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go -erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go -erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go -erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go -erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go -erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go -erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_type_name=timestamp text_null=NULL binary_format=true typed_array.go.erb > timestamp_array.go -erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go -erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go -erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go -erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR go_array_types=[]*net.IPNet,[]net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go -erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null='"NULL"' binary_format=true typed_array.go.erb > text_array.go -erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_type_name=varchar text_null='"NULL"' binary_format=true typed_array.go.erb > varchar_array.go -erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar go_array_types=[]string element_type_name=bpchar text_null='NULL' binary_format=true typed_array.go.erb > bpchar_array.go -erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go -erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go -erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go -erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]float64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go -erb pgtype_array_type=UUIDArray pgtype_element_type=UUID go_array_types=[][16]byte,[][]byte,[]string element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go - -# While the binary format is theoretically possible it is only practical to use the text format. In addition, the text format for NULL enums is unquoted so TextArray or a possible GenericTextArray cannot be used. -erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string text_null='NULL' binary_format=false typed_array.go.erb > enum_array.go - -goimports -w *_array.go diff --git a/pgtype/typed_range.go.erb b/pgtype/typed_range.go.erb deleted file mode 100644 index 91a5cb972..000000000 --- a/pgtype/typed_range.go.erb +++ /dev/null @@ -1,252 +0,0 @@ -package pgtype - -import ( - "bytes" - "database/sql/driver" - "fmt" - "io" - - "github.com/jackc/pgx/pgio" -) - -type <%= range_type %> struct { - Lower <%= element_type %> - Upper <%= element_type %> - LowerType BoundType - UpperType BoundType - Status Status -} - -func (dst *<%= range_type %>) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to <%= range_type %>", src) -} - -func (dst *<%= range_type %>) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *<%= range_type %>) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *<%= range_type %>) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = <%= range_type %>{Status: Null} - return nil - } - - utr, err := ParseUntypedTextRange(string(src)) - if err != nil { - return err - } - - *dst = <%= range_type %>{Status: Present} - - dst.LowerType = utr.LowerType - dst.UpperType = utr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { - return err - } - } - - return nil -} - -func (dst *<%= range_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = <%= range_type %>{Status: Null} - return nil - } - - ubr, err := ParseUntypedBinaryRange(src) - if err != nil { - return err - } - - *dst = <%= range_type %>{Status: Present} - - dst.LowerType = ubr.LowerType - dst.UpperType = ubr.UpperType - - if dst.LowerType == Empty { - return nil - } - - if dst.LowerType == Inclusive || dst.LowerType == Exclusive { - if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { - return err - } - } - - if dst.UpperType == Inclusive || dst.UpperType == Exclusive { - if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { - return err - } - } - - return nil -} - -func (src <%= range_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - switch src.LowerType { - case Exclusive, Unbounded: - buf = append(buf, '(') - case Inclusive: - buf = append(buf, '[') - case Empty: - return append(buf, "empty"...), nil - default: - return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) - } - - var err error - - if src.LowerType != Unbounded { - buf, err = src.Lower.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - } - - buf = append(buf, ',') - - if src.UpperType != Unbounded { - buf, err = src.Upper.EncodeText(ci, buf) - if err != nil { - return nil, err - } else if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - } - - switch src.UpperType { - case Exclusive, Unbounded: - buf = append(buf, ')') - case Inclusive: - buf = append(buf, ']') - default: - return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) - } - - return buf, nil -} - -func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - var rangeType byte - switch src.LowerType { - case Inclusive: - rangeType |= lowerInclusiveMask - case Unbounded: - rangeType |= lowerUnboundedMask - case Exclusive: - case Empty: - return append(buf, emptyMask), nil - default: - return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) - } - - switch src.UpperType { - case Inclusive: - rangeType |= upperInclusiveMask - case Unbounded: - rangeType |= upperUnboundedMask - case Exclusive: - default: - return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) - } - - buf = append(buf, rangeType) - - var err error - - if src.LowerType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Lower.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - if src.UpperType != Unbounded { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf, err = src.Upper.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if buf == nil { - return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") - } - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *<%= range_type %>) Scan(src interface{}) error { - if src == nil { - *dst = <%= range_type %>{Status: Null} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src <%= range_type %>) Value() (driver.Value, error) { - return EncodeValueText(src) -} diff --git a/pgtype/typed_range_gen.sh b/pgtype/typed_range_gen.sh deleted file mode 100644 index bedda2925..000000000 --- a/pgtype/typed_range_gen.sh +++ /dev/null @@ -1,7 +0,0 @@ -erb range_type=Int4range element_type=Int4 typed_range.go.erb > int4range.go -erb range_type=Int8range element_type=Int8 typed_range.go.erb > int8range.go -erb range_type=Tsrange element_type=Timestamp typed_range.go.erb > tsrange.go -erb range_type=Tstzrange element_type=Timestamptz typed_range.go.erb > tstzrange.go -erb range_type=Daterange element_type=Date typed_range.go.erb > daterange.go -erb range_type=Numrange element_type=Numeric typed_range.go.erb > numrange.go -goimports -w *range.go diff --git a/pgtype/uint32.go b/pgtype/uint32.go new file mode 100644 index 000000000..e6d4b1cf6 --- /dev/null +++ b/pgtype/uint32.go @@ -0,0 +1,352 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "encoding/json" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type Uint32Scanner interface { + ScanUint32(v Uint32) error +} + +type Uint32Valuer interface { + Uint32Value() (Uint32, error) +} + +// Uint32 is the core type that is used to represent PostgreSQL types such as OID, CID, and XID. +type Uint32 struct { + Uint32 uint32 + Valid bool +} + +// ScanUint32 implements the [Uint32Scanner] interface. +func (n *Uint32) ScanUint32(v Uint32) error { + *n = v + return nil +} + +// Uint32Value implements the [Uint32Valuer] interface. +func (n Uint32) Uint32Value() (Uint32, error) { + return n, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (dst *Uint32) Scan(src any) error { + if src == nil { + *dst = Uint32{} + return nil + } + + var n int64 + + switch src := src.(type) { + case int64: + n = src + case string: + un, err := strconv.ParseUint(src, 10, 32) + if err != nil { + return err + } + n = int64(un) + default: + return fmt.Errorf("cannot scan %T", src) + } + + if n < 0 { + return fmt.Errorf("%d is less than the minimum value for Uint32", n) + } + if n > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for Uint32", n) + } + + *dst = Uint32{Uint32: uint32(n), Valid: true} + + return nil +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (src Uint32) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return int64(src.Uint32), nil +} + +// MarshalJSON implements the [encoding/json.Marshaler] interface. +func (src Uint32) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + return json.Marshal(src.Uint32) +} + +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. +func (dst *Uint32) UnmarshalJSON(b []byte) error { + var n *uint32 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *dst = Uint32{} + } else { + *dst = Uint32{Uint32: *n, Valid: true} + } + + return nil +} + +type Uint32Codec struct{} + +func (Uint32Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Uint32Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Uint32Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case uint32: + return encodePlanUint32CodecBinaryUint32{} + case Uint32Valuer: + return encodePlanUint32CodecBinaryUint32Valuer{} + case Int64Valuer: + return encodePlanUint32CodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case uint32: + return encodePlanUint32CodecTextUint32{} + case Int64Valuer: + return encodePlanUint32CodecTextInt64Valuer{} + } + } + + return nil +} + +type encodePlanUint32CodecBinaryUint32 struct{} + +func (encodePlanUint32CodecBinaryUint32) Encode(value any, buf []byte) (newBuf []byte, err error) { + v := value.(uint32) + return pgio.AppendUint32(buf, v), nil +} + +type encodePlanUint32CodecBinaryUint32Valuer struct{} + +func (encodePlanUint32CodecBinaryUint32Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Uint32Valuer).Uint32Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + return pgio.AppendUint32(buf, v.Uint32), nil +} + +type encodePlanUint32CodecBinaryInt64Valuer struct{} + +func (encodePlanUint32CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + if v.Int64 < 0 { + return nil, fmt.Errorf("%d is less than minimum value for uint32", v.Int64) + } + if v.Int64 > math.MaxUint32 { + return nil, fmt.Errorf("%d is greater than maximum value for uint32", v.Int64) + } + + return pgio.AppendUint32(buf, uint32(v.Int64)), nil +} + +type encodePlanUint32CodecTextUint32 struct{} + +func (encodePlanUint32CodecTextUint32) Encode(value any, buf []byte) (newBuf []byte, err error) { + v := value.(uint32) + return append(buf, strconv.FormatUint(uint64(v), 10)...), nil +} + +type encodePlanUint32CodecTextUint32Valuer struct{} + +func (encodePlanUint32CodecTextUint32Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Uint32Valuer).Uint32Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + return append(buf, strconv.FormatUint(uint64(v.Uint32), 10)...), nil +} + +type encodePlanUint32CodecTextInt64Valuer struct{} + +func (encodePlanUint32CodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + if v.Int64 < 0 { + return nil, fmt.Errorf("%d is less than minimum value for uint32", v.Int64) + } + if v.Int64 > math.MaxUint32 { + return nil, fmt.Errorf("%d is greater than maximum value for uint32", v.Int64) + } + + return append(buf, strconv.FormatInt(v.Int64, 10)...), nil +} + +func (Uint32Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case *uint32: + return scanPlanBinaryUint32ToUint32{} + case Uint32Scanner: + return scanPlanBinaryUint32ToUint32Scanner{} + case TextScanner: + return scanPlanBinaryUint32ToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case *uint32: + return scanPlanTextAnyToUint32{} + case Uint32Scanner: + return scanPlanTextAnyToUint32Scanner{} + } + } + + return nil +} + +func (c Uint32Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var n uint32 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return int64(n), nil +} + +func (c Uint32Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var n uint32 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +type scanPlanBinaryUint32ToUint32 struct{} + +func (scanPlanBinaryUint32ToUint32) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint32: %v", len(src)) + } + + p := (dst).(*uint32) + *p = binary.BigEndian.Uint32(src) + + return nil +} + +type scanPlanBinaryUint32ToUint32Scanner struct{} + +func (scanPlanBinaryUint32ToUint32Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Uint32Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanUint32(Uint32{}) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint32: %v", len(src)) + } + + n := binary.BigEndian.Uint32(src) + + return s.ScanUint32(Uint32{Uint32: n, Valid: true}) +} + +type scanPlanBinaryUint32ToTextScanner struct{} + +func (scanPlanBinaryUint32ToTextScanner) Scan(src []byte, dst any) error { + s, ok := (dst).(TextScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint32: %v", len(src)) + } + + n := uint64(binary.BigEndian.Uint32(src)) + return s.ScanText(Text{String: strconv.FormatUint(n, 10), Valid: true}) +} + +type scanPlanTextAnyToUint32Scanner struct{} + +func (scanPlanTextAnyToUint32Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Uint32Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanUint32(Uint32{}) + } + + n, err := strconv.ParseUint(string(src), 10, 32) + if err != nil { + return err + } + + return s.ScanUint32(Uint32{Uint32: uint32(n), Valid: true}) +} diff --git a/pgtype/uint32_test.go b/pgtype/uint32_test.go new file mode 100644 index 000000000..efa4e2730 --- /dev/null +++ b/pgtype/uint32_test.go @@ -0,0 +1,22 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestUint32Codec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "oid", []pgxtest.ValueRoundTripTest{ + { + pgtype.Uint32{Uint32: pgtype.TextOID, Valid: true}, + new(pgtype.Uint32), + isExpectedEq(pgtype.Uint32{Uint32: pgtype.TextOID, Valid: true}), + }, + {pgtype.Uint32{}, new(pgtype.Uint32), isExpectedEq(pgtype.Uint32{})}, + {nil, new(pgtype.Uint32), isExpectedEq(pgtype.Uint32{})}, + {"1147", new(string), isExpectedEq("1147")}, + }) +} diff --git a/pgtype/uint64.go b/pgtype/uint64.go new file mode 100644 index 000000000..68fd16613 --- /dev/null +++ b/pgtype/uint64.go @@ -0,0 +1,323 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type Uint64Scanner interface { + ScanUint64(v Uint64) error +} + +type Uint64Valuer interface { + Uint64Value() (Uint64, error) +} + +// Uint64 is the core type that is used to represent PostgreSQL types such as XID8. +type Uint64 struct { + Uint64 uint64 + Valid bool +} + +// ScanUint64 implements the [Uint64Scanner] interface. +func (n *Uint64) ScanUint64(v Uint64) error { + *n = v + return nil +} + +// Uint64Value implements the [Uint64Valuer] interface. +func (n Uint64) Uint64Value() (Uint64, error) { + return n, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (dst *Uint64) Scan(src any) error { + if src == nil { + *dst = Uint64{} + return nil + } + + var n uint64 + + switch src := src.(type) { + case int64: + if src < 0 { + return fmt.Errorf("%d is less than the minimum value for Uint64", src) + } + n = uint64(src) + case string: + un, err := strconv.ParseUint(src, 10, 64) + if err != nil { + return err + } + n = un + default: + return fmt.Errorf("cannot scan %T", src) + } + + *dst = Uint64{Uint64: n, Valid: true} + + return nil +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (src Uint64) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + // If the value is greater than the maximum value for int64, return it as a string instead of losing data or returning + // an error. + if src.Uint64 > math.MaxInt64 { + return strconv.FormatUint(src.Uint64, 10), nil + } + + return int64(src.Uint64), nil +} + +type Uint64Codec struct{} + +func (Uint64Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Uint64Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Uint64Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case uint64: + return encodePlanUint64CodecBinaryUint64{} + case Uint64Valuer: + return encodePlanUint64CodecBinaryUint64Valuer{} + case Int64Valuer: + return encodePlanUint64CodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case uint64: + return encodePlanUint64CodecTextUint64{} + case Int64Valuer: + return encodePlanUint64CodecTextInt64Valuer{} + } + } + + return nil +} + +type encodePlanUint64CodecBinaryUint64 struct{} + +func (encodePlanUint64CodecBinaryUint64) Encode(value any, buf []byte) (newBuf []byte, err error) { + v := value.(uint64) + return pgio.AppendUint64(buf, v), nil +} + +type encodePlanUint64CodecBinaryUint64Valuer struct{} + +func (encodePlanUint64CodecBinaryUint64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Uint64Valuer).Uint64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + return pgio.AppendUint64(buf, v.Uint64), nil +} + +type encodePlanUint64CodecBinaryInt64Valuer struct{} + +func (encodePlanUint64CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + if v.Int64 < 0 { + return nil, fmt.Errorf("%d is less than minimum value for uint64", v.Int64) + } + + return pgio.AppendUint64(buf, uint64(v.Int64)), nil +} + +type encodePlanUint64CodecTextUint64 struct{} + +func (encodePlanUint64CodecTextUint64) Encode(value any, buf []byte) (newBuf []byte, err error) { + v := value.(uint64) + return append(buf, strconv.FormatUint(uint64(v), 10)...), nil +} + +type encodePlanUint64CodecTextUint64Valuer struct{} + +func (encodePlanUint64CodecTextUint64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Uint64Valuer).Uint64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + return append(buf, strconv.FormatUint(v.Uint64, 10)...), nil +} + +type encodePlanUint64CodecTextInt64Valuer struct{} + +func (encodePlanUint64CodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + if v.Int64 < 0 { + return nil, fmt.Errorf("%d is less than minimum value for uint64", v.Int64) + } + + return append(buf, strconv.FormatInt(v.Int64, 10)...), nil +} + +func (Uint64Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case *uint64: + return scanPlanBinaryUint64ToUint64{} + case Uint64Scanner: + return scanPlanBinaryUint64ToUint64Scanner{} + case TextScanner: + return scanPlanBinaryUint64ToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case *uint64: + return scanPlanTextAnyToUint64{} + case Uint64Scanner: + return scanPlanTextAnyToUint64Scanner{} + } + } + + return nil +} + +func (c Uint64Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var n uint64 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return int64(n), nil +} + +func (c Uint64Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var n uint64 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +type scanPlanBinaryUint64ToUint64 struct{} + +func (scanPlanBinaryUint64ToUint64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint64: %v", len(src)) + } + + p := (dst).(*uint64) + *p = binary.BigEndian.Uint64(src) + + return nil +} + +type scanPlanBinaryUint64ToUint64Scanner struct{} + +func (scanPlanBinaryUint64ToUint64Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Uint64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanUint64(Uint64{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint64: %v", len(src)) + } + + n := binary.BigEndian.Uint64(src) + + return s.ScanUint64(Uint64{Uint64: n, Valid: true}) +} + +type scanPlanBinaryUint64ToTextScanner struct{} + +func (scanPlanBinaryUint64ToTextScanner) Scan(src []byte, dst any) error { + s, ok := (dst).(TextScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint64: %v", len(src)) + } + + n := uint64(binary.BigEndian.Uint64(src)) + return s.ScanText(Text{String: strconv.FormatUint(n, 10), Valid: true}) +} + +type scanPlanTextAnyToUint64Scanner struct{} + +func (scanPlanTextAnyToUint64Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Uint64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanUint64(Uint64{}) + } + + n, err := strconv.ParseUint(string(src), 10, 64) + if err != nil { + return err + } + + return s.ScanUint64(Uint64{Uint64: n, Valid: true}) +} diff --git a/pgtype/uint64_test.go b/pgtype/uint64_test.go new file mode 100644 index 000000000..33c2622d5 --- /dev/null +++ b/pgtype/uint64_test.go @@ -0,0 +1,30 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestUint64Codec(t *testing.T) { + skipCockroachDB(t, "Server does not support xid8 (https://github.com/cockroachdb/cockroach/issues/36815)") + skipPostgreSQLVersionLessThan(t, 13) + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "xid8", []pgxtest.ValueRoundTripTest{ + { + pgtype.Uint64{Uint64: 1 << 36, Valid: true}, + new(pgtype.Uint64), + isExpectedEq(pgtype.Uint64{Uint64: 1 << 36, Valid: true}), + }, + {pgtype.Uint64{}, new(pgtype.Uint64), isExpectedEq(pgtype.Uint64{})}, + {nil, new(pgtype.Uint64), isExpectedEq(pgtype.Uint64{})}, + { + uint64(1 << 36), + new(uint64), + isExpectedEq(uint64(1 << 36)), + }, + {"1147", new(string), isExpectedEq("1147")}, + }) +} diff --git a/pgtype/unknown.go b/pgtype/unknown.go deleted file mode 100644 index 567831d71..000000000 --- a/pgtype/unknown.go +++ /dev/null @@ -1,44 +0,0 @@ -package pgtype - -import "database/sql/driver" - -// Unknown represents the PostgreSQL unknown type. It is either a string literal -// or NULL. It is used when PostgreSQL does not know the type of a value. In -// general, this will only be used in pgx when selecting a null value without -// type information. e.g. SELECT NULL; -type Unknown struct { - String string - Status Status -} - -func (dst *Unknown) Set(src interface{}) error { - return (*Text)(dst).Set(src) -} - -func (dst *Unknown) Get() interface{} { - return (*Text)(dst).Get() -} - -// AssignTo assigns from src to dst. Note that as Unknown is not a general number -// type AssignTo does not do automatic type conversion as other number types do. -func (src *Unknown) AssignTo(dst interface{}) error { - return (*Text)(src).AssignTo(dst) -} - -func (dst *Unknown) DecodeText(ci *ConnInfo, src []byte) error { - return (*Text)(dst).DecodeText(ci, src) -} - -func (dst *Unknown) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*Text)(dst).DecodeBinary(ci, src) -} - -// Scan implements the database/sql Scanner interface. -func (dst *Unknown) Scan(src interface{}) error { - return (*Text)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *Unknown) Value() (driver.Value, error) { - return (*Text)(src).Value() -} diff --git a/pgtype/uuid.go b/pgtype/uuid.go index f8297b396..83d0c4127 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -1,93 +1,48 @@ package pgtype import ( + "bytes" "database/sql/driver" "encoding/hex" "fmt" - - "github.com/pkg/errors" ) -type UUID struct { - Bytes [16]byte - Status Status +type UUIDScanner interface { + ScanUUID(v UUID) error } -func (dst *UUID) Set(src interface{}) error { - if src == nil { - *dst = UUID{Status: Null} - return nil - } +type UUIDValuer interface { + UUIDValue() (UUID, error) +} - switch value := src.(type) { - case [16]byte: - *dst = UUID{Bytes: value, Status: Present} - case []byte: - if value != nil { - if len(value) != 16 { - return errors.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) - } - *dst = UUID{Status: Present} - copy(dst.Bytes[:], value) - } else { - *dst = UUID{Status: Null} - } - case string: - uuid, err := parseUUID(value) - if err != nil { - return err - } - *dst = UUID{Bytes: uuid, Status: Present} - default: - if originalSrc, ok := underlyingPtrType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to UUID", value) - } +type UUID struct { + Bytes [16]byte + Valid bool +} +// ScanUUID implements the [UUIDScanner] interface. +func (b *UUID) ScanUUID(v UUID) error { + *b = v return nil } -func (dst *UUID) Get() interface{} { - switch dst.Status { - case Present: - return dst.Bytes - case Null: - return nil - default: - return dst.Status - } -} - -func (src *UUID) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - case *[16]byte: - *v = src.Bytes - return nil - case *[]byte: - *v = make([]byte, 16) - copy(*v, src.Bytes[:]) - return nil - case *string: - *v = encodeUUID(src.Bytes) - return nil - default: - if nextDst, retry := GetAssignToDstType(v); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot assign %v into %T", src, dst) +// UUIDValue implements the [UUIDValuer] interface. +func (b UUID) UUIDValue() (UUID, error) { + return b, nil } // parseUUID converts a string UUID in standard form to a byte array. func parseUUID(src string) (dst [16]byte, err error) { - src = src[0:8] + src[9:13] + src[14:18] + src[19:23] + src[24:] + switch len(src) { + case 36: + src = src[0:8] + src[9:13] + src[14:18] + src[19:23] + src[24:] + case 32: + // dashes already stripped, assume valid + default: + // assume invalid. + return dst, fmt.Errorf("cannot parse UUID %v", src) + } + buf, err := hex.DecodeString(src) if err != nil { return dst, err @@ -99,85 +54,240 @@ func parseUUID(src string) (dst [16]byte, err error) { // encodeUUID converts a uuid byte array to UUID standard string form. func encodeUUID(src [16]byte) string { - return fmt.Sprintf("%x-%x-%x-%x-%x", src[0:4], src[4:6], src[6:8], src[8:10], src[10:16]) + var buf [36]byte + + hex.Encode(buf[0:8], src[:4]) + buf[8] = '-' + hex.Encode(buf[9:13], src[4:6]) + buf[13] = '-' + hex.Encode(buf[14:18], src[6:8]) + buf[18] = '-' + hex.Encode(buf[19:23], src[8:10]) + buf[23] = '-' + hex.Encode(buf[24:], src[10:]) + + return string(buf[:]) } -func (dst *UUID) DecodeText(ci *ConnInfo, src []byte) error { +// Scan implements the [database/sql.Scanner] interface. +func (dst *UUID) Scan(src any) error { if src == nil { - *dst = UUID{Status: Null} + *dst = UUID{} + return nil + } + + switch src := src.(type) { + case string: + buf, err := parseUUID(src) + if err != nil { + return err + } + *dst = UUID{Bytes: buf, Valid: true} return nil } - if len(src) != 36 { - return errors.Errorf("invalid length for UUID: %v", len(src)) + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (src UUID) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil } - buf, err := parseUUID(string(src)) + return encodeUUID(src.Bytes), nil +} + +func (src UUID) String() string { + if !src.Valid { + return "" + } + + return encodeUUID(src.Bytes) +} + +// MarshalJSON implements the [encoding/json.Marshaler] interface. +func (src UUID) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + + var buff bytes.Buffer + buff.WriteByte('"') + buff.WriteString(encodeUUID(src.Bytes)) + buff.WriteByte('"') + return buff.Bytes(), nil +} + +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. +func (dst *UUID) UnmarshalJSON(src []byte) error { + if bytes.Equal(src, []byte("null")) { + *dst = UUID{} + return nil + } + if len(src) != 38 { + return fmt.Errorf("invalid length for UUID: %v", len(src)) + } + buf, err := parseUUID(string(src[1 : len(src)-1])) if err != nil { return err } - - *dst = UUID{Bytes: buf, Status: Present} + *dst = UUID{Bytes: buf, Valid: true} return nil } -func (dst *UUID) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = UUID{Status: Null} +type UUIDCodec struct{} + +func (UUIDCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (UUIDCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (UUIDCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(UUIDValuer); !ok { return nil } - if len(src) != 16 { - return errors.Errorf("invalid length for UUID: %v", len(src)) + switch format { + case BinaryFormatCode: + return encodePlanUUIDCodecBinaryUUIDValuer{} + case TextFormatCode: + return encodePlanUUIDCodecTextUUIDValuer{} } - *dst = UUID{Status: Present} - copy(dst.Bytes[:], src) return nil } -func (src *UUID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: +type encodePlanUUIDCodecBinaryUUIDValuer struct{} + +func (encodePlanUUIDCodecBinaryUUIDValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + uuid, err := value.(UUIDValuer).UUIDValue() + if err != nil { + return nil, err + } + + if !uuid.Valid { return nil, nil - case Undefined: - return nil, errUndefined } - return append(buf, encodeUUID(src.Bytes)...), nil + return append(buf, uuid.Bytes[:]...), nil } -func (src *UUID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: +type encodePlanUUIDCodecTextUUIDValuer struct{} + +func (encodePlanUUIDCodecTextUUIDValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + uuid, err := value.(UUIDValuer).UUIDValue() + if err != nil { + return nil, err + } + + if !uuid.Valid { return nil, nil - case Undefined: - return nil, errUndefined } - return append(buf, src.Bytes[:]...), nil + return append(buf, encodeUUID(uuid.Bytes)...), nil } -// Scan implements the database/sql Scanner interface. -func (dst *UUID) Scan(src interface{}) error { +func (UUIDCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case UUIDScanner: + return scanPlanBinaryUUIDToUUIDScanner{} + case TextScanner: + return scanPlanBinaryUUIDToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case UUIDScanner: + return scanPlanTextAnyToUUIDScanner{} + } + } + + return nil +} + +type scanPlanBinaryUUIDToUUIDScanner struct{} + +func (scanPlanBinaryUUIDToUUIDScanner) Scan(src []byte, dst any) error { + scanner := (dst).(UUIDScanner) + if src == nil { - *dst = UUID{Status: Null} - return nil + return scanner.ScanUUID(UUID{}) } - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + if len(src) != 16 { + return fmt.Errorf("invalid length for UUID: %v", len(src)) } - return errors.Errorf("cannot scan %T", src) + uuid := UUID{Valid: true} + copy(uuid.Bytes[:], src) + + return scanner.ScanUUID(uuid) } -// Value implements the database/sql/driver Valuer interface. -func (src *UUID) Value() (driver.Value, error) { - return EncodeValueText(src) +type scanPlanBinaryUUIDToTextScanner struct{} + +func (scanPlanBinaryUUIDToTextScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TextScanner) + + if src == nil { + return scanner.ScanText(Text{}) + } + + if len(src) != 16 { + return fmt.Errorf("invalid length for UUID: %v", len(src)) + } + + var buf [16]byte + copy(buf[:], src) + + return scanner.ScanText(Text{String: encodeUUID(buf), Valid: true}) +} + +type scanPlanTextAnyToUUIDScanner struct{} + +func (scanPlanTextAnyToUUIDScanner) Scan(src []byte, dst any) error { + scanner := (dst).(UUIDScanner) + + if src == nil { + return scanner.ScanUUID(UUID{}) + } + + buf, err := parseUUID(string(src)) + if err != nil { + return err + } + + return scanner.ScanUUID(UUID{Bytes: buf, Valid: true}) +} + +func (c UUIDCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var uuid UUID + err := codecScan(c, m, oid, format, src, &uuid) + if err != nil { + return nil, err + } + + return encodeUUID(uuid.Bytes), nil +} + +func (c UUIDCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var uuid UUID + err := codecScan(c, m, oid, format, src, &uuid) + if err != nil { + return nil, err + } + return uuid.Bytes, nil } diff --git a/pgtype/uuid_array.go b/pgtype/uuid_array.go deleted file mode 100644 index 9c7843a72..000000000 --- a/pgtype/uuid_array.go +++ /dev/null @@ -1,356 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type UUIDArray struct { - Elements []UUID - Dimensions []ArrayDimension - Status Status -} - -func (dst *UUIDArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = UUIDArray{Status: Null} - return nil - } - - switch value := src.(type) { - - case [][16]byte: - if value == nil { - *dst = UUIDArray{Status: Null} - } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} - } else { - elements := make([]UUID, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = UUIDArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case [][]byte: - if value == nil { - *dst = UUIDArray{Status: Null} - } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} - } else { - elements := make([]UUID, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = UUIDArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []string: - if value == nil { - *dst = UUIDArray{Status: Null} - } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} - } else { - elements := make([]UUID, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = UUIDArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - default: - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to UUIDArray", value) - } - - return nil -} - -func (dst *UUIDArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *UUIDArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - - case *[][16]byte: - *v = make([][16]byte, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[][]byte: - *v = make([][]byte, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *UUIDArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = UUIDArray{Status: Null} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []UUID - - if len(uta.Elements) > 0 { - elements = make([]UUID, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem UUID - var elemSrc []byte - if s != "NULL" { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = UUIDArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} - - return nil -} - -func (dst *UUIDArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = UUIDArray{Status: Null} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = UUIDArray{Dimensions: arrayHeader.Dimensions, Status: Present} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]UUID, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = UUIDArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} - return nil -} - -func (src *UUIDArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src *UUIDArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("uuid"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, errors.Errorf("unable to find oid for type name %v", "uuid") - } - - for i := range src.Elements { - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *UUIDArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *UUIDArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/uuid_array_test.go b/pgtype/uuid_array_test.go deleted file mode 100644 index ee9d3dfa6..000000000 --- a/pgtype/uuid_array_test.go +++ /dev/null @@ -1,205 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestUUIDArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "uuid[]", []interface{}{ - &pgtype.UUIDArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.UUIDArray{Status: pgtype.Null}, - &pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, - {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, - {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestUUIDArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.UUIDArray - }{ - { - source: nil, - result: pgtype.UUIDArray{Status: pgtype.Null}, - }, - { - source: [][16]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, - result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][16]byte{}, - result: pgtype.UUIDArray{Status: pgtype.Present}, - }, - { - source: ([][16]byte)(nil), - result: pgtype.UUIDArray{Status: pgtype.Null}, - }, - { - source: [][]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, - result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: [][]byte{ - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - {16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, - nil, - {32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, - }, - result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{ - {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 4}}, - Status: pgtype.Present}, - }, - { - source: [][]byte{}, - result: pgtype.UUIDArray{Status: pgtype.Present}, - }, - { - source: ([][]byte)(nil), - result: pgtype.UUIDArray{Status: pgtype.Null}, - }, - { - source: []string{"00010203-0405-0607-0809-0a0b0c0d0e0f"}, - result: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: []string{}, - result: pgtype.UUIDArray{Status: pgtype.Present}, - }, - { - source: ([]string)(nil), - result: pgtype.UUIDArray{Status: pgtype.Null}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.UUIDArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestUUIDArrayAssignTo(t *testing.T) { - var byteArraySlice [][16]byte - var byteSliceSlice [][]byte - var stringSlice []string - - simpleTests := []struct { - src pgtype.UUIDArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &byteArraySlice, - expected: [][16]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, - }, - { - src: pgtype.UUIDArray{Status: pgtype.Null}, - dst: &byteArraySlice, - expected: ([][16]byte)(nil), - }, - { - src: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &byteSliceSlice, - expected: [][]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, - }, - { - src: pgtype.UUIDArray{Status: pgtype.Null}, - dst: &byteSliceSlice, - expected: ([][]byte)(nil), - }, - { - src: pgtype.UUIDArray{ - Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &stringSlice, - expected: []string{"00010203-0405-0607-0809-0a0b0c0d0e0f"}, - }, - { - src: pgtype.UUIDArray{Status: pgtype.Null}, - dst: &stringSlice, - expected: ([]string)(nil), - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} diff --git a/pgtype/uuid_test.go b/pgtype/uuid_test.go index 162d999f1..255bd92f6 100644 --- a/pgtype/uuid_test.go +++ b/pgtype/uuid_test.go @@ -1,104 +1,176 @@ package pgtype_test import ( - "bytes" + "context" + "reflect" "testing" - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" ) -func TestUUIDTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ - &pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - &pgtype.UUID{Status: pgtype.Null}, +type renamedUUIDByteArray [16]byte + +func TestUUIDCodec(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "uuid", []pgxtest.ValueRoundTripTest{ + { + pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + new(pgtype.UUID), + isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), + }, + { + "00010203-0405-0607-0809-0a0b0c0d0e0f", + new(pgtype.UUID), + isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), + }, + { + "000102030405060708090a0b0c0d0e0f", + new(pgtype.UUID), + isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), + }, + { + pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + new(string), + isExpectedEq("00010203-0405-0607-0809-0a0b0c0d0e0f"), + }, + {pgtype.UUID{}, new([]byte), isExpectedEqBytes([]byte(nil))}, + {pgtype.UUID{}, new(pgtype.UUID), isExpectedEq(pgtype.UUID{})}, + {nil, new(pgtype.UUID), isExpectedEq(pgtype.UUID{})}, }) -} -func TestUUIDSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.UUID - }{ + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "uuid", []pgxtest.ValueRoundTripTest{ { - source: nil, - result: pgtype.UUID{Status: pgtype.Null}, + [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + new(pgtype.UUID), + isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), }, { - source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + renamedUUIDByteArray{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + new(pgtype.UUID), + isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), }, { - source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + new(renamedUUIDByteArray), + isExpectedEq(renamedUUIDByteArray{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), }, { - source: ([]byte)(nil), - result: pgtype.UUID{Status: pgtype.Null}, + []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + new(pgtype.UUID), + isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), + }, + }) +} + +func TestUUID_String(t *testing.T) { + tests := []struct { + name string + src pgtype.UUID + want string + }{ + { + name: "first", + src: pgtype.UUID{ + Bytes: [16]byte{29, 72, 90, 122, 109, 24, 69, 153, 140, 108, 52, 66, 86, 22, 136, 122}, + Valid: true, + }, + want: "1d485a7a-6d18-4599-8c6c-34425616887a", }, { - source: "00010203-0405-0607-0809-0a0b0c0d0e0f", - result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + name: "third", + src: pgtype.UUID{ + Bytes: [16]byte{}, + }, + want: "", }, } - - for i, tt := range successfulTests { - var r pgtype.UUID - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.src.String() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, tt.want) + } + }) } } -func TestUUIDAssignTo(t *testing.T) { - { - src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} - var dst [16]byte - expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if dst != expected { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } +func TestUUID_MarshalJSON(t *testing.T) { + tests := []struct { + name string + src pgtype.UUID + want []byte + }{ + { + name: "first", + src: pgtype.UUID{ + Bytes: [16]byte{29, 72, 90, 122, 109, 24, 69, 153, 140, 108, 52, 66, 86, 22, 136, 122}, + Valid: true, + }, + want: []byte(`"1d485a7a-6d18-4599-8c6c-34425616887a"`), + }, + { + name: "third", + src: pgtype.UUID{ + Bytes: [16]byte{}, + }, + want: []byte("null"), + }, } - - { - src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} - var dst []byte - expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if bytes.Compare(dst, expected) != 0 { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.src.MarshalJSON() + require.NoError(t, err) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, tt.want) + } + }) } +} - { - src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} - var dst string - expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if dst != expected { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } +func TestUUID_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + want *pgtype.UUID + src []byte + wantErr bool + }{ + { + name: "first", + want: &pgtype.UUID{ + Bytes: [16]byte{29, 72, 90, 122, 109, 24, 69, 153, 140, 108, 52, 66, 86, 22, 136, 122}, + Valid: true, + }, + src: []byte(`"1d485a7a-6d18-4599-8c6c-34425616887a"`), + wantErr: false, + }, + { + name: "second", + want: &pgtype.UUID{ + Bytes: [16]byte{}, + }, + src: []byte("null"), + wantErr: false, + }, + { + name: "third", + want: &pgtype.UUID{ + Bytes: [16]byte{}, + Valid: false, + }, + src: []byte("1d485a7a-6d18-4599-8c6c-34425616887a"), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := &pgtype.UUID{} + if err := got.UnmarshalJSON(tt.src); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("UnmarshalJSON() got = %v, want %v", got, tt.want) + } + }) } - } diff --git a/pgtype/varbit.go b/pgtype/varbit.go deleted file mode 100644 index dfa194d20..000000000 --- a/pgtype/varbit.go +++ /dev/null @@ -1,133 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type Varbit struct { - Bytes []byte - Len int32 // Number of bits - Status Status -} - -func (dst *Varbit) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Varbit", src) -} - -func (dst *Varbit) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *Varbit) AssignTo(dst interface{}) error { - return errors.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Varbit) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Varbit{Status: Null} - return nil - } - - bitLen := len(src) - byteLen := bitLen / 8 - if bitLen%8 > 0 { - byteLen++ - } - buf := make([]byte, byteLen) - - for i, b := range src { - if b == '1' { - byteIdx := i / 8 - bitIdx := uint(i % 8) - buf[byteIdx] = buf[byteIdx] | (128 >> bitIdx) - } - } - - *dst = Varbit{Bytes: buf, Len: int32(bitLen), Status: Present} - return nil -} - -func (dst *Varbit) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Varbit{Status: Null} - return nil - } - - if len(src) < 4 { - return errors.Errorf("invalid length for varbit: %v", len(src)) - } - - bitLen := int32(binary.BigEndian.Uint32(src)) - rp := 4 - - *dst = Varbit{Bytes: src[rp:], Len: bitLen, Status: Present} - return nil -} - -func (src *Varbit) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - for i := int32(0); i < src.Len; i++ { - byteIdx := i / 8 - bitMask := byte(128 >> byte(i%8)) - char := byte('0') - if src.Bytes[byteIdx]&bitMask > 0 { - char = '1' - } - buf = append(buf, char) - } - - return buf, nil -} - -func (src *Varbit) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - buf = pgio.AppendInt32(buf, src.Len) - return append(buf, src.Bytes...), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Varbit) Scan(src interface{}) error { - if src == nil { - *dst = Varbit{Status: Null} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *Varbit) Value() (driver.Value, error) { - return EncodeValueText(src) -} diff --git a/pgtype/varbit_test.go b/pgtype/varbit_test.go deleted file mode 100644 index 6c813aae4..000000000 --- a/pgtype/varbit_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package pgtype_test - -import ( - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestVarbitTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "varbit", []interface{}{ - &pgtype.Varbit{Bytes: []byte{}, Len: 0, Status: pgtype.Present}, - &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Status: pgtype.Present}, - &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Status: pgtype.Present}, - &pgtype.Varbit{Status: pgtype.Null}, - }) -} - -func TestVarbitNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select B'111111111'", - Value: &pgtype.Varbit{Bytes: []byte{255, 128}, Len: 9, Status: pgtype.Present}, - }, - }) -} diff --git a/pgtype/varchar.go b/pgtype/varchar.go deleted file mode 100644 index 6be1a0352..000000000 --- a/pgtype/varchar.go +++ /dev/null @@ -1,58 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" -) - -type Varchar Text - -// Set converts from src to dst. Note that as Varchar is not a general -// number type Set does not do automatic type conversion as other number -// types do. -func (dst *Varchar) Set(src interface{}) error { - return (*Text)(dst).Set(src) -} - -func (dst *Varchar) Get() interface{} { - return (*Text)(dst).Get() -} - -// AssignTo assigns from src to dst. Note that as Varchar is not a general number -// type AssignTo does not do automatic type conversion as other number types do. -func (src *Varchar) AssignTo(dst interface{}) error { - return (*Text)(src).AssignTo(dst) -} - -func (dst *Varchar) DecodeText(ci *ConnInfo, src []byte) error { - return (*Text)(dst).DecodeText(ci, src) -} - -func (dst *Varchar) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*Text)(dst).DecodeBinary(ci, src) -} - -func (src *Varchar) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Text)(src).EncodeText(ci, buf) -} - -func (src *Varchar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Text)(src).EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *Varchar) Scan(src interface{}) error { - return (*Text)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *Varchar) Value() (driver.Value, error) { - return (*Text)(src).Value() -} - -func (src *Varchar) MarshalJSON() ([]byte, error) { - return (*Text)(src).MarshalJSON() -} - -func (dst *Varchar) UnmarshalJSON(b []byte) error { - return (*Text)(dst).UnmarshalJSON(b) -} diff --git a/pgtype/varchar_array.go b/pgtype/varchar_array.go deleted file mode 100644 index 09eba3eab..000000000 --- a/pgtype/varchar_array.go +++ /dev/null @@ -1,300 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -type VarcharArray struct { - Elements []Varchar - Dimensions []ArrayDimension - Status Status -} - -func (dst *VarcharArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = VarcharArray{Status: Null} - return nil - } - - switch value := src.(type) { - - case []string: - if value == nil { - *dst = VarcharArray{Status: Null} - } else if len(value) == 0 { - *dst = VarcharArray{Status: Present} - } else { - elements := make([]Varchar, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = VarcharArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - default: - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return errors.Errorf("cannot convert %v to VarcharArray", value) - } - - return nil -} - -func (dst *VarcharArray) Get() interface{} { - switch dst.Status { - case Present: - return dst - case Null: - return nil - default: - return dst.Status - } -} - -func (src *VarcharArray) AssignTo(dst interface{}) error { - switch src.Status { - case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - } - case Null: - return NullAssignTo(dst) - } - - return errors.Errorf("cannot decode %v into %T", src, dst) -} - -func (dst *VarcharArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = VarcharArray{Status: Null} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Varchar - - if len(uta.Elements) > 0 { - elements = make([]Varchar, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Varchar - var elemSrc []byte - if s != "NULL" { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = VarcharArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} - - return nil -} - -func (dst *VarcharArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = VarcharArray{Status: Null} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = VarcharArray{Dimensions: arrayHeader.Dimensions, Status: Present} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Varchar, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = VarcharArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} - return nil -} - -func (src *VarcharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `"NULL"`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src *VarcharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - switch src.Status { - case Null: - return nil, nil - case Undefined: - return nil, errUndefined - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("varchar"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, errors.Errorf("unable to find oid for type name %v", "varchar") - } - - for i := range src.Elements { - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *VarcharArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return errors.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *VarcharArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/varchar_array_test.go b/pgtype/varchar_array_test.go deleted file mode 100644 index 9fb0960f5..000000000 --- a/pgtype/varchar_array_test.go +++ /dev/null @@ -1,152 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestVarcharArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "varchar[]", []interface{}{ - &pgtype.VarcharArray{ - Elements: nil, - Dimensions: nil, - Status: pgtype.Present, - }, - &pgtype.VarcharArray{ - Elements: []pgtype.Varchar{ - {String: "foo", Status: pgtype.Present}, - {Status: pgtype.Null}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.VarcharArray{Status: pgtype.Null}, - &pgtype.VarcharArray{ - Elements: []pgtype.Varchar{ - {String: "bar ", Status: pgtype.Present}, - {String: "NuLL", Status: pgtype.Present}, - {String: `wow"quz\`, Status: pgtype.Present}, - {String: "", Status: pgtype.Present}, - {Status: pgtype.Null}, - {String: "null", Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.VarcharArray{ - Elements: []pgtype.Varchar{ - {String: "bar", Status: pgtype.Present}, - {String: "baz", Status: pgtype.Present}, - {String: "quz", Status: pgtype.Present}, - {String: "foo", Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Status: pgtype.Present, - }, - }) -} - -func TestVarcharArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.VarcharArray - }{ - { - source: []string{"foo"}, - result: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present}, - }, - { - source: (([]string)(nil)), - result: pgtype.VarcharArray{Status: pgtype.Null}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.VarcharArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestVarcharArrayAssignTo(t *testing.T) { - var stringSlice []string - type _stringSlice []string - var namedStringSlice _stringSlice - - simpleTests := []struct { - src pgtype.VarcharArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &stringSlice, - expected: []string{"foo"}, - }, - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{String: "bar", Status: pgtype.Present}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &namedStringSlice, - expected: _stringSlice{"bar"}, - }, - { - src: pgtype.VarcharArray{Status: pgtype.Null}, - dst: &stringSlice, - expected: (([]string)(nil)), - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.VarcharArray - dst interface{} - }{ - { - src: pgtype.VarcharArray{ - Elements: []pgtype.Varchar{{Status: pgtype.Null}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Status: pgtype.Present, - }, - dst: &stringSlice, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/xid.go b/pgtype/xid.go deleted file mode 100644 index f66f53670..000000000 --- a/pgtype/xid.go +++ /dev/null @@ -1,64 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" -) - -// XID is PostgreSQL's Transaction ID type. -// -// In later versions of PostgreSQL, it is the type used for the backend_xid -// and backend_xmin columns of the pg_stat_activity system view. -// -// Also, when one does -// -// select xmin, xmax, * from some_table; -// -// it is the data type of the xmin and xmax hidden system columns. -// -// It is currently implemented as an unsigned four byte integer. -// Its definition can be found in src/include/postgres_ext.h as TransactionId -// in the PostgreSQL sources. -type XID pguint32 - -// Set converts from src to dst. Note that as XID is not a general -// number type Set does not do automatic type conversion as other number -// types do. -func (dst *XID) Set(src interface{}) error { - return (*pguint32)(dst).Set(src) -} - -func (dst *XID) Get() interface{} { - return (*pguint32)(dst).Get() -} - -// AssignTo assigns from src to dst. Note that as XID is not a general number -// type AssignTo does not do automatic type conversion as other number types do. -func (src *XID) AssignTo(dst interface{}) error { - return (*pguint32)(src).AssignTo(dst) -} - -func (dst *XID) DecodeText(ci *ConnInfo, src []byte) error { - return (*pguint32)(dst).DecodeText(ci, src) -} - -func (dst *XID) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*pguint32)(dst).DecodeBinary(ci, src) -} - -func (src *XID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*pguint32)(src).EncodeText(ci, buf) -} - -func (src *XID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*pguint32)(src).EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *XID) Scan(src interface{}) error { - return (*pguint32)(dst).Scan(src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src *XID) Value() (driver.Value, error) { - return (*pguint32)(src).Value() -} diff --git a/pgtype/xid_test.go b/pgtype/xid_test.go deleted file mode 100644 index d0f3f0ab6..000000000 --- a/pgtype/xid_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/pgtype" - "github.com/jackc/pgx/pgtype/testutil" -) - -func TestXIDTranscode(t *testing.T) { - pgTypeName := "xid" - values := []interface{}{ - &pgtype.XID{Uint: 42, Status: pgtype.Present}, - &pgtype.XID{Status: pgtype.Null}, - } - eqFunc := func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - } - - testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - - // No direct conversion from int to xid, convert through text - testutil.TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc) - - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) - } -} - -func TestXIDSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.XID - }{ - {source: uint32(1), result: pgtype.XID{Uint: 1, Status: pgtype.Present}}, - } - - for i, tt := range successfulTests { - var r pgtype.XID - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestXIDAssignTo(t *testing.T) { - var ui32 uint32 - var pui32 *uint32 - - simpleTests := []struct { - src pgtype.XID - dst interface{} - expected interface{} - }{ - {src: pgtype.XID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.XID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.XID - dst interface{} - expected interface{} - }{ - {src: pgtype.XID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.XID - dst interface{} - }{ - {src: pgtype.XID{Status: pgtype.Null}, dst: &ui32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/xml.go b/pgtype/xml.go new file mode 100644 index 000000000..79e3698a4 --- /dev/null +++ b/pgtype/xml.go @@ -0,0 +1,198 @@ +package pgtype + +import ( + "database/sql" + "database/sql/driver" + "encoding/xml" + "fmt" + "reflect" +) + +type XMLCodec struct { + Marshal func(v any) ([]byte, error) + Unmarshal func(data []byte, v any) error +} + +func (*XMLCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (*XMLCodec) PreferredFormat() int16 { + return TextFormatCode +} + +func (c *XMLCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch value.(type) { + case string: + return encodePlanXMLCodecEitherFormatString{} + case []byte: + return encodePlanXMLCodecEitherFormatByteSlice{} + + // Cannot rely on driver.Valuer being handled later because anything can be marshalled. + // + // https://github.com/jackc/pgx/issues/1430 + // + // Check for driver.Valuer must come before xml.Marshaler so that it is guaranteed to be used + // when both are implemented https://github.com/jackc/pgx/issues/1805 + case driver.Valuer: + return &encodePlanDriverValuer{m: m, oid: oid, formatCode: format} + + // Must come before trying wrap encode plans because a pointer to a struct may be unwrapped to a struct that can be + // marshalled. + // + // https://github.com/jackc/pgx/issues/1681 + case xml.Marshaler: + return &encodePlanXMLCodecEitherFormatMarshal{ + marshal: c.Marshal, + } + } + + // Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the + // appropriate wrappers here. + for _, f := range []TryWrapEncodePlanFunc{ + TryWrapDerefPointerEncodePlan, + TryWrapFindUnderlyingTypeEncodePlan, + } { + if wrapperPlan, nextValue, ok := f(value); ok { + if nextPlan := c.PlanEncode(m, oid, format, nextValue); nextPlan != nil { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } + } + } + + return &encodePlanXMLCodecEitherFormatMarshal{ + marshal: c.Marshal, + } +} + +type encodePlanXMLCodecEitherFormatString struct{} + +func (encodePlanXMLCodecEitherFormatString) Encode(value any, buf []byte) (newBuf []byte, err error) { + xmlString := value.(string) + buf = append(buf, xmlString...) + return buf, nil +} + +type encodePlanXMLCodecEitherFormatByteSlice struct{} + +func (encodePlanXMLCodecEitherFormatByteSlice) Encode(value any, buf []byte) (newBuf []byte, err error) { + xmlBytes := value.([]byte) + if xmlBytes == nil { + return nil, nil + } + + buf = append(buf, xmlBytes...) + return buf, nil +} + +type encodePlanXMLCodecEitherFormatMarshal struct { + marshal func(v any) ([]byte, error) +} + +func (e *encodePlanXMLCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) { + xmlBytes, err := e.marshal(value) + if err != nil { + return nil, err + } + + buf = append(buf, xmlBytes...) + return buf, nil +} + +func (c *XMLCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch target.(type) { + case *string: + return scanPlanAnyToString{} + + case **string: + // This is to fix **string scanning. It seems wrong to special case **string, but it's not clear what a better + // solution would be. + // + // https://github.com/jackc/pgx/issues/1470 -- **string + // https://github.com/jackc/pgx/issues/1691 -- ** anything else + + if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok { + if nextPlan := m.planScan(oid, format, nextDst, 0); nextPlan != nil { + if _, failed := nextPlan.(*scanPlanFail); !failed { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } + } + } + + case *[]byte: + return scanPlanXMLToByteSlice{} + case BytesScanner: + return scanPlanBinaryBytesToBytesScanner{} + + // Cannot rely on sql.Scanner being handled later because scanPlanXMLToXMLUnmarshal will take precedence. + // + // https://github.com/jackc/pgx/issues/1418 + case sql.Scanner: + return &scanPlanSQLScanner{formatCode: format} + } + + return &scanPlanXMLToXMLUnmarshal{ + unmarshal: c.Unmarshal, + } +} + +type scanPlanXMLToByteSlice struct{} + +func (scanPlanXMLToByteSlice) Scan(src []byte, dst any) error { + dstBuf := dst.(*[]byte) + if src == nil { + *dstBuf = nil + return nil + } + + *dstBuf = make([]byte, len(src)) + copy(*dstBuf, src) + return nil +} + +type scanPlanXMLToXMLUnmarshal struct { + unmarshal func(data []byte, v any) error +} + +func (s *scanPlanXMLToXMLUnmarshal) Scan(src []byte, dst any) error { + if src == nil { + dstValue := reflect.ValueOf(dst) + if dstValue.Kind() == reflect.Ptr { + el := dstValue.Elem() + switch el.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Interface, reflect.Struct: + el.Set(reflect.Zero(el.Type())) + return nil + } + } + + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + elem := reflect.ValueOf(dst).Elem() + elem.Set(reflect.Zero(elem.Type())) + + return s.unmarshal(src, dst) +} + +func (c *XMLCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + dstBuf := make([]byte, len(src)) + copy(dstBuf, src) + return dstBuf, nil +} + +func (c *XMLCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var dst any + err := c.Unmarshal(src, &dst) + return dst, err +} diff --git a/pgtype/xml_test.go b/pgtype/xml_test.go new file mode 100644 index 000000000..2c0b899a5 --- /dev/null +++ b/pgtype/xml_test.go @@ -0,0 +1,128 @@ +package pgtype_test + +import ( + "context" + "database/sql" + "encoding/xml" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type xmlStruct struct { + XMLName xml.Name `xml:"person"` + Name string `xml:"name"` + Age int `xml:"age,attr"` +} + +func TestXMLCodec(t *testing.T) { + skipCockroachDB(t, "CockroachDB does not support XML.") + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "xml", []pgxtest.ValueRoundTripTest{ + {nil, new(*xmlStruct), isExpectedEq((*xmlStruct)(nil))}, + {map[string]any(nil), new(*string), isExpectedEq((*string)(nil))}, + {map[string]any(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, + + // Test sql.Scanner. + {"", new(sql.NullString), isExpectedEq(sql.NullString{String: "", Valid: true})}, + + // Test driver.Valuer. + {sql.NullString{String: "", Valid: true}, new(sql.NullString), isExpectedEq(sql.NullString{String: "", Valid: true})}, + }) + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "xml", []pgxtest.ValueRoundTripTest{ + {[]byte(``), new([]byte), isExpectedEqBytes([]byte(``))}, + {[]byte(``), new([]byte), isExpectedEqBytes([]byte(``))}, + {[]byte(``), new(string), isExpectedEq(``)}, + {[]byte(``), new([]byte), isExpectedEqBytes([]byte(``))}, + {[]byte(``), new(string), isExpectedEq(``)}, + {[]byte(""), new([]byte), isExpectedEqBytes([]byte(""))}, + {xmlStruct{Name: "Adam", Age: 10}, new(xmlStruct), isExpectedEq(xmlStruct{XMLName: xml.Name{Local: "person"}, Name: "Adam", Age: 10})}, + {xmlStruct{XMLName: xml.Name{Local: "person"}, Name: "Adam", Age: 10}, new(xmlStruct), isExpectedEq(xmlStruct{XMLName: xml.Name{Local: "person"}, Name: "Adam", Age: 10})}, + {[]byte(`Adam`), new(xmlStruct), isExpectedEq(xmlStruct{XMLName: xml.Name{Local: "person"}, Name: "Adam", Age: 10})}, + }) +} + +// https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648 +func TestXMLCodecUnmarshalSQLNull(t *testing.T) { + skipCockroachDB(t, "CockroachDB does not support XML.") + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + // Byte arrays are nilified + slice := []byte{10, 4} + err := conn.QueryRow(ctx, "select null::xml").Scan(&slice) + assert.NoError(t, err) + assert.Nil(t, slice) + + // Non-pointer structs are zeroed + m := xmlStruct{Name: "Adam"} + err = conn.QueryRow(ctx, "select null::xml").Scan(&m) + assert.NoError(t, err) + assert.Empty(t, m) + + // Pointers to structs are nilified + pm := &xmlStruct{Name: "Adam"} + err = conn.QueryRow(ctx, "select null::xml").Scan(&pm) + assert.NoError(t, err) + assert.Nil(t, pm) + + // Pointer to pointer are nilified + n := "" + p := &n + err = conn.QueryRow(ctx, "select null::xml").Scan(&p) + assert.NoError(t, err) + assert.Nil(t, p) + + // A string cannot scan a NULL. + str := "foobar" + err = conn.QueryRow(ctx, "select null::xml").Scan(&str) + assert.EqualError(t, err, "can't scan into dest[0] (col: xml): cannot scan NULL into *string") + }) +} + +func TestXMLCodecPointerToPointerToString(t *testing.T) { + skipCockroachDB(t, "CockroachDB does not support XML.") + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var s *string + err := conn.QueryRow(ctx, "select ''::xml").Scan(&s) + require.NoError(t, err) + require.NotNil(t, s) + require.Equal(t, "", *s) + + err = conn.QueryRow(ctx, "select null::xml").Scan(&s) + require.NoError(t, err) + require.Nil(t, s) + }) +} + +func TestXMLCodecDecodeValue(t *testing.T) { + skipCockroachDB(t, "CockroachDB does not support XML.") + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + for _, tt := range []struct { + sql string + expected any + }{ + { + sql: `select 'bar'::xml`, + expected: []byte("bar"), + }, + } { + t.Run(tt.sql, func(t *testing.T) { + rows, err := conn.Query(ctx, tt.sql) + require.NoError(t, err) + + for rows.Next() { + values, err := rows.Values() + require.NoError(t, err) + require.Len(t, values, 1) + require.Equal(t, tt.expected, values[0]) + } + + require.NoError(t, rows.Err()) + }) + } + }) +} diff --git a/pgtype/zeronull/doc.go b/pgtype/zeronull/doc.go new file mode 100644 index 000000000..78a523072 --- /dev/null +++ b/pgtype/zeronull/doc.go @@ -0,0 +1,22 @@ +// Package zeronull contains types that automatically convert between database NULLs and Go zero values. +/* +Sometimes the distinction between a zero value and a NULL value is not useful at the application level. For example, +in PostgreSQL an empty string may be stored as NULL. There is usually no application level distinction between an +empty string and a NULL string. Package zeronull implements types that seamlessly convert between PostgreSQL NULL and +the zero value. + +It is recommended to convert types at usage time rather than instantiate these types directly. In the example below, +middlename would be stored as a NULL. + + firstname := "John" + middlename := "" + lastname := "Smith" + _, err := conn.Exec( + ctx, + "insert into people(firstname, middlename, lastname) values($1, $2, $3)", + zeronull.Text(firstname), + zeronull.Text(middlename), + zeronull.Text(lastname), + ) +*/ +package zeronull diff --git a/pgtype/zeronull/float8.go b/pgtype/zeronull/float8.go new file mode 100644 index 000000000..919997eab --- /dev/null +++ b/pgtype/zeronull/float8.go @@ -0,0 +1,58 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgx/v5/pgtype" +) + +type Float8 float64 + +// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface. +func (Float8) SkipUnderlyingTypePlan() {} + +// ScanFloat64 implements the [pgtype.Float64Scanner] interface. +func (f *Float8) ScanFloat64(n pgtype.Float8) error { + if !n.Valid { + *f = 0 + return nil + } + + *f = Float8(n.Float64) + + return nil +} + +// Float64Value implements the [pgtype.Float64Valuer] interface. +func (f Float8) Float64Value() (pgtype.Float8, error) { + if f == 0 { + return pgtype.Float8{}, nil + } + return pgtype.Float8{Float64: float64(f), Valid: true}, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (f *Float8) Scan(src any) error { + if src == nil { + *f = 0 + return nil + } + + var nullable pgtype.Float8 + err := nullable.Scan(src) + if err != nil { + return err + } + + *f = Float8(nullable.Float64) + + return nil +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (f Float8) Value() (driver.Value, error) { + if f == 0 { + return nil, nil + } + return float64(f), nil +} diff --git a/pgtype/zeronull/float8_test.go b/pgtype/zeronull/float8_test.go new file mode 100644 index 000000000..b3c818aaa --- /dev/null +++ b/pgtype/zeronull/float8_test.go @@ -0,0 +1,35 @@ +package zeronull_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgxtest" +) + +func isExpectedEq(a any) func(any) bool { + return func(v any) bool { + return a == v + } +} + +func TestFloat8Transcode(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "float8", []pgxtest.ValueRoundTripTest{ + { + (zeronull.Float8)(1), + new(zeronull.Float8), + isExpectedEq((zeronull.Float8)(1)), + }, + { + nil, + new(zeronull.Float8), + isExpectedEq((zeronull.Float8)(0)), + }, + { + (zeronull.Float8)(0), + new(any), + isExpectedEq(nil), + }, + }) +} diff --git a/pgtype/zeronull/int.go b/pgtype/zeronull/int.go new file mode 100644 index 000000000..e36aa2466 --- /dev/null +++ b/pgtype/zeronull/int.go @@ -0,0 +1,182 @@ +// Code generated from pgtype/zeronull/int.go.erb. DO NOT EDIT. + +package zeronull + +import ( + "database/sql/driver" + "fmt" + "math" + + "github.com/jackc/pgx/v5/pgtype" +) + +type Int2 int16 + +// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface. +func (Int2) SkipUnderlyingTypePlan() {} + +// ScanInt64 implements the [pgtype.Int64Scanner] interface. +func (dst *Int2) ScanInt64(n pgtype.Int8) error { + if !n.Valid { + *dst = 0 + return nil + } + + if n.Int64 < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for Int2", n.Int64) + } + if n.Int64 > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n.Int64) + } + *dst = Int2(n.Int64) + + return nil +} + +// Int64Value implements the [pgtype.Int64Valuer] interface. +func (src Int2) Int64Value() (pgtype.Int8, error) { + if src == 0 { + return pgtype.Int8{}, nil + } + return pgtype.Int8{Int64: int64(src), Valid: true}, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (dst *Int2) Scan(src any) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Int2 + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Int2(nullable.Int16) + + return nil +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (src Int2) Value() (driver.Value, error) { + if src == 0 { + return nil, nil + } + return int64(src), nil +} + +type Int4 int32 + +// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface. +func (Int4) SkipUnderlyingTypePlan() {} + +// ScanInt64 implements the [pgtype.Int64Scanner] interface. +func (dst *Int4) ScanInt64(n pgtype.Int8) error { + if !n.Valid { + *dst = 0 + return nil + } + + if n.Int64 < math.MinInt32 { + return fmt.Errorf("%d is less than minimum value for Int4", n.Int64) + } + if n.Int64 > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n.Int64) + } + *dst = Int4(n.Int64) + + return nil +} + +// Int64Value implements the [pgtype.Int64Valuer] interface. +func (src Int4) Int64Value() (pgtype.Int8, error) { + if src == 0 { + return pgtype.Int8{}, nil + } + return pgtype.Int8{Int64: int64(src), Valid: true}, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (dst *Int4) Scan(src any) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Int4 + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Int4(nullable.Int32) + + return nil +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (src Int4) Value() (driver.Value, error) { + if src == 0 { + return nil, nil + } + return int64(src), nil +} + +type Int8 int64 + +// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface. +func (Int8) SkipUnderlyingTypePlan() {} + +// ScanInt64 implements the [pgtype.Int64Scanner] interface. +func (dst *Int8) ScanInt64(n pgtype.Int8) error { + if !n.Valid { + *dst = 0 + return nil + } + + if n.Int64 < math.MinInt64 { + return fmt.Errorf("%d is less than minimum value for Int8", n.Int64) + } + if n.Int64 > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n.Int64) + } + *dst = Int8(n.Int64) + + return nil +} + +// Int64Value implements the [pgtype.Int64Valuer] interface. +func (src Int8) Int64Value() (pgtype.Int8, error) { + if src == 0 { + return pgtype.Int8{}, nil + } + return pgtype.Int8{Int64: int64(src), Valid: true}, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (dst *Int8) Scan(src any) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Int8 + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Int8(nullable.Int64) + + return nil +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (src Int8) Value() (driver.Value, error) { + if src == 0 { + return nil, nil + } + return int64(src), nil +} diff --git a/pgtype/zeronull/int.go.erb b/pgtype/zeronull/int.go.erb new file mode 100644 index 000000000..6cb5ddce4 --- /dev/null +++ b/pgtype/zeronull/int.go.erb @@ -0,0 +1,69 @@ +package zeronull + +import ( + "database/sql/driver" + "fmt" + "math" + + "github.com/jackc/pgx/v5/pgtype" +) + +<% [2, 4, 8].each do |pg_byte_size| %> +<% pg_bit_size = pg_byte_size * 8 %> +type Int<%= pg_byte_size %> int<%= pg_bit_size %> + +// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface. +func (Int<%= pg_byte_size %>) SkipUnderlyingTypePlan() {} + +// ScanInt64 implements the [pgtype.Int64Scanner] interface. +func (dst *Int<%= pg_byte_size %>) ScanInt64(n pgtype.Int8) error { + if !n.Valid { + *dst = 0 + return nil + } + + if n.Int64 < math.MinInt<%= pg_bit_size %> { + return fmt.Errorf("%d is less than minimum value for Int<%= pg_byte_size %>", n.Int64) + } + if n.Int64 > math.MaxInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n.Int64) + } + *dst = Int<%= pg_byte_size %>(n.Int64) + + return nil +} + +// Int64Value implements the [pgtype.Int64Valuer] interface. +func (src Int<%= pg_byte_size %>) Int64Value() (pgtype.Int8, error) { + if src == 0 { + return pgtype.Int8{}, nil + } + return pgtype.Int8{Int64: int64(src), Valid: true}, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (dst *Int<%= pg_byte_size %>) Scan(src any) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Int<%= pg_byte_size %> + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Int<%= pg_byte_size %>(nullable.Int<%= pg_bit_size %>) + + return nil +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) { + if src == 0 { + return nil, nil + } + return int64(src), nil +} +<% end %> diff --git a/pgtype/zeronull/int_test.go b/pgtype/zeronull/int_test.go new file mode 100644 index 000000000..7e32064ab --- /dev/null +++ b/pgtype/zeronull/int_test.go @@ -0,0 +1,71 @@ +// Code generated from pgtype/zeronull/int_test.go.erb. DO NOT EDIT. + +package zeronull_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestInt2Transcode(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int2", []pgxtest.ValueRoundTripTest{ + { + (zeronull.Int2)(1), + new(zeronull.Int2), + isExpectedEq((zeronull.Int2)(1)), + }, + { + nil, + new(zeronull.Int2), + isExpectedEq((zeronull.Int2)(0)), + }, + { + (zeronull.Int2)(0), + new(any), + isExpectedEq(nil), + }, + }) +} + +func TestInt4Transcode(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int4", []pgxtest.ValueRoundTripTest{ + { + (zeronull.Int4)(1), + new(zeronull.Int4), + isExpectedEq((zeronull.Int4)(1)), + }, + { + nil, + new(zeronull.Int4), + isExpectedEq((zeronull.Int4)(0)), + }, + { + (zeronull.Int4)(0), + new(any), + isExpectedEq(nil), + }, + }) +} + +func TestInt8Transcode(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int8", []pgxtest.ValueRoundTripTest{ + { + (zeronull.Int8)(1), + new(zeronull.Int8), + isExpectedEq((zeronull.Int8)(1)), + }, + { + nil, + new(zeronull.Int8), + isExpectedEq((zeronull.Int8)(0)), + }, + { + (zeronull.Int8)(0), + new(any), + isExpectedEq(nil), + }, + }) +} diff --git a/pgtype/zeronull/int_test.go.erb b/pgtype/zeronull/int_test.go.erb new file mode 100644 index 000000000..c0f72ef49 --- /dev/null +++ b/pgtype/zeronull/int_test.go.erb @@ -0,0 +1,31 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/jackc/pgx/v5/pgtype/zeronull" +) + +<% [2, 4, 8].each do |pg_byte_size| %> +<% pg_bit_size = pg_byte_size * 8 %> +func TestInt<%= pg_byte_size %>Transcode(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int<%= pg_byte_size %>", []pgxtest.ValueRoundTripTest{ + { + (zeronull.Int<%= pg_byte_size %>)(1), + new(zeronull.Int<%= pg_byte_size %>), + isExpectedEq((zeronull.Int<%= pg_byte_size %>)(1)), + }, + { + nil, + new(zeronull.Int<%= pg_byte_size %>), + isExpectedEq((zeronull.Int<%= pg_byte_size %>)(0)), + }, + { + (zeronull.Int<%= pg_byte_size %>)(0), + new(any), + isExpectedEq(nil), + }, + }) +} +<% end %> diff --git a/pgtype/zeronull/text.go b/pgtype/zeronull/text.go new file mode 100644 index 000000000..ed25c2084 --- /dev/null +++ b/pgtype/zeronull/text.go @@ -0,0 +1,50 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgx/v5/pgtype" +) + +type Text string + +// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface. +func (Text) SkipUnderlyingTypePlan() {} + +// ScanText implements the [pgtype.TextScanner] interface. +func (dst *Text) ScanText(v pgtype.Text) error { + if !v.Valid { + *dst = "" + return nil + } + + *dst = Text(v.String) + + return nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (dst *Text) Scan(src any) error { + if src == nil { + *dst = "" + return nil + } + + var nullable pgtype.Text + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Text(nullable.String) + + return nil +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (src Text) Value() (driver.Value, error) { + if src == "" { + return nil, nil + } + return string(src), nil +} diff --git a/pgtype/zeronull/text_test.go b/pgtype/zeronull/text_test.go new file mode 100644 index 000000000..5a60baf18 --- /dev/null +++ b/pgtype/zeronull/text_test.go @@ -0,0 +1,29 @@ +package zeronull_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestTextTranscode(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "text", []pgxtest.ValueRoundTripTest{ + { + (zeronull.Text)("foo"), + new(zeronull.Text), + isExpectedEq((zeronull.Text)("foo")), + }, + { + nil, + new(zeronull.Text), + isExpectedEq((zeronull.Text)("")), + }, + { + (zeronull.Text)(""), + new(any), + isExpectedEq(nil), + }, + }) +} diff --git a/pgtype/zeronull/timestamp.go b/pgtype/zeronull/timestamp.go new file mode 100644 index 000000000..1daf07e4a --- /dev/null +++ b/pgtype/zeronull/timestamp.go @@ -0,0 +1,70 @@ +package zeronull + +import ( + "database/sql/driver" + "fmt" + "time" + + "github.com/jackc/pgx/v5/pgtype" +) + +type Timestamp time.Time + +// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface. +func (Timestamp) SkipUnderlyingTypePlan() {} + +// ScanTimestamp implements the [pgtype.TimestampScanner] interface. +func (ts *Timestamp) ScanTimestamp(v pgtype.Timestamp) error { + if !v.Valid { + *ts = Timestamp{} + return nil + } + + switch v.InfinityModifier { + case pgtype.Finite: + *ts = Timestamp(v.Time) + return nil + case pgtype.Infinity: + return fmt.Errorf("cannot scan Infinity into *time.Time") + case pgtype.NegativeInfinity: + return fmt.Errorf("cannot scan -Infinity into *time.Time") + default: + return fmt.Errorf("invalid InfinityModifier: %v", v.InfinityModifier) + } +} + +// TimestampValue implements the [pgtype.TimestampValuer] interface. +func (ts Timestamp) TimestampValue() (pgtype.Timestamp, error) { + if time.Time(ts).IsZero() { + return pgtype.Timestamp{}, nil + } + + return pgtype.Timestamp{Time: time.Time(ts), Valid: true}, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (ts *Timestamp) Scan(src any) error { + if src == nil { + *ts = Timestamp{} + return nil + } + + var nullable pgtype.Timestamp + err := nullable.Scan(src) + if err != nil { + return err + } + + *ts = Timestamp(nullable.Time) + + return nil +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (ts Timestamp) Value() (driver.Value, error) { + if time.Time(ts).IsZero() { + return nil, nil + } + + return time.Time(ts), nil +} diff --git a/pgtype/zeronull/timestamp_test.go b/pgtype/zeronull/timestamp_test.go new file mode 100644 index 000000000..8a5a57966 --- /dev/null +++ b/pgtype/zeronull/timestamp_test.go @@ -0,0 +1,39 @@ +package zeronull_test + +import ( + "context" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgxtest" +) + +func isExpectedEqTimestamp(a any) func(any) bool { + return func(v any) bool { + at := time.Time(a.(zeronull.Timestamp)) + vt := time.Time(v.(zeronull.Timestamp)) + + return at.Equal(vt) + } +} + +func TestTimestampTranscode(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "timestamp", []pgxtest.ValueRoundTripTest{ + { + (zeronull.Timestamp)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), + new(zeronull.Timestamp), + isExpectedEqTimestamp((zeronull.Timestamp)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC))), + }, + { + nil, + new(zeronull.Timestamp), + isExpectedEqTimestamp((zeronull.Timestamp)(time.Time{})), + }, + { + (zeronull.Timestamp)(time.Time{}), + new(any), + isExpectedEq(nil), + }, + }) +} diff --git a/pgtype/zeronull/timestamptz.go b/pgtype/zeronull/timestamptz.go new file mode 100644 index 000000000..835b4707f --- /dev/null +++ b/pgtype/zeronull/timestamptz.go @@ -0,0 +1,70 @@ +package zeronull + +import ( + "database/sql/driver" + "fmt" + "time" + + "github.com/jackc/pgx/v5/pgtype" +) + +type Timestamptz time.Time + +// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface. +func (Timestamptz) SkipUnderlyingTypePlan() {} + +// ScanTimestamptz implements the [pgtype.TimestamptzScanner] interface. +func (ts *Timestamptz) ScanTimestamptz(v pgtype.Timestamptz) error { + if !v.Valid { + *ts = Timestamptz{} + return nil + } + + switch v.InfinityModifier { + case pgtype.Finite: + *ts = Timestamptz(v.Time) + return nil + case pgtype.Infinity: + return fmt.Errorf("cannot scan Infinity into *time.Time") + case pgtype.NegativeInfinity: + return fmt.Errorf("cannot scan -Infinity into *time.Time") + default: + return fmt.Errorf("invalid InfinityModifier: %v", v.InfinityModifier) + } +} + +// TimestamptzValue implements the [pgtype.TimestamptzValuer] interface. +func (ts Timestamptz) TimestamptzValue() (pgtype.Timestamptz, error) { + if time.Time(ts).IsZero() { + return pgtype.Timestamptz{}, nil + } + + return pgtype.Timestamptz{Time: time.Time(ts), Valid: true}, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (ts *Timestamptz) Scan(src any) error { + if src == nil { + *ts = Timestamptz{} + return nil + } + + var nullable pgtype.Timestamptz + err := nullable.Scan(src) + if err != nil { + return err + } + + *ts = Timestamptz(nullable.Time) + + return nil +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (ts Timestamptz) Value() (driver.Value, error) { + if time.Time(ts).IsZero() { + return nil, nil + } + + return time.Time(ts), nil +} diff --git a/pgtype/zeronull/timestamptz_test.go b/pgtype/zeronull/timestamptz_test.go new file mode 100644 index 000000000..0a6d380ba --- /dev/null +++ b/pgtype/zeronull/timestamptz_test.go @@ -0,0 +1,39 @@ +package zeronull_test + +import ( + "context" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgxtest" +) + +func isExpectedEqTimestamptz(a any) func(any) bool { + return func(v any) bool { + at := time.Time(a.(zeronull.Timestamptz)) + vt := time.Time(v.(zeronull.Timestamptz)) + + return at.Equal(vt) + } +} + +func TestTimestamptzTranscode(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "timestamptz", []pgxtest.ValueRoundTripTest{ + { + (zeronull.Timestamptz)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), + new(zeronull.Timestamptz), + isExpectedEqTimestamptz((zeronull.Timestamptz)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC))), + }, + { + nil, + new(zeronull.Timestamptz), + isExpectedEqTimestamptz((zeronull.Timestamptz)(time.Time{})), + }, + { + (zeronull.Timestamptz)(time.Time{}), + new(any), + isExpectedEq(nil), + }, + }) +} diff --git a/pgtype/zeronull/uuid.go b/pgtype/zeronull/uuid.go new file mode 100644 index 000000000..2cf8dec7b --- /dev/null +++ b/pgtype/zeronull/uuid.go @@ -0,0 +1,64 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgx/v5/pgtype" +) + +type UUID [16]byte + +// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface. +func (UUID) SkipUnderlyingTypePlan() {} + +// ScanUUID implements the [pgtype.UUIDScanner] interface. +func (u *UUID) ScanUUID(v pgtype.UUID) error { + if !v.Valid { + *u = UUID{} + return nil + } + + *u = UUID(v.Bytes) + + return nil +} + +// UUIDValue implements the [pgtype.UUIDValuer] interface. +func (u UUID) UUIDValue() (pgtype.UUID, error) { + if u == (UUID{}) { + return pgtype.UUID{}, nil + } + return pgtype.UUID{Bytes: u, Valid: true}, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (u *UUID) Scan(src any) error { + if src == nil { + *u = UUID{} + return nil + } + + var nullable pgtype.UUID + err := nullable.Scan(src) + if err != nil { + return err + } + + *u = UUID(nullable.Bytes) + + return nil +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (u UUID) Value() (driver.Value, error) { + if u == (UUID{}) { + return nil, nil + } + + buf, err := pgtype.UUIDCodec{}.PlanEncode(nil, pgtype.UUIDOID, pgtype.TextFormatCode, u).Encode(u, nil) + if err != nil { + return nil, err + } + + return string(buf), nil +} diff --git a/pgtype/zeronull/uuid_test.go b/pgtype/zeronull/uuid_test.go new file mode 100644 index 000000000..c50cb300b --- /dev/null +++ b/pgtype/zeronull/uuid_test.go @@ -0,0 +1,29 @@ +package zeronull_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestUUIDTranscode(t *testing.T) { + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "uuid", []pgxtest.ValueRoundTripTest{ + { + (zeronull.UUID)([16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), + new(zeronull.UUID), + isExpectedEq((zeronull.UUID)([16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15})), + }, + { + nil, + new(zeronull.UUID), + isExpectedEq((zeronull.UUID)([16]byte{})), + }, + { + (zeronull.UUID)([16]byte{}), + new(any), + isExpectedEq(nil), + }, + }) +} diff --git a/pgtype/zeronull/zeronull.go b/pgtype/zeronull/zeronull.go new file mode 100644 index 000000000..bba7b423b --- /dev/null +++ b/pgtype/zeronull/zeronull.go @@ -0,0 +1,17 @@ +package zeronull + +import ( + "github.com/jackc/pgx/v5/pgtype" +) + +// Register registers the zeronull types so they can be used in query exec modes that do not know the server OIDs. +func Register(m *pgtype.Map) { + m.RegisterDefaultPgType(Float8(0), "float8") + m.RegisterDefaultPgType(Int2(0), "int2") + m.RegisterDefaultPgType(Int4(0), "int4") + m.RegisterDefaultPgType(Int8(0), "int8") + m.RegisterDefaultPgType(Text(""), "text") + m.RegisterDefaultPgType(Timestamp{}, "timestamp") + m.RegisterDefaultPgType(Timestamptz{}, "timestamptz") + m.RegisterDefaultPgType(UUID{}, "uuid") +} diff --git a/pgtype/zeronull/zeronull_test.go b/pgtype/zeronull/zeronull_test.go new file mode 100644 index 000000000..9ee45cb7f --- /dev/null +++ b/pgtype/zeronull/zeronull_test.go @@ -0,0 +1,26 @@ +package zeronull_test + +import ( + "context" + "os" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype/zeronull" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" +) + +var defaultConnTestRunner pgxtest.ConnTestRunner + +func init() { + defaultConnTestRunner = pgxtest.DefaultConnTestRunner() + defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + return config + } + defaultConnTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + zeronull.Register(conn.TypeMap()) + } +} diff --git a/pgx_test.go b/pgx_test.go new file mode 100644 index 000000000..51b4bbc4e --- /dev/null +++ b/pgx_test.go @@ -0,0 +1,22 @@ +package pgx_test + +import ( + "context" + "os" + "testing" + + "github.com/jackc/pgx/v5" + _ "github.com/jackc/pgx/v5/stdlib" +) + +func skipCockroachDB(t testing.TB, msg string) { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + t.Fatal(err) + } + defer conn.Close(context.Background()) + + if conn.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip(msg) + } +} diff --git a/pgxpool/batch_results.go b/pgxpool/batch_results.go new file mode 100644 index 000000000..5d5c681d5 --- /dev/null +++ b/pgxpool/batch_results.go @@ -0,0 +1,52 @@ +package pgxpool + +import ( + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +type errBatchResults struct { + err error +} + +func (br errBatchResults) Exec() (pgconn.CommandTag, error) { + return pgconn.CommandTag{}, br.err +} + +func (br errBatchResults) Query() (pgx.Rows, error) { + return errRows{err: br.err}, br.err +} + +func (br errBatchResults) QueryRow() pgx.Row { + return errRow{err: br.err} +} + +func (br errBatchResults) Close() error { + return br.err +} + +type poolBatchResults struct { + br pgx.BatchResults + c *Conn +} + +func (br *poolBatchResults) Exec() (pgconn.CommandTag, error) { + return br.br.Exec() +} + +func (br *poolBatchResults) Query() (pgx.Rows, error) { + return br.br.Query() +} + +func (br *poolBatchResults) QueryRow() pgx.Row { + return br.br.QueryRow() +} + +func (br *poolBatchResults) Close() error { + err := br.br.Close() + if br.c != nil { + br.c.Release() + br.c = nil + } + return err +} diff --git a/pgxpool/bench_test.go b/pgxpool/bench_test.go new file mode 100644 index 000000000..c2d58a387 --- /dev/null +++ b/pgxpool/bench_test.go @@ -0,0 +1,84 @@ +package pgxpool_test + +import ( + "context" + "os" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/require" +) + +func BenchmarkAcquireAndRelease(b *testing.B) { + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(b, err) + defer pool.Close() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + c, err := pool.Acquire(context.Background()) + if err != nil { + b.Fatal(err) + } + c.Release() + } +} + +func BenchmarkMinimalPreparedSelectBaseline(b *testing.B) { + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(b, err) + + config.AfterConnect = func(ctx context.Context, c *pgx.Conn) error { + _, err := c.Prepare(ctx, "ps1", "select $1::int8") + return err + } + + db, err := pgxpool.NewWithConfig(context.Background(), config) + require.NoError(b, err) + + conn, err := db.Acquire(context.Background()) + require.NoError(b, err) + defer conn.Release() + + var n int64 + + b.ResetTimer() + for i := 0; i < b.N; i++ { + err = conn.QueryRow(context.Background(), "ps1", i).Scan(&n) + if err != nil { + b.Fatal(err) + } + + if n != int64(i) { + b.Fatalf("expected %d, got %d", i, n) + } + } +} + +func BenchmarkMinimalPreparedSelect(b *testing.B) { + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(b, err) + + config.AfterConnect = func(ctx context.Context, c *pgx.Conn) error { + _, err := c.Prepare(ctx, "ps1", "select $1::int8") + return err + } + + db, err := pgxpool.NewWithConfig(context.Background(), config) + require.NoError(b, err) + + var n int64 + + b.ResetTimer() + for i := 0; i < b.N; i++ { + err = db.QueryRow(context.Background(), "ps1", i).Scan(&n) + if err != nil { + b.Fatal(err) + } + + if n != int64(i) { + b.Fatalf("expected %d, got %d", i, n) + } + } +} diff --git a/pgxpool/common_test.go b/pgxpool/common_test.go new file mode 100644 index 000000000..1b8e488af --- /dev/null +++ b/pgxpool/common_test.go @@ -0,0 +1,204 @@ +package pgxpool_test + +import ( + "context" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Conn.Release is an asynchronous process that returns immediately. There is no signal when the actual work is +// completed. To test something that relies on the actual work for Conn.Release being completed we must simply wait. +// This function wraps the sleep so there is more meaning for the callers. +func waitForReleaseToComplete() { + time.Sleep(500 * time.Millisecond) +} + +type execer interface { + Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) +} + +func testExec(t *testing.T, ctx context.Context, db execer) { + results, err := db.Exec(ctx, "set time zone 'America/Chicago'") + require.NoError(t, err) + assert.EqualValues(t, "SET", results.String()) +} + +type queryer interface { + Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) +} + +func testQuery(t *testing.T, ctx context.Context, db queryer) { + var sum, rowCount int32 + + rows, err := db.Query(ctx, "select generate_series(1,$1)", 10) + require.NoError(t, err) + + for rows.Next() { + var n int32 + rows.Scan(&n) + sum += n + rowCount++ + } + + assert.NoError(t, rows.Err()) + assert.Equal(t, int32(10), rowCount) + assert.Equal(t, int32(55), sum) +} + +type queryRower interface { + QueryRow(ctx context.Context, sql string, args ...any) pgx.Row +} + +func testQueryRow(t *testing.T, ctx context.Context, db queryRower) { + var what, who string + err := db.QueryRow(ctx, "select 'hello', $1::text", "world").Scan(&what, &who) + assert.NoError(t, err) + assert.Equal(t, "hello", what) + assert.Equal(t, "world", who) +} + +type sendBatcher interface { + SendBatch(context.Context, *pgx.Batch) pgx.BatchResults +} + +func testSendBatch(t *testing.T, ctx context.Context, db sendBatcher) { + batch := &pgx.Batch{} + batch.Queue("select 1") + batch.Queue("select 2") + + br := db.SendBatch(ctx, batch) + + var err error + var n int32 + err = br.QueryRow().Scan(&n) + assert.NoError(t, err) + assert.EqualValues(t, 1, n) + + err = br.QueryRow().Scan(&n) + assert.NoError(t, err) + assert.EqualValues(t, 2, n) + + err = br.Close() + assert.NoError(t, err) +} + +type copyFromer interface { + CopyFrom(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) (int64, error) +} + +func testCopyFrom(t *testing.T, ctx context.Context, db interface { + execer + queryer + copyFromer +}, +) { + _, err := db.Exec(ctx, `create temporary table foo(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`) + require.NoError(t, err) + + tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) + + inputRows := [][]any{ + {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, + {nil, nil, nil, nil, nil, nil, nil}, + } + + copyCount, err := db.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) + assert.NoError(t, err) + assert.EqualValues(t, len(inputRows), copyCount) + + rows, err := db.Query(ctx, "select * from foo") + assert.NoError(t, err) + + var outputRows [][]any + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + assert.NoError(t, rows.Err()) + assert.Equal(t, inputRows, outputRows) +} + +func assertConfigsEqual(t *testing.T, expected, actual *pgxpool.Config, testName string) { + if !assert.NotNil(t, expected) { + return + } + if !assert.NotNil(t, actual) { + return + } + + assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) + + // Can't test function equality, so just test that they are set or not. + assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) + assert.Equalf(t, expected.BeforeAcquire == nil, actual.BeforeAcquire == nil, "%s - BeforeAcquire", testName) + assert.Equalf(t, expected.PrepareConn == nil, actual.PrepareConn == nil, "%s - PrepareConn", testName) + assert.Equalf(t, expected.AfterRelease == nil, actual.AfterRelease == nil, "%s - AfterRelease", testName) + + assert.Equalf(t, expected.MaxConnLifetime, actual.MaxConnLifetime, "%s - MaxConnLifetime", testName) + assert.Equalf(t, expected.MaxConnIdleTime, actual.MaxConnIdleTime, "%s - MaxConnIdleTime", testName) + assert.Equalf(t, expected.MaxConns, actual.MaxConns, "%s - MaxConns", testName) + assert.Equalf(t, expected.MinConns, actual.MinConns, "%s - MinConns", testName) + assert.Equalf(t, expected.MinIdleConns, actual.MinIdleConns, "%s - MinIdleConns", testName) + assert.Equalf(t, expected.HealthCheckPeriod, actual.HealthCheckPeriod, "%s - HealthCheckPeriod", testName) + + assertConnConfigsEqual(t, expected.ConnConfig, actual.ConnConfig, testName) +} + +func assertConnConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName string) { + if !assert.NotNil(t, expected) { + return + } + if !assert.NotNil(t, actual) { + return + } + + assert.Equalf(t, expected.Tracer, actual.Tracer, "%s - Tracer", testName) + assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) + assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName) + assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName) + assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName) + assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName) + assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) + assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) + assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) + assert.Equalf(t, expected.User, actual.User, "%s - User", testName) + assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) + assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName) + assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) + + // Can't test function equality, so just test that they are set or not. + assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName) + assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) + + if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { + if expected.TLSConfig != nil { + assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName) + assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName) + } + } + + if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) { + for i := range expected.Fallbacks { + assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i) + assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i) + + if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) { + if expected.Fallbacks[i].TLSConfig != nil { + assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName) + assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName) + } + } + } + } +} diff --git a/pgxpool/conn.go b/pgxpool/conn.go new file mode 100644 index 000000000..38c90f3da --- /dev/null +++ b/pgxpool/conn.go @@ -0,0 +1,134 @@ +package pgxpool + +import ( + "context" + "sync/atomic" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/puddle/v2" +) + +// Conn is an acquired *pgx.Conn from a Pool. +type Conn struct { + res *puddle.Resource[*connResource] + p *Pool +} + +// Release returns c to the pool it was acquired from. Once Release has been called, other methods must not be called. +// However, it is safe to call Release multiple times. Subsequent calls after the first will be ignored. +func (c *Conn) Release() { + if c.res == nil { + return + } + + conn := c.Conn() + res := c.res + c.res = nil + + if c.p.releaseTracer != nil { + c.p.releaseTracer.TraceRelease(c.p, TraceReleaseData{Conn: conn}) + } + + if conn.IsClosed() || conn.PgConn().IsBusy() || conn.PgConn().TxStatus() != 'I' { + res.Destroy() + // Signal to the health check to run since we just destroyed a connections + // and we might be below minConns now + c.p.triggerHealthCheck() + return + } + + // If the pool is consistently being used, we might never get to check the + // lifetime of a connection since we only check idle connections in checkConnsHealth + // so we also check the lifetime here and force a health check + if c.p.isExpired(res) { + atomic.AddInt64(&c.p.lifetimeDestroyCount, 1) + res.Destroy() + // Signal to the health check to run since we just destroyed a connections + // and we might be below minConns now + c.p.triggerHealthCheck() + return + } + + if c.p.afterRelease == nil { + res.Release() + return + } + + go func() { + if c.p.afterRelease(conn) { + res.Release() + } else { + res.Destroy() + // Signal to the health check to run since we just destroyed a connections + // and we might be below minConns now + c.p.triggerHealthCheck() + } + }() +} + +// Hijack assumes ownership of the connection from the pool. Caller is responsible for closing the connection. Hijack +// will panic if called on an already released or hijacked connection. +func (c *Conn) Hijack() *pgx.Conn { + if c.res == nil { + panic("cannot hijack already released or hijacked connection") + } + + conn := c.Conn() + res := c.res + c.res = nil + + res.Hijack() + + return conn +} + +func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { + return c.Conn().Exec(ctx, sql, arguments...) +} + +func (c *Conn) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { + return c.Conn().Query(ctx, sql, args...) +} + +func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { + return c.Conn().QueryRow(ctx, sql, args...) +} + +func (c *Conn) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { + return c.Conn().SendBatch(ctx, b) +} + +func (c *Conn) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { + return c.Conn().CopyFrom(ctx, tableName, columnNames, rowSrc) +} + +// Begin starts a transaction block from the *Conn without explicitly setting a transaction mode (see BeginTx with TxOptions if transaction mode is required). +func (c *Conn) Begin(ctx context.Context) (pgx.Tx, error) { + return c.Conn().Begin(ctx) +} + +// BeginTx starts a transaction block from the *Conn with txOptions determining the transaction mode. +func (c *Conn) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) { + return c.Conn().BeginTx(ctx, txOptions) +} + +func (c *Conn) Ping(ctx context.Context) error { + return c.Conn().Ping(ctx) +} + +func (c *Conn) Conn() *pgx.Conn { + return c.connResource().conn +} + +func (c *Conn) connResource() *connResource { + return c.res.Value() +} + +func (c *Conn) getPoolRow(r pgx.Row) *poolRow { + return c.connResource().getPoolRow(c, r) +} + +func (c *Conn) getPoolRows(r pgx.Rows) *poolRows { + return c.connResource().getPoolRows(c, r) +} diff --git a/pgxpool/conn_test.go b/pgxpool/conn_test.go new file mode 100644 index 000000000..ce35c494f --- /dev/null +++ b/pgxpool/conn_test.go @@ -0,0 +1,96 @@ +package pgxpool_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/require" +) + +func TestConnExec(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + defer c.Release() + + testExec(t, ctx, c) +} + +func TestConnQuery(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + defer c.Release() + + testQuery(t, ctx, c) +} + +func TestConnQueryRow(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + defer c.Release() + + testQueryRow(t, ctx, c) +} + +func TestConnSendBatch(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + defer c.Release() + + testSendBatch(t, ctx, c) +} + +func TestConnCopyFrom(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + defer c.Release() + + testCopyFrom(t, ctx, c) +} diff --git a/pgxpool/doc.go b/pgxpool/doc.go new file mode 100644 index 000000000..099443bca --- /dev/null +++ b/pgxpool/doc.go @@ -0,0 +1,27 @@ +// Package pgxpool is a concurrency-safe connection pool for pgx. +/* +pgxpool implements a nearly identical interface to pgx connections. + +Creating a Pool + +The primary way of creating a pool is with [pgxpool.New]: + + pool, err := pgxpool.New(context.Background(), os.Getenv("DATABASE_URL")) + +The database connection string can be in URL or keyword/value format. PostgreSQL settings, pgx settings, and pool settings can be +specified here. In addition, a config struct can be created by [ParseConfig]. + + config, err := pgxpool.ParseConfig(os.Getenv("DATABASE_URL")) + if err != nil { + // ... + } + config.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { + // do something with every new connection + } + + pool, err := pgxpool.NewWithConfig(context.Background(), config) + +A pool returns without waiting for any connections to be established. Acquire a connection immediately after creating +the pool to check if a connection can successfully be established. +*/ +package pgxpool diff --git a/pgxpool/helper_test.go b/pgxpool/helper_test.go new file mode 100644 index 000000000..7d63732a1 --- /dev/null +++ b/pgxpool/helper_test.go @@ -0,0 +1,39 @@ +package pgxpool_test + +import ( + "context" + "net" + "time" + + "github.com/jackc/pgx/v5/pgconn" +) + +// delayProxy is a that introduces a configurable delay on reads from the database connection. +type delayProxy struct { + net.Conn + readDelay time.Duration +} + +func newDelayProxy(conn net.Conn, readDelay time.Duration) *delayProxy { + p := &delayProxy{ + Conn: conn, + readDelay: readDelay, + } + + return p +} + +func (dp *delayProxy) Read(b []byte) (int, error) { + if dp.readDelay > 0 { + time.Sleep(dp.readDelay) + } + + return dp.Conn.Read(b) +} + +func newDelayProxyDialFunc(readDelay time.Duration) pgconn.DialFunc { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := net.Dial(network, addr) + return newDelayProxy(conn, readDelay), err + } +} diff --git a/pgxpool/pool.go b/pgxpool/pool.go new file mode 100644 index 000000000..f94f9cf0f --- /dev/null +++ b/pgxpool/pool.go @@ -0,0 +1,830 @@ +package pgxpool + +import ( + "context" + "errors" + "math/rand" + "runtime" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/puddle/v2" +) + +var ( + defaultMaxConns = int32(4) + defaultMinConns = int32(0) + defaultMinIdleConns = int32(0) + defaultMaxConnLifetime = time.Hour + defaultMaxConnIdleTime = time.Minute * 30 + defaultHealthCheckPeriod = time.Minute +) + +type connResource struct { + conn *pgx.Conn + conns []Conn + poolRows []poolRow + poolRowss []poolRows + maxAgeTime time.Time +} + +func (cr *connResource) getConn(p *Pool, res *puddle.Resource[*connResource]) *Conn { + if len(cr.conns) == 0 { + cr.conns = make([]Conn, 128) + } + + c := &cr.conns[len(cr.conns)-1] + cr.conns = cr.conns[0 : len(cr.conns)-1] + + c.res = res + c.p = p + + return c +} + +func (cr *connResource) getPoolRow(c *Conn, r pgx.Row) *poolRow { + if len(cr.poolRows) == 0 { + cr.poolRows = make([]poolRow, 128) + } + + pr := &cr.poolRows[len(cr.poolRows)-1] + cr.poolRows = cr.poolRows[0 : len(cr.poolRows)-1] + + pr.c = c + pr.r = r + + return pr +} + +func (cr *connResource) getPoolRows(c *Conn, r pgx.Rows) *poolRows { + if len(cr.poolRowss) == 0 { + cr.poolRowss = make([]poolRows, 128) + } + + pr := &cr.poolRowss[len(cr.poolRowss)-1] + cr.poolRowss = cr.poolRowss[0 : len(cr.poolRowss)-1] + + pr.c = c + pr.r = r + + return pr +} + +// Pool allows for connection reuse. +type Pool struct { + // 64 bit fields accessed with atomics must be at beginning of struct to guarantee alignment for certain 32-bit + // architectures. See BUGS section of https://pkg.go.dev/sync/atomic and https://github.com/jackc/pgx/issues/1288. + newConnsCount int64 + lifetimeDestroyCount int64 + idleDestroyCount int64 + + p *puddle.Pool[*connResource] + config *Config + beforeConnect func(context.Context, *pgx.ConnConfig) error + afterConnect func(context.Context, *pgx.Conn) error + prepareConn func(context.Context, *pgx.Conn) (bool, error) + afterRelease func(*pgx.Conn) bool + beforeClose func(*pgx.Conn) + shouldPing func(context.Context, ShouldPingParams) bool + minConns int32 + minIdleConns int32 + maxConns int32 + maxConnLifetime time.Duration + maxConnLifetimeJitter time.Duration + maxConnIdleTime time.Duration + healthCheckPeriod time.Duration + pingTimeout time.Duration + + healthCheckMu sync.Mutex + healthCheckTimer *time.Timer + + healthCheckChan chan struct{} + + acquireTracer AcquireTracer + releaseTracer ReleaseTracer + + closeOnce sync.Once + closeChan chan struct{} +} + +// ShouldPingParams are the parameters passed to ShouldPing. +type ShouldPingParams struct { + Conn *pgx.Conn + IdleDuration time.Duration +} + +// Config is the configuration struct for creating a pool. It must be created by [ParseConfig] and then it can be +// modified. +type Config struct { + ConnConfig *pgx.ConnConfig + + // BeforeConnect is called before a new connection is made. It is passed a copy of the underlying pgx.ConnConfig and + // will not impact any existing open connections. + BeforeConnect func(context.Context, *pgx.ConnConfig) error + + // AfterConnect is called after a connection is established, but before it is added to the pool. + AfterConnect func(context.Context, *pgx.Conn) error + + // BeforeAcquire is called before a connection is acquired from the pool. It must return true to allow the + // acquisition or false to indicate that the connection should be destroyed and a different connection should be + // acquired. + // + // Deprecated: Use PrepareConn instead. If both PrepareConn and BeforeAcquire are set, PrepareConn will take + // precedence, ignoring BeforeAcquire. + BeforeAcquire func(context.Context, *pgx.Conn) bool + + // PrepareConn is called before a connection is acquired from the pool. If this function returns true, the connection + // is considered valid, otherwise the connection is destroyed. If the function returns a non-nil error, the instigating + // query will fail with the returned error. + // + // Specifically, this means that: + // + // - If it returns true and a nil error, the query proceeds as normal. + // - If it returns true and an error, the connection will be returned to the pool, and the instigating query will fail with the returned error. + // - If it returns false, and an error, the connection will be destroyed, and the query will fail with the returned error. + // - If it returns false and a nil error, the connection will be destroyed, and the instigating query will be retried on a new connection. + PrepareConn func(context.Context, *pgx.Conn) (bool, error) + + // AfterRelease is called after a connection is released, but before it is returned to the pool. It must return true to + // return the connection to the pool or false to destroy the connection. + AfterRelease func(*pgx.Conn) bool + + // BeforeClose is called right before a connection is closed and removed from the pool. + BeforeClose func(*pgx.Conn) + + // ShouldPing is called after a connection is acquired from the pool. If it returns true, the connection is pinged to check for liveness. + // If this func is not set, the default behavior is to ping connections that have been idle for at least 1 second. + ShouldPing func(context.Context, ShouldPingParams) bool + + // MaxConnLifetime is the duration since creation after which a connection will be automatically closed. + MaxConnLifetime time.Duration + + // MaxConnLifetimeJitter is the duration after MaxConnLifetime to randomly decide to close a connection. + // This helps prevent all connections from being closed at the exact same time, starving the pool. + MaxConnLifetimeJitter time.Duration + + // MaxConnIdleTime is the duration after which an idle connection will be automatically closed by the health check. + MaxConnIdleTime time.Duration + + // PingTimeout is the maximum amount of time to wait for a connection to pong before considering it as unhealthy and + // destroying it. If zero, the default is no timeout. + PingTimeout time.Duration + + // MaxConns is the maximum size of the pool. The default is the greater of 4 or runtime.NumCPU(). + MaxConns int32 + + // MinConns is the minimum size of the pool. After connection closes, the pool might dip below MinConns. A low + // number of MinConns might mean the pool is empty after MaxConnLifetime until the health check has a chance + // to create new connections. + MinConns int32 + + // MinIdleConns is the minimum number of idle connections in the pool. You can increase this to ensure that + // there are always idle connections available. This can help reduce tail latencies during request processing, + // as you can avoid the latency of establishing a new connection while handling requests. It is superior + // to MinConns for this purpose. + // Similar to MinConns, the pool might temporarily dip below MinIdleConns after connection closes. + MinIdleConns int32 + + // HealthCheckPeriod is the duration between checks of the health of idle connections. + HealthCheckPeriod time.Duration + + createdByParseConfig bool // Used to enforce created by ParseConfig rule. +} + +// Copy returns a deep copy of the config that is safe to use and modify. +// The only exception is the tls.Config: +// according to the tls.Config docs it must not be modified after creation. +func (c *Config) Copy() *Config { + newConfig := new(Config) + *newConfig = *c + newConfig.ConnConfig = c.ConnConfig.Copy() + return newConfig +} + +// ConnString returns the connection string as parsed by pgxpool.ParseConfig into pgxpool.Config. +func (c *Config) ConnString() string { return c.ConnConfig.ConnString() } + +// New creates a new Pool. See [ParseConfig] for information on connString format. +func New(ctx context.Context, connString string) (*Pool, error) { + config, err := ParseConfig(connString) + if err != nil { + return nil, err + } + + return NewWithConfig(ctx, config) +} + +// NewWithConfig creates a new Pool. config must have been created by [ParseConfig]. +func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { + // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from + // zero values. + if !config.createdByParseConfig { + panic("config must be created by ParseConfig") + } + + prepareConn := config.PrepareConn + if prepareConn == nil && config.BeforeAcquire != nil { + prepareConn = func(ctx context.Context, conn *pgx.Conn) (bool, error) { + return config.BeforeAcquire(ctx, conn), nil + } + } + + p := &Pool{ + config: config, + beforeConnect: config.BeforeConnect, + afterConnect: config.AfterConnect, + prepareConn: prepareConn, + afterRelease: config.AfterRelease, + beforeClose: config.BeforeClose, + minConns: config.MinConns, + minIdleConns: config.MinIdleConns, + maxConns: config.MaxConns, + maxConnLifetime: config.MaxConnLifetime, + maxConnLifetimeJitter: config.MaxConnLifetimeJitter, + maxConnIdleTime: config.MaxConnIdleTime, + pingTimeout: config.PingTimeout, + healthCheckPeriod: config.HealthCheckPeriod, + healthCheckChan: make(chan struct{}, 1), + closeChan: make(chan struct{}), + } + + if t, ok := config.ConnConfig.Tracer.(AcquireTracer); ok { + p.acquireTracer = t + } + + if t, ok := config.ConnConfig.Tracer.(ReleaseTracer); ok { + p.releaseTracer = t + } + + if config.ShouldPing != nil { + p.shouldPing = config.ShouldPing + } else { + p.shouldPing = func(ctx context.Context, params ShouldPingParams) bool { + return params.IdleDuration > time.Second + } + } + + var err error + p.p, err = puddle.NewPool( + &puddle.Config[*connResource]{ + Constructor: func(ctx context.Context) (*connResource, error) { + atomic.AddInt64(&p.newConnsCount, 1) + connConfig := p.config.ConnConfig.Copy() + + // Connection will continue in background even if Acquire is canceled. Ensure that a connect won't hang forever. + if connConfig.ConnectTimeout <= 0 { + connConfig.ConnectTimeout = 2 * time.Minute + } + + if p.beforeConnect != nil { + if err := p.beforeConnect(ctx, connConfig); err != nil { + return nil, err + } + } + + conn, err := pgx.ConnectConfig(ctx, connConfig) + if err != nil { + return nil, err + } + + if p.afterConnect != nil { + err = p.afterConnect(ctx, conn) + if err != nil { + conn.Close(ctx) + return nil, err + } + } + + jitterSecs := rand.Float64() * config.MaxConnLifetimeJitter.Seconds() + maxAgeTime := time.Now().Add(config.MaxConnLifetime).Add(time.Duration(jitterSecs) * time.Second) + + cr := &connResource{ + conn: conn, + conns: make([]Conn, 64), + poolRows: make([]poolRow, 64), + poolRowss: make([]poolRows, 64), + maxAgeTime: maxAgeTime, + } + + return cr, nil + }, + Destructor: func(value *connResource) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + conn := value.conn + if p.beforeClose != nil { + p.beforeClose(conn) + } + conn.Close(ctx) + select { + case <-conn.PgConn().CleanupDone(): + case <-ctx.Done(): + } + cancel() + }, + MaxSize: config.MaxConns, + }, + ) + if err != nil { + return nil, err + } + + go func() { + targetIdleResources := max(int(p.minConns), int(p.minIdleConns)) + p.createIdleResources(ctx, targetIdleResources) + p.backgroundHealthCheck() + }() + + return p, nil +} + +// ParseConfig builds a Config from connString. It parses connString with the same behavior as [pgx.ParseConfig] with the +// addition of the following variables: +// +// - pool_max_conns: integer greater than 0 (default 4) +// - pool_min_conns: integer 0 or greater (default 0) +// - pool_max_conn_lifetime: duration string (default 1 hour) +// - pool_max_conn_idle_time: duration string (default 30 minutes) +// - pool_health_check_period: duration string (default 1 minute) +// - pool_max_conn_lifetime_jitter: duration string (default 0) +// +// See Config for definitions of these arguments. +// +// # Example Keyword/Value +// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca pool_max_conns=10 pool_max_conn_lifetime=1h30m +// +// # Example URL +// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca&pool_max_conns=10&pool_max_conn_lifetime=1h30m +func ParseConfig(connString string) (*Config, error) { + connConfig, err := pgx.ParseConfig(connString) + if err != nil { + return nil, err + } + + config := &Config{ + ConnConfig: connConfig, + createdByParseConfig: true, + } + + if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conns"]; ok { + delete(connConfig.Config.RuntimeParams, "pool_max_conns") + n, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conns", err) + } + if n < 1 { + return nil, pgconn.NewParseConfigError(connString, "pool_max_conns too small", err) + } + config.MaxConns = int32(n) + } else { + config.MaxConns = defaultMaxConns + if numCPU := int32(runtime.NumCPU()); numCPU > config.MaxConns { + config.MaxConns = numCPU + } + } + + if s, ok := config.ConnConfig.Config.RuntimeParams["pool_min_conns"]; ok { + delete(connConfig.Config.RuntimeParams, "pool_min_conns") + n, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_min_conns", err) + } + config.MinConns = int32(n) + } else { + config.MinConns = defaultMinConns + } + + if s, ok := config.ConnConfig.Config.RuntimeParams["pool_min_idle_conns"]; ok { + delete(connConfig.Config.RuntimeParams, "pool_min_idle_conns") + n, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_min_idle_conns", err) + } + config.MinIdleConns = int32(n) + } else { + config.MinIdleConns = defaultMinIdleConns + } + + if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_lifetime"]; ok { + delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime") + d, err := time.ParseDuration(s) + if err != nil { + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conn_lifetime", err) + } + config.MaxConnLifetime = d + } else { + config.MaxConnLifetime = defaultMaxConnLifetime + } + + if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_idle_time"]; ok { + delete(connConfig.Config.RuntimeParams, "pool_max_conn_idle_time") + d, err := time.ParseDuration(s) + if err != nil { + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conn_idle_time", err) + } + config.MaxConnIdleTime = d + } else { + config.MaxConnIdleTime = defaultMaxConnIdleTime + } + + if s, ok := config.ConnConfig.Config.RuntimeParams["pool_health_check_period"]; ok { + delete(connConfig.Config.RuntimeParams, "pool_health_check_period") + d, err := time.ParseDuration(s) + if err != nil { + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_health_check_period", err) + } + config.HealthCheckPeriod = d + } else { + config.HealthCheckPeriod = defaultHealthCheckPeriod + } + + if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_lifetime_jitter"]; ok { + delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime_jitter") + d, err := time.ParseDuration(s) + if err != nil { + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conn_lifetime_jitter", err) + } + config.MaxConnLifetimeJitter = d + } + + return config, nil +} + +// Close closes all connections in the pool and rejects future Acquire calls. Blocks until all connections are returned +// to pool and closed. +func (p *Pool) Close() { + p.closeOnce.Do(func() { + close(p.closeChan) + p.p.Close() + }) +} + +func (p *Pool) isExpired(res *puddle.Resource[*connResource]) bool { + return time.Now().After(res.Value().maxAgeTime) +} + +func (p *Pool) triggerHealthCheck() { + const healthCheckDelay = 500 * time.Millisecond + + p.healthCheckMu.Lock() + defer p.healthCheckMu.Unlock() + + if p.healthCheckTimer == nil { + // Destroy is asynchronous so we give it time to actually remove itself from + // the pool otherwise we might try to check the pool size too soon + p.healthCheckTimer = time.AfterFunc(healthCheckDelay, func() { + select { + case <-p.closeChan: + case p.healthCheckChan <- struct{}{}: + default: + } + }) + return + } + + p.healthCheckTimer.Reset(healthCheckDelay) +} + +func (p *Pool) backgroundHealthCheck() { + ticker := time.NewTicker(p.healthCheckPeriod) + defer ticker.Stop() + for { + select { + case <-p.closeChan: + return + case <-p.healthCheckChan: + p.checkHealth() + case <-ticker.C: + p.checkHealth() + } + } +} + +func (p *Pool) checkHealth() { + for { + // If checkMinConns failed we don't destroy any connections since we couldn't + // even get to minConns + if err := p.checkMinConns(); err != nil { + // Should we log this error somewhere? + break + } + if !p.checkConnsHealth() { + // Since we didn't destroy any connections we can stop looping + break + } + // Technically Destroy is asynchronous but 500ms should be enough for it to + // remove it from the underlying pool + select { + case <-p.closeChan: + return + case <-time.After(500 * time.Millisecond): + } + } +} + +// checkConnsHealth will check all idle connections, destroy a connection if +// it's idle or too old, and returns true if any were destroyed +func (p *Pool) checkConnsHealth() bool { + var destroyed bool + totalConns := p.Stat().TotalConns() + resources := p.p.AcquireAllIdle() + for _, res := range resources { + // We're okay going under minConns if the lifetime is up + if p.isExpired(res) && totalConns >= p.minConns { + atomic.AddInt64(&p.lifetimeDestroyCount, 1) + res.Destroy() + destroyed = true + // Since Destroy is async we manually decrement totalConns. + totalConns-- + } else if res.IdleDuration() > p.maxConnIdleTime && totalConns > p.minConns { + atomic.AddInt64(&p.idleDestroyCount, 1) + res.Destroy() + destroyed = true + // Since Destroy is async we manually decrement totalConns. + totalConns-- + } else { + res.ReleaseUnused() + } + } + return destroyed +} + +func (p *Pool) checkMinConns() error { + // TotalConns can include ones that are being destroyed but we should have + // sleep(500ms) around all of the destroys to help prevent that from throwing + // off this check + + // Create the number of connections needed to get to both minConns and minIdleConns + toCreate := max(p.minConns-p.Stat().TotalConns(), p.minIdleConns-p.Stat().IdleConns()) + if toCreate > 0 { + return p.createIdleResources(context.Background(), int(toCreate)) + } + return nil +} + +func (p *Pool) createIdleResources(parentCtx context.Context, targetResources int) error { + ctx, cancel := context.WithCancel(parentCtx) + defer cancel() + + errs := make(chan error, targetResources) + + for i := 0; i < targetResources; i++ { + go func() { + err := p.p.CreateResource(ctx) + // Ignore ErrNotAvailable since it means that the pool has become full since we started creating resource. + if err == puddle.ErrNotAvailable { + err = nil + } + errs <- err + }() + } + + var firstError error + for i := 0; i < targetResources; i++ { + err := <-errs + if err != nil && firstError == nil { + cancel() + firstError = err + } + } + + return firstError +} + +// Acquire returns a connection (*Conn) from the Pool +func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) { + if p.acquireTracer != nil { + ctx = p.acquireTracer.TraceAcquireStart(ctx, p, TraceAcquireStartData{}) + defer func() { + var conn *pgx.Conn + if c != nil { + conn = c.Conn() + } + p.acquireTracer.TraceAcquireEnd(ctx, p, TraceAcquireEndData{Conn: conn, Err: err}) + }() + } + + // Try to acquire from the connection pool up to maxConns + 1 times, so that + // any that fatal errors would empty the pool and still at least try 1 fresh + // connection. + for range int(p.maxConns) + 1 { + res, err := p.p.Acquire(ctx) + if err != nil { + return nil, err + } + + cr := res.Value() + + shouldPingParams := ShouldPingParams{Conn: cr.conn, IdleDuration: res.IdleDuration()} + if p.shouldPing(ctx, shouldPingParams) { + pingCtx := ctx + if p.pingTimeout > 0 { + var cancel context.CancelFunc + pingCtx, cancel = context.WithTimeout(ctx, p.pingTimeout) + defer cancel() + } + + err := cr.conn.Ping(pingCtx) + if err != nil { + res.Destroy() + continue + } + } + + if p.prepareConn != nil { + ok, err := p.prepareConn(ctx, cr.conn) + if !ok { + res.Destroy() + } + if err != nil { + if ok { + res.Release() + } + return nil, err + } + if !ok { + continue + } + } + + return cr.getConn(p, res), nil + } + return nil, errors.New("pgxpool: detected infinite loop acquiring connection; likely bug in PrepareConn or BeforeAcquire hook") +} + +// AcquireFunc acquires a *Conn and calls f with that *Conn. ctx will only affect the Acquire. It has no effect on the +// call of f. The return value is either an error acquiring the *Conn or the return value of f. The *Conn is +// automatically released after the call of f. +func (p *Pool) AcquireFunc(ctx context.Context, f func(*Conn) error) error { + conn, err := p.Acquire(ctx) + if err != nil { + return err + } + defer conn.Release() + + return f(conn) +} + +// AcquireAllIdle atomically acquires all currently idle connections. Its intended use is for health check and +// keep-alive functionality. It does not update pool statistics. +func (p *Pool) AcquireAllIdle(ctx context.Context) []*Conn { + resources := p.p.AcquireAllIdle() + conns := make([]*Conn, 0, len(resources)) + for _, res := range resources { + cr := res.Value() + if p.prepareConn != nil { + ok, err := p.prepareConn(ctx, cr.conn) + if !ok || err != nil { + res.Destroy() + continue + } + } + conns = append(conns, cr.getConn(p, res)) + } + + return conns +} + +// Reset closes all connections, but leaves the pool open. It is intended for use when an error is detected that would +// disrupt all connections (such as a network interruption or a server state change). +// +// It is safe to reset a pool while connections are checked out. Those connections will be closed when they are returned +// to the pool. +func (p *Pool) Reset() { + p.p.Reset() +} + +// Config returns a copy of config that was used to initialize this pool. +func (p *Pool) Config() *Config { return p.config.Copy() } + +// Stat returns a pgxpool.Stat struct with a snapshot of Pool statistics. +func (p *Pool) Stat() *Stat { + return &Stat{ + s: p.p.Stat(), + newConnsCount: atomic.LoadInt64(&p.newConnsCount), + lifetimeDestroyCount: atomic.LoadInt64(&p.lifetimeDestroyCount), + idleDestroyCount: atomic.LoadInt64(&p.idleDestroyCount), + } +} + +// Exec acquires a connection from the Pool and executes the given SQL. +// SQL can be either a prepared statement name or an SQL string. +// Arguments should be referenced positionally from the SQL string as $1, $2, etc. +// The acquired connection is returned to the pool when the Exec function returns. +func (p *Pool) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { + c, err := p.Acquire(ctx) + if err != nil { + return pgconn.CommandTag{}, err + } + defer c.Release() + + return c.Exec(ctx, sql, arguments...) +} + +// Query acquires a connection and executes a query that returns pgx.Rows. +// Arguments should be referenced positionally from the SQL string as $1, $2, etc. +// See pgx.Rows documentation to close the returned Rows and return the acquired connection to the Pool. +// +// If there is an error, the returned pgx.Rows will be returned in an error state. +// If preferred, ignore the error returned from Query and handle errors using the returned pgx.Rows. +// +// For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and +// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely +// needed. See the documentation for those types for details. +func (p *Pool) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { + c, err := p.Acquire(ctx) + if err != nil { + return errRows{err: err}, err + } + + rows, err := c.Query(ctx, sql, args...) + if err != nil { + c.Release() + return errRows{err: err}, err + } + + return c.getPoolRows(rows), nil +} + +// QueryRow acquires a connection and executes a query that is expected +// to return at most one row (pgx.Row). Errors are deferred until pgx.Row's +// Scan method is called. If the query selects no rows, pgx.Row's Scan will +// return ErrNoRows. Otherwise, pgx.Row's Scan scans the first selected row +// and discards the rest. The acquired connection is returned to the Pool when +// pgx.Row's Scan method is called. +// +// Arguments should be referenced positionally from the SQL string as $1, $2, etc. +// +// For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and +// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely +// needed. See the documentation for those types for details. +func (p *Pool) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { + c, err := p.Acquire(ctx) + if err != nil { + return errRow{err: err} + } + + row := c.QueryRow(ctx, sql, args...) + return c.getPoolRow(row) +} + +func (p *Pool) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { + c, err := p.Acquire(ctx) + if err != nil { + return errBatchResults{err: err} + } + + br := c.SendBatch(ctx, b) + return &poolBatchResults{br: br, c: c} +} + +// Begin acquires a connection from the Pool and starts a transaction. Unlike database/sql, the context only affects the begin command. i.e. there is no +// auto-rollback on context cancellation. Begin initiates a transaction block without explicitly setting a transaction mode for the block (see BeginTx with TxOptions if transaction mode is required). +// *pgxpool.Tx is returned, which implements the pgx.Tx interface. +// Commit or Rollback must be called on the returned transaction to finalize the transaction block. +func (p *Pool) Begin(ctx context.Context) (pgx.Tx, error) { + return p.BeginTx(ctx, pgx.TxOptions{}) +} + +// BeginTx acquires a connection from the Pool and starts a transaction with pgx.TxOptions determining the transaction mode. +// Unlike database/sql, the context only affects the begin command. i.e. there is no auto-rollback on context cancellation. +// *pgxpool.Tx is returned, which implements the pgx.Tx interface. +// Commit or Rollback must be called on the returned transaction to finalize the transaction block. +func (p *Pool) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) { + c, err := p.Acquire(ctx) + if err != nil { + return nil, err + } + + t, err := c.BeginTx(ctx, txOptions) + if err != nil { + c.Release() + return nil, err + } + + return &Tx{t: t, c: c}, nil +} + +func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { + c, err := p.Acquire(ctx) + if err != nil { + return 0, err + } + defer c.Release() + + return c.Conn().CopyFrom(ctx, tableName, columnNames, rowSrc) +} + +// Ping acquires a connection from the Pool and executes an empty sql statement against it. +// If the sql returns without error, the database Ping is considered successful, otherwise, the error is returned. +func (p *Pool) Ping(ctx context.Context) error { + c, err := p.Acquire(ctx) + if err != nil { + return err + } + defer c.Release() + return c.Ping(ctx) +} diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go new file mode 100644 index 000000000..9674c6fac --- /dev/null +++ b/pgxpool/pool_test.go @@ -0,0 +1,1334 @@ +package pgxpool_test + +import ( + "context" + "errors" + "fmt" + "math" + "os" + "sync/atomic" + "testing" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConnect(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + connString := os.Getenv("PGX_TEST_DATABASE") + pool, err := pgxpool.New(ctx, connString) + require.NoError(t, err) + assert.Equal(t, connString, pool.Config().ConnString()) + pool.Close() +} + +func TestConnectConfig(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + connString := os.Getenv("PGX_TEST_DATABASE") + config, err := pgxpool.ParseConfig(connString) + require.NoError(t, err) + pool, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + assertConfigsEqual(t, config, pool.Config(), "Pool.Config() returns original config") + pool.Close() +} + +func TestParseConfigExtractsPoolArguments(t *testing.T) { + t.Parallel() + + config, err := pgxpool.ParseConfig("pool_max_conns=42 pool_min_conns=1 pool_min_idle_conns=2") + assert.NoError(t, err) + assert.EqualValues(t, 42, config.MaxConns) + assert.EqualValues(t, 1, config.MinConns) + assert.EqualValues(t, 2, config.MinIdleConns) + assert.NotContains(t, config.ConnConfig.Config.RuntimeParams, "pool_max_conns") + assert.NotContains(t, config.ConnConfig.Config.RuntimeParams, "pool_min_conns") +} + +func TestConstructorIgnoresContext(t *testing.T) { + t.Parallel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + assert.NoError(t, err) + var cancel func() + config.BeforeConnect = func(context.Context, *pgx.ConnConfig) error { + // cancel the query's context before we actually Dial to ensure the Dial's + // context isn't cancelled + cancel() + return nil + } + + pool, err := pgxpool.NewWithConfig(context.Background(), config) + require.NoError(t, err) + + assert.EqualValues(t, 0, pool.Stat().TotalConns()) + + var ctx context.Context + ctx, cancel = context.WithCancel(context.Background()) + defer cancel() + _, err = pool.Exec(ctx, "SELECT 1") + assert.ErrorIs(t, err, context.Canceled) + assert.EqualValues(t, 1, pool.Stat().TotalConns()) +} + +func TestConnectConfigRequiresConnConfigFromParseConfig(t *testing.T) { + t.Parallel() + + config := &pgxpool.Config{} + + require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgxpool.NewWithConfig(context.Background(), config) }) +} + +func TestConfigCopyReturnsEqualConfig(t *testing.T) { + connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" + original, err := pgxpool.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + + assertConfigsEqual(t, original, copied, t.Name()) +} + +func TestConfigCopyCanBeUsedToConnect(t *testing.T) { + connString := os.Getenv("PGX_TEST_DATABASE") + original, err := pgxpool.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + assert.NotPanics(t, func() { + _, err = pgxpool.NewWithConfig(context.Background(), copied) + }) + assert.NoError(t, err) +} + +func TestPoolAcquireAndConnRelease(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + c.Release() +} + +func TestPoolAcquireAndConnHijack(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + + connsBeforeHijack := pool.Stat().TotalConns() + + conn := c.Hijack() + defer conn.Close(ctx) + + connsAfterHijack := pool.Stat().TotalConns() + require.Equal(t, connsBeforeHijack-1, connsAfterHijack) + + var n int32 + err = conn.QueryRow(ctx, `select 1`).Scan(&n) + require.NoError(t, err) + require.Equal(t, int32(1), n) +} + +func TestPoolAcquireChecksIdleConns(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + controllerConn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer controllerConn.Close(ctx) + pgxtest.SkipCockroachDB(t, controllerConn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + var conns []*pgxpool.Conn + for i := 0; i < 3; i++ { + c, err := pool.Acquire(ctx) + require.NoError(t, err) + conns = append(conns, c) + } + + require.EqualValues(t, 3, pool.Stat().TotalConns()) + + var pids []uint32 + for _, c := range conns { + pids = append(pids, c.Conn().PgConn().PID()) + c.Release() + } + + _, err = controllerConn.Exec(ctx, `select pg_terminate_backend(n) from unnest($1::int[]) n`, pids) + require.NoError(t, err) + + // All conns are dead they don't know it and neither does the pool. + require.EqualValues(t, 3, pool.Stat().TotalConns()) + + // Wait long enough so the pool will realize it needs to check the connections. + time.Sleep(time.Second) + + // Pool should try all existing connections and find them dead, then create a new connection which should successfully ping. + err = pool.Ping(ctx) + require.NoError(t, err) + + // The original 3 conns should have been terminated and the a new conn established for the ping. + require.EqualValues(t, 1, pool.Stat().TotalConns()) + c, err := pool.Acquire(ctx) + require.NoError(t, err) + + cPID := c.Conn().PgConn().PID() + c.Release() + + require.NotContains(t, pids, cPID) +} + +func TestPoolAcquireChecksIdleConnsWithShouldPing(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + controllerConn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer controllerConn.Close(ctx) + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + // Replace the default ShouldPing func + var shouldPingLastCalledWith *pgxpool.ShouldPingParams + config.ShouldPing = func(ctx context.Context, params pgxpool.ShouldPingParams) bool { + shouldPingLastCalledWith = ¶ms + return false + } + + pool, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + c.Release() + + time.Sleep(time.Millisecond * 200) + + c, err = pool.Acquire(ctx) + require.NoError(t, err) + conn := c.Conn() + + require.NotNil(t, shouldPingLastCalledWith) + assert.Equal(t, conn, shouldPingLastCalledWith.Conn) + assert.InDelta(t, time.Millisecond*200, shouldPingLastCalledWith.IdleDuration, float64(time.Millisecond*100)) + + c.Release() +} + +// https://github.com/jackc/pgx/issues/2379 +func TestPoolAcquireWithMaxConnsEqualsMaxInt32(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.MaxConns = math.MaxInt32 + + pool, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + c.Release() +} + +func TestPoolAcquireFunc(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + var n int32 + err = pool.AcquireFunc(ctx, func(c *pgxpool.Conn) error { + return c.QueryRow(ctx, "select 1").Scan(&n) + }) + require.NoError(t, err) + require.EqualValues(t, 1, n) +} + +func TestPoolAcquireFuncReturnsFnError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + err = pool.AcquireFunc(ctx, func(c *pgxpool.Conn) error { + return fmt.Errorf("some error") + }) + require.EqualError(t, err, "some error") +} + +func TestPoolBeforeConnect(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.BeforeConnect = func(ctx context.Context, cfg *pgx.ConnConfig) error { + cfg.Config.RuntimeParams["application_name"] = "pgx" + return nil + } + + db, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer db.Close() + + var str string + err = db.QueryRow(ctx, "SHOW application_name").Scan(&str) + require.NoError(t, err) + assert.EqualValues(t, "pgx", str) +} + +func TestPoolAfterConnect(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.AfterConnect = func(ctx context.Context, c *pgx.Conn) error { + _, err := c.Prepare(ctx, "ps1", "select 1") + return err + } + + db, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer db.Close() + + var n int32 + err = db.QueryRow(ctx, "ps1").Scan(&n) + require.NoError(t, err) + assert.EqualValues(t, 1, n) +} + +func TestPoolBeforeAcquire(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + acquireAttempts := 0 + + config.BeforeAcquire = func(ctx context.Context, c *pgx.Conn) bool { + acquireAttempts++ + return acquireAttempts%2 == 0 + } + + db, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer db.Close() + + conns := make([]*pgxpool.Conn, 4) + for i := range conns { + conns[i], err = db.Acquire(ctx) + assert.NoError(t, err) + } + + for _, c := range conns { + c.Release() + } + waitForReleaseToComplete() + + assert.EqualValues(t, 8, acquireAttempts) + + conns = db.AcquireAllIdle(ctx) + assert.Len(t, conns, 2) + + for _, c := range conns { + c.Release() + } + waitForReleaseToComplete() + + assert.EqualValues(t, 12, acquireAttempts) +} + +func TestPoolPrepareConn(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + acquireAttempts := 0 + + config.PrepareConn = func(context.Context, *pgx.Conn) (bool, error) { + acquireAttempts++ + var err error + if acquireAttempts%3 == 0 { + err = errors.New("PrepareConn error") + } + return acquireAttempts%2 == 0, err + } + + db, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + t.Cleanup(db.Close) + + var errorCount int + conns := make([]*pgxpool.Conn, 0, 4) + for { + conn, err := db.Acquire(ctx) + if err != nil { + errorCount++ + continue + } + conns = append(conns, conn) + if len(conns) == 4 { + break + } + } + const wantErrorCount = 3 + assert.Equal(t, wantErrorCount, errorCount, "Acquire() should have failed %d times", wantErrorCount) + + for _, c := range conns { + c.Release() + } + waitForReleaseToComplete() + + assert.EqualValues(t, len(conns)*2+wantErrorCount-1, acquireAttempts) + + conns = db.AcquireAllIdle(ctx) + assert.Len(t, conns, 1) + + for _, c := range conns { + c.Release() + } + waitForReleaseToComplete() + + assert.EqualValues(t, 14, acquireAttempts) +} + +func TestPoolAfterRelease(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + func() { + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + }() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + afterReleaseCount := 0 + + config.AfterRelease = func(c *pgx.Conn) bool { + afterReleaseCount++ + return afterReleaseCount%2 == 1 + } + + db, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer db.Close() + + connPIDs := map[uint32]struct{}{} + + for i := 0; i < 10; i++ { + conn, err := db.Acquire(ctx) + assert.NoError(t, err) + connPIDs[conn.Conn().PgConn().PID()] = struct{}{} + conn.Release() + waitForReleaseToComplete() + } + + assert.EqualValues(t, 5, len(connPIDs)) +} + +func TestPoolBeforeClose(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + func() { + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + }() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + connPIDs := make(chan uint32, 5) + config.BeforeClose = func(c *pgx.Conn) { + connPIDs <- c.PgConn().PID() + } + + db, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer db.Close() + + acquiredPIDs := make([]uint32, 0, 5) + closedPIDs := make([]uint32, 0, 5) + for i := 0; i < 5; i++ { + conn, err := db.Acquire(ctx) + assert.NoError(t, err) + acquiredPIDs = append(acquiredPIDs, conn.Conn().PgConn().PID()) + conn.Release() + db.Reset() + closedPIDs = append(closedPIDs, <-connPIDs) + } + + assert.ElementsMatch(t, acquiredPIDs, closedPIDs) +} + +func TestPoolAcquireAllIdle(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + db, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer db.Close() + + conns := make([]*pgxpool.Conn, 3) + for i := range conns { + conns[i], err = db.Acquire(ctx) + assert.NoError(t, err) + } + + for _, c := range conns { + if c != nil { + c.Release() + } + } + waitForReleaseToComplete() + + conns = db.AcquireAllIdle(ctx) + assert.Len(t, conns, 3) + + for _, c := range conns { + c.Release() + } +} + +func TestPoolReset(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + db, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer db.Close() + + conns := make([]*pgxpool.Conn, 3) + for i := range conns { + conns[i], err = db.Acquire(ctx) + assert.NoError(t, err) + } + + db.Reset() + + for _, c := range conns { + if c != nil { + c.Release() + } + } + waitForReleaseToComplete() + + require.EqualValues(t, 0, db.Stat().TotalConns()) +} + +func TestConnReleaseChecksMaxConnLifetime(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.MaxConnLifetime = 250 * time.Millisecond + + db, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer db.Close() + + c, err := db.Acquire(ctx) + require.NoError(t, err) + + time.Sleep(config.MaxConnLifetime) + + c.Release() + waitForReleaseToComplete() + + stats := db.Stat() + assert.EqualValues(t, 0, stats.TotalConns()) +} + +func TestConnReleaseClosesBusyConn(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + db, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer db.Close() + + c, err := db.Acquire(ctx) + require.NoError(t, err) + + _, err = c.Query(ctx, "select generate_series(1,10)") + require.NoError(t, err) + + c.Release() + waitForReleaseToComplete() + + // wait for the connection to actually be destroyed + for i := 0; i < 1000; i++ { + if db.Stat().TotalConns() == 0 { + break + } + time.Sleep(time.Millisecond) + } + + stats := db.Stat() + assert.EqualValues(t, 0, stats.TotalConns()) +} + +func TestPoolBackgroundChecksMaxConnLifetime(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.MaxConnLifetime = 100 * time.Millisecond + config.HealthCheckPeriod = 100 * time.Millisecond + + db, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer db.Close() + + c, err := db.Acquire(ctx) + require.NoError(t, err) + c.Release() + time.Sleep(config.MaxConnLifetime + 500*time.Millisecond) + + stats := db.Stat() + assert.EqualValues(t, 0, stats.TotalConns()) + assert.EqualValues(t, 0, stats.MaxIdleDestroyCount()) + assert.EqualValues(t, 1, stats.MaxLifetimeDestroyCount()) + assert.EqualValues(t, 1, stats.NewConnsCount()) +} + +func TestPoolBackgroundChecksMaxConnIdleTime(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.MaxConnLifetime = 1 * time.Minute + config.MaxConnIdleTime = 100 * time.Millisecond + config.HealthCheckPeriod = 150 * time.Millisecond + + db, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer db.Close() + + c, err := db.Acquire(ctx) + require.NoError(t, err) + c.Release() + time.Sleep(config.HealthCheckPeriod) + + for i := 0; i < 1000; i++ { + if db.Stat().TotalConns() == 0 { + break + } + time.Sleep(time.Millisecond) + } + + stats := db.Stat() + assert.EqualValues(t, 0, stats.TotalConns()) + assert.EqualValues(t, 1, stats.MaxIdleDestroyCount()) + assert.EqualValues(t, 0, stats.MaxLifetimeDestroyCount()) + assert.EqualValues(t, 1, stats.NewConnsCount()) +} + +func TestPoolBackgroundChecksMinConns(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.HealthCheckPeriod = 100 * time.Millisecond + config.MinConns = 2 + + db, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer db.Close() + + stats := db.Stat() + for !(stats.IdleConns() == 2 && stats.MaxLifetimeDestroyCount() == 0 && stats.NewConnsCount() == 2) && ctx.Err() == nil { + time.Sleep(50 * time.Millisecond) + stats = db.Stat() + } + require.EqualValues(t, 2, stats.IdleConns()) + require.EqualValues(t, 0, stats.MaxLifetimeDestroyCount()) + require.EqualValues(t, 2, stats.NewConnsCount()) + + c, err := db.Acquire(ctx) + require.NoError(t, err) + + stats = db.Stat() + require.EqualValues(t, 1, stats.IdleConns()) + require.EqualValues(t, 0, stats.MaxLifetimeDestroyCount()) + require.EqualValues(t, 2, stats.NewConnsCount()) + + err = c.Conn().Close(ctx) + require.NoError(t, err) + c.Release() + + stats = db.Stat() + for !(stats.IdleConns() == 2 && stats.MaxIdleDestroyCount() == 0 && stats.NewConnsCount() == 3) && ctx.Err() == nil { + time.Sleep(50 * time.Millisecond) + stats = db.Stat() + } + require.EqualValues(t, 2, stats.TotalConns()) + require.EqualValues(t, 0, stats.MaxIdleDestroyCount()) + require.EqualValues(t, 3, stats.NewConnsCount()) +} + +func TestPoolExec(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + testExec(t, ctx, pool) +} + +func TestPoolQuery(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + // Test common usage + testQuery(t, ctx, pool) + waitForReleaseToComplete() + + // Test expected pool behavior + rows, err := pool.Query(ctx, "select generate_series(1,$1)", 10) + require.NoError(t, err) + + stats := pool.Stat() + assert.EqualValues(t, 1, stats.AcquiredConns()) + assert.EqualValues(t, 1, stats.TotalConns()) + + rows.Close() + assert.NoError(t, rows.Err()) + waitForReleaseToComplete() + + stats = pool.Stat() + assert.EqualValues(t, 0, stats.AcquiredConns()) + assert.EqualValues(t, 1, stats.TotalConns()) +} + +func TestPoolQueryRow(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + testQueryRow(t, ctx, pool) + waitForReleaseToComplete() + + stats := pool.Stat() + assert.EqualValues(t, 0, stats.AcquiredConns()) + assert.EqualValues(t, 1, stats.TotalConns()) +} + +// https://github.com/jackc/pgx/issues/677 +func TestPoolQueryRowErrNoRows(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + err = pool.QueryRow(ctx, "select n from generate_series(1,10) n where n=0").Scan(nil) + require.Equal(t, pgx.ErrNoRows, err) +} + +// https://github.com/jackc/pgx/issues/1628 +func TestPoolQueryRowScanPanicReleasesConnection(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + require.Panics(t, func() { + var greeting *string + pool.QueryRow(ctx, "select 'Hello, world!'").Scan(greeting) // Note lack of &. This means that a typed nil is passed to Scan. + }) + + // If the connection is not released this will block forever in the defer pool.Close(). +} + +func TestPoolSendBatch(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + testSendBatch(t, ctx, pool) + waitForReleaseToComplete() + + stats := pool.Stat() + assert.EqualValues(t, 0, stats.AcquiredConns()) + assert.EqualValues(t, 1, stats.TotalConns()) +} + +func TestPoolCopyFrom(t *testing.T) { + // Not able to use testCopyFrom because it relies on temporary tables and the pool may run subsequent calls under + // different connections. + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + _, err = pool.Exec(ctx, `drop table if exists poolcopyfromtest`) + require.NoError(t, err) + + _, err = pool.Exec(ctx, `create table poolcopyfromtest(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`) + require.NoError(t, err) + defer pool.Exec(ctx, `drop table poolcopyfromtest`) + + tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) + + inputRows := [][]any{ + {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, + {nil, nil, nil, nil, nil, nil, nil}, + } + + copyCount, err := pool.CopyFrom(ctx, pgx.Identifier{"poolcopyfromtest"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) + assert.NoError(t, err) + assert.EqualValues(t, len(inputRows), copyCount) + + rows, err := pool.Query(ctx, "select * from poolcopyfromtest") + assert.NoError(t, err) + + var outputRows [][]any + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + assert.NoError(t, rows.Err()) + assert.Equal(t, inputRows, outputRows) +} + +func TestConnReleaseClosesConnInFailedTransaction(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + + pid := c.Conn().PgConn().PID() + + assert.Equal(t, byte('I'), c.Conn().PgConn().TxStatus()) + + _, err = c.Exec(ctx, "begin") + assert.NoError(t, err) + + assert.Equal(t, byte('T'), c.Conn().PgConn().TxStatus()) + + _, err = c.Exec(ctx, "selct") + assert.Error(t, err) + + assert.Equal(t, byte('E'), c.Conn().PgConn().TxStatus()) + + c.Release() + waitForReleaseToComplete() + + c, err = pool.Acquire(ctx) + require.NoError(t, err) + + assert.NotEqual(t, pid, c.Conn().PgConn().PID()) + assert.Equal(t, byte('I'), c.Conn().PgConn().TxStatus()) + + c.Release() +} + +func TestConnReleaseClosesConnInTransaction(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + + pid := c.Conn().PgConn().PID() + + assert.Equal(t, byte('I'), c.Conn().PgConn().TxStatus()) + + _, err = c.Exec(ctx, "begin") + assert.NoError(t, err) + + assert.Equal(t, byte('T'), c.Conn().PgConn().TxStatus()) + + c.Release() + waitForReleaseToComplete() + + c, err = pool.Acquire(ctx) + require.NoError(t, err) + + assert.NotEqual(t, pid, c.Conn().PgConn().PID()) + assert.Equal(t, byte('I'), c.Conn().PgConn().TxStatus()) + + c.Release() +} + +func TestConnReleaseDestroysClosedConn(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + + err = c.Conn().Close(ctx) + require.NoError(t, err) + + assert.EqualValues(t, 1, pool.Stat().TotalConns()) + + c.Release() + waitForReleaseToComplete() + + // wait for the connection to actually be destroyed + for i := 0; i < 1000; i++ { + if pool.Stat().TotalConns() == 0 { + break + } + time.Sleep(time.Millisecond) + } + + assert.EqualValues(t, 0, pool.Stat().TotalConns()) +} + +func TestConnPoolQueryConcurrentLoad(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + n := 100 + done := make(chan bool) + + for i := 0; i < n; i++ { + go func() { + defer func() { done <- true }() + testQuery(t, ctx, pool) + testQueryRow(t, ctx, pool) + }() + } + + for i := 0; i < n; i++ { + <-done + } +} + +func TestConnReleaseWhenBeginFail(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + db, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer db.Close() + + tx, err := db.BeginTx(ctx, pgx.TxOptions{ + IsoLevel: pgx.TxIsoLevel("foo"), + }) + assert.Error(t, err) + if !assert.Zero(t, tx) { + err := tx.Rollback(ctx) + assert.NoError(t, err) + } + + for i := 0; i < 1000; i++ { + if db.Stat().TotalConns() == 0 { + break + } + time.Sleep(time.Millisecond) + } + + assert.EqualValues(t, 0, db.Stat().TotalConns()) +} + +func TestTxBeginFuncNestedTransactionCommit(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + db, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer db.Close() + + createSql := ` + drop table if exists pgxpooltx; + create temporary table pgxpooltx( + id integer, + unique (id) + ); + ` + + _, err = db.Exec(ctx, createSql) + require.NoError(t, err) + + defer func() { + db.Exec(ctx, "drop table pgxpooltx") + }() + + err = pgx.BeginFunc(ctx, db, func(db pgx.Tx) error { + _, err := db.Exec(ctx, "insert into pgxpooltx(id) values (1)") + require.NoError(t, err) + + err = pgx.BeginFunc(ctx, db, func(db pgx.Tx) error { + _, err := db.Exec(ctx, "insert into pgxpooltx(id) values (2)") + require.NoError(t, err) + + err = pgx.BeginFunc(ctx, db, func(db pgx.Tx) error { + _, err := db.Exec(ctx, "insert into pgxpooltx(id) values (3)") + require.NoError(t, err) + return nil + }) + require.NoError(t, err) + return nil + }) + require.NoError(t, err) + return nil + }) + require.NoError(t, err) + + var n int64 + err = db.QueryRow(ctx, "select count(*) from pgxpooltx").Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 3, n) +} + +func TestTxBeginFuncNestedTransactionRollback(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + db, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer db.Close() + + createSql := ` + drop table if exists pgxpooltx; + create temporary table pgxpooltx( + id integer, + unique (id) + ); + ` + + _, err = db.Exec(ctx, createSql) + require.NoError(t, err) + + defer func() { + db.Exec(ctx, "drop table pgxpooltx") + }() + + err = pgx.BeginFunc(ctx, db, func(db pgx.Tx) error { + _, err := db.Exec(ctx, "insert into pgxpooltx(id) values (1)") + require.NoError(t, err) + + err = pgx.BeginFunc(ctx, db, func(db pgx.Tx) error { + _, err := db.Exec(ctx, "insert into pgxpooltx(id) values (2)") + require.NoError(t, err) + return errors.New("do a rollback") + }) + require.EqualError(t, err, "do a rollback") + + _, err = db.Exec(ctx, "insert into pgxpooltx(id) values (3)") + require.NoError(t, err) + + return nil + }) + require.NoError(t, err) + + var n int64 + err = db.QueryRow(ctx, "select count(*) from pgxpooltx").Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 2, n) +} + +func TestIdempotentPoolClose(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + // Close the open pool. + require.NotPanics(t, func() { pool.Close() }) + + // Close the already closed pool. + require.NotPanics(t, func() { pool.Close() }) +} + +func TestConnectEagerlyReachesMinPoolSize(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.MinConns = int32(12) + config.MaxConns = int32(15) + + acquireAttempts := int64(0) + connectAttempts := int64(0) + + config.PrepareConn = func(ctx context.Context, conn *pgx.Conn) (bool, error) { + atomic.AddInt64(&acquireAttempts, 1) + return true, nil + } + config.BeforeConnect = func(ctx context.Context, cfg *pgx.ConnConfig) error { + atomic.AddInt64(&connectAttempts, 1) + return nil + } + + pool, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer pool.Close() + + for i := 0; i < 500; i++ { + time.Sleep(10 * time.Millisecond) + + stat := pool.Stat() + if stat.IdleConns() == 12 && stat.AcquireCount() == 0 && stat.TotalConns() == 12 && atomic.LoadInt64(&acquireAttempts) == 0 && atomic.LoadInt64(&connectAttempts) == 12 { + return + } + } + + t.Fatal("did not reach min pool size") +} + +func TestPoolSendBatchBatchCloseTwice(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + errChan := make(chan error) + testCount := 5000 + + for i := 0; i < testCount; i++ { + go func() { + batch := &pgx.Batch{} + batch.Queue("select 1") + batch.Queue("select 2") + + br := pool.SendBatch(ctx, batch) + defer br.Close() + + var err error + var n int32 + err = br.QueryRow().Scan(&n) + if err != nil { + errChan <- err + return + } + if n != 1 { + errChan <- fmt.Errorf("expected 1 got %v", n) + return + } + + err = br.QueryRow().Scan(&n) + if err != nil { + errChan <- err + return + } + if n != 2 { + errChan <- fmt.Errorf("expected 2 got %v", n) + return + } + + err = br.Close() + errChan <- err + }() + } + + for i := 0; i < testCount; i++ { + err := <-errChan + assert.NoError(t, err) + } +} + +func TestPoolAcquirePingTimeout(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.PingTimeout = 200 * time.Millisecond + config.ConnConfig.DialFunc = newDelayProxyDialFunc(500 * time.Millisecond) + + var conID *uint32 + // Only ping the connection with the original PID to force creation of a new connection + config.ShouldPing = func(_ context.Context, params pgxpool.ShouldPingParams) bool { + if conID != nil && params.Conn.PgConn().PID() == *conID { + return true + } + return false + } + + // Limit to a single connection to ensure the same connection is reused + config.MinConns = 1 + config.MaxConns = 1 + + pool, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + require.EqualValues(t, 1, pool.Stat().TotalConns()) + originalPID := c.Conn().PgConn().PID() + conID = &originalPID + + c.Release() + require.EqualValues(t, 1, pool.Stat().TotalConns()) + + c, err = pool.Acquire(ctx) + require.NoError(t, err) + require.EqualValues(t, 1, pool.Stat().TotalConns()) + newPID := c.Conn().PgConn().PID() + + c.Release() + + require.EqualValues(t, 1, pool.Stat().TotalConns()) + assert.Nil(t, ctx.Err()) + assert.NotEqualValues(t, originalPID, newPID, + "Expected new connection due to ping timeout, but got same connection") +} diff --git a/pgxpool/rows.go b/pgxpool/rows.go new file mode 100644 index 000000000..f834b7ec3 --- /dev/null +++ b/pgxpool/rows.go @@ -0,0 +1,116 @@ +package pgxpool + +import ( + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +type errRows struct { + err error +} + +func (errRows) Close() {} +func (e errRows) Err() error { return e.err } +func (errRows) CommandTag() pgconn.CommandTag { return pgconn.CommandTag{} } +func (errRows) FieldDescriptions() []pgconn.FieldDescription { return nil } +func (errRows) Next() bool { return false } +func (e errRows) Scan(dest ...any) error { return e.err } +func (e errRows) Values() ([]any, error) { return nil, e.err } +func (e errRows) RawValues() [][]byte { return nil } +func (e errRows) Conn() *pgx.Conn { return nil } + +type errRow struct { + err error +} + +func (e errRow) Scan(dest ...any) error { return e.err } + +type poolRows struct { + r pgx.Rows + c *Conn + err error +} + +func (rows *poolRows) Close() { + rows.r.Close() + if rows.c != nil { + rows.c.Release() + rows.c = nil + } +} + +func (rows *poolRows) Err() error { + if rows.err != nil { + return rows.err + } + return rows.r.Err() +} + +func (rows *poolRows) CommandTag() pgconn.CommandTag { + return rows.r.CommandTag() +} + +func (rows *poolRows) FieldDescriptions() []pgconn.FieldDescription { + return rows.r.FieldDescriptions() +} + +func (rows *poolRows) Next() bool { + if rows.err != nil { + return false + } + + n := rows.r.Next() + if !n { + rows.Close() + } + return n +} + +func (rows *poolRows) Scan(dest ...any) error { + err := rows.r.Scan(dest...) + if err != nil { + rows.Close() + } + return err +} + +func (rows *poolRows) Values() ([]any, error) { + values, err := rows.r.Values() + if err != nil { + rows.Close() + } + return values, err +} + +func (rows *poolRows) RawValues() [][]byte { + return rows.r.RawValues() +} + +func (rows *poolRows) Conn() *pgx.Conn { + return rows.r.Conn() +} + +type poolRow struct { + r pgx.Row + c *Conn + err error +} + +func (row *poolRow) Scan(dest ...any) error { + if row.err != nil { + return row.err + } + + panicked := true + defer func() { + if panicked && row.c != nil { + row.c.Release() + } + }() + err := row.r.Scan(dest...) + panicked = false + if row.c != nil { + row.c.Release() + } + return err +} diff --git a/pgxpool/stat.go b/pgxpool/stat.go new file mode 100644 index 000000000..e02b6ac39 --- /dev/null +++ b/pgxpool/stat.go @@ -0,0 +1,91 @@ +package pgxpool + +import ( + "time" + + "github.com/jackc/puddle/v2" +) + +// Stat is a snapshot of Pool statistics. +type Stat struct { + s *puddle.Stat + newConnsCount int64 + lifetimeDestroyCount int64 + idleDestroyCount int64 +} + +// AcquireCount returns the cumulative count of successful acquires from the pool. +func (s *Stat) AcquireCount() int64 { + return s.s.AcquireCount() +} + +// AcquireDuration returns the total duration of all successful acquires from +// the pool. +func (s *Stat) AcquireDuration() time.Duration { + return s.s.AcquireDuration() +} + +// AcquiredConns returns the number of currently acquired connections in the pool. +func (s *Stat) AcquiredConns() int32 { + return s.s.AcquiredResources() +} + +// CanceledAcquireCount returns the cumulative count of acquires from the pool +// that were canceled by a context. +func (s *Stat) CanceledAcquireCount() int64 { + return s.s.CanceledAcquireCount() +} + +// ConstructingConns returns the number of conns with construction in progress in +// the pool. +func (s *Stat) ConstructingConns() int32 { + return s.s.ConstructingResources() +} + +// EmptyAcquireCount returns the cumulative count of successful acquires from the pool +// that waited for a resource to be released or constructed because the pool was +// empty. +func (s *Stat) EmptyAcquireCount() int64 { + return s.s.EmptyAcquireCount() +} + +// IdleConns returns the number of currently idle conns in the pool. +func (s *Stat) IdleConns() int32 { + return s.s.IdleResources() +} + +// MaxConns returns the maximum size of the pool. +func (s *Stat) MaxConns() int32 { + return s.s.MaxResources() +} + +// TotalConns returns the total number of resources currently in the pool. +// The value is the sum of ConstructingConns, AcquiredConns, and +// IdleConns. +func (s *Stat) TotalConns() int32 { + return s.s.TotalResources() +} + +// NewConnsCount returns the cumulative count of new connections opened. +func (s *Stat) NewConnsCount() int64 { + return s.newConnsCount +} + +// MaxLifetimeDestroyCount returns the cumulative count of connections destroyed +// because they exceeded MaxConnLifetime. +func (s *Stat) MaxLifetimeDestroyCount() int64 { + return s.lifetimeDestroyCount +} + +// MaxIdleDestroyCount returns the cumulative count of connections destroyed because +// they exceeded MaxConnIdleTime. +func (s *Stat) MaxIdleDestroyCount() int64 { + return s.idleDestroyCount +} + +// EmptyAcquireWaitTime returns the cumulative time waited for successful acquires +// from the pool for a resource to be released or constructed because the pool was +// empty. +func (s *Stat) EmptyAcquireWaitTime() time.Duration { + return s.s.EmptyAcquireWaitTime() +} diff --git a/pgxpool/tracer.go b/pgxpool/tracer.go new file mode 100644 index 000000000..78b9d15a2 --- /dev/null +++ b/pgxpool/tracer.go @@ -0,0 +1,33 @@ +package pgxpool + +import ( + "context" + + "github.com/jackc/pgx/v5" +) + +// AcquireTracer traces Acquire. +type AcquireTracer interface { + // TraceAcquireStart is called at the beginning of Acquire. + // The returned context is used for the rest of the call and will be passed to the TraceAcquireEnd. + TraceAcquireStart(ctx context.Context, pool *Pool, data TraceAcquireStartData) context.Context + // TraceAcquireEnd is called when a connection has been acquired. + TraceAcquireEnd(ctx context.Context, pool *Pool, data TraceAcquireEndData) +} + +type TraceAcquireStartData struct{} + +type TraceAcquireEndData struct { + Conn *pgx.Conn + Err error +} + +// ReleaseTracer traces Release. +type ReleaseTracer interface { + // TraceRelease is called at the beginning of Release. + TraceRelease(pool *Pool, data TraceReleaseData) +} + +type TraceReleaseData struct { + Conn *pgx.Conn +} diff --git a/pgxpool/tracer_test.go b/pgxpool/tracer_test.go new file mode 100644 index 000000000..10724d94c --- /dev/null +++ b/pgxpool/tracer_test.go @@ -0,0 +1,130 @@ +package pgxpool_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/require" +) + +type testTracer struct { + traceAcquireStart func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context + traceAcquireEnd func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) + traceRelease func(pool *pgxpool.Pool, data pgxpool.TraceReleaseData) +} + +type ctxKey string + +func (tt *testTracer) TraceAcquireStart(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context { + if tt.traceAcquireStart != nil { + return tt.traceAcquireStart(ctx, pool, data) + } + return ctx +} + +func (tt *testTracer) TraceAcquireEnd(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) { + if tt.traceAcquireEnd != nil { + tt.traceAcquireEnd(ctx, pool, data) + } +} + +func (tt *testTracer) TraceRelease(pool *pgxpool.Pool, data pgxpool.TraceReleaseData) { + if tt.traceRelease != nil { + tt.traceRelease(pool, data) + } +} + +func (tt *testTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + return ctx +} + +func (tt *testTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { +} + +func TestTraceAcquire(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.ConnConfig.Tracer = tracer + + pool, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer pool.Close() + + traceAcquireStartCalled := false + tracer.traceAcquireStart = func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context { + traceAcquireStartCalled = true + require.NotNil(t, pool) + return context.WithValue(ctx, ctxKey("fromTraceAcquireStart"), "foo") + } + + traceAcquireEndCalled := false + tracer.traceAcquireEnd = func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) { + traceAcquireEndCalled = true + require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceAcquireStart"))) + require.NotNil(t, pool) + require.NotNil(t, data.Conn) + require.NoError(t, data.Err) + } + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + defer c.Release() + require.True(t, traceAcquireStartCalled) + require.True(t, traceAcquireEndCalled) + + traceAcquireStartCalled = false + traceAcquireEndCalled = false + tracer.traceAcquireEnd = func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) { + traceAcquireEndCalled = true + require.NotNil(t, pool) + require.Nil(t, data.Conn) + require.Error(t, data.Err) + } + + ctx, cancel = context.WithCancel(ctx) + cancel() + _, err = pool.Acquire(ctx) + require.ErrorIs(t, err, context.Canceled) + require.True(t, traceAcquireStartCalled) + require.True(t, traceAcquireEndCalled) +} + +func TestTraceRelease(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.ConnConfig.Tracer = tracer + + pool, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer pool.Close() + + traceReleaseCalled := false + tracer.traceRelease = func(pool *pgxpool.Pool, data pgxpool.TraceReleaseData) { + traceReleaseCalled = true + require.NotNil(t, pool) + require.NotNil(t, data.Conn) + } + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + c.Release() + require.True(t, traceReleaseCalled) +} diff --git a/pgxpool/tx.go b/pgxpool/tx.go new file mode 100644 index 000000000..b49e7f4d9 --- /dev/null +++ b/pgxpool/tx.go @@ -0,0 +1,83 @@ +package pgxpool + +import ( + "context" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +// Tx represents a database transaction acquired from a Pool. +type Tx struct { + t pgx.Tx + c *Conn +} + +// Begin starts a pseudo nested transaction implemented with a savepoint. +func (tx *Tx) Begin(ctx context.Context) (pgx.Tx, error) { + return tx.t.Begin(ctx) +} + +// Commit commits the transaction and returns the associated connection back to the Pool. Commit will return an error +// where errors.Is(ErrTxClosed) is true if the Tx is already closed, but is otherwise safe to call multiple times. If +// the commit fails with a rollback status (e.g. the transaction was already in a broken state) then ErrTxCommitRollback +// will be returned. +func (tx *Tx) Commit(ctx context.Context) error { + err := tx.t.Commit(ctx) + if tx.c != nil { + tx.c.Release() + tx.c = nil + } + return err +} + +// Rollback rolls back the transaction and returns the associated connection back to the Pool. Rollback will return +// where an error where errors.Is(ErrTxClosed) is true if the Tx is already closed, but is otherwise safe to call +// multiple times. Hence, defer tx.Rollback() is safe even if tx.Commit() will be called first in a non-error condition. +func (tx *Tx) Rollback(ctx context.Context) error { + err := tx.t.Rollback(ctx) + if tx.c != nil { + tx.c.Release() + tx.c = nil + } + return err +} + +func (tx *Tx) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { + return tx.t.CopyFrom(ctx, tableName, columnNames, rowSrc) +} + +func (tx *Tx) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { + return tx.t.SendBatch(ctx, b) +} + +func (tx *Tx) LargeObjects() pgx.LargeObjects { + return tx.t.LargeObjects() +} + +// Prepare creates a prepared statement with name and sql. If the name is empty, +// an anonymous prepared statement will be used. sql can contain placeholders +// for bound parameters. These placeholders are referenced positionally as $1, $2, etc. +// +// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same +// name and sql arguments. This allows a code path to Prepare and Query/Exec without +// needing to first check whether the statement has already been prepared. +func (tx *Tx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) { + return tx.t.Prepare(ctx, name, sql) +} + +func (tx *Tx) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { + return tx.t.Exec(ctx, sql, arguments...) +} + +func (tx *Tx) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { + return tx.t.Query(ctx, sql, args...) +} + +func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { + return tx.t.QueryRow(ctx, sql, args...) +} + +func (tx *Tx) Conn() *pgx.Conn { + return tx.t.Conn() +} diff --git a/pgxpool/tx_test.go b/pgxpool/tx_test.go new file mode 100644 index 000000000..e1611e679 --- /dev/null +++ b/pgxpool/tx_test.go @@ -0,0 +1,96 @@ +package pgxpool_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/require" +) + +func TestTxExec(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + tx, err := pool.Begin(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + + testExec(t, ctx, tx) +} + +func TestTxQuery(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + tx, err := pool.Begin(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + + testQuery(t, ctx, tx) +} + +func TestTxQueryRow(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + tx, err := pool.Begin(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + + testQueryRow(t, ctx, tx) +} + +func TestTxSendBatch(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + tx, err := pool.Begin(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + + testSendBatch(t, ctx, tx) +} + +func TestTxCopyFrom(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + tx, err := pool.Begin(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + + testCopyFrom(t, ctx, tx) +} diff --git a/pgxtest/pgxtest.go b/pgxtest/pgxtest.go new file mode 100644 index 000000000..ece6d91b8 --- /dev/null +++ b/pgxtest/pgxtest.go @@ -0,0 +1,173 @@ +// Package pgxtest provides utilities for testing pgx and packages that integrate with pgx. +package pgxtest + +import ( + "context" + "fmt" + "reflect" + "regexp" + "strconv" + "testing" + + "github.com/jackc/pgx/v5" +) + +var AllQueryExecModes = []pgx.QueryExecMode{ + pgx.QueryExecModeCacheStatement, + pgx.QueryExecModeCacheDescribe, + pgx.QueryExecModeDescribeExec, + pgx.QueryExecModeExec, + pgx.QueryExecModeSimpleProtocol, +} + +// KnownOIDQueryExecModes is a slice of all query exec modes where the param and result OIDs are known before sending the query. +var KnownOIDQueryExecModes = []pgx.QueryExecMode{ + pgx.QueryExecModeCacheStatement, + pgx.QueryExecModeCacheDescribe, + pgx.QueryExecModeDescribeExec, +} + +// ConnTestRunner controls how a *pgx.Conn is created and closed by tests. All fields are required. Use DefaultConnTestRunner to get a +// ConnTestRunner with reasonable default values. +type ConnTestRunner struct { + // CreateConfig returns a *pgx.ConnConfig suitable for use with pgx.ConnectConfig. + CreateConfig func(ctx context.Context, t testing.TB) *pgx.ConnConfig + + // AfterConnect is called after conn is established. It allows for arbitrary connection setup before a test begins. + AfterConnect func(ctx context.Context, t testing.TB, conn *pgx.Conn) + + // AfterTest is called after the test is run. It allows for validating the state of the connection before it is closed. + AfterTest func(ctx context.Context, t testing.TB, conn *pgx.Conn) + + // CloseConn closes conn. + CloseConn func(ctx context.Context, t testing.TB, conn *pgx.Conn) +} + +// DefaultConnTestRunner returns a new ConnTestRunner with all fields set to reasonable default values. +func DefaultConnTestRunner() ConnTestRunner { + return ConnTestRunner{ + CreateConfig: func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config, err := pgx.ParseConfig("") + if err != nil { + t.Fatalf("ParseConfig failed: %v", err) + } + return config + }, + AfterConnect: func(ctx context.Context, t testing.TB, conn *pgx.Conn) {}, + AfterTest: func(ctx context.Context, t testing.TB, conn *pgx.Conn) {}, + CloseConn: func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + err := conn.Close(ctx) + if err != nil { + t.Errorf("Close failed: %v", err) + } + }, + } +} + +func (ctr *ConnTestRunner) RunTest(ctx context.Context, t testing.TB, f func(ctx context.Context, t testing.TB, conn *pgx.Conn)) { + t.Helper() + + config := ctr.CreateConfig(ctx, t) + conn, err := pgx.ConnectConfig(ctx, config) + if err != nil { + t.Fatalf("ConnectConfig failed: %v", err) + } + defer ctr.CloseConn(ctx, t, conn) + + ctr.AfterConnect(ctx, t, conn) + f(ctx, t, conn) + ctr.AfterTest(ctx, t, conn) +} + +// RunWithQueryExecModes runs a f in a new test for each element of modes with a new connection created using connector. +// If modes is nil all pgx.QueryExecModes are tested. +func RunWithQueryExecModes(ctx context.Context, t *testing.T, ctr ConnTestRunner, modes []pgx.QueryExecMode, f func(ctx context.Context, t testing.TB, conn *pgx.Conn)) { + if modes == nil { + modes = AllQueryExecModes + } + + for _, mode := range modes { + ctrWithMode := ctr + ctrWithMode.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := ctr.CreateConfig(ctx, t) + config.DefaultQueryExecMode = mode + return config + } + + t.Run(mode.String(), + func(t *testing.T) { + ctrWithMode.RunTest(ctx, t, f) + }, + ) + } +} + +type ValueRoundTripTest struct { + Param any + Result any + Test func(any) bool +} + +func RunValueRoundTripTests( + ctx context.Context, + t testing.TB, + ctr ConnTestRunner, + modes []pgx.QueryExecMode, + pgTypeName string, + tests []ValueRoundTripTest, +) { + t.Helper() + + if modes == nil { + modes = AllQueryExecModes + } + + ctr.RunTest(ctx, t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + t.Helper() + + sql := fmt.Sprintf("select $1::%s", pgTypeName) + + for i, tt := range tests { + for _, mode := range modes { + err := conn.QueryRow(ctx, sql, mode, tt.Param).Scan(tt.Result) + if err != nil { + t.Errorf("%d. %v: %v", i, mode, err) + } + + result := reflect.ValueOf(tt.Result) + if result.Kind() == reflect.Ptr { + result = result.Elem() + } + + if !tt.Test(result.Interface()) { + t.Errorf("%d. %v: unexpected result for %v: %v", i, mode, tt.Param, result.Interface()) + } + } + } + }) +} + +// SkipCockroachDB calls Skip on t with msg if the connection is to a CockroachDB server. +func SkipCockroachDB(t testing.TB, conn *pgx.Conn, msg string) { + if conn.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip(msg) + } +} + +func SkipPostgreSQLVersionLessThan(t testing.TB, conn *pgx.Conn, minVersion int64) { + serverVersionStr := conn.PgConn().ParameterStatus("server_version") + serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr) + // if not PostgreSQL do nothing + if serverVersionStr == "" { + return + } + + serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64) + if err != nil { + t.Fatalf("postgres version parsed failed: %s", err) + } + + if serverVersion < minVersion { + t.Skipf("Test requires PostgreSQL v%d+", minVersion) + } +} diff --git a/pipeline_test.go b/pipeline_test.go new file mode 100644 index 000000000..b8590bf9f --- /dev/null +++ b/pipeline_test.go @@ -0,0 +1,79 @@ +package pgx_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/require" +) + +func TestPipelineWithoutPreparedOrDescribedStatements(t *testing.T) { + t.Parallel() + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pipeline := conn.PgConn().StartPipeline(ctx) + + eqb := pgx.ExtendedQueryBuilder{} + + err := eqb.Build(conn.TypeMap(), nil, []any{1, 2}) + require.NoError(t, err) + pipeline.SendQueryParams(`select $1::bigint + $2::bigint`, eqb.ParamValues, nil, eqb.ParamFormats, eqb.ResultFormats) + + err = eqb.Build(conn.TypeMap(), nil, []any{3, 4, 5}) + require.NoError(t, err) + pipeline.SendQueryParams(`select $1::bigint + $2::bigint + $3::bigint`, eqb.ParamValues, nil, eqb.ParamFormats, eqb.ResultFormats) + + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.True(t, ok) + rows := pgx.RowsFromResultReader(conn.TypeMap(), rr) + + rowCount := 0 + var n int64 + for rows.Next() { + err = rows.Scan(&n) + require.NoError(t, err) + rowCount++ + } + require.NoError(t, rows.Err()) + require.Equal(t, 1, rowCount) + require.Equal(t, "SELECT 1", rows.CommandTag().String()) + require.EqualValues(t, 3, n) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.True(t, ok) + rows = pgx.RowsFromResultReader(conn.TypeMap(), rr) + + rowCount = 0 + n = 0 + for rows.Next() { + err = rows.Scan(&n) + require.NoError(t, err) + rowCount++ + } + require.NoError(t, rows.Err()) + require.Equal(t, 1, rowCount) + require.Equal(t, "SELECT 1", rows.CommandTag().String()) + require.EqualValues(t, 12, n) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.True(t, ok) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + }) +} diff --git a/private_test.go b/private_test.go deleted file mode 100644 index df732a723..000000000 --- a/private_test.go +++ /dev/null @@ -1,7 +0,0 @@ -package pgx - -// This file contains methods that expose internal pgx state to tests. - -func (c *Conn) TxStatus() byte { - return c.txStatus -} diff --git a/query.go b/query.go deleted file mode 100644 index c014cacd8..000000000 --- a/query.go +++ /dev/null @@ -1,546 +0,0 @@ -package pgx - -import ( - "context" - "database/sql" - "fmt" - "reflect" - "time" - - "github.com/pkg/errors" - - "github.com/jackc/pgx/internal/sanitize" - "github.com/jackc/pgx/pgproto3" - "github.com/jackc/pgx/pgtype" -) - -// Row is a convenience wrapper over Rows that is returned by QueryRow. -type Row Rows - -// Scan works the same as (*Rows Scan) with the following exceptions. If no -// rows were found it returns ErrNoRows. If multiple rows are returned it -// ignores all but the first. -func (r *Row) Scan(dest ...interface{}) (err error) { - rows := (*Rows)(r) - - if rows.Err() != nil { - return rows.Err() - } - - if !rows.Next() { - if rows.Err() == nil { - return ErrNoRows - } - return rows.Err() - } - - rows.Scan(dest...) - rows.Close() - return rows.Err() -} - -// Rows is the result set returned from *Conn.Query. Rows must be closed before -// the *Conn can be used again. Rows are closed by explicitly calling Close(), -// calling Next() until it returns false, or when a fatal error occurs. -type Rows struct { - conn *Conn - connPool *ConnPool - batch *Batch - values [][]byte - fields []FieldDescription - rowCount int - columnIdx int - err error - startTime time.Time - sql string - args []interface{} - unlockConn bool - closed bool -} - -func (rows *Rows) FieldDescriptions() []FieldDescription { - return rows.fields -} - -// Close closes the rows, making the connection ready for use again. It is safe -// to call Close after rows is already closed. -func (rows *Rows) Close() { - if rows.closed { - return - } - - if rows.unlockConn { - rows.conn.unlock() - rows.unlockConn = false - } - - rows.closed = true - - rows.err = rows.conn.termContext(rows.err) - - if rows.err == nil { - if rows.conn.shouldLog(LogLevelInfo) { - endTime := time.Now() - rows.conn.log(LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount}) - } - } else if rows.conn.shouldLog(LogLevelError) { - rows.conn.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)}) - } - - if rows.batch != nil && rows.err != nil { - rows.batch.die(rows.err) - } - - if rows.connPool != nil { - rows.connPool.Release(rows.conn) - } -} - -func (rows *Rows) Err() error { - return rows.err -} - -// fatal signals an error occurred after the query was sent to the server. It -// closes the rows automatically. -func (rows *Rows) fatal(err error) { - if rows.err != nil { - return - } - - rows.err = err - rows.Close() -} - -// Next prepares the next row for reading. It returns true if there is another -// row and false if no more rows are available. It automatically closes rows -// when all rows are read. -func (rows *Rows) Next() bool { - if rows.closed { - return false - } - - rows.rowCount++ - rows.columnIdx = 0 - - for { - msg, err := rows.conn.rxMsg() - if err != nil { - rows.fatal(err) - return false - } - - switch msg := msg.(type) { - case *pgproto3.RowDescription: - rows.fields = rows.conn.rxRowDescription(msg) - for i := range rows.fields { - if dt, ok := rows.conn.ConnInfo.DataTypeForOID(rows.fields[i].DataType); ok { - rows.fields[i].DataTypeName = dt.Name - rows.fields[i].FormatCode = TextFormatCode - } else { - rows.fatal(errors.Errorf("unknown oid: %d", rows.fields[i].DataType)) - return false - } - } - case *pgproto3.DataRow: - if len(msg.Values) != len(rows.fields) { - rows.fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), len(msg.Values)))) - return false - } - - rows.values = msg.Values - return true - case *pgproto3.CommandComplete: - if rows.batch != nil { - rows.batch.pendingCommandComplete = false - } - rows.Close() - return false - - default: - err = rows.conn.processContextFreeMsg(msg) - if err != nil { - rows.fatal(err) - return false - } - } - } -} - -func (rows *Rows) nextColumn() ([]byte, *FieldDescription, bool) { - if rows.closed { - return nil, nil, false - } - if len(rows.fields) <= rows.columnIdx { - rows.fatal(ProtocolError("No next column available")) - return nil, nil, false - } - - buf := rows.values[rows.columnIdx] - fd := &rows.fields[rows.columnIdx] - rows.columnIdx++ - return buf, fd, true -} - -type scanArgError struct { - col int - err error -} - -func (e scanArgError) Error() string { - return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err) -} - -// Scan reads the values from the current row into dest values positionally. -// dest can include pointers to core types, values implementing the Scanner -// interface, []byte, and nil. []byte will skip the decoding process and directly -// copy the raw bytes received from PostgreSQL. nil will skip the value entirely. -func (rows *Rows) Scan(dest ...interface{}) (err error) { - if len(rows.fields) != len(dest) { - err = errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields)) - rows.fatal(err) - return err - } - - for i, d := range dest { - buf, fd, _ := rows.nextColumn() - - if d == nil { - continue - } - - if s, ok := d.(pgtype.BinaryDecoder); ok && fd.FormatCode == BinaryFormatCode { - err = s.DecodeBinary(rows.conn.ConnInfo, buf) - if err != nil { - rows.fatal(scanArgError{col: i, err: err}) - } - } else if s, ok := d.(pgtype.TextDecoder); ok && fd.FormatCode == TextFormatCode { - err = s.DecodeText(rows.conn.ConnInfo, buf) - if err != nil { - rows.fatal(scanArgError{col: i, err: err}) - } - } else { - if dt, ok := rows.conn.ConnInfo.DataTypeForOID(fd.DataType); ok { - value := dt.Value - switch fd.FormatCode { - case TextFormatCode: - if textDecoder, ok := value.(pgtype.TextDecoder); ok { - err = textDecoder.DecodeText(rows.conn.ConnInfo, buf) - if err != nil { - rows.fatal(scanArgError{col: i, err: err}) - } - } else { - rows.fatal(scanArgError{col: i, err: errors.Errorf("%T is not a pgtype.TextDecoder", value)}) - } - case BinaryFormatCode: - if binaryDecoder, ok := value.(pgtype.BinaryDecoder); ok { - err = binaryDecoder.DecodeBinary(rows.conn.ConnInfo, buf) - if err != nil { - rows.fatal(scanArgError{col: i, err: err}) - } - } else { - rows.fatal(scanArgError{col: i, err: errors.Errorf("%T is not a pgtype.BinaryDecoder", value)}) - } - default: - rows.fatal(scanArgError{col: i, err: errors.Errorf("unknown format code: %v", fd.FormatCode)}) - } - - if rows.Err() == nil { - if scanner, ok := d.(sql.Scanner); ok { - sqlSrc, err := pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value) - if err != nil { - rows.fatal(err) - } - err = scanner.Scan(sqlSrc) - if err != nil { - rows.fatal(scanArgError{col: i, err: err}) - } - } else if err := value.AssignTo(d); err != nil { - rows.fatal(scanArgError{col: i, err: err}) - } - } - } else { - rows.fatal(scanArgError{col: i, err: errors.Errorf("unknown oid: %v", fd.DataType)}) - } - } - - if rows.Err() != nil { - return rows.Err() - } - } - - return nil -} - -// Values returns an array of the row values -func (rows *Rows) Values() ([]interface{}, error) { - if rows.closed { - return nil, errors.New("rows is closed") - } - - values := make([]interface{}, 0, len(rows.fields)) - - for range rows.fields { - buf, fd, _ := rows.nextColumn() - - if buf == nil { - values = append(values, nil) - continue - } - - if dt, ok := rows.conn.ConnInfo.DataTypeForOID(fd.DataType); ok { - value := reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(pgtype.Value) - - switch fd.FormatCode { - case TextFormatCode: - decoder := value.(pgtype.TextDecoder) - if decoder == nil { - decoder = &pgtype.GenericText{} - } - err := decoder.DecodeText(rows.conn.ConnInfo, buf) - if err != nil { - rows.fatal(err) - } - values = append(values, decoder.(pgtype.Value).Get()) - case BinaryFormatCode: - decoder := value.(pgtype.BinaryDecoder) - if decoder == nil { - decoder = &pgtype.GenericBinary{} - } - err := decoder.DecodeBinary(rows.conn.ConnInfo, buf) - if err != nil { - rows.fatal(err) - } - values = append(values, value.Get()) - default: - rows.fatal(errors.New("Unknown format code")) - } - } else { - rows.fatal(errors.New("Unknown type")) - } - - if rows.Err() != nil { - return nil, rows.Err() - } - } - - return values, rows.Err() -} - -// Query executes sql with args. If there is an error the returned *Rows will -// be returned in an error state. So it is allowed to ignore the error returned -// from Query and handle it in *Rows. -func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) { - return c.QueryEx(context.Background(), sql, nil, args...) -} - -func (c *Conn) getRows(sql string, args []interface{}) *Rows { - if len(c.preallocatedRows) == 0 { - c.preallocatedRows = make([]Rows, 64) - } - - r := &c.preallocatedRows[len(c.preallocatedRows)-1] - c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1] - - r.conn = c - r.startTime = c.lastActivityTime - r.sql = sql - r.args = args - - return r -} - -// QueryRow is a convenience wrapper over Query. Any error that occurs while -// querying is deferred until calling Scan on the returned *Row. That *Row will -// error with ErrNoRows if no rows are returned. -func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { - rows, _ := c.Query(sql, args...) - return (*Row)(rows) -} - -type QueryExOptions struct { - // When ParameterOIDs are present and the query is not a prepared statement, - // then ParameterOIDs and ResultFormatCodes will be used to avoid an extra - // network round-trip. - ParameterOIDs []pgtype.OID - ResultFormatCodes []int16 - - SimpleProtocol bool -} - -func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) { - c.lastActivityTime = time.Now() - rows = c.getRows(sql, args) - - err = c.waitForPreviousCancelQuery(ctx) - if err != nil { - rows.fatal(err) - return rows, err - } - - if err := c.ensureConnectionReadyForQuery(); err != nil { - rows.fatal(err) - return rows, err - } - - if err := c.lock(); err != nil { - rows.fatal(err) - return rows, err - } - rows.unlockConn = true - - err = c.initContext(ctx) - if err != nil { - rows.fatal(err) - return rows, rows.err - } - - if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) { - err = c.sanitizeAndSendSimpleQuery(sql, args...) - if err != nil { - rows.fatal(err) - return rows, err - } - - return rows, nil - } - - if options != nil && len(options.ParameterOIDs) > 0 { - - buf, err := c.buildOneRoundTripQueryEx(c.wbuf, sql, options, args) - if err != nil { - rows.fatal(err) - return rows, err - } - - buf = appendSync(buf) - - n, err := c.conn.Write(buf) - if err != nil && fatalWriteErr(n, err) { - rows.fatal(err) - c.die(err) - return rows, err - } - c.pendingReadyForQueryCount++ - - fieldDescriptions, err := c.readUntilRowDescription() - if err != nil { - rows.fatal(err) - return rows, err - } - - if len(options.ResultFormatCodes) == 0 { - for i := range fieldDescriptions { - fieldDescriptions[i].FormatCode = TextFormatCode - } - } else if len(options.ResultFormatCodes) == 1 { - fc := options.ResultFormatCodes[0] - for i := range fieldDescriptions { - fieldDescriptions[i].FormatCode = fc - } - } else { - for i := range options.ResultFormatCodes { - fieldDescriptions[i].FormatCode = options.ResultFormatCodes[i] - } - } - - rows.sql = sql - rows.fields = fieldDescriptions - return rows, nil - } - - ps, ok := c.preparedStatements[sql] - if !ok { - var err error - ps, err = c.prepareEx("", sql, nil) - if err != nil { - rows.fatal(err) - return rows, rows.err - } - } - rows.sql = ps.SQL - rows.fields = ps.FieldDescriptions - - err = c.sendPreparedQuery(ps, args...) - if err != nil { - rows.fatal(err) - } - - return rows, rows.err -} - -func (c *Conn) buildOneRoundTripQueryEx(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) { - if len(arguments) != len(options.ParameterOIDs) { - return nil, errors.Errorf("mismatched number of arguments (%d) and options.ParameterOIDs (%d)", len(arguments), len(options.ParameterOIDs)) - } - - if len(options.ParameterOIDs) > 65535 { - return nil, errors.Errorf("Number of QueryExOptions ParameterOIDs must be between 0 and 65535, received %d", len(options.ParameterOIDs)) - } - - buf = appendParse(buf, "", sql, options.ParameterOIDs) - buf = appendDescribe(buf, 'S', "") - buf, err := appendBind(buf, "", "", c.ConnInfo, options.ParameterOIDs, arguments, options.ResultFormatCodes) - if err != nil { - return nil, err - } - buf = appendExecute(buf, "", 0) - - return buf, nil -} - -func (c *Conn) readUntilRowDescription() ([]FieldDescription, error) { - for { - msg, err := c.rxMsg() - if err != nil { - return nil, err - } - - switch msg := msg.(type) { - case *pgproto3.ParameterDescription: - case *pgproto3.RowDescription: - fieldDescriptions := c.rxRowDescription(msg) - for i := range fieldDescriptions { - if dt, ok := c.ConnInfo.DataTypeForOID(fieldDescriptions[i].DataType); ok { - fieldDescriptions[i].DataTypeName = dt.Name - } else { - return nil, errors.Errorf("unknown oid: %d", fieldDescriptions[i].DataType) - } - } - return fieldDescriptions, nil - default: - if err := c.processContextFreeMsg(msg); err != nil { - return nil, err - } - } - } -} - -func (c *Conn) sanitizeAndSendSimpleQuery(sql string, args ...interface{}) (err error) { - if c.RuntimeParams["standard_conforming_strings"] != "on" { - return errors.New("simple protocol queries must be run with standard_conforming_strings=on") - } - - if c.RuntimeParams["client_encoding"] != "UTF8" { - return errors.New("simple protocol queries must be run with client_encoding=UTF8") - } - - valueArgs := make([]interface{}, len(args)) - for i, a := range args { - valueArgs[i], err = convertSimpleArgument(c.ConnInfo, a) - if err != nil { - return err - } - } - - sql, err = sanitize.SanitizeSQL(sql, valueArgs...) - if err != nil { - return err - } - - return c.sendSimpleQuery(sql) -} - -func (c *Conn) QueryRowEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) *Row { - rows, _ := c.QueryEx(ctx, sql, options, args...) - return (*Row)(rows) -} diff --git a/query_test.go b/query_test.go index 6b6b5facb..8c70d51e4 100644 --- a/query_test.go +++ b/query_test.go @@ -4,29 +4,33 @@ import ( "bytes" "context" "database/sql" + "database/sql/driver" + "encoding/json" + "errors" "fmt" - "reflect" + "os" + "strconv" "strings" "testing" "time" - "github.com/cockroachdb/apd" - "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" - satori "github.com/jackc/pgx/pgtype/ext/satori-uuid" - "github.com/satori/go.uuid" - "github.com/shopspring/decimal" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestConnQueryScan(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) var sum, rowCount int32 - rows, err := conn.Query("select generate_series(1,$1)", 10) + rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10) if err != nil { t.Fatalf("conn.Query failed: %v", err) } @@ -40,9 +44,11 @@ func TestConnQueryScan(t *testing.T) { } if rows.Err() != nil { - t.Fatalf("conn.Query failed: %v", err) + t.Fatalf("conn.Query failed: %v", rows.Err()) } + assert.Equal(t, "SELECT 10", rows.CommandTag().String()) + if rowCount != 10 { t.Error("Select called onDataRow wrong number of times") } @@ -51,10 +57,37 @@ func TestConnQueryScan(t *testing.T) { } } +func TestConnQueryRowsFieldDescriptionsBeforeNext(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + rows, err := conn.Query(context.Background(), "select 'hello' as msg") + require.NoError(t, err) + defer rows.Close() + + require.Len(t, rows.FieldDescriptions(), 1) + assert.Equal(t, "msg", rows.FieldDescriptions()[0].Name) +} + +func TestConnQueryWithoutResultSetCommandTag(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + rows, err := conn.Query(context.Background(), "create temporary table t (id serial);") + assert.NoError(t, err) + rows.Close() + assert.NoError(t, rows.Err()) + assert.Equal(t, "CREATE TABLE", rows.CommandTag().String()) +} + func TestConnQueryScanWithManyColumns(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) columnCount := 1000 @@ -71,14 +104,14 @@ func TestConnQueryScanWithManyColumns(t *testing.T) { var rowCount int - rows, err := conn.Query(sql) + rows, err := conn.Query(context.Background(), sql) if err != nil { t.Fatalf("conn.Query failed: %v", err) } defer rows.Close() for rows.Next() { - destPtrs := make([]interface{}, columnCount) + destPtrs := make([]any, columnCount) for i := range destPtrs { destPtrs[i] = &dest[i] } @@ -95,7 +128,7 @@ func TestConnQueryScanWithManyColumns(t *testing.T) { } if rows.Err() != nil { - t.Fatalf("conn.Query failed: %v", err) + t.Fatalf("conn.Query failed: %v", rows.Err()) } if rowCount != 5 { @@ -106,12 +139,12 @@ func TestConnQueryScanWithManyColumns(t *testing.T) { func TestConnQueryValues(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) var rowCount int32 - rows, err := conn.Query("select 'foo'::text, 'bar'::varchar, n, null, n from generate_series(1,$1) n", 10) + rows, err := conn.Query(context.Background(), "select 'foo'::text, 'bar'::varchar, n, null, n from generate_series(1,$1) n", 10) if err != nil { t.Fatalf("conn.Query failed: %v", err) } @@ -121,34 +154,17 @@ func TestConnQueryValues(t *testing.T) { rowCount++ values, err := rows.Values() - if err != nil { - t.Fatalf("rows.Values failed: %v", err) - } - if len(values) != 5 { - t.Errorf("Expected rows.Values to return 5 values, but it returned %d", len(values)) - } - if values[0] != "foo" { - t.Errorf(`Expected values[0] to be "foo", but it was %v`, values[0]) - } - if values[1] != "bar" { - t.Errorf(`Expected values[1] to be "bar", but it was %v`, values[1]) - } - - if values[2] != rowCount { - t.Errorf(`Expected values[2] to be %d, but it was %d`, rowCount, values[2]) - } - - if values[3] != nil { - t.Errorf(`Expected values[3] to be %v, but it was %d`, nil, values[3]) - } - - if values[4] != rowCount { - t.Errorf(`Expected values[4] to be %d, but it was %d`, rowCount, values[4]) - } + require.NoError(t, err) + require.Len(t, values, 5) + assert.Equal(t, "foo", values[0]) + assert.Equal(t, "bar", values[1]) + assert.EqualValues(t, rowCount, values[2]) + assert.Nil(t, values[3]) + assert.EqualValues(t, rowCount, values[4]) } if rows.Err() != nil { - t.Fatalf("conn.Query failed: %v", err) + t.Fatalf("conn.Query failed: %v", rows.Err()) } if rowCount != 10 { @@ -156,94 +172,180 @@ func TestConnQueryValues(t *testing.T) { } } -// https://github.com/jackc/pgx/issues/386 -func TestConnQueryValuesWithMultipleComplexColumnsOfSameType(t *testing.T) { +// https://github.com/jackc/pgx/issues/666 +func TestConnQueryValuesWhenUnableToDecode(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - expected0 := &pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Int: 3, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, - Status: pgtype.Present, - } + // Note that this relies on pgtype.Record not supporting the text protocol. This seems safe as it is impossible to + // decode the text protocol because unlike the binary protocol there is no way to determine the OIDs of the elements. + rows, err := conn.Query(context.Background(), "select (array[1::oid], null)", pgx.QueryResultFormats{pgx.TextFormatCode}) + require.NoError(t, err) + defer rows.Close() - expected1 := &pgtype.Int8Array{ - Elements: []pgtype.Int8{ - {Int: 4, Status: pgtype.Present}, - {Int: 5, Status: pgtype.Present}, - {Int: 6, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, - Status: pgtype.Present, - } + require.True(t, rows.Next()) + + values, err := rows.Values() + require.NoError(t, err) + require.Equal(t, "({1},)", values[0]) +} + +func TestConnQueryValuesWithUnregisteredOID(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + tx, err := conn.Begin(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + + _, err = tx.Exec(ctx, "create type fruit as enum('orange', 'apple', 'pear')") + require.NoError(t, err) + + rows, err := conn.Query(context.Background(), "select 'orange'::fruit") + require.NoError(t, err) + defer rows.Close() + + require.True(t, rows.Next()) + + values, err := rows.Values() + require.NoError(t, err) + require.Equal(t, "orange", values[0]) +} + +func TestConnQueryArgsAndScanWithUnregisteredOID(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + tx, err := conn.Begin(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + + _, err = tx.Exec(ctx, "create type fruit as enum('orange', 'apple', 'pear')") + require.NoError(t, err) + + var result string + err = conn.QueryRow(ctx, "select $1::fruit", "orange").Scan(&result) + require.NoError(t, err) + require.Equal(t, "orange", result) + }) +} + +// https://github.com/jackc/pgx/issues/478 +func TestConnQueryReadRowMultipleTimes(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) var rowCount int32 - rows, err := conn.Query("select '{1,2,3}'::bigint[], '{4,5,6}'::bigint[] from generate_series(1,$1) n", 10) - if err != nil { - t.Fatalf("conn.Query failed: %v", err) - } + rows, err := conn.Query(context.Background(), "select 'foo'::text, 'bar'::varchar, n, null, n from generate_series(1,$1) n", 10) + require.NoError(t, err) defer rows.Close() for rows.Next() { rowCount++ - values, err := rows.Values() - if err != nil { - t.Fatalf("rows.Values failed: %v", err) + for i := 0; i < 2; i++ { + values, err := rows.Values() + require.NoError(t, err) + require.Len(t, values, 5) + require.Equal(t, "foo", values[0]) + require.Equal(t, "bar", values[1]) + require.EqualValues(t, rowCount, values[2]) + require.Nil(t, values[3]) + require.EqualValues(t, rowCount, values[4]) + + var a, b string + var c int32 + var d pgtype.Text + var e int32 + + err = rows.Scan(&a, &b, &c, &d, &e) + require.NoError(t, err) + require.Equal(t, "foo", a) + require.Equal(t, "bar", b) + require.Equal(t, rowCount, c) + require.False(t, d.Valid) + require.Equal(t, rowCount, e) } - if len(values) != 2 { - t.Errorf("Expected rows.Values to return 2 values, but it returned %d", len(values)) - } - if !reflect.DeepEqual(values[0], expected0) { - t.Errorf(`Expected values[0] to be %v, but it was %v`, expected0, values[0]) - } - if !reflect.DeepEqual(values[1], expected1) { - t.Errorf(`Expected values[1] to be %v, but it was %v`, expected1, values[1]) - } - } - - if rows.Err() != nil { - t.Fatalf("conn.Query failed: %v", err) } - if rowCount != 10 { - t.Error("Select called onDataRow wrong number of times") - } + require.NoError(t, rows.Err()) + require.Equal(t, int32(10), rowCount) } // https://github.com/jackc/pgx/issues/228 func TestRowsScanDoesNotAllowScanningBinaryFormatValuesIntoString(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) + pgxtest.SkipCockroachDB(t, conn, "Server does not support point type") + var s string - err := conn.QueryRow("select 1").Scan(&s) - if err == nil || !(strings.Contains(err.Error(), "cannot decode binary value into string") || strings.Contains(err.Error(), "cannot assign")) { - t.Fatalf("Expected Scan to fail to encode binary value into string but: %v", err) + err := conn.QueryRow(context.Background(), "select point(1,2)").Scan(&s) + if err == nil || !(strings.Contains(err.Error(), "cannot scan point (OID 600) in binary format into *string")) { + t.Fatalf("Expected Scan to fail to scan binary value into string but: %v", err) } ensureConnValid(t, conn) } +func TestConnQueryRawValues(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + var rowCount int32 + + rows, err := conn.Query( + context.Background(), + "select 'foo'::text, 'bar'::varchar, n, null, n from generate_series(1,$1) n", + pgx.QueryExecModeSimpleProtocol, + 10, + ) + require.NoError(t, err) + defer rows.Close() + + for rows.Next() { + rowCount++ + + rawValues := rows.RawValues() + assert.Len(t, rawValues, 5) + assert.Equal(t, "foo", string(rawValues[0])) + assert.Equal(t, "bar", string(rawValues[1])) + assert.Equal(t, strconv.FormatInt(int64(rowCount), 10), string(rawValues[2])) + assert.Nil(t, rawValues[3]) + assert.Equal(t, strconv.FormatInt(int64(rowCount), 10), string(rawValues[4])) + } + + require.NoError(t, rows.Err()) + assert.EqualValues(t, 10, rowCount) +} + // Test that a connection stays valid when query results are closed early func TestConnQueryCloseEarly(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) // Immediately close query without reading any rows - rows, err := conn.Query("select generate_series(1,$1)", 10) + rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10) if err != nil { t.Fatalf("conn.Query failed: %v", err) } @@ -252,7 +354,7 @@ func TestConnQueryCloseEarly(t *testing.T) { ensureConnValid(t, conn) // Read partial response then close - rows, err = conn.Query("select generate_series(1,$1)", 10) + rows, err = conn.Query(context.Background(), "select generate_series(1,$1)", 10) if err != nil { t.Fatalf("conn.Query failed: %v", err) } @@ -276,13 +378,14 @@ func TestConnQueryCloseEarly(t *testing.T) { func TestConnQueryCloseEarlyWithErrorOnWire(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - rows, err := conn.Query("select 1/(10-n) from generate_series(1,10) n") + rows, err := conn.Query(context.Background(), "select 1/(10-n) from generate_series(1,10) n") if err != nil { t.Fatalf("conn.Query failed: %v", err) } + assert.False(t, pgconn.SafeToRetry(err)) rows.Close() ensureConnValid(t, conn) @@ -292,11 +395,11 @@ func TestConnQueryCloseEarlyWithErrorOnWire(t *testing.T) { func TestConnQueryReadWrongTypeError(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) // Read a single value incorrectly - rows, err := conn.Query("select generate_series(1,$1)", 10) + rows, err := conn.Query(context.Background(), "select n::int4 from generate_series(1,$1) n", 10) if err != nil { t.Fatalf("conn.Query failed: %v", err) } @@ -317,7 +420,7 @@ func TestConnQueryReadWrongTypeError(t *testing.T) { t.Fatal("Expected Rows to have an error after an improper read but it didn't") } - if rows.Err().Error() != "can't scan into dest[0]: Can't convert OID 23 to time.Time" && !strings.Contains(rows.Err().Error(), "cannot assign") { + if rows.Err().Error() != "can't scan into dest[0] (col: n): cannot scan int4 (OID 23) in binary format into *time.Time" { t.Fatalf("Expected different Rows.Err(): %v", rows.Err()) } @@ -328,11 +431,11 @@ func TestConnQueryReadWrongTypeError(t *testing.T) { func TestConnQueryReadTooManyValues(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) // Read too many values - rows, err := conn.Query("select generate_series(1,$1)", 10) + rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10) if err != nil { t.Fatalf("conn.Query failed: %v", err) } @@ -359,10 +462,10 @@ func TestConnQueryReadTooManyValues(t *testing.T) { func TestConnQueryScanIgnoreColumn(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - rows, err := conn.Query("select 1::int8, 2::int8, 3::int8") + rows, err := conn.Query(context.Background(), "select 1::int8, 2::int8, 3::int8") if err != nil { t.Fatalf("conn.Query failed: %v", err) } @@ -390,17 +493,62 @@ func TestConnQueryScanIgnoreColumn(t *testing.T) { ensureConnValid(t, conn) } +// https://github.com/jackc/pgx/issues/570 +func TestConnQueryDeferredError(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + + mustExec(t, conn, `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred +); + +insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) + + rows, err := conn.Query(context.Background(), `update t set n=n+1 where id='b' returning *`) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + for rows.Next() { + var id string + var n int32 + err = rows.Scan(&id, &n) + if err != nil { + t.Fatal(err) + } + } + + if rows.Err() == nil { + t.Fatal("expected error 23505 but got none") + } + + if err, ok := rows.Err().(*pgconn.PgError); !ok || err.Code != "23505" { + t.Fatalf("expected error 23505, got %v", err) + } + + ensureConnValid(t, conn) +} + func TestConnQueryErrorWhileReturningRows(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) + pgxtest.SkipCockroachDB(t, conn, "Server uses numeric instead of int") + for i := 0; i < 100; i++ { func() { sql := `select 42 / (random() * 20)::integer from generate_series(1,100000)` - rows, err := conn.Query(sql) + rows, err := conn.Query(context.Background(), sql) if err != nil { t.Fatal(err) } @@ -408,29 +556,31 @@ func TestConnQueryErrorWhileReturningRows(t *testing.T) { for rows.Next() { var n int32 - rows.Scan(&n) + if err := rows.Scan(&n); err != nil { + t.Fatalf("Row scan failed: %v", err) + } } - if err, ok := rows.Err().(pgx.PgError); !ok { - t.Fatalf("Expected pgx.PgError, got %v", err) + if _, ok := rows.Err().(*pgconn.PgError); !ok { + t.Fatalf("Expected pgconn.PgError, got %v", rows.Err()) } ensureConnValid(t, conn) }() } - } func TestQueryEncodeError(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - rows, err := conn.Query("select $1::integer", "wrong") + rows, err := conn.Query(context.Background(), "select $1::integer", "wrong") if err != nil { t.Errorf("conn.Query failure: %v", err) } + assert.False(t, pgconn.SafeToRetry(err)) defer rows.Close() rows.Next() @@ -438,7 +588,7 @@ func TestQueryEncodeError(t *testing.T) { if rows.Err() == nil { t.Error("Expected rows.Err() to return error, but it didn't") } - if rows.Err().Error() != `ERROR: invalid input syntax for integer: "wrong" (SQLSTATE 22P02)` { + if !strings.Contains(rows.Err().Error(), "SQLSTATE 22P02") { t.Error("Expected rows.Err() to return different error:", rows.Err()) } } @@ -446,7 +596,7 @@ func TestQueryEncodeError(t *testing.T) { func TestQueryRowCoreTypes(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) type allTypes struct { @@ -455,43 +605,43 @@ func TestQueryRowCoreTypes(t *testing.T) { f64 float64 b bool t time.Time - oid pgtype.OID + oid uint32 } var actual, zero allTypes tests := []struct { sql string - queryArgs []interface{} - scanArgs []interface{} + queryArgs []any + scanArgs []any expected allTypes }{ - {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.s}, allTypes{s: "Jack"}}, - {"select $1::float4", []interface{}{float32(1.23)}, []interface{}{&actual.f32}, allTypes{f32: 1.23}}, - {"select $1::float8", []interface{}{float64(1.23)}, []interface{}{&actual.f64}, allTypes{f64: 1.23}}, - {"select $1::bool", []interface{}{true}, []interface{}{&actual.b}, allTypes{b: true}}, - {"select $1::timestamptz", []interface{}{time.Unix(123, 5000)}, []interface{}{&actual.t}, allTypes{t: time.Unix(123, 5000)}}, - {"select $1::timestamp", []interface{}{time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}, []interface{}{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}}, - {"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}}, - {"select $1::oid", []interface{}{pgtype.OID(42)}, []interface{}{&actual.oid}, allTypes{oid: 42}}, + {"select $1::text", []any{"Jack"}, []any{&actual.s}, allTypes{s: "Jack"}}, + {"select $1::float4", []any{float32(1.23)}, []any{&actual.f32}, allTypes{f32: 1.23}}, + {"select $1::float8", []any{float64(1.23)}, []any{&actual.f64}, allTypes{f64: 1.23}}, + {"select $1::bool", []any{true}, []any{&actual.b}, allTypes{b: true}}, + {"select $1::timestamptz", []any{time.Unix(123, 5000)}, []any{&actual.t}, allTypes{t: time.Unix(123, 5000)}}, + {"select $1::timestamp", []any{time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}, []any{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}}, + {"select $1::date", []any{time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}, []any{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}}, + {"select $1::oid", []any{uint32(42)}, []any{&actual.oid}, allTypes{oid: 42}}, } for i, tt := range tests { actual = zero - err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) + err := conn.QueryRow(context.Background(), tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) if err != nil { t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs) } - if actual != tt.expected { + if actual.s != tt.expected.s || actual.f32 != tt.expected.f32 || actual.b != tt.expected.b || !actual.t.Equal(tt.expected.t) || actual.oid != tt.expected.oid { t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, tt.sql, tt.queryArgs) } ensureConnValid(t, conn) // Check that Scan errors when a core type is null - err = conn.QueryRow(tt.sql, nil).Scan(tt.scanArgs...) + err = conn.QueryRow(context.Background(), tt.sql, nil).Scan(tt.scanArgs...) if err == nil { t.Errorf("%d. Expected null to cause error, but it didn't (sql -> %v)", i, tt.sql) } @@ -503,28 +653,21 @@ func TestQueryRowCoreTypes(t *testing.T) { func TestQueryRowCoreIntegerEncoding(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) type allTypes struct { - ui uint - ui8 uint8 - ui16 uint16 - ui32 uint32 - ui64 uint64 - i int - i8 int8 - i16 int16 - i32 int32 - i64 int64 + i16 int16 + i32 int32 + i64 int64 } var actual, zero allTypes successfulEncodeTests := []struct { sql string - queryArg interface{} - scanArg interface{} + queryArg any + scanArg any expected allTypes }{ // Check any integer type where value is within int2 range can be encoded @@ -567,7 +710,7 @@ func TestQueryRowCoreIntegerEncoding(t *testing.T) { for i, tt := range successfulEncodeTests { actual = zero - err := conn.QueryRow(tt.sql, tt.queryArg).Scan(tt.scanArg) + err := conn.QueryRow(context.Background(), tt.sql, tt.queryArg).Scan(tt.scanArg) if err != nil { t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArg -> %v)", i, err, tt.sql, tt.queryArg) continue @@ -582,7 +725,7 @@ func TestQueryRowCoreIntegerEncoding(t *testing.T) { failedEncodeTests := []struct { sql string - queryArg interface{} + queryArg any }{ // Check any integer type where value is outside pg:int2 range cannot be encoded {"select $1::int2", int(32769)}, @@ -604,7 +747,7 @@ func TestQueryRowCoreIntegerEncoding(t *testing.T) { } for i, tt := range failedEncodeTests { - err := conn.QueryRow(tt.sql, tt.queryArg).Scan(nil) + err := conn.QueryRow(context.Background(), tt.sql, tt.queryArg).Scan(nil) if err == nil { t.Errorf("%d. Expected failure to encode, but unexpectedly succeeded: %v (sql -> %v, queryArg -> %v)", i, err, tt.sql, tt.queryArg) } else if !strings.Contains(err.Error(), "is greater than") { @@ -618,7 +761,7 @@ func TestQueryRowCoreIntegerEncoding(t *testing.T) { func TestQueryRowCoreIntegerDecoding(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) type allTypes struct { @@ -638,7 +781,7 @@ func TestQueryRowCoreIntegerDecoding(t *testing.T) { successfulDecodeTests := []struct { sql string - scanArg interface{} + scanArg any expected allTypes }{ // Check any integer type where value is within Go:int range can be decoded @@ -710,7 +853,7 @@ func TestQueryRowCoreIntegerDecoding(t *testing.T) { for i, tt := range successfulDecodeTests { actual = zero - err := conn.QueryRow(tt.sql).Scan(tt.scanArg) + err := conn.QueryRow(context.Background(), tt.sql).Scan(tt.scanArg) if err != nil { t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql) continue @@ -724,65 +867,64 @@ func TestQueryRowCoreIntegerDecoding(t *testing.T) { } failedDecodeTests := []struct { - sql string - scanArg interface{} - expectedErr string + sql string + scanArg any }{ // Check any integer type where value is outside Go:int8 range cannot be decoded - {"select 128::int2", &actual.i8, "is greater than"}, - {"select 128::int4", &actual.i8, "is greater than"}, - {"select 128::int8", &actual.i8, "is greater than"}, - {"select -129::int2", &actual.i8, "is less than"}, - {"select -129::int4", &actual.i8, "is less than"}, - {"select -129::int8", &actual.i8, "is less than"}, + {"select 128::int2", &actual.i8}, + {"select 128::int4", &actual.i8}, + {"select 128::int8", &actual.i8}, + {"select -129::int2", &actual.i8}, + {"select -129::int4", &actual.i8}, + {"select -129::int8", &actual.i8}, // Check any integer type where value is outside Go:int16 range cannot be decoded - {"select 32768::int4", &actual.i16, "is greater than"}, - {"select 32768::int8", &actual.i16, "is greater than"}, - {"select -32769::int4", &actual.i16, "is less than"}, - {"select -32769::int8", &actual.i16, "is less than"}, + {"select 32768::int4", &actual.i16}, + {"select 32768::int8", &actual.i16}, + {"select -32769::int4", &actual.i16}, + {"select -32769::int8", &actual.i16}, // Check any integer type where value is outside Go:int32 range cannot be decoded - {"select 2147483648::int8", &actual.i32, "is greater than"}, - {"select -2147483649::int8", &actual.i32, "is less than"}, + {"select 2147483648::int8", &actual.i32}, + {"select -2147483649::int8", &actual.i32}, // Check any integer type where value is outside Go:uint range cannot be decoded - {"select -1::int2", &actual.ui, "is less than"}, - {"select -1::int4", &actual.ui, "is less than"}, - {"select -1::int8", &actual.ui, "is less than"}, + {"select -1::int2", &actual.ui}, + {"select -1::int4", &actual.ui}, + {"select -1::int8", &actual.ui}, // Check any integer type where value is outside Go:uint8 range cannot be decoded - {"select 256::int2", &actual.ui8, "is greater than"}, - {"select 256::int4", &actual.ui8, "is greater than"}, - {"select 256::int8", &actual.ui8, "is greater than"}, - {"select -1::int2", &actual.ui8, "is less than"}, - {"select -1::int4", &actual.ui8, "is less than"}, - {"select -1::int8", &actual.ui8, "is less than"}, + {"select 256::int2", &actual.ui8}, + {"select 256::int4", &actual.ui8}, + {"select 256::int8", &actual.ui8}, + {"select -1::int2", &actual.ui8}, + {"select -1::int4", &actual.ui8}, + {"select -1::int8", &actual.ui8}, // Check any integer type where value is outside Go:uint16 cannot be decoded - {"select 65536::int4", &actual.ui16, "is greater than"}, - {"select 65536::int8", &actual.ui16, "is greater than"}, - {"select -1::int2", &actual.ui16, "is less than"}, - {"select -1::int4", &actual.ui16, "is less than"}, - {"select -1::int8", &actual.ui16, "is less than"}, + {"select 65536::int4", &actual.ui16}, + {"select 65536::int8", &actual.ui16}, + {"select -1::int2", &actual.ui16}, + {"select -1::int4", &actual.ui16}, + {"select -1::int8", &actual.ui16}, // Check any integer type where value is outside Go:uint32 range cannot be decoded - {"select 4294967296::int8", &actual.ui32, "is greater than"}, - {"select -1::int2", &actual.ui32, "is less than"}, - {"select -1::int4", &actual.ui32, "is less than"}, - {"select -1::int8", &actual.ui32, "is less than"}, + {"select 4294967296::int8", &actual.ui32}, + {"select -1::int2", &actual.ui32}, + {"select -1::int4", &actual.ui32}, + {"select -1::int8", &actual.ui32}, // Check any integer type where value is outside Go:uint64 range cannot be decoded - {"select -1::int2", &actual.ui64, "is less than"}, - {"select -1::int4", &actual.ui64, "is less than"}, - {"select -1::int8", &actual.ui64, "is less than"}, + {"select -1::int2", &actual.ui64}, + {"select -1::int4", &actual.ui64}, + {"select -1::int8", &actual.ui64}, } for i, tt := range failedDecodeTests { - err := conn.QueryRow(tt.sql).Scan(tt.scanArg) + err := conn.QueryRow(context.Background(), tt.sql).Scan(tt.scanArg) if err == nil { t.Errorf("%d. Expected failure to decode, but unexpectedly succeeded: %v (sql -> %v)", i, err, tt.sql) - } else if !strings.Contains(err.Error(), tt.expectedErr) { + } else if !strings.Contains(err.Error(), "can't scan") { t.Errorf("%d. Expected failure to decode, but got: %v (sql -> %v)", i, err, tt.sql) } @@ -793,12 +935,12 @@ func TestQueryRowCoreIntegerDecoding(t *testing.T) { func TestQueryRowCoreByteSlice(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) tests := []struct { sql string - queryArg interface{} + queryArg any expected []byte }{ {"select $1::text", "Jack", []byte("Jack")}, @@ -810,7 +952,7 @@ func TestQueryRowCoreByteSlice(t *testing.T) { for i, tt := range tests { var actual []byte - err := conn.QueryRow(tt.sql, tt.queryArg).Scan(&actual) + err := conn.QueryRow(context.Background(), tt.sql, tt.queryArg).Scan(&actual) if err != nil { t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql) } @@ -823,51 +965,18 @@ func TestQueryRowCoreByteSlice(t *testing.T) { } } -func TestQueryRowUnknownType(t *testing.T) { +func TestQueryRowErrors(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - // Clear existing type mappings - conn.ConnInfo = pgtype.NewConnInfo() - conn.ConnInfo.RegisterDataType(pgtype.DataType{ - Value: &pgtype.GenericText{}, - Name: "point", - OID: 600, - }) - conn.ConnInfo.RegisterDataType(pgtype.DataType{ - Value: &pgtype.Int4{}, - Name: "int4", - OID: pgtype.Int4OID, - }) - - sql := "select $1::point" - expected := "(1,0)" - var actual string - - err := conn.QueryRow(sql, expected).Scan(&actual) - if err != nil { - t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) - } - - if actual != expected { - t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql) - + if conn.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip("Skipping due to known server missing point type") } - ensureConnValid(t, conn) -} - -func TestQueryRowErrors(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - type allTypes struct { i16 int16 - i int s string } @@ -875,20 +984,20 @@ func TestQueryRowErrors(t *testing.T) { tests := []struct { sql string - queryArgs []interface{} - scanArgs []interface{} + queryArgs []any + scanArgs []any err string }{ - {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`}, - {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"}, - {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "cannot decode"}, - {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot convert 705 to Point"}, + {"select $1::badtype", []any{"Jack"}, []any{&actual.i16}, `type "badtype" does not exist`}, + {"SYNTAX ERROR", []any{}, []any{&actual.i16}, "SQLSTATE 42601"}, + {"select $1::text", []any{"Jack"}, []any{&actual.i16}, "cannot scan text (OID 25) in text format into *int16"}, + {"select $1::point", []any{int(705)}, []any{&actual.s}, "unable to encode 705 into binary format for point (OID 600)"}, } for i, tt := range tests { actual = zero - err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) + err := conn.QueryRow(context.Background(), tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) if err == nil { t.Errorf("%d. Unexpected success (sql -> %v, queryArgs -> %v)", i, tt.sql, tt.queryArgs) } @@ -900,65 +1009,45 @@ func TestQueryRowErrors(t *testing.T) { } } -func TestQueryRowExErrorsWrongParameterOIDs(t *testing.T) { +func TestQueryRowNoResults(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - sql := ` - with t as ( - select 1::int8 as some_int, 'foo'::text as some_text - ) - select some_int from t where some_text = $1` - paramOIDs := []pgtype.OID{pgtype.TextArrayOID} - queryArgs := []interface{}{"bar"} - expectedErr := "operator does not exist: text = text[] (SQLSTATE 42883)" - var result int64 - - err := conn.QueryRowEx( - context.Background(), - sql, - &pgx.QueryExOptions{ - ParameterOIDs: paramOIDs, - ResultFormatCodes: []int16{pgx.BinaryFormatCode}, - }, - queryArgs..., - ).Scan(&result) - - if err == nil { - t.Errorf("Unexpected success (sql -> %v, paramOIDs -> %v, queryArgs -> %v)", sql, paramOIDs, queryArgs) - } - if err != nil && !strings.Contains(err.Error(), expectedErr) { - t.Errorf("Expected error to contain %s, but got %v (sql -> %v, paramOIDs -> %v, queryArgs -> %v)", - expectedErr, err, sql, paramOIDs, queryArgs) + var n int32 + err := conn.QueryRow(context.Background(), "select 1 where 1=0").Scan(&n) + if err != pgx.ErrNoRows { + t.Errorf("Expected pgx.ErrNoRows, got %v", err) } ensureConnValid(t, conn) } -func TestQueryRowNoResults(t *testing.T) { +func TestQueryRowEmptyQuery(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + var n int32 - err := conn.QueryRow("select 1 where 1=0").Scan(&n) - if err != pgx.ErrNoRows { - t.Errorf("Expected pgx.ErrNoRows, got %v", err) - } + err := conn.QueryRow(ctx, "").Scan(&n) + require.Error(t, err) + require.False(t, pgconn.Timeout(err)) ensureConnValid(t, conn) } func TestReadingValueAfterEmptyArray(t *testing.T) { - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) var a []string var b int32 - err := conn.QueryRow("select '{}'::text[], 42::integer").Scan(&a, &b) + err := conn.QueryRow(context.Background(), "select '{}'::text[], 42::integer").Scan(&a, &b) if err != nil { t.Fatalf("conn.QueryRow failed: %v", err) } @@ -973,11 +1062,11 @@ func TestReadingValueAfterEmptyArray(t *testing.T) { } func TestReadingNullByteArray(t *testing.T) { - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) var a []byte - err := conn.QueryRow("select null::text").Scan(&a) + err := conn.QueryRow(context.Background(), "select null::text").Scan(&a) if err != nil { t.Fatalf("conn.QueryRow failed: %v", err) } @@ -988,10 +1077,10 @@ func TestReadingNullByteArray(t *testing.T) { } func TestReadingNullByteArrays(t *testing.T) { - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - rows, err := conn.Query("select null::text union all select null::text") + rows, err := conn.Query(context.Background(), "select null::text union all select null::text") if err != nil { t.Fatalf("conn.Query failed: %v", err) } @@ -1012,55 +1101,52 @@ func TestReadingNullByteArrays(t *testing.T) { } } -// Use github.com/shopspring/decimal as real-world database/sql custom type -// to test against. +func TestQueryNullSliceIsSet(t *testing.T) { + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + a := []int32{1, 2, 3} + err := conn.QueryRow(context.Background(), "select null::int[]").Scan(&a) + if err != nil { + t.Fatalf("conn.QueryRow failed: %v", err) + } + + if a != nil { + t.Errorf("Expected 'a' to be nil, but it was: %v", a) + } +} + func TestConnQueryDatabaseSQLScanner(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - var num decimal.Decimal + var num sql.NullFloat64 - err := conn.QueryRow("select '1234.567'::decimal").Scan(&num) + err := conn.QueryRow(context.Background(), "select '1234.567'::float8").Scan(&num) if err != nil { t.Fatalf("Scan failed: %v", err) } - expected, err := decimal.NewFromString("1234.567") - if err != nil { - t.Fatal(err) - } - - if !num.Equals(expected) { - t.Errorf("Expected num to be %v, but it was %v", expected, num) - } + require.True(t, num.Valid) + require.Equal(t, 1234.567, num.Float64) ensureConnValid(t, conn) } -// Use github.com/shopspring/decimal as real-world database/sql custom type -// to test against. func TestConnQueryDatabaseSQLDriverValuer(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - expected, err := decimal.NewFromString("1234.567") - if err != nil { - t.Fatal(err) - } - var num decimal.Decimal - - err = conn.QueryRow("select $1::decimal", &expected).Scan(&num) - if err != nil { - t.Fatalf("Scan failed: %v", err) - } + expected := sql.NullFloat64{Float64: 1234.567, Valid: true} + var actual sql.NullFloat64 - if !num.Equals(expected) { - t.Errorf("Expected num to be %v, but it was %v", expected, num) - } + err := conn.QueryRow(context.Background(), "select $1::float8", &expected).Scan(&actual) + require.NoError(t, err) + require.Equal(t, expected, actual) ensureConnValid(t, conn) } @@ -1069,48 +1155,232 @@ func TestConnQueryDatabaseSQLDriverValuer(t *testing.T) { func TestConnQueryDatabaseSQLDriverValuerWithAutoGeneratedPointerReceiver(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) mustExec(t, conn, "create temporary table t(n numeric)") - var d *apd.Decimal - commandTag, err := conn.Exec(`insert into t(n) values($1)`, d) + var d *sql.NullInt64 + commandTag, err := conn.Exec(context.Background(), `insert into t(n) values($1)`, d) if err != nil { t.Fatal(err) } - if commandTag != "INSERT 0 1" { + if commandTag.String() != "INSERT 0 1" { t.Fatalf("want %s, got %s", "INSERT 0 1", commandTag) } ensureConnValid(t, conn) } -func TestConnQueryDatabaseSQLDriverValuerWithBinaryPgTypeThatAcceptsSameType(t *testing.T) { +type nilPointerAsEmptyJSONObject struct { + ID string + Name string +} + +func (v *nilPointerAsEmptyJSONObject) Value() (driver.Value, error) { + if v == nil { + return "{}", nil + } + + return json.Marshal(v) +} + +// https://github.com/jackc/pgx/issues/1566 +func TestConnQueryDatabaseSQLDriverValuerCalledOnNilPointerImplementers(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - conn.ConnInfo.RegisterDataType(pgtype.DataType{ - Value: &satori.UUID{}, - Name: "uuid", - OID: 2950, - }) + mustExec(t, conn, "create temporary table t(v json not null)") - expected, err := uuid.FromString("6ba7b810-9dad-11d1-80b4-00c04fd430c8") - if err != nil { - t.Fatal(err) + var v *nilPointerAsEmptyJSONObject + commandTag, err := conn.Exec(context.Background(), `insert into t(v) values($1)`, v) + require.NoError(t, err) + require.Equal(t, "INSERT 0 1", commandTag.String()) + + var s string + err = conn.QueryRow(context.Background(), "select v from t").Scan(&s) + require.NoError(t, err) + require.Equal(t, "{}", s) + + _, err = conn.Exec(context.Background(), `delete from t`) + require.NoError(t, err) + + v = &nilPointerAsEmptyJSONObject{ID: "1", Name: "foo"} + commandTag, err = conn.Exec(context.Background(), `insert into t(v) values($1)`, v) + require.NoError(t, err) + require.Equal(t, "INSERT 0 1", commandTag.String()) + + var v2 *nilPointerAsEmptyJSONObject + err = conn.QueryRow(context.Background(), "select v from t").Scan(&v2) + require.NoError(t, err) + require.Equal(t, v, v2) + + ensureConnValid(t, conn) +} + +type nilSliceAsEmptySlice []byte + +func (j nilSliceAsEmptySlice) Value() (driver.Value, error) { + if len(j) == 0 { + return []byte("[]"), nil } - var u2 uuid.UUID - err = conn.QueryRow("select $1::uuid", expected).Scan(&u2) + return []byte(j), nil +} + +func (j *nilSliceAsEmptySlice) UnmarshalJSON(data []byte) error { + *j = bytes.Clone(data) + return nil +} + +// https://github.com/jackc/pgx/issues/1860 +func TestConnQueryDatabaseSQLDriverValuerCalledOnNilSliceImplementers(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, "create temporary table t(v json not null)") + + var v nilSliceAsEmptySlice + commandTag, err := conn.Exec(context.Background(), `insert into t(v) values($1)`, v) + require.NoError(t, err) + require.Equal(t, "INSERT 0 1", commandTag.String()) + + var s string + err = conn.QueryRow(context.Background(), "select v from t").Scan(&s) + require.NoError(t, err) + require.Equal(t, "[]", s) + + _, err = conn.Exec(context.Background(), `delete from t`) + require.NoError(t, err) + + v = nilSliceAsEmptySlice(`{"name": "foo"}`) + commandTag, err = conn.Exec(context.Background(), `insert into t(v) values($1)`, v) + require.NoError(t, err) + require.Equal(t, "INSERT 0 1", commandTag.String()) + + var v2 nilSliceAsEmptySlice + err = conn.QueryRow(context.Background(), "select v from t").Scan(&v2) + require.NoError(t, err) + require.Equal(t, v, v2) + + ensureConnValid(t, conn) +} + +type nilMapAsEmptyObject map[string]any + +func (j nilMapAsEmptyObject) Value() (driver.Value, error) { + if j == nil { + return []byte("{}"), nil + } + + return json.Marshal(j) +} + +func (j *nilMapAsEmptyObject) UnmarshalJSON(data []byte) error { + var m map[string]any + err := json.Unmarshal(data, &m) if err != nil { - t.Fatalf("Scan failed: %v", err) + return err + } + + *j = m + + return nil +} + +// https://github.com/jackc/pgx/pull/2019#discussion_r1605806751 +func TestConnQueryDatabaseSQLDriverValuerCalledOnNilMapImplementers(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, "create temporary table t(v json not null)") + + var v nilMapAsEmptyObject + commandTag, err := conn.Exec(context.Background(), `insert into t(v) values($1)`, v) + require.NoError(t, err) + require.Equal(t, "INSERT 0 1", commandTag.String()) + + var s string + err = conn.QueryRow(context.Background(), "select v from t").Scan(&s) + require.NoError(t, err) + require.Equal(t, "{}", s) + + _, err = conn.Exec(context.Background(), `delete from t`) + require.NoError(t, err) + + v = nilMapAsEmptyObject{"name": "foo"} + commandTag, err = conn.Exec(context.Background(), `insert into t(v) values($1)`, v) + require.NoError(t, err) + require.Equal(t, "INSERT 0 1", commandTag.String()) + + var v2 nilMapAsEmptyObject + err = conn.QueryRow(context.Background(), "select v from t").Scan(&v2) + require.NoError(t, err) + require.Equal(t, v, v2) + + ensureConnValid(t, conn) +} + +func TestConnQueryDatabaseSQLDriverScannerWithBinaryPgTypeThatAcceptsSameType(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + var actual sql.NullString + err := conn.QueryRow(context.Background(), "select '6ba7b810-9dad-11d1-80b4-00c04fd430c8'::uuid").Scan(&actual) + require.NoError(t, err) + + require.True(t, actual.Valid) + require.Equal(t, "6ba7b810-9dad-11d1-80b4-00c04fd430c8", actual.String) + + ensureConnValid(t, conn) +} + +// https://github.com/jackc/pgx/issues/1273#issuecomment-1221672175 +func TestConnQueryDatabaseSQLDriverValuerTextWhenBinaryIsPreferred(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + arg := sql.NullString{String: "1.234", Valid: true} + var result pgtype.Numeric + err := conn.QueryRow(context.Background(), "select $1::numeric", arg).Scan(&result) + require.NoError(t, err) + + require.True(t, result.Valid) + f64, err := result.Float64Value() + require.NoError(t, err) + require.Equal(t, pgtype.Float8{Float64: 1.234, Valid: true}, f64) + + ensureConnValid(t, conn) +} + +// https://github.com/jackc/pgx/issues/1426 +func TestConnQueryDatabaseSQLNullFloat64NegativeZeroPointZero(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + tests := []float64{ + -0.01, + -0.001, + -0.0001, } - if expected != u2 { - t.Errorf("Expected u2 to be %v, but it was %v", expected, u2) + for _, val := range tests { + var result sql.NullFloat64 + err := conn.QueryRow(context.Background(), "select $1::numeric", val).Scan(&result) + require.NoError(t, err) + require.Equal(t, sql.NullFloat64{Float64: val, Valid: true}, result) } ensureConnValid(t, conn) @@ -1119,7 +1389,7 @@ func TestConnQueryDatabaseSQLDriverValuerWithBinaryPgTypeThatAcceptsSameType(t * func TestConnQueryDatabaseSQLNullX(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) type row struct { @@ -1143,6 +1413,7 @@ func TestConnQueryDatabaseSQLNullX(t *testing.T) { var actual row err := conn.QueryRow( + context.Background(), "select $1::bool, $2::bool, $3::int8, $4::int8, $5::float8, $6::float8, $7::text, $8::text", expected.boolValid, expected.boolNull, @@ -1173,16 +1444,16 @@ func TestConnQueryDatabaseSQLNullX(t *testing.T) { ensureConnValid(t, conn) } -func TestQueryExContextSuccess(t *testing.T) { +func TestQueryContextSuccess(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() - rows, err := conn.QueryEx(ctx, "select 42::integer", nil) + rows, err := conn.Query(ctx, "select 42::integer") if err != nil { t.Fatal(err) } @@ -1210,16 +1481,18 @@ func TestQueryExContextSuccess(t *testing.T) { ensureConnValid(t, conn) } -func TestQueryExContextErrorWhileReceivingRows(t *testing.T) { +func TestQueryContextErrorWhileReceivingRows(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) + pgxtest.SkipCockroachDB(t, conn, "Server uses numeric instead of int") + ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() - rows, err := conn.QueryEx(ctx, "select 10/(10-n) from generate_series(1, 100) n", nil) + rows, err := conn.Query(ctx, "select 10/(10-n) from generate_series(1, 100) n") if err != nil { t.Fatal(err) } @@ -1247,45 +1520,17 @@ func TestQueryExContextErrorWhileReceivingRows(t *testing.T) { ensureConnValid(t, conn) } -func TestQueryExContextCancelationCancelsQuery(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - ctx, cancelFunc := context.WithCancel(context.Background()) - go func() { - time.Sleep(500 * time.Millisecond) - cancelFunc() - }() - - rows, err := conn.QueryEx(ctx, "select pg_sleep(5)", nil) - if err != nil { - t.Fatal(err) - } - - for rows.Next() { - t.Fatal("No rows should ever be ready -- context cancel apparently did not happen") - } - - if rows.Err() != context.Canceled { - t.Fatalf("Expected context.Canceled error, got %v", rows.Err()) - } - - ensureConnValid(t, conn) -} - -func TestQueryRowExContextSuccess(t *testing.T) { +func TestQueryRowContextSuccess(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() var result int - err := conn.QueryRowEx(ctx, "select 42::integer", nil).Scan(&result) + err := conn.QueryRow(ctx, "select 42::integer").Scan(&result) if err != nil { t.Fatal(err) } @@ -1296,17 +1541,17 @@ func TestQueryRowExContextSuccess(t *testing.T) { ensureConnValid(t, conn) } -func TestQueryRowExContextErrorWhileReceivingRow(t *testing.T) { +func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() var result int - err := conn.QueryRowEx(ctx, "select 10/0", nil).Scan(&result) + err := conn.QueryRow(ctx, "select 10/0").Scan(&result) if err == nil || err.Error() != "ERROR: division by zero (SQLSTATE 22012)" { t.Fatalf("Expected division by zero error, but got %v", err) } @@ -1314,57 +1559,46 @@ func TestQueryRowExContextErrorWhileReceivingRow(t *testing.T) { ensureConnValid(t, conn) } -func TestQueryRowExContextCancelationCancelsQuery(t *testing.T) { +func TestQueryCloseBefore(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - ctx, cancelFunc := context.WithCancel(context.Background()) - go func() { - time.Sleep(500 * time.Millisecond) - cancelFunc() - }() - - var result []byte - err := conn.QueryRowEx(ctx, "select pg_sleep(5)", nil).Scan(&result) - if err != context.Canceled { - t.Fatalf("Expected context.Canceled error, got %v", err) - } + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + closeConn(t, conn) - ensureConnValid(t, conn) + _, err := conn.Query(context.Background(), "select 1") + require.Error(t, err) + assert.True(t, pgconn.SafeToRetry(err)) } -func TestConnQueryRowExSingleRoundTrip(t *testing.T) { +func TestScanRow(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - var result int32 - err := conn.QueryRowEx( - context.Background(), - "select $1 + $2", - &pgx.QueryExOptions{ - ParameterOIDs: []pgtype.OID{pgtype.Int4OID, pgtype.Int4OID}, - ResultFormatCodes: []int16{pgx.BinaryFormatCode}, - }, - 1, 2, - ).Scan(&result) - if err != nil { - t.Fatal(err) - } - if result != 3 { - t.Fatalf("result => %d, want %d", result, 3) + resultReader := conn.PgConn().ExecParams(context.Background(), "select generate_series(1,$1)", [][]byte{[]byte("10")}, nil, nil, nil) + + var sum, rowCount int32 + + for resultReader.NextRow() { + var n int32 + err := pgx.ScanRow(conn.TypeMap(), resultReader.FieldDescriptions(), resultReader.Values(), &n) + assert.NoError(t, err) + sum += n + rowCount++ } - ensureConnValid(t, conn) + _, err := resultReader.Close() + + require.NoError(t, err) + assert.EqualValues(t, 10, rowCount) + assert.EqualValues(t, 55, sum) } func TestConnSimpleProtocol(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) // Test all supported low-level types @@ -1372,10 +1606,10 @@ func TestConnSimpleProtocol(t *testing.T) { { expected := int64(42) var actual int64 - err := conn.QueryRowEx( + err := conn.QueryRow( context.Background(), "select $1::int8", - &pgx.QueryExOptions{SimpleProtocol: true}, + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1389,10 +1623,10 @@ func TestConnSimpleProtocol(t *testing.T) { { expected := float64(1.23) var actual float64 - err := conn.QueryRowEx( + err := conn.QueryRow( context.Background(), "select $1::float8", - &pgx.QueryExOptions{SimpleProtocol: true}, + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1406,10 +1640,10 @@ func TestConnSimpleProtocol(t *testing.T) { { expected := true var actual bool - err := conn.QueryRowEx( + err := conn.QueryRow( context.Background(), - "select $1", - &pgx.QueryExOptions{SimpleProtocol: true}, + "select $1::boolean", + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1423,16 +1657,16 @@ func TestConnSimpleProtocol(t *testing.T) { { expected := []byte{0, 1, 20, 35, 64, 80, 120, 3, 255, 240, 128, 95} var actual []byte - err := conn.QueryRowEx( + err := conn.QueryRow( context.Background(), "select $1::bytea", - &pgx.QueryExOptions{SimpleProtocol: true}, + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { t.Error(err) } - if bytes.Compare(actual, expected) != 0 { + if !bytes.Equal(actual, expected) { t.Errorf("expected %v got %v", expected, actual) } } @@ -1440,10 +1674,10 @@ func TestConnSimpleProtocol(t *testing.T) { { expected := "test" var actual string - err := conn.QueryRowEx( + err := conn.QueryRow( context.Background(), "select $1::text", - &pgx.QueryExOptions{SimpleProtocol: true}, + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1454,22 +1688,257 @@ func TestConnSimpleProtocol(t *testing.T) { } } - // Test high-level type + { + tests := []struct { + expected []string + }{ + {[]string(nil)}, + {[]string{}}, + {[]string{"test", "foo", "bar"}}, + {[]string{`foo'bar"\baz;quz`, `foo'bar"\baz;quz`}}, + } + for i, tt := range tests { + var actual []string + err := conn.QueryRow( + context.Background(), + "select $1::text[]", + pgx.QueryExecModeSimpleProtocol, + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + } { - expected := pgtype.Circle{P: pgtype.Vec2{1, 2}, R: 1.5, Status: pgtype.Present} - actual := expected - err := conn.QueryRowEx( - context.Background(), - "select $1::circle", - &pgx.QueryExOptions{SimpleProtocol: true}, - &expected, - ).Scan(&actual) - if err != nil { - t.Error(err) + tests := []struct { + expected []int16 + }{ + {[]int16(nil)}, + {[]int16{}}, + {[]int16{1, 2, 3}}, } - if expected != actual { - t.Errorf("expected %v got %v", expected, actual) + for i, tt := range tests { + var actual []int16 + err := conn.QueryRow( + context.Background(), + "select $1::smallint[]", + pgx.QueryExecModeSimpleProtocol, + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + } + + { + tests := []struct { + expected []int32 + }{ + {[]int32(nil)}, + {[]int32{}}, + {[]int32{1, 2, 3}}, + } + for i, tt := range tests { + var actual []int32 + err := conn.QueryRow( + context.Background(), + "select $1::int[]", + pgx.QueryExecModeSimpleProtocol, + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + } + + { + tests := []struct { + expected []int64 + }{ + {[]int64(nil)}, + {[]int64{}}, + {[]int64{1, 2, 3}}, + } + for i, tt := range tests { + var actual []int64 + err := conn.QueryRow( + context.Background(), + "select $1::bigint[]", + pgx.QueryExecModeSimpleProtocol, + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + } + + { + tests := []struct { + expected []int + }{ + {[]int(nil)}, + {[]int{}}, + {[]int{1, 2, 3}}, + } + for i, tt := range tests { + var actual []int + err := conn.QueryRow( + context.Background(), + "select $1::bigint[]", + pgx.QueryExecModeSimpleProtocol, + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + } + + { + tests := []struct { + expected []uint16 + }{ + {[]uint16(nil)}, + {[]uint16{}}, + {[]uint16{1, 2, 3}}, + } + for i, tt := range tests { + var actual []uint16 + err := conn.QueryRow( + context.Background(), + "select $1::smallint[]", + pgx.QueryExecModeSimpleProtocol, + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + } + + { + tests := []struct { + expected []uint32 + }{ + {[]uint32(nil)}, + {[]uint32{}}, + {[]uint32{1, 2, 3}}, + } + for i, tt := range tests { + var actual []uint32 + err := conn.QueryRow( + context.Background(), + "select $1::bigint[]", + pgx.QueryExecModeSimpleProtocol, + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + } + + { + tests := []struct { + expected []uint64 + }{ + {[]uint64(nil)}, + {[]uint64{}}, + {[]uint64{1, 2, 3}}, + } + for i, tt := range tests { + var actual []uint64 + err := conn.QueryRow( + context.Background(), + "select $1::bigint[]", + pgx.QueryExecModeSimpleProtocol, + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + } + + { + tests := []struct { + expected []uint + }{ + {[]uint(nil)}, + {[]uint{}}, + {[]uint{1, 2, 3}}, + } + for i, tt := range tests { + var actual []uint + err := conn.QueryRow( + context.Background(), + "select $1::bigint[]", + pgx.QueryExecModeSimpleProtocol, + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + } + + { + tests := []struct { + expected []float32 + }{ + {[]float32(nil)}, + {[]float32{}}, + {[]float32{1, 2, 3}}, + } + for i, tt := range tests { + var actual []float32 + err := conn.QueryRow( + context.Background(), + "select $1::float4[]", + pgx.QueryExecModeSimpleProtocol, + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + } + + { + tests := []struct { + expected []float64 + }{ + {[]float64(nil)}, + {[]float64{}}, + {[]float64{1, 2, 3}}, + } + for i, tt := range tests { + var actual []float64 + err := conn.QueryRow( + context.Background(), + "select $1::float8[]", + pgx.QueryExecModeSimpleProtocol, + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + } + + // Test high-level type + + { + if conn.PgConn().ParameterStatus("crdb_version") == "" { + // CockroachDB doesn't support circle type. + expected := pgtype.Circle{P: pgtype.Vec2{X: 1, Y: 2}, R: 1.5, Valid: true} + actual := expected + err := conn.QueryRow( + context.Background(), + "select $1::circle", + pgx.QueryExecModeSimpleProtocol, + &expected, + ).Scan(&actual) + if err != nil { + t.Error(err) + } + if expected != actual { + t.Errorf("expected %v got %v", expected, actual) + } } } @@ -1486,10 +1955,10 @@ func TestConnSimpleProtocol(t *testing.T) { var actualBool bool var actualBytes []byte var actualString string - err := conn.QueryRowEx( + err := conn.QueryRow( context.Background(), - "select $1::int8, $2::float8, $3, $4::bytea, $5::text", - &pgx.QueryExOptions{SimpleProtocol: true}, + "select $1::int8, $2::float8, $3::boolean, $4::bytea, $5::text", + pgx.QueryExecModeSimpleProtocol, expectedInt64, expectedFloat64, expectedBool, expectedBytes, expectedString, ).Scan(&actualInt64, &actualFloat64, &actualBool, &actualBytes, &actualString) if err != nil { @@ -1504,7 +1973,7 @@ func TestConnSimpleProtocol(t *testing.T) { if expectedBool != actualBool { t.Errorf("expected %v got %v", expectedBool, actualBool) } - if bytes.Compare(expectedBytes, actualBytes) != 0 { + if !bytes.Equal(expectedBytes, actualBytes) { t.Errorf("expected %v got %v", expectedBytes, actualBytes) } if expectedString != actualString { @@ -1517,10 +1986,10 @@ func TestConnSimpleProtocol(t *testing.T) { { expected := "foo';drop table users;" var actual string - err := conn.QueryRowEx( + err := conn.QueryRow( context.Background(), "select $1", - &pgx.QueryExOptions{SimpleProtocol: true}, + pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) if err != nil { @@ -1537,16 +2006,18 @@ func TestConnSimpleProtocol(t *testing.T) { func TestConnSimpleProtocolRefusesNonUTF8ClientEncoding(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) + pgxtest.SkipCockroachDB(t, conn, "Server does not support changing client_encoding (https://www.cockroachlabs.com/docs/stable/set-vars.html)") + mustExec(t, conn, "set client_encoding to 'SQL_ASCII'") var expected string - err := conn.QueryRowEx( + err := conn.QueryRow( context.Background(), "select $1", - &pgx.QueryExOptions{SimpleProtocol: true}, + pgx.QueryExecModeSimpleProtocol, "test", ).Scan(&expected) if err == nil { @@ -1559,16 +2030,18 @@ func TestConnSimpleProtocolRefusesNonUTF8ClientEncoding(t *testing.T) { func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) + pgxtest.SkipCockroachDB(t, conn, "Server does not support standard_conforming_strings = off (https://github.com/cockroachdb/cockroach/issues/36215)") + mustExec(t, conn, "set standard_conforming_strings to off") var expected string - err := conn.QueryRowEx( + err := conn.QueryRow( context.Background(), "select $1", - &pgx.QueryExOptions{SimpleProtocol: true}, + pgx.QueryExecModeSimpleProtocol, `\'; drop table users; --`, ).Scan(&expected) if err == nil { @@ -1577,3 +2050,216 @@ func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) { ensureConnValid(t, conn) } + +// https://github.com/jackc/pgx/issues/895 +func TestQueryErrorWithDisabledStatementCache(t *testing.T) { + t.Parallel() + + config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec + config.StatementCacheCapacity = 0 + config.DescriptionCacheCapacity = 0 + + conn := mustConnect(t, config) + defer closeConn(t, conn) + + _, err := conn.Exec(context.Background(), "create temporary table t_unq(id text primary key);") + require.NoError(t, err) + + _, err = conn.Exec(context.Background(), "insert into t_unq (id) values ($1)", "abc") + require.NoError(t, err) + + rows, err := conn.Query(context.Background(), "insert into t_unq (id) values ($1)", "abc") + require.NoError(t, err) + rows.Close() + err = rows.Err() + require.Error(t, err) + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + assert.Equal(t, "23505", pgErr.Code) + } else { + t.Errorf("err is not a *pgconn.PgError: %T", err) + } + + ensureConnValid(t, conn) +} + +func TestConnQueryQueryExecModeCacheDescribeSafeEvenWhenTypesChange(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + pgxtest.SkipCockroachDB(t, conn, "Server does not support alter column type from int to float4") + + _, err := conn.Exec(ctx, `create temporary table to_change ( + name text primary key, + age int +); + +insert into to_change (name, age) values ('John', 42);`) + require.NoError(t, err) + + var name string + var ageInt32 int32 + err = conn.QueryRow(ctx, "select * from to_change where age = $1", pgx.QueryExecModeCacheDescribe, int32(42)).Scan(&name, &ageInt32) + require.NoError(t, err) + require.Equal(t, "John", name) + require.Equal(t, int32(42), ageInt32) + + _, err = conn.Exec(ctx, `alter table to_change alter column age type float4;`) + require.NoError(t, err) + + err = conn.QueryRow(ctx, "select * from to_change where age = $1", pgx.QueryExecModeCacheDescribe, int32(42)).Scan(&name, &ageInt32) + require.NoError(t, err) + require.Equal(t, "John", name) + require.Equal(t, int32(42), ageInt32) + + var ageFloat32 float32 + err = conn.QueryRow(ctx, "select * from to_change where age = $1", pgx.QueryExecModeCacheDescribe, int32(42)).Scan(&name, &ageFloat32) + require.NoError(t, err) + require.Equal(t, "John", name) + require.Equal(t, float32(42), ageFloat32) + + _, err = conn.Exec(ctx, `alter table to_change drop column name;`) + require.NoError(t, err) + + // Number of result columns has changed, so just like with a prepared statement, this will fail the first time. + err = conn.QueryRow(ctx, "select * from to_change where age = $1", pgx.QueryExecModeCacheDescribe, int32(42)).Scan(&ageFloat32) + require.EqualError(t, err, "ERROR: bind message has 2 result formats but query has 1 columns (SQLSTATE 08P01)") + + // But it will work the second time after the cache is invalidated. + err = conn.QueryRow(ctx, "select * from to_change where age = $1", pgx.QueryExecModeCacheDescribe, int32(42)).Scan(&ageFloat32) + require.NoError(t, err) + require.Equal(t, float32(42), ageFloat32) + + _, err = conn.Exec(ctx, `alter table to_change alter column age type numeric;`) + require.NoError(t, err) + + err = conn.QueryRow(ctx, "select * from to_change where age = $1", pgx.QueryExecModeCacheDescribe, int32(42)).Scan(&ageFloat32) + require.NoError(t, err) + require.Equal(t, float32(42), ageFloat32) +} + +func TestQueryWithQueryRewriter(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + qr := testQueryRewriter{sql: "select $1::int", args: []any{42}} + rows, err := conn.Query(ctx, "should be replaced", &qr) + require.NoError(t, err) + + var n int32 + var rowCount int + for rows.Next() { + rowCount++ + err = rows.Scan(&n) + require.NoError(t, err) + } + + require.NoError(t, rows.Err()) + }) +} + +// https://github.com/jackc/pgx/issues/2402 +func TestQueryWithEmptyQuery(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + emptyQueryStrings := []string{"", " ", "/* ping */", "-- ping"} + for _, eq := range emptyQueryStrings { + rows, err := conn.Query(ctx, eq) + require.NoError(t, err) + require.Equal(t, []pgconn.FieldDescription(nil), rows.FieldDescriptions()) + require.False(t, rows.Next()) + require.NoError(t, rows.Err()) + } + }) +} + +// This example uses Query without using any helpers to read the results. Normally CollectRows, ForEachRow, or another +// helper function should be used. +func ExampleConn_Query() { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + if conn.PgConn().ParameterStatus("crdb_version") != "" { + // Skip test / example when running on CockroachDB. Since an example can't be skipped fake success instead. + fmt.Println(`Cheeseburger: $10 +Fries: $5 +Soft Drink: $3`) + return + } + + // Setup example schema and data. + _, err = conn.Exec(ctx, ` +create temporary table products ( + id int primary key generated by default as identity, + name varchar(100) not null, + price int not null +); + +insert into products (name, price) values + ('Cheeseburger', 10), + ('Double Cheeseburger', 14), + ('Fries', 5), + ('Soft Drink', 3); +`) + if err != nil { + fmt.Printf("Unable to setup example schema and data: %v", err) + return + } + + rows, err := conn.Query(ctx, "select name, price from products where price < $1 order by price desc", 12) + // It is unnecessary to check err. If an error occurred it will be returned by rows.Err() later. But in rare + // cases it may be useful to detect the error as early as possible. + if err != nil { + fmt.Printf("Query error: %v", err) + return + } + + // Ensure rows is closed. It is safe to close rows multiple times. + defer rows.Close() + + // Iterate through the result set + for rows.Next() { + var name string + var price int32 + + err = rows.Scan(&name, &price) + if err != nil { + fmt.Printf("Scan error: %v", err) + return + } + + fmt.Printf("%s: $%d\n", name, price) + } + + // rows is closed automatically when rows.Next() returns false so it is not necessary to manually close rows. + + // The first error encountered by the original Query call, rows.Next or rows.Scan will be returned here. + if rows.Err() != nil { + fmt.Printf("rows error: %v", rows.Err()) + return + } + + // Output: + // Cheeseburger: $10 + // Fries: $5 + // Soft Drink: $3 +} diff --git a/replication.go b/replication.go deleted file mode 100644 index 7dd5efe4b..000000000 --- a/replication.go +++ /dev/null @@ -1,459 +0,0 @@ -package pgx - -import ( - "context" - "encoding/binary" - "fmt" - "time" - - "github.com/pkg/errors" - - "github.com/jackc/pgx/pgio" - "github.com/jackc/pgx/pgproto3" -) - -const ( - copyBothResponse = 'W' - walData = 'w' - senderKeepalive = 'k' - standbyStatusUpdate = 'r' - initialReplicationResponseTimeout = 5 * time.Second -) - -var epochNano int64 - -func init() { - epochNano = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).UnixNano() -} - -// Format the given 64bit LSN value into the XXX/XXX format, -// which is the format reported by postgres. -func FormatLSN(lsn uint64) string { - return fmt.Sprintf("%X/%X", uint32(lsn>>32), uint32(lsn)) -} - -// Parse the given XXX/XXX format LSN as reported by postgres, -// into a 64 bit integer as used internally by the wire procotols -func ParseLSN(lsn string) (outputLsn uint64, err error) { - var upperHalf uint64 - var lowerHalf uint64 - var nparsed int - nparsed, err = fmt.Sscanf(lsn, "%X/%X", &upperHalf, &lowerHalf) - if err != nil { - return - } - - if nparsed != 2 { - err = errors.New(fmt.Sprintf("Failed to parsed LSN: %s", lsn)) - return - } - - outputLsn = (upperHalf << 32) + lowerHalf - return -} - -// The WAL message contains WAL payload entry data -type WalMessage struct { - // The WAL start position of this data. This - // is the WAL position we need to track. - WalStart uint64 - // The server wal end and server time are - // documented to track the end position and current - // time of the server, both of which appear to be - // unimplemented in pg 9.5. - ServerWalEnd uint64 - ServerTime uint64 - // The WAL data is the raw unparsed binary WAL entry. - // The contents of this are determined by the output - // logical encoding plugin. - WalData []byte -} - -func (w *WalMessage) Time() time.Time { - return time.Unix(0, (int64(w.ServerTime)*1000)+epochNano) -} - -func (w *WalMessage) ByteLag() uint64 { - return (w.ServerWalEnd - w.WalStart) -} - -func (w *WalMessage) String() string { - return fmt.Sprintf("Wal: %s Time: %s Lag: %d", FormatLSN(w.WalStart), w.Time(), w.ByteLag()) -} - -// The server heartbeat is sent periodically from the server, -// including server status, and a reply request field -type ServerHeartbeat struct { - // The current max wal position on the server, - // used for lag tracking - ServerWalEnd uint64 - // The server time, in microseconds since jan 1 2000 - ServerTime uint64 - // If 1, the server is requesting a standby status message - // to be sent immediately. - ReplyRequested byte -} - -func (s *ServerHeartbeat) Time() time.Time { - return time.Unix(0, (int64(s.ServerTime)*1000)+epochNano) -} - -func (s *ServerHeartbeat) String() string { - return fmt.Sprintf("WalEnd: %s ReplyRequested: %d T: %s", FormatLSN(s.ServerWalEnd), s.ReplyRequested, s.Time()) -} - -// The replication message wraps all possible messages from the -// server received during replication. At most one of the wal message -// or server heartbeat will be non-nil -type ReplicationMessage struct { - WalMessage *WalMessage - ServerHeartbeat *ServerHeartbeat -} - -// The standby status is the client side heartbeat sent to the postgresql -// server to track the client wal positions. For practical purposes, -// all wal positions are typically set to the same value. -type StandbyStatus struct { - // The WAL position that's been locally written - WalWritePosition uint64 - // The WAL position that's been locally flushed - WalFlushPosition uint64 - // The WAL position that's been locally applied - WalApplyPosition uint64 - // The client time in microseconds since jan 1 2000 - ClientTime uint64 - // If 1, requests the server to immediately send a - // server heartbeat - ReplyRequested byte -} - -// Create a standby status struct, which sets all the WAL positions -// to the given wal position, and the client time to the current time. -// The wal positions are, in order: -// WalFlushPosition -// WalApplyPosition -// WalWritePosition -// -// If only one position is provided, it will be used as the value for all 3 -// status fields. Note you must provide either 1 wal position, or all 3 -// in order to initialize the standby status. -func NewStandbyStatus(walPositions ...uint64) (status *StandbyStatus, err error) { - if len(walPositions) == 1 { - status = new(StandbyStatus) - status.WalFlushPosition = walPositions[0] - status.WalApplyPosition = walPositions[0] - status.WalWritePosition = walPositions[0] - } else if len(walPositions) == 3 { - status = new(StandbyStatus) - status.WalFlushPosition = walPositions[0] - status.WalApplyPosition = walPositions[1] - status.WalWritePosition = walPositions[2] - } else { - err = errors.New(fmt.Sprintf("Invalid number of wal positions provided, need 1 or 3, got %d", len(walPositions))) - return - } - status.ClientTime = uint64((time.Now().UnixNano() - epochNano) / 1000) - return -} - -func ReplicationConnect(config ConnConfig) (r *ReplicationConn, err error) { - if config.RuntimeParams == nil { - config.RuntimeParams = make(map[string]string) - } - config.RuntimeParams["replication"] = "database" - - c, err := Connect(config) - if err != nil { - return - } - return &ReplicationConn{c: c}, nil -} - -type ReplicationConn struct { - c *Conn -} - -// Send standby status to the server, which both acts as a keepalive -// message to the server, as well as carries the WAL position of the -// client, which then updates the server's replication slot position. -func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) { - buf := rc.c.wbuf - buf = append(buf, copyData) - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf = append(buf, standbyStatusUpdate) - buf = pgio.AppendInt64(buf, int64(k.WalWritePosition)) - buf = pgio.AppendInt64(buf, int64(k.WalFlushPosition)) - buf = pgio.AppendInt64(buf, int64(k.WalApplyPosition)) - buf = pgio.AppendInt64(buf, int64(k.ClientTime)) - buf = append(buf, k.ReplyRequested) - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - _, err = rc.c.conn.Write(buf) - if err != nil { - rc.c.die(err) - } - - return -} - -func (rc *ReplicationConn) Close() error { - return rc.c.Close() -} - -func (rc *ReplicationConn) IsAlive() bool { - return rc.c.IsAlive() -} - -func (rc *ReplicationConn) CauseOfDeath() error { - return rc.c.CauseOfDeath() -} - -func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err error) { - msg, err := rc.c.rxMsg() - if err != nil { - return - } - - switch msg := msg.(type) { - case *pgproto3.NoticeResponse: - pgError := rc.c.rxErrorResponse((*pgproto3.ErrorResponse)(msg)) - if rc.c.shouldLog(LogLevelInfo) { - rc.c.log(LogLevelInfo, pgError.Error(), nil) - } - case *pgproto3.ErrorResponse: - err = rc.c.rxErrorResponse(msg) - if rc.c.shouldLog(LogLevelError) { - rc.c.log(LogLevelError, err.Error(), nil) - } - return - case *pgproto3.CopyBothResponse: - // This is the tail end of the replication process start, - // and can be safely ignored - return - case *pgproto3.CopyData: - msgType := msg.Data[0] - rp := 1 - - switch msgType { - case walData: - walStart := binary.BigEndian.Uint64(msg.Data[rp:]) - rp += 8 - serverWalEnd := binary.BigEndian.Uint64(msg.Data[rp:]) - rp += 8 - serverTime := binary.BigEndian.Uint64(msg.Data[rp:]) - rp += 8 - walData := msg.Data[rp:] - walMessage := WalMessage{WalStart: walStart, - ServerWalEnd: serverWalEnd, - ServerTime: serverTime, - WalData: walData, - } - - return &ReplicationMessage{WalMessage: &walMessage}, nil - case senderKeepalive: - serverWalEnd := binary.BigEndian.Uint64(msg.Data[rp:]) - rp += 8 - serverTime := binary.BigEndian.Uint64(msg.Data[rp:]) - rp += 8 - replyNow := msg.Data[rp] - rp += 1 - h := &ServerHeartbeat{ServerWalEnd: serverWalEnd, ServerTime: serverTime, ReplyRequested: replyNow} - return &ReplicationMessage{ServerHeartbeat: h}, nil - default: - if rc.c.shouldLog(LogLevelError) { - rc.c.log(LogLevelError, "Unexpected data playload message type", map[string]interface{}{"type": msgType}) - } - } - default: - if rc.c.shouldLog(LogLevelError) { - rc.c.log(LogLevelError, "Unexpected replication message type", map[string]interface{}{"type": msg}) - } - } - return -} - -// Wait for a single replication message. -// -// Properly using this requires some knowledge of the postgres replication mechanisms, -// as the client can receive both WAL data (the ultimate payload) and server heartbeat -// updates. The caller also must send standby status updates in order to keep the connection -// alive and working. -// -// This returns the context error when there is no replication message before -// the context is canceled. -func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*ReplicationMessage, error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - go func() { - select { - case <-ctx.Done(): - if err := rc.c.conn.SetDeadline(time.Now()); err != nil { - rc.Close() // Close connection if unable to set deadline - return - } - rc.c.closedChan <- ctx.Err() - case <-rc.c.doneChan: - } - }() - - r, opErr := rc.readReplicationMessage() - - var err error - select { - case err = <-rc.c.closedChan: - if err := rc.c.conn.SetDeadline(time.Time{}); err != nil { - rc.Close() // Close connection if unable to disable deadline - return nil, err - } - - if opErr == nil { - err = nil - } - case rc.c.doneChan <- struct{}{}: - err = opErr - } - - return r, err -} - -func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { - rc.c.lastActivityTime = time.Now() - - rows := rc.c.getRows(sql, nil) - - if err := rc.c.lock(); err != nil { - rows.fatal(err) - return rows, err - } - rows.unlockConn = true - - err := rc.c.sendSimpleQuery(sql) - if err != nil { - rows.fatal(err) - } - - msg, err := rc.c.rxMsg() - if err != nil { - return nil, err - } - - switch msg := msg.(type) { - case *pgproto3.RowDescription: - rows.fields = rc.c.rxRowDescription(msg) - // We don't have c.PgTypes here because we're a replication - // connection. This means the field descriptions will have - // only OIDs. Not much we can do about this. - default: - if e := rc.c.processContextFreeMsg(msg); e != nil { - rows.fatal(e) - return rows, e - } - } - - return rows, rows.err -} - -// Execute the "IDENTIFY_SYSTEM" command as documented here: -// https://www.postgresql.org/docs/9.5/static/protocol-replication.html -// -// This will return (if successful) a result set that has a single row -// that contains the systemid, current timeline, xlogpos and database -// name. -// -// NOTE: Because this is a replication mode connection, we don't have -// type names, so the field descriptions in the result will have only -// OIDs and no DataTypeName values -func (rc *ReplicationConn) IdentifySystem() (r *Rows, err error) { - return rc.sendReplicationModeQuery("IDENTIFY_SYSTEM") -} - -// Execute the "TIMELINE_HISTORY" command as documented here: -// https://www.postgresql.org/docs/9.5/static/protocol-replication.html -// -// This will return (if successful) a result set that has a single row -// that contains the filename of the history file and the content -// of the history file. If called for timeline 1, typically this will -// generate an error that the timeline history file does not exist. -// -// NOTE: Because this is a replication mode connection, we don't have -// type names, so the field descriptions in the result will have only -// OIDs and no DataTypeName values -func (rc *ReplicationConn) TimelineHistory(timeline int) (r *Rows, err error) { - return rc.sendReplicationModeQuery(fmt.Sprintf("TIMELINE_HISTORY %d", timeline)) -} - -// Start a replication connection, sending WAL data to the given replication -// receiver. This function wraps a START_REPLICATION command as documented -// here: -// https://www.postgresql.org/docs/9.5/static/protocol-replication.html -// -// Once started, the client needs to invoke WaitForReplicationMessage() in order -// to fetch the WAL and standby status. Also, it is the responsibility of the caller -// to periodically send StandbyStatus messages to update the replication slot position. -// -// This function assumes that slotName has already been created. In order to omit the timeline argument -// pass a -1 for the timeline to get the server default behavior. -func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, timeline int64, pluginArguments ...string) (err error) { - var queryString string - if timeline >= 0 { - queryString = fmt.Sprintf("START_REPLICATION SLOT %s LOGICAL %s TIMELINE %d", slotName, FormatLSN(startLsn), timeline) - } else { - queryString = fmt.Sprintf("START_REPLICATION SLOT %s LOGICAL %s", slotName, FormatLSN(startLsn)) - } - - for _, arg := range pluginArguments { - queryString += fmt.Sprintf(" %s", arg) - } - - if err = rc.c.sendQuery(queryString); err != nil { - return - } - - ctx, cancelFn := context.WithTimeout(context.Background(), initialReplicationResponseTimeout) - defer cancelFn() - - // The first replication message that comes back here will be (in a success case) - // a empty CopyBoth that is (apparently) sent as the confirmation that the replication has - // started. This call will either return nil, nil or if it returns an error - // that indicates the start replication command failed - var r *ReplicationMessage - r, err = rc.WaitForReplicationMessage(ctx) - if err != nil && r != nil { - if rc.c.shouldLog(LogLevelError) { - rc.c.log(LogLevelError, "Unexpected replication message", map[string]interface{}{"msg": r, "err": err}) - } - } - - return -} - -// Create the replication slot, using the given name and output plugin. -func (rc *ReplicationConn) CreateReplicationSlot(slotName, outputPlugin string) (err error) { - _, err = rc.c.Exec(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s", slotName, outputPlugin)) - return -} - -// Create the replication slot, using the given name and output plugin, and return the consistent_point and snapshot_name values. -func (rc *ReplicationConn) CreateReplicationSlotEx(slotName, outputPlugin string) (consistentPoint string, snapshotName string, err error) { - var dummy string - var rows *Rows - rows, err = rc.sendReplicationModeQuery(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s", slotName, outputPlugin)) - defer rows.Close() - for rows.Next() { - rows.Scan(&dummy, &consistentPoint, &snapshotName, &dummy) - } - return -} - -// Drop the replication slot for the given name -func (rc *ReplicationConn) DropReplicationSlot(slotName string) (err error) { - _, err = rc.c.Exec(fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName)) - return -} diff --git a/replication_test.go b/replication_test.go deleted file mode 100644 index d06d73cd9..000000000 --- a/replication_test.go +++ /dev/null @@ -1,345 +0,0 @@ -package pgx_test - -import ( - "context" - "fmt" - "reflect" - "strconv" - "strings" - "testing" - "time" - - "github.com/jackc/pgx" -) - -// This function uses a postgresql 9.6 specific column -func getConfirmedFlushLsnFor(t *testing.T, conn *pgx.Conn, slot string) string { - // Fetch the restart LSN of the slot, to establish a starting point - rows, err := conn.Query(fmt.Sprintf("select confirmed_flush_lsn from pg_replication_slots where slot_name='%s'", slot)) - if err != nil { - t.Fatalf("conn.Query failed: %v", err) - } - defer rows.Close() - - var restartLsn string - for rows.Next() { - rows.Scan(&restartLsn) - } - return restartLsn -} - -// This battleship test (at least somewhat by necessity) does -// several things all at once in a single run. It: -// - Establishes a replication connection & slot -// - Does a series of operations to create some known WAL entries -// - Replicates the entries down, and checks that the rows it -// created come down in order -// - Sends a standby status message to update the server with the -// wal position of the slot -// - Checks the wal position of the slot on the server to make sure -// the update succeeded -func TestSimpleReplicationConnection(t *testing.T) { - var err error - - if replicationConnConfig == nil { - t.Skip("Skipping due to undefined replicationConnConfig") - } - - conn := mustConnect(t, *replicationConnConfig) - defer func() { - // Ensure replication slot is destroyed, but don't check for errors as it - // should have already been destroyed. - conn.Exec("select pg_drop_replication_slot('pgx_test')") - closeConn(t, conn) - }() - - replicationConn := mustReplicationConnect(t, *replicationConnConfig) - defer closeReplicationConn(t, replicationConn) - - var cp string - var snapshot_name string - cp, snapshot_name, err = replicationConn.CreateReplicationSlotEx("pgx_test", "test_decoding") - if err != nil { - t.Fatalf("replication slot create failed: %v", err) - } - if cp == "" { - t.Logf("consistent_point is empty") - } - if snapshot_name == "" { - t.Logf("snapshot_name is empty") - } - - // Do a simple change so we can get some wal data - _, err = conn.Exec("create table if not exists replication_test (a integer)") - if err != nil { - t.Fatalf("Failed to create table: %v", err) - } - - err = replicationConn.StartReplication("pgx_test", 0, -1) - if err != nil { - t.Fatalf("Failed to start replication: %v", err) - } - - var insertedTimes []int64 - currentTime := time.Now().Unix() - - for i := 0; i < 5; i++ { - var ct pgx.CommandTag - insertedTimes = append(insertedTimes, currentTime) - ct, err = conn.Exec("insert into replication_test(a) values($1)", currentTime) - if err != nil { - t.Fatalf("Insert failed: %v", err) - } - t.Logf("Inserted %d rows", ct.RowsAffected()) - currentTime++ - } - - var foundTimes []int64 - var foundCount int - var maxWal uint64 - - ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) - defer cancelFn() - - for { - var message *pgx.ReplicationMessage - - message, err = replicationConn.WaitForReplicationMessage(ctx) - if err != nil { - t.Fatalf("Replication failed: %v %s", err, reflect.TypeOf(err)) - } - - if message.WalMessage != nil { - // The waldata payload with the test_decoding plugin looks like: - // public.replication_test: INSERT: a[integer]:2 - // What we wanna do here is check that once we find one of our inserted times, - // that they occur in the wal stream in the order we executed them. - walString := string(message.WalMessage.WalData) - if strings.Contains(walString, "public.replication_test: INSERT") { - stringParts := strings.Split(walString, ":") - offset, err := strconv.ParseInt(stringParts[len(stringParts)-1], 10, 64) - if err != nil { - t.Fatalf("Failed to parse walString %s", walString) - } - if foundCount > 0 || offset == insertedTimes[0] { - foundTimes = append(foundTimes, offset) - foundCount++ - } - if foundCount == len(insertedTimes) { - break - } - } - if message.WalMessage.WalStart > maxWal { - maxWal = message.WalMessage.WalStart - } - - } - if message.ServerHeartbeat != nil { - t.Logf("Got heartbeat: %s", message.ServerHeartbeat) - } - } - - for i := range insertedTimes { - if foundTimes[i] != insertedTimes[i] { - t.Fatalf("Found %d expected %d", foundTimes[i], insertedTimes[i]) - } - } - - t.Logf("Found %d times, as expected", len(foundTimes)) - - // Before closing our connection, let's send a standby status to update our wal - // position, which should then be reflected if we fetch out our current wal position - // for the slot - status, err := pgx.NewStandbyStatus(maxWal) - if err != nil { - t.Errorf("Failed to create standby status %v", err) - } - replicationConn.SendStandbyStatus(status) - - restartLsn := getConfirmedFlushLsnFor(t, conn, "pgx_test") - integerRestartLsn, _ := pgx.ParseLSN(restartLsn) - if integerRestartLsn != maxWal { - t.Fatalf("Wal offset update failed, expected %s found %s", pgx.FormatLSN(maxWal), restartLsn) - } - - closeReplicationConn(t, replicationConn) - - replicationConn2 := mustReplicationConnect(t, *replicationConnConfig) - defer closeReplicationConn(t, replicationConn2) - - err = replicationConn2.DropReplicationSlot("pgx_test") - if err != nil { - t.Fatalf("Failed to drop replication slot: %v", err) - } - - droppedLsn := getConfirmedFlushLsnFor(t, conn, "pgx_test") - if droppedLsn != "" { - t.Errorf("Got odd flush lsn %s for supposedly dropped slot", droppedLsn) - } -} - -func TestReplicationConn_DropReplicationSlot(t *testing.T) { - if replicationConnConfig == nil { - t.Skip("Skipping due to undefined replicationConnConfig") - } - - replicationConn := mustReplicationConnect(t, *replicationConnConfig) - defer closeReplicationConn(t, replicationConn) - - var cp string - var snapshot_name string - cp, snapshot_name, err := replicationConn.CreateReplicationSlotEx("pgx_slot_test", "test_decoding") - if err != nil { - t.Logf("replication slot create failed: %v", err) - } - if cp == "" { - t.Logf("consistent_point is empty") - } - if snapshot_name == "" { - t.Logf("snapshot_name is empty") - } - - err = replicationConn.DropReplicationSlot("pgx_slot_test") - if err != nil { - t.Fatalf("Failed to drop replication slot: %v", err) - } - - // We re-create to ensure the drop worked. - cp, snapshot_name, err = replicationConn.CreateReplicationSlotEx("pgx_slot_test", "test_decoding") - if err != nil { - t.Logf("replication slot create failed: %v", err) - } - if cp == "" { - t.Logf("consistent_point is empty") - } - if snapshot_name == "" { - t.Logf("snapshot_name is empty") - } - - // And finally we drop to ensure we don't leave dirty state - err = replicationConn.DropReplicationSlot("pgx_slot_test") - if err != nil { - t.Fatalf("Failed to drop replication slot: %v", err) - } -} - -func TestIdentifySystem(t *testing.T) { - if replicationConnConfig == nil { - t.Skip("Skipping due to undefined replicationConnConfig") - } - - replicationConn2 := mustReplicationConnect(t, *replicationConnConfig) - defer closeReplicationConn(t, replicationConn2) - - r, err := replicationConn2.IdentifySystem() - if err != nil { - t.Error(err) - } - defer r.Close() - for _, fd := range r.FieldDescriptions() { - t.Logf("Field: %s of type %v", fd.Name, fd.DataType) - } - - var rowCount int - for r.Next() { - rowCount++ - values, err := r.Values() - if err != nil { - t.Error(err) - } - t.Logf("Row values: %v", values) - } - if r.Err() != nil { - t.Error(r.Err()) - } - - if rowCount == 0 { - t.Errorf("Failed to find any rows: %d", rowCount) - } -} - -func getCurrentTimeline(t *testing.T, rc *pgx.ReplicationConn) int { - r, err := rc.IdentifySystem() - if err != nil { - t.Error(err) - } - defer r.Close() - for r.Next() { - values, e := r.Values() - if e != nil { - t.Error(e) - } - return int(values[1].(int32)) - } - t.Fatal("Failed to read timeline") - return -1 -} - -func TestGetTimelineHistory(t *testing.T) { - if replicationConnConfig == nil { - t.Skip("Skipping due to undefined replicationConnConfig") - } - - replicationConn := mustReplicationConnect(t, *replicationConnConfig) - defer closeReplicationConn(t, replicationConn) - - timeline := getCurrentTimeline(t, replicationConn) - - r, err := replicationConn.TimelineHistory(timeline) - if err != nil { - t.Errorf("%#v", err) - } - defer r.Close() - - for _, fd := range r.FieldDescriptions() { - t.Logf("Field: %s of type %v", fd.Name, fd.DataType) - } - - var rowCount int - for r.Next() { - rowCount++ - values, err := r.Values() - if err != nil { - t.Error(err) - } - t.Logf("Row values: %v", values) - } - if r.Err() != nil { - if strings.Contains(r.Err().Error(), "No such file or directory") { - // This is normal, this means the timeline we're on has no - // history, which is the common case in a test db that - // has only one timeline - return - } - t.Error(r.Err()) - } - - // If we have a timeline history (see above) there should have been - // rows emitted - if rowCount == 0 { - t.Errorf("Failed to find any rows: %d", rowCount) - } -} - -func TestStandbyStatusParsing(t *testing.T) { - // Let's push the boundary conditions of the standby status and ensure it errors correctly - status, err := pgx.NewStandbyStatus(0, 1, 2, 3, 4) - if err == nil { - t.Errorf("Expected error from new standby status, got %v", status) - } - - // And if you provide 3 args, ensure the right fields are set - status, err = pgx.NewStandbyStatus(1, 2, 3) - if err != nil { - t.Errorf("Failed to create test status: %v", err) - } - if status.WalFlushPosition != 1 { - t.Errorf("Unexpected flush position %d", status.WalFlushPosition) - } - if status.WalApplyPosition != 2 { - t.Errorf("Unexpected apply position %d", status.WalApplyPosition) - } - if status.WalWritePosition != 3 { - t.Errorf("Unexpected write position %d", status.WalWritePosition) - } -} diff --git a/rows.go b/rows.go new file mode 100644 index 000000000..ac02ba9a0 --- /dev/null +++ b/rows.go @@ -0,0 +1,871 @@ +package pgx + +import ( + "context" + "errors" + "fmt" + "reflect" + "strings" + "sync" + "time" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" +) + +// Rows is the result set returned from *Conn.Query. Rows must be closed before +// the *Conn can be used again. Rows are closed by explicitly calling Close(), +// calling Next() until it returns false, or when a fatal error occurs. +// +// Once a Rows is closed the only methods that may be called are Close(), Err(), +// and CommandTag(). +// +// Rows is an interface instead of a struct to allow tests to mock Query. However, +// adding a method to an interface is technically a breaking change. Because of this +// the Rows interface is partially excluded from semantic version requirements. +// Methods will not be removed or changed, but new methods may be added. +type Rows interface { + // Close closes the rows, making the connection ready for use again. It is safe + // to call Close after rows is already closed. + Close() + + // Err returns any error that occurred while executing a query or reading its results. Err must be called after the + // Rows is closed (either by calling Close or by Next returning false) to check if the query was successful. If it is + // called before the Rows is closed it may return nil even if the query failed on the server. + Err() error + + // CommandTag returns the command tag from this query. It is only available after Rows is closed. + CommandTag() pgconn.CommandTag + + // FieldDescriptions returns the field descriptions of the columns. It may return nil. In particular this can occur + // when there was an error executing the query. + FieldDescriptions() []pgconn.FieldDescription + + // Next prepares the next row for reading. It returns true if there is another row and false if no more rows are + // available or a fatal error has occurred. It automatically closes rows upon returning false (whether due to all rows + // having been read or due to an error). + // + // Callers should check rows.Err() after rows.Next() returns false to detect whether result-set reading ended + // prematurely due to an error. See Conn.Query for details. + // + // For simpler error handling, consider using the higher-level pgx v5 CollectRows() and ForEachRow() helpers instead. + Next() bool + + // Scan reads the values from the current row into dest values positionally. dest can include pointers to core types, + // values implementing the Scanner interface, and nil. nil will skip the value entirely. It is an error to call Scan + // without first calling Next() and checking that it returned true. Rows is automatically closed upon error. + Scan(dest ...any) error + + // Values returns the decoded row values. As with Scan(), it is an error to + // call Values without first calling Next() and checking that it returned + // true. + Values() ([]any, error) + + // RawValues returns the unparsed bytes of the row values. The returned data is only valid until the next Next + // call or the Rows is closed. + RawValues() [][]byte + + // Conn returns the underlying *Conn on which the query was executed. This may return nil if Rows did not come from a + // *Conn (e.g. if it was created by RowsFromResultReader) + Conn() *Conn +} + +// Row is a convenience wrapper over Rows that is returned by QueryRow. +// +// Row is an interface instead of a struct to allow tests to mock QueryRow. However, +// adding a method to an interface is technically a breaking change. Because of this +// the Row interface is partially excluded from semantic version requirements. +// Methods will not be removed or changed, but new methods may be added. +type Row interface { + // Scan works the same as Rows. with the following exceptions. If no + // rows were found it returns ErrNoRows. If multiple rows are returned it + // ignores all but the first. + Scan(dest ...any) error +} + +// RowScanner scans an entire row at a time into the RowScanner. +type RowScanner interface { + // ScanRows scans the row. + ScanRow(rows Rows) error +} + +// connRow implements the Row interface for Conn.QueryRow. +type connRow baseRows + +func (r *connRow) Scan(dest ...any) (err error) { + rows := (*baseRows)(r) + + if rows.Err() != nil { + return rows.Err() + } + + for _, d := range dest { + if _, ok := d.(*pgtype.DriverBytes); ok { + rows.Close() + return fmt.Errorf("cannot scan into *pgtype.DriverBytes from QueryRow") + } + } + + if !rows.Next() { + if rows.Err() == nil { + return ErrNoRows + } + return rows.Err() + } + + rows.Scan(dest...) + rows.Close() + return rows.Err() +} + +// baseRows implements the Rows interface for Conn.Query. +type baseRows struct { + typeMap *pgtype.Map + resultReader *pgconn.ResultReader + + values [][]byte + + commandTag pgconn.CommandTag + err error + closed bool + + scanPlans []pgtype.ScanPlan + scanTypes []reflect.Type + + conn *Conn + multiResultReader *pgconn.MultiResultReader + + queryTracer QueryTracer + batchTracer BatchTracer + ctx context.Context + startTime time.Time + sql string + args []any + rowCount int +} + +func (rows *baseRows) FieldDescriptions() []pgconn.FieldDescription { + return rows.resultReader.FieldDescriptions() +} + +func (rows *baseRows) Close() { + if rows.closed { + return + } + + rows.closed = true + + if rows.resultReader != nil { + var closeErr error + rows.commandTag, closeErr = rows.resultReader.Close() + if rows.err == nil { + rows.err = closeErr + } + } + + if rows.multiResultReader != nil { + closeErr := rows.multiResultReader.Close() + if rows.err == nil { + rows.err = closeErr + } + } + + if rows.err != nil && rows.conn != nil && rows.sql != "" { + if sc := rows.conn.statementCache; sc != nil { + sc.Invalidate(rows.sql) + } + + if sc := rows.conn.descriptionCache; sc != nil { + sc.Invalidate(rows.sql) + } + } + + if rows.batchTracer != nil { + rows.batchTracer.TraceBatchQuery(rows.ctx, rows.conn, TraceBatchQueryData{SQL: rows.sql, Args: rows.args, CommandTag: rows.commandTag, Err: rows.err}) + } else if rows.queryTracer != nil { + rows.queryTracer.TraceQueryEnd(rows.ctx, rows.conn, TraceQueryEndData{rows.commandTag, rows.err}) + } + + // Zero references to other memory allocations. This allows them to be GC'd even when the Rows still referenced. In + // particular, when using pgxpool GC could be delayed as pgxpool.poolRows are allocated in large slices. + // + // https://github.com/jackc/pgx/pull/2269 + rows.values = nil + rows.scanPlans = nil + rows.scanTypes = nil + rows.ctx = nil + rows.sql = "" + rows.args = nil +} + +func (rows *baseRows) CommandTag() pgconn.CommandTag { + return rows.commandTag +} + +func (rows *baseRows) Err() error { + return rows.err +} + +// fatal signals an error occurred after the query was sent to the server. It +// closes the rows automatically. +func (rows *baseRows) fatal(err error) { + if rows.err != nil { + return + } + + rows.err = err + rows.Close() +} + +func (rows *baseRows) Next() bool { + if rows.closed { + return false + } + + if rows.resultReader.NextRow() { + rows.rowCount++ + rows.values = rows.resultReader.Values() + return true + } else { + rows.Close() + return false + } +} + +func (rows *baseRows) Scan(dest ...any) error { + m := rows.typeMap + fieldDescriptions := rows.FieldDescriptions() + values := rows.values + + if len(fieldDescriptions) != len(values) { + err := fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values)) + rows.fatal(err) + return err + } + + if len(dest) == 1 { + if rc, ok := dest[0].(RowScanner); ok { + err := rc.ScanRow(rows) + if err != nil { + rows.fatal(err) + } + return err + } + } + + if len(fieldDescriptions) != len(dest) { + err := fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest)) + rows.fatal(err) + return err + } + + if rows.scanPlans == nil { + rows.scanPlans = make([]pgtype.ScanPlan, len(values)) + rows.scanTypes = make([]reflect.Type, len(values)) + for i := range dest { + rows.scanPlans[i] = m.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i]) + rows.scanTypes[i] = reflect.TypeOf(dest[i]) + } + } + + for i, dst := range dest { + if dst == nil { + continue + } + + if rows.scanTypes[i] != reflect.TypeOf(dst) { + rows.scanPlans[i] = m.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i]) + rows.scanTypes[i] = reflect.TypeOf(dest[i]) + } + + err := rows.scanPlans[i].Scan(values[i], dst) + if err != nil { + err = ScanArgError{ColumnIndex: i, FieldName: fieldDescriptions[i].Name, Err: err} + rows.fatal(err) + return err + } + } + + return nil +} + +func (rows *baseRows) Values() ([]any, error) { + if rows.closed { + return nil, errors.New("rows is closed") + } + + values := make([]any, 0, len(rows.FieldDescriptions())) + + for i := range rows.FieldDescriptions() { + buf := rows.values[i] + fd := &rows.FieldDescriptions()[i] + + if buf == nil { + values = append(values, nil) + continue + } + + if dt, ok := rows.typeMap.TypeForOID(fd.DataTypeOID); ok { + value, err := dt.Codec.DecodeValue(rows.typeMap, fd.DataTypeOID, fd.Format, buf) + if err != nil { + rows.fatal(err) + } + values = append(values, value) + } else { + switch fd.Format { + case TextFormatCode: + values = append(values, string(buf)) + case BinaryFormatCode: + newBuf := make([]byte, len(buf)) + copy(newBuf, buf) + values = append(values, newBuf) + default: + rows.fatal(errors.New("unknown format code")) + } + } + + if rows.Err() != nil { + return nil, rows.Err() + } + } + + return values, rows.Err() +} + +func (rows *baseRows) RawValues() [][]byte { + return rows.values +} + +func (rows *baseRows) Conn() *Conn { + return rows.conn +} + +type ScanArgError struct { + ColumnIndex int + FieldName string + Err error +} + +func (e ScanArgError) Error() string { + if e.FieldName == "?column?" { // Don't include the fieldname if it's unknown + return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err) + } + + return fmt.Sprintf("can't scan into dest[%d] (col: %s): %v", e.ColumnIndex, e.FieldName, e.Err) +} + +func (e ScanArgError) Unwrap() error { + return e.Err +} + +// ScanRow decodes raw row data into dest. It can be used to scan rows read from the lower level pgconn interface. +// +// typeMap - OID to Go type mapping. +// fieldDescriptions - OID and format of values +// values - the raw data as returned from the PostgreSQL server +// dest - the destination that values will be decoded into +func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgconn.FieldDescription, values [][]byte, dest ...any) error { + if len(fieldDescriptions) != len(values) { + return fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values)) + } + if len(fieldDescriptions) != len(dest) { + return fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest)) + } + + for i, d := range dest { + if d == nil { + continue + } + + err := typeMap.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d) + if err != nil { + return ScanArgError{ColumnIndex: i, FieldName: fieldDescriptions[i].Name, Err: err} + } + } + + return nil +} + +// RowsFromResultReader returns a Rows that will read from values resultReader and decode with typeMap. It can be used +// to read from the lower level pgconn interface. +func RowsFromResultReader(typeMap *pgtype.Map, resultReader *pgconn.ResultReader) Rows { + return &baseRows{ + typeMap: typeMap, + resultReader: resultReader, + } +} + +// ForEachRow iterates through rows. For each row it scans into the elements of scans and calls fn. If any row +// fails to scan or fn returns an error the query will be aborted and the error will be returned. Rows will be closed +// when ForEachRow returns. +func ForEachRow(rows Rows, scans []any, fn func() error) (pgconn.CommandTag, error) { + defer rows.Close() + + for rows.Next() { + err := rows.Scan(scans...) + if err != nil { + return pgconn.CommandTag{}, err + } + + err = fn() + if err != nil { + return pgconn.CommandTag{}, err + } + } + + if err := rows.Err(); err != nil { + return pgconn.CommandTag{}, err + } + + return rows.CommandTag(), nil +} + +// CollectableRow is the subset of Rows methods that a RowToFunc is allowed to call. +type CollectableRow interface { + FieldDescriptions() []pgconn.FieldDescription + Scan(dest ...any) error + Values() ([]any, error) + RawValues() [][]byte +} + +// RowToFunc is a function that scans or otherwise converts row to a T. +type RowToFunc[T any] func(row CollectableRow) (T, error) + +// AppendRows iterates through rows, calling fn for each row, and appending the results into a slice of T. +// +// This function closes the rows automatically on return. +func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) { + defer rows.Close() + + for rows.Next() { + value, err := fn(rows) + if err != nil { + return nil, err + } + slice = append(slice, value) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return slice, nil +} + +// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T. +// +// This function closes the rows automatically on return. +func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) { + return AppendRows([]T{}, rows, fn) +} + +// CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true. +// CollectOneRow is to CollectRows as QueryRow is to Query. +// +// This function closes the rows automatically on return. +func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { + defer rows.Close() + + var value T + var err error + + if !rows.Next() { + if err = rows.Err(); err != nil { + return value, err + } + return value, ErrNoRows + } + + value, err = fn(rows) + if err != nil { + return value, err + } + + // The defer rows.Close() won't have executed yet. If the query returned more than one row, rows would still be open. + // rows.Close() must be called before rows.Err() so we explicitly call it here. + rows.Close() + return value, rows.Err() +} + +// CollectExactlyOneRow calls fn for the first row in rows and returns the result. +// - If no rows are found returns an error where errors.Is(ErrNoRows) is true. +// - If more than 1 row is found returns an error where errors.Is(ErrTooManyRows) is true. +// +// This function closes the rows automatically on return. +func CollectExactlyOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { + defer rows.Close() + + var ( + err error + value T + ) + + if !rows.Next() { + if err = rows.Err(); err != nil { + return value, err + } + + return value, ErrNoRows + } + + value, err = fn(rows) + if err != nil { + return value, err + } + + if rows.Next() { + var zero T + + return zero, ErrTooManyRows + } + + return value, rows.Err() +} + +// RowTo returns a T scanned from row. +func RowTo[T any](row CollectableRow) (T, error) { + var value T + err := row.Scan(&value) + return value, err +} + +// RowTo returns a the address of a T scanned from row. +func RowToAddrOf[T any](row CollectableRow) (*T, error) { + var value T + err := row.Scan(&value) + return &value, err +} + +// RowToMap returns a map scanned from row. +func RowToMap(row CollectableRow) (map[string]any, error) { + var value map[string]any + err := row.Scan((*mapRowScanner)(&value)) + return value, err +} + +type mapRowScanner map[string]any + +func (rs *mapRowScanner) ScanRow(rows Rows) error { + values, err := rows.Values() + if err != nil { + return err + } + + *rs = make(mapRowScanner, len(values)) + + for i := range values { + (*rs)[string(rows.FieldDescriptions()[i].Name)] = values[i] + } + + return nil +} + +// RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number of public fields as row +// has fields. The row and T fields will be matched by position. If the "db" struct tag is "-" then the field will be +// ignored. +func RowToStructByPos[T any](row CollectableRow) (T, error) { + var value T + err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row) + return value, err +} + +// RowToAddrOfStructByPos returns the address of a T scanned from row. T must be a struct. T must have the same number a +// public fields as row has fields. The row and T fields will be matched by position. If the "db" struct tag is "-" then +// the field will be ignored. +func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) { + var value T + err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row) + return &value, err +} + +type positionalStructRowScanner struct { + ptrToStruct any +} + +func (rs *positionalStructRowScanner) ScanRow(rows CollectableRow) error { + typ := reflect.TypeOf(rs.ptrToStruct).Elem() + fields := lookupStructFields(typ) + if len(rows.RawValues()) > len(fields) { + return fmt.Errorf( + "got %d values, but dst struct has only %d fields", + len(rows.RawValues()), + len(fields), + ) + } + scanTargets := setupStructScanTargets(rs.ptrToStruct, fields) + return rows.Scan(scanTargets...) +} + +// Map from reflect.Type -> []structRowField +var positionalStructFieldMap sync.Map + +func lookupStructFields(t reflect.Type) []structRowField { + if cached, ok := positionalStructFieldMap.Load(t); ok { + return cached.([]structRowField) + } + + fieldStack := make([]int, 0, 1) + fields := computeStructFields(t, make([]structRowField, 0, t.NumField()), &fieldStack) + fieldsIface, _ := positionalStructFieldMap.LoadOrStore(t, fields) + return fieldsIface.([]structRowField) +} + +func computeStructFields( + t reflect.Type, + fields []structRowField, + fieldStack *[]int, +) []structRowField { + tail := len(*fieldStack) + *fieldStack = append(*fieldStack, 0) + for i := 0; i < t.NumField(); i++ { + sf := t.Field(i) + (*fieldStack)[tail] = i + // Handle anonymous struct embedding, but do not try to handle embedded pointers. + if sf.Anonymous && sf.Type.Kind() == reflect.Struct { + fields = computeStructFields(sf.Type, fields, fieldStack) + } else if sf.PkgPath == "" { + dbTag, _ := sf.Tag.Lookup(structTagKey) + if dbTag == "-" { + // Field is ignored, skip it. + continue + } + fields = append(fields, structRowField{ + path: append([]int(nil), *fieldStack...), + }) + } + } + *fieldStack = (*fieldStack)[:tail] + return fields +} + +// RowToStructByName returns a T scanned from row. T must be a struct. T must have the same number of named public +// fields as row has fields. The row and T fields will be matched by name. The match is case-insensitive. The database +// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored. +func RowToStructByName[T any](row CollectableRow) (T, error) { + var value T + err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row) + return value, err +} + +// RowToAddrOfStructByName returns the address of a T scanned from row. T must be a struct. T must have the same number +// of named public fields as row has fields. The row and T fields will be matched by name. The match is +// case-insensitive. The database column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" +// then the field will be ignored. +func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) { + var value T + err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row) + return &value, err +} + +// RowToStructByNameLax returns a T scanned from row. T must be a struct. T must have greater than or equal number of named public +// fields as row has fields. The row and T fields will be matched by name. The match is case-insensitive. The database +// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored. +func RowToStructByNameLax[T any](row CollectableRow) (T, error) { + var value T + err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row) + return value, err +} + +// RowToAddrOfStructByNameLax returns the address of a T scanned from row. T must be a struct. T must have greater than or +// equal number of named public fields as row has fields. The row and T fields will be matched by name. The match is +// case-insensitive. The database column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" +// then the field will be ignored. +func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) { + var value T + err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row) + return &value, err +} + +type namedStructRowScanner struct { + ptrToStruct any + lax bool +} + +func (rs *namedStructRowScanner) ScanRow(rows CollectableRow) error { + typ := reflect.TypeOf(rs.ptrToStruct).Elem() + fldDescs := rows.FieldDescriptions() + namedStructFields, err := lookupNamedStructFields(typ, fldDescs) + if err != nil { + return err + } + if !rs.lax && namedStructFields.missingField != "" { + return fmt.Errorf("cannot find field %s in returned row", namedStructFields.missingField) + } + fields := namedStructFields.fields + scanTargets := setupStructScanTargets(rs.ptrToStruct, fields) + return rows.Scan(scanTargets...) +} + +// Map from namedStructFieldMap -> *namedStructFields +var namedStructFieldMap sync.Map + +type namedStructFieldsKey struct { + t reflect.Type + colNames string +} + +type namedStructFields struct { + fields []structRowField + // missingField is the first field from the struct without a corresponding row field. + // This is used to construct the correct error message for non-lax queries. + missingField string +} + +func lookupNamedStructFields( + t reflect.Type, + fldDescs []pgconn.FieldDescription, +) (*namedStructFields, error) { + key := namedStructFieldsKey{ + t: t, + colNames: joinFieldNames(fldDescs), + } + if cached, ok := namedStructFieldMap.Load(key); ok { + return cached.(*namedStructFields), nil + } + + // We could probably do two-levels of caching, where we compute the key -> fields mapping + // for a type only once, cache it by type, then use that to compute the column -> fields + // mapping for a given set of columns. + fieldStack := make([]int, 0, 1) + fields, missingField := computeNamedStructFields( + fldDescs, + t, + make([]structRowField, len(fldDescs)), + &fieldStack, + ) + for i, f := range fields { + if f.path == nil { + return nil, fmt.Errorf( + "struct doesn't have corresponding row field %s", + fldDescs[i].Name, + ) + } + } + + fieldsIface, _ := namedStructFieldMap.LoadOrStore( + key, + &namedStructFields{fields: fields, missingField: missingField}, + ) + return fieldsIface.(*namedStructFields), nil +} + +func joinFieldNames(fldDescs []pgconn.FieldDescription) string { + switch len(fldDescs) { + case 0: + return "" + case 1: + return fldDescs[0].Name + } + + totalSize := len(fldDescs) - 1 // Space for separator bytes. + for _, d := range fldDescs { + totalSize += len(d.Name) + } + var b strings.Builder + b.Grow(totalSize) + b.WriteString(fldDescs[0].Name) + for _, d := range fldDescs[1:] { + b.WriteByte(0) // Join with NUL byte as it's (presumably) not a valid column character. + b.WriteString(d.Name) + } + return b.String() +} + +func computeNamedStructFields( + fldDescs []pgconn.FieldDescription, + t reflect.Type, + fields []structRowField, + fieldStack *[]int, +) ([]structRowField, string) { + var missingField string + tail := len(*fieldStack) + *fieldStack = append(*fieldStack, 0) + for i := 0; i < t.NumField(); i++ { + sf := t.Field(i) + (*fieldStack)[tail] = i + if sf.PkgPath != "" && !sf.Anonymous { + // Field is unexported, skip it. + continue + } + // Handle anonymous struct embedding, but do not try to handle embedded pointers. + if sf.Anonymous && sf.Type.Kind() == reflect.Struct { + var missingSubField string + fields, missingSubField = computeNamedStructFields( + fldDescs, + sf.Type, + fields, + fieldStack, + ) + if missingField == "" { + missingField = missingSubField + } + } else { + dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey) + if dbTagPresent { + dbTag, _, _ = strings.Cut(dbTag, ",") + } + if dbTag == "-" { + // Field is ignored, skip it. + continue + } + colName := dbTag + if !dbTagPresent { + colName = sf.Name + } + fpos := fieldPosByName(fldDescs, colName, !dbTagPresent) + if fpos == -1 { + if missingField == "" { + missingField = colName + } + continue + } + fields[fpos] = structRowField{ + path: append([]int(nil), *fieldStack...), + } + } + } + *fieldStack = (*fieldStack)[:tail] + + return fields, missingField +} + +const structTagKey = "db" + +func fieldPosByName(fldDescs []pgconn.FieldDescription, field string, normalize bool) (i int) { + i = -1 + + if normalize { + field = strings.ReplaceAll(field, "_", "") + } + for i, desc := range fldDescs { + if normalize { + if strings.EqualFold(strings.ReplaceAll(desc.Name, "_", ""), field) { + return i + } + } else { + if desc.Name == field { + return i + } + } + } + return +} + +// structRowField describes a field of a struct. +// +// TODO: It would be a bit more efficient to track the path using the pointer +// offset within the (outermost) struct and use unsafe.Pointer arithmetic to +// construct references when scanning rows. However, it's not clear it's worth +// using unsafe for this. +type structRowField struct { + path []int +} + +func setupStructScanTargets(receiver any, fields []structRowField) []any { + scanTargets := make([]any, len(fields)) + v := reflect.ValueOf(receiver).Elem() + for i, f := range fields { + scanTargets[i] = v.FieldByIndex(f.path).Addr().Interface() + } + return scanTargets +} diff --git a/rows_test.go b/rows_test.go new file mode 100644 index 000000000..4cda957fc --- /dev/null +++ b/rows_test.go @@ -0,0 +1,995 @@ +package pgx_test + +import ( + "context" + "errors" + "fmt" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxtest" +) + +type testRowScanner struct { + name string + age int32 +} + +func (rs *testRowScanner) ScanRow(rows pgx.Rows) error { + return rows.Scan(&rs.name, &rs.age) +} + +func TestRowScanner(t *testing.T) { + t.Parallel() + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var s testRowScanner + err := conn.QueryRow(ctx, "select 'Adam' as name, 72 as height").Scan(&s) + require.NoError(t, err) + require.Equal(t, "Adam", s.name) + require.Equal(t, int32(72), s.age) + }) +} + +type testErrRowScanner string + +func (ers *testErrRowScanner) ScanRow(rows pgx.Rows) error { + return errors.New(string(*ers)) +} + +// https://github.com/jackc/pgx/issues/1654 +func TestRowScannerErrorIsFatalToRows(t *testing.T) { + t.Parallel() + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + s := testErrRowScanner("foo") + err := conn.QueryRow(ctx, "select 'Adam' as name, 72 as height").Scan(&s) + require.EqualError(t, err, "foo") + }) +} + +func TestForEachRow(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var actualResults []any + + rows, _ := conn.Query( + context.Background(), + "select n, n * 2 from generate_series(1, $1) n", + 3, + ) + var a, b int + ct, err := pgx.ForEachRow(rows, []any{&a, &b}, func() error { + actualResults = append(actualResults, []any{a, b}) + return nil + }) + require.NoError(t, err) + + expectedResults := []any{ + []any{1, 2}, + []any{2, 4}, + []any{3, 6}, + } + require.Equal(t, expectedResults, actualResults) + require.EqualValues(t, 3, ct.RowsAffected()) + }) +} + +func TestForEachRowScanError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var actualResults []any + + rows, _ := conn.Query( + context.Background(), + "select 'foo', 'bar' from generate_series(1, $1) n", + 3, + ) + var a, b int + ct, err := pgx.ForEachRow(rows, []any{&a, &b}, func() error { + actualResults = append(actualResults, []any{a, b}) + return nil + }) + require.EqualError(t, err, "can't scan into dest[0]: cannot scan text (OID 25) in text format into *int") + require.Equal(t, pgconn.CommandTag{}, ct) + }) +} + +func TestForEachRowAbort(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query( + context.Background(), + "select n, n * 2 from generate_series(1, $1) n", + 3, + ) + var a, b int + ct, err := pgx.ForEachRow(rows, []any{&a, &b}, func() error { + return errors.New("abort") + }) + require.EqualError(t, err, "abort") + require.Equal(t, pgconn.CommandTag{}, ct) + }) +} + +func ExampleForEachRow() { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + rows, _ := conn.Query( + context.Background(), + "select n, n * 2 from generate_series(1, $1) n", + 3, + ) + var a, b int + _, err = pgx.ForEachRow(rows, []any{&a, &b}, func() error { + fmt.Printf("%v, %v\n", a, b) + return nil + }) + if err != nil { + fmt.Printf("ForEachRow error: %v", err) + return + } + + // Output: + // 1, 2 + // 2, 4 + // 3, 6 +} + +func TestCollectRows(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`) + numbers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + require.NoError(t, err) + + assert.Len(t, numbers, 100) + for i := range numbers { + assert.Equal(t, int32(i), numbers[i]) + } + }) +} + +func TestCollectRowsEmpty(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select n from generate_series(1, 0) n`) + numbers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + require.NoError(t, err) + require.NotNil(t, numbers) + + assert.Empty(t, numbers) + }) +} + +// This example uses CollectRows with a manually written collector function. In most cases RowTo, RowToAddrOf, +// RowToStructByPos, RowToAddrOfStructByPos, or another generic function would be used. +func ExampleCollectRows() { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + rows, _ := conn.Query(ctx, `select n from generate_series(1, 5) n`) + numbers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + if err != nil { + fmt.Printf("CollectRows error: %v", err) + return + } + + fmt.Println(numbers) + + // Output: + // [1 2 3 4 5] +} + +func TestCollectOneRow(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 42`) + n, err := pgx.CollectOneRow(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + assert.NoError(t, err) + assert.Equal(t, int32(42), n) + }) +} + +func TestCollectOneRowNotFound(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 42 where false`) + n, err := pgx.CollectOneRow(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + assert.ErrorIs(t, err, pgx.ErrNoRows) + assert.Equal(t, int32(0), n) + }) +} + +func TestCollectOneRowIgnoresExtraRows(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select n from generate_series(42, 99) n`) + n, err := pgx.CollectOneRow(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + require.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, int32(42), n) + }) +} + +// https://github.com/jackc/pgx/issues/1334 +func TestCollectOneRowPrefersPostgreSQLErrorOverErrNoRows(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(ctx, `create temporary table t (name text not null unique)`) + require.NoError(t, err) + + var name string + rows, _ := conn.Query(ctx, `insert into t (name) values ('foo') returning name`) + name, err = pgx.CollectOneRow(rows, func(row pgx.CollectableRow) (string, error) { + var n string + err := row.Scan(&n) + return n, err + }) + require.NoError(t, err) + require.Equal(t, "foo", name) + + rows, _ = conn.Query(ctx, `insert into t (name) values ('foo') returning name`) + name, err = pgx.CollectOneRow(rows, func(row pgx.CollectableRow) (string, error) { + var n string + err := row.Scan(&n) + return n, err + }) + require.Error(t, err) + var pgErr *pgconn.PgError + require.ErrorAs(t, err, &pgErr) + require.Equal(t, "23505", pgErr.Code) + require.Equal(t, "", name) + }) +} + +func TestCollectExactlyOneRow(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 42`) + n, err := pgx.CollectExactlyOneRow(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + assert.NoError(t, err) + assert.Equal(t, int32(42), n) + }) +} + +func TestCollectExactlyOneRowNotFound(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 42 where false`) + n, err := pgx.CollectExactlyOneRow(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + assert.ErrorIs(t, err, pgx.ErrNoRows) + assert.Equal(t, int32(0), n) + }) +} + +func TestCollectExactlyOneRowExtraRows(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select n from generate_series(42, 99) n`) + n, err := pgx.CollectExactlyOneRow(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + assert.ErrorIs(t, err, pgx.ErrTooManyRows) + assert.Equal(t, int32(0), n) + }) +} + +func TestRowTo(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`) + numbers, err := pgx.CollectRows(rows, pgx.RowTo[int32]) + require.NoError(t, err) + + assert.Len(t, numbers, 100) + for i := range numbers { + assert.Equal(t, int32(i), numbers[i]) + } + }) +} + +func ExampleRowTo() { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + rows, _ := conn.Query(ctx, `select n from generate_series(1, 5) n`) + numbers, err := pgx.CollectRows(rows, pgx.RowTo[int32]) + if err != nil { + fmt.Printf("CollectRows error: %v", err) + return + } + + fmt.Println(numbers) + + // Output: + // [1 2 3 4 5] +} + +func TestRowToAddrOf(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`) + numbers, err := pgx.CollectRows(rows, pgx.RowToAddrOf[int32]) + require.NoError(t, err) + + assert.Len(t, numbers, 100) + for i := range numbers { + assert.Equal(t, int32(i), *numbers[i]) + } + }) +} + +func ExampleRowToAddrOf() { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + rows, _ := conn.Query(ctx, `select n from generate_series(1, 5) n`) + pNumbers, err := pgx.CollectRows(rows, pgx.RowToAddrOf[int32]) + if err != nil { + fmt.Printf("CollectRows error: %v", err) + return + } + + for _, p := range pNumbers { + fmt.Println(*p) + } + + // Output: + // 1 + // 2 + // 3 + // 4 + // 5 +} + +func TestRowToMap(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'Joe' as name, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToMap) + require.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Joe", slice[i]["name"]) + assert.EqualValues(t, i, slice[i]["age"]) + } + }) +} + +func TestRowToStructByPos(t *testing.T) { + type person struct { + Name string + Age int32 + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'Joe' as name, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToStructByPos[person]) + require.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Joe", slice[i].Name) + assert.EqualValues(t, i, slice[i].Age) + } + }) +} + +func TestRowToStructByPosIgnoredField(t *testing.T) { + type person struct { + Name string + Age int32 `db:"-"` + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'Joe' as name from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToStructByPos[person]) + require.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Joe", slice[i].Name) + } + }) +} + +func TestRowToStructByPosEmbeddedStruct(t *testing.T) { + type Name struct { + First string + Last string + } + + type person struct { + Name + Age int32 + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToStructByPos[person]) + require.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "John", slice[i].Name.First) + assert.Equal(t, "Smith", slice[i].Name.Last) + assert.EqualValues(t, i, slice[i].Age) + } + }) +} + +func TestRowToStructByPosMultipleEmbeddedStruct(t *testing.T) { + type Sandwich struct { + Bread string + Salad string + } + type Drink struct { + Ml int + } + + type meal struct { + Sandwich + Drink + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'Baguette' as bread, 'Lettuce' as salad, drink_ml from generate_series(0, 9) drink_ml`) + slice, err := pgx.CollectRows(rows, pgx.RowToStructByPos[meal]) + require.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Baguette", slice[i].Sandwich.Bread) + assert.Equal(t, "Lettuce", slice[i].Sandwich.Salad) + assert.EqualValues(t, i, slice[i].Drink.Ml) + } + }) +} + +func TestRowToStructByPosEmbeddedUnexportedStruct(t *testing.T) { + type name struct { + First string + Last string + } + + type person struct { + name + Age int32 + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToStructByPos[person]) + require.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "John", slice[i].name.First) + assert.Equal(t, "Smith", slice[i].name.Last) + assert.EqualValues(t, i, slice[i].Age) + } + }) +} + +// Pointer to struct is not supported. But check that we don't panic. +func TestRowToStructByPosEmbeddedPointerToStruct(t *testing.T) { + type Name struct { + First string + Last string + } + + type person struct { + *Name + Age int32 + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age from generate_series(0, 9) n`) + _, err := pgx.CollectRows(rows, pgx.RowToStructByPos[person]) + require.EqualError(t, err, "got 3 values, but dst struct has only 2 fields") + }) +} + +func ExampleRowToStructByPos() { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + if conn.PgConn().ParameterStatus("crdb_version") != "" { + // Skip test / example when running on CockroachDB. Since an example can't be skipped fake success instead. + fmt.Println(`Cheeseburger: $10 +Fries: $5 +Soft Drink: $3`) + return + } + + // Setup example schema and data. + _, err = conn.Exec(ctx, ` +create temporary table products ( + id int primary key generated by default as identity, + name varchar(100) not null, + price int not null +); + +insert into products (name, price) values + ('Cheeseburger', 10), + ('Double Cheeseburger', 14), + ('Fries', 5), + ('Soft Drink', 3); +`) + if err != nil { + fmt.Printf("Unable to setup example schema and data: %v", err) + return + } + + type product struct { + ID int32 + Name string + Price int32 + } + + rows, _ := conn.Query(ctx, "select * from products where price < $1 order by price desc", 12) + products, err := pgx.CollectRows(rows, pgx.RowToStructByPos[product]) + if err != nil { + fmt.Printf("CollectRows error: %v", err) + return + } + + for _, p := range products { + fmt.Printf("%s: $%d\n", p.Name, p.Price) + } + + // Output: + // Cheeseburger: $10 + // Fries: $5 + // Soft Drink: $3 +} + +func TestRowToAddrOfStructPos(t *testing.T) { + type person struct { + Name string + Age int32 + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'Joe' as name, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToAddrOfStructByPos[person]) + require.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Joe", slice[i].Name) + assert.EqualValues(t, i, slice[i].Age) + } + }) +} + +func TestRowToStructByName(t *testing.T) { + type person struct { + Last string + First string + Age int32 + AccountID string + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'John' as first, 'Smith' as last, n as age, 'd5e49d3f' as account_id from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToStructByName[person]) + assert.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Smith", slice[i].Last) + assert.Equal(t, "John", slice[i].First) + assert.EqualValues(t, i, slice[i].Age) + assert.Equal(t, "d5e49d3f", slice[i].AccountID) + } + + // check missing fields in a returned row + rows, _ = conn.Query(ctx, `select 'Smith' as last, n as age from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToStructByName[person]) + assert.ErrorContains(t, err, "cannot find field First in returned row") + + // check missing field in a destination struct + rows, _ = conn.Query(ctx, `select 'John' as first, 'Smith' as last, n as age, 'd5e49d3f' as account_id, null as ignore from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByName[person]) + assert.ErrorContains(t, err, "struct doesn't have corresponding row field ignore") + }) +} + +func TestRowToStructByNameDbTags(t *testing.T) { + type person struct { + Last string `db:"last_name"` + First string `db:"first_name"` + Age int32 `db:"age"` + AccountID string `db:"account_id"` + AnotherAccountID string `db:"account__id"` + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age, 'd5e49d3f' as account_id, '5e49d321' as account__id from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToStructByName[person]) + assert.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Smith", slice[i].Last) + assert.Equal(t, "John", slice[i].First) + assert.EqualValues(t, i, slice[i].Age) + assert.Equal(t, "d5e49d3f", slice[i].AccountID) + assert.Equal(t, "5e49d321", slice[i].AnotherAccountID) + } + + // check missing fields in a returned row + rows, _ = conn.Query(ctx, `select 'Smith' as last_name, n as age from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToStructByName[person]) + assert.ErrorContains(t, err, "cannot find field first_name in returned row") + + // check missing field in a destination struct + rows, _ = conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age, 'd5e49d3f' as account_id, '5e49d321' as account__id, null as ignore from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByName[person]) + assert.ErrorContains(t, err, "struct doesn't have corresponding row field ignore") + }) +} + +func TestRowToStructByNameEmbeddedStruct(t *testing.T) { + type Name struct { + Last string `db:"last_name"` + First string `db:"first_name"` + } + + type person struct { + Ignore bool `db:"-"` + Name + Age int32 + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToStructByName[person]) + assert.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Smith", slice[i].Name.Last) + assert.Equal(t, "John", slice[i].Name.First) + assert.EqualValues(t, i, slice[i].Age) + } + + // check missing fields in a returned row + rows, _ = conn.Query(ctx, `select 'Smith' as last_name, n as age from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToStructByName[person]) + assert.ErrorContains(t, err, "cannot find field first_name in returned row") + + // check missing field in a destination struct + rows, _ = conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age, null as ignore from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByName[person]) + assert.ErrorContains(t, err, "struct doesn't have corresponding row field ignore") + }) +} + +func ExampleRowToStructByName() { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + if conn.PgConn().ParameterStatus("crdb_version") != "" { + // Skip test / example when running on CockroachDB. Since an example can't be skipped fake success instead. + fmt.Println(`Cheeseburger: $10 +Fries: $5 +Soft Drink: $3`) + return + } + + // Setup example schema and data. + _, err = conn.Exec(ctx, ` +create temporary table products ( + id int primary key generated by default as identity, + name varchar(100) not null, + price int not null +); + +insert into products (name, price) values + ('Cheeseburger', 10), + ('Double Cheeseburger', 14), + ('Fries', 5), + ('Soft Drink', 3); +`) + if err != nil { + fmt.Printf("Unable to setup example schema and data: %v", err) + return + } + + type product struct { + ID int32 + Name string + Price int32 + } + + rows, _ := conn.Query(ctx, "select * from products where price < $1 order by price desc", 12) + products, err := pgx.CollectRows(rows, pgx.RowToStructByName[product]) + if err != nil { + fmt.Printf("CollectRows error: %v", err) + return + } + + for _, p := range products { + fmt.Printf("%s: $%d\n", p.Name, p.Price) + } + + // Output: + // Cheeseburger: $10 + // Fries: $5 + // Soft Drink: $3 +} + +func TestRowToStructByNameLax(t *testing.T) { + type person struct { + Last string + First string + Age int32 + Ignore bool `db:"-"` + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'John' as first, 'Smith' as last, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToStructByNameLax[person]) + assert.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Smith", slice[i].Last) + assert.Equal(t, "John", slice[i].First) + assert.EqualValues(t, i, slice[i].Age) + } + + // check missing fields in a returned row + rows, _ = conn.Query(ctx, `select 'John' as first, n as age from generate_series(0, 9) n`) + slice, err = pgx.CollectRows(rows, pgx.RowToStructByNameLax[person]) + assert.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "John", slice[i].First) + assert.EqualValues(t, i, slice[i].Age) + } + + // check extra fields in a returned row + rows, _ = conn.Query(ctx, `select 'John' as first, 'Smith' as last, n as age, null as ignore from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByNameLax[person]) + assert.ErrorContains(t, err, "struct doesn't have corresponding row field ignore") + + // check missing fields in a destination struct + rows, _ = conn.Query(ctx, `select 'Smith' as last, 'D.' as middle, n as age from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByNameLax[person]) + assert.ErrorContains(t, err, "struct doesn't have corresponding row field middle") + + // check ignored fields in a destination struct + rows, _ = conn.Query(ctx, `select 'Smith' as last, n as age, null as ignore from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByNameLax[person]) + assert.ErrorContains(t, err, "struct doesn't have corresponding row field ignore") + }) +} + +func TestRowToStructByNameLaxEmbeddedStruct(t *testing.T) { + type Name struct { + Last string `db:"last_name"` + First string `db:"first_name"` + } + + type person struct { + Ignore bool `db:"-"` + Name + Age int32 + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToStructByNameLax[person]) + assert.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Smith", slice[i].Name.Last) + assert.Equal(t, "John", slice[i].Name.First) + assert.EqualValues(t, i, slice[i].Age) + } + + // check missing fields in a returned row + rows, _ = conn.Query(ctx, `select 'John' as first_name, n as age from generate_series(0, 9) n`) + slice, err = pgx.CollectRows(rows, pgx.RowToStructByNameLax[person]) + assert.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "John", slice[i].Name.First) + assert.EqualValues(t, i, slice[i].Age) + } + + // check extra fields in a returned row + rows, _ = conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age, null as ignore from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByNameLax[person]) + assert.ErrorContains(t, err, "struct doesn't have corresponding row field ignore") + + // check missing fields in a destination struct + rows, _ = conn.Query(ctx, `select 'Smith' as last_name, 'D.' as middle_name, n as age from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByNameLax[person]) + assert.ErrorContains(t, err, "struct doesn't have corresponding row field middle_name") + + // check ignored fields in a destination struct + rows, _ = conn.Query(ctx, `select 'Smith' as last_name, n as age, null as ignore from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByNameLax[person]) + assert.ErrorContains(t, err, "struct doesn't have corresponding row field ignore") + }) +} + +func TestRowToStructByNameLaxRowValue(t *testing.T) { + type AnotherTable struct{} + type User struct { + UserID int `json:"userId" db:"user_id"` + Name string `json:"name" db:"name"` + } + type UserAPIKey struct { + UserAPIKeyID int `json:"userApiKeyId" db:"user_api_key_id"` + UserID int `json:"userId" db:"user_id"` + + User *User `json:"user" db:"user"` + AnotherTable *AnotherTable `json:"anotherTable" db:"another_table"` + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "") + + rows, _ := conn.Query(ctx, ` + WITH user_api_keys AS ( + SELECT 1 AS user_id, 101 AS user_api_key_id, 'abc123' AS api_key + ), users AS ( + SELECT 1 AS user_id, 'John Doe' AS name + ) + SELECT user_api_keys.user_api_key_id, user_api_keys.user_id, row(users.*) AS user + FROM user_api_keys + LEFT JOIN users ON users.user_id = user_api_keys.user_id + WHERE user_api_keys.api_key = 'abc123'; + `) + slice, err := pgx.CollectRows(rows, pgx.RowToStructByNameLax[UserAPIKey]) + + assert.NoError(t, err) + assert.ElementsMatch(t, slice, []UserAPIKey{{UserAPIKeyID: 101, UserID: 1, User: &User{UserID: 1, Name: "John Doe"}, AnotherTable: nil}}) + }) +} + +func ExampleRowToStructByNameLax() { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + if conn.PgConn().ParameterStatus("crdb_version") != "" { + // Skip test / example when running on CockroachDB. Since an example can't be skipped fake success instead. + fmt.Println(`Cheeseburger: $10 +Fries: $5 +Soft Drink: $3`) + return + } + + // Setup example schema and data. + _, err = conn.Exec(ctx, ` +create temporary table products ( + id int primary key generated by default as identity, + name varchar(100) not null, + price int not null +); + +insert into products (name, price) values + ('Cheeseburger', 10), + ('Double Cheeseburger', 14), + ('Fries', 5), + ('Soft Drink', 3); +`) + if err != nil { + fmt.Printf("Unable to setup example schema and data: %v", err) + return + } + + type product struct { + ID int32 + Name string + Type string + Price int32 + } + + rows, _ := conn.Query(ctx, "select * from products where price < $1 order by price desc", 12) + products, err := pgx.CollectRows(rows, pgx.RowToStructByNameLax[product]) + if err != nil { + fmt.Printf("CollectRows error: %v", err) + return + } + + for _, p := range products { + fmt.Printf("%s: $%d\n", p.Name, p.Price) + } + + // Output: + // Cheeseburger: $10 + // Fries: $5 + // Soft Drink: $3 +} diff --git a/sql.go b/sql.go deleted file mode 100644 index 7ee0f2a0a..000000000 --- a/sql.go +++ /dev/null @@ -1,29 +0,0 @@ -package pgx - -import ( - "strconv" -) - -// QueryArgs is a container for arguments to an SQL query. It is helpful when -// building SQL statements where the number of arguments is variable. -type QueryArgs []interface{} - -var placeholders []string - -func init() { - placeholders = make([]string, 64) - - for i := 1; i < 64; i++ { - placeholders[i] = "$" + strconv.Itoa(i) - } -} - -// Append adds a value to qa and returns the placeholder value for the -// argument. e.g. $1, $2, etc. -func (qa *QueryArgs) Append(v interface{}) string { - *qa = append(*qa, v) - if len(*qa) < len(placeholders) { - return placeholders[len(*qa)] - } - return "$" + strconv.Itoa(len(*qa)) -} diff --git a/sql_test.go b/sql_test.go deleted file mode 100644 index dd0360356..000000000 --- a/sql_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package pgx_test - -import ( - "strconv" - "testing" - - "github.com/jackc/pgx" -) - -func TestQueryArgs(t *testing.T) { - var qa pgx.QueryArgs - - for i := 1; i < 512; i++ { - expectedPlaceholder := "$" + strconv.Itoa(i) - placeholder := qa.Append(i) - if placeholder != expectedPlaceholder { - t.Errorf(`Expected qa.Append to return "%s", but it returned "%s"`, expectedPlaceholder, placeholder) - } - } -} - -func BenchmarkQueryArgs(b *testing.B) { - for i := 0; i < b.N; i++ { - qa := pgx.QueryArgs(make([]interface{}, 0, 16)) - qa.Append("foo1") - qa.Append("foo2") - qa.Append("foo3") - qa.Append("foo4") - qa.Append("foo5") - qa.Append("foo6") - qa.Append("foo7") - qa.Append("foo8") - qa.Append("foo9") - qa.Append("foo10") - } -} diff --git a/stdlib/bench_test.go b/stdlib/bench_test.go new file mode 100644 index 000000000..141fc4eb5 --- /dev/null +++ b/stdlib/bench_test.go @@ -0,0 +1,160 @@ +package stdlib_test + +import ( + "database/sql" + "fmt" + "os" + "strconv" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" +) + +func getSelectRowsCounts(b *testing.B) []int64 { + var rowCounts []int64 + { + s := os.Getenv("PGX_BENCH_SELECT_ROWS_COUNTS") + if s != "" { + for _, p := range strings.Split(s, " ") { + n, err := strconv.ParseInt(p, 10, 64) + if err != nil { + b.Fatalf("Bad PGX_BENCH_SELECT_ROWS_COUNTS value: %v", err) + } + rowCounts = append(rowCounts, n) + } + } + } + + if len(rowCounts) == 0 { + rowCounts = []int64{1, 10, 100, 1000} + } + + return rowCounts +} + +type BenchRowSimple struct { + ID int32 + FirstName string + LastName string + Sex string + BirthDate time.Time + Weight int32 + Height int32 + UpdateTime time.Time +} + +func BenchmarkSelectRowsScanSimple(b *testing.B) { + db := openDB(b) + defer closeDB(b, db) + + rowCounts := getSelectRowsCounts(b) + + for _, rowCount := range rowCounts { + b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { + br := &BenchRowSimple{} + for i := 0; i < b.N; i++ { + rows, err := db.Query("select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(1, $1) n", rowCount) + if err != nil { + b.Fatal(err) + } + + for rows.Next() { + rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime) + } + + if rows.Err() != nil { + b.Fatal(rows.Err()) + } + } + }) + } +} + +type BenchRowNull struct { + ID sql.NullInt32 + FirstName sql.NullString + LastName sql.NullString + Sex sql.NullString + BirthDate sql.NullTime + Weight sql.NullInt32 + Height sql.NullInt32 + UpdateTime sql.NullTime +} + +func BenchmarkSelectRowsScanNull(b *testing.B) { + db := openDB(b) + defer closeDB(b, db) + + rowCounts := getSelectRowsCounts(b) + + for _, rowCount := range rowCounts { + b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { + br := &BenchRowSimple{} + for i := 0; i < b.N; i++ { + rows, err := db.Query("select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100000, 100000 + $1) n", rowCount) + if err != nil { + b.Fatal(err) + } + + for rows.Next() { + rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime) + } + + if rows.Err() != nil { + b.Fatal(rows.Err()) + } + } + }) + } +} + +func BenchmarkFlatArrayEncodeArgument(b *testing.B) { + db := openDB(b) + defer closeDB(b, db) + + input := make(pgtype.FlatArray[string], 10) + for i := range input { + input[i] = fmt.Sprintf("String %d", i) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var n int64 + err := db.QueryRow("select cardinality($1::text[])", input).Scan(&n) + if err != nil { + b.Fatal(err) + } + if n != int64(len(input)) { + b.Fatalf("Expected %d, got %d", len(input), n) + } + } +} + +func BenchmarkFlatArrayScanResult(b *testing.B) { + db := openDB(b) + defer closeDB(b, db) + + var input string + for i := 0; i < 10; i++ { + if i > 0 { + input += "," + } + input += fmt.Sprintf(`'String %d'`, i) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var result pgtype.FlatArray[string] + err := db.QueryRow(fmt.Sprintf("select array[%s]::text[]", input)).Scan(&result) + if err != nil { + b.Fatal(err) + } + if len(result) != 10 { + b.Fatalf("Expected %d, got %d", len(result), 10) + } + } +} diff --git a/stdlib/sql.go b/stdlib/sql.go index 2d4930ee9..9d268b152 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -4,216 +4,437 @@ // // db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable") // if err != nil { -// return err +// return err // } // -// Or from a DSN string. +// Or from a keyword/value string. // // db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable") // if err != nil { -// return err +// return err // } // -// A DriverConfig can be used to further configure the connection process. This -// allows configuring TLS configuration, setting a custom dialer, logging, and -// setting an AfterConnect hook. +// Or from a *pgxpool.Pool. // -// driverConfig := stdlib.DriverConfig{ -// ConnConfig: pgx.ConnConfig{ -// Logger: logger, -// }, -// AfterConnect: func(c *pgx.Conn) error { -// // Ensure all connections have this temp table available -// _, err := c.Exec("create temporary table foo(...)") -// return err -// }, +// pool, err := pgxpool.New(context.Background(), os.Getenv("DATABASE_URL")) +// if err != nil { +// return err // } // -// stdlib.RegisterDriverConfig(&driverConfig) +// db := stdlib.OpenDBFromPool(pool) // -// db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")) -// if err != nil { -// return err -// } +// Or a pgx.ConnConfig can be used to set configuration not accessible via connection string. In this case the +// pgx.ConnConfig must first be registered with the driver. This registration returns a connection string which is used +// with sql.Open. +// +// connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL")) +// connConfig.Tracer = &tracelog.TraceLog{Logger: myLogger, LogLevel: tracelog.LogLevelInfo} +// connStr := stdlib.RegisterConnConfig(connConfig) +// db, _ := sql.Open("pgx", connStr) // -// pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2. -// It does not support named parameters. +// pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2. It does not support named parameters. // // db.QueryRow("select * from users where id=$1", userID) // -// AcquireConn and ReleaseConn acquire and release a *pgx.Conn from the standard -// database/sql.DB connection pool. This allows operations that must be -// performed on a single connection, but should not be run in a transaction or -// to use pgx specific functionality. +// (*sql.Conn) Raw() can be used to get a *pgx.Conn from the standard database/sql.DB connection pool. This allows +// operations that use pgx specific functionality. // -// conn, err := stdlib.AcquireConn(db) +// // Given db is a *sql.DB +// conn, err := db.Conn(context.Background()) // if err != nil { -// return err +// // handle error from acquiring connection from DB pool // } -// defer stdlib.ReleaseConn(db, conn) // -// // do stuff with pgx.Conn +// err = conn.Raw(func(driverConn any) error { +// conn := driverConn.(*stdlib.Conn).Conn() // conn is a *pgx.Conn +// // Do pgx specific stuff with conn +// conn.CopyFrom(...) +// return nil +// }) +// if err != nil { +// // handle error that occurred while using *pgx.Conn +// } // -// It also can be used to enable a fast path for pgx while preserving -// compatibility with other drivers and database. +// # PostgreSQL Specific Data Types // -// conn, err := stdlib.AcquireConn(db) -// if err == nil { -// // fast path with pgx -// // ... -// // release conn when done -// stdlib.ReleaseConn(db, conn) -// } else { -// // normal path for other drivers and databases -// } +// As of Go 1.26 the database/sql allows drivers to implement their own scanning logic by implementing the +// driver.RowsColumnScanner interface. This allows PostgreSQL arrays to be scanned directly into Go slices. +// +// var a []int64 +// err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(&a) +// +// In older versions of Go, *pgtype.Map.SQLScanner can be used as an adapter that makes these types usable as a +// sql.Scanner. +// +// m := pgtype.NewMap() +// var a []int64 +// err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a)) +// +// The pgtype package provides support for PostgreSQL specific types. These types can be used directly in Go 1.26 and +// with *pgtype.Map.SQLScanner in older Go versions. +// +// var r pgtype.Range[pgtype.Int4] +// err := db.QueryRow("select int4range(1, 5)").Scan(&r) package stdlib import ( "context" "database/sql" "database/sql/driver" - "encoding/binary" + "errors" "fmt" "io" + "math" + "math/rand" "reflect" + "slices" + "strconv" "strings" "sync" + "time" - "github.com/pkg/errors" - - "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" ) -// oids that map to intrinsic database/sql types. These will be allowed to be -// binary, anything else will be forced to text format -var databaseSqlOIDs map[pgtype.OID]bool +// Only intrinsic types should be binary format with database/sql. +var databaseSQLResultFormats pgx.QueryResultFormatsByOID var pgxDriver *Driver -type ctxKey int +func init() { + pgxDriver = &Driver{ + configs: make(map[string]*pgx.ConnConfig), + } -var ctxKeyFakeTx ctxKey = 0 + // if pgx driver was already registered by different pgx major version then we + // skip registration under the default name. + if !slices.Contains(sql.Drivers(), "pgx") { + sql.Register("pgx", pgxDriver) + } + sql.Register("pgx/v5", pgxDriver) -var ErrNotPgx = errors.New("not pgx *sql.DB") + databaseSQLResultFormats = pgx.QueryResultFormatsByOID{ + pgtype.BoolOID: 1, + pgtype.ByteaOID: 1, + pgtype.CIDOID: 1, + pgtype.DateOID: 1, + pgtype.Float4OID: 1, + pgtype.Float8OID: 1, + pgtype.Int2OID: 1, + pgtype.Int4OID: 1, + pgtype.Int8OID: 1, + pgtype.OIDOID: 1, + pgtype.TimestampOID: 1, + pgtype.TimestamptzOID: 1, + pgtype.XIDOID: 1, + } +} -func init() { - pgxDriver = &Driver{ - configs: make(map[int64]*DriverConfig), - fakeTxConns: make(map[*pgx.Conn]*sql.Tx), - } - sql.Register("pgx", pgxDriver) - - databaseSqlOIDs = make(map[pgtype.OID]bool) - databaseSqlOIDs[pgtype.BoolOID] = true - databaseSqlOIDs[pgtype.ByteaOID] = true - databaseSqlOIDs[pgtype.CIDOID] = true - databaseSqlOIDs[pgtype.DateOID] = true - databaseSqlOIDs[pgtype.Float4OID] = true - databaseSqlOIDs[pgtype.Float8OID] = true - databaseSqlOIDs[pgtype.Int2OID] = true - databaseSqlOIDs[pgtype.Int4OID] = true - databaseSqlOIDs[pgtype.Int8OID] = true - databaseSqlOIDs[pgtype.OIDOID] = true - databaseSqlOIDs[pgtype.TimestampOID] = true - databaseSqlOIDs[pgtype.TimestamptzOID] = true - databaseSqlOIDs[pgtype.XIDOID] = true +// OptionOpenDB options for configuring the driver when opening a new db pool. +type OptionOpenDB func(*connector) + +// ShouldPingParams are passed to OptionShouldPing to decide whether to ping before reusing a connection. +type ShouldPingParams struct { + // Conn is the underlying pgx connection. + Conn *pgx.Conn + // IdleDuration is how long it has been since ResetSession last ran. + IdleDuration time.Duration } -type Driver struct { - configMutex sync.Mutex - configCount int64 - configs map[int64]*DriverConfig +// OptionShouldPing controls whether stdlib should issue a liveness ping before reusing a connection. +// If the function returns true, stdlib will ping. +// If it returns false, stdlib will skip the ping. +// If not provided, default is ping only when IdleDuration > 1s. +func OptionShouldPing(f func(context.Context, ShouldPingParams) bool) OptionOpenDB { + return func(dc *connector) { dc.ShouldPing = f } +} - fakeTxMutex sync.Mutex - fakeTxConns map[*pgx.Conn]*sql.Tx +// OptionBeforeConnect provides a callback for before connect. It is passed a shallow copy of the ConnConfig that will +// be used to connect, so only its immediate members should be modified. Used only if db is opened with *pgx.ConnConfig. +func OptionBeforeConnect(bc func(context.Context, *pgx.ConnConfig) error) OptionOpenDB { + return func(dc *connector) { + dc.BeforeConnect = bc + } } -func (d *Driver) Open(name string) (driver.Conn, error) { - var connConfig pgx.ConnConfig - var afterConnect func(*pgx.Conn) error - if len(name) >= 9 && name[0] == 0 { - idBuf := []byte(name)[1:9] - id := int64(binary.BigEndian.Uint64(idBuf)) - connConfig = d.configs[id].ConnConfig - afterConnect = d.configs[id].AfterConnect - name = name[9:] +// OptionAfterConnect provides a callback for after connect. Used only if db is opened with *pgx.ConnConfig. +func OptionAfterConnect(ac func(context.Context, *pgx.Conn) error) OptionOpenDB { + return func(dc *connector) { + dc.AfterConnect = ac } +} - parsedConfig, err := pgx.ParseConnectionString(name) - if err != nil { - return nil, err +// OptionResetSession provides a callback that can be used to add custom logic prior to executing a query on the +// connection if the connection has been used before. +// If ResetSessionFunc returns ErrBadConn error the connection will be discarded. +func OptionResetSession(rs func(context.Context, *pgx.Conn) error) OptionOpenDB { + return func(dc *connector) { + dc.ResetSession = rs } - connConfig = connConfig.Merge(parsedConfig) +} - conn, err := pgx.Connect(connConfig) - if err != nil { - return nil, err +// RandomizeHostOrderFunc is a BeforeConnect hook that randomizes the host order in the provided connConfig, so that a +// new host becomes primary each time. This is useful to distribute connections for multi-master databases like +// CockroachDB. If you use this you likely should set https://golang.org/pkg/database/sql/#DB.SetConnMaxLifetime as well +// to ensure that connections are periodically rebalanced across your nodes. +func RandomizeHostOrderFunc(ctx context.Context, connConfig *pgx.ConnConfig) error { + if len(connConfig.Fallbacks) == 0 { + return nil + } + + newFallbacks := append([]*pgconn.FallbackConfig{{ + Host: connConfig.Host, + Port: connConfig.Port, + TLSConfig: connConfig.TLSConfig, + }}, connConfig.Fallbacks...) + + rand.Shuffle(len(newFallbacks), func(i, j int) { + newFallbacks[i], newFallbacks[j] = newFallbacks[j], newFallbacks[i] + }) + + // Use the one that sorted last as the primary and keep the rest as the fallbacks + newPrimary := newFallbacks[len(newFallbacks)-1] + connConfig.Host = newPrimary.Host + connConfig.Port = newPrimary.Port + connConfig.TLSConfig = newPrimary.TLSConfig + connConfig.Fallbacks = newFallbacks[:len(newFallbacks)-1] + return nil +} + +func GetConnector(config pgx.ConnConfig, opts ...OptionOpenDB) driver.Connector { + c := connector{ + ConnConfig: config, + BeforeConnect: func(context.Context, *pgx.ConnConfig) error { return nil }, // noop before connect by default + AfterConnect: func(context.Context, *pgx.Conn) error { return nil }, // noop after connect by default + ResetSession: func(context.Context, *pgx.Conn) error { return nil }, // noop reset session by default + driver: pgxDriver, } - if afterConnect != nil { - err = afterConnect(conn) + for _, opt := range opts { + opt(&c) + } + return c +} + +// GetPoolConnector creates a new driver.Connector from the given *pgxpool.Pool. By using this be sure to set the +// maximum idle connections of the *sql.DB created with this connector to zero since they must be managed from the +// *pgxpool.Pool. This is required to avoid acquiring all the connections from the pgxpool and starving any direct +// users of the pgxpool. +func GetPoolConnector(pool *pgxpool.Pool, opts ...OptionOpenDB) driver.Connector { + c := connector{ + pool: pool, + ResetSession: func(context.Context, *pgx.Conn) error { return nil }, // noop reset session by default + driver: pgxDriver, + } + + for _, opt := range opts { + opt(&c) + } + + return c +} + +func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB { + c := GetConnector(config, opts...) + return sql.OpenDB(c) +} + +// OpenDBFromPool creates a new *sql.DB from the given *pgxpool.Pool. Note that this method automatically sets the +// maximum number of idle connections in *sql.DB to zero, since they must be managed from the *pgxpool.Pool. This is +// required to avoid acquiring all the connections from the pgxpool and starving any direct users of the pgxpool. Note +// that closing the returned *sql.DB will not close the *pgxpool.Pool. +func OpenDBFromPool(pool *pgxpool.Pool, opts ...OptionOpenDB) *sql.DB { + c := GetPoolConnector(pool, opts...) + db := sql.OpenDB(c) + db.SetMaxIdleConns(0) + return db +} + +type connector struct { + pgx.ConnConfig + pool *pgxpool.Pool + BeforeConnect func(context.Context, *pgx.ConnConfig) error // function to call before creation of every new connection + AfterConnect func(context.Context, *pgx.Conn) error // function to call after creation of every new connection + ResetSession func(context.Context, *pgx.Conn) error // function is called before a connection is reused + ShouldPing func(context.Context, ShouldPingParams) bool // function to decide if stdlib should ping before reusing a connection + driver *Driver +} + +// Connect implement driver.Connector interface +func (c connector) Connect(ctx context.Context) (driver.Conn, error) { + var ( + connConfig pgx.ConnConfig + conn *pgx.Conn + close func(context.Context) error + err error + ) + + if c.pool == nil { + // Create a shallow copy of the config, so that BeforeConnect can safely modify it + connConfig = c.ConnConfig + + if err = c.BeforeConnect(ctx, &connConfig); err != nil { + return nil, err + } + + if conn, err = pgx.ConnectConfig(ctx, &connConfig); err != nil { + return nil, err + } + + if err = c.AfterConnect(ctx, conn); err != nil { + return nil, err + } + + close = conn.Close + } else { + var pconn *pgxpool.Conn + + pconn, err = c.pool.Acquire(ctx) if err != nil { return nil, err } + + conn = pconn.Conn() + + close = func(_ context.Context) error { + pconn.Release() + return nil + } } - c := &Conn{conn: conn, driver: d, connConfig: connConfig} - return c, nil + return &Conn{ + conn: conn, + close: close, + driver: c.driver, + connConfig: connConfig, + resetSessionFunc: c.ResetSession, + shouldPing: c.ShouldPing, + psRefCounts: make(map[*pgconn.StatementDescription]int), + }, nil } -type DriverConfig struct { - pgx.ConnConfig - AfterConnect func(*pgx.Conn) error // function to call on every new connection - driver *Driver - id int64 +// Driver implement driver.Connector interface +func (c connector) Driver() driver.Driver { + return c.driver } -// ConnectionString encodes the DriverConfig into the original connection -// string. DriverConfig must be registered before calling ConnectionString. -func (c *DriverConfig) ConnectionString(original string) string { - if c.driver == nil { - panic("DriverConfig must be registered before calling ConnectionString") - } +// GetDefaultDriver returns the driver initialized in the init function +// and used when the pgx driver is registered. +func GetDefaultDriver() driver.Driver { + return pgxDriver +} - buf := make([]byte, 9) - binary.BigEndian.PutUint64(buf[1:], uint64(c.id)) - buf = append(buf, original...) - return string(buf) +type Driver struct { + configMutex sync.Mutex + configs map[string]*pgx.ConnConfig + sequence int } -func (d *Driver) registerDriverConfig(c *DriverConfig) { - d.configMutex.Lock() +func (d *Driver) Open(name string) (driver.Conn, error) { + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Ensure eventual timeout + defer cancel() - c.driver = d - c.id = d.configCount - d.configs[d.configCount] = c - d.configCount++ + connector, err := d.OpenConnector(name) + if err != nil { + return nil, err + } + return connector.Connect(ctx) +} +func (d *Driver) OpenConnector(name string) (driver.Connector, error) { + return &driverConnector{driver: d, name: name}, nil +} + +func (d *Driver) registerConnConfig(c *pgx.ConnConfig) string { + d.configMutex.Lock() + connStr := fmt.Sprintf("registeredConnConfig%d", d.sequence) + d.sequence++ + d.configs[connStr] = c d.configMutex.Unlock() + return connStr } -func (d *Driver) unregisterDriverConfig(c *DriverConfig) { +func (d *Driver) unregisterConnConfig(connStr string) { d.configMutex.Lock() - delete(d.configs, c.id) + delete(d.configs, connStr) d.configMutex.Unlock() } -// RegisterDriverConfig registers a DriverConfig for use with Open. -func RegisterDriverConfig(c *DriverConfig) { - pgxDriver.registerDriverConfig(c) +type driverConnector struct { + driver *Driver + name string +} + +func (dc *driverConnector) Connect(ctx context.Context) (driver.Conn, error) { + var connConfig *pgx.ConnConfig + + dc.driver.configMutex.Lock() + connConfig = dc.driver.configs[dc.name] + dc.driver.configMutex.Unlock() + + if connConfig == nil { + var err error + connConfig, err = pgx.ParseConfig(dc.name) + if err != nil { + return nil, err + } + } + + conn, err := pgx.ConnectConfig(ctx, connConfig) + if err != nil { + return nil, err + } + + c := &Conn{ + conn: conn, + close: conn.Close, + driver: dc.driver, + connConfig: *connConfig, + resetSessionFunc: func(context.Context, *pgx.Conn) error { return nil }, + psRefCounts: make(map[*pgconn.StatementDescription]int), + } + + return c, nil +} + +func (dc *driverConnector) Driver() driver.Driver { + return dc.driver } -// UnregisterDriverConfig removes a DriverConfig registration. -func UnregisterDriverConfig(c *DriverConfig) { - pgxDriver.unregisterDriverConfig(c) +// RegisterConnConfig registers a ConnConfig and returns the connection string to use with Open. +func RegisterConnConfig(c *pgx.ConnConfig) string { + return pgxDriver.registerConnConfig(c) +} + +// UnregisterConnConfig removes the ConnConfig registration for connStr. +func UnregisterConnConfig(connStr string) { + pgxDriver.unregisterConnConfig(connStr) } type Conn struct { - conn *pgx.Conn - psCount int64 // Counter used for creating unique prepared statement names - driver *Driver - connConfig pgx.ConnConfig + conn *pgx.Conn + close func(context.Context) error + driver *Driver + connConfig pgx.ConnConfig + resetSessionFunc func(context.Context, *pgx.Conn) error // Function is called before a connection is reused + shouldPing func(context.Context, ShouldPingParams) bool // Function to decide if stdlib should ping before reusing a connection + lastResetSessionTime time.Time + + // psRefCounts contains reference counts for prepared statements. Prepare uses the underlying pgx logic to generate + // deterministic statement names from the statement text. If this query has already been prepared then the existing + // *pgconn.StatementDescription will be returned. However, this means that if Close is called on the returned Stmt + // then the underlying prepared statement will be closed even when the underlying prepared statement is still in use + // by another database/sql Stmt. To prevent this psRefCounts keeps track of how many database/sql statements are using + // the same underlying statement and only closes the underlying statement when the reference count reaches 0. + psRefCounts map[*pgconn.StatementDescription]int +} + +// Conn returns the underlying *pgx.Conn +func (c *Conn) Conn() *pgx.Conn { + return c.conn } func (c *Conn) Prepare(query string) (driver.Stmt, error) { @@ -221,25 +442,23 @@ func (c *Conn) Prepare(query string) (driver.Stmt, error) { } func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - if !c.conn.IsAlive() { + if c.conn.IsClosed() { return nil, driver.ErrBadConn } - name := fmt.Sprintf("pgx_%d", c.psCount) - c.psCount++ - - ps, err := c.conn.PrepareEx(ctx, name, query, nil) + sd, err := c.conn.Prepare(ctx, query, query) if err != nil { return nil, err } + c.psRefCounts[sd]++ - restrictBinaryToDatabaseSqlTypes(ps) - - return &Stmt{ps: ps, conn: c}, nil + return &Stmt{sd: sd, conn: c}, nil } func (c *Conn) Close() error { - return c.conn.Close() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + return c.close(ctx) } func (c *Conn) Begin() (driver.Tx, error) { @@ -247,15 +466,10 @@ func (c *Conn) Begin() (driver.Tx, error) { } func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { - if !c.conn.IsAlive() { + if c.conn.IsClosed() { return nil, driver.ErrBadConn } - if pconn, ok := ctx.Value(ctxKeyFakeTx).(**pgx.Conn); ok { - *pconn = c.conn - return fakeTx{}, nil - } - var pgxOpts pgx.TxOptions switch sql.IsolationLevel(opts.Isolation) { case sql.LevelDefault: @@ -263,248 +477,387 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e pgxOpts.IsoLevel = pgx.ReadUncommitted case sql.LevelReadCommitted: pgxOpts.IsoLevel = pgx.ReadCommitted - case sql.LevelSnapshot: + case sql.LevelRepeatableRead, sql.LevelSnapshot: pgxOpts.IsoLevel = pgx.RepeatableRead case sql.LevelSerializable: pgxOpts.IsoLevel = pgx.Serializable default: - return nil, errors.Errorf("unsupported isolation: %v", opts.Isolation) + return nil, fmt.Errorf("unsupported isolation: %v", opts.Isolation) } if opts.ReadOnly { pgxOpts.AccessMode = pgx.ReadOnly } - return c.conn.BeginEx(ctx, &pgxOpts) -} - -func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) { - if !c.conn.IsAlive() { - return nil, driver.ErrBadConn + tx, err := c.conn.BeginTx(ctx, pgxOpts) + if err != nil { + return nil, err } - args := valueToInterface(argsV) - commandTag, err := c.conn.Exec(query, args...) - return driver.RowsAffected(commandTag.RowsAffected()), err + return wrapTx{ctx: ctx, tx: tx}, nil } func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Result, error) { - if !c.conn.IsAlive() { + if c.conn.IsClosed() { return nil, driver.ErrBadConn } - args := namedValueToInterface(argsV) + args := make([]any, len(argsV)) + convertNamedArguments(args, argsV) - commandTag, err := c.conn.ExecEx(ctx, query, nil, args...) - return driver.RowsAffected(commandTag.RowsAffected()), err -} - -func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) { - if !c.conn.IsAlive() { - return nil, driver.ErrBadConn - } - - if !c.connConfig.PreferSimpleProtocol { - ps, err := c.conn.Prepare("", query) - if err != nil { - return nil, err - } - - restrictBinaryToDatabaseSqlTypes(ps) - return c.queryPrepared("", argsV) - } - - rows, err := c.conn.Query(query, valueToInterface(argsV)...) + commandTag, err := c.conn.Exec(ctx, query, args...) + // if we got a network error before we had a chance to send the query, retry if err != nil { - return nil, err + if pgconn.SafeToRetry(err) { + return nil, driver.ErrBadConn + } } - - // Preload first row because otherwise we won't know what columns are available when database/sql asks. - more := rows.Next() - return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil + return driver.RowsAffected(commandTag.RowsAffected()), err } func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) { - if !c.conn.IsAlive() { + if c.conn.IsClosed() { return nil, driver.ErrBadConn } - if !c.connConfig.PreferSimpleProtocol { - ps, err := c.conn.PrepareEx(ctx, "", query, nil) - if err != nil { - return nil, err - } + args := make([]any, 1+len(argsV)) + args[0] = databaseSQLResultFormats + convertNamedArguments(args[1:], argsV) - restrictBinaryToDatabaseSqlTypes(ps) - return c.queryPreparedContext(ctx, "", argsV) - } - - rows, err := c.conn.QueryEx(ctx, query, nil, namedValueToInterface(argsV)...) + rows, err := c.conn.Query(ctx, query, args...) if err != nil { + if pgconn.SafeToRetry(err) { + return nil, driver.ErrBadConn + } return nil, err } // Preload first row because otherwise we won't know what columns are available when database/sql asks. more := rows.Next() - return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil -} - -func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) { - if !c.conn.IsAlive() { - return nil, driver.ErrBadConn - } - - args := valueToInterface(argsV) - - rows, err := c.conn.Query(name, args...) - if err != nil { + if err = rows.Err(); err != nil { + rows.Close() return nil, err } - - return &Rows{rows: rows}, nil + return &Rows{conn: c, rows: rows, skipNext: true, skipNextMore: more}, nil } -func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []driver.NamedValue) (driver.Rows, error) { - if !c.conn.IsAlive() { - return nil, driver.ErrBadConn +func (c *Conn) Ping(ctx context.Context) error { + if c.conn.IsClosed() { + return driver.ErrBadConn } - args := namedValueToInterface(argsV) - - rows, err := c.conn.QueryEx(ctx, name, nil, args...) + err := c.conn.Ping(ctx) if err != nil { - return nil, err + // A Ping failure implies some sort of fatal state. The connection is almost certainly already closed by the + // failure, but manually close it just to be sure. + c.Close() + return driver.ErrBadConn } - return &Rows{rows: rows}, nil + return nil } -func (c *Conn) Ping(ctx context.Context) error { - if !c.conn.IsAlive() { +func (c *Conn) CheckNamedValue(*driver.NamedValue) error { + // Underlying pgx supports sql.Scanner and driver.Valuer interfaces natively. So everything can be passed through directly. + return nil +} + +func (c *Conn) ResetSession(ctx context.Context) error { + if c.conn.IsClosed() { return driver.ErrBadConn } - return c.conn.Ping(ctx) -} + now := time.Now() + idle := now.Sub(c.lastResetSessionTime) + + doPing := idle > time.Second // default behavior: ping only if idle > 1s -// Anything that isn't a database/sql compatible type needs to be forced to -// text format so that pgx.Rows.Values doesn't decode it into a native type -// (e.g. []int32) -func restrictBinaryToDatabaseSqlTypes(ps *pgx.PreparedStatement) { - for i := range ps.FieldDescriptions { - intrinsic, _ := databaseSqlOIDs[ps.FieldDescriptions[i].DataType] - if !intrinsic { - ps.FieldDescriptions[i].FormatCode = pgx.TextFormatCode + if c.shouldPing != nil { + doPing = c.shouldPing(ctx, ShouldPingParams{ + Conn: c.conn, + IdleDuration: idle, + }) + } + + if doPing { + if err := c.conn.PgConn().Ping(ctx); err != nil { + return driver.ErrBadConn } } + + c.lastResetSessionTime = now + + return c.resetSessionFunc(ctx, c.conn) } type Stmt struct { - ps *pgx.PreparedStatement + sd *pgconn.StatementDescription conn *Conn } func (s *Stmt) Close() error { - return s.conn.conn.Deallocate(s.ps.Name) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + refCount := s.conn.psRefCounts[s.sd] + if refCount == 1 { + delete(s.conn.psRefCounts, s.sd) + } else { + s.conn.psRefCounts[s.sd]-- + return nil + } + + return s.conn.conn.Deallocate(ctx, s.sd.SQL) } func (s *Stmt) NumInput() int { - return len(s.ps.ParameterOIDs) + return len(s.sd.ParamOIDs) } func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) { - return s.conn.Exec(s.ps.Name, argsV) + return nil, errors.New("Stmt.Exec deprecated and not implemented") } func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) { - return s.conn.ExecContext(ctx, s.ps.Name, argsV) + return s.conn.ExecContext(ctx, s.sd.SQL, argsV) } func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { - return s.conn.queryPrepared(s.ps.Name, argsV) + return nil, errors.New("Stmt.Query deprecated and not implemented") } func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) { - return s.conn.queryPreparedContext(ctx, s.ps.Name, argsV) + return s.conn.QueryContext(ctx, s.sd.SQL, argsV) } +type rowValueFunc func(src []byte) (driver.Value, error) + type Rows struct { - rows *pgx.Rows - values []interface{} + conn *Conn + rows pgx.Rows + valueFuncs []rowValueFunc skipNext bool skipNextMore bool + + columnNames []string } func (r *Rows) Columns() []string { - fieldDescriptions := r.rows.FieldDescriptions() - names := make([]string, 0, len(fieldDescriptions)) - for _, fd := range fieldDescriptions { - names = append(names, fd.Name) + if r.columnNames == nil { + fields := r.rows.FieldDescriptions() + r.columnNames = make([]string, len(fields)) + for i, fd := range fields { + r.columnNames[i] = string(fd.Name) + } } - return names + + return r.columnNames } -// ColumnTypeDatabaseTypeName return the database system type name. +// ColumnTypeDatabaseTypeName returns the database system type name. If the name is unknown the OID is returned. func (r *Rows) ColumnTypeDatabaseTypeName(index int) string { - return strings.ToUpper(r.rows.FieldDescriptions()[index].DataTypeName) + if dt, ok := r.conn.conn.TypeMap().TypeForOID(r.rows.FieldDescriptions()[index].DataTypeOID); ok { + return strings.ToUpper(dt.Name) + } + + return strconv.FormatInt(int64(r.rows.FieldDescriptions()[index].DataTypeOID), 10) } +const varHeaderSize = 4 + // ColumnTypeLength returns the length of the column type if the column is a // variable length type. If the column is not a variable length type ok // should return false. func (r *Rows) ColumnTypeLength(index int) (int64, bool) { - return r.rows.FieldDescriptions()[index].Length() + fd := r.rows.FieldDescriptions()[index] + + switch fd.DataTypeOID { + case pgtype.TextOID, pgtype.ByteaOID: + return math.MaxInt64, true + case pgtype.VarcharOID, pgtype.BPCharArrayOID: + return int64(fd.TypeModifier - varHeaderSize), true + default: + return 0, false + } } // ColumnTypePrecisionScale should return the precision and scale for decimal // types. If not applicable, ok should be false. func (r *Rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { - return r.rows.FieldDescriptions()[index].PrecisionScale() + fd := r.rows.FieldDescriptions()[index] + + switch fd.DataTypeOID { + case pgtype.NumericOID: + mod := fd.TypeModifier - varHeaderSize + precision = int64((mod >> 16) & 0xffff) + scale = int64(mod & 0xffff) + return precision, scale, true + default: + return 0, 0, false + } } // ColumnTypeScanType returns the value type that can be used to scan types into. func (r *Rows) ColumnTypeScanType(index int) reflect.Type { - return r.rows.FieldDescriptions()[index].Type() + fd := r.rows.FieldDescriptions()[index] + + switch fd.DataTypeOID { + case pgtype.Float8OID: + return reflect.TypeOf(float64(0)) + case pgtype.Float4OID: + return reflect.TypeOf(float32(0)) + case pgtype.Int8OID: + return reflect.TypeOf(int64(0)) + case pgtype.Int4OID: + return reflect.TypeOf(int32(0)) + case pgtype.Int2OID: + return reflect.TypeOf(int16(0)) + case pgtype.BoolOID: + return reflect.TypeOf(false) + case pgtype.NumericOID: + return reflect.TypeOf(float64(0)) + case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID: + return reflect.TypeOf(time.Time{}) + case pgtype.ByteaOID: + return reflect.TypeOf([]byte(nil)) + default: + return reflect.TypeOf("") + } } func (r *Rows) Close() error { r.rows.Close() - return nil + return r.rows.Err() } func (r *Rows) Next(dest []driver.Value) error { - if r.values == nil { - r.values = make([]interface{}, len(r.rows.FieldDescriptions())) - for i, fd := range r.rows.FieldDescriptions() { - switch fd.DataType { + m := r.conn.conn.TypeMap() + fieldDescriptions := r.rows.FieldDescriptions() + + if r.valueFuncs == nil { + r.valueFuncs = make([]rowValueFunc, len(fieldDescriptions)) + + for i, fd := range fieldDescriptions { + dataTypeOID := fd.DataTypeOID + format := fd.Format + + switch fd.DataTypeOID { case pgtype.BoolOID: - r.values[i] = &pgtype.Bool{} + var d bool + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + return d, err + } case pgtype.ByteaOID: - r.values[i] = &pgtype.Bytea{} - case pgtype.CIDOID: - r.values[i] = &pgtype.CID{} + var d []byte + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + return d, err + } + case pgtype.CIDOID, pgtype.OIDOID, pgtype.XIDOID: + var d pgtype.Uint32 + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + if err != nil { + return nil, err + } + return d.Value() + } case pgtype.DateOID: - r.values[i] = &pgtype.Date{} + var d pgtype.Date + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + if err != nil { + return nil, err + } + return d.Value() + } case pgtype.Float4OID: - r.values[i] = &pgtype.Float4{} + var d float32 + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + return float64(d), err + } case pgtype.Float8OID: - r.values[i] = &pgtype.Float8{} + var d float64 + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + return d, err + } case pgtype.Int2OID: - r.values[i] = &pgtype.Int2{} + var d int16 + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + return int64(d), err + } case pgtype.Int4OID: - r.values[i] = &pgtype.Int4{} + var d int32 + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + return int64(d), err + } case pgtype.Int8OID: - r.values[i] = &pgtype.Int8{} - case pgtype.OIDOID: - r.values[i] = &pgtype.OIDValue{} + var d int64 + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + return d, err + } + case pgtype.JSONOID, pgtype.JSONBOID: + var d []byte + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + if err != nil { + return nil, err + } + return d, nil + } case pgtype.TimestampOID: - r.values[i] = &pgtype.Timestamp{} + var d pgtype.Timestamp + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + if err != nil { + return nil, err + } + return d.Value() + } case pgtype.TimestamptzOID: - r.values[i] = &pgtype.Timestamptz{} - case pgtype.XIDOID: - r.values[i] = &pgtype.XID{} + var d pgtype.Timestamptz + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + if err != nil { + return nil, err + } + return d.Value() + } + case pgtype.XMLOID: + var d []byte + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + if err != nil { + return nil, err + } + return d, nil + } default: - r.values[i] = &pgtype.GenericText{} + var d string + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + return d, err + } } } } @@ -525,85 +878,42 @@ func (r *Rows) Next(dest []driver.Value) error { } } - err := r.rows.Scan(r.values...) - if err != nil { - return err - } - - for i, v := range r.values { - dest[i], err = v.(driver.Valuer).Value() - if err != nil { - return err + for i, rv := range r.rows.RawValues() { + if rv != nil { + var err error + dest[i], err = r.valueFuncs[i](rv) + if err != nil { + return fmt.Errorf("convert field %d failed: %w", i, err) + } + } else { + dest[i] = nil } } return nil } -func valueToInterface(argsV []driver.Value) []interface{} { - args := make([]interface{}, 0, len(argsV)) - for _, v := range argsV { - if v != nil { - args = append(args, v.(interface{})) - } else { - args = append(args, nil) - } - } - return args -} - -func namedValueToInterface(argsV []driver.NamedValue) []interface{} { - args := make([]interface{}, 0, len(argsV)) - for _, v := range argsV { +func convertNamedArguments(args []any, argsV []driver.NamedValue) { + for i, v := range argsV { if v.Value != nil { - args = append(args, v.Value.(interface{})) + args[i] = v.Value.(any) } else { - args = append(args, nil) + args[i] = nil } } - return args } -type fakeTx struct{} - -func (fakeTx) Commit() error { return nil } - -func (fakeTx) Rollback() error { return nil } - -func AcquireConn(db *sql.DB) (*pgx.Conn, error) { - driver, ok := db.Driver().(*Driver) - if !ok { - return nil, ErrNotPgx - } - - var conn *pgx.Conn - ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn) - tx, err := db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - - driver.fakeTxMutex.Lock() - driver.fakeTxConns[conn] = tx - driver.fakeTxMutex.Unlock() - - return conn, nil +func (r *Rows) ScanColumn(dest any, index int) error { + m := r.conn.conn.TypeMap() + fd := r.rows.FieldDescriptions()[index] + return m.Scan(fd.DataTypeOID, fd.Format, r.rows.RawValues()[index], dest) } -func ReleaseConn(db *sql.DB, conn *pgx.Conn) error { - var tx *sql.Tx - var ok bool +type wrapTx struct { + ctx context.Context + tx pgx.Tx +} - driver := db.Driver().(*Driver) - driver.fakeTxMutex.Lock() - tx, ok = driver.fakeTxConns[conn] - if ok { - delete(driver.fakeTxConns, conn) - driver.fakeTxMutex.Unlock() - } else { - driver.fakeTxMutex.Unlock() - return errors.Errorf("can't release conn that is not acquired") - } +func (wtx wrapTx) Commit() error { return wtx.tx.Commit(wtx.ctx) } - return tx.Rollback() -} +func (wtx wrapTx) Rollback() error { return wtx.tx.Rollback(wtx.ctx) } diff --git a/stdlib/sql_go1.26_test.go b/stdlib/sql_go1.26_test.go new file mode 100644 index 000000000..c8b9aaa2e --- /dev/null +++ b/stdlib/sql_go1.26_test.go @@ -0,0 +1,168 @@ +//go:build go1.26 + +package stdlib_test + +import ( + "database/sql" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/jackc/pgx/v5/pgtype" +) + +func TestGoArray(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + var names []string + + err := db.QueryRow("select array['John', 'Jane']::text[]").Scan(&names) + require.NoError(t, err) + require.Equal(t, []string{"John", "Jane"}, names) + + var n int + err = db.QueryRow("select cardinality($1::text[])", names).Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 2, n) + + err = db.QueryRow("select null::text[]").Scan(&names) + require.NoError(t, err) + require.Nil(t, names) + }) +} + +func TestGoArrayOfDriverValuer(t *testing.T) { + // Because []sql.NullString is not a registered type on the connection, it will only work with known OIDs. + testWithKnownOIDQueryExecModes(t, func(t *testing.T, db *sql.DB) { + var names []sql.NullString + + err := db.QueryRow("select array['John', null, 'Jane']::text[]").Scan(&names) + require.NoError(t, err) + require.Equal(t, []sql.NullString{{String: "John", Valid: true}, {}, {String: "Jane", Valid: true}}, names) + + var n int + err = db.QueryRow("select cardinality($1::text[])", names).Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 3, n) + + err = db.QueryRow("select null::text[]").Scan(&names) + require.NoError(t, err) + require.Nil(t, names) + }) +} + +func TestPGTypeFlatArray(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + var names pgtype.FlatArray[string] + + err := db.QueryRow("select array['John', 'Jane']::text[]").Scan(&names) + require.NoError(t, err) + require.Equal(t, pgtype.FlatArray[string]{"John", "Jane"}, names) + + var n int + err = db.QueryRow("select cardinality($1::text[])", names).Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 2, n) + + err = db.QueryRow("select null::text[]").Scan(&names) + require.NoError(t, err) + require.Nil(t, names) + }) +} + +func TestPGTypeArray(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + skipCockroachDB(t, db, "Server does not support nested arrays") + + var matrix pgtype.Array[int64] + + err := db.QueryRow("select '{{1,2,3},{4,5,6}}'::bigint[]").Scan(&matrix) + require.NoError(t, err) + require.Equal(t, + pgtype.Array[int64]{ + Elements: []int64{1, 2, 3, 4, 5, 6}, + Dims: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 1}, + {Length: 3, LowerBound: 1}, + }, + Valid: true}, + matrix) + + var equal bool + err = db.QueryRow("select '{{1,2,3},{4,5,6}}'::bigint[] = $1::bigint[]", matrix).Scan(&equal) + require.NoError(t, err) + require.Equal(t, true, equal) + + err = db.QueryRow("select null::bigint[]").Scan(&matrix) + require.NoError(t, err) + assert.Equal(t, pgtype.Array[int64]{Elements: nil, Dims: nil, Valid: false}, matrix) + }) +} + +func TestConnQueryPGTypeRange(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + skipCockroachDB(t, db, "Server does not support int4range") + + var r pgtype.Range[pgtype.Int4] + err := db.QueryRow("select int4range(1, 5)").Scan(&r) + require.NoError(t, err) + assert.Equal( + t, + pgtype.Range[pgtype.Int4]{ + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + r) + + var equal bool + err = db.QueryRow("select int4range(1, 5) = $1::int4range", r).Scan(&equal) + require.NoError(t, err) + require.Equal(t, true, equal) + + err = db.QueryRow("select null::int4range").Scan(&r) + require.NoError(t, err) + assert.Equal(t, pgtype.Range[pgtype.Int4]{}, r) + }) +} + +func TestConnQueryPGTypeMultirange(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + skipCockroachDB(t, db, "Server does not support int4range") + skipPostgreSQLVersionLessThan(t, db, 14) + + var r pgtype.Multirange[pgtype.Range[pgtype.Int4]] + err := db.QueryRow("select int4multirange(int4range(1, 5), int4range(7,9))").Scan(&r) + require.NoError(t, err) + assert.Equal( + t, + pgtype.Multirange[pgtype.Range[pgtype.Int4]]{ + { + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + { + Lower: pgtype.Int4{Int32: 7, Valid: true}, + Upper: pgtype.Int4{Int32: 9, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + }, + r) + + var equal bool + err = db.QueryRow("select int4multirange(int4range(1, 5), int4range(7,9)) = $1::int4multirange", r).Scan(&equal) + require.NoError(t, err) + require.Equal(t, true, equal) + + err = db.QueryRow("select null::int4multirange").Scan(&r) + require.NoError(t, err) + require.Nil(t, r) + }) +} diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index a4a999718..84f48f6c9 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -4,42 +4,144 @@ import ( "bytes" "context" "database/sql" + "encoding/json" "fmt" "math" + "os" "reflect" + "regexp" + "strconv" + "sync" "testing" "time" - "github.com/jackc/pgx" - "github.com/jackc/pgx/pgmock" - "github.com/jackc/pgx/pgproto3" - "github.com/jackc/pgx/stdlib" -) + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" -func openDB(t *testing.T) *sql.DB { - db, err := sql.Open("pgx", "postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test") - if err != nil { - t.Fatalf("sql.Open failed: %v", err) - } + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5/stdlib" + "github.com/jackc/pgx/v5/tracelog" +) - return db +func openDB(t testing.TB, opts ...stdlib.OptionOpenDB) *sql.DB { + t.Helper() + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + return stdlib.OpenDB(*config, opts...) } -func closeDB(t *testing.T, db *sql.DB) { +func closeDB(t testing.TB, db *sql.DB) { err := db.Close() - if err != nil { - t.Fatalf("db.Close unexpectedly failed: %v", err) + require.NoError(t, err) +} + +func skipCockroachDB(t testing.TB, db *sql.DB, msg string) { + conn, err := db.Conn(context.Background()) + require.NoError(t, err) + defer conn.Close() + + err = conn.Raw(func(driverConn any) error { + conn := driverConn.(*stdlib.Conn).Conn() + if conn.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip(msg) + } + return nil + }) + require.NoError(t, err) +} + +func skipPostgreSQLVersionLessThan(t testing.TB, db *sql.DB, minVersion int64) { + conn, err := db.Conn(context.Background()) + require.NoError(t, err) + defer conn.Close() + + err = conn.Raw(func(driverConn any) error { + conn := driverConn.(*stdlib.Conn).Conn() + serverVersionStr := conn.PgConn().ParameterStatus("server_version") + serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr) + // if not PostgreSQL do nothing + if serverVersionStr == "" { + return nil + } + + serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64) + if err != nil { + return err + } + + if serverVersion < minVersion { + t.Skipf("Test requires PostgreSQL v%d+", minVersion) + } + + return nil + }) + require.NoError(t, err) +} + +func testWithAllQueryExecModes(t *testing.T, f func(t *testing.T, db *sql.DB)) { + for _, mode := range []pgx.QueryExecMode{ + pgx.QueryExecModeCacheStatement, + pgx.QueryExecModeCacheDescribe, + pgx.QueryExecModeDescribeExec, + pgx.QueryExecModeExec, + pgx.QueryExecModeSimpleProtocol, + } { + t.Run(mode.String(), + func(t *testing.T) { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.DefaultQueryExecMode = mode + db := stdlib.OpenDB(*config) + defer func() { + err := db.Close() + require.NoError(t, err) + }() + + f(t, db) + + ensureDBValid(t, db) + }, + ) + } +} + +func testWithKnownOIDQueryExecModes(t *testing.T, f func(t *testing.T, db *sql.DB)) { + for _, mode := range []pgx.QueryExecMode{ + pgx.QueryExecModeCacheStatement, + pgx.QueryExecModeCacheDescribe, + pgx.QueryExecModeDescribeExec, + } { + t.Run(mode.String(), + func(t *testing.T) { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.DefaultQueryExecMode = mode + db := stdlib.OpenDB(*config) + defer func() { + err := db.Close() + require.NoError(t, err) + }() + + f(t, db) + + ensureDBValid(t, db) + }, + ) } } -// Do a simple query to ensure the connection is still usable -func ensureConnValid(t *testing.T, db *sql.DB) { +// Do a simple query to ensure the DB is still usable. This is of less use in stdlib as the connection pool should +// cover broken connections. +func ensureDBValid(t testing.TB, db *sql.DB) { var sum, rowCount int32 rows, err := db.Query("select generate_series(1,$1)", 10) - if err != nil { - t.Fatalf("db.Query failed: %v", err) - } + require.NoError(t, err) defer rows.Close() for rows.Next() { @@ -49,9 +151,7 @@ func ensureConnValid(t *testing.T, db *sql.DB) { rowCount++ } - if rows.Err() != nil { - t.Fatalf("db.Query failed: %v", err) - } + require.NoError(t, rows.Err()) if rowCount != 10 { t.Error("Select called onDataRow wrong number of times") @@ -67,31 +167,56 @@ type preparer interface { func prepareStmt(t *testing.T, p preparer, sql string) *sql.Stmt { stmt, err := p.Prepare(sql) - if err != nil { - t.Fatalf("%v Prepare unexpectedly failed: %v", p, err) - } - + require.NoError(t, err) return stmt } func closeStmt(t *testing.T, stmt *sql.Stmt) { err := stmt.Close() - if err != nil { - t.Fatalf("stmt.Close unexpectedly failed: %v", err) + require.NoError(t, err) +} + +func TestSQLOpen(t *testing.T) { + tests := []struct { + driverName string + }{ + {driverName: "pgx"}, + {driverName: "pgx/v5"}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.driverName, func(t *testing.T) { + db, err := sql.Open(tt.driverName, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + closeDB(t, db) + }) } } +func TestSQLOpenFromPool(t *testing.T) { + pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + t.Cleanup(pool.Close) + + db := stdlib.OpenDBFromPool(pool) + ensureDBValid(t, db) + + db.Close() +} + func TestNormalLifeCycle(t *testing.T) { db := openDB(t) defer closeDB(t, db) + skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") + stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n") defer closeStmt(t, stmt) rows, err := stmt.Query(int32(1), int32(10)) - if err != nil { - t.Fatalf("stmt.Query unexpectedly failed: %v", err) - } + require.NoError(t, err) rowCount := int64(0) @@ -100,9 +225,9 @@ func TestNormalLifeCycle(t *testing.T) { var s string var n int64 - if err := rows.Scan(&s, &n); err != nil { - t.Fatalf("rows.Scan unexpectedly failed: %v", err) - } + err := rows.Scan(&s, &n) + require.NoError(t, err) + if s != "foo" { t.Errorf(`Expected "foo", received "%v"`, s) } @@ -110,47 +235,14 @@ func TestNormalLifeCycle(t *testing.T) { t.Errorf("Expected %d, received %d", rowCount, n) } } - err = rows.Err() - if err != nil { - t.Fatalf("rows.Err unexpectedly is: %v", err) - } - if rowCount != 10 { - t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount) - } + require.NoError(t, rows.Err()) - err = rows.Close() - if err != nil { - t.Fatalf("rows.Close unexpectedly failed: %v", err) - } - - ensureConnValid(t, db) -} - -func TestOpenWithDriverConfigAfterConnect(t *testing.T) { - driverConfig := stdlib.DriverConfig{ - AfterConnect: func(c *pgx.Conn) error { - _, err := c.Exec("create temporary sequence pgx") - return err - }, - } - - stdlib.RegisterDriverConfig(&driverConfig) - defer stdlib.UnregisterDriverConfig(&driverConfig) + require.EqualValues(t, 10, rowCount) - db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")) - if err != nil { - t.Fatalf("sql.Open failed: %v", err) - } - defer closeDB(t, db) + err = rows.Close() + require.NoError(t, err) - var n int64 - err = db.QueryRow("select nextval('pgx')").Scan(&n) - if err != nil { - t.Fatalf("db.QueryRow unexpectedly failed: %v", err) - } - if n != 1 { - t.Fatalf("n => %d, want %d", n, 1) - } + ensureDBValid(t, db) } func TestStmtExec(t *testing.T) { @@ -158,62 +250,44 @@ func TestStmtExec(t *testing.T) { defer closeDB(t, db) tx, err := db.Begin() - if err != nil { - t.Fatalf("db.Begin unexpectedly failed: %v", err) - } + require.NoError(t, err) createStmt := prepareStmt(t, tx, "create temporary table t(a varchar not null)") _, err = createStmt.Exec() - if err != nil { - t.Fatalf("stmt.Exec unexpectedly failed: %v", err) - } + require.NoError(t, err) closeStmt(t, createStmt) insertStmt := prepareStmt(t, tx, "insert into t values($1::text)") result, err := insertStmt.Exec("foo") - if err != nil { - t.Fatalf("stmt.Exec unexpectedly failed: %v", err) - } + require.NoError(t, err) n, err := result.RowsAffected() - if err != nil { - t.Fatalf("result.RowsAffected unexpectedly failed: %v", err) - } - if n != 1 { - t.Fatalf("Expected 1, received %d", n) - } + require.NoError(t, err) + require.EqualValues(t, 1, n) closeStmt(t, insertStmt) - if err != nil { - t.Fatalf("tx.Commit unexpectedly failed: %v", err) - } - - ensureConnValid(t, db) + ensureDBValid(t, db) } func TestQueryCloseRowsEarly(t *testing.T) { db := openDB(t) defer closeDB(t, db) + skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") + stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n") defer closeStmt(t, stmt) rows, err := stmt.Query(int32(1), int32(10)) - if err != nil { - t.Fatalf("stmt.Query unexpectedly failed: %v", err) - } + require.NoError(t, err) // Close rows immediately without having read them err = rows.Close() - if err != nil { - t.Fatalf("rows.Close unexpectedly failed: %v", err) - } + require.NoError(t, err) // Run the query again to ensure the connection and statement are still ok rows, err = stmt.Query(int32(1), int32(10)) - if err != nil { - t.Fatalf("stmt.Query unexpectedly failed: %v", err) - } + require.NoError(t, err) rowCount := int64(0) @@ -222,9 +296,8 @@ func TestQueryCloseRowsEarly(t *testing.T) { var s string var n int64 - if err := rows.Scan(&s, &n); err != nil { - t.Fatalf("rows.Scan unexpectedly failed: %v", err) - } + err := rows.Scan(&s, &n) + require.NoError(t, err) if s != "foo" { t.Errorf(`Expected "foo", received "%v"`, s) } @@ -232,289 +305,299 @@ func TestQueryCloseRowsEarly(t *testing.T) { t.Errorf("Expected %d, received %d", rowCount, n) } } - err = rows.Err() - if err != nil { - t.Fatalf("rows.Err unexpectedly is: %v", err) - } - if rowCount != 10 { - t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount) - } + require.NoError(t, rows.Err()) + require.EqualValues(t, 10, rowCount) err = rows.Close() - if err != nil { - t.Fatalf("rows.Close unexpectedly failed: %v", err) - } + require.NoError(t, err) - ensureConnValid(t, db) + ensureDBValid(t, db) } func TestConnExec(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + _, err := db.Exec("create temporary table t(a varchar not null)") + require.NoError(t, err) - _, err := db.Exec("create temporary table t(a varchar not null)") - if err != nil { - t.Fatalf("db.Exec unexpectedly failed: %v", err) - } - - result, err := db.Exec("insert into t values('hey')") - if err != nil { - t.Fatalf("db.Exec unexpectedly failed: %v", err) - } - - n, err := result.RowsAffected() - if err != nil { - t.Fatalf("result.RowsAffected unexpectedly failed: %v", err) - } - if n != 1 { - t.Fatalf("Expected 1, received %d", n) - } + result, err := db.Exec("insert into t values('hey')") + require.NoError(t, err) - ensureConnValid(t, db) + n, err := result.RowsAffected() + require.NoError(t, err) + require.EqualValues(t, 1, n) + }) } func TestConnQuery(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") - rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10)) - if err != nil { - t.Fatalf("db.Query unexpectedly failed: %v", err) - } + rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10)) + require.NoError(t, err) - rowCount := int64(0) + rowCount := int64(0) - for rows.Next() { - rowCount++ + for rows.Next() { + rowCount++ - var s string - var n int64 - if err := rows.Scan(&s, &n); err != nil { - t.Fatalf("rows.Scan unexpectedly failed: %v", err) - } - if s != "foo" { - t.Errorf(`Expected "foo", received "%v"`, s) - } - if n != rowCount { - t.Errorf("Expected %d, received %d", rowCount, n) + var s string + var n int64 + err := rows.Scan(&s, &n) + require.NoError(t, err) + if s != "foo" { + t.Errorf(`Expected "foo", received "%v"`, s) + } + if n != rowCount { + t.Errorf("Expected %d, received %d", rowCount, n) + } } - } - err = rows.Err() - if err != nil { - t.Fatalf("rows.Err unexpectedly is: %v", err) - } - if rowCount != 10 { - t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount) - } - - err = rows.Close() - if err != nil { - t.Fatalf("rows.Close unexpectedly failed: %v", err) - } - - ensureConnValid(t, db) -} + require.NoError(t, rows.Err()) + require.EqualValues(t, 10, rowCount) -type testLog struct { - lvl pgx.LogLevel - msg string - data map[string]interface{} -} - -type testLogger struct { - logs []testLog -} - -func (l *testLogger) Log(lvl pgx.LogLevel, msg string, data map[string]interface{}) { - l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data}) + err = rows.Close() + require.NoError(t, err) + }) } -func TestConnQueryLog(t *testing.T) { - logger := &testLogger{} +func TestConnConcurrency(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + _, err := db.Exec("create table t (id integer primary key, str text, dur_str interval)") + require.NoError(t, err) - driverConfig := stdlib.DriverConfig{ - ConnConfig: pgx.ConnConfig{ - Host: "127.0.0.1", - User: "pgx_md5", - Password: "secret", - Database: "pgx_test", - Logger: logger, - }, - } + defer func() { + _, err := db.Exec("drop table t") + require.NoError(t, err) + }() - stdlib.RegisterDriverConfig(&driverConfig) - defer stdlib.UnregisterDriverConfig(&driverConfig) + var wg sync.WaitGroup + + concurrency := 50 + errChan := make(chan error, concurrency) + + for i := 1; i <= concurrency; i++ { + wg.Add(1) + + go func(idx int) { + defer wg.Done() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + str := strconv.Itoa(idx) + duration := time.Duration(idx) * time.Second + _, err := db.ExecContext(ctx, "insert into t values($1)", idx) + if err != nil { + errChan <- fmt.Errorf("insert failed: %d %w", idx, err) + return + } + _, err = db.ExecContext(ctx, "update t set str = $1 where id = $2", str, idx) + if err != nil { + errChan <- fmt.Errorf("update 1 failed: %d %w", idx, err) + return + } + _, err = db.ExecContext(ctx, "update t set dur_str = $1 where id = $2", duration, idx) + if err != nil { + errChan <- fmt.Errorf("update 2 failed: %d %w", idx, err) + return + } + + errChan <- nil + }(i) + } + wg.Wait() + for i := 1; i <= concurrency; i++ { + err := <-errChan + require.NoError(t, err) + } - db, err := sql.Open("pgx", driverConfig.ConnectionString("")) - if err != nil { - t.Fatalf("sql.Open failed: %v", err) - } - defer closeDB(t, db) + for i := 1; i <= concurrency; i++ { + wg.Add(1) + + go func(idx int) { + defer wg.Done() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + var id int + var str string + var duration pgtype.Interval + err := db.QueryRowContext(ctx, "select id,str,dur_str from t where id = $1", idx).Scan(&id, &str, &duration) + if err != nil { + errChan <- fmt.Errorf("select failed: %d %w", idx, err) + return + } + if id != idx { + errChan <- fmt.Errorf("id mismatch: %d %d", idx, id) + return + } + if str != strconv.Itoa(idx) { + errChan <- fmt.Errorf("str mismatch: %d %s", idx, str) + return + } + expectedDuration := pgtype.Interval{ + Microseconds: int64(idx) * time.Second.Microseconds(), + Valid: true, + } + if duration != expectedDuration { + errChan <- fmt.Errorf("duration mismatch: %d %v", idx, duration) + return + } + + errChan <- nil + }(i) + } + wg.Wait() + for i := 1; i <= concurrency; i++ { + err := <-errChan + require.NoError(t, err) + } + }) +} - var n int64 - err = db.QueryRow("select 1").Scan(&n) - if err != nil { - t.Fatalf("db.QueryRow unexpectedly failed: %v", err) - } +// https://github.com/jackc/pgx/issues/781 +func TestConnQueryDifferentScanPlansIssue781(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + var s string + var b bool - l := logger.logs[len(logger.logs)-1] - if l.msg != "Query" { - t.Errorf("Expected to log Query, but got %v", l) - } + rows, err := db.Query("select true, 'foo'") + require.NoError(t, err) - if l.data["sql"] != "select 1" { - t.Errorf("Expected to log Query with sql 'select 1', but got %v", l) - } + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&b, &s)) + assert.Equal(t, true, b) + assert.Equal(t, "foo", s) + }) } func TestConnQueryNull(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + rows, err := db.Query("select $1::int", nil) + require.NoError(t, err) - rows, err := db.Query("select $1::int", nil) - if err != nil { - t.Fatalf("db.Query unexpectedly failed: %v", err) - } - - rowCount := int64(0) + rowCount := int64(0) - for rows.Next() { - rowCount++ + for rows.Next() { + rowCount++ - var n sql.NullInt64 - if err := rows.Scan(&n); err != nil { - t.Fatalf("rows.Scan unexpectedly failed: %v", err) - } - if n.Valid != false { - t.Errorf("Expected n to be null, but it was %v", n) + var n sql.NullInt64 + err := rows.Scan(&n) + require.NoError(t, err) + if n.Valid != false { + t.Errorf("Expected n to be null, but it was %v", n) + } } - } - err = rows.Err() - if err != nil { - t.Fatalf("rows.Err unexpectedly is: %v", err) - } - if rowCount != 1 { - t.Fatalf("Expected to receive 11 rows, instead received %d", rowCount) - } - - err = rows.Close() - if err != nil { - t.Fatalf("rows.Close unexpectedly failed: %v", err) - } + require.NoError(t, rows.Err()) + require.EqualValues(t, 1, rowCount) - ensureConnValid(t, db) + err = rows.Close() + require.NoError(t, err) + }) } func TestConnQueryRowByteSlice(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) - - expected := []byte{222, 173, 190, 239} - var actual []byte + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + expected := []byte{222, 173, 190, 239} + var actual []byte + + err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual) + require.NoError(t, err) + require.EqualValues(t, expected, actual) + }) +} - err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual) - if err != nil { - t.Fatalf("db.QueryRow unexpectedly failed: %v", err) - } +func TestConnQueryFailure(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + _, err := db.Query("select 'foo") + require.Error(t, err) + require.IsType(t, new(pgconn.PgError), err) + }) +} - if bytes.Compare(actual, expected) != 0 { - t.Fatalf("Expected %v, but got %v", expected, actual) - } +func TestConnSimpleSlicePassThrough(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + skipCockroachDB(t, db, "Server does not support cardinality function") - ensureConnValid(t, db) + var n int64 + err := db.QueryRow("select cardinality($1::text[])", []string{"a", "b", "c"}).Scan(&n) + require.NoError(t, err) + assert.EqualValues(t, 3, n) + }) } -func TestConnQueryFailure(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) - - _, err := db.Query("select 'foo") - if _, ok := err.(pgx.PgError); !ok { - t.Fatalf("Expected db.Query to return pgx.PgError, but instead received: %v", err) - } +func TestConnQueryScanGoArray(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + m := pgtype.NewMap() - ensureConnValid(t, db) + var a []int64 + err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a)) + require.NoError(t, err) + assert.Equal(t, []int64{1, 2, 3}, a) + }) } // Test type that pgx would handle natively in binary, but since it is not a // database/sql native type should be passed through as a string func TestConnQueryRowPgxBinary(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) - - sql := "select $1::int4[]" - expected := "{1,2,3}" - var actual string - - err := db.QueryRow(sql, expected).Scan(&actual) - if err != nil { - t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) - } - - if actual != expected { - t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql) - } - - ensureConnValid(t, db) + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + sql := "select $1::int4[]" + expected := "{1,2,3}" + var actual string + + err := db.QueryRow(sql, expected).Scan(&actual) + require.NoError(t, err) + require.EqualValues(t, expected, actual) + }) } func TestConnQueryRowUnknownType(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) - - sql := "select $1::point" - expected := "(1,2)" - var actual string + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + skipCockroachDB(t, db, "Server does not support point type") - err := db.QueryRow(sql, expected).Scan(&actual) - if err != nil { - t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) - } - - if actual != expected { - t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql) - } + sql := "select $1::point" + expected := "(1,2)" + var actual string - ensureConnValid(t, db) + err := db.QueryRow(sql, expected).Scan(&actual) + require.NoError(t, err) + require.EqualValues(t, expected, actual) + }) } func TestConnQueryJSONIntoByteSlice(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) - - _, err := db.Exec(` + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + _, err := db.Exec(` create temporary table docs( body json not null ); - insert into docs(body) values('{"foo":"bar"}'); + insert into docs(body) values('{"foo": "bar"}'); `) - if err != nil { - t.Fatalf("db.Exec unexpectedly failed: %v", err) - } + require.NoError(t, err) - sql := `select * from docs` - expected := []byte(`{"foo":"bar"}`) - var actual []byte - - err = db.QueryRow(sql).Scan(&actual) - if err != nil { - t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) - } + sql := `select * from docs` + expected := []byte(`{"foo": "bar"}`) + var actual []byte - if bytes.Compare(actual, expected) != 0 { - t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, string(expected), string(actual), sql) - } + err = db.QueryRow(sql).Scan(&actual) + if err != nil { + t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) + } - _, err = db.Exec(`drop table docs`) - if err != nil { - t.Fatalf("db.Exec unexpectedly failed: %v", err) - } + if !bytes.Equal(actual, expected) { + t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, string(expected), string(actual), sql) + } - ensureConnValid(t, db) + _, err = db.Exec(`drop table docs`) + require.NoError(t, err) + }) } func TestConnExecInsertByteSliceIntoJSON(t *testing.T) { + // Not testing with simple protocol because there is no way for that to work. A []byte will be considered binary data + // that needs to escape. No way to know whether the destination is really a text compatible or a bytea. + db := openDB(t) defer closeDB(t, db) @@ -523,477 +606,310 @@ func TestConnExecInsertByteSliceIntoJSON(t *testing.T) { body json not null ); `) - if err != nil { - t.Fatalf("db.Exec unexpectedly failed: %v", err) - } + require.NoError(t, err) - expected := []byte(`{"foo":"bar"}`) + expected := []byte(`{"foo": "bar"}`) _, err = db.Exec(`insert into docs(body) values($1)`, expected) - if err != nil { - t.Fatalf("db.Exec unexpectedly failed: %v", err) - } + require.NoError(t, err) var actual []byte err = db.QueryRow(`select body from docs`).Scan(&actual) - if err != nil { - t.Fatalf("db.QueryRow unexpectedly failed: %v", err) - } + require.NoError(t, err) - if bytes.Compare(actual, expected) != 0 { + if !bytes.Equal(actual, expected) { t.Errorf(`Expected "%v", got "%v"`, string(expected), string(actual)) } _, err = db.Exec(`drop table docs`) - if err != nil { - t.Fatalf("db.Exec unexpectedly failed: %v", err) - } - - ensureConnValid(t, db) + require.NoError(t, err) } func TestTransactionLifeCycle(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) - - _, err := db.Exec("create temporary table t(a varchar not null)") - if err != nil { - t.Fatalf("db.Exec unexpectedly failed: %v", err) - } - - tx, err := db.Begin() - if err != nil { - t.Fatalf("db.Begin unexpectedly failed: %v", err) - } + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + _, err := db.Exec("create temporary table t(a varchar not null)") + require.NoError(t, err) - _, err = tx.Exec("insert into t values('hi')") - if err != nil { - t.Fatalf("tx.Exec unexpectedly failed: %v", err) - } + tx, err := db.Begin() + require.NoError(t, err) - err = tx.Rollback() - if err != nil { - t.Fatalf("tx.Rollback unexpectedly failed: %v", err) - } + _, err = tx.Exec("insert into t values('hi')") + require.NoError(t, err) - var n int64 - err = db.QueryRow("select count(*) from t").Scan(&n) - if err != nil { - t.Fatalf("db.QueryRow.Scan unexpectedly failed: %v", err) - } - if n != 0 { - t.Fatalf("Expected 0 rows due to rollback, instead found %d", n) - } + err = tx.Rollback() + require.NoError(t, err) - tx, err = db.Begin() - if err != nil { - t.Fatalf("db.Begin unexpectedly failed: %v", err) - } + var n int64 + err = db.QueryRow("select count(*) from t").Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 0, n) - _, err = tx.Exec("insert into t values('hi')") - if err != nil { - t.Fatalf("tx.Exec unexpectedly failed: %v", err) - } + tx, err = db.Begin() + require.NoError(t, err) - err = tx.Commit() - if err != nil { - t.Fatalf("tx.Commit unexpectedly failed: %v", err) - } + _, err = tx.Exec("insert into t values('hi')") + require.NoError(t, err) - err = db.QueryRow("select count(*) from t").Scan(&n) - if err != nil { - t.Fatalf("db.QueryRow.Scan unexpectedly failed: %v", err) - } - if n != 1 { - t.Fatalf("Expected 1 rows due to rollback, instead found %d", n) - } + err = tx.Commit() + require.NoError(t, err) - ensureConnValid(t, db) + err = db.QueryRow("select count(*) from t").Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + }) } func TestConnBeginTxIsolation(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) - - var defaultIsoLevel string - err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel) - if err != nil { - t.Fatalf("QueryRow failed: %v", err) - } + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + skipCockroachDB(t, db, "Server always uses serializable isolation level") + + var defaultIsoLevel string + err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel) + require.NoError(t, err) + + supportedTests := []struct { + sqlIso sql.IsolationLevel + pgIso string + }{ + {sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel}, + {sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"}, + {sqlIso: sql.LevelReadCommitted, pgIso: "read committed"}, + {sqlIso: sql.LevelRepeatableRead, pgIso: "repeatable read"}, + {sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"}, + {sqlIso: sql.LevelSerializable, pgIso: "serializable"}, + } + for i, tt := range supportedTests { + func() { + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso}) + if err != nil { + t.Errorf("%d. BeginTx failed: %v", i, err) + return + } + defer tx.Rollback() + + var pgIso string + err = tx.QueryRow("show transaction_isolation").Scan(&pgIso) + if err != nil { + t.Errorf("%d. QueryRow failed: %v", i, err) + } + + if pgIso != tt.pgIso { + t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso) + } + }() + } - supportedTests := []struct { - sqlIso sql.IsolationLevel - pgIso string - }{ - {sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel}, - {sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"}, - {sqlIso: sql.LevelReadCommitted, pgIso: "read committed"}, - {sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"}, - {sqlIso: sql.LevelSerializable, pgIso: "serializable"}, - } - for i, tt := range supportedTests { - func() { + unsupportedTests := []struct { + sqlIso sql.IsolationLevel + }{ + {sqlIso: sql.LevelWriteCommitted}, + {sqlIso: sql.LevelLinearizable}, + } + for i, tt := range unsupportedTests { tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso}) - if err != nil { - t.Errorf("%d. BeginTx failed: %v", i, err) - return - } - defer tx.Rollback() - - var pgIso string - err = tx.QueryRow("show transaction_isolation").Scan(&pgIso) - if err != nil { - t.Errorf("%d. QueryRow failed: %v", i, err) + if err == nil { + t.Errorf("%d. BeginTx should have failed", i) + tx.Rollback() } - - if pgIso != tt.pgIso { - t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso) - } - }() - } - - unsupportedTests := []struct { - sqlIso sql.IsolationLevel - }{ - {sqlIso: sql.LevelWriteCommitted}, - {sqlIso: sql.LevelLinearizable}, - } - for i, tt := range unsupportedTests { - tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso}) - if err == nil { - t.Errorf("%d. BeginTx should have failed", i) - tx.Rollback() } - } - - ensureConnValid(t, db) + }) } func TestConnBeginTxReadOnly(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) - - tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) - if err != nil { - t.Fatalf("BeginTx failed: %v", err) - } - defer tx.Rollback() - - var pgReadOnly string - err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly) - if err != nil { - t.Errorf("QueryRow failed: %v", err) - } + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) + require.NoError(t, err) + defer tx.Rollback() - if pgReadOnly != "on" { - t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on") - } + var pgReadOnly string + err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly) + if err != nil { + t.Errorf("QueryRow failed: %v", err) + } - ensureConnValid(t, db) + if pgReadOnly != "on" { + t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on") + } + }) } func TestBeginTxContextCancel(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + _, err := db.Exec("drop table if exists t") + require.NoError(t, err) - _, err := db.Exec("drop table if exists t") - if err != nil { - t.Fatalf("db.Exec failed: %v", err) - } + ctx, cancelFn := context.WithCancel(context.Background()) - ctx, cancelFn := context.WithCancel(context.Background()) + tx, err := db.BeginTx(ctx, nil) + require.NoError(t, err) - tx, err := db.BeginTx(ctx, nil) - if err != nil { - t.Fatalf("BeginTx failed: %v", err) - } + _, err = tx.Exec("create table t(id serial)") + require.NoError(t, err) - _, err = tx.Exec("create table t(id serial)") - if err != nil { - t.Fatalf("tx.Exec failed: %v", err) - } + cancelFn() - cancelFn() - - err = tx.Commit() - if err != context.Canceled && err != sql.ErrTxDone { - t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone) - } - - var n int - err = db.QueryRow("select count(*) from t").Scan(&n) - if pgErr, ok := err.(pgx.PgError); !ok || pgErr.Code != "42P01" { - t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err) - } + err = tx.Commit() + if err != context.Canceled && err != sql.ErrTxDone { + t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone) + } - ensureConnValid(t, db) + var n int + err = db.QueryRow("select count(*) from t").Scan(&n) + if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "42P01" { + t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err) + } + }) } -func acceptStandardPgxConn(backend *pgproto3.Backend) error { - script := pgmock.Script{ - Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), - } - - err := script.Run(backend) - if err != nil { - return err - } - - typeScript := pgmock.Script{ - Steps: pgmock.PgxInitSteps(), - } - - return typeScript.Run(backend) +func TestConnRaw(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + conn, err := db.Conn(context.Background()) + require.NoError(t, err) + + var n int + err = conn.Raw(func(driverConn any) error { + conn := driverConn.(*stdlib.Conn).Conn() + return conn.QueryRow(context.Background(), "select 42").Scan(&n) + }) + require.NoError(t, err) + assert.EqualValues(t, 42, n) + }) } -func TestBeginTxContextCancelWithDeadConn(t *testing.T) { - script := &pgmock.Script{ - Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), - } - script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) - script.Steps = append(script.Steps, - pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}), - pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: "BEGIN"}), - pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'T'}), - ) - - server, err := pgmock.NewServer(script) - if err != nil { - t.Fatal(err) - } - - errChan := make(chan error) - go func() { - errChan <- server.ServeOne() - }() - - db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) - if err != nil { - t.Fatalf("sql.Open failed: %v", err) - } - defer closeDB(t, db) - - ctx, cancelFn := context.WithCancel(context.Background()) - - tx, err := db.BeginTx(ctx, nil) - if err != nil { - t.Fatalf("BeginTx failed: %v", err) - } - - cancelFn() - - err = tx.Commit() - if err != context.Canceled && err != sql.ErrTxDone { - t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone) - } - - if err := <-errChan; err != nil { - t.Fatalf("mock server err: %v", err) - } +func TestConnPingContextSuccess(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + err := db.PingContext(context.Background()) + require.NoError(t, err) + }) } -func TestAcquireConn(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) +func TestConnPrepareContextSuccess(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + stmt, err := db.PrepareContext(context.Background(), "select now()") + require.NoError(t, err) + err = stmt.Close() + require.NoError(t, err) + }) +} - var conns []*pgx.Conn +// https://github.com/jackc/pgx/issues/1753#issuecomment-1746033281 +// https://github.com/jackc/pgx/issues/1754#issuecomment-1752004634 +func TestConnMultiplePrepareAndDeallocate(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + skipCockroachDB(t, db, "Server does not support pg_prepared_statements") + + sql := "select 42" + stmt1, err := db.PrepareContext(context.Background(), sql) + require.NoError(t, err) + stmt2, err := db.PrepareContext(context.Background(), sql) + require.NoError(t, err) + err = stmt1.Close() + require.NoError(t, err) + + var preparedStmtCount int64 + err = db.QueryRowContext(context.Background(), "select count(*) from pg_prepared_statements where statement = $1", sql).Scan(&preparedStmtCount) + require.NoError(t, err) + require.EqualValues(t, 1, preparedStmtCount) + + err = stmt2.Close() // err isn't as useful as it should be as database/sql will ignore errors from Deallocate. + require.NoError(t, err) + + err = db.QueryRowContext(context.Background(), "select count(*) from pg_prepared_statements where statement = $1", sql).Scan(&preparedStmtCount) + require.NoError(t, err) + require.EqualValues(t, 0, preparedStmtCount) + }) +} - for i := 1; i < 6; i++ { - conn, err := stdlib.AcquireConn(db) - if err != nil { - t.Errorf("%d. AcquireConn failed: %v", i, err) - continue - } +func TestConnExecContextSuccess(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + _, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)") + require.NoError(t, err) + }) +} - var n int32 - err = conn.QueryRow("select 1").Scan(&n) - if err != nil { - t.Errorf("%d. QueryRow failed: %v", i, err) - } - if n != 1 { - t.Errorf("%d. n => %d, want %d", i, n, 1) +func TestConnQueryContextSuccess(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n") + require.NoError(t, err) + + for rows.Next() { + var n int64 + err := rows.Scan(&n) + require.NoError(t, err) } + require.NoError(t, rows.Err()) + }) +} - stats := db.Stats() - if stats.OpenConnections != i { - t.Errorf("%d. stats.OpenConnections => %d, want %d", i, stats.OpenConnections, i) - } +func TestRowsColumnTypeDatabaseTypeName(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + rows, err := db.Query("select 42::bigint") + require.NoError(t, err) - conns = append(conns, conn) - } + columnTypes, err := rows.ColumnTypes() + require.NoError(t, err) + require.Len(t, columnTypes, 1) - for i, conn := range conns { - if err := stdlib.ReleaseConn(db, conn); err != nil { - t.Errorf("%d. stdlib.ReleaseConn failed: %v", i, err) + if columnTypes[0].DatabaseTypeName() != "INT8" { + t.Errorf("columnTypes[0].DatabaseTypeName() => %v, want %v", columnTypes[0].DatabaseTypeName(), "INT8") } - } - ensureConnValid(t, db) + err = rows.Close() + require.NoError(t, err) + }) } -func TestConnPingContextSuccess(t *testing.T) { +func TestStmtExecContextSuccess(t *testing.T) { db := openDB(t) defer closeDB(t, db) - if err := db.PingContext(context.Background()); err != nil { - t.Fatalf("db.PingContext failed: %v", err) - } - - ensureConnValid(t, db) -} - -func TestConnPingContextCancel(t *testing.T) { - script := &pgmock.Script{ - Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), - } - script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) - script.Steps = append(script.Steps, - pgmock.ExpectMessage(&pgproto3.Query{String: ";"}), - pgmock.WaitForClose(), - ) - - server, err := pgmock.NewServer(script) - if err != nil { - t.Fatal(err) - } - defer server.Close() - - errChan := make(chan error, 1) - go func() { - errChan <- server.ServeOne() - }() - - db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) - if err != nil { - t.Fatalf("sql.Open failed: %v", err) - } - defer closeDB(t, db) + _, err := db.Exec("create temporary table t(id int primary key)") + require.NoError(t, err) - ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + stmt, err := db.Prepare("insert into t(id) values ($1::int4)") + require.NoError(t, err) + defer stmt.Close() - err = db.PingContext(ctx) - if err != context.DeadlineExceeded { - t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) - } + _, err = stmt.ExecContext(context.Background(), 42) + require.NoError(t, err) - if err := <-errChan; err != nil { - t.Errorf("mock server err: %v", err) - } + ensureDBValid(t, db) } -func TestConnPrepareContextSuccess(t *testing.T) { +func TestStmtExecContextCancel(t *testing.T) { db := openDB(t) defer closeDB(t, db) - stmt, err := db.PrepareContext(context.Background(), "select now()") - if err != nil { - t.Fatalf("db.PrepareContext failed: %v", err) - } - stmt.Close() - - ensureConnValid(t, db) -} - -func TestConnPrepareContextCancel(t *testing.T) { - script := &pgmock.Script{ - Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), - } - script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) - script.Steps = append(script.Steps, - pgmock.ExpectMessage(&pgproto3.Parse{Name: "pgx_0", Query: "select now()"}), - pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: "pgx_0"}), - pgmock.ExpectMessage(&pgproto3.Sync{}), - pgmock.WaitForClose(), - ) - - server, err := pgmock.NewServer(script) - if err != nil { - t.Fatal(err) - } - defer server.Close() - - errChan := make(chan error) - go func() { - errChan <- server.ServeOne() - }() + _, err := db.Exec("create temporary table t(id int primary key)") + require.NoError(t, err) - db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) - if err != nil { - t.Fatalf("sql.Open failed: %v", err) - } - defer closeDB(t, db) + stmt, err := db.Prepare("insert into t(id) select $1::int4 from pg_sleep(5)") + require.NoError(t, err) + defer stmt.Close() - ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() - _, err = db.PrepareContext(ctx, "select now()") - if err != context.DeadlineExceeded { - t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) + _, err = stmt.ExecContext(ctx, 42) + if !pgconn.Timeout(err) { + t.Errorf("expected timeout error, got %v", err) } - if err := <-errChan; err != nil { - t.Errorf("mock server err: %v", err) - } + ensureDBValid(t, db) } -func TestConnExecContextSuccess(t *testing.T) { +func TestStmtQueryContextSuccess(t *testing.T) { db := openDB(t) defer closeDB(t, db) - _, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)") - if err != nil { - t.Fatalf("db.ExecContext failed: %v", err) - } - - ensureConnValid(t, db) -} - -func TestConnExecContextCancel(t *testing.T) { - script := &pgmock.Script{ - Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), - } - script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) - script.Steps = append(script.Steps, - pgmock.ExpectMessage(&pgproto3.Query{String: "create temporary table exec_context_test(id serial primary key)"}), - pgmock.WaitForClose(), - ) - - server, err := pgmock.NewServer(script) - if err != nil { - t.Fatal(err) - } - defer server.Close() + skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") - errChan := make(chan error) - go func() { - errChan <- server.ServeOne() - }() - - db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) - if err != nil { - t.Fatalf("sql.Open failed: %v", err) - } - defer closeDB(t, db) - - ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) - - _, err = db.ExecContext(ctx, "create temporary table exec_context_test(id serial primary key)") - if err != context.DeadlineExceeded { - t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) - } - - if err := <-errChan; err != nil { - t.Errorf("mock server err: %v", err) - } -} - -func TestConnQueryContextSuccess(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) + stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n") + require.NoError(t, err) + defer stmt.Close() - rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n") - if err != nil { - t.Fatalf("db.QueryContext failed: %v", err) - } + rows, err := stmt.QueryContext(context.Background(), 5) + require.NoError(t, err) for rows.Next() { var n int64 @@ -1006,463 +922,472 @@ func TestConnQueryContextSuccess(t *testing.T) { t.Error(rows.Err()) } - ensureConnValid(t, db) + ensureDBValid(t, db) } -func TestConnQueryContextCancel(t *testing.T) { - script := &pgmock.Script{ - Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), - } - script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) - script.Steps = append(script.Steps, - pgmock.ExpectMessage(&pgproto3.Parse{Query: "select * from generate_series(1,10) n"}), - pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S'}), - pgmock.ExpectMessage(&pgproto3.Sync{}), - - pgmock.SendMessage(&pgproto3.ParseComplete{}), - pgmock.SendMessage(&pgproto3.ParameterDescription{}), - pgmock.SendMessage(&pgproto3.RowDescription{ - Fields: []pgproto3.FieldDescription{ - { - Name: "n", - DataTypeOID: 23, - DataTypeSize: 4, - TypeModifier: 4294967295, +func TestRowsColumnTypes(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + columnTypesTests := []struct { + Name string + TypeName string + Length struct { + Len int64 + OK bool + } + DecimalSize struct { + Precision int64 + Scale int64 + OK bool + } + ScanType reflect.Type + }{ + { + Name: "a", + TypeName: "INT8", + Length: struct { + Len int64 + OK bool + }{ + Len: 0, + OK: false, + }, + DecimalSize: struct { + Precision int64 + Scale int64 + OK bool + }{ + Precision: 0, + Scale: 0, + OK: false, + }, + ScanType: reflect.TypeOf(int64(0)), + }, { + Name: "bar", + TypeName: "TEXT", + Length: struct { + Len int64 + OK bool + }{ + Len: math.MaxInt64, + OK: true, + }, + DecimalSize: struct { + Precision int64 + Scale int64 + OK bool + }{ + Precision: 0, + Scale: 0, + OK: false, + }, + ScanType: reflect.TypeOf(""), + }, { + Name: "dec", + TypeName: "NUMERIC", + Length: struct { + Len int64 + OK bool + }{ + Len: 0, + OK: false, + }, + DecimalSize: struct { + Precision int64 + Scale int64 + OK bool + }{ + Precision: 9, + Scale: 2, + OK: true, }, + ScanType: reflect.TypeOf(float64(0)), + }, { + Name: "d", + TypeName: "1266", + Length: struct { + Len int64 + OK bool + }{ + Len: 0, + OK: false, + }, + DecimalSize: struct { + Precision int64 + Scale int64 + OK bool + }{ + Precision: 0, + Scale: 0, + OK: false, + }, + ScanType: reflect.TypeOf(""), }, - }), - pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + } - pgmock.ExpectMessage(&pgproto3.Bind{ResultFormatCodes: []int16{1}}), - pgmock.ExpectMessage(&pgproto3.Execute{}), - pgmock.ExpectMessage(&pgproto3.Sync{}), + rows, err := db.Query("SELECT 1::bigint AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec, '12:00:00'::timetz as d") + require.NoError(t, err) - pgmock.SendMessage(&pgproto3.BindComplete{}), - pgmock.WaitForClose(), - ) + columns, err := rows.ColumnTypes() + require.NoError(t, err) + assert.Len(t, columns, 4) - server, err := pgmock.NewServer(script) - if err != nil { - t.Fatal(err) - } - defer server.Close() - - errChan := make(chan error) - go func() { - errChan <- server.ServeOne() - }() + for i, tt := range columnTypesTests { + c := columns[i] + if c.Name() != tt.Name { + t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name) + } + if c.DatabaseTypeName() != tt.TypeName { + t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName) + } + l, ok := c.Length() + if l != tt.Length.Len { + t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len) + } + if ok != tt.Length.OK { + t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK) + } + p, s, ok := c.DecimalSize() + if p != tt.DecimalSize.Precision { + t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision) + } + if s != tt.DecimalSize.Scale { + t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale) + } + if ok != tt.DecimalSize.OK { + t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK) + } + if c.ScanType() != tt.ScanType { + t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType) + } + } + }) +} - db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) - if err != nil { - t.Fatalf("sql.Open failed: %v", err) - } - defer db.Close() +func TestQueryLifeCycle(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") - ctx, cancelFn := context.WithCancel(context.Background()) + rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3) + require.NoError(t, err) - rows, err := db.QueryContext(ctx, "select * from generate_series(1,10) n") - if err != nil { - t.Fatalf("db.QueryContext failed: %v", err) - } + rowCount := int64(0) - cancelFn() + for rows.Next() { + rowCount++ + var ( + s string + n int64 + ) - for rows.Next() { - t.Fatalf("no rows should ever be received") - } + err := rows.Scan(&s, &n) + require.NoError(t, err) - if rows.Err() != context.Canceled { - t.Errorf("rows.Err() => %v, want %v", rows.Err(), context.Canceled) - } + if s != "foo" { + t.Errorf(`Expected "foo", received "%v"`, s) + } - if err := <-errChan; err != nil { - t.Errorf("mock server err: %v", err) - } -} + if n != rowCount { + t.Errorf("Expected %d, received %d", rowCount, n) + } + } + require.NoError(t, rows.Err()) -func TestRowsColumnTypeDatabaseTypeName(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) + err = rows.Close() + require.NoError(t, err) - rows, err := db.Query("select * from generate_series(1,10) n") - if err != nil { - t.Fatalf("db.Query failed: %v", err) - } + rows, err = db.Query("select 1 where false") + require.NoError(t, err) - columnTypes, err := rows.ColumnTypes() - if err != nil { - t.Fatalf("rows.ColumnTypes failed: %v", err) - } + rowCount = int64(0) - if len(columnTypes) != 1 { - t.Fatalf("len(columnTypes) => %v, want %v", len(columnTypes), 1) - } + for rows.Next() { + rowCount++ + } + require.NoError(t, rows.Err()) + require.EqualValues(t, 0, rowCount) - if columnTypes[0].DatabaseTypeName() != "INT4" { - t.Errorf("columnTypes[0].DatabaseTypeName() => %v, want %v", columnTypes[0].DatabaseTypeName(), "INT4") - } + err = rows.Close() + require.NoError(t, err) + }) +} - rows.Close() +// https://github.com/jackc/pgx/issues/409 +func TestScanJSONIntoJSONRawMessage(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + var msg json.RawMessage - ensureConnValid(t, db) + err := db.QueryRow("select '{}'::json").Scan(&msg) + require.NoError(t, err) + require.EqualValues(t, []byte("{}"), []byte(msg)) + }) } -func TestStmtExecContextSuccess(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) +type testLog struct { + lvl tracelog.LogLevel + msg string + data map[string]any +} - _, err := db.Exec("create temporary table t(id int primary key)") - if err != nil { - t.Fatalf("db.Exec failed: %v", err) - } +type testLogger struct { + logs []testLog +} - stmt, err := db.Prepare("insert into t(id) values ($1::int4)") - if err != nil { - t.Fatal(err) - } - defer stmt.Close() +func (l *testLogger) Log(ctx context.Context, lvl tracelog.LogLevel, msg string, data map[string]any) { + l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data}) +} - _, err = stmt.ExecContext(context.Background(), 42) - if err != nil { - t.Fatal(err) - } +func TestRegisterConnConfig(t *testing.T) { + connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) - ensureConnValid(t, db) -} + logger := &testLogger{} + connConfig.Tracer = &tracelog.TraceLog{Logger: logger, LogLevel: tracelog.LogLevelInfo} -func TestStmtExecContextCancel(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) + // Issue 947: Register and unregister a ConnConfig and ensure that the + // returned connection string is not reused. + connStr := stdlib.RegisterConnConfig(connConfig) + require.Equal(t, "registeredConnConfig0", connStr) + stdlib.UnregisterConnConfig(connStr) - _, err := db.Exec("create temporary table t(id int primary key)") - if err != nil { - t.Fatalf("db.Exec failed: %v", err) - } + connStr = stdlib.RegisterConnConfig(connConfig) + defer stdlib.UnregisterConnConfig(connStr) + require.Equal(t, "registeredConnConfig1", connStr) - stmt, err := db.Prepare("insert into t(id) select $1::int4 from pg_sleep(5)") - if err != nil { - t.Fatal(err) - } - defer stmt.Close() + db, err := sql.Open("pgx", connStr) + require.NoError(t, err) + defer closeDB(t, db) - ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + var n int64 + err = db.QueryRow("select 1").Scan(&n) + require.NoError(t, err) - _, err = stmt.ExecContext(ctx, 42) - if err != context.DeadlineExceeded { - t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) - } + l := logger.logs[len(logger.logs)-1] + assert.Equal(t, "Query", l.msg) + assert.Equal(t, "select 1", l.data["sql"]) +} - ensureConnValid(t, db) +// https://github.com/jackc/pgx/issues/958 +func TestConnQueryRowConstraintErrors(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + skipPostgreSQLVersionLessThan(t, db, 11) + skipCockroachDB(t, db, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + + _, err := db.Exec(`create temporary table defer_test ( + id text primary key, + n int not null, unique (n), + unique (n) deferrable initially deferred )`) + require.NoError(t, err) + + _, err = db.Exec(`drop function if exists test_trigger cascade`) + require.NoError(t, err) + + _, err = db.Exec(`create function test_trigger() returns trigger language plpgsql as $$ + begin + if new.n = 4 then + raise exception 'n cant be 4!'; + end if; + return new; + end$$`) + require.NoError(t, err) + + _, err = db.Exec(`create constraint trigger test + after insert or update on defer_test + deferrable initially deferred + for each row + execute function test_trigger()`) + require.NoError(t, err) + + _, err = db.Exec(`insert into defer_test (id, n) values ('a', 1), ('b', 2), ('c', 3)`) + require.NoError(t, err) + + var id string + err = db.QueryRow(`insert into defer_test (id, n) values ('e', 4) returning id`).Scan(&id) + assert.Error(t, err) + }) } -func TestStmtQueryContextSuccess(t *testing.T) { - db := openDB(t) +func TestOptionBeforeAfterConnect(t *testing.T) { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + var beforeConnConfigs []*pgx.ConnConfig + var afterConns []*pgx.Conn + db := stdlib.OpenDB(*config, + stdlib.OptionBeforeConnect(func(ctx context.Context, connConfig *pgx.ConnConfig) error { + beforeConnConfigs = append(beforeConnConfigs, connConfig) + return nil + }), + stdlib.OptionAfterConnect(func(ctx context.Context, conn *pgx.Conn) error { + afterConns = append(afterConns, conn) + return nil + })) defer closeDB(t, db) - stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n") - if err != nil { - t.Fatal(err) - } - defer stmt.Close() + // Force it to close and reopen a new connection after each query + db.SetMaxIdleConns(0) - rows, err := stmt.QueryContext(context.Background(), 5) - if err != nil { - t.Fatalf("stmt.QueryContext failed: %v", err) - } + _, err = db.Exec("select 1") + require.NoError(t, err) - for rows.Next() { - var n int64 - if err := rows.Scan(&n); err != nil { - t.Error(err) - } - } + _, err = db.Exec("select 1") + require.NoError(t, err) - if rows.Err() != nil { - t.Error(rows.Err()) - } + require.Len(t, beforeConnConfigs, 2) + require.Len(t, afterConns, 2) - ensureConnValid(t, db) + // Note: BeforeConnect creates a shallow copy, so the config contents will be the same but we wean to ensure they + // are different objects, so can't use require.NotEqual + require.False(t, config == beforeConnConfigs[0]) + require.False(t, beforeConnConfigs[0] == beforeConnConfigs[1]) } -func TestStmtQueryContextCancel(t *testing.T) { - script := &pgmock.Script{ - Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), - } - script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) - script.Steps = append(script.Steps, - pgmock.ExpectMessage(&pgproto3.Parse{Name: "pgx_0", Query: "select * from generate_series(1, $1::int4) n"}), - pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: "pgx_0"}), - pgmock.ExpectMessage(&pgproto3.Sync{}), - - pgmock.SendMessage(&pgproto3.ParseComplete{}), - pgmock.SendMessage(&pgproto3.ParameterDescription{ParameterOIDs: []uint32{23}}), - pgmock.SendMessage(&pgproto3.RowDescription{ - Fields: []pgproto3.FieldDescription{ - { - Name: "n", - DataTypeOID: 23, - DataTypeSize: 4, - TypeModifier: 4294967295, - }, - }, - }), - pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), - - pgmock.ExpectMessage(&pgproto3.Bind{PreparedStatement: "pgx_0", ParameterFormatCodes: []int16{1}, Parameters: [][]uint8{{0x0, 0x0, 0x0, 0x2a}}, ResultFormatCodes: []int16{1}}), - pgmock.ExpectMessage(&pgproto3.Execute{}), - pgmock.ExpectMessage(&pgproto3.Sync{}), - - pgmock.SendMessage(&pgproto3.BindComplete{}), - pgmock.WaitForClose(), - ) +func TestRandomizeHostOrderFunc(t *testing.T) { + config, err := pgx.ParseConfig("postgres://host1,host2,host3") + require.NoError(t, err) - server, err := pgmock.NewServer(script) - if err != nil { - t.Fatal(err) + // Test that at some point we connect to all 3 hosts + hostsNotSeenYet := map[string]struct{}{ + "host1": {}, + "host2": {}, + "host3": {}, } - defer server.Close() - errChan := make(chan error) - go func() { - errChan <- server.ServeOne() - }() + // If we don't succeed within this many iterations, something is certainly wrong + for i := 0; i < 100000; i++ { + connCopy := *config + stdlib.RandomizeHostOrderFunc(context.Background(), &connCopy) - db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) - if err != nil { - t.Fatalf("sql.Open failed: %v", err) - } - // defer closeDB(t, db) // mock DB doesn't close correctly yet - - stmt, err := db.Prepare("select * from generate_series(1, $1::int4) n") - if err != nil { - t.Fatal(err) - } - // defer stmt.Close() - - ctx, cancelFn := context.WithCancel(context.Background()) + delete(hostsNotSeenYet, connCopy.Host) + if len(hostsNotSeenYet) == 0 { + return + } - rows, err := stmt.QueryContext(ctx, 42) - if err != nil { - t.Fatalf("stmt.QueryContext failed: %v", err) + hostCheckLoop: + for _, h := range []string{"host1", "host2", "host3"} { + if connCopy.Host == h { + continue + } + for _, f := range connCopy.Fallbacks { + if f.Host == h { + continue hostCheckLoop + } + } + require.Failf(t, "got configuration from RandomizeHostOrderFunc that did not have all the hosts", "%+v", connCopy) + } } - cancelFn() + require.Fail(t, "did not get all hosts as primaries after many randomizations") +} - for rows.Next() { - t.Fatalf("no rows should ever be received") - } +func TestResetSessionHookCalled(t *testing.T) { + var mockCalled bool - if rows.Err() != context.Canceled { - t.Errorf("rows.Err() => %v, want %v", rows.Err(), context.Canceled) - } + connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) - if err := <-errChan; err != nil { - t.Errorf("mock server err: %v", err) - } -} + db := stdlib.OpenDB(*connConfig, stdlib.OptionResetSession(func(ctx context.Context, conn *pgx.Conn) error { + mockCalled = true -func TestRowsColumnTypes(t *testing.T) { - columnTypesTests := []struct { - Name string - TypeName string - Length struct { - Len int64 - OK bool - } - DecimalSize struct { - Precision int64 - Scale int64 - OK bool - } - ScanType reflect.Type - }{ - { - Name: "a", - TypeName: "INT4", - Length: struct { - Len int64 - OK bool - }{ - Len: 0, - OK: false, - }, - DecimalSize: struct { - Precision int64 - Scale int64 - OK bool - }{ - Precision: 0, - Scale: 0, - OK: false, - }, - ScanType: reflect.TypeOf(int32(0)), - }, { - Name: "bar", - TypeName: "TEXT", - Length: struct { - Len int64 - OK bool - }{ - Len: math.MaxInt64, - OK: true, - }, - DecimalSize: struct { - Precision int64 - Scale int64 - OK bool - }{ - Precision: 0, - Scale: 0, - OK: false, - }, - ScanType: reflect.TypeOf(""), - }, { - Name: "dec", - TypeName: "NUMERIC", - Length: struct { - Len int64 - OK bool - }{ - Len: 0, - OK: false, - }, - DecimalSize: struct { - Precision int64 - Scale int64 - OK bool - }{ - Precision: 9, - Scale: 2, - OK: true, - }, - ScanType: reflect.TypeOf(float64(0)), - }, - } + return nil + })) - db := openDB(t) defer closeDB(t, db) - rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec") - if err != nil { - t.Fatal(err) - } + err = db.Ping() + require.NoError(t, err) - columns, err := rows.ColumnTypes() - if err != nil { - t.Fatal(err) - } - if len(columns) != 3 { - t.Errorf("expected 3 columns found %d", len(columns)) - } + err = db.Ping() + require.NoError(t, err) - for i, tt := range columnTypesTests { - c := columns[i] - if c.Name() != tt.Name { - t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name) - } - if c.DatabaseTypeName() != tt.TypeName { - t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName) - } - l, ok := c.Length() - if l != tt.Length.Len { - t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len) - } - if ok != tt.Length.OK { - t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK) - } - p, s, ok := c.DecimalSize() - if p != tt.DecimalSize.Precision { - t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision) - } - if s != tt.DecimalSize.Scale { - t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale) - } - if ok != tt.DecimalSize.OK { - t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK) - } - if c.ScanType() != tt.ScanType { - t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType) - } - } + require.True(t, mockCalled) } -func TestSimpleQueryLifeCycle(t *testing.T) { - driverConfig := stdlib.DriverConfig{ - ConnConfig: pgx.ConnConfig{PreferSimpleProtocol: true}, - } +func TestCheckIdleConn(t *testing.T) { + controllerConn, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeDB(t, controllerConn) - stdlib.RegisterDriverConfig(&driverConfig) - defer stdlib.UnregisterDriverConfig(&driverConfig) + skipCockroachDB(t, controllerConn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") - db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")) - if err != nil { - t.Fatalf("sql.Open failed: %v", err) - } + db, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) defer closeDB(t, db) - rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3) - if err != nil { - t.Fatalf("stmt.Query unexpectedly failed: %v", err) + var conns []*sql.Conn + for i := 0; i < 3; i++ { + c, err := db.Conn(context.Background()) + require.NoError(t, err) + conns = append(conns, c) } - rowCount := int64(0) + require.EqualValues(t, 3, db.Stats().OpenConnections) - for rows.Next() { - rowCount++ - var ( - s string - n int64 - ) + var pids []uint32 + for _, c := range conns { + err := c.Raw(func(driverConn any) error { + pids = append(pids, driverConn.(*stdlib.Conn).Conn().PgConn().PID()) + return nil + }) + require.NoError(t, err) + err = c.Close() + require.NoError(t, err) + } - if err := rows.Scan(&s, &n); err != nil { - t.Fatalf("rows.Scan unexpectedly failed: %v", err) - } + // The database/sql connection pool seems to automatically close idle connections to only keep 2 alive. + // require.EqualValues(t, 3, db.Stats().OpenConnections) - if s != "foo" { - t.Errorf(`Expected "foo", received "%v"`, s) - } + _, err = controllerConn.ExecContext(context.Background(), `select pg_terminate_backend(n) from unnest($1::int[]) n`, pids) + require.NoError(t, err) - if n != rowCount { - t.Errorf("Expected %d, received %d", rowCount, n) - } - } + // All conns are dead they don't know it and neither does the pool. But because of database/sql automatically closing + // idle connections we can't be sure how many we should have. require.EqualValues(t, 3, db.Stats().OpenConnections) - if err = rows.Err(); err != nil { - t.Fatalf("rows.Err unexpectedly is: %v", err) - } + // Wait long enough so the pool will realize it needs to check the connections. + time.Sleep(time.Second) - if rowCount != 10 { - t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount) - } + // Pool should try all existing connections and find them dead, then create a new connection which should successfully ping. + err = db.PingContext(context.Background()) + require.NoError(t, err) - err = rows.Close() - if err != nil { - t.Fatalf("rows.Close unexpectedly failed: %v", err) - } + // The original 3 conns should have been terminated and the a new conn established for the ping. + require.EqualValues(t, 1, db.Stats().OpenConnections) + c, err := db.Conn(context.Background()) + require.NoError(t, err) - rows, err = db.Query("select 1 where false") - if err != nil { - t.Fatalf("stmt.Query unexpectedly failed: %v", err) - } + var cPID uint32 + err = c.Raw(func(driverConn any) error { + cPID = driverConn.(*stdlib.Conn).Conn().PgConn().PID() + return nil + }) + require.NoError(t, err) + err = c.Close() + require.NoError(t, err) - rowCount = int64(0) + require.NotContains(t, pids, cPID) +} - for rows.Next() { - rowCount++ - } +func TestOptionShouldPing_HookCalledOnReuse(t *testing.T) { + hookCalled := false - if err = rows.Err(); err != nil { - t.Fatalf("rows.Err unexpectedly is: %v", err) - } + db := openDB(t, + stdlib.OptionShouldPing(func(context.Context, stdlib.ShouldPingParams) bool { + hookCalled = true + // Return false to avoid relying on actual ping behavior. + return false + }), + ) + defer closeDB(t, db) - if rowCount != 0 { - t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount) - } + // Ensure reuse (so ResetSession runs) + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) - err = rows.Close() - if err != nil { - t.Fatalf("rows.Close unexpectedly failed: %v", err) - } + // Establish the connection + require.NoError(t, db.Ping()) + + // Reuse the connection -> should trigger ResetSession -> ShouldPing + _, err := db.Exec("select 1") + require.NoError(t, err) - ensureConnValid(t, db) + require.True(t, hookCalled, "hook should be called on reuse") } diff --git a/stress_test.go b/stress_test.go deleted file mode 100644 index 114bec816..000000000 --- a/stress_test.go +++ /dev/null @@ -1,359 +0,0 @@ -package pgx_test - -import ( - "context" - "fmt" - "math/rand" - "os" - "strconv" - "testing" - "time" - - "github.com/pkg/errors" - - "github.com/jackc/fake" - "github.com/jackc/pgx" -) - -type execer interface { - Exec(sql string, arguments ...interface{}) (commandTag pgx.CommandTag, err error) -} -type queryer interface { - Query(sql string, args ...interface{}) (*pgx.Rows, error) -} -type queryRower interface { - QueryRow(sql string, args ...interface{}) *pgx.Row -} - -func TestStressConnPool(t *testing.T) { - t.Parallel() - - maxConnections := 8 - pool := createConnPool(t, maxConnections) - defer pool.Close() - - setupStressDB(t, pool) - - actions := []struct { - name string - fn func(*pgx.ConnPool, int) error - }{ - {"insertUnprepared", func(p *pgx.ConnPool, n int) error { return insertUnprepared(p, n) }}, - {"queryRowWithoutParams", func(p *pgx.ConnPool, n int) error { return queryRowWithoutParams(p, n) }}, - {"query", func(p *pgx.ConnPool, n int) error { return queryCloseEarly(p, n) }}, - {"queryCloseEarly", func(p *pgx.ConnPool, n int) error { return query(p, n) }}, - {"queryErrorWhileReturningRows", func(p *pgx.ConnPool, n int) error { return queryErrorWhileReturningRows(p, n) }}, - {"txInsertRollback", txInsertRollback}, - {"txInsertCommit", txInsertCommit}, - {"txMultipleQueries", txMultipleQueries}, - {"notify", notify}, - {"listenAndPoolUnlistens", listenAndPoolUnlistens}, - {"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }}, - {"poolPrepareUseAndDeallocate", poolPrepareUseAndDeallocate}, - {"canceledQueryExContext", canceledQueryExContext}, - {"canceledExecExContext", canceledExecExContext}, - } - - actionCount := 1000 - if s := os.Getenv("STRESS_FACTOR"); s != "" { - stressFactor, err := strconv.ParseInt(s, 10, 64) - if err != nil { - t.Fatalf("failed to parse STRESS_FACTOR: %v", s) - } - actionCount *= int(stressFactor) - } - - workerCount := 16 - - workChan := make(chan int) - doneChan := make(chan struct{}) - errChan := make(chan error) - - work := func() { - for n := range workChan { - action := actions[rand.Intn(len(actions))] - err := action.fn(pool, n) - if err != nil { - errChan <- errors.Errorf("%s: %v", action.name, err) - break - } - } - doneChan <- struct{}{} - } - - for i := 0; i < workerCount; i++ { - go work() - } - - for i := 0; i < actionCount; i++ { - select { - case workChan <- i: - case err := <-errChan: - close(workChan) - t.Fatal(err) - } - } - close(workChan) - - for i := 0; i < workerCount; i++ { - <-doneChan - } -} - -func setupStressDB(t *testing.T, pool *pgx.ConnPool) { - _, err := pool.Exec(` - drop table if exists widgets; - create table widgets( - id serial primary key, - name varchar not null, - description text, - creation_time timestamptz - ); -`) - if err != nil { - t.Fatal(err) - } -} - -func insertUnprepared(e execer, actionNum int) error { - sql := ` - insert into widgets(name, description, creation_time) - values($1, $2, $3)` - - _, err := e.Exec(sql, fake.ProductName(), fake.Sentences(), time.Now()) - return err -} - -func queryRowWithoutParams(qr queryRower, actionNum int) error { - var id int32 - var name, description string - var creationTime time.Time - - sql := `select * from widgets order by random() limit 1` - - err := qr.QueryRow(sql).Scan(&id, &name, &description, &creationTime) - if err == pgx.ErrNoRows { - return nil - } - return err -} - -func query(q queryer, actionNum int) error { - sql := `select * from widgets order by random() limit $1` - - rows, err := q.Query(sql, 10) - if err != nil { - return err - } - defer rows.Close() - - for rows.Next() { - var id int32 - var name, description string - var creationTime time.Time - rows.Scan(&id, &name, &description, &creationTime) - } - - return rows.Err() -} - -func queryCloseEarly(q queryer, actionNum int) error { - sql := `select * from generate_series(1,$1)` - - rows, err := q.Query(sql, 100) - if err != nil { - return err - } - defer rows.Close() - - for i := 0; i < 10 && rows.Next(); i++ { - var n int32 - rows.Scan(&n) - } - rows.Close() - - return rows.Err() -} - -func queryErrorWhileReturningRows(q queryer, actionNum int) error { - // This query should divide by 0 within the first number of rows - sql := `select 42 / (random() * 20)::integer from generate_series(1,100000)` - - rows, err := q.Query(sql) - if err != nil { - return nil - } - defer rows.Close() - - for rows.Next() { - var n int32 - rows.Scan(&n) - } - - if _, ok := rows.Err().(pgx.PgError); ok { - return nil - } - return rows.Err() -} - -func notify(pool *pgx.ConnPool, actionNum int) error { - _, err := pool.Exec("notify stress") - return err -} - -func listenAndPoolUnlistens(pool *pgx.ConnPool, actionNum int) error { - conn, err := pool.Acquire() - if err != nil { - return err - } - defer pool.Release(conn) - - err = conn.Listen("stress") - if err != nil { - return err - } - - ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) - _, err = conn.WaitForNotification(ctx) - if err == context.DeadlineExceeded { - return nil - } - return err -} - -func poolPrepareUseAndDeallocate(pool *pgx.ConnPool, actionNum int) error { - psName := fmt.Sprintf("poolPreparedStatement%d", actionNum) - - _, err := pool.Prepare(psName, "select $1::text") - if err != nil { - return err - } - - var s string - err = pool.QueryRow(psName, "hello").Scan(&s) - if err != nil { - return err - } - - if s != "hello" { - return errors.Errorf("Prepared statement did not return expected value: %v", s) - } - - return pool.Deallocate(psName) -} - -func txInsertRollback(pool *pgx.ConnPool, actionNum int) error { - tx, err := pool.Begin() - if err != nil { - return err - } - - sql := ` - insert into widgets(name, description, creation_time) - values($1, $2, $3)` - - _, err = tx.Exec(sql, fake.ProductName(), fake.Sentences(), time.Now()) - if err != nil { - return err - } - - return tx.Rollback() -} - -func txInsertCommit(pool *pgx.ConnPool, actionNum int) error { - tx, err := pool.Begin() - if err != nil { - return err - } - - sql := ` - insert into widgets(name, description, creation_time) - values($1, $2, $3)` - - _, err = tx.Exec(sql, fake.ProductName(), fake.Sentences(), time.Now()) - if err != nil { - tx.Rollback() - return err - } - - return tx.Commit() -} - -func txMultipleQueries(pool *pgx.ConnPool, actionNum int) error { - tx, err := pool.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - errExpectedTxDeath := errors.New("Expected tx death") - - actions := []struct { - name string - fn func() error - }{ - {"insertUnprepared", func() error { return insertUnprepared(tx, actionNum) }}, - {"queryRowWithoutParams", func() error { return queryRowWithoutParams(tx, actionNum) }}, - {"query", func() error { return query(tx, actionNum) }}, - {"queryCloseEarly", func() error { return queryCloseEarly(tx, actionNum) }}, - {"queryErrorWhileReturningRows", func() error { - err := queryErrorWhileReturningRows(tx, actionNum) - if err != nil { - return err - } - return errExpectedTxDeath - }}, - } - - for i := 0; i < 20; i++ { - action := actions[rand.Intn(len(actions))] - err := action.fn() - if err == errExpectedTxDeath { - return nil - } else if err != nil { - return err - } - } - - return tx.Commit() -} - -func canceledQueryExContext(pool *pgx.ConnPool, actionNum int) error { - ctx, cancelFunc := context.WithCancel(context.Background()) - go func() { - time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond) - cancelFunc() - }() - - rows, err := pool.QueryEx(ctx, "select pg_sleep(2)", nil) - if err == context.Canceled { - return nil - } else if err != nil { - return errors.Errorf("Only allowed error is context.Canceled, got %v", err) - } - - for rows.Next() { - return errors.New("should never receive row") - } - - if rows.Err() != context.Canceled { - return errors.Errorf("Expected context.Canceled error, got %v", rows.Err()) - } - - return nil -} - -func canceledExecExContext(pool *pgx.ConnPool, actionNum int) error { - ctx, cancelFunc := context.WithCancel(context.Background()) - go func() { - time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond) - cancelFunc() - }() - - _, err := pool.ExecEx(ctx, "select pg_sleep(2)", nil) - if err != context.Canceled { - return errors.Errorf("Expected context.Canceled error, got %v", err) - } - - return nil -} diff --git a/testsetup/README.md b/testsetup/README.md new file mode 100644 index 000000000..4a1dbab91 --- /dev/null +++ b/testsetup/README.md @@ -0,0 +1,3 @@ +# Test Setup + +This directory contains miscellaneous files used to setup a test database. diff --git a/testsetup/generate_certs.go b/testsetup/generate_certs.go new file mode 100644 index 000000000..d465b6c52 --- /dev/null +++ b/testsetup/generate_certs.go @@ -0,0 +1,186 @@ +// Generates a CA, server certificate, and encrypted client certificate for testing pgx. + +package main + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "os" + "time" +) + +func main() { + // Create the CA + ca := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "pgx-root-ca", + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(20, 0, 0), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + caKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + panic(err) + } + + caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caKey.PublicKey, caKey) + if err != nil { + panic(err) + } + + err = writePrivateKey("ca.key", caKey) + if err != nil { + panic(err) + } + + err = writeCertificate("ca.pem", caBytes) + if err != nil { + panic(err) + } + + // Create a server certificate signed by the CA for localhost. + serverCert := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{ + CommonName: "localhost", + }, + DNSNames: []string{"localhost"}, + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(20, 0, 0), + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + + serverCertPrivKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic(err) + } + + serverBytes, err := x509.CreateCertificate(rand.Reader, serverCert, ca, &serverCertPrivKey.PublicKey, caKey) + if err != nil { + panic(err) + } + + err = writePrivateKey("localhost.key", serverCertPrivKey) + if err != nil { + panic(err) + } + + err = writeCertificate("localhost.crt", serverBytes) + if err != nil { + panic(err) + } + + // Create a client certificate signed by the CA and encrypted. + clientCert := &x509.Certificate{ + SerialNumber: big.NewInt(3), + Subject: pkix.Name{ + CommonName: "pgx_sslcert", + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(20, 0, 0), + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + + clientCertPrivKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic(err) + } + + clientBytes, err := x509.CreateCertificate(rand.Reader, clientCert, ca, &clientCertPrivKey.PublicKey, caKey) + if err != nil { + panic(err) + } + + err = writeEncryptedPrivateKey("pgx_sslcert.key", clientCertPrivKey, "certpw") + if err != nil { + panic(err) + } + + err = writeCertificate("pgx_sslcert.crt", clientBytes) + if err != nil { + panic(err) + } +} + +func writePrivateKey(path string, privateKey *rsa.PrivateKey) error { + file, err := os.Create(path) + if err != nil { + return fmt.Errorf("writePrivateKey: %w", err) + } + + err = pem.Encode(file, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + if err != nil { + return fmt.Errorf("writePrivateKey: %w", err) + } + + err = file.Close() + if err != nil { + return fmt.Errorf("writePrivateKey: %w", err) + } + + return nil +} + +func writeEncryptedPrivateKey(path string, privateKey *rsa.PrivateKey, password string) error { + file, err := os.Create(path) + if err != nil { + return fmt.Errorf("writeEncryptedPrivateKey: %w", err) + } + + block, err := x509.EncryptPEMBlock(rand.Reader, "CERTIFICATE", x509.MarshalPKCS1PrivateKey(privateKey), []byte(password), x509.PEMCipher3DES) + if err != nil { + return fmt.Errorf("writeEncryptedPrivateKey: %w", err) + } + + err = pem.Encode(file, block) + if err != nil { + return fmt.Errorf("writeEncryptedPrivateKey: %w", err) + } + + err = file.Close() + if err != nil { + return fmt.Errorf("writeEncryptedPrivateKey: %w", err) + } + + return nil +} + +func writeCertificate(path string, certBytes []byte) error { + file, err := os.Create(path) + if err != nil { + return fmt.Errorf("writeCertificate: %w", err) + } + + err = pem.Encode(file, &pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + }) + if err != nil { + return fmt.Errorf("writeCertificate: %w", err) + } + + err = file.Close() + if err != nil { + return fmt.Errorf("writeCertificate: %w", err) + } + + return nil +} diff --git a/testsetup/pg_hba.conf b/testsetup/pg_hba.conf new file mode 100644 index 000000000..6609372ef --- /dev/null +++ b/testsetup/pg_hba.conf @@ -0,0 +1,7 @@ +local all postgres trust +local all all trust +host all pgx_md5 127.0.0.1/32 md5 +host all pgx_scram 127.0.0.1/32 scram-sha-256 +host all pgx_pw 127.0.0.1/32 password +hostssl all pgx_ssl 127.0.0.1/32 scram-sha-256 +hostssl all pgx_sslcert 127.0.0.1/32 cert diff --git a/testsetup/postgresql_setup.sql b/testsetup/postgresql_setup.sql new file mode 100644 index 000000000..837c978ac --- /dev/null +++ b/testsetup/postgresql_setup.sql @@ -0,0 +1,20 @@ +-- Create extensions and types. +create extension hstore; +create extension ltree; +create domain uint64 as numeric(20,0); + +-- Create users for different types of connections and authentication. +create user pgx_ssl with superuser PASSWORD 'secret'; +create user pgx_sslcert with superuser PASSWORD 'secret'; +set password_encryption = md5; +create user pgx_md5 with superuser PASSWORD 'secret'; +set password_encryption = 'scram-sha-256'; +create user pgx_pw with superuser PASSWORD 'secret'; +create user pgx_scram with superuser PASSWORD 'secret'; +\set whoami `whoami` +create user :whoami with superuser; -- unix domain socket user + + +-- The tricky test user, below, has to actually exist so that it can be used in a test +-- of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. +create user " tricky, ' } "" \\ test user " superuser password 'secret'; diff --git a/testsetup/postgresql_ssl.conf b/testsetup/postgresql_ssl.conf new file mode 100644 index 000000000..bf75f3c7b --- /dev/null +++ b/testsetup/postgresql_ssl.conf @@ -0,0 +1,4 @@ +ssl = on +ssl_cert_file = 'server.crt' +ssl_key_file = 'server.key' +ssl_ca_file = 'root.crt' diff --git a/tracelog/tracelog.go b/tracelog/tracelog.go new file mode 100644 index 000000000..b36fc99ca --- /dev/null +++ b/tracelog/tracelog.go @@ -0,0 +1,409 @@ +// Package tracelog provides a tracer that acts as a traditional logger. +package tracelog + +import ( + "context" + "encoding/hex" + "errors" + "fmt" + "sync" + "time" + "unicode/utf8" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +// LogLevel represents the pgx logging level. See LogLevel* constants for +// possible values. +type LogLevel int + +// The values for log levels are chosen such that the zero value means that no +// log level was specified. +const ( + LogLevelTrace = LogLevel(6) + LogLevelDebug = LogLevel(5) + LogLevelInfo = LogLevel(4) + LogLevelWarn = LogLevel(3) + LogLevelError = LogLevel(2) + LogLevelNone = LogLevel(1) +) + +func (ll LogLevel) String() string { + switch ll { + case LogLevelTrace: + return "trace" + case LogLevelDebug: + return "debug" + case LogLevelInfo: + return "info" + case LogLevelWarn: + return "warn" + case LogLevelError: + return "error" + case LogLevelNone: + return "none" + default: + return fmt.Sprintf("invalid level %d", ll) + } +} + +// Logger is the interface used to get log output from pgx. +type Logger interface { + // Log a message at the given level with data key/value pairs. data may be nil. + Log(ctx context.Context, level LogLevel, msg string, data map[string]any) +} + +// LoggerFunc is a wrapper around a function to satisfy the pgx.Logger interface +type LoggerFunc func(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) + +// Log delegates the logging request to the wrapped function +func (f LoggerFunc) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) { + f(ctx, level, msg, data) +} + +// LogLevelFromString converts log level string to constant +// +// Valid levels: +// +// trace +// debug +// info +// warn +// error +// none +func LogLevelFromString(s string) (LogLevel, error) { + switch s { + case "trace": + return LogLevelTrace, nil + case "debug": + return LogLevelDebug, nil + case "info": + return LogLevelInfo, nil + case "warn": + return LogLevelWarn, nil + case "error": + return LogLevelError, nil + case "none": + return LogLevelNone, nil + default: + return 0, errors.New("invalid log level") + } +} + +func logQueryArgs(args []any) []any { + logArgs := make([]any, 0, len(args)) + + for _, a := range args { + switch v := a.(type) { + case []byte: + if len(v) < 64 { + a = hex.EncodeToString(v) + } else { + a = fmt.Sprintf("%x (truncated %d bytes)", v[:64], len(v)-64) + } + case string: + if len(v) > 64 { + l := 0 + for w := 0; l < 64; l += w { + _, w = utf8.DecodeRuneInString(v[l:]) + } + if len(v) > l { + a = fmt.Sprintf("%s (truncated %d bytes)", v[:l], len(v)-l) + } + } + } + logArgs = append(logArgs, a) + } + + return logArgs +} + +// TraceLogConfig holds the configuration for key names +type TraceLogConfig struct { + TimeKey string +} + +// DefaultTraceLogConfig returns the default configuration for TraceLog +func DefaultTraceLogConfig() *TraceLogConfig { + return &TraceLogConfig{ + TimeKey: "time", + } +} + +// TraceLog implements pgx.QueryTracer, pgx.BatchTracer, pgx.ConnectTracer, pgx.CopyFromTracer, pgxpool.AcquireTracer, +// and pgxpool.ReleaseTracer. Logger and LogLevel are required. Config will be automatically initialized on the +// first use if nil. +type TraceLog struct { + Logger Logger + LogLevel LogLevel + + Config *TraceLogConfig + ensureConfigOnce sync.Once +} + +// ensureConfig initializes the Config field with default values if it is nil. +func (tl *TraceLog) ensureConfig() { + tl.ensureConfigOnce.Do( + func() { + if tl.Config == nil { + tl.Config = DefaultTraceLogConfig() + } + }, + ) +} + +type ctxKey int + +const ( + _ ctxKey = iota + tracelogQueryCtxKey + tracelogBatchCtxKey + tracelogCopyFromCtxKey + tracelogConnectCtxKey + tracelogPrepareCtxKey + tracelogAcquireCtxKey +) + +type traceQueryData struct { + startTime time.Time + sql string + args []any +} + +func (tl *TraceLog) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + return context.WithValue(ctx, tracelogQueryCtxKey, &traceQueryData{ + startTime: time.Now(), + sql: data.SQL, + args: data.Args, + }) +} + +func (tl *TraceLog) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + tl.ensureConfig() + queryData := ctx.Value(tracelogQueryCtxKey).(*traceQueryData) + + endTime := time.Now() + interval := endTime.Sub(queryData.startTime) + + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.log(ctx, conn, LogLevelError, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), "err": data.Err, tl.Config.TimeKey: interval}) + } + return + } + + if tl.shouldLog(LogLevelInfo) { + tl.log(ctx, conn, LogLevelInfo, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), tl.Config.TimeKey: interval, "commandTag": data.CommandTag.String()}) + } +} + +type traceBatchData struct { + startTime time.Time +} + +func (tl *TraceLog) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + return context.WithValue(ctx, tracelogBatchCtxKey, &traceBatchData{ + startTime: time.Now(), + }) +} + +func (tl *TraceLog) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.log(ctx, conn, LogLevelError, "BatchQuery", map[string]any{"sql": data.SQL, "args": logQueryArgs(data.Args), "err": data.Err}) + } + return + } + + if tl.shouldLog(LogLevelInfo) { + tl.log(ctx, conn, LogLevelInfo, "BatchQuery", map[string]any{"sql": data.SQL, "args": logQueryArgs(data.Args), "commandTag": data.CommandTag.String()}) + } +} + +func (tl *TraceLog) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + tl.ensureConfig() + queryData := ctx.Value(tracelogBatchCtxKey).(*traceBatchData) + + endTime := time.Now() + interval := endTime.Sub(queryData.startTime) + + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.log(ctx, conn, LogLevelError, "BatchClose", map[string]any{"err": data.Err, tl.Config.TimeKey: interval}) + } + return + } + + if tl.shouldLog(LogLevelInfo) { + tl.log(ctx, conn, LogLevelInfo, "BatchClose", map[string]any{tl.Config.TimeKey: interval}) + } +} + +type traceCopyFromData struct { + startTime time.Time + TableName pgx.Identifier + ColumnNames []string +} + +func (tl *TraceLog) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context { + return context.WithValue(ctx, tracelogCopyFromCtxKey, &traceCopyFromData{ + startTime: time.Now(), + TableName: data.TableName, + ColumnNames: data.ColumnNames, + }) +} + +func (tl *TraceLog) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) { + tl.ensureConfig() + copyFromData := ctx.Value(tracelogCopyFromCtxKey).(*traceCopyFromData) + + endTime := time.Now() + interval := endTime.Sub(copyFromData.startTime) + + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.log(ctx, conn, LogLevelError, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, tl.Config.TimeKey: interval}) + } + return + } + + if tl.shouldLog(LogLevelInfo) { + tl.log(ctx, conn, LogLevelInfo, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, tl.Config.TimeKey: interval, "rowCount": data.CommandTag.RowsAffected()}) + } +} + +type traceConnectData struct { + startTime time.Time + connConfig *pgx.ConnConfig +} + +func (tl *TraceLog) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context { + return context.WithValue(ctx, tracelogConnectCtxKey, &traceConnectData{ + startTime: time.Now(), + connConfig: data.ConnConfig, + }) +} + +func (tl *TraceLog) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) { + tl.ensureConfig() + connectData := ctx.Value(tracelogConnectCtxKey).(*traceConnectData) + + endTime := time.Now() + interval := endTime.Sub(connectData.startTime) + + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.Logger.Log(ctx, LogLevelError, "Connect", map[string]any{ + "host": connectData.connConfig.Host, + "port": connectData.connConfig.Port, + "database": connectData.connConfig.Database, + tl.Config.TimeKey: interval, + "err": data.Err, + }) + } + return + } + + if data.Conn != nil { + if tl.shouldLog(LogLevelInfo) { + tl.log(ctx, data.Conn, LogLevelInfo, "Connect", map[string]any{ + "host": connectData.connConfig.Host, + "port": connectData.connConfig.Port, + "database": connectData.connConfig.Database, + tl.Config.TimeKey: interval, + }) + } + } +} + +type tracePrepareData struct { + startTime time.Time + name string + sql string +} + +func (tl *TraceLog) TracePrepareStart(ctx context.Context, _ *pgx.Conn, data pgx.TracePrepareStartData) context.Context { + return context.WithValue(ctx, tracelogPrepareCtxKey, &tracePrepareData{ + startTime: time.Now(), + name: data.Name, + sql: data.SQL, + }) +} + +func (tl *TraceLog) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { + tl.ensureConfig() + prepareData := ctx.Value(tracelogPrepareCtxKey).(*tracePrepareData) + + endTime := time.Now() + interval := endTime.Sub(prepareData.startTime) + + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.log(ctx, conn, LogLevelError, "Prepare", map[string]any{"name": prepareData.name, "sql": prepareData.sql, "err": data.Err, tl.Config.TimeKey: interval}) + } + return + } + + if tl.shouldLog(LogLevelInfo) { + tl.log(ctx, conn, LogLevelInfo, "Prepare", map[string]any{"name": prepareData.name, "sql": prepareData.sql, tl.Config.TimeKey: interval, "alreadyPrepared": data.AlreadyPrepared}) + } +} + +type traceAcquireData struct { + startTime time.Time +} + +func (tl *TraceLog) TraceAcquireStart(ctx context.Context, _ *pgxpool.Pool, _ pgxpool.TraceAcquireStartData) context.Context { + return context.WithValue(ctx, tracelogAcquireCtxKey, &traceAcquireData{ + startTime: time.Now(), + }) +} + +func (tl *TraceLog) TraceAcquireEnd(ctx context.Context, _ *pgxpool.Pool, data pgxpool.TraceAcquireEndData) { + tl.ensureConfig() + acquireData := ctx.Value(tracelogAcquireCtxKey).(*traceAcquireData) + + endTime := time.Now() + interval := endTime.Sub(acquireData.startTime) + + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.Logger.Log(ctx, LogLevelError, "Acquire", map[string]any{"err": data.Err, tl.Config.TimeKey: interval}) + } + return + } + + if data.Conn != nil { + if tl.shouldLog(LogLevelDebug) { + tl.log(ctx, data.Conn, LogLevelDebug, "Acquire", map[string]any{tl.Config.TimeKey: interval}) + } + } +} + +func (tl *TraceLog) TraceRelease(_ *pgxpool.Pool, data pgxpool.TraceReleaseData) { + if tl.shouldLog(LogLevelDebug) { + // there is no context on the TraceRelease callback + tl.log(context.Background(), data.Conn, LogLevelDebug, "Release", map[string]any{}) + } +} + +func (tl *TraceLog) shouldLog(lvl LogLevel) bool { + return tl.LogLevel >= lvl +} + +func (tl *TraceLog) log(ctx context.Context, conn *pgx.Conn, lvl LogLevel, msg string, data map[string]any) { + if data == nil { + data = map[string]any{} + } + + pgConn := conn.PgConn() + if pgConn != nil { + pid := pgConn.PID() + if pid != 0 { + data["pid"] = pid + } + } + + tl.Logger.Log(ctx, lvl, msg, data) +} diff --git a/tracelog/tracelog_test.go b/tracelog/tracelog_test.go new file mode 100644 index 000000000..7d28fab57 --- /dev/null +++ b/tracelog/tracelog_test.go @@ -0,0 +1,593 @@ +package tracelog_test + +import ( + "bytes" + "context" + "log" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/jackc/pgx/v5/tracelog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +var defaultConnTestRunner pgxtest.ConnTestRunner + +func init() { + defaultConnTestRunner = pgxtest.DefaultConnTestRunner() + defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + return config + } +} + +type testLog struct { + lvl tracelog.LogLevel + msg string + data map[string]any +} + +type testLogger struct { + logs []testLog + + mux sync.Mutex +} + +func (l *testLogger) Log(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]any) { + l.mux.Lock() + defer l.mux.Unlock() + + data["ctxdata"] = ctx.Value("ctxdata") + l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data}) +} + +func (l *testLogger) Clear() { + l.mux.Lock() + defer l.mux.Unlock() + + l.logs = l.logs[0:0] +} + +func (l *testLogger) FilterByMsg(msg string) (res []testLog) { + l.mux.Lock() + defer l.mux.Unlock() + + for _, log := range l.logs { + if log.msg == msg { + res = append(res, log) + } + } + + return res +} + +func TestContextGetsPassedToLogMethod(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + logger.Clear() // Clear any logs written when establishing connection + + ctx = context.WithValue(ctx, "ctxdata", "foo") + _, err := conn.Exec(ctx, `;`) + require.NoError(t, err) + require.Len(t, logger.logs, 1) + require.Equal(t, "foo", logger.logs[0].data["ctxdata"]) + }) +} + +func TestLoggerFunc(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + const testMsg = "foo" + + buf := bytes.Buffer{} + logger := log.New(&buf, "", 0) + + createAdapterFn := func(logger *log.Logger) tracelog.LoggerFunc { + return func(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]interface{}) { + logger.Printf("%s", testMsg) + } + } + + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = &tracelog.TraceLog{ + Logger: createAdapterFn(logger), + LogLevel: tracelog.LogLevelTrace, + } + + conn, err := pgx.ConnectConfig(ctx, config) + require.NoError(t, err) + defer conn.Close(ctx) + + buf.Reset() // Clear logs written when establishing connection + + if _, err := conn.Exec(context.TODO(), ";"); err != nil { + t.Fatal(err) + } + + if strings.TrimSpace(buf.String()) != testMsg { + t.Errorf("Expected logger function to return '%s', but it was '%s'", testMsg, buf.String()) + } +} + +func TestLogQuery(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + logger.Clear() // Clear any logs written when establishing connection + + _, err := conn.Exec(ctx, `select $1::text`, "testing") + require.NoError(t, err) + + logs := logger.FilterByMsg("Query") + require.Len(t, logs, 1) + require.Equal(t, tracelog.LogLevelInfo, logs[0].lvl) + + logger.Clear() + + _, err = conn.Exec(ctx, `foo`, "testing") + require.Error(t, err) + + logs = logger.FilterByMsg("Query") + require.Len(t, logs, 1) + require.Equal(t, tracelog.LogLevelError, logs[0].lvl) + require.Equal(t, err, logs[0].data["err"]) + }) +} + +// https://github.com/jackc/pgx/issues/1365 +func TestLogQueryArgsHandlesUTF8(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + logger.Clear() // Clear any logs written when establishing connection + + var s string + for i := 0; i < 63; i++ { + s += "0" + } + s += "😊" + + _, err := conn.Exec(ctx, `select $1::text`, s) + require.NoError(t, err) + + logs := logger.FilterByMsg("Query") + require.Len(t, logs, 1) + require.Equal(t, tracelog.LogLevelInfo, logs[0].lvl) + require.Equal(t, s, logs[0].data["args"].([]any)[0]) + + logger.Clear() + + _, err = conn.Exec(ctx, `select $1::text`, s+"000") + require.NoError(t, err) + + logs = logger.FilterByMsg("Query") + require.Len(t, logs, 1) + require.Equal(t, tracelog.LogLevelInfo, logs[0].lvl) + require.Equal(t, s+" (truncated 3 bytes)", logs[0].data["args"].([]any)[0]) + }) +} + +func TestLogCopyFrom(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(ctx, t, ctr, pgxtest.KnownOIDQueryExecModes, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(ctx, `create temporary table foo(a int4)`) + require.NoError(t, err) + + logger.Clear() + + inputRows := [][]any{ + {int32(1)}, + {nil}, + } + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + + logs := logger.FilterByMsg("CopyFrom") + require.Len(t, logs, 1) + require.Equal(t, tracelog.LogLevelInfo, logs[0].lvl) + + logger.Clear() + + inputRows = [][]any{ + {"not an integer"}, + {nil}, + } + + copyCount, err = conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows)) + require.Error(t, err) + require.EqualValues(t, 0, copyCount) + + logs = logger.FilterByMsg("CopyFrom") + require.Len(t, logs, 1) + require.Equal(t, tracelog.LogLevelError, logs[0].lvl) + }) +} + +func TestLogConnect(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + + conn1, err := pgx.ConnectConfig(ctx, config) + require.NoError(t, err) + defer conn1.Close(ctx) + require.Len(t, logger.logs, 1) + require.Equal(t, "Connect", logger.logs[0].msg) + require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl) + + logger.Clear() + + config, err = pgx.ParseConfig("host=/invalid") + require.NoError(t, err) + config.Tracer = tracer + + conn2, err := pgx.ConnectConfig(ctx, config) + require.Nil(t, conn2) + require.Error(t, err) + require.Len(t, logger.logs, 1) + require.Equal(t, "Connect", logger.logs[0].msg) + require.Equal(t, tracelog.LogLevelError, logger.logs[0].lvl) +} + +func TestLogBatchStatementsOnExec(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + logger.Clear() // Clear any logs written when establishing connection + + batch := &pgx.Batch{} + batch.Queue("create table foo (id bigint)") + batch.Queue("drop table foo") + + br := conn.SendBatch(ctx, batch) + + _, err := br.Exec() + require.NoError(t, err) + + _, err = br.Exec() + require.NoError(t, err) + + err = br.Close() + require.NoError(t, err) + + require.Len(t, logger.logs, 3) + assert.Equal(t, "BatchQuery", logger.logs[0].msg) + assert.Equal(t, "create table foo (id bigint)", logger.logs[0].data["sql"]) + assert.Equal(t, "BatchQuery", logger.logs[1].msg) + assert.Equal(t, "drop table foo", logger.logs[1].data["sql"]) + assert.Equal(t, "BatchClose", logger.logs[2].msg) + }) +} + +func TestLogBatchStatementsOnBatchResultClose(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + logger.Clear() // Clear any logs written when establishing connection + + batch := &pgx.Batch{} + batch.Queue("select generate_series(1,$1)", 100) + batch.Queue("select 1 = 1;") + + br := conn.SendBatch(ctx, batch) + err := br.Close() + require.NoError(t, err) + + require.Len(t, logger.logs, 3) + assert.Equal(t, "BatchQuery", logger.logs[0].msg) + assert.Equal(t, "select generate_series(1,$1)", logger.logs[0].data["sql"]) + assert.Equal(t, "BatchQuery", logger.logs[1].msg) + assert.Equal(t, "select 1 = 1;", logger.logs[1].data["sql"]) + assert.Equal(t, "BatchClose", logger.logs[2].msg) + }) +} + +func TestLogAcquire(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + + poolConfig, err := pgxpool.ParseConfig(config.ConnString()) + require.NoError(t, err) + + poolConfig.ConnConfig = config + pool1, err := pgxpool.NewWithConfig(ctx, poolConfig) + require.NoError(t, err) + defer pool1.Close() + + conn1, err := pool1.Acquire(ctx) + require.NoError(t, err) + defer conn1.Release() + require.Len(t, logger.logs, 2) // Has both the Connect and Acquire logs + require.Equal(t, "Acquire", logger.logs[1].msg) + require.Equal(t, tracelog.LogLevelDebug, logger.logs[1].lvl) + + logger.Clear() + + // create a 2nd pool with a bad host to verify the error handling + poolConfig, err = pgxpool.ParseConfig("host=/invalid") + require.NoError(t, err) + poolConfig.ConnConfig.Tracer = tracer + + pool2, err := pgxpool.NewWithConfig(ctx, poolConfig) + require.NoError(t, err) + defer pool2.Close() + + conn2, err := pool2.Acquire(ctx) + require.Error(t, err) + require.Nil(t, conn2) + require.Len(t, logger.logs, 2) + require.Equal(t, "Acquire", logger.logs[1].msg) + require.Equal(t, tracelog.LogLevelError, logger.logs[1].lvl) +} + +func TestLogRelease(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + + poolConfig, err := pgxpool.ParseConfig(config.ConnString()) + require.NoError(t, err) + + poolConfig.ConnConfig = config + pool1, err := pgxpool.NewWithConfig(ctx, poolConfig) + require.NoError(t, err) + defer pool1.Close() + + conn1, err := pool1.Acquire(ctx) + require.NoError(t, err) + + logger.Clear() + conn1.Release() + require.Len(t, logger.logs, 1) + require.Equal(t, "Release", logger.logs[0].msg) + require.Equal(t, tracelog.LogLevelDebug, logger.logs[0].lvl) +} + +func TestLogPrepare(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(ctx, t, ctr, []pgx.QueryExecMode{ + pgx.QueryExecModeCacheStatement, + pgx.QueryExecModeCacheDescribe, + pgx.QueryExecModeDescribeExec, + }, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + logger.Clear() // Clear any logs written when establishing connection + + _, err := conn.Exec(ctx, `select $1::text`, "testing") + require.NoError(t, err) + + logs := logger.FilterByMsg("Prepare") + require.Len(t, logs, 1) + require.Equal(t, tracelog.LogLevelInfo, logs[0].lvl) + + logger.Clear() + + _, err = conn.Exec(ctx, `foo aaaa`, "testing") + require.Error(t, err) + + logs = logger.FilterByMsg("Prepare") + require.Len(t, logs, 1) + require.Equal(t, err, logs[0].data["err"]) + }) + + ctx, cancel = context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + logger.Clear() // Clear any logs written when establishing connection + + _, err := conn.Prepare(ctx, "test_query_1", `select $1::int`) + require.NoError(t, err) + + require.Len(t, logger.logs, 1) + require.Equal(t, "Prepare", logger.logs[0].msg) + require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl) + + logger.Clear() + + _, err = conn.Prepare(ctx, `test_query_2`, "foo aaaa") + require.Error(t, err) + + require.Len(t, logger.logs, 1) + require.Equal(t, "Prepare", logger.logs[0].msg) + require.Equal(t, err, logger.logs[0].data["err"]) + }) +} + +// https://github.com/jackc/pgx/pull/2120 +func TestConcurrentUsage(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.ConnConfig.Tracer = tracer + + for i := 0; i < 50; i++ { + func() { + pool, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + + defer pool.Close() + + eg := errgroup.Group{} + + for i := 0; i < 5; i++ { + eg.Go(func() error { + _, err := pool.Exec(ctx, `select 1`) + return err + }) + } + + err = eg.Wait() + require.NoError(t, err) + }() + } +} diff --git a/tracer.go b/tracer.go new file mode 100644 index 000000000..58ca99f7e --- /dev/null +++ b/tracer.go @@ -0,0 +1,107 @@ +package pgx + +import ( + "context" + + "github.com/jackc/pgx/v5/pgconn" +) + +// QueryTracer traces Query, QueryRow, and Exec. +type QueryTracer interface { + // TraceQueryStart is called at the beginning of Query, QueryRow, and Exec calls. The returned context is used for the + // rest of the call and will be passed to TraceQueryEnd. + TraceQueryStart(ctx context.Context, conn *Conn, data TraceQueryStartData) context.Context + + TraceQueryEnd(ctx context.Context, conn *Conn, data TraceQueryEndData) +} + +type TraceQueryStartData struct { + SQL string + Args []any +} + +type TraceQueryEndData struct { + CommandTag pgconn.CommandTag + Err error +} + +// BatchTracer traces SendBatch. +type BatchTracer interface { + // TraceBatchStart is called at the beginning of SendBatch calls. The returned context is used for the + // rest of the call and will be passed to TraceBatchQuery and TraceBatchEnd. + TraceBatchStart(ctx context.Context, conn *Conn, data TraceBatchStartData) context.Context + + TraceBatchQuery(ctx context.Context, conn *Conn, data TraceBatchQueryData) + TraceBatchEnd(ctx context.Context, conn *Conn, data TraceBatchEndData) +} + +type TraceBatchStartData struct { + Batch *Batch +} + +type TraceBatchQueryData struct { + SQL string + Args []any + CommandTag pgconn.CommandTag + Err error +} + +type TraceBatchEndData struct { + Err error +} + +// CopyFromTracer traces CopyFrom. +type CopyFromTracer interface { + // TraceCopyFromStart is called at the beginning of CopyFrom calls. The returned context is used for the + // rest of the call and will be passed to TraceCopyFromEnd. + TraceCopyFromStart(ctx context.Context, conn *Conn, data TraceCopyFromStartData) context.Context + + TraceCopyFromEnd(ctx context.Context, conn *Conn, data TraceCopyFromEndData) +} + +type TraceCopyFromStartData struct { + TableName Identifier + ColumnNames []string +} + +type TraceCopyFromEndData struct { + CommandTag pgconn.CommandTag + Err error +} + +// PrepareTracer traces Prepare. +type PrepareTracer interface { + // TracePrepareStart is called at the beginning of Prepare calls. The returned context is used for the + // rest of the call and will be passed to TracePrepareEnd. + TracePrepareStart(ctx context.Context, conn *Conn, data TracePrepareStartData) context.Context + + TracePrepareEnd(ctx context.Context, conn *Conn, data TracePrepareEndData) +} + +type TracePrepareStartData struct { + Name string + SQL string +} + +type TracePrepareEndData struct { + AlreadyPrepared bool + Err error +} + +// ConnectTracer traces Connect and ConnectConfig. +type ConnectTracer interface { + // TraceConnectStart is called at the beginning of Connect and ConnectConfig calls. The returned context is used for + // the rest of the call and will be passed to TraceConnectEnd. + TraceConnectStart(ctx context.Context, data TraceConnectStartData) context.Context + + TraceConnectEnd(ctx context.Context, data TraceConnectEndData) +} + +type TraceConnectStartData struct { + ConnConfig *ConnConfig +} + +type TraceConnectEndData struct { + Conn *Conn + Err error +} diff --git a/tracer_test.go b/tracer_test.go new file mode 100644 index 000000000..8920313ff --- /dev/null +++ b/tracer_test.go @@ -0,0 +1,608 @@ +package pgx_test + +import ( + "context" + "testing" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" +) + +type testTracer struct { + traceQueryStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context + traceQueryEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) + traceBatchStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context + traceBatchQuery func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) + traceBatchEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) + traceCopyFromStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context + traceCopyFromEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) + tracePrepareStart func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context + tracePrepareEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) + traceConnectStart func(ctx context.Context, data pgx.TraceConnectStartData) context.Context + traceConnectEnd func(ctx context.Context, data pgx.TraceConnectEndData) +} + +type ctxKey string + +func (tt *testTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + if tt.traceQueryStart != nil { + return tt.traceQueryStart(ctx, conn, data) + } + return ctx +} + +func (tt *testTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + if tt.traceQueryEnd != nil { + tt.traceQueryEnd(ctx, conn, data) + } +} + +func (tt *testTracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + if tt.traceBatchStart != nil { + return tt.traceBatchStart(ctx, conn, data) + } + return ctx +} + +func (tt *testTracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + if tt.traceBatchQuery != nil { + tt.traceBatchQuery(ctx, conn, data) + } +} + +func (tt *testTracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + if tt.traceBatchEnd != nil { + tt.traceBatchEnd(ctx, conn, data) + } +} + +func (tt *testTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context { + if tt.traceCopyFromStart != nil { + return tt.traceCopyFromStart(ctx, conn, data) + } + return ctx +} + +func (tt *testTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) { + if tt.traceCopyFromEnd != nil { + tt.traceCopyFromEnd(ctx, conn, data) + } +} + +func (tt *testTracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context { + if tt.tracePrepareStart != nil { + return tt.tracePrepareStart(ctx, conn, data) + } + return ctx +} + +func (tt *testTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { + if tt.tracePrepareEnd != nil { + tt.tracePrepareEnd(ctx, conn, data) + } +} + +func (tt *testTracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context { + if tt.traceConnectStart != nil { + return tt.traceConnectStart(ctx, data) + } + return ctx +} + +func (tt *testTracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) { + if tt.traceConnectEnd != nil { + tt.traceConnectEnd(ctx, data) + } +} + +func TestTraceExec(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceQueryStartCalled := false + tracer.traceQueryStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + traceQueryStartCalled = true + require.Equal(t, `select $1::text`, data.SQL) + require.Len(t, data.Args, 1) + require.Equal(t, `testing`, data.Args[0]) + return context.WithValue(ctx, ctxKey(ctxKey("fromTraceQueryStart")), "foo") + } + + traceQueryEndCalled := false + tracer.traceQueryEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + traceQueryEndCalled = true + require.Equal(t, "foo", ctx.Value(ctxKey(ctxKey("fromTraceQueryStart")))) + require.Equal(t, `SELECT 1`, data.CommandTag.String()) + require.NoError(t, data.Err) + } + + _, err := conn.Exec(ctx, `select $1::text`, "testing") + require.NoError(t, err) + require.True(t, traceQueryStartCalled) + require.True(t, traceQueryEndCalled) + }) +} + +func TestTraceQuery(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceQueryStartCalled := false + tracer.traceQueryStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + traceQueryStartCalled = true + require.Equal(t, `select $1::text`, data.SQL) + require.Len(t, data.Args, 1) + require.Equal(t, `testing`, data.Args[0]) + return context.WithValue(ctx, ctxKey("fromTraceQueryStart"), "foo") + } + + traceQueryEndCalled := false + tracer.traceQueryEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + traceQueryEndCalled = true + require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceQueryStart"))) + require.Equal(t, `SELECT 1`, data.CommandTag.String()) + require.NoError(t, data.Err) + } + + var s string + err := conn.QueryRow(ctx, `select $1::text`, "testing").Scan(&s) + require.NoError(t, err) + require.Equal(t, "testing", s) + require.True(t, traceQueryStartCalled) + require.True(t, traceQueryEndCalled) + }) +} + +func TestTraceBatchNormal(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceBatchStartCalled := false + tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + traceBatchStartCalled = true + require.NotNil(t, data.Batch) + require.Equal(t, 2, data.Batch.Len()) + return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo") + } + + traceBatchQueryCalledCount := 0 + tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + traceBatchQueryCalledCount++ + require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart"))) + require.NoError(t, data.Err) + } + + traceBatchEndCalled := false + tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + traceBatchEndCalled = true + require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart"))) + require.NoError(t, data.Err) + } + + batch := &pgx.Batch{} + batch.Queue(`select 1`) + batch.Queue(`select 2`) + + br := conn.SendBatch(context.Background(), batch) + require.True(t, traceBatchStartCalled) + + var n int32 + err := br.QueryRow().Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + require.EqualValues(t, 1, traceBatchQueryCalledCount) + + err = br.QueryRow().Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 2, n) + require.EqualValues(t, 2, traceBatchQueryCalledCount) + + err = br.Close() + require.NoError(t, err) + + require.True(t, traceBatchEndCalled) + }) +} + +func TestTraceBatchClose(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceBatchStartCalled := false + tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + traceBatchStartCalled = true + require.NotNil(t, data.Batch) + require.Equal(t, 2, data.Batch.Len()) + return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo") + } + + traceBatchQueryCalledCount := 0 + tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + traceBatchQueryCalledCount++ + require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart"))) + require.NoError(t, data.Err) + } + + traceBatchEndCalled := false + tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + traceBatchEndCalled = true + require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart"))) + require.NoError(t, data.Err) + } + + batch := &pgx.Batch{} + batch.Queue(`select 1`) + batch.Queue(`select 2`) + + br := conn.SendBatch(context.Background(), batch) + require.True(t, traceBatchStartCalled) + err := br.Close() + require.NoError(t, err) + require.EqualValues(t, 2, traceBatchQueryCalledCount) + require.True(t, traceBatchEndCalled) + }) +} + +func TestTraceBatchErrorWhileReadingResults(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, ctr, []pgx.QueryExecMode{pgx.QueryExecModeSimpleProtocol}, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceBatchStartCalled := false + tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + traceBatchStartCalled = true + require.NotNil(t, data.Batch) + require.Equal(t, 3, data.Batch.Len()) + return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo") + } + + traceBatchQueryCalledCount := 0 + tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + traceBatchQueryCalledCount++ + require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart"))) + if traceBatchQueryCalledCount == 2 { + require.Error(t, data.Err) + } else { + require.NoError(t, data.Err) + } + } + + traceBatchEndCalled := false + tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + traceBatchEndCalled = true + require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart"))) + require.Error(t, data.Err) + } + + batch := &pgx.Batch{} + batch.Queue(`select 1`) + batch.Queue(`select 2/n-2 from generate_series(0,10) n`) + batch.Queue(`select 3`) + + br := conn.SendBatch(context.Background(), batch) + require.True(t, traceBatchStartCalled) + + commandTag, err := br.Exec() + require.NoError(t, err) + require.Equal(t, "SELECT 1", commandTag.String()) + + commandTag, err = br.Exec() + require.Error(t, err) + require.Equal(t, "", commandTag.String()) + + commandTag, err = br.Exec() + require.Error(t, err) + require.Equal(t, "", commandTag.String()) + + err = br.Close() + require.Error(t, err) + require.EqualValues(t, 2, traceBatchQueryCalledCount) + require.True(t, traceBatchEndCalled) + }) +} + +func TestTraceBatchErrorWhileReadingResultsWhileClosing(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, ctr, []pgx.QueryExecMode{pgx.QueryExecModeSimpleProtocol}, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceBatchStartCalled := false + tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + traceBatchStartCalled = true + require.NotNil(t, data.Batch) + require.Equal(t, 3, data.Batch.Len()) + return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo") + } + + traceBatchQueryCalledCount := 0 + tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + traceBatchQueryCalledCount++ + require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart"))) + if traceBatchQueryCalledCount == 2 { + require.Error(t, data.Err) + } else { + require.NoError(t, data.Err) + } + } + + traceBatchEndCalled := false + tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + traceBatchEndCalled = true + require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart"))) + require.Error(t, data.Err) + } + + batch := &pgx.Batch{} + batch.Queue(`select 1`) + batch.Queue(`select 2/n-2 from generate_series(0,10) n`) + batch.Queue(`select 3`) + + br := conn.SendBatch(context.Background(), batch) + require.True(t, traceBatchStartCalled) + err := br.Close() + require.Error(t, err) + require.EqualValues(t, 2, traceBatchQueryCalledCount) + require.True(t, traceBatchEndCalled) + }) +} + +func TestTraceCopyFrom(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + traceCopyFromStartCalled := false + tracer.traceCopyFromStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context { + traceCopyFromStartCalled = true + require.Equal(t, pgx.Identifier{"foo"}, data.TableName) + require.Equal(t, []string{"a"}, data.ColumnNames) + return context.WithValue(ctx, ctxKey("fromTraceCopyFromStart"), "foo") + } + + traceCopyFromEndCalled := false + tracer.traceCopyFromEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) { + traceCopyFromEndCalled = true + require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceCopyFromStart"))) + require.Equal(t, `COPY 2`, data.CommandTag.String()) + require.NoError(t, data.Err) + } + + _, err := conn.Exec(ctx, `create temporary table foo(a int4)`) + require.NoError(t, err) + + inputRows := [][]any{ + {int32(1)}, + {nil}, + } + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + require.True(t, traceCopyFromStartCalled) + require.True(t, traceCopyFromEndCalled) + }) +} + +func TestTracePrepare(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + tracePrepareStartCalled := false + tracer.tracePrepareStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context { + tracePrepareStartCalled = true + require.Equal(t, `ps`, data.Name) + require.Equal(t, `select $1::text`, data.SQL) + return context.WithValue(ctx, ctxKey("fromTracePrepareStart"), "foo") + } + + tracePrepareEndCalled := false + tracer.tracePrepareEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { + tracePrepareEndCalled = true + require.False(t, data.AlreadyPrepared) + require.NoError(t, data.Err) + } + + _, err := conn.Prepare(ctx, "ps", `select $1::text`) + require.NoError(t, err) + require.True(t, tracePrepareStartCalled) + require.True(t, tracePrepareEndCalled) + + tracePrepareStartCalled = false + tracePrepareEndCalled = false + tracer.tracePrepareEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { + tracePrepareEndCalled = true + require.True(t, data.AlreadyPrepared) + require.NoError(t, data.Err) + } + + _, err = conn.Prepare(ctx, "ps", `select $1::text`) + require.NoError(t, err) + require.True(t, tracePrepareStartCalled) + require.True(t, tracePrepareEndCalled) + }) +} + +func TestTraceConnect(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + config := defaultConnTestRunner.CreateConfig(context.Background(), t) + config.Tracer = tracer + + traceConnectStartCalled := false + tracer.traceConnectStart = func(ctx context.Context, data pgx.TraceConnectStartData) context.Context { + traceConnectStartCalled = true + require.NotNil(t, data.ConnConfig) + return context.WithValue(ctx, ctxKey("fromTraceConnectStart"), "foo") + } + + traceConnectEndCalled := false + tracer.traceConnectEnd = func(ctx context.Context, data pgx.TraceConnectEndData) { + traceConnectEndCalled = true + require.NotNil(t, data.Conn) + require.NoError(t, data.Err) + } + + conn1, err := pgx.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer conn1.Close(context.Background()) + require.True(t, traceConnectStartCalled) + require.True(t, traceConnectEndCalled) + + config, err = pgx.ParseConfig("host=/invalid") + require.NoError(t, err) + config.Tracer = tracer + + traceConnectStartCalled = false + traceConnectEndCalled = false + tracer.traceConnectEnd = func(ctx context.Context, data pgx.TraceConnectEndData) { + traceConnectEndCalled = true + require.Nil(t, data.Conn) + require.Error(t, data.Err) + } + + conn2, err := pgx.ConnectConfig(context.Background(), config) + require.Nil(t, conn2) + require.Error(t, err) + require.True(t, traceConnectStartCalled) + require.True(t, traceConnectEndCalled) +} + +// Ensure tracer runs within a transaction. +// +// https://github.com/jackc/pgx/issues/2304 +func TestTraceWithinTx(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var queries []string + tracer.traceQueryStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + queries = append(queries, data.SQL) + return ctx + } + + tx, err := conn.Begin(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + _, err = tx.Exec(ctx, `select $1::text`, "testing") + require.NoError(t, err) + err = tx.Commit(ctx) + require.NoError(t, err) + + require.Len(t, queries, 3) + require.Equal(t, `begin`, queries[0]) + require.Equal(t, `select $1::text`, queries[1]) + require.Equal(t, `commit`, queries[2]) + }) +} diff --git a/travis/before_install.bash b/travis/before_install.bash deleted file mode 100755 index 23c7d9cf0..000000000 --- a/travis/before_install.bash +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env bash -set -eux - -if [ "${PGVERSION-}" != "" ] -then - sudo apt-get remove -y --purge postgresql libpq-dev libpq5 postgresql-client-common postgresql-common - sudo rm -rf /var/lib/postgresql - wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - - sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list" - sudo apt-get update -qq - sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION - sudo chmod 777 /etc/postgresql/$PGVERSION/main/pg_hba.conf - echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf - echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - echo "host all pgx_md5 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - echo "host all pgx_pw 127.0.0.1/32 password" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - echo "hostssl all pgx_ssl 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - echo "host replication pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - echo "host pgx_test pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - sudo chmod 777 /etc/postgresql/$PGVERSION/main/postgresql.conf - if $(dpkg --compare-versions $PGVERSION ge 9.6) ; then - echo "wal_level='logical'" >> /etc/postgresql/$PGVERSION/main/postgresql.conf - echo "max_wal_senders=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf - echo "max_replication_slots=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf - fi - sudo /etc/init.d/postgresql restart -fi - -if [ "${CRATEVERSION-}" != "" ] -then - docker run \ - -p "6543:5432" \ - -d \ - crate:"$CRATEVERSION" \ - crate \ - -Cnetwork.host=0.0.0.0 \ - -Ctransport.host=localhost \ - -Clicense.enterprise=false -fi diff --git a/travis/before_script.bash b/travis/before_script.bash deleted file mode 100755 index 7e206e7ad..000000000 --- a/travis/before_script.bash +++ /dev/null @@ -1,17 +0,0 @@ -#!/usr/bin/env bash -set -eux - -mv conn_config_test.go.travis conn_config_test.go - -if [ "${PGVERSION-}" != "" ] -then - # The tricky test user, below, has to actually exist so that it can be used in a test - # of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. - psql -U postgres -c 'create database pgx_test' - psql -U postgres pgx_test -c 'create extension hstore' - psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user pgx_replication with replication password 'secret'" - psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'" -fi diff --git a/travis/install.bash b/travis/install.bash deleted file mode 100755 index 61b683ec8..000000000 --- a/travis/install.bash +++ /dev/null @@ -1,13 +0,0 @@ -#!/usr/bin/env bash -set -eux - -go get -u github.com/cockroachdb/apd -go get -u github.com/shopspring/decimal -go get -u gopkg.in/inconshreveable/log15.v2 -go get -u github.com/jackc/fake -go get -u github.com/lib/pq -go get -u github.com/hashicorp/go-version -go get -u github.com/satori/go.uuid -go get -u github.com/sirupsen/logrus -go get -u github.com/pkg/errors -go get -u go.uber.org/zap diff --git a/travis/script.bash b/travis/script.bash deleted file mode 100755 index 5bf1b77e7..000000000 --- a/travis/script.bash +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env bash -set -eux - -if [ "${PGVERSION-}" != "" ] -then - go test -v -race ./... -elif [ "${CRATEVERSION-}" != "" ] -then - go test -v -race -run 'TestCrateDBConnect' -fi diff --git a/tx.go b/tx.go index 81fcfa267..571e5e00f 100644 --- a/tx.go +++ b/tx.go @@ -1,69 +1,82 @@ package pgx import ( - "bytes" "context" - "fmt" - "time" + "errors" + "strconv" + "strings" - "github.com/pkg/errors" + "github.com/jackc/pgx/v5/pgconn" ) +// TxIsoLevel is the transaction isolation level (serializable, repeatable read, read committed or read uncommitted) type TxIsoLevel string // Transaction isolation levels const ( - Serializable = TxIsoLevel("serializable") - RepeatableRead = TxIsoLevel("repeatable read") - ReadCommitted = TxIsoLevel("read committed") - ReadUncommitted = TxIsoLevel("read uncommitted") + Serializable TxIsoLevel = "serializable" + RepeatableRead TxIsoLevel = "repeatable read" + ReadCommitted TxIsoLevel = "read committed" + ReadUncommitted TxIsoLevel = "read uncommitted" ) +// TxAccessMode is the transaction access mode (read write or read only) type TxAccessMode string // Transaction access modes const ( - ReadWrite = TxAccessMode("read write") - ReadOnly = TxAccessMode("read only") + ReadWrite TxAccessMode = "read write" + ReadOnly TxAccessMode = "read only" ) +// TxDeferrableMode is the transaction deferrable mode (deferrable or not deferrable) type TxDeferrableMode string // Transaction deferrable modes const ( - Deferrable = TxDeferrableMode("deferrable") - NotDeferrable = TxDeferrableMode("not deferrable") -) - -const ( - TxStatusInProgress = 0 - TxStatusCommitFailure = -1 - TxStatusRollbackFailure = -2 - TxStatusCommitSuccess = 1 - TxStatusRollbackSuccess = 2 + Deferrable TxDeferrableMode = "deferrable" + NotDeferrable TxDeferrableMode = "not deferrable" ) +// TxOptions are transaction modes within a transaction block type TxOptions struct { IsoLevel TxIsoLevel AccessMode TxAccessMode DeferrableMode TxDeferrableMode + + // BeginQuery is the SQL query that will be executed to begin the transaction. This allows using non-standard syntax + // such as BEGIN PRIORITY HIGH with CockroachDB. If set this will override the other settings. + BeginQuery string + // CommitQuery is the SQL query that will be executed to commit the transaction. + CommitQuery string } -func (txOptions *TxOptions) beginSQL() string { - if txOptions == nil { +var emptyTxOptions TxOptions + +func (txOptions TxOptions) beginSQL() string { + if txOptions == emptyTxOptions { return "begin" } - buf := &bytes.Buffer{} + if txOptions.BeginQuery != "" { + return txOptions.BeginQuery + } + + var buf strings.Builder + buf.Grow(64) // 64 - maximum length of string with available options buf.WriteString("begin") + if txOptions.IsoLevel != "" { - fmt.Fprintf(buf, " isolation level %s", txOptions.IsoLevel) + buf.WriteString(" isolation level ") + buf.WriteString(string(txOptions.IsoLevel)) } if txOptions.AccessMode != "" { - fmt.Fprintf(buf, " %s", txOptions.AccessMode) + buf.WriteByte(' ') + buf.WriteString(string(txOptions.AccessMode)) } if txOptions.DeferrableMode != "" { - fmt.Fprintf(buf, " %s", txOptions.DeferrableMode) + buf.WriteByte(' ') + buf.WriteString(string(txOptions.DeferrableMode)) } return buf.String() @@ -76,172 +89,354 @@ var ErrTxClosed = errors.New("tx is closed") // it is treated as ROLLBACK. var ErrTxCommitRollback = errors.New("commit unexpectedly resulted in rollback") -// Begin starts a transaction with the default transaction mode for the -// current connection. To use a specific transaction mode see BeginEx. -func (c *Conn) Begin() (*Tx, error) { - return c.BeginEx(context.Background(), nil) +// Begin starts a transaction. Unlike database/sql, the context only affects the begin command. i.e. there is no +// auto-rollback on context cancellation. +func (c *Conn) Begin(ctx context.Context) (Tx, error) { + return c.BeginTx(ctx, TxOptions{}) } -// BeginEx starts a transaction with txOptions determining the transaction -// mode. Unlike database/sql, the context only affects the begin command. i.e. -// there is no auto-rollback on context cancelation. -func (c *Conn) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) { - _, err := c.ExecEx(ctx, txOptions.beginSQL(), nil) +// BeginTx starts a transaction with txOptions determining the transaction mode. Unlike database/sql, the context only +// affects the begin command. i.e. there is no auto-rollback on context cancellation. +func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) { + _, err := c.Exec(ctx, txOptions.beginSQL()) if err != nil { // begin should never fail unless there is an underlying connection issue or // a context timeout. In either case, the connection is possibly broken. - c.die(errors.New("failed to begin transaction")) + c.die() return nil, err } - return &Tx{conn: c}, nil + return &dbTx{ + conn: c, + commitQuery: txOptions.CommitQuery, + }, nil } // Tx represents a database transaction. // -// All Tx methods return ErrTxClosed if Commit or Rollback has already been -// called on the Tx. -type Tx struct { - conn *Conn - connPool *ConnPool - err error - status int8 +// Tx is an interface instead of a struct to enable connection pools to be implemented without relying on internal pgx +// state, to support pseudo-nested transactions with savepoints, and to allow tests to mock transactions. However, +// adding a method to an interface is technically a breaking change. If new methods are added to Conn it may be +// desirable to add them to Tx as well. Because of this the Tx interface is partially excluded from semantic version +// requirements. Methods will not be removed or changed, but new methods may be added. +type Tx interface { + // Begin starts a pseudo nested transaction. + Begin(ctx context.Context) (Tx, error) + + // Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested + // transaction. Commit will return an error where errors.Is(ErrTxClosed) is true if the Tx is already closed, but is + // otherwise safe to call multiple times. If the commit fails with a rollback status (e.g. the transaction was already + // in a broken state) then an error where errors.Is(ErrTxCommitRollback) is true will be returned. + Commit(ctx context.Context) error + + // Rollback rolls back the transaction if this is a real transaction or rolls back to the savepoint if this is a + // pseudo nested transaction. Rollback will return an error where errors.Is(ErrTxClosed) is true if the Tx is already + // closed, but is otherwise safe to call multiple times. Hence, a defer tx.Rollback() is safe even if tx.Commit() will + // be called first in a non-error condition. Any other failure of a real transaction will result in the connection + // being closed. + Rollback(ctx context.Context) error + + CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) + SendBatch(ctx context.Context, b *Batch) BatchResults + LargeObjects() LargeObjects + + Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) + + Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) + Query(ctx context.Context, sql string, args ...any) (Rows, error) + QueryRow(ctx context.Context, sql string, args ...any) Row + + // Conn returns the underlying *Conn that on which this transaction is executing. + Conn() *Conn } -// Commit commits the transaction -func (tx *Tx) Commit() error { - return tx.CommitEx(context.Background()) +// dbTx represents a database transaction. +// +// All dbTx methods return ErrTxClosed if Commit or Rollback has already been +// called on the dbTx. +type dbTx struct { + conn *Conn + savepointNum int64 + closed bool + commitQuery string } -// CommitEx commits the transaction with a context. -func (tx *Tx) CommitEx(ctx context.Context) error { - if tx.status != TxStatusInProgress { +// Begin starts a pseudo nested transaction implemented with a savepoint. +func (tx *dbTx) Begin(ctx context.Context) (Tx, error) { + if tx.closed { + return nil, ErrTxClosed + } + + tx.savepointNum++ + _, err := tx.conn.Exec(ctx, "savepoint sp_"+strconv.FormatInt(tx.savepointNum, 10)) + if err != nil { + return nil, err + } + + return &dbSimulatedNestedTx{tx: tx, savepointNum: tx.savepointNum}, nil +} + +// Commit commits the transaction. +func (tx *dbTx) Commit(ctx context.Context) error { + if tx.closed { return ErrTxClosed } - commandTag, err := tx.conn.ExecEx(ctx, "commit", nil) - if err == nil && commandTag == "COMMIT" { - tx.status = TxStatusCommitSuccess - } else if err == nil && commandTag == "ROLLBACK" { - tx.status = TxStatusCommitFailure - tx.err = ErrTxCommitRollback - } else { - tx.status = TxStatusCommitFailure - tx.err = err - // A commit failure leaves the connection in an undefined state - tx.conn.die(errors.New("commit failed")) + commandSQL := "commit" + if tx.commitQuery != "" { + commandSQL = tx.commitQuery } - if tx.connPool != nil { - tx.connPool.Release(tx.conn) + commandTag, err := tx.conn.Exec(ctx, commandSQL) + tx.closed = true + if err != nil { + if tx.conn.PgConn().TxStatus() != 'I' { + _ = tx.conn.Close(ctx) // already have error to return + } + return err + } + if commandTag.String() == "ROLLBACK" { + return ErrTxCommitRollback } - return tx.err + return nil } // Rollback rolls back the transaction. Rollback will return ErrTxClosed if the // Tx is already closed, but is otherwise safe to call multiple times. Hence, a // defer tx.Rollback() is safe even if tx.Commit() will be called first in a // non-error condition. -func (tx *Tx) Rollback() error { - ctx, _ := context.WithTimeout(context.Background(), 15*time.Second) - return tx.RollbackEx(ctx) -} - -// RollbackEx is the context version of Rollback -func (tx *Tx) RollbackEx(ctx context.Context) error { - if tx.status != TxStatusInProgress { +func (tx *dbTx) Rollback(ctx context.Context) error { + if tx.closed { return ErrTxClosed } - _, tx.err = tx.conn.ExecEx(ctx, "rollback", nil) - if tx.err == nil { - tx.status = TxStatusRollbackSuccess - } else { - tx.status = TxStatusRollbackFailure + _, err := tx.conn.Exec(ctx, "rollback") + tx.closed = true + if err != nil { // A rollback failure leaves the connection in an undefined state - tx.conn.die(errors.New("rollback failed")) + tx.conn.die() + return err } - if tx.connPool != nil { - tx.connPool.Release(tx.conn) + return nil +} + +// Exec delegates to the underlying *Conn +func (tx *dbTx) Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) { + if tx.closed { + return pgconn.CommandTag{}, ErrTxClosed } - return tx.err + return tx.conn.Exec(ctx, sql, arguments...) } -// Exec delegates to the underlying *Conn -func (tx *Tx) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { - return tx.ExecEx(context.Background(), sql, nil, arguments...) +// Prepare delegates to the underlying *Conn +func (tx *dbTx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) { + if tx.closed { + return nil, ErrTxClosed + } + + return tx.conn.Prepare(ctx, name, sql) } -// ExecEx delegates to the underlying *Conn -func (tx *Tx) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) { - if tx.status != TxStatusInProgress { - return CommandTag(""), ErrTxClosed +// Query delegates to the underlying *Conn +func (tx *dbTx) Query(ctx context.Context, sql string, args ...any) (Rows, error) { + if tx.closed { + // Because checking for errors can be deferred to the *Rows, build one with the error + err := ErrTxClosed + return &baseRows{closed: true, err: err}, err } - return tx.conn.ExecEx(ctx, sql, options, arguments...) + return tx.conn.Query(ctx, sql, args...) } -// Prepare delegates to the underlying *Conn -func (tx *Tx) Prepare(name, sql string) (*PreparedStatement, error) { - return tx.PrepareEx(context.Background(), name, sql, nil) +// QueryRow delegates to the underlying *Conn +func (tx *dbTx) QueryRow(ctx context.Context, sql string, args ...any) Row { + rows, _ := tx.Query(ctx, sql, args...) + return (*connRow)(rows.(*baseRows)) +} + +// CopyFrom delegates to the underlying *Conn +func (tx *dbTx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { + if tx.closed { + return 0, ErrTxClosed + } + + return tx.conn.CopyFrom(ctx, tableName, columnNames, rowSrc) +} + +// SendBatch delegates to the underlying *Conn +func (tx *dbTx) SendBatch(ctx context.Context, b *Batch) BatchResults { + if tx.closed { + return &batchResults{err: ErrTxClosed} + } + + return tx.conn.SendBatch(ctx, b) } -// PrepareEx delegates to the underlying *Conn -func (tx *Tx) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) { - if tx.status != TxStatusInProgress { +// LargeObjects returns a LargeObjects instance for the transaction. +func (tx *dbTx) LargeObjects() LargeObjects { + return LargeObjects{tx: tx} +} + +func (tx *dbTx) Conn() *Conn { + return tx.conn +} + +// dbSimulatedNestedTx represents a simulated nested transaction implemented by a savepoint. +type dbSimulatedNestedTx struct { + tx Tx + savepointNum int64 + closed bool +} + +// Begin starts a pseudo nested transaction implemented with a savepoint. +func (sp *dbSimulatedNestedTx) Begin(ctx context.Context) (Tx, error) { + if sp.closed { return nil, ErrTxClosed } - return tx.conn.PrepareEx(ctx, name, sql, opts) + return sp.tx.Begin(ctx) } -// Query delegates to the underlying *Conn -func (tx *Tx) Query(sql string, args ...interface{}) (*Rows, error) { - return tx.QueryEx(context.Background(), sql, nil, args...) +// Commit releases the savepoint essentially committing the pseudo nested transaction. +func (sp *dbSimulatedNestedTx) Commit(ctx context.Context) error { + if sp.closed { + return ErrTxClosed + } + + _, err := sp.Exec(ctx, "release savepoint sp_"+strconv.FormatInt(sp.savepointNum, 10)) + sp.closed = true + return err } -// QueryEx delegates to the underlying *Conn -func (tx *Tx) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (*Rows, error) { - if tx.status != TxStatusInProgress { - // Because checking for errors can be deferred to the *Rows, build one with the error - err := ErrTxClosed - return &Rows{closed: true, err: err}, err +// Rollback rolls back to the savepoint essentially rolling back the pseudo nested transaction. Rollback will return +// ErrTxClosed if the dbSavepoint is already closed, but is otherwise safe to call multiple times. Hence, a defer sp.Rollback() +// is safe even if sp.Commit() will be called first in a non-error condition. +func (sp *dbSimulatedNestedTx) Rollback(ctx context.Context) error { + if sp.closed { + return ErrTxClosed } - return tx.conn.QueryEx(ctx, sql, options, args...) + _, err := sp.Exec(ctx, "rollback to savepoint sp_"+strconv.FormatInt(sp.savepointNum, 10)) + sp.closed = true + return err } -// QueryRow delegates to the underlying *Conn -func (tx *Tx) QueryRow(sql string, args ...interface{}) *Row { - rows, _ := tx.Query(sql, args...) - return (*Row)(rows) +// Exec delegates to the underlying Tx +func (sp *dbSimulatedNestedTx) Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) { + if sp.closed { + return pgconn.CommandTag{}, ErrTxClosed + } + + return sp.tx.Exec(ctx, sql, arguments...) } -// QueryRowEx delegates to the underlying *Conn -func (tx *Tx) QueryRowEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) *Row { - rows, _ := tx.QueryEx(ctx, sql, options, args...) - return (*Row)(rows) +// Prepare delegates to the underlying Tx +func (sp *dbSimulatedNestedTx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) { + if sp.closed { + return nil, ErrTxClosed + } + + return sp.tx.Prepare(ctx, name, sql) +} + +// Query delegates to the underlying Tx +func (sp *dbSimulatedNestedTx) Query(ctx context.Context, sql string, args ...any) (Rows, error) { + if sp.closed { + // Because checking for errors can be deferred to the *Rows, build one with the error + err := ErrTxClosed + return &baseRows{closed: true, err: err}, err + } + + return sp.tx.Query(ctx, sql, args...) +} + +// QueryRow delegates to the underlying Tx +func (sp *dbSimulatedNestedTx) QueryRow(ctx context.Context, sql string, args ...any) Row { + rows, _ := sp.Query(ctx, sql, args...) + return (*connRow)(rows.(*baseRows)) } // CopyFrom delegates to the underlying *Conn -func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) { - if tx.status != TxStatusInProgress { +func (sp *dbSimulatedNestedTx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { + if sp.closed { return 0, ErrTxClosed } - return tx.conn.CopyFrom(tableName, columnNames, rowSrc) + return sp.tx.CopyFrom(ctx, tableName, columnNames, rowSrc) +} + +// SendBatch delegates to the underlying *Conn +func (sp *dbSimulatedNestedTx) SendBatch(ctx context.Context, b *Batch) BatchResults { + if sp.closed { + return &batchResults{err: ErrTxClosed} + } + + return sp.tx.SendBatch(ctx, b) +} + +func (sp *dbSimulatedNestedTx) LargeObjects() LargeObjects { + return LargeObjects{tx: sp} +} + +func (sp *dbSimulatedNestedTx) Conn() *Conn { + return sp.tx.Conn() +} + +// BeginFunc calls Begin on db and then calls fn. If fn does not return an error then it calls Commit on db. If fn +// returns an error it calls Rollback on db. The context will be used when executing the transaction control statements +// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of fn. +func BeginFunc( + ctx context.Context, + db interface { + Begin(ctx context.Context) (Tx, error) + }, + fn func(Tx) error, +) (err error) { + var tx Tx + tx, err = db.Begin(ctx) + if err != nil { + return err + } + + return beginFuncExec(ctx, tx, fn) } -// Status returns the status of the transaction from the set of -// pgx.TxStatus* constants. -func (tx *Tx) Status() int8 { - return tx.status +// BeginTxFunc calls BeginTx on db and then calls fn. If fn does not return an error then it calls Commit on db. If fn +// returns an error it calls Rollback on db. The context will be used when executing the transaction control statements +// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of fn. +func BeginTxFunc( + ctx context.Context, + db interface { + BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) + }, + txOptions TxOptions, + fn func(Tx) error, +) (err error) { + var tx Tx + tx, err = db.BeginTx(ctx, txOptions) + if err != nil { + return err + } + + return beginFuncExec(ctx, tx, fn) } -// Err returns the final error state, if any, of calling Commit or Rollback. -func (tx *Tx) Err() error { - return tx.err +func beginFuncExec(ctx context.Context, tx Tx, fn func(Tx) error) (err error) { + defer func() { + rollbackErr := tx.Rollback(ctx) + if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) { + err = rollbackErr + } + }() + + fErr := fn(tx) + if fErr != nil { + _ = tx.Rollback(ctx) // ignore rollback error as there is already an error to return + return fErr + } + + return tx.Commit(ctx) } diff --git a/tx_test.go b/tx_test.go index b25e1c9fe..cd4fb2074 100644 --- a/tx_test.go +++ b/tx_test.go @@ -2,49 +2,51 @@ package pgx_test import ( "context" - "fmt" + "errors" + "os" "testing" "time" - "github.com/jackc/pgx" - "github.com/jackc/pgx/pgmock" - "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" ) func TestTransactionSuccessfulCommit(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) createSql := ` create temporary table foo( id integer, - unique (id) initially deferred + unique (id) ); ` - if _, err := conn.Exec(createSql); err != nil { + if _, err := conn.Exec(context.Background(), createSql); err != nil { t.Fatalf("Failed to create table: %v", err) } - tx, err := conn.Begin() + tx, err := conn.Begin(context.Background()) if err != nil { t.Fatalf("conn.Begin failed: %v", err) } - _, err = tx.Exec("insert into foo(id) values (1)") + _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)") if err != nil { t.Fatalf("tx.Exec failed: %v", err) } - err = tx.Commit() + err = tx.Commit(context.Background()) if err != nil { t.Fatalf("tx.Commit failed: %v", err) } var n int64 - err = conn.QueryRow("select count(*) from foo").Scan(&n) + err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) if err != nil { t.Fatalf("QueryRow Scan failed: %v", err) } @@ -56,41 +58,88 @@ func TestTransactionSuccessfulCommit(t *testing.T) { func TestTxCommitWhenTxBroken(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) createSql := ` create temporary table foo( id integer, - unique (id) initially deferred + unique (id) ); ` - if _, err := conn.Exec(createSql); err != nil { + if _, err := conn.Exec(context.Background(), createSql); err != nil { t.Fatalf("Failed to create table: %v", err) } - tx, err := conn.Begin() + tx, err := conn.Begin(context.Background()) if err != nil { t.Fatalf("conn.Begin failed: %v", err) } - if _, err := tx.Exec("insert into foo(id) values (1)"); err != nil { + if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil { t.Fatalf("tx.Exec failed: %v", err) } // Purposely break transaction - if _, err := tx.Exec("syntax error"); err == nil { + if _, err := tx.Exec(context.Background(), "syntax error"); err == nil { t.Fatal("Unexpected success") } - err = tx.Commit() + err = tx.Commit(context.Background()) if err != pgx.ErrTxCommitRollback { t.Fatalf("Expected error %v, got %v", pgx.ErrTxCommitRollback, err) } var n int64 - err = conn.QueryRow("select count(*) from foo").Scan(&n) + err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if n != 0 { + t.Fatalf("Did not receive correct number of rows: %v", n) + } +} + +func TestTxCommitWhenDeferredConstraintFailure(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + + createSql := ` + create temporary table foo( + id integer, + unique (id) initially deferred + ); + ` + + if _, err := conn.Exec(context.Background(), createSql); err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + tx, err := conn.Begin(context.Background()) + if err != nil { + t.Fatalf("conn.Begin failed: %v", err) + } + + if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil { + t.Fatalf("tx.Exec failed: %v", err) + } + + if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil { + t.Fatalf("tx.Exec failed: %v", err) + } + + err = tx.Commit(context.Background()) + if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "23505" { + t.Fatalf("Expected unique constraint violation 23505, got %#v", err) + } + + var n int64 + err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) if err != nil { t.Fatalf("QueryRow Scan failed: %v", err) } @@ -102,83 +151,96 @@ func TestTxCommitWhenTxBroken(t *testing.T) { func TestTxCommitSerializationFailure(t *testing.T) { t.Parallel() - pool := createConnPool(t, 5) - defer pool.Close() + c1 := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, c1) - pool.Exec(`drop table if exists tx_serializable_sums`) - _, err := pool.Exec(`create table tx_serializable_sums(num integer);`) + if c1.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip("Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/60754)") + } + + c2 := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, c2) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + c1.Exec(ctx, `drop table if exists tx_serializable_sums`) + _, err := c1.Exec(ctx, `create table tx_serializable_sums(num integer);`) if err != nil { t.Fatalf("Unable to create temporary table: %v", err) } - defer pool.Exec(`drop table tx_serializable_sums`) + defer c1.Exec(ctx, `drop table tx_serializable_sums`) - tx1, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable}) + tx1, err := c1.BeginTx(ctx, pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { - t.Fatalf("BeginEx failed: %v", err) + t.Fatalf("Begin failed: %v", err) } - defer tx1.Rollback() + defer tx1.Rollback(ctx) - tx2, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable}) + tx2, err := c2.BeginTx(ctx, pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { - t.Fatalf("BeginEx failed: %v", err) + t.Fatalf("Begin failed: %v", err) } - defer tx2.Rollback() + defer tx2.Rollback(ctx) - _, err = tx1.Exec(`insert into tx_serializable_sums(num) select sum(num) from tx_serializable_sums`) + _, err = tx1.Exec(ctx, `insert into tx_serializable_sums(num) select sum(num)::int from tx_serializable_sums`) if err != nil { t.Fatalf("Exec failed: %v", err) } - _, err = tx2.Exec(`insert into tx_serializable_sums(num) select sum(num) from tx_serializable_sums`) + _, err = tx2.Exec(ctx, `insert into tx_serializable_sums(num) select sum(num)::int from tx_serializable_sums`) if err != nil { t.Fatalf("Exec failed: %v", err) } - err = tx1.Commit() + err = tx1.Commit(ctx) if err != nil { t.Fatalf("Commit failed: %v", err) } - err = tx2.Commit() - if pgErr, ok := err.(pgx.PgError); !ok || pgErr.Code != "40001" { + err = tx2.Commit(ctx) + if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "40001" { t.Fatalf("Expected serialization error 40001, got %#v", err) } + + ensureConnValid(t, c1) + ensureConnValid(t, c2) } func TestTransactionSuccessfulRollback(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) createSql := ` create temporary table foo( id integer, - unique (id) initially deferred + unique (id) ); ` - if _, err := conn.Exec(createSql); err != nil { + if _, err := conn.Exec(context.Background(), createSql); err != nil { t.Fatalf("Failed to create table: %v", err) } - tx, err := conn.Begin() + tx, err := conn.Begin(context.Background()) if err != nil { t.Fatalf("conn.Begin failed: %v", err) } - _, err = tx.Exec("insert into foo(id) values (1)") + _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)") if err != nil { t.Fatalf("tx.Exec failed: %v", err) } - err = tx.Rollback() + err = tx.Rollback(context.Background()) if err != nil { t.Fatalf("tx.Rollback failed: %v", err) } var n int64 - err = conn.QueryRow("select count(*) from foo").Scan(&n) + err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) if err != nil { t.Fatalf("QueryRow Scan failed: %v", err) } @@ -187,198 +249,396 @@ func TestTransactionSuccessfulRollback(t *testing.T) { } } -func TestBeginExIsoLevels(t *testing.T) { +func TestTransactionRollbackFailsClosesConnection(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) + ctx, cancel := context.WithCancel(context.Background()) + + tx, err := conn.Begin(ctx) + require.NoError(t, err) + + cancel() + + err = tx.Rollback(ctx) + require.Error(t, err) + + require.True(t, conn.IsClosed()) +} + +func TestBeginIsoLevels(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + pgxtest.SkipCockroachDB(t, conn, "Server always uses SERIALIZABLE isolation (https://www.cockroachlabs.com/docs/stable/demo-serializable.html)") + isoLevels := []pgx.TxIsoLevel{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted} for _, iso := range isoLevels { - tx, err := conn.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: iso}) + tx, err := conn.BeginTx(context.Background(), pgx.TxOptions{IsoLevel: iso}) if err != nil { - t.Fatalf("conn.BeginEx failed: %v", err) + t.Fatalf("conn.Begin failed: %v", err) } var level pgx.TxIsoLevel - conn.QueryRow("select current_setting('transaction_isolation')").Scan(&level) + conn.QueryRow(context.Background(), "select current_setting('transaction_isolation')").Scan(&level) if level != iso { t.Errorf("Expected to be in isolation level %v but was %v", iso, level) } - err = tx.Rollback() + err = tx.Rollback(context.Background()) if err != nil { t.Fatalf("tx.Rollback failed: %v", err) } } } -func TestBeginExReadOnly(t *testing.T) { +func TestBeginFunc(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - tx, err := conn.BeginEx(context.Background(), &pgx.TxOptions{AccessMode: pgx.ReadOnly}) + createSql := ` + create temporary table foo( + id integer, + unique (id) + ); + ` + + _, err := conn.Exec(context.Background(), createSql) + require.NoError(t, err) + + err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error { + _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") + require.NoError(t, err) + return nil + }) + require.NoError(t, err) + + var n int64 + err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) +} + +func TestBeginFuncRollbackOnError(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + createSql := ` + create temporary table foo( + id integer, + unique (id) + ); + ` + + _, err := conn.Exec(context.Background(), createSql) + require.NoError(t, err) + + err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error { + _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") + require.NoError(t, err) + return errors.New("some error") + }) + require.EqualError(t, err, "some error") + + var n int64 + err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 0, n) +} + +func TestBeginReadOnly(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + tx, err := conn.BeginTx(context.Background(), pgx.TxOptions{AccessMode: pgx.ReadOnly}) if err != nil { - t.Fatalf("conn.BeginEx failed: %v", err) + t.Fatalf("conn.Begin failed: %v", err) } - defer tx.Rollback() + defer tx.Rollback(context.Background()) - _, err = conn.Exec("create table foo(id serial primary key)") - if pgErr, ok := err.(pgx.PgError); !ok || pgErr.Code != "25006" { + _, err = conn.Exec(context.Background(), "create table foo(id serial primary key)") + if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "25006" { t.Errorf("Expected error SQLSTATE 25006, but got %#v", err) } } -func TestConnBeginExContextCancel(t *testing.T) { +func TestBeginTxBeginQuery(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + tx, err := conn.BeginTx(ctx, pgx.TxOptions{BeginQuery: "begin read only"}) + require.NoError(t, err) + defer tx.Rollback(ctx) + + var readOnly bool + conn.QueryRow(ctx, "select current_setting('transaction_read_only')::bool").Scan(&readOnly) + require.True(t, readOnly) + + err = tx.Rollback(ctx) + require.NoError(t, err) + }) +} + +func TestTxNestedTransactionCommit(t *testing.T) { t.Parallel() - script := &pgmock.Script{ - Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + createSql := ` + create temporary table foo( + id integer, + unique (id) + ); + ` + + if _, err := conn.Exec(context.Background(), createSql); err != nil { + t.Fatalf("Failed to create table: %v", err) } - script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) - script.Steps = append(script.Steps, - pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}), - pgmock.WaitForClose(), - ) - server, err := pgmock.NewServer(script) + tx, err := conn.Begin(context.Background()) if err != nil { t.Fatal(err) } - defer server.Close() - errChan := make(chan error, 1) - go func() { - errChan <- server.ServeOne() - }() + _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)") + if err != nil { + t.Fatalf("tx.Exec failed: %v", err) + } - mockConfig, err := pgx.ParseURI(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + nestedTx, err := tx.Begin(context.Background()) if err != nil { t.Fatal(err) } - conn := mustConnect(t, mockConfig) + _, err = nestedTx.Exec(context.Background(), "insert into foo(id) values (2)") + if err != nil { + t.Fatalf("nestedTx.Exec failed: %v", err) + } - ctx, _ := context.WithTimeout(context.Background(), 50*time.Millisecond) + doubleNestedTx, err := nestedTx.Begin(context.Background()) + if err != nil { + t.Fatal(err) + } - _, err = conn.BeginEx(ctx, nil) - if err != context.DeadlineExceeded { - t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) + _, err = doubleNestedTx.Exec(context.Background(), "insert into foo(id) values (3)") + if err != nil { + t.Fatalf("doubleNestedTx.Exec failed: %v", err) } - if conn.IsAlive() { - t.Error("expected conn to be dead after BeginEx failure") + err = doubleNestedTx.Commit(context.Background()) + if err != nil { + t.Fatalf("doubleNestedTx.Commit failed: %v", err) } - if err := <-errChan; err != nil { - t.Errorf("mock server err: %v", err) + err = nestedTx.Commit(context.Background()) + if err != nil { + t.Fatalf("nestedTx.Commit failed: %v", err) + } + + err = tx.Commit(context.Background()) + if err != nil { + t.Fatalf("tx.Commit failed: %v", err) + } + + var n int64 + err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if n != 3 { + t.Fatalf("Did not receive correct number of rows: %v", n) } } -func TestTxCommitExCancel(t *testing.T) { +func TestTxNestedTransactionRollback(t *testing.T) { t.Parallel() - script := &pgmock.Script{ - Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + createSql := ` + create temporary table foo( + id integer, + unique (id) + ); + ` + + if _, err := conn.Exec(context.Background(), createSql); err != nil { + t.Fatalf("Failed to create table: %v", err) } - script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) - script.Steps = append(script.Steps, - pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}), - pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: "BEGIN"}), - pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'T'}), - pgmock.WaitForClose(), - ) - server, err := pgmock.NewServer(script) + tx, err := conn.Begin(context.Background()) if err != nil { t.Fatal(err) } - defer server.Close() - errChan := make(chan error, 1) - go func() { - errChan <- server.ServeOne() - }() + _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)") + if err != nil { + t.Fatalf("tx.Exec failed: %v", err) + } - mockConfig, err := pgx.ParseURI(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + nestedTx, err := tx.Begin(context.Background()) if err != nil { t.Fatal(err) } - conn := mustConnect(t, mockConfig) - defer conn.Close() + _, err = nestedTx.Exec(context.Background(), "insert into foo(id) values (2)") + if err != nil { + t.Fatalf("nestedTx.Exec failed: %v", err) + } - tx, err := conn.Begin() + err = nestedTx.Rollback(context.Background()) if err != nil { - t.Fatal(err) + t.Fatalf("nestedTx.Rollback failed: %v", err) } - ctx, _ := context.WithTimeout(context.Background(), 50*time.Millisecond) - err = tx.CommitEx(ctx) - if err != context.DeadlineExceeded { - t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) + _, err = tx.Exec(context.Background(), "insert into foo(id) values (3)") + if err != nil { + t.Fatalf("tx.Exec failed: %v", err) } - if conn.IsAlive() { - t.Error("expected conn to be dead after CommitEx failure") + err = tx.Commit(context.Background()) + if err != nil { + t.Fatalf("tx.Commit failed: %v", err) } - if err := <-errChan; err != nil { - t.Errorf("mock server err: %v", err) + var n int64 + err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if n != 2 { + t.Fatalf("Did not receive correct number of rows: %v", n) } } -func TestTxStatus(t *testing.T) { +func TestTxBeginFuncNestedTransactionCommit(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) + db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, db) - tx, err := conn.Begin() - if err != nil { - t.Fatal(err) - } + createSql := ` + create temporary table foo( + id integer, + unique (id) + ); + ` - if status := tx.Status(); status != pgx.TxStatusInProgress { - t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInProgress, status) - } + _, err := db.Exec(context.Background(), createSql) + require.NoError(t, err) - if err := tx.Rollback(); err != nil { - t.Fatal(err) - } + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { + _, err := db.Exec(context.Background(), "insert into foo(id) values (1)") + require.NoError(t, err) - if status := tx.Status(); status != pgx.TxStatusRollbackSuccess { - t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusRollbackSuccess, status) - } + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { + _, err := db.Exec(context.Background(), "insert into foo(id) values (2)") + require.NoError(t, err) + + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { + _, err := db.Exec(context.Background(), "insert into foo(id) values (3)") + require.NoError(t, err) + return nil + }) + require.NoError(t, err) + + return nil + }) + require.NoError(t, err) + return nil + }) + require.NoError(t, err) + + var n int64 + err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 3, n) } -func TestTxErr(t *testing.T) { +func TestTxBeginFuncNestedTransactionRollback(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) + db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, db) - tx, err := conn.Begin() - if err != nil { - t.Fatal(err) - } + createSql := ` + create temporary table foo( + id integer, + unique (id) + ); + ` - // Purposely break transaction - if _, err := tx.Exec("syntax error"); err == nil { - t.Fatal("Unexpected success") - } + _, err := db.Exec(context.Background(), createSql) + require.NoError(t, err) - if err := tx.Commit(); err != pgx.ErrTxCommitRollback { - t.Fatalf("Expected error %v, got %v", pgx.ErrTxCommitRollback, err) - } + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { + _, err := db.Exec(context.Background(), "insert into foo(id) values (1)") + require.NoError(t, err) - if status := tx.Status(); status != pgx.TxStatusCommitFailure { - t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusRollbackSuccess, status) - } + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { + _, err := db.Exec(context.Background(), "insert into foo(id) values (2)") + require.NoError(t, err) + return errors.New("do a rollback") + }) + require.EqualError(t, err, "do a rollback") - if err := tx.Err(); err != pgx.ErrTxCommitRollback { - t.Fatalf("Expected error %v, got %v", pgx.ErrTxCommitRollback, err) - } + _, err = db.Exec(context.Background(), "insert into foo(id) values (3)") + require.NoError(t, err) + + return nil + }) + require.NoError(t, err) + + var n int64 + err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 2, n) +} + +func TestTxSendBatchClosed(t *testing.T) { + t.Parallel() + + db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, db) + + tx, err := db.Begin(context.Background()) + require.NoError(t, err) + defer tx.Rollback(context.Background()) + + err = tx.Commit(context.Background()) + require.NoError(t, err) + + batch := &pgx.Batch{} + batch.Queue("select 1") + batch.Queue("select 2") + batch.Queue("select 3") + + br := tx.SendBatch(context.Background(), batch) + defer br.Close() + + var n int + + _, err = br.Exec() + require.Error(t, err) + + err = br.QueryRow().Scan(&n) + require.Error(t, err) + + _, err = br.Query() + require.Error(t, err) } diff --git a/values.go b/values.go index 6a1c4f084..6e2ff3003 100644 --- a/values.go +++ b/values.go @@ -1,15 +1,10 @@ package pgx import ( - "database/sql/driver" - "fmt" - "math" - "reflect" - "time" + "errors" - "github.com/jackc/pgx/pgio" - "github.com/jackc/pgx/pgtype" - "github.com/pkg/errors" + "github.com/jackc/pgx/v5/internal/pgio" + "github.com/jackc/pgx/v5/pgtype" ) // PostgreSQL format codes @@ -18,242 +13,51 @@ const ( BinaryFormatCode = 1 ) -// SerializationError occurs on failure to encode or decode a value -type SerializationError string - -func (e SerializationError) Error() string { - return string(e) -} - -func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, error) { - if arg == nil { - return nil, nil - } - - switch arg := arg.(type) { - case driver.Valuer: - return callValuerValue(arg) - case pgtype.TextEncoder: - buf, err := arg.EncodeText(ci, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - return string(buf), nil - case int64: - return arg, nil - case float64: - return arg, nil - case bool: - return arg, nil - case time.Time: - return arg, nil - case string: - return arg, nil - case []byte: - return arg, nil - case int8: - return int64(arg), nil - case int16: - return int64(arg), nil - case int32: - return int64(arg), nil - case int: - return int64(arg), nil - case uint8: - return int64(arg), nil - case uint16: - return int64(arg), nil - case uint32: - return int64(arg), nil - case uint64: - if arg > math.MaxInt64 { - return nil, errors.Errorf("arg too big for int64: %v", arg) - } - return int64(arg), nil - case uint: - if uint64(arg) > math.MaxInt64 { - return nil, errors.Errorf("arg too big for int64: %v", arg) - } - return int64(arg), nil - case float32: - return float64(arg), nil - } - - refVal := reflect.ValueOf(arg) - - if refVal.Kind() == reflect.Ptr { - if refVal.IsNil() { - return nil, nil - } - arg = refVal.Elem().Interface() - return convertSimpleArgument(ci, arg) +func convertSimpleArgument(m *pgtype.Map, arg any) (any, error) { + buf, err := m.Encode(0, TextFormatCode, arg, []byte{}) + if err != nil { + return nil, err } - - if strippedArg, ok := stripNamedType(&refVal); ok { - return convertSimpleArgument(ci, strippedArg) + if buf == nil { + return nil, nil } - return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg)) + return string(buf), nil } -func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid pgtype.OID, arg interface{}) ([]byte, error) { - if arg == nil { - return pgio.AppendInt32(buf, -1), nil - } - - switch arg := arg.(type) { - case pgtype.BinaryEncoder: - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - argBuf, err := arg.EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if argBuf != nil { - buf = argBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - return buf, nil - case pgtype.TextEncoder: - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - argBuf, err := arg.EncodeText(ci, buf) - if err != nil { - return nil, err - } - if argBuf != nil { - buf = argBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - return buf, nil - case string: - buf = pgio.AppendInt32(buf, int32(len(arg))) - buf = append(buf, arg...) - return buf, nil - } - - refVal := reflect.ValueOf(arg) - - if refVal.Kind() == reflect.Ptr { - if refVal.IsNil() { - return pgio.AppendInt32(buf, -1), nil - } - arg = refVal.Elem().Interface() - return encodePreparedStatementArgument(ci, buf, oid, arg) - } - - if dt, ok := ci.DataTypeForOID(oid); ok { - value := dt.Value - err := value.Set(arg) - if err != nil { - { - if arg, ok := arg.(driver.Valuer); ok { - v, err := callValuerValue(arg) - if err != nil { - return nil, err - } - return encodePreparedStatementArgument(ci, buf, oid, v) - } - } - - return nil, err - } - - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if argBuf != nil { - buf = argBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - return buf, nil - } - - if arg, ok := arg.(driver.Valuer); ok { - v, err := callValuerValue(arg) - if err != nil { +func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf) + if err != nil { + if argBuf2, err2 := tryScanStringCopyValueThenEncode(m, buf, oid, arg); err2 == nil { + argBuf = argBuf2 + } else { return nil, err } - return encodePreparedStatementArgument(ci, buf, oid, v) } - if strippedArg, ok := stripNamedType(&refVal); ok { - return encodePreparedStatementArgument(ci, buf, oid, strippedArg) + if argBuf != nil { + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) + return buf, nil } -// chooseParameterFormatCode determines the correct format code for an -// argument to a prepared statement. It defaults to TextFormatCode if no -// determination can be made. -func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid pgtype.OID, arg interface{}) int16 { - switch arg.(type) { - case pgtype.BinaryEncoder: - return BinaryFormatCode - case string, *string, pgtype.TextEncoder: - return TextFormatCode - } - - if dt, ok := ci.DataTypeForOID(oid); ok { - if _, ok := dt.Value.(pgtype.BinaryEncoder); ok { - if arg, ok := arg.(driver.Valuer); ok { - if err := dt.Value.Set(arg); err != nil { - if value, err := callValuerValue(arg); err == nil { - if _, ok := value.(string); ok { - return TextFormatCode - } - } - } - } - - return BinaryFormatCode +func tryScanStringCopyValueThenEncode(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) { + s, ok := arg.(string) + if !ok { + textBuf, err := m.Encode(oid, TextFormatCode, arg, nil) + if err != nil { + return nil, errors.New("not a string and cannot be encoded as text") } + s = string(textBuf) } - return TextFormatCode -} - -func stripNamedType(val *reflect.Value) (interface{}, bool) { - switch val.Kind() { - case reflect.Int: - convVal := int(val.Int()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Int8: - convVal := int8(val.Int()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Int16: - convVal := int16(val.Int()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Int32: - convVal := int32(val.Int()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Int64: - convVal := int64(val.Int()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Uint: - convVal := uint(val.Uint()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Uint8: - convVal := uint8(val.Uint()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Uint16: - convVal := uint16(val.Uint()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Uint32: - convVal := uint32(val.Uint()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Uint64: - convVal := uint64(val.Uint()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.String: - convVal := val.String() - return convVal, reflect.TypeOf(convVal) != val.Type() + var v any + err := m.Scan(oid, TextFormatCode, []byte(s), &v) + if err != nil { + return nil, err } - return nil, false + return m.Encode(oid, BinaryFormatCode, v, buf) } diff --git a/values_test.go b/values_test.go index ddaf5468b..116577d42 100644 --- a/values_test.go +++ b/values_test.go @@ -2,71 +2,82 @@ package pgx_test import ( "bytes" + "context" + "fmt" "net" + "os" "reflect" + "strings" "testing" "time" - "github.com/jackc/pgx" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestDateTranscode(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - dates := []time.Time{ - time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), - time.Date(1000, 1, 1, 0, 0, 0, 0, time.UTC), - time.Date(1600, 1, 1, 0, 0, 0, 0, time.UTC), - time.Date(1700, 1, 1, 0, 0, 0, 0, time.UTC), - time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), - time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), - time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), - time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), - time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), - time.Date(2001, 1, 2, 0, 0, 0, 0, time.UTC), - time.Date(2004, 2, 29, 0, 0, 0, 0, time.UTC), - time.Date(2013, 7, 4, 0, 0, 0, 0, time.UTC), - time.Date(2013, 12, 25, 0, 0, 0, 0, time.UTC), - time.Date(2029, 1, 1, 0, 0, 0, 0, time.UTC), - time.Date(2081, 1, 1, 0, 0, 0, 0, time.UTC), - time.Date(2096, 2, 29, 0, 0, 0, 0, time.UTC), - time.Date(2550, 1, 1, 0, 0, 0, 0, time.UTC), - time.Date(9999, 12, 31, 0, 0, 0, 0, time.UTC), - } + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + dates := []time.Time{ + time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1000, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1600, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1700, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), + time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2001, 1, 2, 0, 0, 0, 0, time.UTC), + time.Date(2004, 2, 29, 0, 0, 0, 0, time.UTC), + time.Date(2013, 7, 4, 0, 0, 0, 0, time.UTC), + time.Date(2013, 12, 25, 0, 0, 0, 0, time.UTC), + time.Date(2029, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2081, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2096, 2, 29, 0, 0, 0, 0, time.UTC), + time.Date(2550, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(9999, 12, 31, 0, 0, 0, 0, time.UTC), + } - for _, actualDate := range dates { - var d time.Time + for _, actualDate := range dates { + var d time.Time - err := conn.QueryRow("select $1::date", actualDate).Scan(&d) - if err != nil { - t.Fatalf("Unexpected failure on QueryRow Scan: %v", err) - } - if !actualDate.Equal(d) { - t.Errorf("Did not transcode date successfully: %v is not %v", d, actualDate) + err := conn.QueryRow(context.Background(), "select $1::date", actualDate).Scan(&d) + if err != nil { + t.Fatalf("Unexpected failure on QueryRow Scan: %v", err) + } + if !actualDate.Equal(d) { + t.Errorf("Did not transcode date successfully: %v is not %v", d, actualDate) + } } - } + }) } func TestTimestampTzTranscode(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - inputTime := time.Date(2013, 1, 2, 3, 4, 5, 6000, time.Local) + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + inputTime := time.Date(2013, 1, 2, 3, 4, 5, 6000, time.Local) - var outputTime time.Time + var outputTime time.Time - err := conn.QueryRow("select $1::timestamptz", inputTime).Scan(&outputTime) - if err != nil { - t.Fatalf("QueryRow Scan failed: %v", err) - } - if !inputTime.Equal(outputTime) { - t.Errorf("Did not transcode time successfully: %v is not %v", outputTime, inputTime) - } + err := conn.QueryRow(context.Background(), "select $1::timestamptz", inputTime).Scan(&outputTime) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if !inputTime.Equal(outputTime) { + t.Errorf("Did not transcode time successfully: %v is not %v", outputTime, inputTime) + } + }) } // TODO - move these tests to pgtype @@ -74,16 +85,31 @@ func TestTimestampTzTranscode(t *testing.T) { func TestJSONAndJSONBTranscode(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + for _, typename := range []string{"json", "jsonb"} { + if _, ok := conn.TypeMap().TypeForName(typename); !ok { + continue // No JSON/JSONB type -- must be running against old PostgreSQL + } + + testJSONString(t, conn, typename) + testJSONStringPointer(t, conn, typename) + } + }) +} + +func TestJSONAndJSONBTranscodeExtendedOnly(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) for _, typename := range []string{"json", "jsonb"} { - if _, ok := conn.ConnInfo.DataTypeForName(typename); !ok { + if _, ok := conn.TypeMap().TypeForName(typename); !ok { continue // No JSON/JSONB type -- must be running against old PostgreSQL } - - testJSONString(t, conn, typename) - testJSONStringPointer(t, conn, typename) testJSONSingleLevelStringMap(t, conn, typename) testJSONNestedMap(t, conn, typename) testJSONStringArray(t, conn, typename) @@ -93,11 +119,11 @@ func TestJSONAndJSONBTranscode(t *testing.T) { } } -func testJSONString(t *testing.T, conn *pgx.Conn, typename string) { +func testJSONString(t testing.TB, conn *pgx.Conn, typename string) { input := `{"key": "value"}` expectedOutput := map[string]string{"key": "value"} var output map[string]string - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) if err != nil { t.Errorf("%s: QueryRow Scan failed: %v", typename, err) return @@ -109,11 +135,11 @@ func testJSONString(t *testing.T, conn *pgx.Conn, typename string) { } } -func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string) { +func testJSONStringPointer(t testing.TB, conn *pgx.Conn, typename string) { input := `{"key": "value"}` expectedOutput := map[string]string{"key": "value"} var output map[string]string - err := conn.QueryRow("select $1::"+typename, &input).Scan(&output) + err := conn.QueryRow(context.Background(), "select $1::"+typename, &input).Scan(&output) if err != nil { t.Errorf("%s: QueryRow Scan failed: %v", typename, err) return @@ -128,7 +154,7 @@ func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string) { func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string) { input := map[string]string{"key": "value"} var output map[string]string - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) if err != nil { t.Errorf("%s: QueryRow Scan failed: %v", typename, err) return @@ -141,20 +167,20 @@ func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string) } func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string) { - input := map[string]interface{}{ + input := map[string]any{ "name": "Uncanny", - "stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)}, - "inventory": []interface{}{"phone", "key"}, + "stats": map[string]any{"hp": float64(107), "maxhp": float64(150)}, + "inventory": []any{"phone", "key"}, } - var output map[string]interface{} - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + var output map[string]any + err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) if err != nil { t.Errorf("%s: QueryRow Scan failed: %v", typename, err) return } if !reflect.DeepEqual(input, output) { - t.Errorf("%s: Did not transcode map[string]interface{} successfully: %v is not %v", typename, input, output) + t.Errorf("%s: Did not transcode map[string]any successfully: %v is not %v", typename, input, output) return } } @@ -162,7 +188,7 @@ func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string) { func testJSONStringArray(t *testing.T, conn *pgx.Conn, typename string) { input := []string{"foo", "bar", "baz"} var output []string - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) if err != nil { t.Errorf("%s: QueryRow Scan failed: %v", typename, err) } @@ -175,7 +201,7 @@ func testJSONStringArray(t *testing.T, conn *pgx.Conn, typename string) { func testJSONInt64Array(t *testing.T, conn *pgx.Conn, typename string) { input := []int64{1, 2, 234432} var output []int64 - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) if err != nil { t.Errorf("%s: QueryRow Scan failed: %v", typename, err) } @@ -188,9 +214,14 @@ func testJSONInt64Array(t *testing.T, conn *pgx.Conn, typename string) { func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string) { input := []int{1, 2, 234432} var output []int16 - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) - if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" { - t.Errorf("%s: Expected *json.UnmarkalTypeError, but got %v", typename, err) + err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) + fieldName := typename + if conn.PgConn().ParameterStatus("crdb_version") != "" && typename == "json" { + fieldName = "jsonb" // Seems like CockroachDB treats json as jsonb. + } + expectedMessage := fmt.Sprintf("can't scan into dest[0] (col: %s): json: cannot unmarshal number 234432 into Go value of type int16", fieldName) + if err == nil || err.Error() != expectedMessage { + t.Errorf("%s: Expected *json.UnmarshalTypeError, but got %v", typename, err) } } @@ -207,7 +238,7 @@ func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string) { var output person - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) if err != nil { t.Errorf("%s: QueryRow Scan failed: %v", typename, err) } @@ -217,7 +248,7 @@ func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string) { } } -func mustParseCIDR(t *testing.T, s string) *net.IPNet { +func mustParseCIDR(t testing.TB, s string) *net.IPNet { _, ipnet, err := net.ParseCIDR(s) if err != nil { t.Fatal(err) @@ -226,725 +257,860 @@ func mustParseCIDR(t *testing.T, s string) *net.IPNet { return ipnet } -func TestStringToNotTextTypeTranscode(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - input := "01086ee0-4963-4e35-9116-30c173a8d0bd" - - var output string - err := conn.QueryRow("select $1::uuid", input).Scan(&output) - if err != nil { - t.Fatal(err) - } - if input != output { - t.Errorf("uuid: Did not transcode string successfully: %s is not %s", input, output) - } - - err = conn.QueryRow("select $1::uuid", &input).Scan(&output) - if err != nil { - t.Fatal(err) - } - if input != output { - t.Errorf("uuid: Did not transcode pointer to string successfully: %s is not %s", input, output) - } -} - func TestInetCIDRTranscodeIPNet(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + tests := []struct { + sql string + value *net.IPNet + }{ + {"select $1::inet", mustParseCIDR(t, "0.0.0.0/32")}, + {"select $1::inet", mustParseCIDR(t, "127.0.0.1/32")}, + {"select $1::inet", mustParseCIDR(t, "12.34.56.0/32")}, + {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, + {"select $1::inet", mustParseCIDR(t, "255.0.0.0/8")}, + {"select $1::inet", mustParseCIDR(t, "255.255.255.255/32")}, + {"select $1::inet", mustParseCIDR(t, "::/128")}, + {"select $1::inet", mustParseCIDR(t, "::/0")}, + {"select $1::inet", mustParseCIDR(t, "::1/128")}, + {"select $1::inet", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")}, + {"select $1::cidr", mustParseCIDR(t, "0.0.0.0/32")}, + {"select $1::cidr", mustParseCIDR(t, "127.0.0.1/32")}, + {"select $1::cidr", mustParseCIDR(t, "12.34.56.0/32")}, + {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, + {"select $1::cidr", mustParseCIDR(t, "255.0.0.0/8")}, + {"select $1::cidr", mustParseCIDR(t, "255.255.255.255/32")}, + {"select $1::cidr", mustParseCIDR(t, "::/128")}, + {"select $1::cidr", mustParseCIDR(t, "::/0")}, + {"select $1::cidr", mustParseCIDR(t, "::1/128")}, + {"select $1::cidr", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")}, + } - tests := []struct { - sql string - value *net.IPNet - }{ - {"select $1::inet", mustParseCIDR(t, "0.0.0.0/32")}, - {"select $1::inet", mustParseCIDR(t, "127.0.0.1/32")}, - {"select $1::inet", mustParseCIDR(t, "12.34.56.0/32")}, - {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, - {"select $1::inet", mustParseCIDR(t, "255.0.0.0/8")}, - {"select $1::inet", mustParseCIDR(t, "255.255.255.255/32")}, - {"select $1::inet", mustParseCIDR(t, "::/128")}, - {"select $1::inet", mustParseCIDR(t, "::/0")}, - {"select $1::inet", mustParseCIDR(t, "::1/128")}, - {"select $1::inet", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")}, - {"select $1::cidr", mustParseCIDR(t, "0.0.0.0/32")}, - {"select $1::cidr", mustParseCIDR(t, "127.0.0.1/32")}, - {"select $1::cidr", mustParseCIDR(t, "12.34.56.0/32")}, - {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, - {"select $1::cidr", mustParseCIDR(t, "255.0.0.0/8")}, - {"select $1::cidr", mustParseCIDR(t, "255.255.255.255/32")}, - {"select $1::cidr", mustParseCIDR(t, "::/128")}, - {"select $1::cidr", mustParseCIDR(t, "::/0")}, - {"select $1::cidr", mustParseCIDR(t, "::1/128")}, - {"select $1::cidr", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")}, - } + for i, tt := range tests { + if conn.PgConn().ParameterStatus("crdb_version") != "" && strings.Contains(tt.sql, "cidr") { + t.Log("Server does not support cidr type (https://github.com/cockroachdb/cockroach/issues/18846)") + continue + } - for i, tt := range tests { - var actual net.IPNet + var actual net.IPNet - err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) - continue - } + err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } - if actual.String() != tt.value.String() { - t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) + if actual.String() != tt.value.String() { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) + } } - - ensureConnValid(t, conn) - } + }) } func TestInetCIDRTranscodeIP(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + tests := []struct { + sql string + value net.IP + }{ + {"select $1::inet", net.ParseIP("0.0.0.0")}, + {"select $1::inet", net.ParseIP("127.0.0.1")}, + {"select $1::inet", net.ParseIP("12.34.56.0")}, + {"select $1::inet", net.ParseIP("255.255.255.255")}, + {"select $1::inet", net.ParseIP("::1")}, + {"select $1::inet", net.ParseIP("2607:f8b0:4009:80b::200e")}, + {"select $1::cidr", net.ParseIP("0.0.0.0")}, + {"select $1::cidr", net.ParseIP("127.0.0.1")}, + {"select $1::cidr", net.ParseIP("12.34.56.0")}, + {"select $1::cidr", net.ParseIP("255.255.255.255")}, + {"select $1::cidr", net.ParseIP("::1")}, + {"select $1::cidr", net.ParseIP("2607:f8b0:4009:80b::200e")}, + } - tests := []struct { - sql string - value net.IP - }{ - {"select $1::inet", net.ParseIP("0.0.0.0")}, - {"select $1::inet", net.ParseIP("127.0.0.1")}, - {"select $1::inet", net.ParseIP("12.34.56.0")}, - {"select $1::inet", net.ParseIP("255.255.255.255")}, - {"select $1::inet", net.ParseIP("::1")}, - {"select $1::inet", net.ParseIP("2607:f8b0:4009:80b::200e")}, - {"select $1::cidr", net.ParseIP("0.0.0.0")}, - {"select $1::cidr", net.ParseIP("127.0.0.1")}, - {"select $1::cidr", net.ParseIP("12.34.56.0")}, - {"select $1::cidr", net.ParseIP("255.255.255.255")}, - {"select $1::cidr", net.ParseIP("::1")}, - {"select $1::cidr", net.ParseIP("2607:f8b0:4009:80b::200e")}, - } + for i, tt := range tests { + if conn.PgConn().ParameterStatus("crdb_version") != "" && strings.Contains(tt.sql, "cidr") { + t.Log("Server does not support cidr type (https://github.com/cockroachdb/cockroach/issues/18846)") + continue + } - for i, tt := range tests { - var actual net.IP + var actual net.IP - err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) - continue + err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } + + if !actual.Equal(tt.value) { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) + } + + ensureConnValid(t, conn) } - if !actual.Equal(tt.value) { - t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) + failTests := []struct { + sql string + value *net.IPNet + }{ + {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, + {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, } + for i, tt := range failTests { + var actual net.IP - ensureConnValid(t, conn) - } + err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) + if err == nil { + t.Errorf("%d. Expected failure but got none", i) + continue + } - failTests := []struct { - sql string - value *net.IPNet - }{ - {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, - {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, - } - for i, tt := range failTests { - var actual net.IP - - err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) - if err == nil { - t.Errorf("%d. Expected failure but got none", i) - continue + ensureConnValid(t, conn) } - - ensureConnValid(t, conn) - } + }) } func TestInetCIDRArrayTranscodeIPNet(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - tests := []struct { - sql string - value []*net.IPNet - }{ - { - "select $1::inet[]", - []*net.IPNet{ - mustParseCIDR(t, "0.0.0.0/32"), - mustParseCIDR(t, "127.0.0.1/32"), - mustParseCIDR(t, "12.34.56.0/32"), - mustParseCIDR(t, "192.168.1.0/24"), - mustParseCIDR(t, "255.0.0.0/8"), - mustParseCIDR(t, "255.255.255.255/32"), - mustParseCIDR(t, "::/128"), - mustParseCIDR(t, "::/0"), - mustParseCIDR(t, "::1/128"), - mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + tests := []struct { + sql string + value []*net.IPNet + }{ + { + "select $1::inet[]", + []*net.IPNet{ + mustParseCIDR(t, "0.0.0.0/32"), + mustParseCIDR(t, "127.0.0.1/32"), + mustParseCIDR(t, "12.34.56.0/32"), + mustParseCIDR(t, "192.168.1.0/24"), + mustParseCIDR(t, "255.0.0.0/8"), + mustParseCIDR(t, "255.255.255.255/32"), + mustParseCIDR(t, "::/128"), + mustParseCIDR(t, "::/0"), + mustParseCIDR(t, "::1/128"), + mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), + }, }, - }, - { - "select $1::cidr[]", - []*net.IPNet{ - mustParseCIDR(t, "0.0.0.0/32"), - mustParseCIDR(t, "127.0.0.1/32"), - mustParseCIDR(t, "12.34.56.0/32"), - mustParseCIDR(t, "192.168.1.0/24"), - mustParseCIDR(t, "255.0.0.0/8"), - mustParseCIDR(t, "255.255.255.255/32"), - mustParseCIDR(t, "::/128"), - mustParseCIDR(t, "::/0"), - mustParseCIDR(t, "::1/128"), - mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), + { + "select $1::cidr[]", + []*net.IPNet{ + mustParseCIDR(t, "0.0.0.0/32"), + mustParseCIDR(t, "127.0.0.1/32"), + mustParseCIDR(t, "12.34.56.0/32"), + mustParseCIDR(t, "192.168.1.0/24"), + mustParseCIDR(t, "255.0.0.0/8"), + mustParseCIDR(t, "255.255.255.255/32"), + mustParseCIDR(t, "::/128"), + mustParseCIDR(t, "::/0"), + mustParseCIDR(t, "::1/128"), + mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), + }, }, - }, - } + } - for i, tt := range tests { - var actual []*net.IPNet + for i, tt := range tests { + if conn.PgConn().ParameterStatus("crdb_version") != "" && strings.Contains(tt.sql, "cidr") { + t.Log("Server does not support cidr type (https://github.com/cockroachdb/cockroach/issues/18846)") + continue + } - err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) - continue - } + var actual []*net.IPNet - if !reflect.DeepEqual(actual, tt.value) { - t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) - } + err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } - ensureConnValid(t, conn) - } + if !reflect.DeepEqual(actual, tt.value) { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) + } + + ensureConnValid(t, conn) + } + }) } func TestInetCIDRArrayTranscodeIP(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - tests := []struct { - sql string - value []net.IP - }{ - { - "select $1::inet[]", - []net.IP{ - net.ParseIP("0.0.0.0"), - net.ParseIP("127.0.0.1"), - net.ParseIP("12.34.56.0"), - net.ParseIP("255.255.255.255"), - net.ParseIP("2607:f8b0:4009:80b::200e"), + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + tests := []struct { + sql string + value []net.IP + }{ + { + "select $1::inet[]", + []net.IP{ + net.ParseIP("0.0.0.0"), + net.ParseIP("127.0.0.1"), + net.ParseIP("12.34.56.0"), + net.ParseIP("255.255.255.255"), + net.ParseIP("2607:f8b0:4009:80b::200e"), + }, }, - }, - { - "select $1::cidr[]", - []net.IP{ - net.ParseIP("0.0.0.0"), - net.ParseIP("127.0.0.1"), - net.ParseIP("12.34.56.0"), - net.ParseIP("255.255.255.255"), - net.ParseIP("2607:f8b0:4009:80b::200e"), + { + "select $1::cidr[]", + []net.IP{ + net.ParseIP("0.0.0.0"), + net.ParseIP("127.0.0.1"), + net.ParseIP("12.34.56.0"), + net.ParseIP("255.255.255.255"), + net.ParseIP("2607:f8b0:4009:80b::200e"), + }, }, - }, - } + } - for i, tt := range tests { - var actual []net.IP + for i, tt := range tests { + if conn.PgConn().ParameterStatus("crdb_version") != "" && strings.Contains(tt.sql, "cidr") { + t.Log("Server does not support cidr type (https://github.com/cockroachdb/cockroach/issues/18846)") + continue + } - err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) - continue - } + var actual []net.IP - if !reflect.DeepEqual(actual, tt.value) { - t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) - } + err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } - ensureConnValid(t, conn) - } + assert.Equal(t, len(tt.value), len(actual), "%d", i) + for j := range actual { + assert.True(t, actual[j].Equal(tt.value[j]), "%d", i) + } - failTests := []struct { - sql string - value []*net.IPNet - }{ - { - "select $1::inet[]", - []*net.IPNet{ - mustParseCIDR(t, "12.34.56.0/32"), - mustParseCIDR(t, "192.168.1.0/24"), + ensureConnValid(t, conn) + } + + failTests := []struct { + sql string + value []*net.IPNet + }{ + { + "select $1::inet[]", + []*net.IPNet{ + mustParseCIDR(t, "12.34.56.0/32"), + mustParseCIDR(t, "192.168.1.0/24"), + }, }, - }, - { - "select $1::cidr[]", - []*net.IPNet{ - mustParseCIDR(t, "12.34.56.0/32"), - mustParseCIDR(t, "192.168.1.0/24"), + { + "select $1::cidr[]", + []*net.IPNet{ + mustParseCIDR(t, "12.34.56.0/32"), + mustParseCIDR(t, "192.168.1.0/24"), + }, }, - }, - } + } - for i, tt := range failTests { - var actual []net.IP + for i, tt := range failTests { + var actual []net.IP - err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) - if err == nil { - t.Errorf("%d. Expected failure but got none", i) - continue - } + err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) + if err == nil { + t.Errorf("%d. Expected failure but got none", i) + continue + } - ensureConnValid(t, conn) - } + ensureConnValid(t, conn) + } + }) } func TestInetCIDRTranscodeWithJustIP(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + tests := []struct { + sql string + value string + }{ + {"select $1::inet", "0.0.0.0/32"}, + {"select $1::inet", "127.0.0.1/32"}, + {"select $1::inet", "12.34.56.0/32"}, + {"select $1::inet", "255.255.255.255/32"}, + {"select $1::inet", "::/128"}, + {"select $1::inet", "2607:f8b0:4009:80b::200e/128"}, + {"select $1::cidr", "0.0.0.0/32"}, + {"select $1::cidr", "127.0.0.1/32"}, + {"select $1::cidr", "12.34.56.0/32"}, + {"select $1::cidr", "255.255.255.255/32"}, + {"select $1::cidr", "::/128"}, + {"select $1::cidr", "2607:f8b0:4009:80b::200e/128"}, + } - tests := []struct { - sql string - value string - }{ - {"select $1::inet", "0.0.0.0/32"}, - {"select $1::inet", "127.0.0.1/32"}, - {"select $1::inet", "12.34.56.0/32"}, - {"select $1::inet", "255.255.255.255/32"}, - {"select $1::inet", "::/128"}, - {"select $1::inet", "2607:f8b0:4009:80b::200e/128"}, - {"select $1::cidr", "0.0.0.0/32"}, - {"select $1::cidr", "127.0.0.1/32"}, - {"select $1::cidr", "12.34.56.0/32"}, - {"select $1::cidr", "255.255.255.255/32"}, - {"select $1::cidr", "::/128"}, - {"select $1::cidr", "2607:f8b0:4009:80b::200e/128"}, - } + for i, tt := range tests { + if conn.PgConn().ParameterStatus("crdb_version") != "" && strings.Contains(tt.sql, "cidr") { + t.Log("Server does not support cidr type (https://github.com/cockroachdb/cockroach/issues/18846)") + continue + } - for i, tt := range tests { - expected := mustParseCIDR(t, tt.value) - var actual net.IPNet + expected := mustParseCIDR(t, tt.value) + var actual net.IPNet - err := conn.QueryRow(tt.sql, expected.IP).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) - continue - } + err := conn.QueryRow(context.Background(), tt.sql, expected.IP).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } - if actual.String() != expected.String() { - t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) - } + if actual.String() != expected.String() { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) + } - ensureConnValid(t, conn) - } + ensureConnValid(t, conn) + } + }) } func TestArrayDecoding(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - tests := []struct { - sql string - query interface{} - scan interface{} - assert func(*testing.T, interface{}, interface{}) - }{ - { - "select $1::bool[]", []bool{true, false, true}, &[]bool{}, - func(t *testing.T, query, scan interface{}) { - if !reflect.DeepEqual(query, *(scan.(*[]bool))) { - t.Errorf("failed to encode bool[]") - } + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + tests := []struct { + sql string + query any + scan any + assert func(testing.TB, any, any) + }{ + { + "select $1::bool[]", + []bool{true, false, true}, + &[]bool{}, + func(t testing.TB, query, scan any) { + if !reflect.DeepEqual(query, *(scan.(*[]bool))) { + t.Errorf("failed to encode bool[]") + } + }, }, - }, - { - "select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{}, - func(t *testing.T, query, scan interface{}) { - if !reflect.DeepEqual(query, *(scan.(*[]int16))) { - t.Errorf("failed to encode smallint[]") - } + { + "select $1::smallint[]", + []int16{2, 4, 484, 32767}, + &[]int16{}, + func(t testing.TB, query, scan any) { + if !reflect.DeepEqual(query, *(scan.(*[]int16))) { + t.Errorf("failed to encode smallint[]") + } + }, }, - }, - { - "select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{}, - func(t *testing.T, query, scan interface{}) { - if !reflect.DeepEqual(query, *(scan.(*[]uint16))) { - t.Errorf("failed to encode smallint[]") - } + { + "select $1::smallint[]", + []uint16{2, 4, 484, 32767}, + &[]uint16{}, + func(t testing.TB, query, scan any) { + if !reflect.DeepEqual(query, *(scan.(*[]uint16))) { + t.Errorf("failed to encode smallint[]") + } + }, }, - }, - { - "select $1::int[]", []int32{2, 4, 484}, &[]int32{}, - func(t *testing.T, query, scan interface{}) { - if !reflect.DeepEqual(query, *(scan.(*[]int32))) { - t.Errorf("failed to encode int[]") - } + { + "select $1::int[]", + []int32{2, 4, 484}, + &[]int32{}, + func(t testing.TB, query, scan any) { + if !reflect.DeepEqual(query, *(scan.(*[]int32))) { + t.Errorf("failed to encode int[]") + } + }, }, - }, - { - "select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{}, - func(t *testing.T, query, scan interface{}) { - if !reflect.DeepEqual(query, *(scan.(*[]uint32))) { - t.Errorf("failed to encode int[]") - } + { + "select $1::int[]", + []uint32{2, 4, 484, 2147483647}, + &[]uint32{}, + func(t testing.TB, query, scan any) { + if !reflect.DeepEqual(query, *(scan.(*[]uint32))) { + t.Errorf("failed to encode int[]") + } + }, }, - }, - { - "select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{}, - func(t *testing.T, query, scan interface{}) { - if !reflect.DeepEqual(query, *(scan.(*[]int64))) { - t.Errorf("failed to encode bigint[]") - } + { + "select $1::bigint[]", + []int64{2, 4, 484, 9223372036854775807}, + &[]int64{}, + func(t testing.TB, query, scan any) { + if !reflect.DeepEqual(query, *(scan.(*[]int64))) { + t.Errorf("failed to encode bigint[]") + } + }, }, - }, - { - "select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{}, - func(t *testing.T, query, scan interface{}) { - if !reflect.DeepEqual(query, *(scan.(*[]uint64))) { - t.Errorf("failed to encode bigint[]") - } + { + "select $1::bigint[]", + []uint64{2, 4, 484, 9223372036854775807}, + &[]uint64{}, + func(t testing.TB, query, scan any) { + if !reflect.DeepEqual(query, *(scan.(*[]uint64))) { + t.Errorf("failed to encode bigint[]") + } + }, }, - }, - { - "select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{}, - func(t *testing.T, query, scan interface{}) { - if !reflect.DeepEqual(query, *(scan.(*[]string))) { - t.Errorf("failed to encode text[]") - } + { + "select $1::text[]", + []string{"it's", "over", "9000!"}, + &[]string{}, + func(t testing.TB, query, scan any) { + if !reflect.DeepEqual(query, *(scan.(*[]string))) { + t.Errorf("failed to encode text[]") + } + }, }, - }, - { - "select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, - func(t *testing.T, query, scan interface{}) { - if !reflect.DeepEqual(query, *(scan.(*[]time.Time))) { - t.Errorf("failed to encode time.Time[] to timestamptz[]") - } + { + "select $1::timestamptz[]", + []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 0o0)}, + &[]time.Time{}, + func(t testing.TB, query, scan any) { + queryTimeSlice := query.([]time.Time) + scanTimeSlice := *(scan.(*[]time.Time)) + require.Equal(t, len(queryTimeSlice), len(scanTimeSlice)) + for i := range queryTimeSlice { + assert.Truef(t, queryTimeSlice[i].Equal(scanTimeSlice[i]), "%d", i) + } + }, }, - }, - { - "select $1::bytea[]", [][]byte{{0, 1, 2, 3}, {4, 5, 6, 7}}, &[][]byte{}, - func(t *testing.T, query, scan interface{}) { - queryBytesSliceSlice := query.([][]byte) - scanBytesSliceSlice := *(scan.(*[][]byte)) - if len(queryBytesSliceSlice) != len(scanBytesSliceSlice) { - t.Errorf("failed to encode byte[][] to bytea[]: expected %d to equal %d", len(queryBytesSliceSlice), len(scanBytesSliceSlice)) - } - for i := range queryBytesSliceSlice { - qb := queryBytesSliceSlice[i] - sb := scanBytesSliceSlice[i] - if !bytes.Equal(qb, sb) { - t.Errorf("failed to encode byte[][] to bytea[]: expected %v to equal %v", qb, sb) + { + "select $1::bytea[]", + [][]byte{{0, 1, 2, 3}, {4, 5, 6, 7}}, + &[][]byte{}, + func(t testing.TB, query, scan any) { + queryBytesSliceSlice := query.([][]byte) + scanBytesSliceSlice := *(scan.(*[][]byte)) + if len(queryBytesSliceSlice) != len(scanBytesSliceSlice) { + t.Errorf("failed to encode byte[][] to bytea[]: expected %d to equal %d", len(queryBytesSliceSlice), len(scanBytesSliceSlice)) } - } + for i := range queryBytesSliceSlice { + qb := queryBytesSliceSlice[i] + sb := scanBytesSliceSlice[i] + if !bytes.Equal(qb, sb) { + t.Errorf("failed to encode byte[][] to bytea[]: expected %v to equal %v", qb, sb) + } + } + }, }, - }, - } + } - for i, tt := range tests { - err := conn.QueryRow(tt.sql, tt.query).Scan(tt.scan) - if err != nil { - t.Errorf(`%d. error reading array: %v`, i, err) - continue + for i, tt := range tests { + err := conn.QueryRow(context.Background(), tt.sql, tt.query).Scan(tt.scan) + if err != nil { + t.Errorf(`%d. error reading array: %v`, i, err) + continue + } + tt.assert(t, tt.query, tt.scan) + ensureConnValid(t, conn) } - tt.assert(t, tt.query, tt.scan) - ensureConnValid(t, conn) - } + }) } func TestEmptyArrayDecoding(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - var val []string + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - err := conn.QueryRow("select array[]::text[]").Scan(&val) - if err != nil { - t.Errorf(`error reading array: %v`, err) - } - if len(val) != 0 { - t.Errorf("Expected 0 values, got %d", len(val)) - } - - var n, m int32 + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var val []string - err = conn.QueryRow("select 1::integer, array[]::text[], 42::integer").Scan(&n, &val, &m) - if err != nil { - t.Errorf(`error reading array: %v`, err) - } - if len(val) != 0 { - t.Errorf("Expected 0 values, got %d", len(val)) - } - if n != 1 { - t.Errorf("Expected n to be 1, but it was %d", n) - } - if m != 42 { - t.Errorf("Expected n to be 42, but it was %d", n) - } + err := conn.QueryRow(context.Background(), "select array[]::text[]").Scan(&val) + if err != nil { + t.Errorf(`error reading array: %v`, err) + } + if len(val) != 0 { + t.Errorf("Expected 0 values, got %d", len(val)) + } - rows, err := conn.Query("select 1::integer, array['test']::text[] union select 2::integer, array[]::text[] union select 3::integer, array['test']::text[]") - if err != nil { - t.Errorf(`error retrieving rows with array: %v`, err) - } - defer rows.Close() + var n, m int32 - for rows.Next() { - err = rows.Scan(&n, &val) + err = conn.QueryRow(context.Background(), "select 1::integer, array[]::text[], 42::integer").Scan(&n, &val, &m) if err != nil { t.Errorf(`error reading array: %v`, err) } - } + if len(val) != 0 { + t.Errorf("Expected 0 values, got %d", len(val)) + } + if n != 1 { + t.Errorf("Expected n to be 1, but it was %d", n) + } + if m != 42 { + t.Errorf("Expected n to be 42, but it was %d", n) + } - ensureConnValid(t, conn) + rows, err := conn.Query(context.Background(), "select 1::integer, array['test']::text[] union select 2::integer, array[]::text[] union select 3::integer, array['test']::text[]") + if err != nil { + t.Errorf(`error retrieving rows with array: %v`, err) + } + defer rows.Close() + + for rows.Next() { + err = rows.Scan(&n, &val) + if err != nil { + t.Errorf(`error reading array: %v`, err) + } + } + }) } func TestPointerPointer(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - type allTypes struct { - s *string - i16 *int16 - i32 *int32 - i64 *int64 - f32 *float32 - f64 *float64 - b *bool - t *time.Time - } - - var actual, zero, expected allTypes - - { - s := "foo" - expected.s = &s - i16 := int16(1) - expected.i16 = &i16 - i32 := int32(1) - expected.i32 = &i32 - i64 := int64(1) - expected.i64 = &i64 - f32 := float32(1.23) - expected.f32 = &f32 - f64 := float64(1.23) - expected.f64 = &f64 - b := true - expected.b = &b - t := time.Unix(123, 5000) - expected.t = &t - } - - tests := []struct { - sql string - queryArgs []interface{} - scanArgs []interface{} - expected allTypes - }{ - {"select $1::text", []interface{}{expected.s}, []interface{}{&actual.s}, allTypes{s: expected.s}}, - {"select $1::text", []interface{}{zero.s}, []interface{}{&actual.s}, allTypes{}}, - {"select $1::int2", []interface{}{expected.i16}, []interface{}{&actual.i16}, allTypes{i16: expected.i16}}, - {"select $1::int2", []interface{}{zero.i16}, []interface{}{&actual.i16}, allTypes{}}, - {"select $1::int4", []interface{}{expected.i32}, []interface{}{&actual.i32}, allTypes{i32: expected.i32}}, - {"select $1::int4", []interface{}{zero.i32}, []interface{}{&actual.i32}, allTypes{}}, - {"select $1::int8", []interface{}{expected.i64}, []interface{}{&actual.i64}, allTypes{i64: expected.i64}}, - {"select $1::int8", []interface{}{zero.i64}, []interface{}{&actual.i64}, allTypes{}}, - {"select $1::float4", []interface{}{expected.f32}, []interface{}{&actual.f32}, allTypes{f32: expected.f32}}, - {"select $1::float4", []interface{}{zero.f32}, []interface{}{&actual.f32}, allTypes{}}, - {"select $1::float8", []interface{}{expected.f64}, []interface{}{&actual.f64}, allTypes{f64: expected.f64}}, - {"select $1::float8", []interface{}{zero.f64}, []interface{}{&actual.f64}, allTypes{}}, - {"select $1::bool", []interface{}{expected.b}, []interface{}{&actual.b}, allTypes{b: expected.b}}, - {"select $1::bool", []interface{}{zero.b}, []interface{}{&actual.b}, allTypes{}}, - {"select $1::timestamptz", []interface{}{expected.t}, []interface{}{&actual.t}, allTypes{t: expected.t}}, - {"select $1::timestamptz", []interface{}{zero.t}, []interface{}{&actual.t}, allTypes{}}, - } + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server auto converts ints to bigint and test relies on exact types") + + type allTypes struct { + s *string + i16 *int16 + i32 *int32 + i64 *int64 + f32 *float32 + f64 *float64 + b *bool + t *time.Time + } - for i, tt := range tests { - actual = zero + var actual, zero, expected allTypes - err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) - if err != nil { - t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs) + { + s := "foo" + expected.s = &s + i16 := int16(1) + expected.i16 = &i16 + i32 := int32(1) + expected.i32 = &i32 + i64 := int64(1) + expected.i64 = &i64 + f32 := float32(1.23) + expected.f32 = &f32 + f64 := float64(1.23) + expected.f64 = &f64 + b := true + expected.b = &b + t := time.Unix(123, 5000) + expected.t = &t } - if !reflect.DeepEqual(actual, tt.expected) { - t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, tt.sql, tt.queryArgs) + tests := []struct { + sql string + queryArgs []any + scanArgs []any + expected allTypes + }{ + {"select $1::text", []any{expected.s}, []any{&actual.s}, allTypes{s: expected.s}}, + {"select $1::text", []any{zero.s}, []any{&actual.s}, allTypes{}}, + {"select $1::int2", []any{expected.i16}, []any{&actual.i16}, allTypes{i16: expected.i16}}, + {"select $1::int2", []any{zero.i16}, []any{&actual.i16}, allTypes{}}, + {"select $1::int4", []any{expected.i32}, []any{&actual.i32}, allTypes{i32: expected.i32}}, + {"select $1::int4", []any{zero.i32}, []any{&actual.i32}, allTypes{}}, + {"select $1::int8", []any{expected.i64}, []any{&actual.i64}, allTypes{i64: expected.i64}}, + {"select $1::int8", []any{zero.i64}, []any{&actual.i64}, allTypes{}}, + {"select $1::float4", []any{expected.f32}, []any{&actual.f32}, allTypes{f32: expected.f32}}, + {"select $1::float4", []any{zero.f32}, []any{&actual.f32}, allTypes{}}, + {"select $1::float8", []any{expected.f64}, []any{&actual.f64}, allTypes{f64: expected.f64}}, + {"select $1::float8", []any{zero.f64}, []any{&actual.f64}, allTypes{}}, + {"select $1::bool", []any{expected.b}, []any{&actual.b}, allTypes{b: expected.b}}, + {"select $1::bool", []any{zero.b}, []any{&actual.b}, allTypes{}}, + {"select $1::timestamptz", []any{expected.t}, []any{&actual.t}, allTypes{t: expected.t}}, + {"select $1::timestamptz", []any{zero.t}, []any{&actual.t}, allTypes{}}, } - ensureConnValid(t, conn) - } + for i, tt := range tests { + actual = zero + + err := conn.QueryRow(context.Background(), tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs) + } + + assert.Equal(t, tt.expected.s, actual.s) + assert.Equal(t, tt.expected.i16, actual.i16) + assert.Equal(t, tt.expected.i32, actual.i32) + assert.Equal(t, tt.expected.i64, actual.i64) + assert.Equal(t, tt.expected.f32, actual.f32) + assert.Equal(t, tt.expected.f64, actual.f64) + assert.Equal(t, tt.expected.b, actual.b) + if tt.expected.t != nil || actual.t != nil { + assert.True(t, tt.expected.t.Equal(*actual.t)) + } + + ensureConnValid(t, conn) + } + }) } func TestPointerPointerNonZero(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - f := "foo" - dest := &f + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + f := "foo" + dest := &f - err := conn.QueryRow("select $1::text", nil).Scan(&dest) - if err != nil { - t.Errorf("Unexpected failure scanning: %v", err) - } - if dest != nil { - t.Errorf("Expected dest to be nil, got %#v", dest) - } + err := conn.QueryRow(context.Background(), "select $1::text", nil).Scan(&dest) + if err != nil { + t.Errorf("Unexpected failure scanning: %v", err) + } + if dest != nil { + t.Errorf("Expected dest to be nil, got %#v", dest) + } + }) } func TestEncodeTypeRename(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - type _int int - inInt := _int(1) - var outInt _int - - type _int8 int8 - inInt8 := _int8(2) - var outInt8 _int8 - - type _int16 int16 - inInt16 := _int16(3) - var outInt16 _int16 + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + type _int int + inInt := _int(1) + var outInt _int + + type _int8 int8 + inInt8 := _int8(2) + var outInt8 _int8 + + type _int16 int16 + inInt16 := _int16(3) + var outInt16 _int16 + + type _int32 int32 + inInt32 := _int32(4) + var outInt32 _int32 + + type _int64 int64 + inInt64 := _int64(5) + var outInt64 _int64 + + type _uint uint + inUint := _uint(6) + var outUint _uint + + type _uint8 uint8 + inUint8 := _uint8(7) + var outUint8 _uint8 + + type _uint16 uint16 + inUint16 := _uint16(8) + var outUint16 _uint16 + + type _uint32 uint32 + inUint32 := _uint32(9) + var outUint32 _uint32 + + type _uint64 uint64 + inUint64 := _uint64(10) + var outUint64 _uint64 + + type _string string + inString := _string("foo") + var outString _string + + type _bool bool + inBool := _bool(true) + var outBool _bool + + // pgx.QueryExecModeExec requires all types to be registered. + conn.TypeMap().RegisterDefaultPgType(inInt, "int8") + conn.TypeMap().RegisterDefaultPgType(inInt8, "int8") + conn.TypeMap().RegisterDefaultPgType(inInt16, "int8") + conn.TypeMap().RegisterDefaultPgType(inInt32, "int8") + conn.TypeMap().RegisterDefaultPgType(inInt64, "int8") + conn.TypeMap().RegisterDefaultPgType(inUint, "int8") + conn.TypeMap().RegisterDefaultPgType(inUint8, "int8") + conn.TypeMap().RegisterDefaultPgType(inUint16, "int8") + conn.TypeMap().RegisterDefaultPgType(inUint32, "int8") + conn.TypeMap().RegisterDefaultPgType(inUint64, "int8") + conn.TypeMap().RegisterDefaultPgType(inString, "text") + conn.TypeMap().RegisterDefaultPgType(inBool, "bool") + + err := conn.QueryRow(context.Background(), "select $1::int, $2::int, $3::int2, $4::int4, $5::int8, $6::int, $7::int, $8::int, $9::int, $10::int, $11::text, $12::bool", + inInt, inInt8, inInt16, inInt32, inInt64, inUint, inUint8, inUint16, inUint32, inUint64, inString, inBool, + ).Scan(&outInt, &outInt8, &outInt16, &outInt32, &outInt64, &outUint, &outUint8, &outUint16, &outUint32, &outUint64, &outString, &outBool) + if err != nil { + t.Fatalf("Failed with type rename: %v", err) + } - type _int32 int32 - inInt32 := _int32(4) - var outInt32 _int32 + if inInt != outInt { + t.Errorf("int rename: expected %v, got %v", inInt, outInt) + } - type _int64 int64 - inInt64 := _int64(5) - var outInt64 _int64 + if inInt8 != outInt8 { + t.Errorf("int8 rename: expected %v, got %v", inInt8, outInt8) + } - type _uint uint - inUint := _uint(6) - var outUint _uint + if inInt16 != outInt16 { + t.Errorf("int16 rename: expected %v, got %v", inInt16, outInt16) + } - type _uint8 uint8 - inUint8 := _uint8(7) - var outUint8 _uint8 + if inInt32 != outInt32 { + t.Errorf("int32 rename: expected %v, got %v", inInt32, outInt32) + } - type _uint16 uint16 - inUint16 := _uint16(8) - var outUint16 _uint16 + if inInt64 != outInt64 { + t.Errorf("int64 rename: expected %v, got %v", inInt64, outInt64) + } - type _uint32 uint32 - inUint32 := _uint32(9) - var outUint32 _uint32 + if inUint != outUint { + t.Errorf("uint rename: expected %v, got %v", inUint, outUint) + } - type _uint64 uint64 - inUint64 := _uint64(10) - var outUint64 _uint64 + if inUint8 != outUint8 { + t.Errorf("uint8 rename: expected %v, got %v", inUint8, outUint8) + } - type _string string - inString := _string("foo") - var outString _string + if inUint16 != outUint16 { + t.Errorf("uint16 rename: expected %v, got %v", inUint16, outUint16) + } - err := conn.QueryRow("select $1::int, $2::int, $3::int2, $4::int4, $5::int8, $6::int, $7::int, $8::int, $9::int, $10::int, $11::text", - inInt, inInt8, inInt16, inInt32, inInt64, inUint, inUint8, inUint16, inUint32, inUint64, inString, - ).Scan(&outInt, &outInt8, &outInt16, &outInt32, &outInt64, &outUint, &outUint8, &outUint16, &outUint32, &outUint64, &outString) - if err != nil { - t.Fatalf("Failed with type rename: %v", err) - } + if inUint32 != outUint32 { + t.Errorf("uint32 rename: expected %v, got %v", inUint32, outUint32) + } - if inInt != outInt { - t.Errorf("int rename: expected %v, got %v", inInt, outInt) - } + if inUint64 != outUint64 { + t.Errorf("uint64 rename: expected %v, got %v", inUint64, outUint64) + } - if inInt8 != outInt8 { - t.Errorf("int8 rename: expected %v, got %v", inInt8, outInt8) - } + if inString != outString { + t.Errorf("string rename: expected %v, got %v", inString, outString) + } - if inInt16 != outInt16 { - t.Errorf("int16 rename: expected %v, got %v", inInt16, outInt16) - } + if inBool != outBool { + t.Errorf("bool rename: expected %v, got %v", inBool, outBool) + } + }) +} - if inInt32 != outInt32 { - t.Errorf("int32 rename: expected %v, got %v", inInt32, outInt32) - } +// func TestRowDecodeBinary(t *testing.T) { +// t.Parallel() + +// conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) +// defer closeConn(t, conn) + +// tests := []struct { +// sql string +// expected []any +// }{ +// { +// "select row(1, 'cat', '2015-01-01 08:12:42-00'::timestamptz)", +// []any{ +// int32(1), +// "cat", +// time.Date(2015, 1, 1, 8, 12, 42, 0, time.UTC).Local(), +// }, +// }, +// { +// "select row(100.0::float, 1.09::float)", +// []any{ +// float64(100), +// float64(1.09), +// }, +// }, +// } + +// for i, tt := range tests { +// var actual []any + +// err := conn.QueryRow(context.Background(), tt.sql).Scan(&actual) +// if err != nil { +// t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql) +// continue +// } + +// for j := range tt.expected { +// assert.EqualValuesf(t, tt.expected[j], actual[j], "%d. [%d]", i, j) + +// } + +// ensureConnValid(t, conn) +// } +// } + +// https://github.com/jackc/pgx/issues/810 +func TestRowsScanNilThenScanValue(t *testing.T) { + t.Parallel() - if inInt64 != outInt64 { - t.Errorf("int64 rename: expected %v, got %v", inInt64, outInt64) - } + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() - if inUint != outUint { - t.Errorf("uint rename: expected %v, got %v", inUint, outUint) - } + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + sql := `select null as a, null as b +union +select 1, 2 +order by a nulls first +` + rows, err := conn.Query(context.Background(), sql) + require.NoError(t, err) - if inUint8 != outUint8 { - t.Errorf("uint8 rename: expected %v, got %v", inUint8, outUint8) - } + require.True(t, rows.Next()) - if inUint16 != outUint16 { - t.Errorf("uint16 rename: expected %v, got %v", inUint16, outUint16) - } + err = rows.Scan(nil, nil) + require.NoError(t, err) - if inUint32 != outUint32 { - t.Errorf("uint32 rename: expected %v, got %v", inUint32, outUint32) - } + require.True(t, rows.Next()) - if inUint64 != outUint64 { - t.Errorf("uint64 rename: expected %v, got %v", inUint64, outUint64) - } + var a int + var b int + err = rows.Scan(&a, &b) + require.NoError(t, err) - if inString != outString { - t.Errorf("string rename: expected %v, got %v", inString, outString) - } + require.EqualValues(t, 1, a) + require.EqualValues(t, 2, b) - ensureConnValid(t, conn) + rows.Close() + require.NoError(t, rows.Err()) + }) } -func TestRowDecode(t *testing.T) { +func TestScanIntoByteSlice(t *testing.T) { t.Parallel() - conn := mustConnect(t, *defaultConnConfig) + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - - tests := []struct { - sql string - expected []interface{} + // Success cases + for _, tt := range []struct { + name string + sql string + resultFormatCode int16 + output []byte }{ - { - "select row(1, 'cat', '2015-01-01 08:12:42-00'::timestamptz)", - []interface{}{ - int32(1), - "cat", - time.Date(2015, 1, 1, 8, 12, 42, 0, time.UTC).Local(), - }, - }, - { - "select row(100.0::float, 1.09::float)", - []interface{}{ - float64(100), - float64(1.09), - }, - }, - } - - for i, tt := range tests { - var actual []interface{} - - err := conn.QueryRow(tt.sql).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql) - continue - } - - if !reflect.DeepEqual(actual, tt.expected) { - t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql) - } - - ensureConnValid(t, conn) + {"int - text", "select 42", pgx.TextFormatCode, []byte("42")}, + {"int - binary", "select 42", pgx.BinaryFormatCode, []byte("42")}, + {"text - text", "select 'hi'", pgx.TextFormatCode, []byte("hi")}, + {"text - binary", "select 'hi'", pgx.BinaryFormatCode, []byte("hi")}, + {"json - text", "select '{}'::json", pgx.TextFormatCode, []byte("{}")}, + {"json - binary", "select '{}'::json", pgx.BinaryFormatCode, []byte("{}")}, + {"jsonb - text", "select '{}'::jsonb", pgx.TextFormatCode, []byte("{}")}, + {"jsonb - binary", "select '{}'::jsonb", pgx.BinaryFormatCode, []byte("{}")}, + } { + t.Run(tt.name, func(t *testing.T) { + var buf []byte + err := conn.QueryRow(context.Background(), tt.sql, pgx.QueryResultFormats{tt.resultFormatCode}).Scan(&buf) + require.NoError(t, err) + require.Equal(t, tt.output, buf) + }) } }