diff --git a/.codespellrc b/.codespellrc new file mode 100644 index 000000000..0fcbfb7be --- /dev/null +++ b/.codespellrc @@ -0,0 +1,5 @@ +[codespell] +skip = */testdata,./LICENSE,./datetime/timezones.go +ignore-words-list = ro,gost,warmup +count = +quiet-level = 3 diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 000000000..70638f98a --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,9 @@ +What has been done? Why? What problem is being solved? + +I didn't forget about (remove if it is not applicable): + +- [ ] Tests (see [documentation](https://pkg.go.dev/testing) for a testing package) +- [ ] Changelog (see [documentation](https://keepachangelog.com/en/1.0.0/) for changelog format) +- [ ] Documentation (see [documentation](https://go.dev/blog/godoc) for documentation style guide) + +Related issues: diff --git a/.github/workflows/check.yaml b/.github/workflows/check.yaml new file mode 100644 index 000000000..29b6f21d2 --- /dev/null +++ b/.github/workflows/check.yaml @@ -0,0 +1,72 @@ +name: Run checks + +on: + push: + pull_request: + +jobs: + luacheck: + runs-on: ubuntu-24.04 + if: | + github.event_name == 'push' || + github.event_name == 'pull_request' && + github.event.pull_request.head.repo.full_name != github.repository + steps: + - uses: actions/checkout@v5 + + - name: Setup Tarantool + uses: tarantool/setup-tarantool@v4 + with: + tarantool-version: '3.4' + + - name: Setup tt + run: | + curl -L https://tarantool.io/release/3/installer.sh | sudo bash + sudo apt install -y tt + tt version + + - name: Setup luacheck + run: tt rocks install luacheck 0.25.0 + + - name: Run luacheck + run: ./.rocks/bin/luacheck . + + golangci-lint: + runs-on: ubuntu-24.04 + if: | + github.event_name == 'push' || + github.event_name == 'pull_request' && + github.event.pull_request.head.repo.full_name != github.repository + steps: + - uses: actions/setup-go@v5 + + - uses: actions/checkout@v5 + + - name: golangci-lint + uses: golangci/golangci-lint-action@v3 + continue-on-error: true + with: + # The first run is for GitHub Actions error format. + args: --config=.golangci.yaml + + - name: golangci-lint + uses: golangci/golangci-lint-action@v3 + with: + # The second run is for human-readable error format with a file name + # and a line number. + args: --out-${NO_FUTURE}format colored-line-number --config=.golangci.yaml + + codespell: + runs-on: ubuntu-24.04 + if: | + github.event_name == 'push' || + github.event_name == 'pull_request' && + github.event.pull_request.head.repo.full_name != github.repository + steps: + - uses: actions/checkout@v5 + + - name: Install codespell + run: pip3 install codespell + + - name: Run codespell + run: make codespell diff --git a/.github/workflows/reusable-run.yml b/.github/workflows/reusable-run.yml new file mode 100644 index 000000000..e6ec05228 --- /dev/null +++ b/.github/workflows/reusable-run.yml @@ -0,0 +1,98 @@ +name: Reusable Test Run + +on: + workflow_call: + inputs: + os: + required: true + type: string + tarantool-version: + required: true + type: string + go-version: + required: true + type: string + coveralls: + required: false + type: boolean + default: false + fuzzing: + required: false + type: boolean + default: false + +jobs: + run-tests: + runs-on: ${{ inputs.os }} + steps: + - name: Clone the connector + uses: actions/checkout@v5 + + - name: Setup tt + run: | + curl -L https://tarantool.io/release/3/installer.sh | sudo bash + sudo apt install -y tt + + - name: Setup tt environment + run: tt init + + - name: Setup Tarantool ${{ inputs.tarantool-version }} + if: inputs.tarantool-version != 'master' + uses: tarantool/setup-tarantool@v4 + with: + tarantool-version: ${{ inputs.tarantool-version }} + + - name: Get Tarantool master commit + if: inputs.tarantool-version == 'master' + run: | + commit_hash=$(git ls-remote https://github.com/tarantool/tarantool.git --branch master | head -c 8) + echo "LATEST_COMMIT=${commit_hash}" >> $GITHUB_ENV + shell: bash + + - name: Cache Tarantool master + if: inputs.tarantool-version == 'master' + id: cache-latest + uses: actions/cache@v4 + with: + path: | + ${{ github.workspace }}/bin + ${{ github.workspace }}/include + key: cache-latest-${{ env.LATEST_COMMIT }} + + - name: Setup Tarantool master + if: inputs.tarantool-version == 'master' && steps.cache-latest.outputs.cache-hit != 'true' + run: | + sudo tt install tarantool master + + - name: Add Tarantool master to PATH + if: inputs.tarantool-version == 'master' + run: echo "${GITHUB_WORKSPACE}/bin" >> $GITHUB_PATH + + - name: Setup golang for the connector and tests + uses: actions/setup-go@v5 + with: + go-version: ${{ inputs.go-version }} + + - name: Install test dependencies + run: make deps + + - name: Run regression tests + run: make test + + - name: Run race tests + run: make testrace + + - name: Run fuzzing tests + if: ${{ inputs.fuzzing }} + run: make fuzzing TAGS="go_tarantool_decimal_fuzzing" + + - name: Run tests, collect code coverage data and send to Coveralls + if: ${{ inputs.coveralls }} + env: + COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + make coveralls + + - name: Check workability of benchmark tests + if: inputs.go-version == 'stable' + run: make bench-deps bench DURATION=1x COUNT=1 diff --git a/.github/workflows/reusable_testing.yml b/.github/workflows/reusable_testing.yml new file mode 100644 index 000000000..beb4c84c8 --- /dev/null +++ b/.github/workflows/reusable_testing.yml @@ -0,0 +1,51 @@ +name: reusable_testing + +on: + workflow_call: + inputs: + artifact_name: + description: The name of the tarantool build artifact + default: ubuntu-focal + required: false + type: string + +jobs: + run_tests: + runs-on: ubuntu-22.04 + steps: + - name: Clone the go-tarantool connector + uses: actions/checkout@v5 + with: + repository: ${{ github.repository_owner }}/go-tarantool + + - name: Download the tarantool build artifact + uses: actions/download-artifact@v5 + with: + name: ${{ inputs.artifact_name }} + + - name: Install tarantool + # Now we're lucky: all dependencies are already installed. Check package + # dependencies when migrating to other OS version. + run: sudo dpkg -i tarantool*.deb + + - name: Get the tarantool version + run: | + TNT_VERSION=$(tarantool --version | grep -e '^Tarantool') + echo "TNT_VERSION=$TNT_VERSION" >> $GITHUB_ENV + + - name: Setup golang for connector and tests + uses: actions/setup-go@v5 + with: + go-version: '1.24' + + - name: Setup tt + run: | + curl -L https://tarantool.io/release/3/installer.sh | sudo bash + sudo apt install -y tt + tt version + + - name: Install test dependencies + run: make deps + + - name: Run tests + run: make test diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml new file mode 100644 index 000000000..12e41dae6 --- /dev/null +++ b/.github/workflows/testing.yml @@ -0,0 +1,143 @@ +name: testing + +on: + push: + pull_request: + pull_request_target: + types: [labeled] + workflow_dispatch: + +jobs: + run-tests-tarantool-1-10: + if: (github.event_name == 'push') || + (github.event_name == 'pull_request' && + github.event.pull_request.head.repo.full_name != github.repository) || + (github.event_name == 'workflow_dispatch') + strategy: + fail-fast: false + matrix: + golang: ['1.24', 'stable'] + tarantool: ['1.10'] + coveralls: [false] + fuzzing: [false] + uses: ./.github/workflows/reusable-run.yml + with: + os: ubuntu-22.04 + go-version: ${{ matrix.golang }} + tarantool-version: ${{ matrix.tarantool }} + coveralls: ${{ matrix.coveralls }} + fuzzing: ${{ matrix.fuzzing }} + + run-tests: + if: (github.event_name == 'push') || + (github.event_name == 'pull_request' && + github.event.pull_request.head.repo.full_name != github.repository) || + (github.event_name == 'workflow_dispatch') + strategy: + fail-fast: false + matrix: + golang: ['1.24', 'stable'] + tarantool: ['2.11', '3.4', 'master'] + coveralls: [false] + fuzzing: [false] + include: + - golang: '1.24' + tarantool: 'master' + coveralls: true + fuzzing: false + - golang: '1.24' + tarantool: 'master' + coveralls: false + fuzzing: true + uses: ./.github/workflows/reusable-run.yml + with: + os: ubuntu-24.04 + go-version: ${{ matrix.golang }} + tarantool-version: ${{ matrix.tarantool }} + coveralls: ${{ matrix.coveralls }} + fuzzing: ${{ matrix.fuzzing }} + + testing_mac_os: + # We want to run on external PRs, but not on our own internal + # PRs as they'll be run by the push to the branch. + # + # The main trick is described here: + # https://github.com/Dart-Code/Dart-Code/pull/2375 + if: (github.event_name == 'push') || + (github.event_name == 'pull_request' && + github.event.pull_request.head.repo.full_name != github.repository) || + (github.event_name == 'workflow_dispatch') + + strategy: + fail-fast: false + matrix: + golang: + - '1.24' + - 'stable' + runs-on: + - macos-14 + - macos-15 + - macos-26 + + env: + # Make sense only for non-brew jobs. + # + # Set as absolute paths to avoid any possible confusion + # after changing a current directory. + T_VERSION: ${{ matrix.tarantool }} + T_SRCDIR: ${{ format('{0}/tarantool-{1}', github.workspace, matrix.tarantool) }} + T_TARDIR: ${{ format('{0}/tarantool-{1}-build', github.workspace, matrix.tarantool) }} + SRCDIR: ${{ format('{0}/{1}', github.workspace, github.repository) }} + + runs-on: ${{ matrix.runs-on }} + steps: + - name: Clone the connector + uses: actions/checkout@v5 + with: + path: ${{ env.SRCDIR }} + + - name: Setup cmake + uses: jwlawson/actions-setup-cmake@v2 + with: + cmake-version: '3.29.x' + + - name: Install latest tarantool from brew + run: brew install tarantool + + - name: Setup golang for the connector and tests + uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.golang }} + + # Workaround issue https://github.com/tarantool/tt/issues/640 + - name: Fix tt rocks + run: | + brew ls --verbose tarantool | grep macosx.lua | xargs rm -f + + - name: Install test dependencies + run: | + brew install tt + cd "${SRCDIR}" + make deps + + - name: Run regression tests + run: | + cd "${SRCDIR}" + make test + + - name: Run race tests + run: | + cd "${SRCDIR}" + make testrace + + - name: Run fuzzing tests + if: ${{ matrix.fuzzing }} + run: | + cd "${SRCDIR}" + make fuzzing TAGS="go_tarantool_decimal_fuzzing" + + - name: Check workability of benchmark tests + if: matrix.golang == 'stable' + run: | + cd "${SRCDIR}" + make bench-deps bench DURATION=1x COUNT=1 diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..c9f687eb1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +*.DS_Store +*.swp +.idea/ +work_dir* +.rocks +bench* +testdata/sidecar/main diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 000000000..7fedeac31 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,30 @@ +run: + timeout: 3m + +linters: + disable: + - errcheck + enable: + - forbidigo + - gocritic + - goimports + - lll + - reassign + - stylecheck + - unconvert + +linters-settings: + gocritic: + disabled-checks: + - ifElseChain + lll: + line-length: 100 + tab-width: 4 + stylecheck: + checks: ["all", "-ST1003"] + +issues: + exclude-rules: + - linters: + - lll + source: "^\\s*//\\s*(\\S+\\s){0,3}https?://\\S+$" diff --git a/.luacheckrc b/.luacheckrc new file mode 100644 index 000000000..4e8998348 --- /dev/null +++ b/.luacheckrc @@ -0,0 +1,28 @@ +redefined = false + +globals = { + 'box', + 'utf8', + 'checkers', + '_TARANTOOL' +} + +include_files = { + '**/*.lua', + '*.luacheckrc', + '*.rockspec' +} + +exclude_files = { + '**/*.rocks/' +} + +max_line_length = 120 + +ignore = { + "212/self", -- Unused argument . + "411", -- Redefining a local variable. + "421", -- Shadowing a local variable. + "431", -- Shadowing an upvalue. + "432", -- Shadowing an upvalue argument. +} diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 000000000..6121a22df --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,604 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) +and this project adheres to [Semantic +Versioning](http://semver.org/spec/v2.0.0.html) except to the first release. + +## [Unreleased] + +### Added + +* New types for MessagePack extensions compatible with go-option (#459). +* Added `box.MustNew` wrapper for `box.New` without an error (#448). + +### Changed + +* Required Go version is `1.24` now (#456). +* `box.New` returns an error instead of panic (#448). +* Now cases of `<-ctx.Done()` returns wrapped error provided by `ctx.Cause()`. + Allows you compare it using `errors.Is/As` (#457). +* Removed deprecated `pool` methods, related interfaces and tests are updated (#478). +* Removed deprecated `box.session.push()` support: Future.AppendPush() + and Future.GetIterator() methods, ResponseIterator and TimeoutResponseIterator types, + Future.pushes[] (#480). +* `LogAppendPushFailed` replaced with `LogBoxSessionPushUnsupported` (#480) + +### Fixed + +## [v2.4.1] - 2025-10-16 + +This maintenance release marks the end of active development on the `v2` +branch. + +## [v2.4.0] - 2025-07-11 + +This release focuses on adding schema/user/session operations, synchronous transaction +flag handling, and fixes watcher panic. + +### Added + +- Implemented all box.schema.user operations requests and sugar interface (#426). +- Implemented box.session.su request and sugar interface only for current session granting (#426). +- Defined `ErrConcurrentSchemaUpdate` constant for "concurrent schema update" error (#404). + Now you can check this error with `errors.Is(err, tarantool.ErrConcurrentSchemaUpdate)`. +- Implemented support for `IPROTO_IS_SYNC` flag in stream transactions, + added `IsSync(bool)` method for `BeginRequest`/`CommitRequest` (#447). +- Added missing IPROTO feature flags to greeting negotiation + (iproto.IPROTO_FEATURE_IS_SYNC, iproto.IPROTO_FEATURE_INSERT_ARROW) (#466). + +### Fixed + +- Fixed panic when calling NewWatcher() during reconnection or after + connection is closed (#438). + +## [v2.3.2] - 2025-04-14 + +This release improves the logic of `Connect` and `pool.Connect` in case of a +hung I/O connection. + +### Changed + +- Previously, `pool.Connect` attempted to establish a connection one after + another instance. It could cause the entire chain to hang if one connection + hanged. Now connections are established in parallel. After the first + successful connection, the remaining connections wait with a timeout of + `pool.Opts.CheckTimeout` (#444). + +### Fixed + +- Connect() may not cancel Dial() call on context expiration if network + connection hangs (#443). +- pool.Connect() failed to connect to any instance if a first instance + connection hangs (#444). + +## [v2.3.1] - 2025-04-03 + +The patch releases fixes expected Connect() behavior and reduces allocations. + +### Added + +- A usage of sync.Pool of msgpack.Decoder saves 2 object allocations per + a response decoding (#440). + +### Changed + +- Connect() now retry the connection if a failure occurs and opts.Reconnect > 0. + The number of attempts is equal to opts.MaxReconnects or unlimited if + opts.MaxReconnects == 0. Connect() blocks until a connection is established, + the context is cancelled, or the number of attempts is exhausted (#436). + +## [v2.3.0] - 2025-03-11 + +The release extends box.info responses and ConnectionPool.GetInfo return data. + +Be careful, we have changed the test_helpers package a little since we do not +support backward compatibility for it. + +### Added + +- Extend box with replication information (#427). +- The Instance info has been added to ConnectionInfo for ConnectionPool.GetInfo + response (#429). +- Added helpers to run Tarantool config storage (#431). + +### Changed + +- Changed helpers API `StartTarantool` and `StopTarantool`, now it uses + pointer on `TarantoolInstance`: + * `StartTarantool()` returns `*TarantoolInstance`; + * `StopTarantool()` and `StopTarantoolWithCleanup()` accepts + `*TarantoolInstance` as arguments. +- Field `Cmd` in `TarantoolInstance` struct declared as deprecated. + Suggested `Wait()`, `Stop()` and `Signal()` methods as safer to use + instead of direct `Cmd.Process` access (#431). + +### Fixed + +- Test helpers does not detect a fail to start a Tarantool instance if + another Tarantool instance already listens a port (#431). + +## [v2.2.1] - 2024-12-17 + +The release fixes a schema lost after a reconnect. + +### Fixed + +- `unable to use an index name because schema is not loaded` error after + a reconnect (#424). + +## [v2.2.0] - 2024-12-16 + +The release introduces the IPROTO_INSERT_ARROW request (arrow.InsertRequest) +and a request to archive `box.info` values (box.InfoRequest). Additionally, it +includes some improvements to logging. + +### Added + +- Error logging to `ConnectionPool.Add()` in case, when unable to establish + connection and ctx is not canceled (#389). +- Error logging for error case of `ConnectionPool.tryConnect()` calls in + `ConnectionPool.controller()` and `ConnectionPool.reconnect()` (#389). +- Methods that are implemented but not included in the pooler interface (#395). +- Implemented stringer methods for pool.Role (#405). +- Support the IPROTO_INSERT_ARROW request (#399). +- A simple implementation of using the box interface (#410). + +### Changed + +- More informative request canceling: log the probable reason for unexpected + request ID and add request ID info to context done error message (#407). + +## [v2.1.0] - 2024-03-06 + +The small release improves the ConnectionPool. The ConnectionPool now does not +require execute access for `box.info` from a user for Tarantool >= 3.0.0. + +### Changed + +- `execute` access for `box.info` is no longer required for ConnectionPool + for a Tarantool version >= 3.0.0 (#380). + +### Fixed + +- `ConnectionPool.Remove()` does not notify a `ConnectionHandler` after + an instance is already removed from the pool (#385). + +## [v2.0.0] - 2024-02-12 + +There are a lot of changes in the new major version. The main ones: + +* The `go_tarantool_call_17` build tag is no longer needed, since by default + the `CallRequest` is `Call17Request`. +* The `go_tarantool_msgpack_v5` build tag is no longer needed, since only the + `msgpack/v5` library is used. +* The `go_tarantool_ssl_disable` build tag is no longer needed, since the + connector is no longer depends on `OpenSSL` by default. You could use the + external library [go-tlsdialer](https://github.com/tarantool/go-tlsdialer) to + create a connection with the `ssl` transport. +* Required Go version is `1.20` now. +* The `Connect` function became more flexible. It now allows to create a + connection with cancellation and a custom `Dialer` implementation. +* It is required to use `Request` implementation types with the `Connection.Do` + method instead of `Connection.` methods. +* The `connection_pool` package renamed to `pool`. + +See the [migration guide](./MIGRATION.md) for more details. + +### Added + +- Type() method to the Request interface (#158). +- Enumeration types for RLimitAction/iterators (#158). +- IsNullable flag for Field (#302). +- More linters on CI (#310). +- Meaningful description for read/write socket errors (#129). +- Support `operation_data` in `crud.Error` (#330). +- Support `fetch_latest_metadata` option for crud requests with + metadata (#335). +- Support `noreturn` option for data change crud requests (#335). +- Support `crud.schema` request (#336, #351). +- Support `IPROTO_WATCH_ONCE` request type for Tarantool + version >= 3.0.0-alpha1 (#337). +- Support `yield_every` option for crud select requests (#350). +- Support `IPROTO_FEATURE_SPACE_AND_INDEX_NAMES` for Tarantool + version >= 3.0.0-alpha1 (#338). It allows to use space and index names + in requests instead of their IDs. +- `GetSchema` function to get the actual schema (#7). +- Support connection via an existing socket fd (#321). +- `Header` struct for the response header (#237). It can be accessed via + `Header()` method of the `Response` interface. +- `Response` method added to the `Request` interface (#237). +- New `LogAppendPushFailed` connection log constant (#237). + It is logged when connection fails to append a push response. +- `ErrorNo` constant that indicates that no error has occurred while getting + the response (#237). +- Ability to mock connections for tests (#237). Added new types `MockDoer`, + `MockRequest` to `test_helpers`. +- `AuthDialer` type for creating a dialer with authentication (#301). +- `ProtocolDialer` type for creating a dialer with `ProtocolInfo` receiving and + check (#301). +- `GreetingDialer` type for creating a dialer, that fills `Greeting` of a + connection (#301). +- New method `Pool.DoInstance` to execute a request on a target instance in + a pool (#376). + +### Changed + +- connection_pool renamed to pool (#239). +- Use msgpack/v5 instead of msgpack.v2 (#236). +- Call/NewCallRequest = Call17/NewCall17Request (#235). +- Change encoding of the queue.Identify() UUID argument from binary blob to + plain string. Needed for upgrade to Tarantool 3.0, where a binary blob is + decoded to a varbinary object (#313). +- Use objects of the Decimal type instead of pointers (#238). +- Use objects of the Datetime type instead of pointers (#238). +- `connection.Connect` no longer return non-working + connection objects (#136). This function now does not attempt to reconnect + and tries to establish a connection only once. Function might be canceled + via context. Context accepted as first argument. + `pool.Connect` and `pool.Add` now accept context as the first argument, which + user may cancel in process. If `pool.Connect` is canceled in progress, an + error will be returned. All created connections will be closed. +- `iproto.Feature` type now used instead of `ProtocolFeature` (#337). +- `iproto.IPROTO_FEATURE_` constants now used instead of local `Feature` + constants for `protocol` (#337). +- Change `crud` operations `Timeout` option type to `crud.OptFloat64` + instead of `crud.OptUint` (#342). +- Change all `Upsert` and `Update` requests to accept `*tarantool.Operations` + as `ops` parameters instead of `interface{}` (#348). +- Change `OverrideSchema(*Schema)` to `SetSchema(Schema)` (#7). +- Change values, stored by pointers in the `Schema`, `Space`, `Index` structs, + to be stored by their values (#7). +- Make `Dialer` mandatory for creation a single connection (#321). +- Remove `Connection.RemoteAddr()`, `Connection.LocalAddr()`. + Add `Addr()` function instead (#321). +- Remove `Connection.ClientProtocolInfo`, `Connection.ServerProtocolInfo`. + Add `ProtocolInfo()` function, which returns the server protocol info (#321). +- `NewWatcher` checks the actual features of the server, rather than relying + on the features provided by the user during connection creation (#321). +- `pool.NewWatcher` does not create watchers for connections that do not support + it (#321). +- Rename `pool.GetPoolInfo` to `pool.GetInfo`. Change return type to + `map[string]ConnectionInfo` (#321). +- `Response` is now an interface (#237). +- All responses are now implementations of the `Response` interface (#237). + `SelectResponse`, `ExecuteResponse`, `PrepareResponse`, `PushResponse` are part + of a public API. `Pos()`, `MetaData()`, `SQLInfo()` methods created for them + to get specific info. + Special types of responses are used with special requests. +- `IsPush()` method is added to the response iterator (#237). It returns + the information if the current response is a `PushResponse`. + `PushCode` constant is removed. +- Method `Get` for `Future` now returns response data (#237). To get the actual + response new `GetResponse` method has been added. Methods `AppendPush` and + `SetResponse` accept response `Header` and data as their arguments. +- `Future` constructors now accept `Request` as their argument (#237). +- Operations `Ping`, `Select`, `Insert`, `Replace`, `Delete`, `Update`, `Upsert`, + `Call`, `Call16`, `Call17`, `Eval`, `Execute` of a `Connector` and `Pooler` + return response data instead of an actual responses (#237). +- Renamed `StrangerResponse` to `MockResponse` (#237). +- `pool.Connect`, `pool.ConnetcWithOpts` and `pool.Add` use a new type + `pool.Instance` to determinate connection options (#356). +- `pool.Connect`, `pool.ConnectWithOpts` and `pool.Add` add connections to + the pool even it is unable to connect to it (#372). +- Required Go version updated from `1.13` to `1.20` (#378). + +### Deprecated + +- All Connection., Connection.Typed and + Connection.Async methods. Instead you should use requests objects + + Connection.Do() (#241). +- All ConnectionPool., ConnectionPool.Typed and + ConnectionPool.Async methods. Instead you should use requests + objects + ConnectionPool.Do() (#241). +- box.session.push() usage: Future.AppendPush() and Future.GetIterator() + methods, ResponseIterator and TimeoutResponseIterator types (#324). + +### Removed + +- multi subpackage (#240). +- msgpack.v2 support (#236). +- pool/RoundRobinStrategy (#158). +- DeadlineIO (#158). +- UUID_extId (#158). +- IPROTO constants (#158). +- Code() method from the Request interface (#158). +- `Schema` field from the `Connection` struct (#7). +- `OkCode` and `PushCode` constants (#237). +- SSL support (#301). +- `Future.Err()` method (#382). + +### Fixed + +- Flaky decimal/TestSelect (#300). +- Race condition at roundRobinStrategy.GetNextConnection() (#309). +- Incorrect decoding of an MP_DECIMAL when the `scale` value is + negative (#314). +- Incorrect options (`after`, `batch_size` and `force_map_call`) setup for + crud.SelectRequest (#320). +- Incorrect options (`vshard_router`, `fields`, `bucket_id`, `mode`, + `prefer_replica`, `balance`) setup for crud.GetRequest (#335). +- Tests with crud 1.4.0 (#336). +- Tests with case sensitive SQL (#341). +- Splice update operation accepts 3 arguments instead of 5 (#348). +- Unable to use a slice of custom types as a slice of tuples or objects for + `crud.*ManyRequest/crud.*ObjectManyRequest` (#365). + +## [v1.12.0] - 2023-06-07 + +The release introduces the ability to gracefully close Connection +and ConnectionPool and also provides methods for adding or removing an endpoint +from a ConnectionPool. + +### Added + +- Connection.CloseGraceful() unlike Connection.Close() waits for all + requests to complete (#257). +- ConnectionPool.CloseGraceful() unlike ConnectionPool.Close() waits for all + requests to complete (#257). +- ConnectionPool.Add()/ConnectionPool.Remove() to add/remove endpoints + from a pool (#290). + +### Changed + +### Fixed + +- crud tests with Tarantool 3.0 (#293). +- SQL tests with Tarantool 3.0 (#295). + +## [v1.11.0] - 2023-05-18 + +The release adds pagination support and wrappers for the +[crud](https://github.com/tarantool/crud) module. + +### Added + +- Support pagination (#246). +- A Makefile target to test with race detector (#218). +- Support CRUD API (#108). +- An ability to replace a base network connection to a Tarantool + instance (#265). +- Missed iterator constant (#285). + +### Changed + +- queue module version bumped to 1.3.0 (#278). + +### Fixed + +- Several non-critical data race issues (#218). +- Build on Apple M1 with OpenSSL (#260). +- ConnectionPool does not properly handle disconnection with Opts.Reconnect + set (#272). +- Watcher events loss with a small per-request timeout (#284). +- Connect() panics on concurrent schema update (#278). +- Wrong Ttr setup by Queue.Cfg() (#278). +- Flaky queue/Example_connectionPool (#278). +- Flaky queue/Example_simpleQueueCustomMsgPack (#277). + +## [v1.10.0] - 2022-12-31 + +The release improves compatibility with new Tarantool versions. + +### Added + +- Support iproto feature discovery (#120). +- Support errors extended information (#209). +- Support error type in MessagePack (#209). +- Support event subscription (#119). +- Support session settings (#215). +- Support pap-sha256 authorization method (Tarantool EE feature) (#243). +- Support graceful shutdown (#214). + +### Fixed + +- Decimal package uses a test variable DecimalPrecision instead of a + package-level variable decimalPrecision (#233). +- Flaky test TestClientRequestObjectsWithContext (#244). +- Flaky test multi/TestDisconnectAll (#234). +- Build on macOS with Apple M1 (#260). + +## [v1.9.0] - 2022-11-02 + +The release adds support for the latest version of the +[queue package](https://github.com/tarantool/queue) with master-replica +switching. + +### Added + +- Support the queue 1.2.1 (#177). +- ConnectionHandler interface for handling changes of connections in + ConnectionPool (#178). +- Execute, ExecuteTyped and ExecuteAsync methods to ConnectionPool (#176). +- ConnectorAdapter type to use ConnectionPool as Connector interface (#176). +- An example how to use queue and connection_pool subpackages together (#176). + +### Fixed + +- Mode type description in the connection_pool subpackage (#208). +- Missed Role type constants in the connection_pool subpackage (#208). +- ConnectionPool does not close UnknownRole connections (#208). +- Segmentation faults in ConnectionPool requests after disconnect (#208). +- Addresses in ConnectionPool may be changed from an external code (#208). +- ConnectionPool recreates connections too often (#208). +- A connection is still opened after ConnectionPool.Close() (#208). +- Future.GetTyped() after Future.Get() does not decode response + correctly (#213). +- Decimal package uses a test function GetNumberLength instead of a + package-level function getNumberLength (#219). +- Datetime location after encode + decode is unequal (#217). +- Wrong interval arithmetic with timezones (#221). +- Invalid MsgPack if STREAM_ID > 127 (#224). +- queue.Take() returns an invalid task (#222). + +## [v1.8.0] - 2022-08-17 + +The minor release with time zones and interval support for datetime. + +### Added + +- Optional msgpack.v5 usage (#124). +- TZ support for datetime (#163). +- Interval support for datetime (#165). + +### Fixed + +- Markdown of documentation for the decimal subpackage (#201). + +## [v1.7.0] - 2022-08-02 + +This release adds a number of features. The extending of the public API has +become possible with a new way of creating requests. New types of requests are +created via chain calls. Streams, context and prepared statements support are +based on this idea. + +### Added + +- SSL support (#155). +- IPROTO_PUSH messages support (#67). +- Public API with request object types (#126). +- Support decimal type in msgpack (#96). +- Support datetime type in msgpack (#118). +- Prepared SQL statements (#117). +- Context support for request objects (#48). +- Streams and interactive transactions support (#101). +- `Call16` method, support build tag `go_tarantool_call_17` to choose + default behavior for `Call` method as Call17 (#125). + +### Changed + +- `IPROTO_*` constants that identify requests renamed from `Request` to + `RequestCode` (#126). + +### Removed + +- NewErrorFuture function (#190). + +### Fixed + +- Add `ExecuteAsync` and `ExecuteTyped` to common connector interface (#62). + +## [v1.6.0] - 2022-06-01 + +This release adds a number of features. Also it significantly improves testing, +CI and documentation. + +### Added + +- Coveralls support (#149). +- Reusable testing workflow (integration testing with latest Tarantool) (#112). +- Simple CI based on GitHub actions (#114). +- Support UUID type in msgpack (#90). +- Go modules support (#91). +- queue-utube handling (#85). +- Master discovery (#113). +- SQL support (#62). + +### Changed + +- Handle everything with `go test` (#115). +- Use plain package instead of module for UUID submodule (#134). +- Reset buffer if its average use size smaller than quarter of capacity (#95). +- Update API documentation: comments and examples (#123). + +### Fixed + +- Fix queue tests (#107). +- Make test case consistent with comments (#105). + +## [1.5] - 2019-12-29 + +First release. + +### Fixed + +- Fix infinite recursive call of `Upsert` method for `ConnectionMulti`. +- Fix index out of range panic on `dial()` to short address. +- Fix cast in `defaultLogger.Report` (#49). +- Fix race condition on extremely small request timeouts (#43). +- Fix notify for `Connected` transition. +- Fix reconnection logic and add `Opts.SkipSchema` method. +- Fix future sending. +- Fix panic on disconnect + timeout. +- Fix block on msgpack error. +- Fix ratelimit. +- Fix `timeouts` method for `Connection`. +- Fix possible race condition on extremely small request timeouts. +- Fix race condition on future channel creation. +- Fix block on forever closed connection. +- Fix race condition in `Connection`. +- Fix extra map fields. +- Fix response header parsing. +- Fix reconnect logic in `Connection`. + +### Changed + +- Make logger configurable. +- Report user mismatch error immediately. +- Set limit timeout by 0.9 of connection to queue request timeout. +- Update fields could be negative. +- Require `RLimitAction` to be specified if `RateLimit` is specified. +- Use newer typed msgpack interface. +- Do not start timeouts goroutine if no timeout specified. +- Clear buffers on connection close. +- Update `BenchmarkClientParallelMassive`. +- Remove array requirements for keys and opts. +- Do not allocate `Response` inplace. +- Respect timeout on request sending. +- Use `AfterFunc(fut.timeouted)` instead of `time.NewTimer()`. +- Use `_vspace`/`_vindex` for introspection. +- Method `Tuples()` always returns table for response. + +### Removed + +- Remove `UpsertTyped()` method (#23). + +### Added + +- Add methods `Future.WaitChan` and `Future.Err` (#86). +- Get node list from nodes (#81). +- Add method `deleteConnectionFromPool`. +- Add multiconnections support. +- Add `Addr` method for the connection (#64). +- Add `Delete` method for the queue. +- Implemented typed taking from queue (#55). +- Add `OverrideSchema` method for the connection. +- Add default case to default logger. +- Add license (BSD-2 clause as for Tarantool). +- Add `GetTyped` method for the connection (#40). +- Add `ConfiguredTimeout` method for the connection, change queue interface. +- Add an example for queue. +- Add `GetQueue` method for the queue. +- Add queue support. +- Add support of Unix socket address. +- Add check for prefix "tcp:". +- Add the ability to work with the Tarantool via Unix socket. +- Add note about magic way to pack tuples. +- Add notification about connection state change. +- Add workaround for tarantool/tarantool#2060 (#32). +- Add `ConnectedNow` method for the connection. +- Add IO deadline and use `net.Conn.Set(Read|Write)Deadline`. +- Add a couple of benchmarks. +- Add timeout on connection attempt. +- Add `RLimitAction` option. +- Add `Call17` method for the connection to make a call compatible with + Tarantool 1.7. +- Add `ClientParallelMassive` benchmark. +- Add `runtime.Gosched` for decreasing `writer.flush` count. +- Add `Eval`, `EvalTyped`, `SelectTyped`, `InsertTyped`, `ReplaceTyped`, + `DeleteRequest`, `UpdateTyped`, `UpsertTyped` methods. +- Add `UpdateTyped` method. +- Add `CallTyped` method. +- Add possibility to pass `Space` and `Index` objects into `Select` etc. +- Add custom MsgPack pack/unpack functions. +- Add support of Tarantool 1.6.8 schema format. +- Add support of Tarantool 1.6.5 schema format. +- Add schema loading. +- Add `LocalAddr` and `RemoteAddr` methods for the connection. +- Add `Upsert` method for the connection. +- Add `Eval` and `EvalAsync` methods for the connection. +- Add Tarantool error codes. +- Add auth support. +- Add auth during reconnect. +- Add auth request. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..f3ffe4d96 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,162 @@ +# Contribution Guide + +## First steps + +Clone the repository and install dependencies. + +```sh +$ git clone https://github.com/tarantool/go-tarantool +$ cd go-tarantool +$ go get . +``` + +## Running tests + +You need to [install Tarantool](https://tarantool.io/en/download/) to run tests. +See the Installation section in the README for requirements. + +To install test dependencies (such as the +[tarantool/queue](https://github.com/tarantool/queue) module), run: +```bash +make deps +``` + +To run tests for the main package and each subpackage: +```bash +make test +``` + +To run tests for the main package and each subpackage with race detector: +```bash +make testrace +``` + +The tests set up all required `tarantool` processes before run and clean up +afterwards. + +If you want to run the tests for a specific package: +```bash +make test- +``` +For example, for running tests in `pool`, `uuid` and `main` packages, call +```bash +make test-pool test-uuid test-main +``` + +To run [fuzz tests](https://go.dev/doc/tutorial/fuzz) for the main package and each subpackage: +```bash +make TAGS="go_tarantool_decimal_fuzzing" fuzzing +``` + +To check if the current changes will pass the linter in CI, install +golangci-lint from [sources](https://golangci-lint.run/usage/install/) +and run it with next command: +```bash +make golangci-lint +``` + +To format the code install [goimports](https://pkg.go.dev/golang.org/x/tools/cmd/goimports) +and run it with next command: +```bash +make format +``` + +## Benchmarking + +### Quick start + +To run all benchmark tests from the current branch run: + +```bash +make bench +``` + +To measure performance difference between master and the current branch run: + +```bash +make bench-diff +``` + +Note: `benchstat` should be in `PATH`. If it is not set, call: + +```bash +export PATH="/home/${USER}/go/bin:${PATH}" +``` + +or + +```bash +export PATH="${HOME}/go/bin:${PATH}" +``` + +### Customize benchmarking + +Before running benchmark or measuring performance degradation, install benchmark dependencies: +```bash +make bench-deps BENCH_PATH=custom_path +``` + +Use the variable `BENCH_PATH` to specify the path of benchmark artifacts. +It is set to `bench` by default. + +To run benchmark tests, call: +```bash +make bench DURATION=5s COUNT=7 BENCH_PATH=custom_path TEST_PATH=. +``` + +Use the variable `DURATION` to set the duration of perf tests. That variable is mapped on +testing [flag](https://pkg.go.dev/cmd/go#hdr-Testing_flags) `-benchtime` for gotest. +It may take the values in seconds (e.g, `5s`) or count of iterations (e.g, `1000x`). +It is set to `3s` by default. + +Use the variable `COUNT` to control the count of benchmark runs for each test. +It is set to `5` by default. That variable is mapped on testing flag `-count`. +Use higher values if the benchmark numbers aren't stable. + +Use the variable `TEST_PATH` to set the directory of test files. +It is set to `./...` by default, so it runs all the Benchmark tests in the project. + +To measure performance degradation after changes in code, run: +```bash +make bench-diff BENCH_PATH=custom_path +``` + +Note: the variable `BENCH_PATH` is not purposed to be used with absolute paths. + +## Recommendations for how to achieve stable results + +Before any judgments, verify whether results are stable on given host and how +large the noise. Run `make bench-diff` without changes and look on the report. +Several times. + +There are suggestions how to achieve best results, see +https://github.com/tarantool/tarantool/wiki/Benchmarking + +## Code review checklist + +- Public API contains functions, variables, constants that are needed from + outside by users. All the rest should be left closed. +- Public functions, variables and constants contain at least a single-line + comment. +- Code is DRY (see "Do not Repeat Yourself" principle). +- New features have functional and probably performance tests. +- There are no changes in files not related to the issue. +- There are no obvious flaky tests. +- Commits with bugfixes have tests based on reproducers. +- Changelog entry is present in `CHANGELOG.md`. +- Public methods contain executable examples (contains a comment with + reference output). +- Autogenerated documentation looks good. Run `godoc -http=:6060` and point + your web browser to address "http://127.0.0.1:6060/" for evaluating. +- Commit message header may start with a prefix with a short description + follows after colon. It is applicable to changes in a README, examples, tests + and CI configuration files. Examples: `github-ci: add Tarantool 2.x-latest` + and `readme: describe how to run tests`. +- Check your comments, commit title, and even variable names to be + grammatically correct. Start sentences from a capital letter, end with a dot. + Everywhere - in the code, in the tests, in the commit message. + +See also: + +- https://github.com/tarantool/tarantool/wiki/Code-review-procedure +- https://www.tarantool.io/en/doc/latest/contributing/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..4ec1f2574 --- /dev/null +++ b/LICENSE @@ -0,0 +1,27 @@ +BSD 2-Clause License + +Copyright (c) 2014-2022, Tarantool AUTHORS +Copyright (c) 2014-2017, Dmitry Smal +Copyright (c) 2014-2017, Yura Sokolov aka funny_falcon +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/MIGRATION.md b/MIGRATION.md new file mode 100644 index 000000000..b6849da32 --- /dev/null +++ b/MIGRATION.md @@ -0,0 +1,321 @@ +# Migration guide + +## Migration from v2.x.x to v3.x.x + +* [Major changes](#major-changes-v3) + +TODO + +### Major changes + +* Required Go version is `1.24` now. +* `box.New` returns an error instead of panic +* Added `box.MustNew` wrapper for `box.New` without an error +* Removed deprecated `pool` methods, related interfaces and tests are updated. +* Removed `box.session.push()` support: Future.AppendPush() and Future.GetIterator() + methods, ResponseIterator and TimeoutResponseIterator types. + +## Migration from v1.x.x to v2.x.x + +* [Major changes](#major-changes-v2) +* [Main package](#main-package) + * [Go version](#go-version) + * [msgpack/v5](#msgpackv5) + * [Call = Call17](#call--call17) + * [IPROTO constants](#iproto-constants) + * [Request interface](#request-interface) + * [Request changes](#request-changes) + * [Response interface](#response-interface) + * [Response changes](#response-changes) + * [Future type](#future-type) + * [Protocol types](#protocol-types) + * [Connector interface](#connector-interface) + * [Connect function](#connect-function) + * [Connection schema](#connection-schema) + * [Schema type](#schema-type) +* [datetime package](#datetime-package) +* [decimal package](#decimal-package) +* [multi package](#multi-package) +* [pool package](#pool-package) +* [crud package](#crud-package) +* [test_helpers package](#test_helpers-package) + +### Major changes + +* The `go_tarantool_call_17` build tag is no longer needed, since by default + the `CallRequest` is `Call17Request`. +* The `go_tarantool_msgpack_v5` build tag is no longer needed, since only the + `msgpack/v5` library is used. +* The `go_tarantool_ssl_disable` build tag is no longer needed, since the + connector is no longer depends on `OpenSSL` by default. You could use the + external library [go-tlsdialer](https://github.com/tarantool/go-tlsdialer) to + create a connection with the `ssl` transport. +* Required Go version is `1.20` now. +* The `Connect` function became more flexible. It now allows to create a + connection with cancellation and a custom `Dialer` implementation. +* It is required to use `Request` implementation types with the `Connection.Do` + method instead of `Connection.` methods. +* The `connection_pool` package renamed to `pool`. + +The basic code for the `v1.12.2` release: +```Go +package tarantool + +import ( + "fmt" + + "github.com/tarantool/go-tarantool" + _ "github.com/tarantool/go-tarantool/v3/datetime" + _ "github.com/tarantool/go-tarantool/v3/decimal" + _ "github.com/tarantool/go-tarantool/v3/uuid" +) + +func main() { + opts := tarantool.Opts{User: "guest"} + conn, err := tarantool.Connect("127.0.0.1:3301", opts) + if err != nil { + fmt.Println("Connection refused:", err) + return + } + + resp, err := conn.Insert(999, []interface{}{99999, "BB"}) + if err != nil { + fmt.Println("Error:", err) + fmt.Println("Code:", resp.Code) + } else { + fmt.Println("Data:", resp.Data) + } +} +``` + +At now became: +```Go +package tarantool + +import ( + "context" + "fmt" + "time" + + "github.com/tarantool/go-tarantool/v3" + _ "github.com/tarantool/go-tarantool/v3/datetime" + _ "github.com/tarantool/go-tarantool/v3/decimal" + _ "github.com/tarantool/go-tarantool/v3/uuid" +) + +func main() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + dialer := tarantool.NetDialer{ + Address: "127.0.0.1:3301", + User: "guest", + } + opts := tarantool.Opts{ + Timeout: time.Second, + } + + conn, err := tarantool.Connect(ctx, dialer, opts) + if err != nil { + fmt.Println("Connection refused:", err) + return + } + + data, err := conn.Do( + tarantool.NewInsertRequest(999).Tuple([]interface{}{99999, "BB"})).Get() + if err != nil { + fmt.Println("Error:", err) + } else { + fmt.Println("Data:", data) + } +} +``` + +### Main package + +#### Go version + +Required Go version is updated from `1.13` to `1.20`. + +#### msgpack/v5 + +At now the `msgpack/v5` library is used for the `msgpack` encoding/decondig. + +Most function names and argument types in `msgpack/v5` and `msgpack.v2` +have not changed (in our code, we noticed changes in `EncodeInt`, `EncodeUint` +and `RegisterExt`). But there are a lot of changes in a logic of encoding and +decoding. On the plus side the migration seems easy, but on the minus side you +need to be very careful. + +First of all, `EncodeInt8`, `EncodeInt16`, `EncodeInt32`, `EncodeInt64` +and `EncodeUint*` analogues at `msgpack/v5` encode numbers as is without loss of +type. In `msgpack.v2` the type of a number is reduced to a value. + +Secondly, a base decoding function does not convert numbers to `int64` or +`uint64`. It converts numbers to an exact type defined by MessagePack. The +change makes manual type conversions much more difficult and can lead to +runtime errors with an old code. We do not recommend to use type conversions +and give preference to `*Typed` functions (besides, it's faster). + +There are also changes in the logic that can lead to errors in the old code, +[as example](https://github.com/vmihailenco/msgpack/issues/327). Although in +`msgpack/v5` some functions for the logic tuning were added (see +[UseLooseInterfaceDecoding](https://pkg.go.dev/github.com/vmihailenco/msgpack/v5#Decoder.UseLooseInterfaceDecoding), [UseCompactInts](https://pkg.go.dev/github.com/vmihailenco/msgpack/v5#Encoder.UseCompactInts) etc), it is still impossible +to achieve full compliance of behavior between `msgpack/v5` and `msgpack.v2`. +So we don't go this way. We use standard settings if it possible. + +#### Call = Call17 + +Call requests uses `IPROTO_CALL` instead of `IPROTO_CALL_16`. + +So now `Call` = `Call17` and `NewCallRequest` = `NewCall17Request`. A result +of the requests is an array instead of array of arrays. + +#### IPROTO constants + +* IPROTO constants have been moved to a separate package [go-iproto](https://github.com/tarantool/go-iproto). +* `PushCode` constant is removed. To check whether the current response is + a push response, use `IsPush()` method of the response iterator instead. +* `ErrorNo` constant is added to indicate that no error has occurred while + getting the response. It should be used instead of the removed `OkCode`. + See `ExampleErrorNo`. + +#### Request interface + +* The method `Code() uint32` replaced by the `Type() iproto.Type`. +* `Response` method added to the `Request` interface. + +#### Request changes + +* Requests `Update`, `UpdateAsync`, `UpdateTyped`, `Upsert`, `UpsertAsync` no +longer accept `ops` argument (operations) as an `interface{}`. `*Operations` +needs to be passed instead. +* `Op` struct for update operations made private. +* Removed `OpSplice` struct. +* `Operations.Splice` method now accepts 5 arguments instead of 3. +* `UpdateRequest` and `UpsertRequest` structs no longer accept `interface{}` +for an `ops` field. `*Operations` needs to be used instead. + +#### Response interface + +* `Response` is now an interface. +* Response header stored in a new `Header` struct. It could be accessed via + `Header()` method. + +#### Response changes + +* `ResponseIterator` interface now has `IsPush()` method. + It returns true if the current response is a push response. +* For each request type, a different response type is created. They all + implement a `Response` interface. `SelectResponse`, `PrepareResponse`, + `ExecuteResponse`, `PushResponse` are a part of a public API. + `Pos()`, `MetaData()`, `SQLInfo()` methods created for them to get specific + info. Special types of responses are used with special requests. + +#### Future type + +* Method `Get` now returns response data instead of the actual response. +* New method `GetResponse` added to get an actual response. +* `Future` constructors now accept `Request` as their argument. +* Methods `AppendPush` and `SetResponse` accepts response `Header` and data + as their arguments. +* Method `Err` was removed because it was causing improper error handling. + You need to check an error from `Get`, `GetTyped` or `GetResponse` with + an addition check of a value `Response.Header().Error`, see `ExampleErrorNo`. + +#### Connector interface + +* Operations `Ping`, `Select`, `Insert`, `Replace`, `Delete`, `Update`, + `Upsert`, `Call`, `Call16`, `Call17`, `Eval`, `Execute` of a `Connector` + return response data instead of an actual responses. +* New interface `Doer` is added as a child-interface instead of a `Do` method. + +#### Connect function + +`connection.Connect` no longer return non-working connection objects. This +function now does not attempt to reconnect and tries to establish a connection +only once. Function might be canceled via context. Context accepted as first +argument, and user may cancel it in process. + +Now you need to pass `Dialer` as the second argument instead of URI. +If you were using a non-SSL connection, you need to create `NetDialer`. +For SSL-enabled connections, use `OpenSSLDialer` from the +[go-tlsdialer](https://github.com/tarantool/go-tlsdialer) package. + +Please note that the options for creating a connection are now stored in +corresponding `Dialer`, not in `Opts`. + +#### Connection schema + +* Removed `Schema` field from the `Connection` struct. Instead, new + `GetSchema(Doer)` function was added to get the actual connection + schema on demand. +* `OverrideSchema(*Schema)` method replaced with the `SetSchema(Schema)`. + +#### Protocol types + +* `iproto.Feature` type used instead of `ProtocolFeature`. +* `iproto.IPROTO_FEATURE_` constants used instead of local ones. + +#### Schema type + +* `ResolveSpaceIndex` function for `SchemaResolver` interface split into two: +`ResolveSpace` and `ResolveIndex`. `NamesUseSupported` function added into the +interface to get information if the usage of space and index names in requests +is supported. +* `Schema` structure no longer implements `SchemaResolver` interface. +* `Spaces` and `SpacesById` fields of the `Schema` struct store spaces by value. +* `Fields` and `FieldsById` fields of the `Space` struct store fields by value. +`Index` and `IndexById` fields of the `Space` struct store indexes by value. +* `Fields` field of the `Index` struct store `IndexField` by value. + +### datetime package + +Now you need to use objects of the Datetime type instead of pointers to it. A +new constructor `MakeDatetime` returns an object. `NewDatetime` has been +removed. + +### decimal package + +Now you need to use objects of the Decimal type instead of pointers to it. A +new constructor `MakeDecimal` returns an object. `NewDecimal` has been removed. + +### multi package + +The subpackage has been deleted. You could use `pool` instead. + +### pool package + +* The `connection_pool` subpackage has been renamed to `pool`. +* The type `PoolOpts` has been renamed to `Opts`. +* `pool.Connect` and `pool.ConnectWithOpts` now accept context as the first + argument, which user may cancel in process. If it is canceled in progress, + an error will be returned and all created connections will be closed. +* `pool.Connect` and `pool.ConnectWithOpts` now accept `[]pool.Instance` as + the second argument instead of a list of addresses. Each instance is + associated with a unique string name, `Dialer` and connection options which + allows instances to be independently configured. +* `pool.Connect`, `pool.ConnectWithOpts` and `pool.Add` add instances into + the pool even it is unable to connect to it. The pool will try to connect to + the instance later. +* `pool.Add` now accepts context as the first argument, which user may cancel + in process. +* `pool.Add` now accepts `pool.Instance` as the second argument instead of + an address, it allows to configure a new instance more flexible. +* `pool.GetPoolInfo` has been renamed to `pool.GetInfo`. Return type has been + changed to `map[string]ConnectionInfo`. +* Operations `Ping`, `Select`, `Insert`, `Replace`, `Delete`, `Update`, `Upsert`, + `Call`, `Call16`, `Call17`, `Eval`, `Execute` of a `Pooler` return + response data instead of an actual responses. + +### crud package + +* `crud` operations `Timeout` option has `crud.OptFloat64` type + instead of `crud.OptUint`. +* A slice of a custom type could be used as tuples for `ReplaceManyRequest` and + `InsertManyRequest`, `ReplaceObjectManyRequest`. +* A slice of a custom type could be used as objects for `ReplaceObjectManyRequest` + and `InsertObjectManyRequest`. + +### test_helpers package + +* Renamed `StrangerResponse` to `MockResponse`. diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..556fad418 --- /dev/null +++ b/Makefile @@ -0,0 +1,160 @@ +SHELL := /bin/bash +COVERAGE_FILE := coverage.out +MAKEFILE_PATH := $(abspath $(lastword $(MAKEFILE_LIST))) +PROJECT_DIR := $(patsubst %/,%,$(dir $(MAKEFILE_PATH))) +DURATION ?= 3s +COUNT ?= 5 +BENCH_PATH ?= bench-dir +TEST_PATH ?= ${PROJECT_DIR}/... +BENCH_FILE := ${PROJECT_DIR}/${BENCH_PATH}/bench.txt +REFERENCE_FILE := ${PROJECT_DIR}/${BENCH_PATH}/reference.txt +BENCH_FILES := ${REFERENCE_FILE} ${BENCH_FILE} +BENCH_REFERENCE_REPO := ${BENCH_PATH}/go-tarantool +BENCH_OPTIONS := -bench=. -run=^Benchmark -benchmem -benchtime=${DURATION} -count=${COUNT} +GO_TARANTOOL_URL := https://github.com/tarantool/go-tarantool +GO_TARANTOOL_DIR := ${PROJECT_DIR}/${BENCH_PATH}/go-tarantool +TAGS := + +.PHONY: clean +clean: + ( rm -rf queue/testdata/.rocks crud/testdata/.rocks ) + rm -f $(COVERAGE_FILE) + +.PHONY: deps +deps: clean + @(command -v tt > /dev/null || (echo "error: tt not found" && exit 1)) + ( cd ./queue/testdata; tt rocks install queue 1.3.0 ) + ( cd ./crud/testdata; tt rocks install crud ) + +.PHONY: datetime-timezones +datetime-timezones: + (cd ./datetime; ./gen-timezones.sh) + +.PHONY: format +format: + goimports -l -w . + +.PHONY: golangci-lint +golangci-lint: + golangci-lint run --config=.golangci.yaml + +.PHONY: test +test: + @echo "Running all packages tests" + go clean -testcache + go test -tags "$(TAGS)" ./... -v -p 1 + +.PHONY: testdata +testdata: + (cd ./testdata; ./generate.sh) + +.PHONY: testrace +testrace: + @echo "Running all packages tests with data race detector" + go clean -testcache + go test -race -tags "$(TAGS)" ./... -v -p 1 + +.PHONY: test-pool +test-pool: + @echo "Running tests in pool package" + go clean -testcache + go test -tags "$(TAGS)" ./pool/ -v -p 1 + +.PHONY: test-datetime +test-datetime: + @echo "Running tests in datetime package" + go clean -testcache + go test -tags "$(TAGS)" ./datetime/ -v -p 1 + +.PHONY: test-decimal +test-decimal: + @echo "Running tests in decimal package" + go clean -testcache + go test -tags "$(TAGS)" ./decimal/ -v -p 1 + +.PHONY: test-queue +test-queue: + @echo "Running tests in queue package" + cd ./queue/ && tarantool -e "require('queue')" + go clean -testcache + go test -tags "$(TAGS)" ./queue/ -v -p 1 + +.PHONY: test-uuid +test-uuid: + @echo "Running tests in UUID package" + go clean -testcache + go test -tags "$(TAGS)" ./uuid/ -v -p 1 + +.PHONY: test-settings +test-settings: + @echo "Running tests in settings package" + go clean -testcache + go test -tags "$(TAGS)" ./settings/ -v -p 1 + +.PHONY: test-crud +test-crud: + @echo "Running tests in crud package" + cd ./crud/testdata && tarantool -e "require('crud')" + go clean -testcache + go test -tags "$(TAGS)" ./crud/ -v -p 1 + +.PHONY: test-main +test-main: + @echo "Running tests in main package" + go clean -testcache + go test -tags "$(TAGS)" . -v -p 1 + +.PHONY: coverage +coverage: + go clean -testcache + go get golang.org/x/tools/cmd/cover + go test -tags "$(TAGS)" $(go list ./... | grep -v test_helpers) -v -p 1 -covermode=atomic -coverprofile=$(COVERAGE_FILE) + go tool cover -func=$(COVERAGE_FILE) + +.PHONY: coveralls +coveralls: coverage + go get github.com/mattn/goveralls + go install github.com/mattn/goveralls + goveralls -coverprofile=$(COVERAGE_FILE) -service=github + +.PHONY: bench-deps +${BENCH_PATH} bench-deps: + @echo "Installing benchstat tool" + rm -rf ${BENCH_PATH} + mkdir ${BENCH_PATH} + go clean -testcache + cd ${BENCH_PATH} && \ + go get golang.org/x/perf/cmd/benchstat + go install golang.org/x/perf/cmd/benchstat + +.PHONY: bench +${BENCH_FILE} bench: ${BENCH_PATH} + @echo "Running benchmark tests from the current branch" + go test -tags "$(TAGS)" ${TEST_PATH} ${BENCH_OPTIONS} 2>&1 \ + | tee ${BENCH_FILE} + benchstat ${BENCH_FILE} + +${GO_TARANTOOL_DIR}: + @echo "Cloning the repository into ${GO_TARANTOOL_DIR}" + [ ! -e ${GO_TARANTOOL_DIR} ] && git clone --depth=1 ${GO_TARANTOOL_URL} ${GO_TARANTOOL_DIR} + +${REFERENCE_FILE}: ${GO_TARANTOOL_DIR} + @echo "Running benchmark tests from master for using results in bench-diff target" + cd ${GO_TARANTOOL_DIR} && git pull && go test ./... -tags "$(TAGS)" ${BENCH_OPTIONS} 2>&1 \ + | tee ${REFERENCE_FILE} + +bench-diff: ${BENCH_FILES} + @echo "Comparing performance between master and the current branch" + @echo "'old' is a version in master branch, 'new' is a version in a current branch" + benchstat ${BENCH_FILES} | grep -v pkg: + +.PHONY: fuzzing +fuzzing: + @echo "Running fuzzing tests" + go clean -testcache + go test -tags "$(TAGS)" ./... -run=^Fuzz -v -p 1 + +.PHONY: codespell +codespell: + @echo "Running codespell" + codespell diff --git a/README.md b/README.md index 6fea3477f..3c31d2506 100644 --- a/README.md +++ b/README.md @@ -1,145 +1,276 @@ -# Tarantool + + + -[Tarantool 1.6](http://tarantool.org/) client on Go. +[![Go Reference][godoc-badge]][godoc-url] +[![Actions Status][actions-badge]][actions-url] +[![Code Coverage][coverage-badge]][coverage-url] +[![Telegram][telegram-badge]][telegram-url] +[![GitHub Discussions][discussions-badge]][discussions-url] +[![Stack Overflow][stackoverflow-badge]][stackoverflow-url] -## Usage +# ⚠️ Development Status Notice + +**The current `main` branch is under active development for the next major +release (v3).** + +The API on this branch is **unstable and subject to change**. + +**For production use and stable API, please use the +[`v2`](https://github.com/tarantool/go-tarantool/tree/v2) branch of the +repository.** + +# Client in Go for Tarantool + +The package `go-tarantool` contains everything you need to connect to +[Tarantool 1.10+][tarantool-site]. + +The advantage of integrating Go with Tarantool, which is an application server +plus a DBMS, is that Go programmers can handle databases and perform on-the-fly +recompilations of embedded Lua routines, just as in C, with responses that are +faster than other packages according to public benchmarks. + +## Table of contents + +* [Installation](#installation) + * [Build tags](#build-tags) +* [Documentation](#documentation) + * [API reference](#api-reference) + * [Walking\-through example](#walking-through-example) + * [Example with encrypting traffic](#example-with-encrypting-traffic) +* [Migration guide](#migration-guide) +* [Contributing](#contributing) +* [Alternative connectors](#alternative-connectors) + +## Installation + +We assume that you have Tarantool version 1.10+ and a modern Linux or BSD +operating system. + +You need a current version of `go`, version 1.24 or later (use `go version` to +check the version number). Do not use `gccgo-go`. + +**Note:** If your `go` version is older than 1.24 or if `go` is not installed, +download and run the latest tarball from [golang.org][golang-dl]. + +The package `go-tarantool` is located in [tarantool/go-tarantool][go-tarantool] +repository. To download and install, say: + +``` +$ go get github.com/tarantool/go-tarantool/v3 +``` + +This should put the source and binary files in subdirectories of +`/usr/local/go`, so that you can access them by adding +`github.com/tarantool/go-tarantool` to the `import {...}` section at the start +of any Go program. + +### Build tags + +We define multiple [build tags](https://pkg.go.dev/go/build#hdr-Build_Constraints). + +This allows us to introduce new features without losing backward compatibility. + +1. To run fuzz tests with decimals, you can use the build tag: + ``` + go_tarantool_decimal_fuzzing + ``` + **Note:** It crashes old Tarantool versions. + +## Documentation + +Read the [Tarantool documentation][tarantool-doc-data-model-url] +to find descriptions of terms such as "connect", "space", "index", and the +requests to create and manipulate database objects or Lua functions. + +In general, connector methods can be divided into two main parts: + +* `Connect()` function and functions related to connecting, and +* Data manipulation functions and Lua invocations such as `Insert()` or `Call()`. + +The supported requests have parameters and results equivalent to requests in +the [Tarantool CRUD operations][tarantool-doc-box-space-url]. +There are also Typed and Async versions of each data-manipulation function. + +### API Reference + +Learn API documentation and examples at [pkg.go.dev][godoc-url]. + +### Walking-through example + +We can now have a closer look at the example and make some observations +about what it does. ```go -package main +package tarantool import ( - "github.com/fl00r/go-tarantool-1.6" - "fmt" + "context" + "fmt" + "time" + + "github.com/tarantool/go-tarantool/v3" + _ "github.com/tarantool/go-tarantool/v3/datetime" + _ "github.com/tarantool/go-tarantool/v3/decimal" + _ "github.com/tarantool/go-tarantool/v3/uuid" ) func main() { - server := "127.0.0.1:3013" - spaceNo := uint32(514) - indexNo := uint32(0) - limit := uint32(10) - offset := uint32(0) - iterator := tarantool.IterAll - key := []interface{}{ 12 } - tuple1 := []interface{}{ 12, "Hello World", "Olga" } - tuple2 := []interface{}{ 12, "Hello Mars", "Anna" } - upd_tuple := []interface{}{ []interface{}{ "=", 1, "Hello Moon" }, []interface{}{ "#", 2, 1 } } - - functionName := "box.cfg()" - functionTuple := []interface{}{ "box.schema.SPACE_ID" } - - - client, err := tarantool.Connect(server) - - var resp *tarantool.Response - - resp, err = client.Ping() - fmt.Println("Ping") - fmt.Println("ERROR", err) - fmt.Println("Code", resp.Code) - fmt.Println("Data", resp.Data) - fmt.Println("----") - - resp, err = client.Insert(spaceNo, tuple1) - fmt.Println("Insert") - fmt.Println("ERROR", err) - fmt.Println("Code", resp.Code) - fmt.Println("Data", resp.Data) - fmt.Println("----") - - resp, err = client.Select(spaceNo, indexNo, offset, limit, iterator, key) - fmt.Println("Select") - fmt.Println("ERROR", err) - fmt.Println("Code", resp.Code) - fmt.Println("Data", resp.Data) - fmt.Println("----") - - resp, err = client.Replace(spaceNo, tuple2) - fmt.Println("Replace") - fmt.Println("ERROR", err) - fmt.Println("Code", resp.Code) - fmt.Println("Data", resp.Data) - fmt.Println("----") - - resp, err = client.Select(spaceNo, indexNo, offset, limit, iterator, key) - fmt.Println("Select") - fmt.Println("ERROR", err) - fmt.Println("Code", resp.Code) - fmt.Println("Data", resp.Data) - fmt.Println("----") - - resp, err = client.Update(spaceNo, indexNo, key, upd_tuple) - fmt.Println("Update") - fmt.Println("ERROR", err) - fmt.Println("Code", resp.Code) - fmt.Println("Data", resp.Data) - fmt.Println("----") - - resp, err = client.Select(spaceNo, indexNo, offset, limit, iterator, key) - fmt.Println("Select") - fmt.Println("ERROR", err) - fmt.Println("Code", resp.Code) - fmt.Println("Data", resp.Data) - fmt.Println("----") - - resp, err = client.Delete(spaceNo, indexNo, key) - fmt.Println("Delete") - fmt.Println("ERROR", err) - fmt.Println("Code", resp.Code) - fmt.Println("Data", resp.Data) - fmt.Println("----") - - resp, err = client.Call(functionName, functionTuple) - fmt.Println("Call") - fmt.Println("ERROR", err) - fmt.Println("Code", resp.Code) - fmt.Println("Data", resp.Data) - fmt.Println("----") + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + dialer := tarantool.NetDialer{ + Address: "127.0.0.1:3301", + User: "guest", + } + opts := tarantool.Opts{ + Timeout: time.Second, + } + + conn, err := tarantool.Connect(ctx, dialer, opts) + if err != nil { + fmt.Println("Connection refused:", err) + return + } + + data, err := conn.Do( + tarantool.NewInsertRequest(999).Tuple([]interface{}{99999, "BB"})).Get() + if err != nil { + fmt.Println("Error:", err) + } else { + fmt.Println("Data:", data) + } } +``` + +**Observation 1:** The line "`github.com/tarantool/go-tarantool/v3`" in the +`import(...)` section brings in all Tarantool-related functions and structures. + +**Observation 2:** Unused import lines are required to initialize encoders and +decoders for external `msgpack` types. + +**Observation 3:** The line starting with "`ctx, cancel :=`" creates a context +object for `Connect()`. The `Connect()` call will return an error when a +timeout expires before the connection is established. + +**Observation 4:** The line starting with "`dialer :=`" creates dialer for +`Connect()`. This structure contains fields required to establish a connection. + +**Observation 5:** The line starting with "`opts :=`" sets up the options for +`Connect()`. In this example, the structure contains only a single value, the +timeout. The structure may also contain other settings, see more in +[documentation][godoc-opts-url] for the "`Opts`" structure. + +**Observation 6:** The line containing "`tarantool.Connect`" is essential for +starting a session. There are three parameters: + +* a context, +* the dialer that was set up earlier, +* the option structure that was set up earlier. + +There will be only one attempt to connect. If multiple attempts needed, +"`tarantool.Connect`" could be placed inside the loop with some timeout +between each try. Example could be found in the [example_test](./example_test.go), +name - `ExampleConnect_reconnects`. -// #=> Connecting to 127.0.0.1:3013 ... -// #=> Connected ... -// #=> Greeting ... Success -// #=> Version: Tarantool 1.6.2-34-ga53cf4a -// #=> -// #=> Insert -// #=> ERROR -// #=> Code 0 -// #=> Data [[12 Hello World Olga]] -// #=> ---- -// #=> Select -// #=> ERROR -// #=> Code 0 -// #=> Data [[12 Hello World Olga]] -// #=> ---- -// #=> Replace -// #=> ERROR -// #=> Code 0 -// #=> Data [[12 Hello Mars Anna]] -// #=> ---- -// #=> Select -// #=> ERROR -// #=> Code 0 -// #=> Data [[12 Hello Mars Anna]] -// #=> ---- -// #=> Update -// #=> ERROR -// #=> Code 0 -// #=> Data [[12 Hello Moon]] -// #=> ---- -// #=> Select -// #=> ERROR -// #=> Code 0 -// #=> Data [[12 Hello Moon]] -// #=> ---- -// #=> Delete -// #=> ERROR -// #=> Code 0 -// #=> Data [[12 Hello Moon]] -// #=> ---- -// #=> Call -// #=> ERROR Execute access denied for user 'guest' to function 'box.cfg()' -// #=> Code 13570 -// #=> Data [] -// #=> ---- +**Observation 7:** The `err` structure will be `nil` if there is no error, +otherwise it will have a description which can be retrieved with `err.Error()`. + +**Observation 8:** The `Insert` request, like almost all requests, is preceded +by the method `Do` of object `conn` which is the object that was returned +by `Connect()`. + +### Example with encrypting traffic + +For SSL-enabled connections, use `OpenSSLDialer` from the +[go-tlsdialer](https://github.com/tarantool/go-tlsdialer) package. + +Here is small example with importing the `go-tlsdialer` library and using the +`OpenSSLDialer`: + +```go +package tarantool + +import ( + "context" + "fmt" + "time" + + "github.com/tarantool/go-tarantool/v3" + _ "github.com/tarantool/go-tarantool/v3/datetime" + _ "github.com/tarantool/go-tarantool/v3/decimal" + _ "github.com/tarantool/go-tarantool/v3/uuid" + "github.com/tarantool/go-tlsdialer" +) + +func main() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + dialer := tlsdialer.OpenSSLDialer{ + Address: "127.0.0.1:3013", + User: "test", + Password: "test", + SslKeyFile: "testdata/localhost.key", + SslCertFile: "testdata/localhost.crt", + SslCaFile: "testdata/ca.crt", + } + opts := tarantool.Opts{ + Timeout: time.Second, + } + + conn, err := tarantool.Connect(ctx, dialer, opts) + if err != nil { + fmt.Println("Connection refused:", err) + return + } + + data, err := conn.Do( + tarantool.NewInsertRequest(999).Tuple([]interface{}{99999, "BB"})).Get() + if err != nil { + fmt.Println("Error:", err) + } else { + fmt.Println("Data:", data) + } +} ``` +Note that [traffic encryption](https://www.tarantool.io/en/doc/latest/enterprise/security/#encrypting-traffic) +is only available in Tarantool Enterprise Edition 2.10 or newer. + +## Migration guide + +You can review the changes between major versions in the +[migration guide](./MIGRATION.md). + +## Contributing + +See [the contributing guide](CONTRIBUTING.md) for detailed instructions on how +to get started with our project. + +## Alternative connectors + +There are two other connectors available from the open source community: + +* [viciious/go-tarantool](https://github.com/viciious/go-tarantool), +* [FZambia/tarantool](https://github.com/FZambia/tarantool). + +See feature comparison in the [documentation][tarantool-doc-connectors-comparison]. + +[tarantool-site]: https://tarantool.io/ +[godoc-badge]: https://pkg.go.dev/badge/github.com/tarantool/go-tarantool/v3.svg +[godoc-url]: https://pkg.go.dev/github.com/tarantool/go-tarantool/v3 +[actions-badge]: https://github.com/tarantool/go-tarantool/actions/workflows/testing.yml/badge.svg +[actions-url]: https://github.com/tarantool/go-tarantool/actions/workflows/testing.yml +[coverage-badge]: https://coveralls.io/repos/github/tarantool/go-tarantool/badge.svg?branch=master +[coverage-url]: https://coveralls.io/github/tarantool/go-tarantool?branch=master +[telegram-badge]: https://img.shields.io/badge/Telegram-join%20chat-blue.svg +[telegram-url]: http://telegram.me/tarantool +[discussions-badge]: https://img.shields.io/github/discussions/tarantool/tarantool +[discussions-url]: https://github.com/tarantool/tarantool/discussions +[stackoverflow-badge]: https://img.shields.io/badge/stackoverflow-tarantool-orange.svg +[stackoverflow-url]: https://stackoverflow.com/questions/tagged/tarantool +[golang-dl]: https://go.dev/dl/ +[go-tarantool]: https://github.com/tarantool/go-tarantool +[tarantool-doc-data-model-url]: https://www.tarantool.io/en/doc/latest/book/box/data_model/ +[tarantool-doc-box-space-url]: https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_space/ +[godoc-opts-url]: https://pkg.go.dev/github.com/tarantool/go-tarantool/v3#Opts +[tarantool-doc-connectors-comparison]: https://www.tarantool.io/en/doc/latest/book/connectors/#go-feature-comparison diff --git a/arrow/arrow.go b/arrow/arrow.go new file mode 100644 index 000000000..a7767c459 --- /dev/null +++ b/arrow/arrow.go @@ -0,0 +1,69 @@ +package arrow + +import ( + "fmt" + "reflect" + + "github.com/vmihailenco/msgpack/v5" +) + +//go:generate go tool gentypes -ext-code 8 Arrow + +// Arrow MessagePack extension type. +const arrowExtId = 8 + +// Arrow struct wraps a raw arrow data buffer. +type Arrow struct { + data []byte +} + +// MakeArrow returns a new arrow.Arrow object that contains +// wrapped a raw arrow data buffer. +func MakeArrow(arrow []byte) (Arrow, error) { + return Arrow{arrow}, nil +} + +// Raw returns a []byte that contains Arrow raw data. +func (a Arrow) Raw() []byte { + return a.data +} + +// MarshalMsgpack implements a custom msgpack marshaler for extension type. +func (a Arrow) MarshalMsgpack() ([]byte, error) { + return a.data, nil +} + +// UnmarshalMsgpack implements a custom msgpack unmarshaler for extension type. +func (a *Arrow) UnmarshalMsgpack(data []byte) error { + a.data = data + return nil +} + +func arrowDecoder(d *msgpack.Decoder, v reflect.Value, extLen int) error { + arrow := Arrow{ + data: make([]byte, extLen), + } + n, err := d.Buffered().Read(arrow.data) + if err != nil { + return fmt.Errorf("arrowDecoder: can't read bytes on Arrow decode: %w", err) + } + if n < extLen || n != len(arrow.data) { + return fmt.Errorf("arrowDecoder: unexpected end of stream after %d Arrow bytes", n) + } + + v.Set(reflect.ValueOf(arrow)) + return nil +} + +func arrowEncoder(e *msgpack.Encoder, v reflect.Value) ([]byte, error) { + arr, ok := v.Interface().(Arrow) + if !ok { + return []byte{}, fmt.Errorf("arrowEncoder: not an Arrow type") + } + return arr.data, nil +} + +func init() { + msgpack.RegisterExtDecoder(arrowExtId, Arrow{}, arrowDecoder) + msgpack.RegisterExtEncoder(arrowExtId, Arrow{}, arrowEncoder) +} diff --git a/arrow/arrow_gen.go b/arrow/arrow_gen.go new file mode 100644 index 000000000..c86c72778 --- /dev/null +++ b/arrow/arrow_gen.go @@ -0,0 +1,241 @@ +// Code generated by github.com/tarantool/go-option; DO NOT EDIT. + +package arrow + +import ( + "fmt" + + "github.com/vmihailenco/msgpack/v5" + "github.com/vmihailenco/msgpack/v5/msgpcode" + + "github.com/tarantool/go-option" +) + +// OptionalArrow represents an optional value of type Arrow. +// It can either hold a valid Arrow (IsSome == true) or be empty (IsZero == true). +type OptionalArrow struct { + value Arrow + exists bool +} + +// SomeOptionalArrow creates an optional OptionalArrow with the given Arrow value. +// The returned OptionalArrow will have IsSome() == true and IsZero() == false. +func SomeOptionalArrow(value Arrow) OptionalArrow { + return OptionalArrow{ + value: value, + exists: true, + } +} + +// NoneOptionalArrow creates an empty optional OptionalArrow value. +// The returned OptionalArrow will have IsSome() == false and IsZero() == true. +// +// Example: +// +// o := NoneOptionalArrow() +// if o.IsZero() { +// fmt.Println("value is absent") +// } +func NoneOptionalArrow() OptionalArrow { + return OptionalArrow{} +} + +func (o OptionalArrow) newEncodeError(err error) error { + if err == nil { + return nil + } + return &option.EncodeError{ + Type: "OptionalArrow", + Parent: err, + } +} + +func (o OptionalArrow) newDecodeError(err error) error { + if err == nil { + return nil + } + + return &option.DecodeError{ + Type: "OptionalArrow", + Parent: err, + } +} + +// IsSome returns true if the OptionalArrow contains a value. +// This indicates the value is explicitly set (not None). +func (o OptionalArrow) IsSome() bool { + return o.exists +} + +// IsZero returns true if the OptionalArrow does not contain a value. +// Equivalent to !IsSome(). Useful for consistency with types where +// zero value (e.g. 0, false, zero struct) is valid and needs to be distinguished. +func (o OptionalArrow) IsZero() bool { + return !o.exists +} + +// IsNil is an alias for IsZero. +// +// This method is provided for compatibility with the msgpack Encoder interface. +func (o OptionalArrow) IsNil() bool { + return o.IsZero() +} + +// Get returns the stored value and a boolean flag indicating its presence. +// If the value is present, returns (value, true). +// If the value is absent, returns (zero value of Arrow, false). +// +// Recommended usage: +// +// if value, ok := o.Get(); ok { +// // use value +// } +func (o OptionalArrow) Get() (Arrow, bool) { + return o.value, o.exists +} + +// MustGet returns the stored value if it is present. +// Panics if the value is absent (i.e., IsZero() == true). +// +// Use with caution — only when you are certain the value exists. +// +// Panics with: "optional value is not set" if no value is set. +func (o OptionalArrow) MustGet() Arrow { + if !o.exists { + panic("optional value is not set") + } + + return o.value +} + +// Unwrap returns the stored value regardless of presence. +// If no value is set, returns the zero value for Arrow. +// +// Warning: Does not check presence. Use IsSome() before calling if you need +// to distinguish between absent value and explicit zero value. +func (o OptionalArrow) Unwrap() Arrow { + return o.value +} + +// UnwrapOr returns the stored value if present. +// Otherwise, returns the provided default value. +// +// Example: +// +// o := NoneOptionalArrow() +// v := o.UnwrapOr(someDefaultOptionalArrow) +func (o OptionalArrow) UnwrapOr(defaultValue Arrow) Arrow { + if o.exists { + return o.value + } + + return defaultValue +} + +// UnwrapOrElse returns the stored value if present. +// Otherwise, calls the provided function and returns its result. +// Useful when the default value requires computation or side effects. +// +// Example: +// +// o := NoneOptionalArrow() +// v := o.UnwrapOrElse(func() Arrow { return computeDefault() }) +func (o OptionalArrow) UnwrapOrElse(defaultValue func() Arrow) Arrow { + if o.exists { + return o.value + } + + return defaultValue() +} + +func (o OptionalArrow) encodeValue(encoder *msgpack.Encoder) error { + value, err := o.value.MarshalMsgpack() + if err != nil { + return err + } + + err = encoder.EncodeExtHeader(8, len(value)) + if err != nil { + return err + } + + _, err = encoder.Writer().Write(value) + if err != nil { + return err + } + + return nil +} + +// EncodeMsgpack encodes the OptionalArrow value using MessagePack format. +// - If the value is present, it is encoded as Arrow. +// - If the value is absent (None), it is encoded as nil. +// +// Returns an error if encoding fails. +func (o OptionalArrow) EncodeMsgpack(encoder *msgpack.Encoder) error { + if o.exists { + return o.newEncodeError(o.encodeValue(encoder)) + } + + return o.newEncodeError(encoder.EncodeNil()) +} + +func (o *OptionalArrow) decodeValue(decoder *msgpack.Decoder) error { + tp, length, err := decoder.DecodeExtHeader() + switch { + case err != nil: + return o.newDecodeError(err) + case tp != 8: + return o.newDecodeError(fmt.Errorf("invalid extension code: %d", tp)) + } + + a := make([]byte, length) + if err := decoder.ReadFull(a); err != nil { + return o.newDecodeError(err) + } + + if err := o.value.UnmarshalMsgpack(a); err != nil { + return o.newDecodeError(err) + } + + o.exists = true + return nil +} + +func (o *OptionalArrow) checkCode(code byte) bool { + return msgpcode.IsExt(code) +} + +// DecodeMsgpack decodes a OptionalArrow value from MessagePack format. +// Supports two input types: +// - nil: interpreted as no value (NoneOptionalArrow) +// - Arrow: interpreted as a present value (SomeOptionalArrow) +// +// Returns an error if the input type is unsupported or decoding fails. +// +// After successful decoding: +// - on nil: exists = false, value = default zero value +// - on Arrow: exists = true, value = decoded value +func (o *OptionalArrow) DecodeMsgpack(decoder *msgpack.Decoder) error { + code, err := decoder.PeekCode() + if err != nil { + return o.newDecodeError(err) + } + + switch { + case code == msgpcode.Nil: + o.exists = false + + return o.newDecodeError(decoder.Skip()) + case o.checkCode(code): + err := o.decodeValue(decoder) + if err != nil { + return o.newDecodeError(err) + } + o.exists = true + + return err + default: + return o.newDecodeError(fmt.Errorf("unexpected code: %d", code)) + } +} diff --git a/arrow/arrow_gen_test.go b/arrow/arrow_gen_test.go new file mode 100644 index 000000000..d990499f4 --- /dev/null +++ b/arrow/arrow_gen_test.go @@ -0,0 +1,124 @@ +package arrow + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/vmihailenco/msgpack/v5" +) + +func TestSomeOptionalArrow(t *testing.T) { + val, err := MakeArrow([]byte{1, 2, 3}) + assert.NoError(t, err) + opt := SomeOptionalArrow(val) + + assert.True(t, opt.IsSome()) + assert.False(t, opt.IsZero()) + + v, ok := opt.Get() + assert.True(t, ok) + assert.Equal(t, val, v) +} + +func TestNoneOptionalArrow(t *testing.T) { + opt := NoneOptionalArrow() + + assert.False(t, opt.IsSome()) + assert.True(t, opt.IsZero()) + + _, ok := opt.Get() + assert.False(t, ok) +} + +func TestOptionalArrow_MustGet(t *testing.T) { + val, err := MakeArrow([]byte{1, 2, 3}) + assert.NoError(t, err) + optSome := SomeOptionalArrow(val) + optNone := NoneOptionalArrow() + + assert.Equal(t, val, optSome.MustGet()) + assert.Panics(t, func() { optNone.MustGet() }) +} + +func TestOptionalArrow_Unwrap(t *testing.T) { + val, err := MakeArrow([]byte{1, 2, 3}) + assert.NoError(t, err) + optSome := SomeOptionalArrow(val) + optNone := NoneOptionalArrow() + + assert.Equal(t, val, optSome.Unwrap()) + assert.Equal(t, Arrow{}, optNone.Unwrap()) +} + +func TestOptionalArrow_UnwrapOr(t *testing.T) { + val, err := MakeArrow([]byte{1, 2, 3}) + assert.NoError(t, err) + def, err := MakeArrow([]byte{4, 5, 6}) + assert.NoError(t, err) + optSome := SomeOptionalArrow(val) + optNone := NoneOptionalArrow() + + assert.Equal(t, val, optSome.UnwrapOr(def)) + assert.Equal(t, def, optNone.UnwrapOr(def)) +} + +func TestOptionalArrow_UnwrapOrElse(t *testing.T) { + val, err := MakeArrow([]byte{1, 2, 3}) + assert.NoError(t, err) + def, err := MakeArrow([]byte{4, 5, 6}) + assert.NoError(t, err) + optSome := SomeOptionalArrow(val) + optNone := NoneOptionalArrow() + + assert.Equal(t, val, optSome.UnwrapOrElse(func() Arrow { return def })) + assert.Equal(t, def, optNone.UnwrapOrElse(func() Arrow { return def })) +} + +func TestOptionalArrow_EncodeDecodeMsgpack_Some(t *testing.T) { + val, err := MakeArrow([]byte{1, 2, 3}) + assert.NoError(t, err) + some := SomeOptionalArrow(val) + + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + dec := msgpack.NewDecoder(&buf) + + err = enc.Encode(some) + assert.NoError(t, err) + + var decodedSome OptionalArrow + err = dec.Decode(&decodedSome) + assert.NoError(t, err) + assert.True(t, decodedSome.IsSome()) + assert.Equal(t, val, decodedSome.Unwrap()) +} + +func TestOptionalArrow_EncodeDecodeMsgpack_None(t *testing.T) { + none := NoneOptionalArrow() + + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + dec := msgpack.NewDecoder(&buf) + + err := enc.Encode(none) + assert.NoError(t, err) + + var decodedNone OptionalArrow + err = dec.Decode(&decodedNone) + assert.NoError(t, err) + assert.True(t, decodedNone.IsZero()) +} + +func TestOptionalArrow_EncodeDecodeMsgpack_InvalidType(t *testing.T) { + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + dec := msgpack.NewDecoder(&buf) + + err := enc.Encode(123) + assert.NoError(t, err) + + var decodedInvalid OptionalArrow + err = dec.Decode(&decodedInvalid) + assert.Error(t, err) +} diff --git a/arrow/arrow_test.go b/arrow/arrow_test.go new file mode 100644 index 000000000..c3f40090c --- /dev/null +++ b/arrow/arrow_test.go @@ -0,0 +1,101 @@ +package arrow_test + +import ( + "bytes" + "encoding/hex" + "testing" + + "github.com/stretchr/testify/require" + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3/arrow" +) + +var longArrow, _ = hex.DecodeString("ffffffff70000000040000009effffff0400010004000000" + + "b6ffffff0c00000004000000000000000100000004000000daffffff140000000202" + + "000004000000f0ffffff4000000001000000610000000600080004000c0010000400" + + "080009000c000c000c0000000400000008000a000c00040006000800ffffffff8800" + + "0000040000008affffff0400030010000000080000000000000000000000acffffff" + + "01000000000000003400000008000000000000000200000000000000000000000000" + + "00000000000000000000000000000800000000000000000000000100000001000000" + + "0000000000000000000000000a00140004000c0010000c0014000400060008000c00" + + "00000000000000000000") + +var tests = []struct { + name string + arr []byte + enc []byte +}{ + { + "abc", + []byte{'a', 'b', 'c'}, + []byte{0xc7, 0x3, 0x8, 'a', 'b', 'c'}, + }, + { + "empty", + []byte{}, + []byte{0xc7, 0x0, 0x8}, + }, + { + "one", + []byte{1}, + []byte{0xd4, 0x8, 0x1}, + }, + { + "long", + longArrow, + []byte{ + 0xc8, 0x1, 0x10, 0x8, 0xff, 0xff, 0xff, 0xff, 0x70, 0x0, 0x0, 0x0, 0x4, 0x0, 0x0, + 0x0, 0x9e, 0xff, 0xff, 0xff, 0x4, 0x0, 0x1, 0x0, 0x4, 0x0, 0x0, 0x0, 0xb6, 0xff, 0xff, + 0xff, 0xc, 0x0, 0x0, 0x0, 0x4, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, + 0x4, 0x0, 0x0, 0x0, 0xda, 0xff, 0xff, 0xff, 0x14, 0x0, 0x0, 0x0, 0x2, 0x2, 0x0, 0x0, + 0x4, 0x0, 0x0, 0x0, 0xf0, 0xff, 0xff, 0xff, 0x40, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, + 0x61, 0x0, 0x0, 0x0, 0x6, 0x0, 0x8, 0x0, 0x4, 0x0, 0xc, 0x0, 0x10, 0x0, 0x4, 0x0, 0x8, + 0x0, 0x9, 0x0, 0xc, 0x0, 0xc, 0x0, 0xc, 0x0, 0x0, 0x0, 0x4, 0x0, 0x0, 0x0, 0x8, 0x0, + 0xa, 0x0, 0xc, 0x0, 0x4, 0x0, 0x6, 0x0, 0x8, 0x0, 0xff, 0xff, 0xff, 0xff, 0x88, 0x0, + 0x0, 0x0, 0x4, 0x0, 0x0, 0x0, 0x8a, 0xff, 0xff, 0xff, 0x4, 0x0, 0x3, 0x0, 0x10, 0x0, + 0x0, 0x0, 0x8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xac, 0xff, 0xff, + 0xff, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x34, 0x0, 0x0, 0x0, 0x8, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x8, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x0, 0x14, 0x0, + 0x4, 0x0, 0xc, 0x0, 0x10, 0x0, 0xc, 0x0, 0x14, 0x0, 0x4, 0x0, 0x6, 0x0, 0x8, 0x0, 0xc, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + }, + }, +} + +func TestEncodeArrow(t *testing.T) { + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := bytes.NewBuffer([]byte{}) + enc := msgpack.NewEncoder(buf) + + arr, err := arrow.MakeArrow(tt.arr) + require.NoError(t, err) + + err = enc.Encode(arr) + require.NoError(t, err) + + require.Equal(t, tt.enc, buf.Bytes()) + }) + + } +} + +func TestDecodeArrow(t *testing.T) { + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + buf := bytes.NewBuffer(tt.enc) + dec := msgpack.NewDecoder(buf) + + var arr arrow.Arrow + err := dec.Decode(&arr) + require.NoError(t, err) + + require.Equal(t, tt.arr, arr.Raw()) + }) + } +} diff --git a/arrow/example_test.go b/arrow/example_test.go new file mode 100644 index 000000000..3510a777d --- /dev/null +++ b/arrow/example_test.go @@ -0,0 +1,61 @@ +// Run Tarantool Enterprise Edition instance before example execution: +// +// Terminal 1: +// $ cd arrow +// $ TEST_TNT_WORK_DIR=$(mktemp -d -t 'tarantool.XXX') tarantool testdata/config-memcs.lua +// +// Terminal 2: +// $ go test -v example_test.go +package arrow_test + +import ( + "context" + "encoding/hex" + "fmt" + "log" + "time" + + "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/arrow" +) + +var arrowBinData, _ = hex.DecodeString("ffffffff70000000040000009effffff0400010004000000" + + "b6ffffff0c00000004000000000000000100000004000000daffffff140000000202" + + "000004000000f0ffffff4000000001000000610000000600080004000c0010000400" + + "080009000c000c000c0000000400000008000a000c00040006000800ffffffff8800" + + "0000040000008affffff0400030010000000080000000000000000000000acffffff" + + "01000000000000003400000008000000000000000200000000000000000000000000" + + "00000000000000000000000000000800000000000000000000000100000001000000" + + "0000000000000000000000000a00140004000c0010000c0014000400060008000c00" + + "00000000000000000000") + +func Example() { + dialer := tarantool.NetDialer{ + Address: "127.0.0.1:3013", + User: "test", + Password: "test", + } + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + client, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + cancel() + if err != nil { + log.Fatalf("Failed to connect: %s", err) + } + + arr, err := arrow.MakeArrow(arrowBinData) + if err != nil { + log.Fatalf("Failed prepare Arrow data: %s", err) + } + + req := arrow.NewInsertRequest("testArrow", arr) + + resp, err := client.Do(req).Get() + if err != nil { + log.Fatalf("Failed insert Arrow: %s", err) + } + if len(resp) > 0 { + log.Fatalf("Unexpected response") + } else { + fmt.Printf("Batch arrow inserted") + } +} diff --git a/arrow/request.go b/arrow/request.go new file mode 100644 index 000000000..82b55f399 --- /dev/null +++ b/arrow/request.go @@ -0,0 +1,82 @@ +package arrow + +import ( + "context" + "io" + + "github.com/tarantool/go-iproto" + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// InsertRequest helps you to create an insert request object for execution +// by a Connection. +type InsertRequest struct { + arrow Arrow + space interface{} + ctx context.Context +} + +// NewInsertRequest returns a new InsertRequest. +func NewInsertRequest(space interface{}, arrow Arrow) *InsertRequest { + return &InsertRequest{ + space: space, + arrow: arrow, + } +} + +// Type returns a IPROTO_INSERT_ARROW type for the request. +func (r *InsertRequest) Type() iproto.Type { + return iproto.IPROTO_INSERT_ARROW +} + +// Async returns false to the request return a response. +func (r *InsertRequest) Async() bool { + return false +} + +// Ctx returns a context of the request. +func (r *InsertRequest) Ctx() context.Context { + return r.ctx +} + +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (r *InsertRequest) Context(ctx context.Context) *InsertRequest { + r.ctx = ctx + return r +} + +// Arrow sets the arrow for insertion the insert arrow request. +// Note: default value is nil. +func (r *InsertRequest) Arrow(arrow Arrow) *InsertRequest { + r.arrow = arrow + return r +} + +// Body fills an msgpack.Encoder with the insert arrow request body. +func (r *InsertRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + if err := enc.EncodeMapLen(2); err != nil { + return err + } + if err := tarantool.EncodeSpace(res, enc, r.space); err != nil { + return err + } + if err := enc.EncodeUint(uint64(iproto.IPROTO_ARROW)); err != nil { + return err + } + return enc.Encode(r.arrow) +} + +// Response creates a response for the InsertRequest. +func (r *InsertRequest) Response( + header tarantool.Header, + body io.Reader, +) (tarantool.Response, error) { + return tarantool.DecodeBaseResponse(header, body) +} diff --git a/arrow/request_test.go b/arrow/request_test.go new file mode 100644 index 000000000..251b2b31d --- /dev/null +++ b/arrow/request_test.go @@ -0,0 +1,141 @@ +package arrow_test + +import ( + "bytes" + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tarantool/go-iproto" + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/arrow" +) + +const validSpace uint32 = 1 // Any valid value != default. + +func TestInsertRequestType(t *testing.T) { + request := arrow.NewInsertRequest(validSpace, arrow.Arrow{}) + require.Equal(t, iproto.IPROTO_INSERT_ARROW, request.Type()) +} + +func TestInsertRequestAsync(t *testing.T) { + request := arrow.NewInsertRequest(validSpace, arrow.Arrow{}) + require.Equal(t, false, request.Async()) +} + +func TestInsertRequestCtx_default(t *testing.T) { + request := arrow.NewInsertRequest(validSpace, arrow.Arrow{}) + require.Equal(t, nil, request.Ctx()) +} + +func TestInsertRequestCtx_setter(t *testing.T) { + ctx := context.Background() + request := arrow.NewInsertRequest(validSpace, arrow.Arrow{}).Context(ctx) + require.Equal(t, ctx, request.Ctx()) +} + +func TestResponseDecode(t *testing.T) { + header := tarantool.Header{} + buf := bytes.NewBuffer([]byte{}) + enc := msgpack.NewEncoder(buf) + + enc.EncodeMapLen(1) + enc.EncodeUint8(uint8(iproto.IPROTO_DATA)) + enc.Encode([]interface{}{'v', '2'}) + + request := arrow.NewInsertRequest(validSpace, arrow.Arrow{}) + resp, err := request.Response(header, bytes.NewBuffer(buf.Bytes())) + require.NoError(t, err) + require.Equal(t, header, resp.Header()) + + decodedInterface, err := resp.Decode() + require.NoError(t, err) + require.Equal(t, []interface{}{'v', '2'}, decodedInterface) +} + +func TestResponseDecodeTyped(t *testing.T) { + header := tarantool.Header{} + buf := bytes.NewBuffer([]byte{}) + enc := msgpack.NewEncoder(buf) + + enc.EncodeMapLen(1) + enc.EncodeUint8(uint8(iproto.IPROTO_DATA)) + enc.EncodeBytes([]byte{'v', '2'}) + + request := arrow.NewInsertRequest(validSpace, arrow.Arrow{}) + resp, err := request.Response(header, bytes.NewBuffer(buf.Bytes())) + require.NoError(t, err) + require.Equal(t, header, resp.Header()) + + var decoded []byte + err = resp.DecodeTyped(&decoded) + require.NoError(t, err) + require.Equal(t, []byte{'v', '2'}, decoded) +} + +type stubSchemeResolver struct { + space interface{} +} + +func (r stubSchemeResolver) ResolveSpace(s interface{}) (uint32, error) { + if id, ok := r.space.(uint32); ok { + return id, nil + } + if _, ok := r.space.(string); ok { + return 0, nil + } + return 0, fmt.Errorf("stub error message: %v", r.space) +} + +func (stubSchemeResolver) ResolveIndex(i interface{}, spaceNo uint32) (uint32, error) { + return 0, nil +} + +func (r stubSchemeResolver) NamesUseSupported() bool { + _, ok := r.space.(string) + return ok +} + +func TestInsertRequestDefaultValues(t *testing.T) { + buf := bytes.NewBuffer([]byte{}) + enc := msgpack.NewEncoder(buf) + + resolver := stubSchemeResolver{validSpace} + req := arrow.NewInsertRequest(resolver.space, arrow.Arrow{}) + err := req.Body(&resolver, enc) + require.NoError(t, err) + + require.Equal(t, []byte{0x82, 0x10, 0x1, 0x36, 0xc7, 0x0, 0x8}, buf.Bytes()) +} + +func TestInsertRequestSpaceByName(t *testing.T) { + buf := bytes.NewBuffer([]byte{}) + enc := msgpack.NewEncoder(buf) + + resolver := stubSchemeResolver{"valid"} + req := arrow.NewInsertRequest(resolver.space, arrow.Arrow{}) + err := req.Body(&resolver, enc) + require.NoError(t, err) + + require.Equal(t, + []byte{0x82, 0x5e, 0xa5, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x36, 0xc7, 0x0, 0x8}, + buf.Bytes()) +} + +func TestInsertRequestSetters(t *testing.T) { + buf := bytes.NewBuffer([]byte{}) + enc := msgpack.NewEncoder(buf) + + arr, err := arrow.MakeArrow([]byte{'a', 'b', 'c'}) + require.NoError(t, err) + + resolver := stubSchemeResolver{validSpace} + req := arrow.NewInsertRequest(resolver.space, arr) + err = req.Body(&resolver, enc) + require.NoError(t, err) + + require.Equal(t, []byte{0x82, 0x10, 0x1, 0x36, 0xc7, 0x3, 0x8, 'a', 'b', 'c'}, buf.Bytes()) +} diff --git a/arrow/tarantool_test.go b/arrow/tarantool_test.go new file mode 100644 index 000000000..cc2ad552e --- /dev/null +++ b/arrow/tarantool_test.go @@ -0,0 +1,121 @@ +package arrow_test + +import ( + "encoding/hex" + "log" + "os" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/tarantool/go-iproto" + + "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/arrow" + "github.com/tarantool/go-tarantool/v3/test_helpers" +) + +var isArrowSupported = false + +var server = "127.0.0.1:3013" +var dialer = tarantool.NetDialer{ + Address: server, + User: "test", + Password: "test", +} +var space = "testArrow" + +var opts = tarantool.Opts{ + Timeout: 5 * time.Second, +} + +// TestInsert uses Arrow sequence from Tarantool's test. +// See: https://github.com/tarantool/tarantool/blob/d628b71bc537a75b69c253f45ec790462cf1a5cd/test/box-luatest/gh_10508_iproto_insert_arrow_test.lua#L56 +func TestInsert_invalid(t *testing.T) { + arrows := []struct { + arrow string + expected iproto.Error + }{ + { + "", + iproto.ER_INVALID_MSGPACK, + }, + { + "00", + iproto.ER_INVALID_MSGPACK, + }, + { + "ffffffff70000000040000009effffff0400010004000000" + + "b6ffffff0c00000004000000000000000100000004000000daffffff140000000202" + + "000004000000f0ffffff4000000001000000610000000600080004000c0010000400" + + "080009000c000c000c0000000400000008000a000c00040006000800ffffffff8800" + + "0000040000008affffff0400030010000000080000000000000000000000acffffff" + + "01000000000000003400000008000000000000000200000000000000000000000000" + + "00000000000000000000000000000800000000000000000000000100000001000000" + + "0000000000000000000000000a00140004000c0010000c0014000400060008000c00" + + "00000000000000000000", + iproto.ER_UNSUPPORTED, + }, + } + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + for i, a := range arrows { + t.Run(strconv.Itoa(i), func(t *testing.T) { + data, err := hex.DecodeString(a.arrow) + require.NoError(t, err) + + arr, err := arrow.MakeArrow(data) + require.NoError(t, err) + + req := arrow.NewInsertRequest(space, arr) + _, err = conn.Do(req).Get() + ttErr := err.(tarantool.Error) + + require.Equal(t, a.expected, ttErr.Code) + }) + } + +} + +// runTestMain is a body of TestMain function +// (see https://pkg.go.dev/testing#hdr-Main). +// Using defer + os.Exit is not works so TestMain body +// is a separate function, see +// https://stackoverflow.com/questions/27629380/how-to-exit-a-go-program-honoring-deferred-calls +func runTestMain(m *testing.M) int { + isLess, err := test_helpers.IsTarantoolVersionLess(3, 3, 0) + if err != nil { + log.Fatalf("Failed to extract Tarantool version: %s", err) + } + isArrowSupported = !isLess + + if !isArrowSupported { + log.Println("Skipping insert Arrow tests...") + return 0 + } + + instance, err := test_helpers.StartTarantool(test_helpers.StartOpts{ + Dialer: dialer, + InitScript: "testdata/config-memtx.lua", + Listen: server, + WaitStart: 100 * time.Millisecond, + ConnectRetry: 10, + RetryTimeout: 500 * time.Millisecond, + }) + defer test_helpers.StopTarantoolWithCleanup(instance) + + if err != nil { + log.Printf("Failed to prepare test Tarantool: %s", err) + return 1 + } + + return m.Run() +} + +func TestMain(m *testing.M) { + code := runTestMain(m) + os.Exit(code) +} diff --git a/arrow/testdata/config-memcs.lua b/arrow/testdata/config-memcs.lua new file mode 100644 index 000000000..ae1cc430e --- /dev/null +++ b/arrow/testdata/config-memcs.lua @@ -0,0 +1,31 @@ +-- Do not set listen for now so connector won't be +-- able to send requests until everything is configured. +box.cfg { + work_dir = os.getenv("TEST_TNT_WORK_DIR") +} + +box.schema.user.create('test', { + password = 'test', + if_not_exists = true +}) +box.schema.user.grant('test', 'execute', 'universe', nil, { + if_not_exists = true +}) + +local s = box.schema.space.create('testArrow', { + engine = 'memcs', + field_count = 1, + format = {{'a', 'uint64'}}, + if_not_exists = true +}) +s:create_index('primary') +s:truncate() + +box.schema.user.grant('test', 'read,write', 'space', 'testArrow', { + if_not_exists = true +}) + +-- Set listen only when every other thing is configured. +box.cfg { + listen = 3013 +} diff --git a/arrow/testdata/config-memtx.lua b/arrow/testdata/config-memtx.lua new file mode 100644 index 000000000..92c0af096 --- /dev/null +++ b/arrow/testdata/config-memtx.lua @@ -0,0 +1,35 @@ +-- Do not set listen for now so connector won't be +-- able to send requests until everything is configured. +box.cfg { + work_dir = os.getenv("TEST_TNT_WORK_DIR") +} + +box.schema.user.create('test', { + password = 'test', + if_not_exists = true +}) +box.schema.user.grant('test', 'execute', 'universe', nil, { + if_not_exists = true +}) + +local s = box.schema.space.create('testArrow', { + if_not_exists = true +}) +s:create_index('primary', { + type = 'tree', + parts = {{ + field = 1, + type = 'integer' + }}, + if_not_exists = true +}) +s:truncate() + +box.schema.user.grant('test', 'read,write', 'space', 'testArrow', { + if_not_exists = true +}) + +-- Set listen only when every other thing is configured. +box.cfg { + listen = os.getenv("TEST_TNT_LISTEN") +} diff --git a/auth.go b/auth.go new file mode 100644 index 000000000..2e5ddc4c4 --- /dev/null +++ b/auth.go @@ -0,0 +1,80 @@ +package tarantool + +import ( + "crypto/sha1" + "encoding/base64" + "fmt" +) + +const ( + chapSha1 = "chap-sha1" + papSha256 = "pap-sha256" +) + +// Auth is used as a parameter to set up an authentication method. +type Auth int + +const ( + // AutoAuth does not force any authentication method. A method will be + // selected automatically (a value from IPROTO_ID response or + // ChapSha1Auth). + AutoAuth Auth = iota + // ChapSha1Auth forces chap-sha1 authentication method. The method is + // available both in the Tarantool Community Edition (CE) and the + // Tarantool Enterprise Edition (EE) + ChapSha1Auth + // PapSha256Auth forces pap-sha256 authentication method. The method is + // available only for the Tarantool Enterprise Edition (EE) with + // SSL transport. + PapSha256Auth +) + +// String returns a string representation of an authentication method. +func (a Auth) String() string { + switch a { + case AutoAuth: + return "auto" + case ChapSha1Auth: + return chapSha1 + case PapSha256Auth: + return papSha256 + default: + return fmt.Sprintf("unknown auth type (code %d)", a) + } +} + +func scramble(encodedSalt, pass string) (scramble []byte, err error) { + /* ================================================================== + According to: http://tarantool.org/doc/dev_guide/box-protocol.html + + salt = base64_decode(encodedSalt); + step1 = sha1(password); + step2 = sha1(step1); + step3 = sha1(salt, step2); + scramble = xor(step1, step3); + return scramble; + + ===================================================================== */ + scrambleSize := sha1.Size // == 20 + + salt, err := base64.StdEncoding.DecodeString(encodedSalt) + if err != nil { + return + } + step1 := sha1.Sum([]byte(pass)) + step2 := sha1.Sum(step1[0:]) + hash := sha1.New() // May be create it once per connection? + hash.Write(salt[0:scrambleSize]) + hash.Write(step2[0:]) + step3 := hash.Sum(nil) + + return xor(step1[0:], step3[0:], scrambleSize), nil +} + +func xor(left, right []byte, size int) []byte { + result := make([]byte, size) + for i := 0; i < size; i++ { + result[i] = left[i] ^ right[i] + } + return result +} diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 000000000..a9a0b34e9 --- /dev/null +++ b/auth_test.go @@ -0,0 +1,28 @@ +package tarantool_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + . "github.com/tarantool/go-tarantool/v3" +) + +func TestAuth_String(t *testing.T) { + unknownId := int(PapSha256Auth) + 1 + tests := []struct { + auth Auth + expected string + }{ + {AutoAuth, "auto"}, + {ChapSha1Auth, "chap-sha1"}, + {PapSha256Auth, "pap-sha256"}, + {Auth(unknownId), fmt.Sprintf("unknown auth type (code %d)", unknownId)}, + } + + for _, tc := range tests { + t.Run(tc.expected, func(t *testing.T) { + assert.Equal(t, tc.auth.String(), tc.expected) + }) + } +} diff --git a/box/box.go b/box/box.go new file mode 100644 index 000000000..4341768cf --- /dev/null +++ b/box/box.go @@ -0,0 +1,58 @@ +package box + +import ( + "errors" + + "github.com/tarantool/go-tarantool/v3" +) + +// Box is a helper that wraps box.* requests. +// It holds a connection to the Tarantool instance via the Doer interface. +type Box struct { + conn tarantool.Doer // Connection interface for interacting with Tarantool. +} + +// New returns a new instance of the box structure, which implements the Box interface. +func New(conn tarantool.Doer) (*Box, error) { + if conn == nil { + return nil, errors.New("tarantool connection cannot be nil") + } + + return &Box{ + conn: conn, // Assigns the provided Tarantool connection. + }, nil +} + +// MustNew returns a new instance of the box structure, which implements the Box interface. +// It panics if conn == nil. +func MustNew(conn tarantool.Doer) *Box { + b, err := New(conn) + if err != nil { + panic(err) + } + return b +} + +// Schema returns a new Schema instance, providing access to schema-related operations. +// It uses the connection from the Box instance to communicate with Tarantool. +func (b *Box) Schema() *Schema { + return newSchema(b.conn) +} + +// Info retrieves the current information of the Tarantool instance. +// It calls the "box.info" function and parses the result into the Info structure. +func (b *Box) Info() (Info, error) { + var infoResp InfoResponse + + // Call "box.info" to get instance information from Tarantool. + fut := b.conn.Do(NewInfoRequest()) + + // Parse the result into the Info structure. + err := fut.GetTyped(&infoResp) + if err != nil { + return Info{}, err + } + + // Return the parsed info and any potential error. + return infoResp.Info, err +} diff --git a/box/box_test.go b/box/box_test.go new file mode 100644 index 000000000..57d0d5526 --- /dev/null +++ b/box/box_test.go @@ -0,0 +1,112 @@ +package box_test + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/tarantool/go-tarantool/v3/box" + "github.com/tarantool/go-tarantool/v3/test_helpers" +) + +func TestNew(t *testing.T) { + t.Parallel() + + _, err := box.New(nil) + require.Error(t, err) +} + +func TestMustNew(t *testing.T) { + t.Parallel() + + // Create a box instance with a nil connection. This should lead to a panic. + require.Panics(t, func() { box.MustNew(nil) }) +} + +func TestMocked_BoxNew(t *testing.T) { + t.Parallel() + + mock := test_helpers.NewMockDoer(t, + test_helpers.NewMockResponse(t, "valid"), + ) + + b, err := box.New(&mock) + require.NoError(t, err) + require.NotNil(t, b) + + assert.Len(t, mock.Requests, 0) + b.Schema().User().Exists(box.NewInfoRequest().Ctx(), "") + require.Len(t, mock.Requests, 1) +} + +func TestMocked_BoxInfo(t *testing.T) { + t.Parallel() + + data := []interface{}{ + map[string]interface{}{ + "version": "1.0.0", + "id": nil, + "ro": false, + "uuid": "uuid", + "pid": 456, + "status": "status", + "lsn": 123, + "replication": nil, + }, + } + mock := test_helpers.NewMockDoer(t, + test_helpers.NewMockResponse(t, data), + ) + b := box.MustNew(&mock) + + info, err := b.Info() + require.NoError(t, err) + + assert.Equal(t, "1.0.0", info.Version) + assert.Equal(t, 456, info.PID) +} + +func TestMocked_BoxSchemaUserInfo(t *testing.T) { + t.Parallel() + + data := []interface{}{ + []interface{}{ + []interface{}{"read,write,execute", "universe", ""}, + }, + } + mock := test_helpers.NewMockDoer(t, + test_helpers.NewMockResponse(t, data), + ) + b := box.MustNew(&mock) + + privs, err := b.Schema().User().Info(context.Background(), "username") + require.NoError(t, err) + + assert.Equal(t, []box.Privilege{ + { + Permissions: []box.Permission{ + box.PermissionRead, + box.PermissionWrite, + box.PermissionExecute, + }, + Type: box.PrivilegeUniverse, + Name: "", + }, + }, privs) +} + +func TestMocked_BoxSessionSu(t *testing.T) { + t.Parallel() + + mock := test_helpers.NewMockDoer(t, + test_helpers.NewMockResponse(t, []interface{}{}), + errors.New("user not found or supplied credentials are invalid"), + ) + b := box.MustNew(&mock) + + err := b.Session().Su(context.Background(), "admin") + require.NoError(t, err) +} diff --git a/box/example_test.go b/box/example_test.go new file mode 100644 index 000000000..39474d8aa --- /dev/null +++ b/box/example_test.go @@ -0,0 +1,244 @@ +// Run Tarantool Common Edition before example execution: +// +// Terminal 1: +// $ cd box +// $ TEST_TNT_LISTEN=127.0.0.1:3013 tarantool testdata/config.lua +// +// Terminal 2: +// $ go test -v example_test.go + +package box_test + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/box" +) + +func ExampleBox_Info() { + dialer := tarantool.NetDialer{ + Address: "127.0.0.1:3013", + User: "test", + Password: "test", + } + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + client, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + cancel() + if err != nil { + log.Fatalf("Failed to connect: %s", err) + } + + // You can use Info Request type. + + fut := client.Do(box.NewInfoRequest()) + + resp := &box.InfoResponse{} + + err = fut.GetTyped(resp) + if err != nil { + log.Fatalf("Failed get box info: %s", err) + } + + // Or use simple Box implementation. + + b := box.MustNew(client) + + info, err := b.Info() + if err != nil { + log.Fatalf("Failed get box info: %s", err) + } + + if info.UUID != resp.Info.UUID { + log.Fatalf("Box info uuids are not equal") + } + + fmt.Printf("Box info uuids are equal\n") + fmt.Printf("Current box ro: %+v", resp.Info.RO) + // Output: + // Box info uuids are equal + // Current box ro: false +} + +func ExampleSchemaUser_Exists() { + dialer := tarantool.NetDialer{ + Address: "127.0.0.1:3013", + User: "test", + Password: "test", + } + ctx := context.Background() + + client, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + + if err != nil { + log.Fatalf("Failed to connect: %s", err) + } + + // You can use UserExistsRequest type and call it directly. + fut := client.Do(box.NewUserExistsRequest("user")) + + resp := &box.UserExistsResponse{} + + err = fut.GetTyped(resp) + if err != nil { + log.Fatalf("Failed get box schema user exists with error: %s", err) + } + + // Or use simple User implementation. + b := box.MustNew(client) + + exists, err := b.Schema().User().Exists(ctx, "user") + if err != nil { + log.Fatalf("Failed get box schema user exists with error: %s", err) + } + + if exists != resp.Exists { + log.Fatalf("Box schema users exists are not equal") + } + + fmt.Printf("Box schema users exists are equal\n") + fmt.Printf("Current exists state: %+v", exists) + // Output: + // Box schema users exists are equal + // Current exists state: false +} + +func ExampleSchemaUser_Create() { + // Connect to Tarantool. + dialer := tarantool.NetDialer{ + Address: "127.0.0.1:3013", + User: "test", + Password: "test", + } + ctx := context.Background() + + client, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + if err != nil { + log.Fatalf("Failed to connect: %s", err) + } + + // Create SchemaUser. + schemaUser := box.MustNew(client).Schema().User() + + // Create a new user. + username := "new_user" + options := box.UserCreateOptions{ + IfNotExists: true, + Password: "secure_password", + } + err = schemaUser.Create(ctx, username, options) + if err != nil { + log.Fatalf("Failed to create user: %s", err) + } + + fmt.Printf("User '%s' created successfully\n", username) + // Output: + // User 'new_user' created successfully +} + +func ExampleSchemaUser_Drop() { + // Connect to Tarantool. + dialer := tarantool.NetDialer{ + Address: "127.0.0.1:3013", + User: "test", + Password: "test", + } + ctx := context.Background() + + client, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + if err != nil { + log.Fatalf("Failed to connect: %s", err) + } + + // Create SchemaUser. + schemaUser := box.MustNew(client).Schema().User() + + // Drop an existing user. + username := "new_user" + options := box.UserDropOptions{ + IfExists: true, + } + err = schemaUser.Drop(ctx, username, options) + if err != nil { + log.Fatalf("Failed to drop user: %s", err) + } + + exists, err := schemaUser.Exists(ctx, username) + if err != nil { + log.Fatalf("Failed to get user exists: %s", err) + } + + fmt.Printf("User '%s' dropped successfully\n", username) + fmt.Printf("User '%s' exists status: %v \n", username, exists) + // Output: + // User 'new_user' dropped successfully + // User 'new_user' exists status: false +} + +func ExampleSchemaUser_Password() { + // Connect to Tarantool. + dialer := tarantool.NetDialer{ + Address: "127.0.0.1:3013", + User: "test", + Password: "test", + } + ctx := context.Background() + + client, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + if err != nil { + log.Fatalf("Failed to connect: %s", err) + } + + // Create SchemaUser. + schemaUser := box.MustNew(client).Schema().User() + + // Get the password hash. + password := "my-password" + passwordHash, err := schemaUser.Password(ctx, password) + if err != nil { + log.Fatalf("Failed to get password hash: %s", err) + } + + fmt.Printf("Password '%s' hash: %s", password, passwordHash) + // Output: + // Password 'my-password' hash: 3PHNAQGFWFo0KRfToxNgDXHj2i8= +} + +func ExampleSchemaUser_Info() { + // Connect to Tarantool. + dialer := tarantool.NetDialer{ + Address: "127.0.0.1:3013", + User: "test", + Password: "test", + } + ctx := context.Background() + + client, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + if err != nil { + log.Fatalf("Failed to connect: %s", err) + } + + // Create SchemaUser. + schemaUser := box.MustNew(client).Schema().User() + + info, err := schemaUser.Info(ctx, "test") + if err != nil { + log.Fatalf("Failed to get password hash: %s", err) + } + + hasSuper := false + for _, i := range info { + if i.Name == "super" && i.Type == box.PrivilegeRole { + hasSuper = true + } + } + + if hasSuper { + fmt.Printf("User have super privileges") + } + // Output: + // User have super privileges +} diff --git a/box/info.go b/box/info.go new file mode 100644 index 000000000..edd4894cd --- /dev/null +++ b/box/info.go @@ -0,0 +1,125 @@ +package box + +import ( + "fmt" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +var _ tarantool.Request = (*InfoRequest)(nil) + +// Info represents detailed information about the Tarantool instance. +// It includes version, node ID, read-only status, process ID, cluster information, and more. +type Info struct { + // The Version of the Tarantool instance. + Version string `msgpack:"version"` + // The node ID (nullable). + ID *int `msgpack:"id"` + // Read-only (RO) status of the instance. + RO bool `msgpack:"ro"` + // UUID - Unique identifier of the instance. + UUID string `msgpack:"uuid"` + // Process ID of the instance. + PID int `msgpack:"pid"` + // Status - Current status of the instance (e.g., running, unconfigured). + Status string `msgpack:"status"` + // LSN - Log sequence number of the instance. + LSN uint64 `msgpack:"lsn"` + // Replication - replication status. + Replication map[int]Replication `msgpack:"replication,omitempty"` +} + +// Replication section of box.info() is a table with statistics for all instances +// in the replica set that the current instance belongs to. +type Replication struct { + // ID is a short numeric identifier of instance n within the replica set. + ID int `msgpack:"id"` + // UUID - Unique identifier of the instance. + UUID string `msgpack:"uuid"` + // LSN - Log sequence number of the instance. + LSN uint64 `msgpack:"lsn"` + // Upstream - information about upstream. + Upstream Upstream `msgpack:"upstream,omitempty"` + // Downstream - information about downstream. + Downstream Downstream `msgpack:"downstream,omitempty"` +} + +// Upstream information. +type Upstream struct { + // Status is replication status of the connection with the instance. + Status string `msgpack:"status"` + // Idle is the time (in seconds) since the last event was received. + Idle float64 `msgpack:"idle"` + // Peer contains instance n’s URI. + Peer string `msgpack:"peer"` + // Lag is the time difference between the local time of instance n, + // recorded when the event was received, and the local time at another master + // recorded when the event was written to the write-ahead log on that master. + Lag float64 `msgpack:"lag"` + // Message contains an error message in case of a degraded state; otherwise, it is nil. + Message string `msgpack:"message,omitempty"` + // SystemMessage contains an error message in case of a degraded state; otherwise, it is nil. + SystemMessage string `msgpack:"system_message,omitempty"` +} + +// Downstream information. +type Downstream struct { + // Status is replication status of the connection with the instance. + Status string `msgpack:"status"` + // Idle is the time (in seconds) since the last event was received. + Idle float64 `msgpack:"idle"` + // VClock contains the vector clock, which is a table of ‘id, lsn’ pairs. + VClock map[int]uint64 `msgpack:"vclock"` + // Lag is the time difference between the local time of instance n, + // recorded when the event was received, and the local time at another master + // recorded when the event was written to the write-ahead log on that master. + Lag float64 `msgpack:"lag"` + // Message contains an error message in case of a degraded state; otherwise, it is nil. + Message string `msgpack:"message,omitempty"` + // SystemMessage contains an error message in case of a degraded state; otherwise, it is nil. + SystemMessage string `msgpack:"system_message,omitempty"` +} + +// InfoResponse represents the response structure +// that holds the information of the Tarantool instance. +// It contains a single field: Info, which holds the instance details (version, UUID, PID, etc.). +type InfoResponse struct { + Info Info +} + +func (ir *InfoResponse) DecodeMsgpack(d *msgpack.Decoder) error { + arrayLen, err := d.DecodeArrayLen() + if err != nil { + return err + } + + if arrayLen != 1 { + return fmt.Errorf("protocol violation; expected 1 array entry, got %d", arrayLen) + } + + i := Info{} + err = d.Decode(&i) + if err != nil { + return err + } + + ir.Info = i + return nil +} + +// InfoRequest represents a request to retrieve information about the Tarantool instance. +// It implements the tarantool.Request interface. +type InfoRequest struct { + *tarantool.CallRequest // Underlying Tarantool call request. +} + +// NewInfoRequest returns a new empty info request. +func NewInfoRequest() InfoRequest { + callReq := tarantool.NewCallRequest("box.info") + + return InfoRequest{ + callReq, + } +} diff --git a/box/info_test.go b/box/info_test.go new file mode 100644 index 000000000..d4e8e9539 --- /dev/null +++ b/box/info_test.go @@ -0,0 +1,128 @@ +package box_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3/box" +) + +func TestInfo(t *testing.T) { + id := 1 + cases := []struct { + Name string + Struct box.Info + Data map[string]interface{} + }{ + { + Name: "Case: base info struct", + Struct: box.Info{ + Version: "2.11.4-0-g8cebbf2cad", + ID: &id, + RO: false, + UUID: "69360e9b-4641-4ec3-ab51-297f46749849", + PID: 1, + Status: "running", + LSN: 8, + }, + Data: map[string]interface{}{ + "version": "2.11.4-0-g8cebbf2cad", + "id": 1, + "ro": false, + "uuid": "69360e9b-4641-4ec3-ab51-297f46749849", + "pid": 1, + "status": "running", + "lsn": 8, + }, + }, + { + Name: "Case: info struct with replication", + Struct: box.Info{ + Version: "2.11.4-0-g8cebbf2cad", + ID: &id, + RO: false, + UUID: "69360e9b-4641-4ec3-ab51-297f46749849", + PID: 1, + Status: "running", + LSN: 8, + Replication: map[int]box.Replication{ + 1: { + ID: 1, + UUID: "69360e9b-4641-4ec3-ab51-297f46749849", + LSN: 8, + }, + 2: { + ID: 2, + UUID: "75f5f5aa-89f0-4d95-b5a9-96a0eaa0ce36", + LSN: 0, + Upstream: box.Upstream{ + Status: "follow", + Idle: 2.4564633660484, + Peer: "other.tarantool:3301", + Lag: 0.00011920928955078, + Message: "'getaddrinfo: Name or service not known'", + SystemMessage: "Input/output error", + }, + Downstream: box.Downstream{ + Status: "follow", + Idle: 2.8306158290943, + VClock: map[int]uint64{1: 8}, + Lag: 0, + Message: "'unexpected EOF when reading from socket'", + SystemMessage: "Broken pipe", + }, + }, + }, + }, + Data: map[string]interface{}{ + "version": "2.11.4-0-g8cebbf2cad", + "id": 1, + "ro": false, + "uuid": "69360e9b-4641-4ec3-ab51-297f46749849", + "pid": 1, + "status": "running", + "lsn": 8, + "replication": map[interface{}]interface{}{ + 1: map[string]interface{}{ + "id": 1, + "uuid": "69360e9b-4641-4ec3-ab51-297f46749849", + "lsn": 8, + }, + 2: map[string]interface{}{ + "id": 2, + "uuid": "75f5f5aa-89f0-4d95-b5a9-96a0eaa0ce36", + "lsn": 0, + "upstream": map[string]interface{}{ + "status": "follow", + "idle": 2.4564633660484, + "peer": "other.tarantool:3301", + "lag": 0.00011920928955078, + "message": "'getaddrinfo: Name or service not known'", + "system_message": "Input/output error", + }, + "downstream": map[string]interface{}{ + "status": "follow", + "idle": 2.8306158290943, + "vclock": map[interface{}]interface{}{1: 8}, + "lag": 0, + "message": "'unexpected EOF when reading from socket'", + "system_message": "Broken pipe", + }, + }, + }, + }, + }, + } + for _, tc := range cases { + data, err := msgpack.Marshal(tc.Data) + require.NoError(t, err, tc.Name) + + var result box.Info + err = msgpack.Unmarshal(data, &result) + require.NoError(t, err, tc.Name) + + require.Equal(t, tc.Struct, result) + } +} diff --git a/box/schema.go b/box/schema.go new file mode 100644 index 000000000..036c76c6d --- /dev/null +++ b/box/schema.go @@ -0,0 +1,23 @@ +package box + +import ( + "github.com/tarantool/go-tarantool/v3" +) + +// Schema represents the schema-related operations in Tarantool. +// It holds a connection to interact with the Tarantool instance. +type Schema struct { + conn tarantool.Doer // Connection interface for interacting with Tarantool. +} + +// newSchema creates a new Schema instance with the provided Tarantool connection. +// It initializes a Schema object that can be used for schema-related operations +// such as managing users, tables, and other schema elements in the Tarantool instance. +func newSchema(conn tarantool.Doer) *Schema { + return &Schema{conn: conn} // Pass the connection to the Schema. +} + +// User returns a new SchemaUser instance, allowing schema-related user operations. +func (s *Schema) User() *SchemaUser { + return newSchemaUser(s.conn) +} diff --git a/box/schema_user.go b/box/schema_user.go new file mode 100644 index 000000000..6e9f558a5 --- /dev/null +++ b/box/schema_user.go @@ -0,0 +1,546 @@ +package box + +import ( + "context" + "fmt" + "strings" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// SchemaUser provides methods to interact with schema-related user operations in Tarantool. +type SchemaUser struct { + conn tarantool.Doer // Connection interface for interacting with Tarantool. +} + +// newSchemaUser creates a new SchemaUser instance with the provided Tarantool connection. +// It initializes a SchemaUser object, which provides methods to perform user-related +// schema operations (such as creating, modifying, or deleting users) in the Tarantool instance. +func newSchemaUser(conn tarantool.Doer) *SchemaUser { + return &SchemaUser{conn: conn} +} + +// UserExistsRequest represents a request to check if a user exists in Tarantool. +type UserExistsRequest struct { + *tarantool.CallRequest // Underlying Tarantool call request. +} + +// UserExistsResponse represents the response to a user existence check. +type UserExistsResponse struct { + Exists bool // True if the user exists, false otherwise. +} + +// DecodeMsgpack decodes the response from a Msgpack-encoded byte slice. +func (uer *UserExistsResponse) DecodeMsgpack(d *msgpack.Decoder) error { + arrayLen, err := d.DecodeArrayLen() + if err != nil { + return err + } + + // Ensure that the response array contains exactly 1 element (the "Exists" field). + if arrayLen != 1 { + return fmt.Errorf("protocol violation; expected 1 array entry, got %d", arrayLen) + } + + // Decode the boolean value indicating whether the user exists. + uer.Exists, err = d.DecodeBool() + + return err +} + +// NewUserExistsRequest creates a new request to check if a user exists. +func NewUserExistsRequest(username string) UserExistsRequest { + callReq := tarantool.NewCallRequest("box.schema.user.exists").Args([]interface{}{username}) + + return UserExistsRequest{ + callReq, + } +} + +// Exists checks if the specified user exists in Tarantool. +func (u *SchemaUser) Exists(ctx context.Context, username string) (bool, error) { + // Create a request and send it to Tarantool. + req := NewUserExistsRequest(username).Context(ctx) + resp := &UserExistsResponse{} + + // Execute the request and parse the response. + err := u.conn.Do(req).GetTyped(resp) + + return resp.Exists, err +} + +// UserCreateOptions represents options for creating a user in Tarantool. +type UserCreateOptions struct { + // IfNotExists - if true, prevents an error if the user already exists. + IfNotExists bool `msgpack:"if_not_exists"` + // Password for the new user. + Password string `msgpack:"password"` +} + +// UserCreateRequest represents a request to create a new user in Tarantool. +type UserCreateRequest struct { + *tarantool.CallRequest // Underlying Tarantool call request. +} + +// NewUserCreateRequest creates a new request to create a user with specified options. +func NewUserCreateRequest(username string, options UserCreateOptions) UserCreateRequest { + callReq := tarantool.NewCallRequest("box.schema.user.create"). + Args([]interface{}{username, options}) + + return UserCreateRequest{ + callReq, + } +} + +// UserCreateResponse represents the response to a user creation request. +type UserCreateResponse struct{} + +// DecodeMsgpack decodes the response for a user creation request. +// In this case, the response does not contain any data. +func (uer *UserCreateResponse) DecodeMsgpack(_ *msgpack.Decoder) error { + return nil +} + +// Create creates a new user in Tarantool with the given username and options. +func (u *SchemaUser) Create(ctx context.Context, username string, options UserCreateOptions) error { + // Create a request and send it to Tarantool. + req := NewUserCreateRequest(username, options).Context(ctx) + resp := &UserCreateResponse{} + + // Execute the request and handle the response. + fut := u.conn.Do(req) + + err := fut.GetTyped(resp) + if err != nil { + return err + } + + return nil +} + +// UserDropOptions represents options for dropping a user in Tarantool. +type UserDropOptions struct { + IfExists bool `msgpack:"if_exists"` // If true, prevents an error if the user does not exist. +} + +// UserDropRequest represents a request to drop a user from Tarantool. +type UserDropRequest struct { + *tarantool.CallRequest // Underlying Tarantool call request. +} + +// NewUserDropRequest creates a new request to drop a user with specified options. +func NewUserDropRequest(username string, options UserDropOptions) UserDropRequest { + callReq := tarantool.NewCallRequest("box.schema.user.drop"). + Args([]interface{}{username, options}) + + return UserDropRequest{ + callReq, + } +} + +// UserDropResponse represents the response to a user drop request. +type UserDropResponse struct{} + +// Drop drops the specified user from Tarantool, with optional conditions. +func (u *SchemaUser) Drop(ctx context.Context, username string, options UserDropOptions) error { + // Create a request and send it to Tarantool. + req := NewUserDropRequest(username, options).Context(ctx) + resp := &UserCreateResponse{} + + // Execute the request and handle the response. + fut := u.conn.Do(req) + + err := fut.GetTyped(resp) + if err != nil { + return err + } + + return nil +} + +// UserPasswordRequest represents a request to retrieve a user's password from Tarantool. +type UserPasswordRequest struct { + *tarantool.CallRequest // Underlying Tarantool call request. +} + +// NewUserPasswordRequest creates a new request to fetch the user's password. +// It takes the username and constructs the request to Tarantool. +func NewUserPasswordRequest(username string) UserPasswordRequest { + // Create a request to get the user's password. + callReq := tarantool.NewCallRequest("box.schema.user.password").Args([]interface{}{username}) + + return UserPasswordRequest{ + callReq, + } +} + +// UserPasswordResponse represents the response to the user password request. +// It contains the password hash. +type UserPasswordResponse struct { + Hash string // The password hash of the user. +} + +// DecodeMsgpack decodes the response from Tarantool in Msgpack format. +// It expects the response to be an array of length 1, containing the password hash string. +func (upr *UserPasswordResponse) DecodeMsgpack(d *msgpack.Decoder) error { + // Decode the array length. + arrayLen, err := d.DecodeArrayLen() + if err != nil { + return err + } + + // Ensure the array contains exactly 1 element (the password hash). + if arrayLen != 1 { + return fmt.Errorf("protocol violation; expected 1 array entry, got %d", arrayLen) + } + + // Decode the string containing the password hash. + upr.Hash, err = d.DecodeString() + + return err +} + +// Password sends a request to retrieve the user's password from Tarantool. +// It returns the password hash as a string or an error if the request fails. +// It works just like hash function. +func (u *SchemaUser) Password(ctx context.Context, password string) (string, error) { + // Create the request and send it to Tarantool. + req := NewUserPasswordRequest(password).Context(ctx) + resp := &UserPasswordResponse{} + + // Execute the request and handle the response. + fut := u.conn.Do(req) + + // Get the decoded response. + err := fut.GetTyped(resp) + if err != nil { + return "", err + } + + // Return the password hash. + return resp.Hash, nil +} + +// UserPasswdRequest represents a request to change a user's password in Tarantool. +type UserPasswdRequest struct { + *tarantool.CallRequest // Underlying Tarantool call request. +} + +// NewUserPasswdRequest creates a new request to change a user's password in Tarantool. +func NewUserPasswdRequest(args ...string) (UserPasswdRequest, error) { + callReq := tarantool.NewCallRequest("box.schema.user.passwd") + + switch len(args) { + case 1: + callReq.Args([]interface{}{args[0]}) + case 2: + callReq.Args([]interface{}{args[0], args[1]}) + default: + return UserPasswdRequest{}, fmt.Errorf("len of fields must be 1 or 2, got %d", len(args)) + + } + + return UserPasswdRequest{callReq}, nil +} + +// UserPasswdResponse represents the response to a user passwd request. +type UserPasswdResponse struct{} + +// Passwd sends a request to set a password for a currently logged in or a specified user. +// A currently logged-in user can change their password using box.schema.user.passwd(password). +// An administrator can change the password of another user +// with box.schema.user.passwd(username, password). +func (u *SchemaUser) Passwd(ctx context.Context, args ...string) error { + req, err := NewUserPasswdRequest(args...) + if err != nil { + return err + } + + req.Context(ctx) + + resp := &UserPasswdResponse{} + + // Execute the request and handle the response. + fut := u.conn.Do(req) + + err = fut.GetTyped(resp) + if err != nil { + return err + } + + return nil +} + +// UserInfoRequest represents a request to get a user's info in Tarantool. +type UserInfoRequest struct { + *tarantool.CallRequest // Underlying Tarantool call request. +} + +// NewUserInfoRequest creates a new request to get user privileges. +func NewUserInfoRequest(username string) UserInfoRequest { + callReq := tarantool.NewCallRequest("box.schema.user.info").Args([]interface{}{username}) + + return UserInfoRequest{ + callReq, + } +} + +// PrivilegeType is a struct based on privilege object types list +// https://www.tarantool.io/en/doc/latest/admin/access_control/#all-object-types-and-permissions +type PrivilegeType string + +const ( + // PrivilegeUniverse - privilege type based on universe. + // A database (box.schema) that contains database objects, including spaces, + // indexes, users, roles, sequences, and functions. + // Granting privileges to universe gives a user access to any object in the database. + PrivilegeUniverse PrivilegeType = "universe" + // PrivilegeTypeUser - privilege type based on user. + // A user identifies a person or program that interacts with a Tarantool instance. + PrivilegeTypeUser PrivilegeType = "user" + // PrivilegeRole - privilege type based on role. + // A role is a container for privileges that can be granted to users. + // Roles can also be assigned to other roles, creating a role hierarchy. + PrivilegeRole PrivilegeType = "role" + // PrivilegeSpace - privilege type based on space. + // Tarantool stores tuples in containers called spaces. + PrivilegeSpace PrivilegeType = "space" + // PrivilegeFunction - privilege type based on functions. + // This allows access control based on function access. + PrivilegeFunction PrivilegeType = "function" + // PrivilegeSequence - privilege type based on sequences. + // A sequence is a generator of ordered integer values. + PrivilegeSequence PrivilegeType = "sequence" + // PrivilegeLuaEval - privilege type based on executing arbitrary Lua code. + PrivilegeLuaEval PrivilegeType = "lua_eval" + // PrivilegeLuaCall - privilege type based on + // calling any global user-defined Lua function. + PrivilegeLuaCall PrivilegeType = "lua_call" + // PrivilegeSQL - privilege type based on + // executing an arbitrary SQL expression. + PrivilegeSQL PrivilegeType = "sql" +) + +// Permission is a struct based on permission tarantool object +// https://www.tarantool.io/en/doc/latest/admin/access_control/#permissions +type Permission string + +const ( + // PermissionRead allows reading data of the specified object. + // For example, this permission can be used to allow a user + // to select data from the specified space. + PermissionRead Permission = "read" + // PermissionWrite allows updating data of the specified object. + // For example, this permission can be used to allow + // a user to modify data in the specified space. + PermissionWrite Permission = "write" + // PermissionCreate allows creating objects of the specified type. + // For example, this permission can be used to allow a user to create new spaces. + // Note that this permission requires read and write access to certain system spaces. + PermissionCreate Permission = "create" + // PermissionAlter allows altering objects of the specified type. + // Note that this permission requires read and write access to certain system spaces. + PermissionAlter Permission = "alter" + // PermissionDrop allows dropping objects of the specified type. + // Note that this permission requires read and write access to certain system spaces. + PermissionDrop Permission = "drop" + // PermissionExecute for role, + // allows using the specified role. For other object types, allows calling a function. + // Can be used only for role, universe, function, lua_eval, lua_call, sql. + PermissionExecute Permission = "execute" + // PermissionSession allows a user to connect to an instance over IPROTO. + PermissionSession Permission = "session" + // PermissionUsage allows a user to use their privileges on database objects + // (for example, read, write, and alter spaces). + PermissionUsage Permission = "usage" +) + +// Privilege is a structure that is used to create new rights, +// as well as obtain information for rights. +type Privilege struct { + // Permissions is a list of privileges that apply to the privileges object type. + Permissions []Permission + // Type - one of privilege object types (it might be space,function, etc.). + Type PrivilegeType + // Name - can be the name of a function or space, + // and can also be empty in case of universe access + Name string +} + +// UserInfoResponse represents the response to a user info request. +type UserInfoResponse struct { + Privileges []Privilege +} + +// DecodeMsgpack decodes the response from Tarantool in Msgpack format. +func (uer *UserInfoResponse) DecodeMsgpack(d *msgpack.Decoder) error { + rawArr := make([][][3]string, 0) + + err := d.Decode(&rawArr) + switch { + case err != nil: + return err + case len(rawArr) != 1: + return fmt.Errorf("protocol violation; expected 1 array, got %d", len(rawArr)) + } + + privileges := make([]Privilege, len(rawArr[0])) + + for i, rawPrivileges := range rawArr[0] { + strPerms := strings.Split(rawPrivileges[0], ",") + + perms := make([]Permission, len(strPerms)) + for j, strPerm := range strPerms { + perms[j] = Permission(strPerm) + } + + privileges[i] = Privilege{ + Permissions: perms, + Type: PrivilegeType(rawPrivileges[1]), + Name: rawPrivileges[2], + } + } + + uer.Privileges = privileges + + return nil +} + +// Info returns a list of user privileges according to the box.schema.user.info method call. +func (u *SchemaUser) Info(ctx context.Context, username string) ([]Privilege, error) { + req := NewUserInfoRequest(username).Context(ctx) + + resp := &UserInfoResponse{} + fut := u.conn.Do(req) + + err := fut.GetTyped(resp) + if err != nil { + return nil, err + } + + return resp.Privileges, nil +} + +// prepareGrantAndRevokeArgs prepares the arguments for granting or revoking user permissions. +// It accepts a username, a privilege, and options for either granting or revoking. +// The generic type T can be UserGrantOptions or UserRevokeOptions. +func prepareGrantAndRevokeArgs[T UserGrantOptions | UserRevokeOptions](username string, + privilege Privilege, opts T) []interface{} { + + args := []interface{}{username} // Initialize args slice with the username. + + switch privilege.Type { + case PrivilegeUniverse: + // Preparing arguments for granting permissions at the universe level. + // box.schema.user.grant(username, permissions, 'universe'[, nil, {options}]) + strPerms := make([]string, len(privilege.Permissions)) + for i, perm := range privilege.Permissions { + strPerms[i] = string(perm) // Convert each Permission to a string. + } + + reqPerms := strings.Join(strPerms, ",") // Join permissions into a single string. + + // Append universe-specific arguments to args. + args = append(args, reqPerms, string(privilege.Type), nil, opts) + case PrivilegeRole: + // Handling the case where the object type is a role name. + // Append role-specific arguments to args. + args = append(args, privilege.Name, nil, nil, opts) + default: + // Preparing arguments for granting permissions on a specific object. + strPerms := make([]string, len(privilege.Permissions)) + for i, perm := range privilege.Permissions { + strPerms[i] = string(perm) // Convert each Permission to a string. + } + + reqPerms := strings.Join(strPerms, ",") // Join permissions into a single string. + // box.schema.user.grant(username, permissions, object-type, object-name[, {options}]) + // Example: box.schema.user.grant('testuser', 'read', 'space', 'writers') + args = append(args, reqPerms, string(privilege.Type), privilege.Name, opts) + } + + return args // Return the prepared arguments. +} + +// UserGrantOptions holds options for granting permissions to a user. +type UserGrantOptions struct { + Grantor string `msgpack:"grantor,omitempty"` // Optional grantor name. + IfNotExists bool `msgpack:"if_not_exists"` // Option to skip if the grant already exists. +} + +// UserGrantRequest wraps a Tarantool call request for granting user permissions. +type UserGrantRequest struct { + *tarantool.CallRequest // Underlying Tarantool call request. +} + +// NewUserGrantRequest creates a new UserGrantRequest based on provided parameters. +func NewUserGrantRequest(username string, privilege Privilege, + opts UserGrantOptions) UserGrantRequest { + args := prepareGrantAndRevokeArgs[UserGrantOptions](username, privilege, opts) + + // Create a new call request for the box.schema.user.grant method with the given args. + callReq := tarantool.NewCallRequest("box.schema.user.grant").Args(args) + + return UserGrantRequest{callReq} // Return the UserGrantRequest. +} + +// UserGrantResponse represents the response from a user grant request. +type UserGrantResponse struct{} + +// Grant executes the user grant operation in Tarantool, returning an error if it fails. +func (u *SchemaUser) Grant(ctx context.Context, username string, privilege Privilege, + opts UserGrantOptions) error { + req := NewUserGrantRequest(username, privilege, opts).Context(ctx) + + resp := &UserGrantResponse{} // Initialize a response object. + fut := u.conn.Do(req) // Execute the request. + + err := fut.GetTyped(resp) // Get the typed response and check for errors. + if err != nil { + return err // Return any errors encountered. + } + + return nil // Return nil if the operation was successful. +} + +// UserRevokeOptions holds options for revoking permissions from a user. +type UserRevokeOptions struct { + IfExists bool `msgpack:"if_exists"` // Option to skip if the revoke does not exist. +} + +// UserRevokeRequest wraps a Tarantool call request for revoking user permissions. +type UserRevokeRequest struct { + *tarantool.CallRequest // Underlying Tarantool call request. +} + +// UserRevokeResponse represents the response from a user revoke request. +type UserRevokeResponse struct{} + +// NewUserRevokeRequest creates a new UserRevokeRequest based on provided parameters. +func NewUserRevokeRequest(username string, privilege Privilege, + opts UserRevokeOptions) UserRevokeRequest { + args := prepareGrantAndRevokeArgs[UserRevokeOptions](username, privilege, opts) + + // Create a new call request for the box.schema.user.revoke method with the given args. + callReq := tarantool.NewCallRequest("box.schema.user.revoke").Args(args) + + return UserRevokeRequest{callReq} +} + +// Revoke executes the user revoke operation in Tarantool, returning an error if it fails. +func (u *SchemaUser) Revoke(ctx context.Context, username string, privilege Privilege, + opts UserRevokeOptions) error { + req := NewUserRevokeRequest(username, privilege, opts).Context(ctx) + + resp := &UserRevokeResponse{} // Initialize a response object. + fut := u.conn.Do(req) // Execute the request. + + err := fut.GetTyped(resp) // Get the typed response and check for errors. + if err != nil { + return err + } + + return nil +} diff --git a/box/schema_user_test.go b/box/schema_user_test.go new file mode 100644 index 000000000..985e9fe86 --- /dev/null +++ b/box/schema_user_test.go @@ -0,0 +1,195 @@ +package box_test + +import ( + "bytes" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vmihailenco/msgpack/v5" + "github.com/vmihailenco/msgpack/v5/msgpcode" + + "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/box" +) + +func TestUserExistsResponse_DecodeMsgpack(t *testing.T) { + tCases := map[bool]func() *bytes.Buffer{ + true: func() *bytes.Buffer { + buf := bytes.NewBuffer(nil) + buf.WriteByte(msgpcode.FixedArrayLow | byte(1)) + buf.WriteByte(msgpcode.True) + + return buf + }, + false: func() *bytes.Buffer { + buf := bytes.NewBuffer(nil) + buf.WriteByte(msgpcode.FixedArrayLow | byte(1)) + buf.WriteByte(msgpcode.False) + + return buf + }, + } + + for tCaseBool, tCaseBuf := range tCases { + tCaseBool := tCaseBool + tCaseBuf := tCaseBuf() + + t.Run(fmt.Sprintf("case: %t", tCaseBool), func(t *testing.T) { + t.Parallel() + + resp := box.UserExistsResponse{} + + require.NoError(t, resp.DecodeMsgpack(msgpack.NewDecoder(tCaseBuf))) + require.Equal(t, tCaseBool, resp.Exists) + }) + } + +} + +func TestUserPasswordResponse_DecodeMsgpack(t *testing.T) { + tCases := []string{ + "test", + "$tr0ng_pass", + } + + for _, tCase := range tCases { + tCase := tCase + + t.Run(tCase, func(t *testing.T) { + t.Parallel() + buf := bytes.NewBuffer(nil) + buf.WriteByte(msgpcode.FixedArrayLow | byte(1)) + + bts, err := msgpack.Marshal(tCase) + require.NoError(t, err) + buf.Write(bts) + + resp := box.UserPasswordResponse{} + + err = resp.DecodeMsgpack(msgpack.NewDecoder(buf)) + require.NoError(t, err) + require.Equal(t, tCase, resp.Hash) + }) + } + +} + +func FuzzUserPasswordResponse_DecodeMsgpack(f *testing.F) { + f.Fuzz(func(t *testing.T, orig string) { + buf := bytes.NewBuffer(nil) + buf.WriteByte(msgpcode.FixedArrayLow | byte(1)) + + bts, err := msgpack.Marshal(orig) + require.NoError(t, err) + buf.Write(bts) + + resp := box.UserPasswordResponse{} + + err = resp.DecodeMsgpack(msgpack.NewDecoder(buf)) + require.NoError(t, err) + require.Equal(t, orig, resp.Hash) + }) +} + +func TestNewUserExistsRequest(t *testing.T) { + t.Parallel() + + req := box.UserExistsRequest{} + + require.NotPanics(t, func() { + req = box.NewUserExistsRequest("test") + }) + + require.Implements(t, (*tarantool.Request)(nil), req) +} + +func TestNewUserCreateRequest(t *testing.T) { + t.Parallel() + + req := box.UserCreateRequest{} + + require.NotPanics(t, func() { + req = box.NewUserCreateRequest("test", box.UserCreateOptions{}) + }) + + require.Implements(t, (*tarantool.Request)(nil), req) +} + +func TestNewUserDropRequest(t *testing.T) { + t.Parallel() + + req := box.UserDropRequest{} + + require.NotPanics(t, func() { + req = box.NewUserDropRequest("test", box.UserDropOptions{}) + }) + + require.Implements(t, (*tarantool.Request)(nil), req) +} + +func TestNewUserPasswordRequest(t *testing.T) { + t.Parallel() + + req := box.UserPasswordRequest{} + + require.NotPanics(t, func() { + req = box.NewUserPasswordRequest("test") + }) + + require.Implements(t, (*tarantool.Request)(nil), req) +} + +func TestNewUserPasswdRequest(t *testing.T) { + t.Parallel() + + var err error + req := box.UserPasswdRequest{} + + require.NotPanics(t, func() { + req, err = box.NewUserPasswdRequest("test") + require.NoError(t, err) + }) + + _, err = box.NewUserPasswdRequest() + require.Errorf(t, err, "invalid arguments count") + + require.Implements(t, (*tarantool.Request)(nil), req) +} + +func TestNewUserInfoRequest(t *testing.T) { + t.Parallel() + + var err error + req := box.UserInfoRequest{} + + require.NotPanics(t, func() { + req = box.NewUserInfoRequest("test") + require.NoError(t, err) + }) + + require.Implements(t, (*tarantool.Request)(nil), req) +} + +func TestNewUserGrantRequest(t *testing.T) { + t.Parallel() + + var err error + req := box.UserGrantRequest{} + + require.NotPanics(t, func() { + req = box.NewUserGrantRequest("test", box.Privilege{ + Permissions: []box.Permission{ + box.PermissionAlter, + box.PermissionCreate, + box.PermissionDrop, + }, + Type: box.PrivilegeUniverse, + Name: "test", + }, box.UserGrantOptions{IfNotExists: true}) + require.NoError(t, err) + }) + + assert.Implements(t, (*tarantool.Request)(nil), req) +} diff --git a/box/session.go b/box/session.go new file mode 100644 index 000000000..5a01b32a3 --- /dev/null +++ b/box/session.go @@ -0,0 +1,57 @@ +package box + +import ( + "context" + + "github.com/tarantool/go-tarantool/v3" +) + +// Session struct represents a connection session to Tarantool. +type Session struct { + conn tarantool.Doer // Connection interface for interacting with Tarantool. +} + +// newSession creates a new Session instance, taking a Tarantool connection as an argument. +func newSession(conn tarantool.Doer) *Session { + return &Session{conn: conn} // Pass the connection to the Session structure. +} + +// Session method returns a new Session object associated with the Box instance. +func (b *Box) Session() *Session { + return newSession(b.conn) +} + +// SessionSuRequest struct wraps a Tarantool call request specifically for session switching. +type SessionSuRequest struct { + *tarantool.CallRequest // Underlying Tarantool call request. +} + +// NewSessionSuRequest creates a new SessionSuRequest for switching session to a specified username. +// It returns an error if any execute functions are provided, as they are not supported now. +func NewSessionSuRequest(username string) (SessionSuRequest, error) { + args := []interface{}{username} // Create args slice with the username. + + // Create a new call request for the box.session.su method with the given args. + callReq := tarantool.NewCallRequest("box.session.su").Args(args) + + return SessionSuRequest{ + callReq, // Return the new SessionSuRequest containing the call request. + }, nil +} + +// Su method is used to switch the session to the specified username. +// It sends the request to Tarantool and returns an error. +func (s *Session) Su(ctx context.Context, username string) error { + // Create a request and send it to Tarantool. + req, err := NewSessionSuRequest(username) + if err != nil { + return err // Return any errors encountered while creating the request. + } + + req.Context(ctx) // Attach the context to the request for cancellation and timeout. + + // Execute the request and return the future response, or an error. + fut := s.conn.Do(req) + _, err = fut.GetResponse() + return err +} diff --git a/box/session_test.go b/box/session_test.go new file mode 100644 index 000000000..188360251 --- /dev/null +++ b/box/session_test.go @@ -0,0 +1,20 @@ +package box_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/tarantool/go-tarantool/v3/box" + th "github.com/tarantool/go-tarantool/v3/test_helpers" +) + +func TestBox_Session(t *testing.T) { + b := box.MustNew(th.Ptr(th.NewMockDoer(t))) + require.NotNil(t, b.Session()) +} + +func TestNewSessionSuRequest(t *testing.T) { + _, err := box.NewSessionSuRequest("admin") + require.NoError(t, err) +} diff --git a/box/tarantool_test.go b/box/tarantool_test.go new file mode 100644 index 000000000..ae47932d8 --- /dev/null +++ b/box/tarantool_test.go @@ -0,0 +1,688 @@ +package box_test + +import ( + "context" + "errors" + "log" + "os" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "github.com/tarantool/go-iproto" + + "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/box" + "github.com/tarantool/go-tarantool/v3/test_helpers" +) + +var server = "127.0.0.1:3013" +var dialer = tarantool.NetDialer{ + Address: server, + User: "test", + Password: "test", +} + +func validateInfo(t testing.TB, info box.Info) { + var err error + + // Check all fields run correctly. + _, err = uuid.Parse(info.UUID) + require.NoErrorf(t, err, "validate instance uuid is valid") + + require.NotEmpty(t, info.Version) + // Check that pid parsed correctly. + require.NotEqual(t, info.PID, 0) + + // Check replication is parsed correctly. + require.NotEmpty(t, info.Replication) + + // Check one replica uuid is equal system uuid. + require.Equal(t, info.UUID, info.Replication[1].UUID) +} + +func TestBox_Sugar_Info(t *testing.T) { + ctx := context.TODO() + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + + b := box.MustNew(conn) + + info, err := b.Info() + require.NoError(t, err) + + validateInfo(t, info) +} + +func TestBox_Info(t *testing.T) { + ctx := context.TODO() + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + + fut := conn.Do(box.NewInfoRequest()) + require.NotNil(t, fut) + + resp := &box.InfoResponse{} + err = fut.GetTyped(resp) + require.NoError(t, err) + + validateInfo(t, resp.Info) +} + +func TestBox_Sugar_Schema_UserCreate_NoError(t *testing.T) { + const ( + username = "user_create_no_error" + password = "user_create_no_error" + ) + + defer cleanupUser(username) + + ctx := context.TODO() + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + + b := box.MustNew(conn) + + // Create new user. + err = b.Schema().User().Create(ctx, username, box.UserCreateOptions{Password: password}) + require.NoError(t, err) +} + +func TestBox_Sugar_Schema_UserCreate_CanConnectWithNewCred(t *testing.T) { + const ( + username = "can_connect_with_new_cred" + password = "can_connect_with_new_cred" + ) + + defer cleanupUser(username) + + ctx := context.TODO() + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + + b := box.MustNew(conn) + + // Create new user. + err = b.Schema().User().Create(ctx, username, box.UserCreateOptions{Password: password}) + require.NoError(t, err) + + // Can connect with new credentials + // Check that password is valid, and we can connect to tarantool with such credentials + var newUserDialer = tarantool.NetDialer{ + Address: server, + User: username, + Password: password, + } + + // We can connect with our new credentials + newUserConn, err := tarantool.Connect(ctx, newUserDialer, tarantool.Opts{}) + require.NoError(t, err) + require.NotNil(t, newUserConn) + require.NoError(t, newUserConn.Close()) +} + +func TestBox_Sugar_Schema_UserCreate_AlreadyExists(t *testing.T) { + const ( + username = "create_already_exists" + password = "create_already_exists" + ) + + defer cleanupUser(username) + + ctx := context.TODO() + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + + b := box.MustNew(conn) + + // Create new user. + err = b.Schema().User().Create(ctx, username, box.UserCreateOptions{Password: password}) + require.NoError(t, err) + + // Create user already exists error. + // Get error that user already exists. + err = b.Schema().User().Create(ctx, username, box.UserCreateOptions{Password: password}) + require.Error(t, err) + + // Require that error code is ER_USER_EXISTS. + var boxErr tarantool.Error + errors.As(err, &boxErr) + require.Equal(t, iproto.ER_USER_EXISTS, boxErr.Code) +} + +func TestBox_Sugar_Schema_UserCreate_ExistsTrue(t *testing.T) { + const ( + username = "exists_check" + password = "exists_check" + ) + + defer cleanupUser(username) + + ctx := context.TODO() + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + + b := box.MustNew(conn) + + // Create new user. + err = b.Schema().User().Create(ctx, username, box.UserCreateOptions{Password: password}) + require.NoError(t, err) + + // Check that already exists by exists call procedure + exists, err := b.Schema().User().Exists(ctx, username) + require.True(t, exists) + require.NoError(t, err) + +} + +func TestBox_Sugar_Schema_UserCreate_IfNotExistsNoErr(t *testing.T) { + const ( + username = "if_not_exists_no_err" + password = "if_not_exists_no_err" + ) + + defer cleanupUser(username) + + ctx := context.TODO() + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + + b := box.MustNew(conn) + + // Create new user. + err = b.Schema().User().Create(ctx, username, box.UserCreateOptions{Password: password}) + require.NoError(t, err) + + // Again create such user. + err = b.Schema().User().Create(ctx, username, box.UserCreateOptions{ + Password: password, + IfNotExists: true, + }) + require.NoError(t, err) +} + +func TestBox_Sugar_Schema_UserPassword(t *testing.T) { + const ( + password = "passwd" + ) + + ctx := context.TODO() + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + + b := box.MustNew(conn) + + // Require password hash. + hash, err := b.Schema().User().Password(ctx, password) + require.NoError(t, err) + require.NotEmpty(t, hash) +} + +func TestBox_Sugar_Schema_UserDrop_AfterCreate(t *testing.T) { + const ( + username = "to_drop_after_create" + password = "to_drop_after_create" + ) + + ctx := context.TODO() + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + + b := box.MustNew(conn) + + // Create new user + err = b.Schema().User().Create(ctx, username, box.UserCreateOptions{Password: password}) + require.NoError(t, err) + + // Try to drop user + err = b.Schema().User().Drop(ctx, username, box.UserDropOptions{}) + require.NoError(t, err) +} + +func TestBox_Sugar_Schema_UserDrop_DoubleDrop(t *testing.T) { + const ( + username = "to_drop_double_drop" + password = "to_drop_double_drop" + ) + + ctx := context.TODO() + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + + b := box.MustNew(conn) + + // Create new user + err = b.Schema().User().Create(ctx, username, box.UserCreateOptions{Password: password}) + require.NoError(t, err) + + // Try to drop user first time + err = b.Schema().User().Drop(ctx, username, box.UserDropOptions{}) + require.NoError(t, err) + + // Error double drop with IfExists: false option + err = b.Schema().User().Drop(ctx, username, box.UserDropOptions{}) + require.Error(t, err) + + // Require no error with IfExists: true option. + err = b.Schema().User().Drop(ctx, username, + box.UserDropOptions{IfExists: true}) + require.NoError(t, err) +} + +func TestBox_Sugar_Schema_UserDrop_UnknownUser(t *testing.T) { + t.Parallel() + + ctx := context.TODO() + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + + b := box.MustNew(conn) + + // Require error cause user not exists + err = b.Schema().User().Drop(ctx, "some_strange_not_existing_name", box.UserDropOptions{}) + require.Error(t, err) + + var boxErr tarantool.Error + + // Require that error code is ER_NO_SUCH_USER + errors.As(err, &boxErr) + require.Equal(t, iproto.ER_NO_SUCH_USER, boxErr.Code) +} + +func TestSchemaUser_Passwd_NotFound(t *testing.T) { + ctx := context.TODO() + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + + b := box.MustNew(conn) + + err = b.Schema().User().Passwd(ctx, "not-exists-passwd", "new_password") + require.Error(t, err) + // Require that error code is ER_USER_EXISTS. + var boxErr tarantool.Error + errors.As(err, &boxErr) + require.Equal(t, iproto.ER_NO_SUCH_USER, boxErr.Code) +} + +func TestSchemaUser_Passwd_Ok(t *testing.T) { + const ( + username = "new_password_user" + startPassword = "new_password" + endPassword = "end_password" + ) + + defer cleanupUser(username) + + ctx := context.TODO() + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + + b := box.MustNew(conn) + + // New user change password and connect + + err = b.Schema().User().Create(ctx, username, + box.UserCreateOptions{Password: startPassword, IfNotExists: true}) + require.NoError(t, err) + + err = b.Schema().User().Passwd(ctx, username, endPassword) + require.NoError(t, err) + + dialer := dialer + dialer.User = username + dialer.Password = startPassword + + // Can't connect with old password. + _, err = tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.Error(t, err, "can't connect with old password") + + // Ok connection with new password. + dialer.Password = endPassword + _, err = tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err, "ok connection with new password") +} + +func TestSchemaUser_Passwd_WithoutGrants(t *testing.T) { + const ( + username = "new_password_user_fail_conn" + startPassword = "new_password" + endPassword = "end_password" + ) + defer cleanupUser(username) + + ctx := context.TODO() + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + + b := box.MustNew(conn) + + err = b.Schema().User().Create(ctx, username, + box.UserCreateOptions{Password: startPassword, IfNotExists: true}) + require.NoError(t, err) + + dialer := dialer + dialer.User = username + dialer.Password = startPassword + + conn2Fail, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + require.NotNil(t, conn2Fail) + + bFail := box.MustNew(conn2Fail) + + // can't change self user password without grants + err = bFail.Schema().User().Passwd(ctx, endPassword) + require.Error(t, err) + + // Require that error code is AccessDeniedError, + var boxErr tarantool.Error + errors.As(err, &boxErr) + require.Equal(t, iproto.ER_ACCESS_DENIED, boxErr.Code) + +} + +func TestSchemaUser_Info_TestUserCorrect(t *testing.T) { + t.Parallel() + + ctx := context.TODO() + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + + b := box.MustNew(conn) + + privileges, err := b.Schema().User().Info(ctx, dialer.User) + require.NoError(t, err) + require.NotNil(t, privileges) + + require.Len(t, privileges, 4) +} + +func TestSchemaUser_Info_NonExistsUser(t *testing.T) { + t.Parallel() + + ctx := context.TODO() + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + + b := box.MustNew(conn) + + privileges, err := b.Schema().User().Info(ctx, "non-existing") + require.Error(t, err) + require.Nil(t, privileges) +} + +func TestBox_Sugar_Schema_UserGrant_NoSu(t *testing.T) { + const ( + username = "to_grant_no_su" + password = "to_grant_no_su" + ) + + defer cleanupUser(username) + + ctx := context.TODO() + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + + b := box.MustNew(conn) + + err = b.Schema().User().Create(ctx, username, box.UserCreateOptions{Password: password}) + require.NoError(t, err) + + err = b.Schema().User().Grant(ctx, username, box.Privilege{ + Permissions: []box.Permission{ + box.PermissionRead, + }, + Type: box.PrivilegeSpace, + Name: "space1", + }, box.UserGrantOptions{IfNotExists: false}) + require.Error(t, err) + + // Require that error code is ER_ACCESS_DENIED. + var boxErr tarantool.Error + errors.As(err, &boxErr) + require.Equal(t, iproto.ER_ACCESS_DENIED, boxErr.Code) +} + +func TestBox_Sugar_Schema_UserGrant_WithSu(t *testing.T) { + const ( + username = "to_grant_with_su" + password = "to_grant_with_su" + ) + + defer cleanupUser(username) + + ctx := context.TODO() + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + + b := box.MustNew(conn) + + err = b.Schema().User().Create(ctx, username, box.UserCreateOptions{Password: password}) + require.NoError(t, err) + + startPrivilages, err := b.Schema().User().Info(ctx, username) + require.NoError(t, err) + + err = b.Session().Su(ctx, "admin") + require.NoError(t, err) + + require.NoError(t, err, "default user in super group") + + newPrivilege := box.Privilege{ + Permissions: []box.Permission{ + box.PermissionRead, + }, + Type: box.PrivilegeSpace, + Name: "space1", + } + + require.NotContains(t, startPrivilages, newPrivilege) + + err = b.Schema().User().Grant(ctx, + username, + newPrivilege, + box.UserGrantOptions{ + IfNotExists: false, + }) + require.NoError(t, err) + + endPrivileges, err := b.Schema().User().Info(ctx, username) + require.NoError(t, err) + require.NotEqual(t, startPrivilages, endPrivileges) + require.Contains(t, endPrivileges, newPrivilege) +} + +func TestSchemaUser_Revoke_WithoutSu(t *testing.T) { + const ( + username = "to_revoke_without_su" + password = "to_revoke_without_su" + ) + + defer cleanupUser(username) + + ctx := context.TODO() + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + + b := box.MustNew(conn) + + err = b.Schema().User().Create(ctx, username, box.UserCreateOptions{Password: password}) + require.NoError(t, err) + + // Can`t revoke without su permissions. + err = b.Schema().User().Grant(ctx, username, box.Privilege{ + Permissions: []box.Permission{ + box.PermissionRead, + }, + Type: box.PrivilegeSpace, + Name: "space1", + }, box.UserGrantOptions{IfNotExists: false}) + require.Error(t, err) + + // Require that error code is ER_ACCESS_DENIED. + var boxErr tarantool.Error + errors.As(err, &boxErr) + require.Equal(t, iproto.ER_ACCESS_DENIED, boxErr.Code) +} + +func TestSchemaUser_Revoke_WithSu(t *testing.T) { + const ( + username = "to_revoke_with_su" + password = "to_revoke_with_su" + ) + + defer cleanupUser(username) + + ctx := context.TODO() + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + + b := box.MustNew(conn) + + err = b.Schema().User().Create(ctx, username, box.UserCreateOptions{Password: password}) + require.NoError(t, err) + + // Can revoke with su admin permissions. + startPrivileges, err := b.Schema().User().Info(ctx, username) + require.NoError(t, err) + + err = b.Session().Su(ctx, "admin") + require.NoError(t, err) + + require.NoError(t, err, "dialer user in super group") + + require.NotEmpty(t, startPrivileges) + // Let's choose random first privilege. + examplePriv := startPrivileges[0] + + // Revoke it. + err = b.Schema().User().Revoke(ctx, + username, + examplePriv, + box.UserRevokeOptions{ + IfExists: false, + }) + + require.NoError(t, err) + + privileges, err := b.Schema().User().Info(ctx, username) + require.NoError(t, err) + + require.NotEqual(t, startPrivileges, privileges) + require.NotContains(t, privileges, examplePriv) +} + +func TestSchemaUser_Revoke_NonExistsPermission(t *testing.T) { + const ( + username = "to_revoke_non_exists_permission" + password = "to_revoke_non_exists_permission" + ) + + defer cleanupUser(username) + + ctx := context.TODO() + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + + b := box.MustNew(conn) + + err = b.Schema().User().Create(ctx, username, box.UserCreateOptions{Password: password}) + require.NoError(t, err) + + startPrivileges, err := b.Schema().User().Info(ctx, username) + require.NoError(t, err) + + err = b.Session().Su(ctx, "admin") + require.NoError(t, err) + + require.NoError(t, err, "dialer user in super group") + + require.NotEmpty(t, startPrivileges) + examplePriv := box.Privilege{ + Permissions: []box.Permission{box.PermissionRead}, + Name: "non_existing_space", + Type: box.PrivilegeSpace, + } + + err = b.Schema().User().Revoke(ctx, + username, + examplePriv, + box.UserRevokeOptions{ + IfExists: false, + }) + + require.Error(t, err) +} + +func TestSession_Su_AdminPermissions(t *testing.T) { + ctx := context.TODO() + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + require.NoError(t, err) + + b := box.MustNew(conn) + + err = b.Session().Su(ctx, "admin") + require.NoError(t, err) +} + +func cleanupUser(username string) { + ctx := context.TODO() + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + if err != nil { + log.Fatal(err) + } + + b := box.MustNew(conn) + + err = b.Schema().User().Drop(ctx, username, box.UserDropOptions{}) + if err != nil { + log.Fatal(err) + } +} + +func runTestMain(m *testing.M) int { + instance, err := test_helpers.StartTarantool(test_helpers.StartOpts{ + Dialer: dialer, + InitScript: "testdata/config.lua", + Listen: server, + WaitStart: 100 * time.Millisecond, + ConnectRetry: 10, + RetryTimeout: 500 * time.Millisecond, + }) + if err != nil { + log.Printf("Failed to prepare test Tarantool: %s", err) + return 1 + } + + defer test_helpers.StopTarantoolWithCleanup(instance) + + return m.Run() +} + +func TestMain(m *testing.M) { + code := runTestMain(m) + os.Exit(code) +} diff --git a/box/testdata/config.lua b/box/testdata/config.lua new file mode 100644 index 000000000..3d0db6acb --- /dev/null +++ b/box/testdata/config.lua @@ -0,0 +1,18 @@ +-- Do not set listen for now so connector won't be +-- able to send requests until everything is configured. +box.cfg{ + work_dir = os.getenv("TEST_TNT_WORK_DIR"), +} + +box.schema.space.create('space1') + +box.schema.user.create('test', { password = 'test' , if_not_exists = true }) +box.schema.user.grant('test', 'super', nil, nil, { if_not_exists = true }) + +-- Set listen only when every other thing is configured. +box.cfg{ + listen = os.getenv("TEST_TNT_LISTEN"), + replication = { + os.getenv("TEST_TNT_LISTEN"), + }, +} diff --git a/boxerror.go b/boxerror.go new file mode 100644 index 000000000..2fb18268f --- /dev/null +++ b/boxerror.go @@ -0,0 +1,309 @@ +package tarantool + +import ( + "bytes" + "fmt" + + "github.com/vmihailenco/msgpack/v5" +) + +const errorExtID = 3 + +//go:generate go tool gentypes -ext-code 3 BoxError + +const ( + keyErrorStack = 0x00 + keyErrorType = 0x00 + keyErrorFile = 0x01 + keyErrorLine = 0x02 + keyErrorMessage = 0x03 + keyErrorErrno = 0x04 + keyErrorErrcode = 0x05 + keyErrorFields = 0x06 +) + +// BoxError is a type representing Tarantool `box.error` object: a single +// MP_ERROR_STACK object with a link to the previous stack error. +// See https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_error/error/ +// +// Since 1.10.0 +type BoxError struct { + // Type is error type that implies its source (for example, "ClientError"). + Type string + // File is a source code file where the error was caught. + File string + // Line is a number of line in the source code file where the error was caught. + Line uint64 + // Msg is the text of reason. + Msg string + // Errno is the ordinal number of the error. + Errno uint64 + // Code is the number of the error as defined in `errcode.h`. + Code uint64 + // Fields are additional fields depending on error type. For example, if + // type is "AccessDeniedError", then it will include "object_type", + // "object_name", "access_type". + Fields map[string]interface{} + // Prev is the previous error in stack. + Prev *BoxError +} + +// Error converts a BoxError to a string. +func (e *BoxError) Error() string { + s := fmt.Sprintf("%s (%s, code 0x%x), see %s line %d", + e.Msg, e.Type, e.Code, e.File, e.Line) + + if e.Prev != nil { + return fmt.Sprintf("%s: %s", s, e.Prev) + } + + return s +} + +// Depth computes the count of errors in stack, including the current one. +func (e *BoxError) Depth() int { + depth := int(0) + + cur := e + for cur != nil { + cur = cur.Prev + depth++ + } + + return depth +} + +func decodeBoxError(d *msgpack.Decoder) (*BoxError, error) { + var l, larr, l1, l2 int + var errorStack []BoxError + var err error + + if l, err = d.DecodeMapLen(); err != nil { + return nil, err + } + + for ; l > 0; l-- { + var cd int + if cd, err = d.DecodeInt(); err != nil { + return nil, err + } + switch cd { + case keyErrorStack: + if larr, err = d.DecodeArrayLen(); err != nil { + return nil, err + } + + errorStack = make([]BoxError, larr) + + for i := 0; i < larr; i++ { + if l1, err = d.DecodeMapLen(); err != nil { + return nil, err + } + + for ; l1 > 0; l1-- { + var cd1 int + if cd1, err = d.DecodeInt(); err != nil { + return nil, err + } + switch cd1 { + case keyErrorType: + if errorStack[i].Type, err = d.DecodeString(); err != nil { + return nil, err + } + case keyErrorFile: + if errorStack[i].File, err = d.DecodeString(); err != nil { + return nil, err + } + case keyErrorLine: + if errorStack[i].Line, err = d.DecodeUint64(); err != nil { + return nil, err + } + case keyErrorMessage: + if errorStack[i].Msg, err = d.DecodeString(); err != nil { + return nil, err + } + case keyErrorErrno: + if errorStack[i].Errno, err = d.DecodeUint64(); err != nil { + return nil, err + } + case keyErrorErrcode: + if errorStack[i].Code, err = d.DecodeUint64(); err != nil { + return nil, err + } + case keyErrorFields: + var mapk string + var mapv interface{} + + errorStack[i].Fields = make(map[string]interface{}) + + if l2, err = d.DecodeMapLen(); err != nil { + return nil, err + } + for ; l2 > 0; l2-- { + if mapk, err = d.DecodeString(); err != nil { + return nil, err + } + if mapv, err = d.DecodeInterface(); err != nil { + return nil, err + } + errorStack[i].Fields[mapk] = mapv + } + default: + if err = d.Skip(); err != nil { + return nil, err + } + } + } + + if i > 0 { + errorStack[i-1].Prev = &errorStack[i] + } + } + default: + if err = d.Skip(); err != nil { + return nil, err + } + } + } + + if len(errorStack) == 0 { + return nil, fmt.Errorf("msgpack: unexpected empty BoxError stack on decode") + } + + return &errorStack[0], nil +} + +func encodeBoxError(enc *msgpack.Encoder, boxError *BoxError) error { + if boxError == nil { + return fmt.Errorf("msgpack: unexpected nil BoxError on encode") + } + + if err := enc.EncodeMapLen(1); err != nil { + return err + } + if err := enc.EncodeUint(keyErrorStack); err != nil { + return err + } + + var stackDepth = boxError.Depth() + if err := enc.EncodeArrayLen(stackDepth); err != nil { + return err + } + + for ; stackDepth > 0; stackDepth-- { + fieldsLen := len(boxError.Fields) + + if fieldsLen > 0 { + if err := enc.EncodeMapLen(7); err != nil { + return err + } + } else { + if err := enc.EncodeMapLen(6); err != nil { + return err + } + } + + if err := enc.EncodeUint(keyErrorType); err != nil { + return err + } + if err := enc.EncodeString(boxError.Type); err != nil { + return err + } + + if err := enc.EncodeUint(keyErrorFile); err != nil { + return err + } + if err := enc.EncodeString(boxError.File); err != nil { + return err + } + + if err := enc.EncodeUint(keyErrorLine); err != nil { + return err + } + if err := enc.EncodeUint64(boxError.Line); err != nil { + return err + } + + if err := enc.EncodeUint(keyErrorMessage); err != nil { + return err + } + if err := enc.EncodeString(boxError.Msg); err != nil { + return err + } + + if err := enc.EncodeUint(keyErrorErrno); err != nil { + return err + } + if err := enc.EncodeUint64(boxError.Errno); err != nil { + return err + } + + if err := enc.EncodeUint(keyErrorErrcode); err != nil { + return err + } + if err := enc.EncodeUint64(boxError.Code); err != nil { + return err + } + + if fieldsLen > 0 { + if err := enc.EncodeUint(keyErrorFields); err != nil { + return err + } + + if err := enc.EncodeMapLen(fieldsLen); err != nil { + return err + } + + for k, v := range boxError.Fields { + if err := enc.EncodeString(k); err != nil { + return err + } + if err := enc.Encode(v); err != nil { + return err + } + } + } + + if stackDepth > 1 { + boxError = boxError.Prev + } + } + + return nil +} + +// UnmarshalMsgpack deserializes a BoxError value from a MessagePack +// representation. +func (e *BoxError) UnmarshalMsgpack(b []byte) error { + if e == nil { + panic("cannot unmarshal to a nil pointer") + } + + buf := bytes.NewBuffer(b) + + dec := getDecoder(buf) + defer putDecoder(dec) + + if val, err := decodeBoxError(dec); err != nil { + return err + } else { + *e = *val + return nil + } +} + +// MarshalMsgpack serializes the BoxError into a MessagePack representation. +func (e *BoxError) MarshalMsgpack() ([]byte, error) { + var buf bytes.Buffer + + enc := msgpack.NewEncoder(&buf) + if err := encodeBoxError(enc, e); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +func init() { + msgpack.RegisterExt(errorExtID, (*BoxError)(nil)) +} diff --git a/boxerror_gen.go b/boxerror_gen.go new file mode 100644 index 000000000..07a1be695 --- /dev/null +++ b/boxerror_gen.go @@ -0,0 +1,241 @@ +// Code generated by github.com/tarantool/go-option; DO NOT EDIT. + +package tarantool + +import ( + "fmt" + + "github.com/vmihailenco/msgpack/v5" + "github.com/vmihailenco/msgpack/v5/msgpcode" + + "github.com/tarantool/go-option" +) + +// OptionalBoxError represents an optional value of type BoxError. +// It can either hold a valid BoxError (IsSome == true) or be empty (IsZero == true). +type OptionalBoxError struct { + value BoxError + exists bool +} + +// SomeOptionalBoxError creates an optional OptionalBoxError with the given BoxError value. +// The returned OptionalBoxError will have IsSome() == true and IsZero() == false. +func SomeOptionalBoxError(value BoxError) OptionalBoxError { + return OptionalBoxError{ + value: value, + exists: true, + } +} + +// NoneOptionalBoxError creates an empty optional OptionalBoxError value. +// The returned OptionalBoxError will have IsSome() == false and IsZero() == true. +// +// Example: +// +// o := NoneOptionalBoxError() +// if o.IsZero() { +// fmt.Println("value is absent") +// } +func NoneOptionalBoxError() OptionalBoxError { + return OptionalBoxError{} +} + +func (o OptionalBoxError) newEncodeError(err error) error { + if err == nil { + return nil + } + return &option.EncodeError{ + Type: "OptionalBoxError", + Parent: err, + } +} + +func (o OptionalBoxError) newDecodeError(err error) error { + if err == nil { + return nil + } + + return &option.DecodeError{ + Type: "OptionalBoxError", + Parent: err, + } +} + +// IsSome returns true if the OptionalBoxError contains a value. +// This indicates the value is explicitly set (not None). +func (o OptionalBoxError) IsSome() bool { + return o.exists +} + +// IsZero returns true if the OptionalBoxError does not contain a value. +// Equivalent to !IsSome(). Useful for consistency with types where +// zero value (e.g. 0, false, zero struct) is valid and needs to be distinguished. +func (o OptionalBoxError) IsZero() bool { + return !o.exists +} + +// IsNil is an alias for IsZero. +// +// This method is provided for compatibility with the msgpack Encoder interface. +func (o OptionalBoxError) IsNil() bool { + return o.IsZero() +} + +// Get returns the stored value and a boolean flag indicating its presence. +// If the value is present, returns (value, true). +// If the value is absent, returns (zero value of BoxError, false). +// +// Recommended usage: +// +// if value, ok := o.Get(); ok { +// // use value +// } +func (o OptionalBoxError) Get() (BoxError, bool) { + return o.value, o.exists +} + +// MustGet returns the stored value if it is present. +// Panics if the value is absent (i.e., IsZero() == true). +// +// Use with caution — only when you are certain the value exists. +// +// Panics with: "optional value is not set" if no value is set. +func (o OptionalBoxError) MustGet() BoxError { + if !o.exists { + panic("optional value is not set") + } + + return o.value +} + +// Unwrap returns the stored value regardless of presence. +// If no value is set, returns the zero value for BoxError. +// +// Warning: Does not check presence. Use IsSome() before calling if you need +// to distinguish between absent value and explicit zero value. +func (o OptionalBoxError) Unwrap() BoxError { + return o.value +} + +// UnwrapOr returns the stored value if present. +// Otherwise, returns the provided default value. +// +// Example: +// +// o := NoneOptionalBoxError() +// v := o.UnwrapOr(someDefaultOptionalBoxError) +func (o OptionalBoxError) UnwrapOr(defaultValue BoxError) BoxError { + if o.exists { + return o.value + } + + return defaultValue +} + +// UnwrapOrElse returns the stored value if present. +// Otherwise, calls the provided function and returns its result. +// Useful when the default value requires computation or side effects. +// +// Example: +// +// o := NoneOptionalBoxError() +// v := o.UnwrapOrElse(func() BoxError { return computeDefault() }) +func (o OptionalBoxError) UnwrapOrElse(defaultValue func() BoxError) BoxError { + if o.exists { + return o.value + } + + return defaultValue() +} + +func (o OptionalBoxError) encodeValue(encoder *msgpack.Encoder) error { + value, err := o.value.MarshalMsgpack() + if err != nil { + return err + } + + err = encoder.EncodeExtHeader(3, len(value)) + if err != nil { + return err + } + + _, err = encoder.Writer().Write(value) + if err != nil { + return err + } + + return nil +} + +// EncodeMsgpack encodes the OptionalBoxError value using MessagePack format. +// - If the value is present, it is encoded as BoxError. +// - If the value is absent (None), it is encoded as nil. +// +// Returns an error if encoding fails. +func (o OptionalBoxError) EncodeMsgpack(encoder *msgpack.Encoder) error { + if o.exists { + return o.newEncodeError(o.encodeValue(encoder)) + } + + return o.newEncodeError(encoder.EncodeNil()) +} + +func (o *OptionalBoxError) decodeValue(decoder *msgpack.Decoder) error { + tp, length, err := decoder.DecodeExtHeader() + switch { + case err != nil: + return o.newDecodeError(err) + case tp != 3: + return o.newDecodeError(fmt.Errorf("invalid extension code: %d", tp)) + } + + a := make([]byte, length) + if err := decoder.ReadFull(a); err != nil { + return o.newDecodeError(err) + } + + if err := o.value.UnmarshalMsgpack(a); err != nil { + return o.newDecodeError(err) + } + + o.exists = true + return nil +} + +func (o *OptionalBoxError) checkCode(code byte) bool { + return msgpcode.IsExt(code) +} + +// DecodeMsgpack decodes a OptionalBoxError value from MessagePack format. +// Supports two input types: +// - nil: interpreted as no value (NoneOptionalBoxError) +// - BoxError: interpreted as a present value (SomeOptionalBoxError) +// +// Returns an error if the input type is unsupported or decoding fails. +// +// After successful decoding: +// - on nil: exists = false, value = default zero value +// - on BoxError: exists = true, value = decoded value +func (o *OptionalBoxError) DecodeMsgpack(decoder *msgpack.Decoder) error { + code, err := decoder.PeekCode() + if err != nil { + return o.newDecodeError(err) + } + + switch { + case code == msgpcode.Nil: + o.exists = false + + return o.newDecodeError(decoder.Skip()) + case o.checkCode(code): + err := o.decodeValue(decoder) + if err != nil { + return o.newDecodeError(err) + } + o.exists = true + + return err + default: + return o.newDecodeError(fmt.Errorf("unexpected code: %d", code)) + } +} diff --git a/boxerror_gen_test.go b/boxerror_gen_test.go new file mode 100644 index 000000000..4d7fada3a --- /dev/null +++ b/boxerror_gen_test.go @@ -0,0 +1,116 @@ +package tarantool + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/vmihailenco/msgpack/v5" +) + +func TestSomeOptionalBoxError(t *testing.T) { + val := BoxError{Code: 1, Msg: "error"} + opt := SomeOptionalBoxError(val) + + assert.True(t, opt.IsSome()) + assert.False(t, opt.IsZero()) + + v, ok := opt.Get() + assert.True(t, ok) + assert.Equal(t, val, v) +} + +func TestNoneOptionalBoxError(t *testing.T) { + opt := NoneOptionalBoxError() + + assert.False(t, opt.IsSome()) + assert.True(t, opt.IsZero()) + + _, ok := opt.Get() + assert.False(t, ok) +} + +func TestOptionalBoxError_MustGet(t *testing.T) { + val := BoxError{Code: 1, Msg: "error"} + optSome := SomeOptionalBoxError(val) + optNone := NoneOptionalBoxError() + + assert.Equal(t, val, optSome.MustGet()) + assert.Panics(t, func() { optNone.MustGet() }) +} + +func TestOptionalBoxError_Unwrap(t *testing.T) { + val := BoxError{Code: 1, Msg: "error"} + optSome := SomeOptionalBoxError(val) + optNone := NoneOptionalBoxError() + + assert.Equal(t, val, optSome.Unwrap()) + assert.Equal(t, BoxError{}, optNone.Unwrap()) +} + +func TestOptionalBoxError_UnwrapOr(t *testing.T) { + val := BoxError{Code: 1, Msg: "error"} + def := BoxError{Code: 2, Msg: "default"} + optSome := SomeOptionalBoxError(val) + optNone := NoneOptionalBoxError() + + assert.Equal(t, val, optSome.UnwrapOr(def)) + assert.Equal(t, def, optNone.UnwrapOr(def)) +} + +func TestOptionalBoxError_UnwrapOrElse(t *testing.T) { + val := BoxError{Code: 1, Msg: "error"} + def := BoxError{Code: 2, Msg: "default"} + optSome := SomeOptionalBoxError(val) + optNone := NoneOptionalBoxError() + + assert.Equal(t, val, optSome.UnwrapOrElse(func() BoxError { return def })) + assert.Equal(t, def, optNone.UnwrapOrElse(func() BoxError { return def })) +} + +func TestOptionalBoxError_EncodeDecodeMsgpack_Some(t *testing.T) { + val := BoxError{Code: 1, Msg: "error"} + some := SomeOptionalBoxError(val) + + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + dec := msgpack.NewDecoder(&buf) + + err := enc.Encode(some) + assert.NoError(t, err) + + var decodedSome OptionalBoxError + err = dec.Decode(&decodedSome) + assert.NoError(t, err) + assert.True(t, decodedSome.IsSome()) + assert.Equal(t, val, decodedSome.Unwrap()) +} + +func TestOptionalBoxError_EncodeDecodeMsgpack_None(t *testing.T) { + none := NoneOptionalBoxError() + + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + dec := msgpack.NewDecoder(&buf) + + err := enc.Encode(none) + assert.NoError(t, err) + + var decodedNone OptionalBoxError + err = dec.Decode(&decodedNone) + assert.NoError(t, err) + assert.True(t, decodedNone.IsZero()) +} + +func TestOptionalBoxError_EncodeDecodeMsgpack_InvalidType(t *testing.T) { + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + dec := msgpack.NewDecoder(&buf) + + err := enc.Encode(123) + assert.NoError(t, err) + + var decodedInvalid OptionalBoxError + err = dec.Decode(&decodedInvalid) + assert.Error(t, err) +} diff --git a/boxerror_test.go b/boxerror_test.go new file mode 100644 index 000000000..acb051d31 --- /dev/null +++ b/boxerror_test.go @@ -0,0 +1,499 @@ +package tarantool_test + +import ( + "fmt" + "regexp" + "testing" + + "github.com/stretchr/testify/require" + "github.com/vmihailenco/msgpack/v5" + + . "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/test_helpers" +) + +var samples = map[string]BoxError{ + "SimpleError": { + Type: "ClientError", + File: "config.lua", + Line: uint64(202), + Msg: "Unknown error", + Errno: uint64(0), + Code: uint64(0), + }, + "AccessDeniedError": { + Type: "AccessDeniedError", + File: "/__w/sdk/sdk/tarantool-2.10/tarantool/src/box/func.c", + Line: uint64(535), + Msg: "Execute access to function 'forbidden_function' is denied for user 'no_grants'", + Errno: uint64(0), + Code: uint64(42), + Fields: map[string]interface{}{ + "object_type": "function", + "object_name": "forbidden_function", + "access_type": "Execute", + }, + }, + "ChainedError": { + Type: "ClientError", + File: "config.lua", + Line: uint64(205), + Msg: "Timeout exceeded", + Errno: uint64(0), + Code: uint64(78), + Prev: &BoxError{ + Type: "ClientError", + File: "config.lua", + Line: uint64(202), + Msg: "Unknown error", + Errno: uint64(0), + Code: uint64(0), + }, + }, +} + +var stringCases = map[string]struct { + e BoxError + s string +}{ + "SimpleError": { + samples["SimpleError"], + "Unknown error (ClientError, code 0x0), see config.lua line 202", + }, + "AccessDeniedError": { + samples["AccessDeniedError"], + "Execute access to function 'forbidden_function' is denied for user " + + "'no_grants' (AccessDeniedError, code 0x2a), see " + + "/__w/sdk/sdk/tarantool-2.10/tarantool/src/box/func.c line 535", + }, + "ChainedError": { + samples["ChainedError"], + "Timeout exceeded (ClientError, code 0x4e), see config.lua line 205: " + + "Unknown error (ClientError, code 0x0), see config.lua line 202", + }, +} + +func TestBoxErrorStringRepr(t *testing.T) { + for name, testcase := range stringCases { + t.Run(name, func(t *testing.T) { + require.Equal(t, testcase.s, testcase.e.Error()) + }) + } +} + +var mpDecodeSamples = map[string]struct { + b []byte + ok bool + err *regexp.Regexp +}{ + "OuterMapInvalidLen": { + []byte{0xc1}, + false, + regexp.MustCompile(`msgpack: (?:unexpected|invalid|unknown) code.c1 decoding map length`), + }, + "OuterMapInvalidKey": { + []byte{0x81, 0xc1}, + false, + regexp.MustCompile(`msgpack: (?:unexpected|invalid|unknown) code.c1 decoding int64`), + }, + "OuterMapExtraKey": { + []byte{0x82, 0x00, 0x91, 0x81, 0x02, 0x01, 0x11, 0x00}, + true, + regexp.MustCompile(``), + }, + "OuterMapExtraInvalidKey": { + []byte{0x81, 0x11, 0x81}, + false, + regexp.MustCompile(`EOF`), + }, + "ArrayInvalidLen": { + []byte{0x81, 0x00, 0xc1}, + false, + regexp.MustCompile(`msgpack: (?:unexpected|invalid|unknown) code.c1 decoding array length`), + }, + "ArrayZeroLen": { + []byte{0x81, 0x00, 0x90}, + false, + regexp.MustCompile(`msgpack: unexpected empty BoxError stack on decode`), + }, + "InnerMapInvalidLen": { + []byte{0x81, 0x00, 0x91, 0xc1}, + false, + regexp.MustCompile(`msgpack: (?:unexpected|invalid|unknown) code.c1 decoding map length`), + }, + "InnerMapInvalidKey": { + []byte{0x81, 0x00, 0x91, 0x81, 0xc1}, + false, + regexp.MustCompile(`msgpack: (?:unexpected|invalid|unknown) code.c1 decoding int64`), + }, + "InnerMapInvalidErrorType": { + []byte{0x81, 0x00, 0x91, 0x81, 0x00, 0xc1}, + false, + regexp.MustCompile(`msgpack: (?:unexpected|invalid|unknown) code.c1` + + ` decoding (?:string\/bytes|bytes) length`), + }, + "InnerMapInvalidErrorFile": { + []byte{0x81, 0x00, 0x91, 0x81, 0x01, 0xc1}, + false, + regexp.MustCompile(`msgpack: (?:unexpected|invalid|unknown) code.c1` + + ` decoding (?:string\/bytes|bytes) length`), + }, + "InnerMapInvalidErrorLine": { + []byte{0x81, 0x00, 0x91, 0x81, 0x02, 0xc1}, + false, + regexp.MustCompile(`msgpack: (?:unexpected|invalid|unknown) code.c1 decoding uint64`), + }, + "InnerMapInvalidErrorMessage": { + []byte{0x81, 0x00, 0x91, 0x81, 0x03, 0xc1}, + false, + regexp.MustCompile(`msgpack: (?:unexpected|invalid|unknown) code.c1 decoding` + + ` (?:string\/bytes|bytes) length`), + }, + "InnerMapInvalidErrorErrno": { + []byte{0x81, 0x00, 0x91, 0x81, 0x04, 0xc1}, + false, + regexp.MustCompile(`msgpack: (?:unexpected|invalid|unknown) code.c1 decoding uint64`), + }, + "InnerMapInvalidErrorErrcode": { + []byte{0x81, 0x00, 0x91, 0x81, 0x05, 0xc1}, + false, + regexp.MustCompile(`msgpack: (?:unexpected|invalid|unknown) code.c1 decoding uint64`), + }, + "InnerMapInvalidErrorFields": { + []byte{0x81, 0x00, 0x91, 0x81, 0x06, 0xc1}, + false, + regexp.MustCompile(`msgpack: (?:unexpected|invalid|unknown) code.c1 decoding map length`), + }, + "InnerMapInvalidErrorFieldsKey": { + []byte{0x81, 0x00, 0x91, 0x81, 0x06, 0x81, 0xc1}, + false, + regexp.MustCompile(`msgpack: (?:unexpected|invalid|unknown) code.c1` + + ` decoding (?:string\/bytes|bytes) length`), + }, + "InnerMapInvalidErrorFieldsValue": { + []byte{0x81, 0x00, 0x91, 0x81, 0x06, 0x81, 0xa3, 0x6b, 0x65, 0x79, 0xc1}, + false, + regexp.MustCompile(`msgpack: (?:unexpected|invalid|unknown) code.c1 decoding interface{}`), + }, + "InnerMapExtraKey": { + []byte{0x81, 0x00, 0x91, 0x81, 0x21, 0x00}, + true, + regexp.MustCompile(``), + }, + "InnerMapExtraInvalidKey": { + []byte{0x81, 0x00, 0x91, 0x81, 0x21, 0x81}, + false, + regexp.MustCompile(`EOF`), + }, +} + +func TestMessagePackDecode(t *testing.T) { + for name, testcase := range mpDecodeSamples { + t.Run(name, func(t *testing.T) { + var val = &BoxError{} + err := val.UnmarshalMsgpack(testcase.b) + if testcase.ok { + require.Nilf(t, err, "No errors on decode") + } else { + require.Regexp(t, testcase.err, err.Error()) + } + }) + } +} + +func TestMessagePackUnmarshalToNil(t *testing.T) { + var val *BoxError = nil + require.PanicsWithValue(t, "cannot unmarshal to a nil pointer", + func() { val.UnmarshalMsgpack(mpDecodeSamples["InnerMapExtraKey"].b) }) +} + +func TestMessagePackEncodeNil(t *testing.T) { + var val *BoxError + + _, err := val.MarshalMsgpack() + require.NotNil(t, err) + require.Equal(t, "msgpack: unexpected nil BoxError on encode", err.Error()) +} + +var space = "test_error_type" +var index = "primary" + +type TupleBoxError struct { + pk string // BoxError cannot be used as a primary key. + val BoxError +} + +func (t *TupleBoxError) EncodeMsgpack(e *msgpack.Encoder) error { + if err := e.EncodeArrayLen(2); err != nil { + return err + } + + if err := e.EncodeString(t.pk); err != nil { + return err + } + + return e.Encode(&t.val) +} + +func (t *TupleBoxError) DecodeMsgpack(d *msgpack.Decoder) error { + var err error + var l int + if l, err = d.DecodeArrayLen(); err != nil { + return err + } + if l != 2 { + return fmt.Errorf("Array length doesn't match: %d", l) + } + + if t.pk, err = d.DecodeString(); err != nil { + return err + } + + return d.Decode(&t.val) +} + +// Raw bytes encoding test is impossible for +// object with Fields since map iterating is random. +var tupleCases = map[string]struct { + tuple TupleBoxError + ttObj string +}{ + "SimpleError": { + TupleBoxError{ + "simple_error_pk", + samples["SimpleError"], + }, + "simple_error", + }, + "AccessDeniedError": { + TupleBoxError{ + "access_denied_error_pk", + samples["AccessDeniedError"], + }, + "access_denied_error", + }, + "ChainedError": { + TupleBoxError{ + "chained_error_pk", + samples["ChainedError"], + }, + "chained_error", + }, +} + +func TestErrorTypeMPEncodeDecode(t *testing.T) { + for name, testcase := range tupleCases { + t.Run(name, func(t *testing.T) { + buf, err := msgpack.Marshal(&testcase.tuple) + require.Nil(t, err) + + var res TupleBoxError + err = msgpack.Unmarshal(buf, &res) + require.Nil(t, err) + + require.Equal(t, testcase.tuple, res) + }) + } +} + +func TestErrorTypeEval(t *testing.T) { + test_helpers.SkipIfErrorMessagePackTypeUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + for name, testcase := range tupleCases { + t.Run(name, func(t *testing.T) { + data, err := conn.Eval("return ...", []interface{}{&testcase.tuple.val}) + require.Nil(t, err) + require.NotNil(t, data) + require.Equal(t, len(data), 1) + actual, ok := data[0].(*BoxError) + require.Truef(t, ok, "Response data has valid type") + require.Equal(t, testcase.tuple.val, *actual) + }) + } +} + +func TestErrorTypeEvalTyped(t *testing.T) { + test_helpers.SkipIfErrorMessagePackTypeUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + for name, testcase := range tupleCases { + t.Run(name, func(t *testing.T) { + var res []BoxError + err := conn.EvalTyped("return ...", []interface{}{&testcase.tuple.val}, &res) + require.Nil(t, err) + require.NotNil(t, res) + require.Equal(t, len(res), 1) + require.Equal(t, testcase.tuple.val, res[0]) + }) + } +} + +func TestErrorTypeInsert(t *testing.T) { + test_helpers.SkipIfErrorMessagePackTypeUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + truncateEval := fmt.Sprintf("box.space[%q]:truncate()", space) + _, err := conn.Eval(truncateEval, []interface{}{}) + require.Nil(t, err) + + for name, testcase := range tupleCases { + t.Run(name, func(t *testing.T) { + _, err = conn.Insert(space, &testcase.tuple) + require.Nil(t, err) + + checkEval := fmt.Sprintf(` + local err = rawget(_G, %q) + assert(err ~= nil) + + local tuple = box.space[%q]:get(%q) + assert(tuple ~= nil) + + local tuple_err = tuple[2] + assert(tuple_err ~= nil) + + return compare_box_errors(tuple_err, err) + `, testcase.ttObj, space, testcase.tuple.pk) + + // In fact, compare_box_errors does not check than File and Line + // of connector BoxError are equal to the Tarantool ones + // since they may differ between different Tarantool versions + // and editions. + _, err := conn.Eval(checkEval, []interface{}{}) + require.Nilf(t, err, "Tuple has been successfully inserted") + }) + } +} + +func TestErrorTypeInsertTyped(t *testing.T) { + test_helpers.SkipIfErrorMessagePackTypeUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + truncateEval := fmt.Sprintf("box.space[%q]:truncate()", space) + _, err := conn.Eval(truncateEval, []interface{}{}) + require.Nil(t, err) + + for name, testcase := range tupleCases { + t.Run(name, func(t *testing.T) { + var res []TupleBoxError + err = conn.InsertTyped(space, &testcase.tuple, &res) + require.Nil(t, err) + require.NotNil(t, res) + require.Equal(t, len(res), 1) + require.Equal(t, testcase.tuple, res[0]) + + checkEval := fmt.Sprintf(` + local err = rawget(_G, %q) + assert(err ~= nil) + + local tuple = box.space[%q]:get(%q) + assert(tuple ~= nil) + + local tuple_err = tuple[2] + assert(tuple_err ~= nil) + + return compare_box_errors(tuple_err, err) + `, testcase.ttObj, space, testcase.tuple.pk) + + // In fact, compare_box_errors does not check than File and Line + // of connector BoxError are equal to the Tarantool ones + // since they may differ between different Tarantool versions + // and editions. + _, err := conn.Eval(checkEval, []interface{}{}) + require.Nilf(t, err, "Tuple has been successfully inserted") + }) + } +} + +func TestErrorTypeSelect(t *testing.T) { + test_helpers.SkipIfErrorMessagePackTypeUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + truncateEval := fmt.Sprintf("box.space[%q]:truncate()", space) + _, err := conn.Eval(truncateEval, []interface{}{}) + require.Nil(t, err) + + for name, testcase := range tupleCases { + t.Run(name, func(t *testing.T) { + insertEval := fmt.Sprintf(` + local err = rawget(_G, %q) + assert(err ~= nil) + + local tuple = box.space[%q]:insert{%q, err} + assert(tuple ~= nil) + `, testcase.ttObj, space, testcase.tuple.pk) + + _, err := conn.Eval(insertEval, []interface{}{}) + require.Nilf(t, err, "Tuple has been successfully inserted") + + var offset uint32 = 0 + var limit uint32 = 1 + data, err := conn.Select(space, index, offset, limit, IterEq, + []interface{}{testcase.tuple.pk}) + require.Nil(t, err) + require.NotNil(t, data) + require.Equalf(t, len(data), 1, "Exactly one tuple had been found") + tpl, ok := data[0].([]interface{}) + require.Truef(t, ok, "Tuple has valid type") + require.Equal(t, testcase.tuple.pk, tpl[0]) + actual, ok := tpl[1].(*BoxError) + require.Truef(t, ok, "BoxError tuple field has valid type") + // In fact, CheckEqualBoxErrors does not check than File and Line + // of connector BoxError are equal to the Tarantool ones + // since they may differ between different Tarantool versions + // and editions. + test_helpers.CheckEqualBoxErrors(t, testcase.tuple.val, *actual) + }) + } +} + +func TestErrorTypeSelectTyped(t *testing.T) { + test_helpers.SkipIfErrorMessagePackTypeUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + truncateEval := fmt.Sprintf("box.space[%q]:truncate()", space) + _, err := conn.Eval(truncateEval, []interface{}{}) + require.Nil(t, err) + + for name, testcase := range tupleCases { + t.Run(name, func(t *testing.T) { + insertEval := fmt.Sprintf(` + local err = rawget(_G, %q) + assert(err ~= nil) + + local tuple = box.space[%q]:insert{%q, err} + assert(tuple ~= nil) + `, testcase.ttObj, space, testcase.tuple.pk) + + _, err := conn.Eval(insertEval, []interface{}{}) + require.Nilf(t, err, "Tuple has been successfully inserted") + + var offset uint32 = 0 + var limit uint32 = 1 + var resp []TupleBoxError + err = conn.SelectTyped(space, index, offset, limit, IterEq, + []interface{}{testcase.tuple.pk}, &resp) + require.Nil(t, err) + require.NotNil(t, resp) + require.Equalf(t, len(resp), 1, "Exactly one tuple had been found") + require.Equal(t, testcase.tuple.pk, resp[0].pk) + // In fact, CheckEqualBoxErrors does not check than File and Line + // of connector BoxError are equal to the Tarantool ones + // since they may differ between different Tarantool versions + // and editions. + test_helpers.CheckEqualBoxErrors(t, testcase.tuple.val, resp[0].val) + }) + } +} diff --git a/client_tools.go b/client_tools.go new file mode 100644 index 000000000..351b07cae --- /dev/null +++ b/client_tools.go @@ -0,0 +1,177 @@ +package tarantool + +import ( + "github.com/vmihailenco/msgpack/v5" +) + +// IntKey is utility type for passing integer key to Select*, Update*, +// Delete* and GetTyped. It serializes to array with single integer element. +type IntKey struct { + I int +} + +func (k IntKey) EncodeMsgpack(enc *msgpack.Encoder) error { + enc.EncodeArrayLen(1) + enc.EncodeInt(int64(k.I)) + return nil +} + +// UintKey is utility type for passing unsigned integer key to Select*, +// Update*, Delete* and GetTyped. It serializes to array with single unsigned +// integer element. +type UintKey struct { + I uint +} + +func (k UintKey) EncodeMsgpack(enc *msgpack.Encoder) error { + enc.EncodeArrayLen(1) + enc.EncodeUint(uint64(k.I)) + return nil +} + +// StringKey is utility type for passing string key to Select*, Update*, +// Delete* and GetTyped. It serializes to array with single string element. +type StringKey struct { + S string +} + +func (k StringKey) EncodeMsgpack(enc *msgpack.Encoder) error { + enc.EncodeArrayLen(1) + enc.EncodeString(k.S) + return nil +} + +// IntIntKey is utility type for passing two integer keys to Select*, Update*, +// Delete* and GetTyped. It serializes to array with two integer elements. +type IntIntKey struct { + I1, I2 int +} + +func (k IntIntKey) EncodeMsgpack(enc *msgpack.Encoder) error { + enc.EncodeArrayLen(2) + enc.EncodeInt(int64(k.I1)) + enc.EncodeInt(int64(k.I2)) + return nil +} + +// operation - is update operation. +type operation struct { + Op string + Field int + Arg interface{} + // Pos, Len, Replace fields used in the Splice operation. + Pos int + Len int + Replace string +} + +func (o operation) EncodeMsgpack(enc *msgpack.Encoder) error { + isSpliceOperation := o.Op == spliceOperator + argsLen := 3 + if isSpliceOperation { + argsLen = 5 + } + if err := enc.EncodeArrayLen(argsLen); err != nil { + return err + } + if err := enc.EncodeString(o.Op); err != nil { + return err + } + if err := enc.EncodeInt(int64(o.Field)); err != nil { + return err + } + + if isSpliceOperation { + if err := enc.EncodeInt(int64(o.Pos)); err != nil { + return err + } + if err := enc.EncodeInt(int64(o.Len)); err != nil { + return err + } + return enc.EncodeString(o.Replace) + } + + return enc.Encode(o.Arg) +} + +const ( + appendOperator = "+" + subtractionOperator = "-" + bitwiseAndOperator = "&" + bitwiseOrOperator = "|" + bitwiseXorOperator = "^" + spliceOperator = ":" + insertOperator = "!" + deleteOperator = "#" + assignOperator = "=" +) + +// Operations is a collection of update operations. +type Operations struct { + ops []operation +} + +// EncodeMsgpack encodes Operations as an array of operations. +func (ops *Operations) EncodeMsgpack(enc *msgpack.Encoder) error { + return enc.Encode(ops.ops) +} + +// NewOperations returns a new empty collection of update operations. +func NewOperations() *Operations { + return &Operations{[]operation{}} +} + +func (ops *Operations) append(op string, field int, arg interface{}) *Operations { + ops.ops = append(ops.ops, operation{Op: op, Field: field, Arg: arg}) + return ops +} + +func (ops *Operations) appendSplice(op string, field, pos, len int, replace string) *Operations { + ops.ops = append(ops.ops, operation{Op: op, Field: field, Pos: pos, Len: len, Replace: replace}) + return ops +} + +// Add adds an additional operation to the collection of update operations. +func (ops *Operations) Add(field int, arg interface{}) *Operations { + return ops.append(appendOperator, field, arg) +} + +// Subtract adds a subtraction operation to the collection of update operations. +func (ops *Operations) Subtract(field int, arg interface{}) *Operations { + return ops.append(subtractionOperator, field, arg) +} + +// BitwiseAnd adds a bitwise AND operation to the collection of update operations. +func (ops *Operations) BitwiseAnd(field int, arg interface{}) *Operations { + return ops.append(bitwiseAndOperator, field, arg) +} + +// BitwiseOr adds a bitwise OR operation to the collection of update operations. +func (ops *Operations) BitwiseOr(field int, arg interface{}) *Operations { + return ops.append(bitwiseOrOperator, field, arg) +} + +// BitwiseXor adds a bitwise XOR operation to the collection of update operations. +func (ops *Operations) BitwiseXor(field int, arg interface{}) *Operations { + return ops.append(bitwiseXorOperator, field, arg) +} + +// Splice adds a splice operation to the collection of update operations. +func (ops *Operations) Splice(field, pos, len int, replace string) *Operations { + return ops.appendSplice(spliceOperator, field, pos, len, replace) +} + +// Insert adds an insert operation to the collection of update operations. +func (ops *Operations) Insert(field int, arg interface{}) *Operations { + return ops.append(insertOperator, field, arg) +} + +// Delete adds a delete operation to the collection of update operations. +func (ops *Operations) Delete(field int, arg interface{}) *Operations { + return ops.append(deleteOperator, field, arg) +} + +// Assign adds an assign operation to the collection of update operations. +func (ops *Operations) Assign(field int, arg interface{}) *Operations { + return ops.append(assignOperator, field, arg) +} diff --git a/client_tools_test.go b/client_tools_test.go new file mode 100644 index 000000000..fc911acf9 --- /dev/null +++ b/client_tools_test.go @@ -0,0 +1,51 @@ +package tarantool_test + +import ( + "bytes" + "testing" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +func TestOperations_EncodeMsgpack(t *testing.T) { + ops := tarantool.NewOperations(). + Add(1, 2). + Subtract(1, 2). + BitwiseAnd(1, 2). + BitwiseOr(1, 2). + BitwiseXor(1, 2). + Splice(1, 2, 3, "a"). + Insert(1, 2). + Delete(1, 2). + Assign(1, 2) + refOps := []interface{}{ + []interface{}{"+", 1, 2}, + []interface{}{"-", 1, 2}, + []interface{}{"&", 1, 2}, + []interface{}{"|", 1, 2}, + []interface{}{"^", 1, 2}, + []interface{}{":", 1, 2, 3, "a"}, + []interface{}{"!", 1, 2}, + []interface{}{"#", 1, 2}, + []interface{}{"=", 1, 2}, + } + + var refBuf bytes.Buffer + encRef := msgpack.NewEncoder(&refBuf) + if err := encRef.Encode(refOps); err != nil { + t.Errorf("error while encoding: %v", err.Error()) + } + + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + + if err := enc.Encode(ops); err != nil { + t.Errorf("error while encoding: %v", err.Error()) + } + if !bytes.Equal(refBuf.Bytes(), buf.Bytes()) { + t.Errorf("encode response is wrong:\n expected %v\n got: %v", + refBuf, buf.Bytes()) + } +} diff --git a/config.lua b/config.lua new file mode 100644 index 000000000..f0bf05b65 --- /dev/null +++ b/config.lua @@ -0,0 +1,295 @@ +-- Do not set listen for now so connector won't be +-- able to send requests until everything is configured. +local auth_type = os.getenv("TEST_TNT_AUTH_TYPE") +if auth_type == "auto" then + auth_type = nil +end + +box.cfg{ + auth_type = auth_type, + work_dir = os.getenv("TEST_TNT_WORK_DIR"), + memtx_use_mvcc_engine = os.getenv("TEST_TNT_MEMTX_USE_MVCC_ENGINE") == 'true' or nil, +} + +box.once("init", function() + local st = box.schema.space.create('schematest', { + id = 616, + temporary = true, + if_not_exists = true, + field_count = 8, + format = { + {name = "name0", type = "unsigned"}, + {name = "name1", type = "unsigned"}, + {name = "name2", type = "string"}, + {name = "name3", type = "unsigned"}, + {name = "name4", type = "unsigned"}, + {name = "name5", type = "string"}, + {name = "nullable", is_nullable = true}, + }, + }) + st:create_index('primary', { + type = 'hash', + parts = {1, 'uint'}, + unique = true, + if_not_exists = true, + }) + st:create_index('secondary', { + id = 3, + type = 'tree', + unique = false, + parts = { 2, 'uint', 3, 'string' }, + if_not_exists = true, + }) + st:truncate() + + local s = box.schema.space.create('test', { + id = 617, + if_not_exists = true, + }) + s:create_index('primary', { + type = 'tree', + parts = {1, 'uint'}, + if_not_exists = true + }) + + local s = box.schema.space.create('teststring', { + id = 618, + if_not_exists = true, + }) + s:create_index('primary', { + type = 'tree', + parts = {1, 'string'}, + if_not_exists = true + }) + + local s = box.schema.space.create('testintint', { + id = 619, + if_not_exists = true, + }) + s:create_index('primary', { + type = 'tree', + parts = {1, 'int', 2, 'int'}, + if_not_exists = true + }) + + local s = box.schema.space.create('SQL_TEST', { + id = 620, + if_not_exists = true, + format = { + {name = "NAME0", type = "unsigned"}, + {name = "NAME1", type = "string"}, + {name = "NAME2", type = "string"}, + } + }) + s:create_index('primary', { + type = 'tree', + parts = {1, 'uint'}, + if_not_exists = true + }) + s:insert{1, "test", "test"} + + local s = box.schema.space.create('test_perf', { + id = 621, + temporary = true, + if_not_exists = true, + field_count = 3, + format = { + {name = "id", type = "unsigned"}, + {name = "name", type = "string"}, + {name = "arr1", type = "array"}, + }, + }) + s:create_index('primary', { + type = 'tree', + unique = true, + parts = {1, 'unsigned'}, + if_not_exists = true + }) + s:create_index('secondary', { + id = 5, type = 'tree', + unique = false, + parts = {2, 'string'}, + if_not_exists = true + }) + local arr_data = {} + for i = 1,100 do + arr_data[i] = i + end + for i = 1,1000 do + s:insert{ + i, + 'test_name', + arr_data, + } + end + + local s = box.schema.space.create('test_error_type', { + id = 622, + temporary = true, + if_not_exists = true, + field_count = 2, + -- You can't specify box.error as format type, + -- but can put box.error objects. + }) + s:create_index('primary', { + type = 'tree', + unique = true, + parts = {1, 'string'}, + if_not_exists = true + }) + + --box.schema.user.grant('guest', 'read,write,execute', 'universe') + box.schema.func.create('box.info') + box.schema.func.create('simple_concat') + + -- auth testing: access control + box.schema.user.create('test', {password = 'test'}) + box.schema.user.grant('test', 'execute', 'universe') + box.schema.user.grant('test', 'read,write', 'space', 'test') + box.schema.user.grant('test', 'read,write', 'space', 'schematest') + box.schema.user.grant('test', 'read,write', 'space', 'test_perf') + box.schema.user.grant('test', 'read,write', 'space', 'test_error_type') + + -- grants for sql tests + box.schema.user.grant('test', 'create,read,write,drop,alter', 'space') + box.schema.user.grant('test', 'create', 'sequence') + + box.schema.user.create('no_grants') +end) + +local function func_name() + return { + {221, "", { + {"Moscow", 34}, + {"Minsk", 23}, + {"Kiev", 31}, + } + } + } +end +rawset(_G, 'func_name', func_name) + +local function simple_concat(a) + return a .. a +end +rawset(_G, 'simple_concat', simple_concat) + +local function push_func(cnt) + for i = 1, cnt do + box.session.push(i) + end + return cnt +end +rawset(_G, 'push_func', push_func) + +local function create_spaces() + for i=1,10 do + local s = box.schema.space.create('test' .. tostring(i), { + id = 700 + i, + if_not_exists = true, + }) + local idx = s:create_index('test' .. tostring(i) .. 'primary', { + type = 'tree', + parts = {1, 'uint'}, + if_not_exists = true + }) + idx:drop() + s:drop() + end +end +rawset(_G, 'create_spaces', create_spaces) + +local function tarantool_version_at_least(wanted_major, wanted_minor, wanted_patch) + -- https://github.com/tarantool/crud/blob/733528be02c1ffa3dacc12c034ee58c9903127fc/test/helper.lua#L316-L337 + local major_minor_patch = _TARANTOOL:split('-', 1)[1] + local major_minor_patch_parts = major_minor_patch:split('.', 2) + + local major = tonumber(major_minor_patch_parts[1]) + local minor = tonumber(major_minor_patch_parts[2]) + local patch = tonumber(major_minor_patch_parts[3]) + + if major < (wanted_major or 0) then return false end + if major > (wanted_major or 0) then return true end + + if minor < (wanted_minor or 0) then return false end + if minor > (wanted_minor or 0) then return true end + + if patch < (wanted_patch or 0) then return false end + if patch > (wanted_patch or 0) then return true end + + return true +end + +if tarantool_version_at_least(2, 4, 1) then + local e1 = box.error.new(box.error.UNKNOWN) + rawset(_G, 'simple_error', e1) + + local e2 = box.error.new(box.error.TIMEOUT) + e2:set_prev(e1) + rawset(_G, 'chained_error', e2) + + local user = box.session.user() + box.schema.func.create('forbidden_function', {body = 'function() end'}) + box.session.su('no_grants') + local _, access_denied_error = pcall(function() box.func.forbidden_function:call() end) + box.session.su(user) + rawset(_G, 'access_denied_error', access_denied_error) + + -- cdata structure is as follows: + -- + -- tarantool> err:unpack() + -- - code: val + -- base_type: val + -- type: val + -- message: val + -- field1: val + -- field2: val + -- trace: + -- - file: val + -- line: val + + local function compare_box_error_attributes(expected, actual) + for attr, _ in pairs(expected:unpack()) do + if (attr ~= 'prev') and (attr ~= 'trace') then + if expected[attr] ~= actual[attr] then + error(('%s expected %s is not equal to actual %s'):format( + attr, expected[attr], actual[attr])) + end + end + end + end + + local function compare_box_errors(expected, actual) + if (expected == nil) and (actual ~= nil) then + error(('Expected error stack is empty, but actual error ' .. + 'has previous %s (%s) error'):format( + actual.type, actual.message)) + end + + if (expected ~= nil) and (actual == nil) then + error(('Actual error stack is empty, but expected error ' .. + 'has previous %s (%s) error'):format( + expected.type, expected.message)) + end + + compare_box_error_attributes(expected, actual) + + if (expected.prev ~= nil) or (actual.prev ~= nil) then + return compare_box_errors(expected.prev, actual.prev) + end + + return true + end + + rawset(_G, 'compare_box_errors', compare_box_errors) +end + +box.space.test:truncate() + +--box.schema.user.revoke('guest', 'read,write,execute', 'universe') + +-- Set listen only when every other thing is configured. +box.cfg{ + auth_type = auth_type, + listen = os.getenv("TEST_TNT_LISTEN"), +} diff --git a/connection.go b/connection.go index 08d4fa6be..62829772c 100644 --- a/connection.go +++ b/connection.go @@ -1,140 +1,1601 @@ +// Package with implementation of methods and structures for work with +// Tarantool instance. package tarantool import ( - "net" + "context" + "encoding/binary" + "errors" "fmt" - "github.com/vmihailenco/msgpack" - "sync/atomic" - "bytes" + "io" + "log" + "math" + "net" + "runtime" "sync" + "sync/atomic" + "time" + + "github.com/tarantool/go-iproto" + "github.com/vmihailenco/msgpack/v5" +) + +const requestsMap = 128 +const ignoreStreamId = 0 +const ( + connDisconnected = 0 + connConnected = 1 + connShutdown = 2 + connClosed = 3 ) +const shutdownEventKey = "box.shutdown" + +type ConnEventKind int +type ConnLogKind int + +var ( + errUnknownRequest = errors.New("the passed connected request doesn't belong " + + "to the current connection or connection pool") +) + +const ( + // Connected signals that connection is established or reestablished. + Connected ConnEventKind = iota + 1 + // Disconnected signals that connection is broken. + Disconnected + // ReconnectFailed signals that attempt to reconnect has failed. + ReconnectFailed + // Shutdown signals that shutdown callback is processing. + Shutdown + // Either reconnect attempts exhausted, or explicit Close is called. + Closed + + // LogReconnectFailed is logged when reconnect attempt failed. + LogReconnectFailed ConnLogKind = iota + 1 + // LogLastReconnectFailed is logged when last reconnect attempt failed, + // connection will be closed after that. + LogLastReconnectFailed + // LogUnexpectedResultId is logged when response with unknown id was received. + // Most probably it is due to request timeout. + LogUnexpectedResultId + // LogWatchEventReadFailed is logged when failed to read a watch event. + LogWatchEventReadFailed + // LogBoxSessionPushUnsupported is logged when response type turned IPROTO_CHUNK. + LogBoxSessionPushUnsupported +) + +// ConnEvent is sent throw Notify channel specified in Opts. +type ConnEvent struct { + Conn *Connection + Kind ConnEventKind + When time.Time +} + +// A raw watch event. +type connWatchEvent struct { + key string + value interface{} +} + +var epoch = time.Now() + +// Logger is logger type expected to be passed in options. +type Logger interface { + Report(event ConnLogKind, conn *Connection, v ...interface{}) +} + +type defaultLogger struct{} + +func (d defaultLogger) Report(event ConnLogKind, conn *Connection, v ...interface{}) { + switch event { + case LogReconnectFailed: + reconnects := v[0].(uint) + err := v[1].(error) + addr := conn.Addr() + if addr == nil { + log.Printf("tarantool: connect (%d/%d) failed: %s", + reconnects, conn.opts.MaxReconnects, err) + } else { + log.Printf("tarantool: reconnect (%d/%d) to %s failed: %s", + reconnects, conn.opts.MaxReconnects, addr, err) + } + case LogLastReconnectFailed: + err := v[0].(error) + addr := conn.Addr() + if addr == nil { + log.Printf("tarantool: last connect failed: %s, giving it up", + err) + } else { + log.Printf("tarantool: last reconnect to %s failed: %s, giving it up", + addr, err) + } + case LogUnexpectedResultId: + header := v[0].(Header) + log.Printf("tarantool: connection %s got unexpected request ID (%d) in response "+ + "(probably cancelled request)", + conn.Addr(), header.RequestId) + case LogWatchEventReadFailed: + err := v[0].(error) + log.Printf("tarantool: unable to parse watch event: %s", err) + case LogBoxSessionPushUnsupported: + header := v[0].(Header) + log.Printf("tarantool: unsupported box.session.push() for request %d", header.RequestId) + default: + args := append([]interface{}{"tarantool: unexpected event ", event, conn}, v...) + log.Print(args...) + } +} + +// Connection is a handle with a single connection to a Tarantool instance. +// +// It is created and configured with Connect function, and could not be +// reconfigured later. +// +// Connection could be in three possible states: +// +// - In "Connected" state it sends queries to Tarantool. +// +// - In "Disconnected" state it rejects queries with ClientError{Code: +// ErrConnectionNotReady} +// +// - In "Shutdown" state it rejects queries with ClientError{Code: +// ErrConnectionShutdown}. The state indicates that a graceful shutdown +// in progress. The connection waits for all active requests to +// complete. +// +// - In "Closed" state it rejects queries with ClientError{Code: +// ErrConnectionClosed}. Connection could become "Closed" when +// Connection.Close() method called, or when Tarantool disconnected and +// Reconnect pause is not specified or MaxReconnects is specified and +// MaxReconnect reconnect attempts already performed. +// +// You may perform data manipulation operation by calling its methods: +// Call*, Insert*, Replace*, Update*, Upsert*, Call*, Eval*. +// +// In any method that accepts space you my pass either space number or space +// name (in this case it will be looked up in schema). Same is true for index. +// +// ATTENTION: tuple, key, ops and args arguments for any method should be +// and array or should serialize to msgpack array. +// +// ATTENTION: result argument for *Typed methods should deserialize from +// msgpack array, cause Tarantool always returns result as an array. +// For all space related methods and Call16* (but not Call17*) methods Tarantool +// always returns array of array (array of tuples for space related methods). +// For Eval* and Call* Tarantool always returns array, but does not forces +// array of arrays. +// +// If connected to Tarantool 2.10 or newer, connection supports server graceful +// shutdown. In this case, server will wait until all client requests will be +// finished and client disconnects before going down (server also may go down +// by timeout). Client reconnect will happen if connection options enable +// reconnect. Beware that graceful shutdown event initialization is asynchronous. +// +// More on graceful shutdown: +// https://www.tarantool.io/en/doc/latest/dev_guide/internals/iproto/graceful_shutdown/ type Connection struct { - connection net.Conn - mutex *sync.Mutex - requestId uint32 - Greeting *Greeting - requests map[uint32]chan *Response - packets chan []byte + addr net.Addr + dialer Dialer + c Conn + mutex sync.Mutex + cond *sync.Cond + // schemaResolver contains a SchemaResolver implementation. + schemaResolver SchemaResolver + // requestId contains the last request ID for requests with nil context. + requestId uint32 + // contextRequestId contains the last request ID for requests with context. + contextRequestId uint32 + // Greeting contains first message sent by Tarantool. + Greeting *Greeting + + shard []connShard + dirtyShard chan uint32 + + control chan struct{} + rlimit chan struct{} + opts Opts + state uint32 + dec *msgpack.Decoder + lenbuf [packetLengthBytes]byte + + lastStreamId uint64 + + serverProtocolInfo ProtocolInfo + // watchMap is a map of key -> chan watchState. + watchMap sync.Map + + // shutdownWatcher is the "box.shutdown" event watcher. + shutdownWatcher Watcher + // requestCnt is a counter of active requests. + requestCnt int64 } -type Greeting struct { - version string - auth string +var _ = Connector(&Connection{}) // Check compatibility with connector interface. + +type futureList struct { + first *Future + last **Future } -func Connect(addr string) (conn *Connection, err error) { - fmt.Printf("Connecting to %s ...\n", addr) - connection, err := net.Dial("tcp", addr) - if err != nil { - return +func (list *futureList) findFuture(reqid uint32, fetch bool) *Future { + root := &list.first + for { + fut := *root + if fut == nil { + return nil + } + if fut.requestId == reqid { + if fetch { + *root = fut.next + if fut.next == nil { + list.last = root + } else { + fut.next = nil + } + } + return fut + } + root = &fut.next + } +} + +func (list *futureList) addFuture(fut *Future) { + *list.last = fut + list.last = &fut.next +} + +func (list *futureList) clear(err error, conn *Connection) { + fut := list.first + list.first = nil + list.last = &list.first + for fut != nil { + fut.SetError(err) + conn.markDone(fut) + fut, fut.next = fut.next, nil } - connection.(*net.TCPConn).SetNoDelay(true) +} - fmt.Println("Connected ...") +type connShard struct { + rmut sync.Mutex + requests [requestsMap]futureList + requestsWithCtx [requestsMap]futureList + bufmut sync.Mutex + buf smallWBuf + enc *msgpack.Encoder +} - conn = &Connection{ connection, &sync.Mutex{}, 0, &Greeting{}, make(map[uint32]chan *Response), make(chan []byte) } - err = conn.handShake() +// RLimitActions is an enumeration type for an action to do when a rate limit +// is reached. +type RLimitAction int - go conn.writer() - go conn.reader() +const ( + // RLimitDrop immediately aborts the request. + RLimitDrop RLimitAction = iota + // RLimitWait waits during timeout period for some request to be answered. + // If no request answered during timeout period, this request is aborted. + // If no timeout period is set, it will wait forever. + RLimitWait +) - return +// Opts is a way to configure Connection +type Opts struct { + // Timeout for response to a particular request. The timeout is reset when + // push messages are received. If Timeout is zero, any request can be + // blocked infinitely. + // Also used to setup net.TCPConn.Set(Read|Write)Deadline. + // + // Pay attention, when using contexts with request objects, + // the timeout option for Connection does not affect the lifetime + // of the request. For those purposes use context.WithTimeout() as + // the root context. + Timeout time.Duration + // Timeout between reconnect attempts. If Reconnect is zero, no + // reconnect attempts will be made. + // If specified, then when Tarantool is not reachable or disconnected, + // new connect attempt is performed after pause. + // By default, no reconnection attempts are performed, + // so once disconnected, connection becomes Closed. + Reconnect time.Duration + // Maximum number of reconnect failures; after that we give it up to + // on. If MaxReconnects is zero, the client will try to reconnect + // endlessly. + // After MaxReconnects attempts Connection becomes closed. + MaxReconnects uint + // RateLimit limits number of 'in-fly' request, i.e. already put into + // requests queue, but not yet answered by server or timeouted. + // It is disabled by default. + // See RLimitAction for possible actions when RateLimit.reached. + RateLimit uint + // RLimitAction tells what to do when RateLimit is reached. + // It is required if RateLimit is specified. + RLimitAction RLimitAction + // Concurrency is amount of separate mutexes for request + // queues and buffers inside of connection. + // It is rounded up to nearest power of 2. + // By default it is runtime.GOMAXPROCS(-1) * 4 + Concurrency uint32 + // SkipSchema disables schema loading. Without disabling schema loading, + // there is no way to create Connection for currently not accessible Tarantool. + SkipSchema bool + // Notify is a channel which receives notifications about Connection status + // changes. + Notify chan<- ConnEvent + // Handle is user specified value, that could be retrivied with + // Handle() method. + Handle interface{} + // Logger is user specified logger used for error messages. + Logger Logger } -func (conn *Connection) handShake() (err error) { - fmt.Printf("Greeting ... ") - greeting := make([]byte, 128) - _, err = conn.connection.Read(greeting) +// Connect creates and configures a new Connection. +func Connect(ctx context.Context, dialer Dialer, opts Opts) (conn *Connection, err error) { + conn = &Connection{ + dialer: dialer, + requestId: 0, + contextRequestId: 1, + Greeting: &Greeting{}, + control: make(chan struct{}), + opts: opts, + dec: msgpack.NewDecoder(&smallBuf{}), + } + maxprocs := uint32(runtime.GOMAXPROCS(-1)) + if conn.opts.Concurrency == 0 || conn.opts.Concurrency > maxprocs*128 { + conn.opts.Concurrency = maxprocs * 4 + } + if c := conn.opts.Concurrency; c&(c-1) != 0 { + for i := uint(1); i < 32; i *= 2 { + c |= c >> i + } + conn.opts.Concurrency = c + 1 + } + conn.dirtyShard = make(chan uint32, conn.opts.Concurrency*2) + conn.shard = make([]connShard, conn.opts.Concurrency) + for i := range conn.shard { + shard := &conn.shard[i] + requestsLists := []*[requestsMap]futureList{&shard.requests, &shard.requestsWithCtx} + for _, requests := range requestsLists { + for j := range requests { + requests[j].last = &requests[j].first + } + } + } + + if conn.opts.RateLimit > 0 { + conn.rlimit = make(chan struct{}, conn.opts.RateLimit) + if conn.opts.RLimitAction != RLimitDrop && conn.opts.RLimitAction != RLimitWait { + return nil, errors.New("RLimitAction should be specified to RLimitDone nor RLimitWait") + } + } + + if conn.opts.Logger == nil { + conn.opts.Logger = defaultLogger{} + } + + conn.cond = sync.NewCond(&conn.mutex) + + if conn.opts.Reconnect > 0 { + // We don't need these mutex.Lock()/mutex.Unlock() here, but + // runReconnects() expects mutex.Lock() to be set, so it's + // easier to add them instead of reworking runReconnects(). + conn.mutex.Lock() + err = conn.runReconnects(ctx) + conn.mutex.Unlock() + if err != nil { + return nil, err + } + } else { + if err = conn.connect(ctx); err != nil { + return nil, err + } + } + + go conn.pinger() + if conn.opts.Timeout > 0 { + go conn.timeouts() + } + + // TODO: reload schema after reconnect. + if !conn.opts.SkipSchema { + schema, err := GetSchema(conn) + if err != nil { + conn.mutex.Lock() + defer conn.mutex.Unlock() + conn.closeConnection(err, true) + return nil, err + } + conn.SetSchema(schema) + } + + return conn, err +} + +// ConnectedNow reports if connection is established at the moment. +func (conn *Connection) ConnectedNow() bool { + return atomic.LoadUint32(&conn.state) == connConnected +} + +// ClosedNow reports if connection is closed by user or after reconnect. +func (conn *Connection) ClosedNow() bool { + return atomic.LoadUint32(&conn.state) == connClosed +} + +// Close closes Connection. +// After this method called, there is no way to reopen this Connection. +func (conn *Connection) Close() error { + err := ClientError{ErrConnectionClosed, "connection closed by client"} + conn.mutex.Lock() + defer conn.mutex.Unlock() + return conn.closeConnection(err, true) +} + +// CloseGraceful closes Connection gracefully. It waits for all requests to +// complete. +// After this method called, there is no way to reopen this Connection. +func (conn *Connection) CloseGraceful() error { + return conn.shutdown(true) +} + +// Addr returns a configured address of Tarantool socket. +func (conn *Connection) Addr() net.Addr { + return conn.addr +} + +// Handle returns a user-specified handle from Opts. +func (conn *Connection) Handle() interface{} { + return conn.opts.Handle +} + +func (conn *Connection) cancelFuture(fut *Future, err error) { + if fut = conn.fetchFuture(fut.requestId); fut != nil { + fut.SetError(err) + conn.markDone(fut) + } +} + +func (conn *Connection) dial(ctx context.Context) error { + opts := conn.opts + + var c Conn + c, err := conn.dialer.Dial(ctx, DialOpts{ + IoTimeout: opts.Timeout, + }) if err != nil { - fmt.Println("Error") + return err + } + + conn.addr = c.Addr() + connGreeting := c.Greeting() + conn.Greeting.Version = connGreeting.Version + conn.Greeting.Salt = connGreeting.Salt + conn.serverProtocolInfo = c.ProtocolInfo() + + if conn.schemaResolver == nil { + namesSupported := isFeatureInSlice(iproto.IPROTO_FEATURE_SPACE_AND_INDEX_NAMES, + conn.serverProtocolInfo.Features) + + conn.schemaResolver = &noSchemaResolver{ + SpaceAndIndexNamesSupported: namesSupported, + } + } + + // Watchers. + conn.watchMap.Range(func(key, value interface{}) bool { + st := value.(chan watchState) + state := <-st + if state.unready != nil { + st <- state + return true + } + + req := newWatchRequest(key.(string)) + if err = writeRequest(ctx, c, req); err != nil { + st <- state + return false + } + state.ack = true + + st <- state + return true + }) + + if err != nil { + c.Close() + return fmt.Errorf("unable to register watch: %w", err) + } + + // Only if connected and fully initialized. + conn.lockShards() + conn.c = c + atomic.StoreUint32(&conn.state, connConnected) + conn.cond.Broadcast() + conn.unlockShards() + go conn.writer(c, c) + go conn.reader(c, c) + + // Subscribe shutdown event to process graceful shutdown. + if conn.shutdownWatcher == nil && + isFeatureInSlice(iproto.IPROTO_FEATURE_WATCHERS, + conn.serverProtocolInfo.Features) { + watcher, werr := conn.newWatcherImpl(shutdownEventKey, shutdownEventCallback) + if werr != nil { + return werr + } + conn.shutdownWatcher = watcher + } + + return nil +} + +func pack(h *smallWBuf, enc *msgpack.Encoder, reqid uint32, + req Request, streamId uint64, res SchemaResolver) (err error) { + const uint32Code = 0xce + const uint64Code = 0xcf + const streamBytesLenUint64 = 10 + const streamBytesLenUint32 = 6 + + hl := h.Len() + + var streamBytesLen = 0 + var streamBytes [streamBytesLenUint64]byte + hMapLen := byte(0x82) // 2 element map. + if streamId != ignoreStreamId { + hMapLen = byte(0x83) // 3 element map. + streamBytes[0] = byte(iproto.IPROTO_STREAM_ID) + if streamId > math.MaxUint32 { + streamBytesLen = streamBytesLenUint64 + streamBytes[1] = uint64Code + binary.BigEndian.PutUint64(streamBytes[2:], streamId) + } else { + streamBytesLen = streamBytesLenUint32 + streamBytes[1] = uint32Code + binary.BigEndian.PutUint32(streamBytes[2:], uint32(streamId)) + } + } + + hBytes := append([]byte{ + uint32Code, 0, 0, 0, 0, // Length. + hMapLen, + byte(iproto.IPROTO_REQUEST_TYPE), byte(req.Type()), // Request type. + byte(iproto.IPROTO_SYNC), uint32Code, + byte(reqid >> 24), byte(reqid >> 16), + byte(reqid >> 8), byte(reqid), + }, streamBytes[:streamBytesLen]...) + + h.Write(hBytes) + + if err = req.Body(res, enc); err != nil { return } - conn.Greeting.version = bytes.NewBuffer(greeting[:64]).String() - conn.Greeting.auth = bytes.NewBuffer(greeting[64:]).String() - fmt.Println("Success") - fmt.Println("Version:", conn.Greeting.version) + l := uint32(h.Len() - 5 - hl) + h.b[hl+1] = byte(l >> 24) + h.b[hl+2] = byte(l >> 16) + h.b[hl+3] = byte(l >> 8) + h.b[hl+4] = byte(l) + return } -func (conn *Connection) writer(){ - var ( - err error - packet []byte - ) - for { - packet = <- conn.packets - err = conn.write(packet) - if err != nil { - panic(err) +func (conn *Connection) connect(ctx context.Context) error { + var err error + if conn.c == nil && conn.state == connDisconnected { + if err = conn.dial(ctx); err == nil { + conn.notify(Connected) + return nil } } + if conn.state == connClosed { + err = ClientError{ErrConnectionClosed, "using closed connection"} + } + return err } -func (conn *Connection) reader() { - var ( - err error - resp_bytes []byte - ) - for { - resp_bytes, err = conn.read() +func (conn *Connection) closeConnection(neterr error, forever bool) (err error) { + conn.lockShards() + defer conn.unlockShards() + if forever { + if conn.state != connClosed { + close(conn.control) + atomic.StoreUint32(&conn.state, connClosed) + conn.cond.Broadcast() + // Free the resources. + if conn.shutdownWatcher != nil { + go conn.shutdownWatcher.Unregister() + conn.shutdownWatcher = nil + } + conn.notify(Closed) + } + } else { + atomic.StoreUint32(&conn.state, connDisconnected) + conn.cond.Broadcast() + conn.notify(Disconnected) + } + if conn.c != nil { + err = conn.c.Close() + conn.c = nil + } + for i := range conn.shard { + conn.shard[i].buf.Reset() + requestsLists := []*[requestsMap]futureList{ + &conn.shard[i].requests, + &conn.shard[i].requestsWithCtx, + } + for _, requests := range requestsLists { + for pos := range requests { + requests[pos].clear(neterr, conn) + } + } + } + return +} + +func (conn *Connection) getDialTimeout() time.Duration { + dialTimeout := conn.opts.Reconnect / 2 + if dialTimeout == 0 { + dialTimeout = 500 * time.Millisecond + } else if dialTimeout > 5*time.Second { + dialTimeout = 5 * time.Second + } + return dialTimeout +} + +func (conn *Connection) runReconnects(ctx context.Context) error { + dialTimeout := conn.getDialTimeout() + var reconnects uint + var err error + + t := time.NewTicker(conn.opts.Reconnect) + defer t.Stop() + for conn.opts.MaxReconnects == 0 || reconnects <= conn.opts.MaxReconnects { + localCtx, cancel := context.WithTimeout(ctx, dialTimeout) + err = conn.connect(localCtx) + cancel() + if err != nil { - panic(err) + // The error will most likely be the one that Dialer + // returns to us due to the context being cancelled. + // Although this is not guaranteed. For example, + // if the dialer may throw another error before checking + // the context, and the context has already been + // canceled. Or the context was not canceled after + // the error was thrown, but before the context was + // checked here. + if ctx.Err() != nil { + return err + } + if clientErr, ok := err.(ClientError); ok && + clientErr.Code == ErrConnectionClosed { + return err + } + } else { + return nil } - resp := NewResponse(resp_bytes) - respChan := conn.requests[resp.RequestId] - conn.mutex.Lock() - delete(conn.requests, resp.RequestId) + conn.opts.Logger.Report(LogReconnectFailed, conn, reconnects, err) + conn.notify(ReconnectFailed) + reconnects++ conn.mutex.Unlock() - respChan <- resp + + select { + case <-ctx.Done(): + // Since the context is cancelled, we don't need to do anything. + // Conn.connect() will return the correct error. + case <-t.C: + } + + conn.mutex.Lock() } + + conn.opts.Logger.Report(LogLastReconnectFailed, conn, err) + // mark connection as closed to avoid reopening by another goroutine + return ClientError{ErrConnectionClosed, "last reconnect failed"} } -func (conn *Connection) write(data []byte) (err error) { - l, err := conn.connection.Write(data) - if l != len(data) { - panic("Wrong length writed") +func (conn *Connection) reconnectImpl(neterr error, c Conn) { + if conn.opts.Reconnect > 0 { + if c == conn.c { + conn.closeConnection(neterr, false) + if err := conn.runReconnects(context.Background()); err != nil { + conn.closeConnection(err, true) + } + } + } else { + conn.closeConnection(neterr, true) } - return } -func (conn *Connection) read() (response []byte, err error){ - var length_uint uint32 - var l, tl int - length := make([]byte, PacketLengthBytes) +func (conn *Connection) reconnect(neterr error, c Conn) { + conn.mutex.Lock() + defer conn.mutex.Unlock() + conn.reconnectImpl(neterr, c) + conn.cond.Broadcast() +} - tl = 0 - for tl < int(PacketLengthBytes) { - l, err = conn.connection.Read(length[tl:]) - tl += l - if err != nil { +func (conn *Connection) lockShards() { + for i := range conn.shard { + conn.shard[i].rmut.Lock() + conn.shard[i].bufmut.Lock() + } +} + +func (conn *Connection) unlockShards() { + for i := range conn.shard { + conn.shard[i].rmut.Unlock() + conn.shard[i].bufmut.Unlock() + } +} + +func (conn *Connection) pinger() { + to := conn.opts.Timeout + if to == 0 { + to = 3 * time.Second + } + t := time.NewTicker(to / 3) + defer t.Stop() + for { + select { + case <-conn.control: + return + case <-t.C: + } + conn.Ping() + } +} + +func (conn *Connection) notify(kind ConnEventKind) { + if conn.opts.Notify != nil { + select { + case conn.opts.Notify <- ConnEvent{Kind: kind, Conn: conn, When: time.Now()}: + default: + } + } +} + +func (conn *Connection) writer(w writeFlusher, c Conn) { + var shardn uint32 + var packet smallWBuf + for atomic.LoadUint32(&conn.state) != connClosed { + select { + case shardn = <-conn.dirtyShard: + default: + runtime.Gosched() + if len(conn.dirtyShard) == 0 { + if err := w.Flush(); err != nil { + err = ClientError{ + ErrIoError, + fmt.Sprintf("failed to flush data to the connection: %s", err), + } + conn.reconnect(err, c) + return + } + } + select { + case shardn = <-conn.dirtyShard: + case <-conn.control: + return + } + } + shard := &conn.shard[shardn] + shard.bufmut.Lock() + if conn.c != c { + conn.dirtyShard <- shardn + shard.bufmut.Unlock() + return + } + packet, shard.buf = shard.buf, packet + shard.bufmut.Unlock() + if packet.Len() == 0 { + continue + } + if _, err := w.Write(packet.b); err != nil { + err = ClientError{ + ErrIoError, + fmt.Sprintf("failed to write data to the connection: %s", err), + } + conn.reconnect(err, c) return } + packet.Reset() } +} - err = msgpack.Unmarshal(length, &length_uint) +func readWatchEvent(reader io.Reader) (connWatchEvent, error) { + keyExist := false + event := connWatchEvent{} + + d := getDecoder(reader) + defer putDecoder(d) + + l, err := d.DecodeMapLen() if err != nil { + return event, err + } + + for ; l > 0; l-- { + cd, err := d.DecodeInt() + if err != nil { + return event, err + } + + switch iproto.Key(cd) { + case iproto.IPROTO_EVENT_KEY: + if event.key, err = d.DecodeString(); err != nil { + return event, err + } + keyExist = true + case iproto.IPROTO_EVENT_DATA: + if event.value, err = d.DecodeInterface(); err != nil { + return event, err + } + default: + if err = d.Skip(); err != nil { + return event, err + } + } + } + + if !keyExist { + return event, errors.New("watch event does not have a key") + } + + return event, nil +} + +func (conn *Connection) reader(r io.Reader, c Conn) { + events := make(chan connWatchEvent, 1024) + defer close(events) + + go conn.eventer(events) + + for atomic.LoadUint32(&conn.state) != connClosed { + respBytes, err := read(r, conn.lenbuf[:]) + if err != nil { + err = ClientError{ + ErrIoError, + fmt.Sprintf("failed to read data from the connection: %s", err), + } + conn.reconnect(err, c) + return + } + buf := smallBuf{b: respBytes} + header, code, err := decodeHeader(conn.dec, &buf) + if err != nil { + err = ClientError{ + ErrProtocolError, + fmt.Sprintf("failed to decode IPROTO header: %s", err), + } + conn.reconnect(err, c) + return + } + + var fut *Future = nil + if code == iproto.IPROTO_EVENT { + if event, err := readWatchEvent(&buf); err == nil { + events <- event + } else { + err = ClientError{ + ErrProtocolError, + fmt.Sprintf("failed to decode IPROTO_EVENT: %s", err), + } + conn.opts.Logger.Report(LogWatchEventReadFailed, conn, err) + } + continue + } else if code == iproto.IPROTO_CHUNK { + conn.opts.Logger.Report(LogBoxSessionPushUnsupported, conn, header) + } else { + if fut = conn.fetchFuture(header.RequestId); fut != nil { + if err := fut.SetResponse(header, &buf); err != nil { + fut.SetError(fmt.Errorf("failed to set response: %w", err)) + } + conn.markDone(fut) + } + } + + if fut == nil { + conn.opts.Logger.Report(LogUnexpectedResultId, conn, header) + } + } +} + +// eventer goroutine gets watch events and updates values for watchers. +func (conn *Connection) eventer(events <-chan connWatchEvent) { + for event := range events { + if value, ok := conn.watchMap.Load(event.key); ok { + st := value.(chan watchState) + state := <-st + state.value = event.value + if state.version == math.MaxUint { + state.version = initWatchEventVersion + 1 + } else { + state.version += 1 + } + state.ack = false + if state.changed != nil { + close(state.changed) + state.changed = nil + } + st <- state + } + // It is possible to get IPROTO_EVENT after we already send + // IPROTO_UNWATCH due to processing on a Tarantool side or slow + // read from the network, so it looks like an expected behavior. + } +} + +func (conn *Connection) newFuture(req Request) (fut *Future) { + ctx := req.Ctx() + fut = NewFuture(req) + if conn.rlimit != nil && conn.opts.RLimitAction == RLimitDrop { + select { + case conn.rlimit <- struct{}{}: + default: + fut.err = ClientError{ + ErrRateLimited, + "Request is rate limited on client", + } + fut.ready = nil + fut.done = nil + return + } + } + fut.requestId = conn.nextRequestId(ctx != nil) + shardn := fut.requestId & (conn.opts.Concurrency - 1) + shard := &conn.shard[shardn] + shard.rmut.Lock() + switch atomic.LoadUint32(&conn.state) { + case connClosed: + fut.err = ClientError{ + ErrConnectionClosed, + "using closed connection", + } + fut.ready = nil + fut.done = nil + shard.rmut.Unlock() + return + case connDisconnected: + fut.err = ClientError{ + ErrConnectionNotReady, + "client connection is not ready", + } + fut.ready = nil + fut.done = nil + shard.rmut.Unlock() + return + case connShutdown: + fut.err = ClientError{ + ErrConnectionShutdown, + "server shutdown in progress", + } + fut.ready = nil + fut.done = nil + shard.rmut.Unlock() + return + } + pos := (fut.requestId / conn.opts.Concurrency) & (requestsMap - 1) + if ctx != nil { + select { + case <-ctx.Done(): + fut.SetError(fmt.Errorf("context is done (request ID %d): %w", + fut.requestId, context.Cause(ctx))) + shard.rmut.Unlock() + return + default: + } + shard.requestsWithCtx[pos].addFuture(fut) + } else { + shard.requests[pos].addFuture(fut) + if conn.opts.Timeout > 0 { + fut.timeout = time.Since(epoch) + conn.opts.Timeout + } + } + shard.rmut.Unlock() + if conn.rlimit != nil && conn.opts.RLimitAction == RLimitWait { + select { + case conn.rlimit <- struct{}{}: + default: + runtime.Gosched() + select { + case conn.rlimit <- struct{}{}: + case <-fut.done: + if fut.err == nil { + panic("fut.done is closed, but err is nil") + } + } + } + } + return +} + +// This method removes a future from the internal queue if the context +// is "done" before the response is come. +func (conn *Connection) contextWatchdog(fut *Future, ctx context.Context) { + select { + case <-fut.done: + case <-ctx.Done(): + } + + select { + case <-fut.done: return + default: + conn.cancelFuture(fut, fmt.Errorf("context is done (request ID %d): %w", + fut.requestId, context.Cause(ctx))) + } +} + +func (conn *Connection) incrementRequestCnt() { + atomic.AddInt64(&conn.requestCnt, int64(1)) +} + +func (conn *Connection) decrementRequestCnt() { + if atomic.AddInt64(&conn.requestCnt, int64(-1)) == 0 { + conn.cond.Broadcast() } +} - response = make([]byte, length_uint) - if(length_uint > 0){ - tl = 0 - for tl < int(length_uint) { - l, err = conn.connection.Read(response[tl:]) - tl += l - if err != nil { - return +func (conn *Connection) send(req Request, streamId uint64) *Future { + conn.incrementRequestCnt() + + fut := conn.newFuture(req) + if fut.ready == nil { + conn.decrementRequestCnt() + return fut + } + + if req.Ctx() != nil { + select { + case <-req.Ctx().Done(): + conn.cancelFuture(fut, fmt.Errorf("context is done (request ID %d)", fut.requestId)) + return fut + default: + } + go conn.contextWatchdog(fut, req.Ctx()) + } + conn.putFuture(fut, req, streamId) + + return fut +} + +func (conn *Connection) putFuture(fut *Future, req Request, streamId uint64) { + shardn := fut.requestId & (conn.opts.Concurrency - 1) + shard := &conn.shard[shardn] + shard.bufmut.Lock() + select { + case <-fut.done: + shard.bufmut.Unlock() + return + default: + } + firstWritten := shard.buf.Len() == 0 + if shard.buf.Cap() == 0 { + shard.buf.b = make([]byte, 0, 128) + shard.enc = msgpack.NewEncoder(&shard.buf) + } + blen := shard.buf.Len() + reqid := fut.requestId + if err := pack(&shard.buf, shard.enc, reqid, req, streamId, conn.schemaResolver); err != nil { + shard.buf.Trunc(blen) + shard.bufmut.Unlock() + if f := conn.fetchFuture(reqid); f == fut { + fut.SetError(err) + conn.markDone(fut) + } else if f != nil { + /* in theory, it is possible. In practice, you have + * to have race condition that lasts hours */ + panic("Unknown future") + } else { + fut.wait() + if fut.err == nil { + panic("Future removed from queue without error") + } + if _, ok := fut.err.(ClientError); ok { + // packing error is more important than connection + // error, because it is indication of programmer's + // mistake. + fut.SetError(err) } } + return + } + shard.bufmut.Unlock() + + if firstWritten { + conn.dirtyShard <- shardn + } + + if req.Async() { + if fut = conn.fetchFuture(reqid); fut != nil { + header := Header{ + RequestId: reqid, + Error: ErrorNo, + } + fut.SetResponse(header, nil) + conn.markDone(fut) + } + } +} + +func (conn *Connection) markDone(fut *Future) { + if conn.rlimit != nil { + <-conn.rlimit + } + conn.decrementRequestCnt() +} + +func (conn *Connection) fetchFuture(reqid uint32) (fut *Future) { + shard := &conn.shard[reqid&(conn.opts.Concurrency-1)] + shard.rmut.Lock() + fut = conn.getFutureImp(reqid, true) + shard.rmut.Unlock() + return fut +} + +func (conn *Connection) getFutureImp(reqid uint32, fetch bool) *Future { + shard := &conn.shard[reqid&(conn.opts.Concurrency-1)] + pos := (reqid / conn.opts.Concurrency) & (requestsMap - 1) + // futures with even requests id belong to requests list with nil context + if reqid%2 == 0 { + return shard.requests[pos].findFuture(reqid, fetch) + } else { + return shard.requestsWithCtx[pos].findFuture(reqid, fetch) + } +} + +func (conn *Connection) timeouts() { + timeout := conn.opts.Timeout + t := time.NewTimer(timeout) + for { + var nowepoch time.Duration + select { + case <-conn.control: + t.Stop() + return + case <-t.C: + } + minNext := time.Since(epoch) + timeout + for i := range conn.shard { + nowepoch = time.Since(epoch) + shard := &conn.shard[i] + for pos := range shard.requests { + shard.rmut.Lock() + pair := &shard.requests[pos] + for pair.first != nil && pair.first.timeout < nowepoch { + shard.bufmut.Lock() + fut := pair.first + pair.first = fut.next + if fut.next == nil { + pair.last = &pair.first + } else { + fut.next = nil + } + fut.SetError(ClientError{ + Code: ErrTimeouted, + Msg: fmt.Sprintf("client timeout for request %d", fut.requestId), + }) + conn.markDone(fut) + shard.bufmut.Unlock() + } + if pair.first != nil && pair.first.timeout < minNext { + minNext = pair.first.timeout + } + shard.rmut.Unlock() + } + } + nowepoch = time.Since(epoch) + if nowepoch+time.Microsecond < minNext { + t.Reset(minNext - nowepoch) + } else { + t.Reset(time.Microsecond) + } } +} + +func read(r io.Reader, lenbuf []byte) (response []byte, err error) { + var length uint64 + + if _, err = io.ReadFull(r, lenbuf); err != nil { + return + } + if lenbuf[0] != 0xce { + err = errors.New("wrong response header") + return + } + length = (uint64(lenbuf[1]) << 24) + + (uint64(lenbuf[2]) << 16) + + (uint64(lenbuf[3]) << 8) + + uint64(lenbuf[4]) + + switch { + case length == 0: + err = errors.New("response should not be 0 length") + return + case length > math.MaxUint32: + err = errors.New("response is too big") + return + } + + response = make([]byte, length) + _, err = io.ReadFull(r, response) return } -func (conn *Connection) nextRequestId() (requestId uint32) { - conn.requestId = atomic.AddUint32(&conn.requestId, 1) - return conn.requestId +func (conn *Connection) nextRequestId(context bool) (requestId uint32) { + if context { + return atomic.AddUint32(&conn.contextRequestId, 2) + } else { + return atomic.AddUint32(&conn.requestId, 2) + } +} + +// Do performs a request asynchronously on the connection. +// +// An error is returned if the request was formed incorrectly, or failed to +// create the future. +func (conn *Connection) Do(req Request) *Future { + if connectedReq, ok := req.(ConnectedRequest); ok { + if connectedReq.Conn() != conn { + fut := NewFuture(req) + fut.SetError(errUnknownRequest) + return fut + } + } + return conn.send(req, ignoreStreamId) +} + +// ConfiguredTimeout returns a timeout from connection config. +func (conn *Connection) ConfiguredTimeout() time.Duration { + return conn.opts.Timeout +} + +// SetSchema sets Schema for the connection. +func (conn *Connection) SetSchema(s Schema) { + sCopy := s.copy() + spaceAndIndexNamesSupported := + isFeatureInSlice(iproto.IPROTO_FEATURE_SPACE_AND_INDEX_NAMES, + conn.serverProtocolInfo.Features) + + conn.mutex.Lock() + defer conn.mutex.Unlock() + conn.lockShards() + defer conn.unlockShards() + + conn.schemaResolver = &loadedSchemaResolver{ + Schema: sCopy, + SpaceAndIndexNamesSupported: spaceAndIndexNamesSupported, + } +} + +// NewPrepared passes a sql statement to Tarantool for preparation synchronously. +func (conn *Connection) NewPrepared(expr string) (*Prepared, error) { + req := NewPrepareRequest(expr) + resp, err := conn.Do(req).GetResponse() + if err != nil { + return nil, err + } + return NewPreparedFromResponse(conn, resp) +} + +// NewStream creates new Stream object for connection. +// +// Since v. 2.10.0, Tarantool supports streams and interactive transactions over them. +// To use interactive transactions, memtx_use_mvcc_engine box option should be set to true. +// Since 1.7.0 +func (conn *Connection) NewStream() (*Stream, error) { + next := atomic.AddUint64(&conn.lastStreamId, 1) + return &Stream{ + Id: next, + Conn: conn, + }, nil +} + +// watchState is the current state of the watcher. See the idea at p. 70, 105: +// https://drive.google.com/file/d/1nPdvhB0PutEJzdCq5ms6UI58dp50fcAN/view +type watchState struct { + // value is a current value. + value interface{} + // version is a current version of the value. + version uint + // ack true if the acknowledge is already sent. + ack bool + // cnt is a count of active watchers for the key. + cnt int + // changed is a channel for broadcast the value changes. + changed chan struct{} + // unready channel exists if a state is not ready to work (subscription + // or unsubscription in progress). + unready chan struct{} +} + +// initWatchEventVersion is an initial version until no events from Tarantool. +const initWatchEventVersion uint = 0 + +// connWatcher is an internal implementation of the Watcher interface. +type connWatcher struct { + unregister sync.Once + // done is closed when the watcher is unregistered, but the watcher + // goroutine is not yet finished. + done chan struct{} + // finished is closed when the watcher is unregistered and the watcher + // goroutine is finished. + finished chan struct{} +} + +// Unregister unregisters the connection watcher. +func (w *connWatcher) Unregister() { + w.unregister.Do(func() { + close(w.done) + }) + <-w.finished +} + +// subscribeWatchChannel returns an existing one or a new watch state channel +// for the key. It also increases a counter of active watchers for the channel. +func subscribeWatchChannel(conn *Connection, key string) (chan watchState, error) { + var st chan watchState + + for st == nil { + if val, ok := conn.watchMap.Load(key); !ok { + st = make(chan watchState, 1) + state := watchState{ + value: nil, + version: initWatchEventVersion, + ack: false, + cnt: 0, + changed: nil, + unready: make(chan struct{}), + } + st <- state + + if val, loaded := conn.watchMap.LoadOrStore(key, st); !loaded { + if _, err := conn.Do(newWatchRequest(key)).Get(); err != nil { + conn.watchMap.Delete(key) + close(state.unready) + return nil, err + } + // It is a successful subsctiption to a watch events by itself. + state = <-st + state.cnt = 1 + close(state.unready) + state.unready = nil + st <- state + continue + } else { + close(state.unready) + close(st) + st = val.(chan watchState) + } + } else { + st = val.(chan watchState) + } + + // It is an existing channel created outside. It may be in the + // unready state. + state := <-st + if state.unready == nil { + state.cnt += 1 + } + st <- state + + if state.unready != nil { + // Wait for an update and retry. + <-state.unready + st = nil + } + } + + return st, nil +} + +func isFeatureInSlice(expected iproto.Feature, actualSlice []iproto.Feature) bool { + for _, actual := range actualSlice { + if expected == actual { + return true + } + } + return false +} + +// NewWatcher creates a new Watcher object for the connection. +// +// Server must support IPROTO_FEATURE_WATCHERS to use watchers. +// +// After watcher creation, the watcher callback is invoked for the first time. +// In this case, the callback is triggered whether or not the key has already +// been broadcast. All subsequent invocations are triggered with +// box.broadcast() called on the remote host. If a watcher is subscribed for a +// key that has not been broadcast yet, the callback is triggered only once, +// after the registration of the watcher. +// +// The watcher callbacks are always invoked in a separate goroutine. A watcher +// callback is never executed in parallel with itself, but they can be executed +// in parallel to other watchers. +// +// If the key is updated while the watcher callback is running, the callback +// will be invoked again with the latest value as soon as it returns. +// +// Watchers survive reconnection. All registered watchers are automatically +// resubscribed when the connection is reestablished. +// +// Keep in mind that garbage collection of a watcher handle doesn’t lead to the +// watcher’s destruction. In this case, the watcher remains registered. You +// need to call Unregister() directly. +// +// Unregister() guarantees that there will be no the watcher's callback calls +// after it, but Unregister() call from the callback leads to a deadlock. +// +// See: +// https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_events/#box-watchers +// +// Since 1.10.0 +func (conn *Connection) NewWatcher(key string, callback WatchCallback) (Watcher, error) { + // We need to check the feature because the IPROTO_WATCH request is + // asynchronous. We do not expect any response from a Tarantool instance + // That's why we can't just check the Tarantool response for an unsupported + // request error. + if !isFeatureInSlice(iproto.IPROTO_FEATURE_WATCHERS, + conn.serverProtocolInfo.Features) { + err := fmt.Errorf("the feature %s must be supported by connection "+ + "to create a watcher", iproto.IPROTO_FEATURE_WATCHERS) + return nil, err + } + + return conn.newWatcherImpl(key, callback) +} + +func (conn *Connection) newWatcherImpl(key string, callback WatchCallback) (Watcher, error) { + st, err := subscribeWatchChannel(conn, key) + if err != nil { + return nil, err + } + + // Start the watcher goroutine. + done := make(chan struct{}) + finished := make(chan struct{}) + + go func() { + version := initWatchEventVersion + for { + state := <-st + if state.changed == nil { + state.changed = make(chan struct{}) + } + st <- state + + if state.version != version { + callback(WatchEvent{ + Conn: conn, + Key: key, + Value: state.value, + }) + version = state.version + + // Do we need to acknowledge the notification? + state = <-st + sendAck := !state.ack && version == state.version + if sendAck { + state.ack = true + } + st <- state + + if sendAck { + // We expect a reconnect and re-subscribe if it fails to + // send the watch request. So it looks ok do not check a + // result. But we need to make sure that the re-watch + // request will not be finished by a small per-request + // timeout. + req := newWatchRequest(key).Context(context.Background()) + conn.Do(req).Get() + } + } + + select { + case <-done: + state := <-st + state.cnt -= 1 + if state.cnt == 0 { + state.unready = make(chan struct{}) + } + st <- state + + if state.cnt == 0 { + // The last one sends IPROTO_UNWATCH. + if !conn.ClosedNow() { + // conn.ClosedNow() check is a workaround for calling + // Unregister from connectionClose(). + // + // We need to make sure that the unwatch request will + // not be finished by a small per-request timeout to + // avoid lost of the request. + req := newUnwatchRequest(key).Context(context.Background()) + conn.Do(req).Get() + } + conn.watchMap.Delete(key) + close(state.unready) + } + + close(finished) + return + case <-state.changed: + } + } + }() + + return &connWatcher{ + done: done, + finished: finished, + }, nil +} + +// ProtocolInfo returns protocol version and protocol features +// supported by connected Tarantool server. Beware that values might be +// outdated if connection is in a disconnected state. +// Since 2.0.0 +func (conn *Connection) ProtocolInfo() ProtocolInfo { + return conn.serverProtocolInfo.Clone() +} + +func shutdownEventCallback(event WatchEvent) { + // Receives "true" on server shutdown. + // See https://www.tarantool.io/en/doc/latest/dev_guide/internals/iproto/graceful_shutdown/ + // step 2. + val, ok := event.Value.(bool) + if ok && val { + go event.Conn.shutdown(false) + } +} + +func (conn *Connection) shutdown(forever bool) error { + // Forbid state changes. + conn.mutex.Lock() + defer conn.mutex.Unlock() + + if !atomic.CompareAndSwapUint32(&conn.state, connConnected, connShutdown) { + if forever { + err := ClientError{ErrConnectionClosed, "connection closed by client"} + return conn.closeConnection(err, true) + } + return nil + } + + if forever { + // We don't want to reconnect any more. + conn.opts.Reconnect = 0 + conn.opts.MaxReconnects = 0 + } + + conn.cond.Broadcast() + conn.notify(Shutdown) + + c := conn.c + for { + if (atomic.LoadUint32(&conn.state) != connShutdown) || (c != conn.c) { + return nil + } + if atomic.LoadInt64(&conn.requestCnt) == 0 { + break + } + // Use cond var on conn.mutex since request execution may + // call reconnect(). It is ok if state changes as part of + // reconnect since Tarantool server won't allow to reconnect + // in the middle of shutting down. + conn.cond.Wait() + } + + if forever { + err := ClientError{ErrConnectionClosed, "connection closed by client"} + return conn.closeConnection(err, true) + } else { + // Start to reconnect based on common rules, same as in net.box. + // Reconnect also closes the connection: server waits until all + // subscribed connections are terminated. + // See https://www.tarantool.io/en/doc/latest/dev_guide/internals/iproto/graceful_shutdown/ + // step 3. + conn.reconnectImpl(ClientError{ + ErrConnectionClosed, + "connection closed after server shutdown", + }, conn.c) + return nil + } } diff --git a/connector.go b/connector.go new file mode 100644 index 000000000..7112eb099 --- /dev/null +++ b/connector.go @@ -0,0 +1,19 @@ +package tarantool + +import "time" + +// Doer is an interface that performs requests asynchronously. +type Doer interface { + // Do performs a request asynchronously. + Do(req Request) (fut *Future) +} + +type Connector interface { + Doer + ConnectedNow() bool + Close() error + ConfiguredTimeout() time.Duration + NewPrepared(expr string) (*Prepared, error) + NewStream() (*Stream, error) + NewWatcher(key string, callback WatchCallback) (Watcher, error) +} diff --git a/const.go b/const.go index 54eb4b4bf..e8f389253 100644 --- a/const.go +++ b/const.go @@ -1,42 +1,15 @@ package tarantool -const ( - SelectRequest = 1 - InsertRequest = 2 - ReplaceRequest = 3 - UpdateRequest = 4 - DeleteRequest = 5 - CallRequest = 6 - AuthRequest = 7 - PingRequest = 64 - SubscribeRequest = 66 - - KeyCode = 0x00 - KeySync = 0x01 - KeySpaceNo = 0x10 - KeyIndexNo = 0x11 - KeyLimit = 0x12 - KeyOffset = 0x13 - KeyIterator = 0x14 - KeyKey = 0x20 - KeyTuple = 0x21 - KeyFunctionName = 0x22 - KeyData = 0x30 - KeyError = 0x31 - - // https://github.com/fl00r/go-tarantool-1.6/issues/2 - IterEq = uint32(0) // key == x ASC order - IterReq = uint32(1) // key == x DESC order - IterAll = uint32(2) // all tuples - IterLt = uint32(3) // key < x - IterLe = uint32(4) // key <= x - IterGe = uint32(5) // key > x - IterGt = uint32(6) // key >= x - IterBitsAllSet = uint32(7) // all bits from x are set in key - IterBitsAnySet = uint32(8) // at least one x's bit is set - IterBitsAllNotSet = uint32(9) // all bits are not set +import ( + "github.com/tarantool/go-iproto" +) - OkCode = uint32(0) +const ( + packetLengthBytes = 5 +) - PacketLengthBytes = 5 +const ( + // ErrorNo indicates that no error has occurred. It could be used to + // check that a response has an error without the response body decoding. + ErrorNo = iproto.ER_UNKNOWN ) diff --git a/crud/common.go b/crud/common.go new file mode 100644 index 000000000..df3c4a795 --- /dev/null +++ b/crud/common.go @@ -0,0 +1,97 @@ +// Package crud with support of API of Tarantool's CRUD module. +// +// Supported CRUD methods: +// +// - insert +// +// - insert_object +// +// - insert_many +// +// - insert_object_many +// +// - get +// +// - update +// +// - delete +// +// - replace +// +// - replace_object +// +// - replace_many +// +// - replace_object_many +// +// - upsert +// +// - upsert_object +// +// - upsert_many +// +// - upsert_object_many +// +// - select +// +// - min +// +// - max +// +// - truncate +// +// - len +// +// - storage_info +// +// - count +// +// - stats +// +// - unflatten_rows +// +// Since: 1.11.0. +package crud + +import ( + "context" + "io" + + "github.com/tarantool/go-iproto" + + "github.com/tarantool/go-tarantool/v3" +) + +type baseRequest struct { + impl *tarantool.CallRequest +} + +func newCall(method string) *tarantool.CallRequest { + return tarantool.NewCall17Request(method) +} + +// Type returns IPROTO type for CRUD request. +func (req baseRequest) Type() iproto.Type { + return req.impl.Type() +} + +// Ctx returns a context of CRUD request. +func (req baseRequest) Ctx() context.Context { + return req.impl.Ctx() +} + +// Async returns is CRUD request expects a response. +func (req baseRequest) Async() bool { + return req.impl.Async() +} + +// Response creates a response for the baseRequest. +func (req baseRequest) Response(header tarantool.Header, + body io.Reader) (tarantool.Response, error) { + return req.impl.Response(header, body) +} + +type spaceRequest struct { + baseRequest + space string +} diff --git a/crud/conditions.go b/crud/conditions.go new file mode 100644 index 000000000..8945adcd1 --- /dev/null +++ b/crud/conditions.go @@ -0,0 +1,28 @@ +package crud + +// Operator is a type to describe operator of operation. +type Operator string + +const ( + // Eq - comparison operator for "equal". + Eq Operator = "=" + // Lt - comparison operator for "less than". + Lt Operator = "<" + // Le - comparison operator for "less than or equal". + Le Operator = "<=" + // Gt - comparison operator for "greater than". + Gt Operator = ">" + // Ge - comparison operator for "greater than or equal". + Ge Operator = ">=" +) + +// Condition describes CRUD condition as a table +// {operator, field-identifier, value}. +type Condition struct { + // Instruct msgpack to pack this struct as array, so no custom packer + // is needed. + _msgpack struct{} `msgpack:",asArray"` //nolint: structcheck,unused + Operator Operator + Field string // Field name or index name. + Value interface{} +} diff --git a/crud/count.go b/crud/count.go new file mode 100644 index 000000000..b90198658 --- /dev/null +++ b/crud/count.go @@ -0,0 +1,118 @@ +package crud + +import ( + "context" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// CountResult describes result for `crud.count` method. +type CountResult = NumberResult + +// CountOpts describes options for `crud.count` method. +type CountOpts struct { + // Timeout is a `vshard.call` timeout and vshard + // master discovery timeout (in seconds). + Timeout OptFloat64 + // VshardRouter is cartridge vshard group name or + // vshard router instance. + VshardRouter OptString + // Mode is a parameter with `write`/`read` possible values, + // if `write` is specified then operation is performed on master. + Mode OptString + // PreferReplica is a parameter to specify preferred target + // as one of the replicas. + PreferReplica OptBool + // Balance is a parameter to use replica according to vshard + // load balancing policy. + Balance OptBool + // YieldEvery describes number of tuples processed to yield after. + // Should be positive. + YieldEvery OptUint + // BucketId is a bucket ID. + BucketId OptUint + // ForceMapCall describes the map call is performed without any + // optimizations even if full primary key equal condition is specified. + ForceMapCall OptBool + // Fullscan describes if a critical log entry will be skipped on + // potentially long count. + Fullscan OptBool +} + +// EncodeMsgpack provides custom msgpack encoder. +func (opts CountOpts) EncodeMsgpack(enc *msgpack.Encoder) error { + const optsCnt = 9 + + names := [optsCnt]string{timeoutOptName, vshardRouterOptName, + modeOptName, preferReplicaOptName, balanceOptName, + yieldEveryOptName, bucketIdOptName, forceMapCallOptName, + fullscanOptName} + values := [optsCnt]interface{}{} + exists := [optsCnt]bool{} + values[0], exists[0] = opts.Timeout.Get() + values[1], exists[1] = opts.VshardRouter.Get() + values[2], exists[2] = opts.Mode.Get() + values[3], exists[3] = opts.PreferReplica.Get() + values[4], exists[4] = opts.Balance.Get() + values[5], exists[5] = opts.YieldEvery.Get() + values[6], exists[6] = opts.BucketId.Get() + values[7], exists[7] = opts.ForceMapCall.Get() + values[8], exists[8] = opts.Fullscan.Get() + + return encodeOptions(enc, names[:], values[:], exists[:]) +} + +// CountRequest helps you to create request object to call `crud.count` +// for execution by a Connection. +type CountRequest struct { + spaceRequest + conditions []Condition + opts CountOpts +} + +type countArgs struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Space string + Conditions []Condition + Opts CountOpts +} + +// MakeCountRequest returns a new empty CountRequest. +func MakeCountRequest(space string) CountRequest { + req := CountRequest{} + req.impl = newCall("crud.count") + req.space = space + req.conditions = nil + req.opts = CountOpts{} + return req +} + +// Conditions sets the conditions for the CountRequest request. +// Note: default value is nil. +func (req CountRequest) Conditions(conditions []Condition) CountRequest { + req.conditions = conditions + return req +} + +// Opts sets the options for the CountRequest request. +// Note: default value is nil. +func (req CountRequest) Opts(opts CountOpts) CountRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req CountRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + args := countArgs{Space: req.space, Conditions: req.conditions, Opts: req.opts} + req.impl = req.impl.Args(args) + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req CountRequest) Context(ctx context.Context) CountRequest { + req.impl = req.impl.Context(ctx) + + return req +} diff --git a/crud/delete.go b/crud/delete.go new file mode 100644 index 000000000..075b25b3c --- /dev/null +++ b/crud/delete.go @@ -0,0 +1,67 @@ +package crud + +import ( + "context" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// DeleteOpts describes options for `crud.delete` method. +type DeleteOpts = SimpleOperationOpts + +// DeleteRequest helps you to create request object to call `crud.delete` +// for execution by a Connection. +type DeleteRequest struct { + spaceRequest + key Tuple + opts DeleteOpts +} + +type deleteArgs struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Space string + Key Tuple + Opts DeleteOpts +} + +// MakeDeleteRequest returns a new empty DeleteRequest. +func MakeDeleteRequest(space string) DeleteRequest { + req := DeleteRequest{} + req.impl = newCall("crud.delete") + req.space = space + req.opts = DeleteOpts{} + return req +} + +// Key sets the key for the DeleteRequest request. +// Note: default value is nil. +func (req DeleteRequest) Key(key Tuple) DeleteRequest { + req.key = key + return req +} + +// Opts sets the options for the DeleteRequest request. +// Note: default value is nil. +func (req DeleteRequest) Opts(opts DeleteOpts) DeleteRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req DeleteRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + if req.key == nil { + req.key = []interface{}{} + } + args := deleteArgs{Space: req.space, Key: req.key, Opts: req.opts} + req.impl = req.impl.Args(args) + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req DeleteRequest) Context(ctx context.Context) DeleteRequest { + req.impl = req.impl.Context(ctx) + + return req +} diff --git a/crud/error.go b/crud/error.go new file mode 100644 index 000000000..9233de5c3 --- /dev/null +++ b/crud/error.go @@ -0,0 +1,142 @@ +package crud + +import ( + "reflect" + "strings" + + "github.com/vmihailenco/msgpack/v5" +) + +// Error describes CRUD error object. +type Error struct { + // ClassName is an error class that implies its source (for example, "CountError"). + ClassName string + // Err is the text of reason. + Err string + // File is a source code file where the error was caught. + File string + // Line is a number of line in the source code file where the error was caught. + Line uint64 + // Stack is an information about the call stack when an error + // occurs in a string format. + Stack string + // Str is the text of reason with error class. + Str string + // OperationData is the object/tuple with which an error occurred. + OperationData interface{} + // operationDataType contains the type of OperationData. + operationDataType reflect.Type +} + +// newError creates an Error object with a custom operation data type to decoding. +func newError(operationDataType reflect.Type) *Error { + return &Error{operationDataType: operationDataType} +} + +// DecodeMsgpack provides custom msgpack decoder. +func (e *Error) DecodeMsgpack(d *msgpack.Decoder) error { + l, err := d.DecodeMapLen() + if err != nil { + return err + } + for i := 0; i < l; i++ { + key, err := d.DecodeString() + if err != nil { + return err + } + switch key { + case "class_name": + if e.ClassName, err = d.DecodeString(); err != nil { + return err + } + case "err": + if e.Err, err = d.DecodeString(); err != nil { + return err + } + case "file": + if e.File, err = d.DecodeString(); err != nil { + return err + } + case "line": + if e.Line, err = d.DecodeUint64(); err != nil { + return err + } + case "stack": + if e.Stack, err = d.DecodeString(); err != nil { + return err + } + case "str": + if e.Str, err = d.DecodeString(); err != nil { + return err + } + case "operation_data": + if e.operationDataType != nil { + tuple := reflect.New(e.operationDataType) + if err = d.DecodeValue(tuple); err != nil { + return err + } + e.OperationData = tuple.Elem().Interface() + } else { + if err = d.Decode(&e.OperationData); err != nil { + return err + } + } + default: + if err := d.Skip(); err != nil { + return err + } + } + } + + return nil +} + +// Error converts an Error to a string. +func (e Error) Error() string { + return e.Str +} + +// ErrorMany describes CRUD error object for `_many` methods. +type ErrorMany struct { + Errors []Error + // operationDataType contains the type of OperationData for each Error. + operationDataType reflect.Type +} + +// newErrorMany creates an ErrorMany object with a custom operation data type to decoding. +func newErrorMany(operationDataType reflect.Type) *ErrorMany { + return &ErrorMany{operationDataType: operationDataType} +} + +// DecodeMsgpack provides custom msgpack decoder. +func (e *ErrorMany) DecodeMsgpack(d *msgpack.Decoder) error { + l, err := d.DecodeArrayLen() + if err != nil { + return err + } + + var errs []Error + for i := 0; i < l; i++ { + crudErr := newError(e.operationDataType) + if err := d.Decode(&crudErr); err != nil { + return err + } + errs = append(errs, *crudErr) + } + + if len(errs) > 0 { + e.Errors = errs + } + + return nil +} + +// Error converts an Error to a string. +func (e ErrorMany) Error() string { + var str []string + for _, err := range e.Errors { + str = append(str, err.Str) + } + + return strings.Join(str, "\n") +} diff --git a/crud/error_test.go b/crud/error_test.go new file mode 100644 index 000000000..71fda30d4 --- /dev/null +++ b/crud/error_test.go @@ -0,0 +1,28 @@ +package crud_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/tarantool/go-tarantool/v3/crud" +) + +func TestErrorMany(t *testing.T) { + errs := crud.ErrorMany{Errors: []crud.Error{ + { + ClassName: "a", + Str: "msg 1", + }, + { + ClassName: "b", + Str: "msg 2", + }, + { + ClassName: "c", + Str: "msg 3", + }, + }} + + require.Equal(t, "msg 1\nmsg 2\nmsg 3", errs.Error()) +} diff --git a/crud/example_test.go b/crud/example_test.go new file mode 100644 index 000000000..1b97308ae --- /dev/null +++ b/crud/example_test.go @@ -0,0 +1,420 @@ +package crud_test + +import ( + "context" + "fmt" + "reflect" + "time" + + "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/crud" +) + +const ( + exampleServer = "127.0.0.1:3013" + exampleSpace = "test" +) + +var exampleOpts = tarantool.Opts{ + Timeout: 5 * time.Second, +} + +var exampleDialer = tarantool.NetDialer{ + Address: exampleServer, + User: "test", + Password: "test", +} + +func exampleConnect() *tarantool.Connection { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, exampleDialer, exampleOpts) + if err != nil { + panic("Connection is not established: " + err.Error()) + } + return conn +} + +// ExampleResult_rowsInterface demonstrates how to use a helper type Result +// to decode a crud response. In this example, rows are decoded as an +// interface{} type. +func ExampleResult_rowsInterface() { + conn := exampleConnect() + req := crud.MakeReplaceRequest(exampleSpace). + Tuple([]interface{}{uint(2010), nil, "bla"}) + + ret := crud.Result{} + if err := conn.Do(req).GetTyped(&ret); err != nil { + fmt.Printf("Failed to execute request: %s", err) + return + } + + fmt.Println(ret.Metadata) + fmt.Println(ret.Rows) + // Output: + // [{id unsigned false} {bucket_id unsigned true} {name string false}] + // [[2010 45 bla]] +} + +// ExampleResult_rowsCustomType demonstrates how to use a helper type Result +// to decode a crud response. In this example, rows are decoded as a +// custom type. +func ExampleResult_rowsCustomType() { + conn := exampleConnect() + req := crud.MakeReplaceRequest(exampleSpace). + Tuple([]interface{}{uint(2010), nil, "bla"}) + + type Tuple struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Id uint64 + BucketId uint64 + Name string + } + ret := crud.MakeResult(reflect.TypeOf(Tuple{})) + + if err := conn.Do(req).GetTyped(&ret); err != nil { + fmt.Printf("Failed to execute request: %s", err) + return + } + + fmt.Println(ret.Metadata) + rows := ret.Rows.([]Tuple) + fmt.Println(rows) + // Output: + // [{id unsigned false} {bucket_id unsigned true} {name string false}] + // [{{} 2010 45 bla}] +} + +// ExampleTuples_customType demonstrates how to use a slice of objects of a +// custom type as Tuples to make a ReplaceManyRequest. +func ExampleTuples_customType() { + conn := exampleConnect() + + // The type will be encoded/decoded as an array. + type Tuple struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Id uint64 + BucketId *uint64 + Name string + } + req := crud.MakeReplaceManyRequest(exampleSpace).Tuples([]Tuple{ + Tuple{ + Id: 2010, + BucketId: nil, + Name: "bla", + }, + }) + + ret := crud.MakeResult(reflect.TypeOf(Tuple{})) + if err := conn.Do(req).GetTyped(&ret); err != nil { + fmt.Printf("Failed to execute request: %s", err) + return + } + + fmt.Println(ret.Metadata) + rows := ret.Rows.([]Tuple) + if len(rows) == 1 { + fmt.Println(rows[0].Id) + fmt.Println(*rows[0].BucketId) + fmt.Println(rows[0].Name) + } else { + fmt.Printf("Unexpected result tuples count: %d", len(rows)) + } + // Output: + // [{id unsigned false} {bucket_id unsigned true} {name string false}] + // 2010 + // 45 + // bla +} + +// ExampleObjects_customType demonstrates how to use a slice of objects of +// a custom type as Objects to make a ReplaceObjectManyRequest. +func ExampleObjects_customType() { + conn := exampleConnect() + + // The type will be encoded/decoded as a map. + type Tuple struct { + Id uint64 `msgpack:"id,omitempty"` + BucketId *uint64 `msgpack:"bucket_id,omitempty"` + Name string `msgpack:"name,omitempty"` + } + req := crud.MakeReplaceObjectManyRequest(exampleSpace).Objects([]Tuple{ + Tuple{ + Id: 2010, + BucketId: nil, + Name: "bla", + }, + }) + + ret := crud.MakeResult(reflect.TypeOf(Tuple{})) + if err := conn.Do(req).GetTyped(&ret); err != nil { + fmt.Printf("Failed to execute request: %s", err) + return + } + + fmt.Println(ret.Metadata) + rows := ret.Rows.([]Tuple) + if len(rows) == 1 { + fmt.Println(rows[0].Id) + fmt.Println(*rows[0].BucketId) + fmt.Println(rows[0].Name) + } else { + fmt.Printf("Unexpected result tuples count: %d", len(rows)) + } + // Output: + // [{id unsigned false} {bucket_id unsigned true} {name string false}] + // 2010 + // 45 + // bla +} + +// ExampleResult_operationData demonstrates how to obtain information +// about erroneous objects from crud.Error using `OperationData` field. +func ExampleResult_operationData() { + conn := exampleConnect() + req := crud.MakeInsertObjectManyRequest(exampleSpace).Objects([]crud.Object{ + crud.MapObject{ + "id": 2, + "bucket_id": 3, + "name": "Makar", + }, + crud.MapObject{ + "id": 2, + "bucket_id": 3, + "name": "Vasya", + }, + crud.MapObject{ + "id": 3, + "bucket_id": 5, + }, + }) + + ret := crud.Result{} + if err := conn.Do(req).GetTyped(&ret); err != nil { + crudErrs := err.(crud.ErrorMany) + fmt.Println("Erroneous data:") + for _, crudErr := range crudErrs.Errors { + fmt.Println(crudErr.OperationData) + } + } else { + fmt.Println(ret.Metadata) + fmt.Println(ret.Rows) + } + + // Output: + // Erroneous data: + // [2 3 Vasya] + // map[bucket_id:5 id:3] +} + +// ExampleResult_operationDataCustomType demonstrates the ability +// to cast `OperationData` field, extracted from a CRUD error during decoding +// using crud.Result, to a custom type. +// The type of `OperationData` is determined as the crud.Result row type. +func ExampleResult_operationDataCustomType() { + conn := exampleConnect() + req := crud.MakeInsertObjectManyRequest(exampleSpace).Objects([]crud.Object{ + crud.MapObject{ + "id": 1, + "bucket_id": 3, + "name": "Makar", + }, + crud.MapObject{ + "id": 1, + "bucket_id": 3, + "name": "Vasya", + }, + crud.MapObject{ + "id": 3, + "bucket_id": 5, + }, + }) + + type Tuple struct { + Id uint64 `msgpack:"id,omitempty"` + BucketId uint64 `msgpack:"bucket_id,omitempty"` + Name string `msgpack:"name,omitempty"` + } + + ret := crud.MakeResult(reflect.TypeOf(Tuple{})) + if err := conn.Do(req).GetTyped(&ret); err != nil { + crudErrs := err.(crud.ErrorMany) + fmt.Println("Erroneous data:") + for _, crudErr := range crudErrs.Errors { + operationData := crudErr.OperationData.(Tuple) + fmt.Println(operationData) + } + } else { + fmt.Println(ret.Metadata) + fmt.Println(ret.Rows) + } + // Output: + // Erroneous data: + // {1 3 Vasya} + // {3 5 } +} + +// ExampleResult_many demonstrates that there is no difference in a +// response from *ManyRequest. +func ExampleResult_many() { + conn := exampleConnect() + req := crud.MakeReplaceManyRequest(exampleSpace). + Tuples([]crud.Tuple{ + []interface{}{uint(2010), nil, "bla"}, + []interface{}{uint(2011), nil, "bla"}, + }) + + ret := crud.Result{} + if err := conn.Do(req).GetTyped(&ret); err != nil { + fmt.Printf("Failed to execute request: %s", err) + return + } + + fmt.Println(ret.Metadata) + fmt.Println(ret.Rows) + // Output: + // [{id unsigned false} {bucket_id unsigned true} {name string false}] + // [[2010 45 bla] [2011 4 bla]] +} + +// ExampleResult_noreturn demonstrates noreturn request: a data change +// request where you don't need to retrieve the result, just want to know +// whether it was successful or not. +func ExampleResult_noreturn() { + conn := exampleConnect() + req := crud.MakeReplaceManyRequest(exampleSpace). + Tuples([]crud.Tuple{ + []interface{}{uint(2010), nil, "bla"}, + []interface{}{uint(2011), nil, "bla"}, + }). + Opts(crud.ReplaceManyOpts{ + Noreturn: crud.MakeOptBool(true), + }) + + ret := crud.Result{} + if err := conn.Do(req).GetTyped(&ret); err != nil { + fmt.Printf("Failed to execute request: %s", err) + return + } + + fmt.Println(ret.Metadata) + fmt.Println(ret.Rows) + // Output: + // [] + // +} + +// ExampleResult_error demonstrates how to use a helper type Result +// to handle a crud error. +func ExampleResult_error() { + conn := exampleConnect() + req := crud.MakeReplaceRequest("not_exist"). + Tuple([]interface{}{uint(2010), nil, "bla"}) + + ret := crud.Result{} + if err := conn.Do(req).GetTyped(&ret); err != nil { + crudErr := err.(crud.Error) + fmt.Printf("Failed to execute request: %s", crudErr) + } else { + fmt.Println(ret.Metadata) + fmt.Println(ret.Rows) + } + // Output: + // Failed to execute request: ReplaceError: Space "not_exist" doesn't exist +} + +// ExampleResult_errorMany demonstrates how to use a helper type Result +// to handle a crud error for a *ManyRequest. +func ExampleResult_errorMany() { + conn := exampleConnect() + initReq := crud.MakeReplaceRequest("not_exist"). + Tuple([]interface{}{uint(2010), nil, "bla"}) + if _, err := conn.Do(initReq).Get(); err != nil { + fmt.Printf("Failed to initialize the example: %s\n", err) + } + + req := crud.MakeInsertManyRequest(exampleSpace). + Tuples([]crud.Tuple{ + []interface{}{uint(2010), nil, "bla"}, + []interface{}{uint(2010), nil, "bla"}, + }) + ret := crud.Result{} + if err := conn.Do(req).GetTyped(&ret); err != nil { + crudErr := err.(crud.ErrorMany) + // We need to trim the error message to make the example repeatable. + errmsg := crudErr.Error()[:10] + fmt.Printf("Failed to execute request: %s", errmsg) + } else { + fmt.Println(ret.Metadata) + fmt.Println(ret.Rows) + } + // Output: + // Failed to execute request: CallError: +} + +func ExampleSelectRequest_pagination() { + conn := exampleConnect() + + const ( + fromTuple = 5 + allTuples = 10 + ) + var tuple interface{} + for i := 0; i < allTuples; i++ { + req := crud.MakeReplaceRequest(exampleSpace). + Tuple([]interface{}{uint(3000 + i), nil, "bla"}) + ret := crud.Result{} + if err := conn.Do(req).GetTyped(&ret); err != nil { + fmt.Printf("Failed to initialize the example: %s\n", err) + return + } + if i == fromTuple { + tuple = ret.Rows.([]interface{})[0] + } + } + + req := crud.MakeSelectRequest(exampleSpace). + Opts(crud.SelectOpts{ + First: crud.MakeOptInt(2), + After: crud.MakeOptTuple(tuple), + }) + ret := crud.Result{} + if err := conn.Do(req).GetTyped(&ret); err != nil { + fmt.Printf("Failed to execute request: %s", err) + return + } + fmt.Println(ret.Metadata) + fmt.Println(ret.Rows) + // Output: + // [{id unsigned false} {bucket_id unsigned true} {name string false}] + // [[3006 32 bla] [3007 33 bla]] +} + +func ExampleSchema() { + conn := exampleConnect() + + req := crud.MakeSchemaRequest() + var result crud.SchemaResult + + if err := conn.Do(req).GetTyped(&result); err != nil { + fmt.Printf("Failed to execute request: %s", err) + return + } + + // Schema may differ between different Tarantool versions. + // https://github.com/tarantool/tarantool/issues/4091 + // https://github.com/tarantool/tarantool/commit/17c9c034933d726925910ce5bf8b20e8e388f6e3 + for spaceName, spaceSchema := range result.Value { + fmt.Printf("Space format for '%s' is as follows:\n", spaceName) + + for _, field := range spaceSchema.Format { + fmt.Printf(" - field '%s' with type '%s'\n", field.Name, field.Type) + } + } + + // Output: + // Space format for 'test' is as follows: + // - field 'id' with type 'unsigned' + // - field 'bucket_id' with type 'unsigned' + // - field 'name' with type 'string' +} diff --git a/crud/get.go b/crud/get.go new file mode 100644 index 000000000..5a31473ef --- /dev/null +++ b/crud/get.go @@ -0,0 +1,113 @@ +package crud + +import ( + "context" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// GetOpts describes options for `crud.get` method. +type GetOpts struct { + // Timeout is a `vshard.call` timeout and vshard + // master discovery timeout (in seconds). + Timeout OptFloat64 + // VshardRouter is cartridge vshard group name or + // vshard router instance. + VshardRouter OptString + // Fields is field names for getting only a subset of fields. + Fields OptTuple + // BucketId is a bucket ID. + BucketId OptUint + // Mode is a parameter with `write`/`read` possible values, + // if `write` is specified then operation is performed on master. + Mode OptString + // PreferReplica is a parameter to specify preferred target + // as one of the replicas. + PreferReplica OptBool + // Balance is a parameter to use replica according to vshard + // load balancing policy. + Balance OptBool + // FetchLatestMetadata guarantees the up-to-date metadata (space format) + // in first return value, otherwise it may not take into account + // the latest migration of the data format. Performance overhead is up to 15%. + // Disabled by default. + FetchLatestMetadata OptBool +} + +// EncodeMsgpack provides custom msgpack encoder. +func (opts GetOpts) EncodeMsgpack(enc *msgpack.Encoder) error { + const optsCnt = 8 + + names := [optsCnt]string{timeoutOptName, vshardRouterOptName, + fieldsOptName, bucketIdOptName, modeOptName, + preferReplicaOptName, balanceOptName, fetchLatestMetadataOptName} + values := [optsCnt]interface{}{} + exists := [optsCnt]bool{} + values[0], exists[0] = opts.Timeout.Get() + values[1], exists[1] = opts.VshardRouter.Get() + values[2], exists[2] = opts.Fields.Get() + values[3], exists[3] = opts.BucketId.Get() + values[4], exists[4] = opts.Mode.Get() + values[5], exists[5] = opts.PreferReplica.Get() + values[6], exists[6] = opts.Balance.Get() + values[7], exists[7] = opts.FetchLatestMetadata.Get() + + return encodeOptions(enc, names[:], values[:], exists[:]) +} + +// GetRequest helps you to create request object to call `crud.get` +// for execution by a Connection. +type GetRequest struct { + spaceRequest + key Tuple + opts GetOpts +} + +type getArgs struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Space string + Key Tuple + Opts GetOpts +} + +// MakeGetRequest returns a new empty GetRequest. +func MakeGetRequest(space string) GetRequest { + req := GetRequest{} + req.impl = newCall("crud.get") + req.space = space + req.opts = GetOpts{} + return req +} + +// Key sets the key for the GetRequest request. +// Note: default value is nil. +func (req GetRequest) Key(key Tuple) GetRequest { + req.key = key + return req +} + +// Opts sets the options for the GetRequest request. +// Note: default value is nil. +func (req GetRequest) Opts(opts GetOpts) GetRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req GetRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + if req.key == nil { + req.key = []interface{}{} + } + args := getArgs{Space: req.space, Key: req.key, Opts: req.opts} + req.impl = req.impl.Args(args) + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req GetRequest) Context(ctx context.Context) GetRequest { + req.impl = req.impl.Context(ctx) + + return req +} diff --git a/crud/insert.go b/crud/insert.go new file mode 100644 index 000000000..4e56c6d91 --- /dev/null +++ b/crud/insert.go @@ -0,0 +1,125 @@ +package crud + +import ( + "context" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// InsertOpts describes options for `crud.insert` method. +type InsertOpts = SimpleOperationOpts + +// InsertRequest helps you to create request object to call `crud.insert` +// for execution by a Connection. +type InsertRequest struct { + spaceRequest + tuple Tuple + opts InsertOpts +} + +type insertArgs struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Space string + Tuple Tuple + Opts InsertOpts +} + +// MakeInsertRequest returns a new empty InsertRequest. +func MakeInsertRequest(space string) InsertRequest { + req := InsertRequest{} + req.impl = newCall("crud.insert") + req.space = space + req.opts = InsertOpts{} + return req +} + +// Tuple sets the tuple for the InsertRequest request. +// Note: default value is nil. +func (req InsertRequest) Tuple(tuple Tuple) InsertRequest { + req.tuple = tuple + return req +} + +// Opts sets the options for the insert request. +// Note: default value is nil. +func (req InsertRequest) Opts(opts InsertOpts) InsertRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req InsertRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + if req.tuple == nil { + req.tuple = []interface{}{} + } + args := insertArgs{Space: req.space, Tuple: req.tuple, Opts: req.opts} + req.impl = req.impl.Args(args) + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req InsertRequest) Context(ctx context.Context) InsertRequest { + req.impl = req.impl.Context(ctx) + + return req +} + +// InsertObjectOpts describes options for `crud.insert_object` method. +type InsertObjectOpts = SimpleOperationObjectOpts + +// InsertObjectRequest helps you to create request object to call +// `crud.insert_object` for execution by a Connection. +type InsertObjectRequest struct { + spaceRequest + object Object + opts InsertObjectOpts +} + +type insertObjectArgs struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Space string + Object Object + Opts InsertObjectOpts +} + +// MakeInsertObjectRequest returns a new empty InsertObjectRequest. +func MakeInsertObjectRequest(space string) InsertObjectRequest { + req := InsertObjectRequest{} + req.impl = newCall("crud.insert_object") + req.space = space + req.opts = InsertObjectOpts{} + return req +} + +// Object sets the tuple for the InsertObjectRequest request. +// Note: default value is nil. +func (req InsertObjectRequest) Object(object Object) InsertObjectRequest { + req.object = object + return req +} + +// Opts sets the options for the InsertObjectRequest request. +// Note: default value is nil. +func (req InsertObjectRequest) Opts(opts InsertObjectOpts) InsertObjectRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req InsertObjectRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + if req.object == nil { + req.object = MapObject{} + } + args := insertObjectArgs{Space: req.space, Object: req.object, Opts: req.opts} + req.impl = req.impl.Args(args) + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req InsertObjectRequest) Context(ctx context.Context) InsertObjectRequest { + req.impl = req.impl.Context(ctx) + + return req +} diff --git a/crud/insert_many.go b/crud/insert_many.go new file mode 100644 index 000000000..17748564d --- /dev/null +++ b/crud/insert_many.go @@ -0,0 +1,125 @@ +package crud + +import ( + "context" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// InsertManyOpts describes options for `crud.insert_many` method. +type InsertManyOpts = OperationManyOpts + +// InsertManyRequest helps you to create request object to call +// `crud.insert_many` for execution by a Connection. +type InsertManyRequest struct { + spaceRequest + tuples Tuples + opts InsertManyOpts +} + +type insertManyArgs struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Space string + Tuples Tuples + Opts InsertManyOpts +} + +// MakeInsertManyRequest returns a new empty InsertManyRequest. +func MakeInsertManyRequest(space string) InsertManyRequest { + req := InsertManyRequest{} + req.impl = newCall("crud.insert_many") + req.space = space + req.opts = InsertManyOpts{} + return req +} + +// Tuples sets the tuples for the InsertManyRequest request. +// Note: default value is nil. +func (req InsertManyRequest) Tuples(tuples Tuples) InsertManyRequest { + req.tuples = tuples + return req +} + +// Opts sets the options for the InsertManyRequest request. +// Note: default value is nil. +func (req InsertManyRequest) Opts(opts InsertManyOpts) InsertManyRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req InsertManyRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + if req.tuples == nil { + req.tuples = []Tuple{} + } + args := insertManyArgs{Space: req.space, Tuples: req.tuples, Opts: req.opts} + req.impl = req.impl.Args(args) + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req InsertManyRequest) Context(ctx context.Context) InsertManyRequest { + req.impl = req.impl.Context(ctx) + + return req +} + +// InsertObjectManyOpts describes options for `crud.insert_object_many` method. +type InsertObjectManyOpts = OperationObjectManyOpts + +// InsertObjectManyRequest helps you to create request object to call +// `crud.insert_object_many` for execution by a Connection. +type InsertObjectManyRequest struct { + spaceRequest + objects Objects + opts InsertObjectManyOpts +} + +type insertObjectManyArgs struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Space string + Objects Objects + Opts InsertObjectManyOpts +} + +// MakeInsertObjectManyRequest returns a new empty InsertObjectManyRequest. +func MakeInsertObjectManyRequest(space string) InsertObjectManyRequest { + req := InsertObjectManyRequest{} + req.impl = newCall("crud.insert_object_many") + req.space = space + req.opts = InsertObjectManyOpts{} + return req +} + +// Objects sets the objects for the InsertObjectManyRequest request. +// Note: default value is nil. +func (req InsertObjectManyRequest) Objects(objects Objects) InsertObjectManyRequest { + req.objects = objects + return req +} + +// Opts sets the options for the InsertObjectManyRequest request. +// Note: default value is nil. +func (req InsertObjectManyRequest) Opts(opts InsertObjectManyOpts) InsertObjectManyRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req InsertObjectManyRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + if req.objects == nil { + req.objects = []Object{} + } + args := insertObjectManyArgs{Space: req.space, Objects: req.objects, Opts: req.opts} + req.impl = req.impl.Args(args) + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req InsertObjectManyRequest) Context(ctx context.Context) InsertObjectManyRequest { + req.impl = req.impl.Context(ctx) + + return req +} diff --git a/crud/len.go b/crud/len.go new file mode 100644 index 000000000..a1da72f72 --- /dev/null +++ b/crud/len.go @@ -0,0 +1,58 @@ +package crud + +import ( + "context" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// LenResult describes result for `crud.len` method. +type LenResult = NumberResult + +// LenOpts describes options for `crud.len` method. +type LenOpts = BaseOpts + +// LenRequest helps you to create request object to call `crud.len` +// for execution by a Connection. +type LenRequest struct { + spaceRequest + opts LenOpts +} + +type lenArgs struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Space string + Opts LenOpts +} + +// MakeLenRequest returns a new empty LenRequest. +func MakeLenRequest(space string) LenRequest { + req := LenRequest{} + req.impl = newCall("crud.len") + req.space = space + req.opts = LenOpts{} + return req +} + +// Opts sets the options for the LenRequest request. +// Note: default value is nil. +func (req LenRequest) Opts(opts LenOpts) LenRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req LenRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + args := lenArgs{Space: req.space, Opts: req.opts} + req.impl = req.impl.Args(args) + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req LenRequest) Context(ctx context.Context) LenRequest { + req.impl = req.impl.Context(ctx) + + return req +} diff --git a/crud/max.go b/crud/max.go new file mode 100644 index 000000000..961e7724b --- /dev/null +++ b/crud/max.go @@ -0,0 +1,64 @@ +package crud + +import ( + "context" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// MaxOpts describes options for `crud.max` method. +type MaxOpts = BorderOpts + +// MaxRequest helps you to create request object to call `crud.max` +// for execution by a Connection. +type MaxRequest struct { + spaceRequest + index interface{} + opts MaxOpts +} + +type maxArgs struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Space string + Index interface{} + Opts MaxOpts +} + +// MakeMaxRequest returns a new empty MaxRequest. +func MakeMaxRequest(space string) MaxRequest { + req := MaxRequest{} + req.impl = newCall("crud.max") + req.space = space + req.opts = MaxOpts{} + return req +} + +// Index sets the index name/id for the MaxRequest request. +// Note: default value is nil. +func (req MaxRequest) Index(index interface{}) MaxRequest { + req.index = index + return req +} + +// Opts sets the options for the MaxRequest request. +// Note: default value is nil. +func (req MaxRequest) Opts(opts MaxOpts) MaxRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req MaxRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + args := maxArgs{Space: req.space, Index: req.index, Opts: req.opts} + req.impl = req.impl.Args(args) + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req MaxRequest) Context(ctx context.Context) MaxRequest { + req.impl = req.impl.Context(ctx) + + return req +} diff --git a/crud/min.go b/crud/min.go new file mode 100644 index 000000000..2bbf9b816 --- /dev/null +++ b/crud/min.go @@ -0,0 +1,64 @@ +package crud + +import ( + "context" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// MinOpts describes options for `crud.min` method. +type MinOpts = BorderOpts + +// MinRequest helps you to create request object to call `crud.min` +// for execution by a Connection. +type MinRequest struct { + spaceRequest + index interface{} + opts MinOpts +} + +type minArgs struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Space string + Index interface{} + Opts MinOpts +} + +// MakeMinRequest returns a new empty MinRequest. +func MakeMinRequest(space string) MinRequest { + req := MinRequest{} + req.impl = newCall("crud.min") + req.space = space + req.opts = MinOpts{} + return req +} + +// Index sets the index name/id for the MinRequest request. +// Note: default value is nil. +func (req MinRequest) Index(index interface{}) MinRequest { + req.index = index + return req +} + +// Opts sets the options for the MinRequest request. +// Note: default value is nil. +func (req MinRequest) Opts(opts MinOpts) MinRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req MinRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + args := minArgs{Space: req.space, Index: req.index, Opts: req.opts} + req.impl = req.impl.Args(args) + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req MinRequest) Context(ctx context.Context) MinRequest { + req.impl = req.impl.Context(ctx) + + return req +} diff --git a/crud/object.go b/crud/object.go new file mode 100644 index 000000000..8803d1268 --- /dev/null +++ b/crud/object.go @@ -0,0 +1,24 @@ +package crud + +import ( + "github.com/vmihailenco/msgpack/v5" +) + +// Object is an interface to describe object for CRUD methods. It can be any +// type that msgpack can encode as a map. +type Object = interface{} + +// Objects is a type to describe an array of object for CRUD methods. It can be +// any type that msgpack can encode, but encoded data must be an array of +// objects. +// +// See the reason why not just []Object: +// https://github.com/tarantool/go-tarantool/issues/365 +type Objects = interface{} + +// MapObject is a type to describe object as a map. +type MapObject map[string]interface{} + +func (o MapObject) EncodeMsgpack(enc *msgpack.Encoder) { + enc.Encode(o) +} diff --git a/crud/operations.go b/crud/operations.go new file mode 100644 index 000000000..84953d590 --- /dev/null +++ b/crud/operations.go @@ -0,0 +1,70 @@ +package crud + +import ( + "github.com/vmihailenco/msgpack/v5" +) + +const ( + // Add - operator for addition. + Add Operator = "+" + // Sub - operator for subtraction. + Sub Operator = "-" + // And - operator for bitwise AND. + And Operator = "&" + // Or - operator for bitwise OR. + Or Operator = "|" + // Xor - operator for bitwise XOR. + Xor Operator = "^" + // Splice - operator for string splice. + Splice Operator = ":" + // Insert - operator for insertion of a new field. + Insert Operator = "!" + // Delete - operator for deletion. + Delete Operator = "#" + // Assign - operator for assignment. + Assign Operator = "=" +) + +// Operation describes CRUD operation as a table +// {operator, field_identifier, value}. +// Splice operation described as a table +// {operator, field_identifier, position, length, replace_string}. +type Operation struct { + Operator Operator + Field interface{} // Number or string. + Value interface{} + // Pos, Len, Replace fields used in the Splice operation. + Pos int + Len int + Replace string +} + +// EncodeMsgpack encodes Operation. +func (o Operation) EncodeMsgpack(enc *msgpack.Encoder) error { + isSpliceOperation := o.Operator == Splice + argsLen := 3 + if isSpliceOperation { + argsLen = 5 + } + if err := enc.EncodeArrayLen(argsLen); err != nil { + return err + } + if err := enc.EncodeString(string(o.Operator)); err != nil { + return err + } + if err := enc.Encode(o.Field); err != nil { + return err + } + + if isSpliceOperation { + if err := enc.EncodeInt(int64(o.Pos)); err != nil { + return err + } + if err := enc.EncodeInt(int64(o.Len)); err != nil { + return err + } + return enc.EncodeString(o.Replace) + } + + return enc.Encode(o.Value) +} diff --git a/crud/operations_test.go b/crud/operations_test.go new file mode 100644 index 000000000..a7f61a8a7 --- /dev/null +++ b/crud/operations_test.go @@ -0,0 +1,123 @@ +package crud_test + +import ( + "bytes" + "testing" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3/crud" +) + +func TestOperation_EncodeMsgpack(t *testing.T) { + testCases := []struct { + name string + op crud.Operation + ref []interface{} + }{ + { + "Add", + crud.Operation{ + Operator: crud.Add, + Field: 1, + Value: 2, + }, + []interface{}{"+", 1, 2}, + }, + { + "Sub", + crud.Operation{ + Operator: crud.Sub, + Field: 1, + Value: 2, + }, + []interface{}{"-", 1, 2}, + }, + { + "And", + crud.Operation{ + Operator: crud.And, + Field: 1, + Value: 2, + }, + []interface{}{"&", 1, 2}, + }, + { + "Or", + crud.Operation{ + Operator: crud.Or, + Field: 1, + Value: 2, + }, + []interface{}{"|", 1, 2}, + }, + { + "Xor", + crud.Operation{ + Operator: crud.Xor, + Field: 1, + Value: 2, + }, + []interface{}{"^", 1, 2}, + }, + { + "Splice", + crud.Operation{ + Operator: crud.Splice, + Field: 1, + Pos: 2, + Len: 3, + Replace: "a", + }, + []interface{}{":", 1, 2, 3, "a"}, + }, + { + "Insert", + crud.Operation{ + Operator: crud.Insert, + Field: 1, + Value: 2, + }, + []interface{}{"!", 1, 2}, + }, + { + "Delete", + crud.Operation{ + Operator: crud.Delete, + Field: 1, + Value: 2, + }, + []interface{}{"#", 1, 2}, + }, + { + "Assign", + crud.Operation{ + Operator: crud.Assign, + Field: 1, + Value: 2, + }, + []interface{}{"=", 1, 2}, + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + var refBuf bytes.Buffer + encRef := msgpack.NewEncoder(&refBuf) + if err := encRef.Encode(test.ref); err != nil { + t.Errorf("error while encoding: %v", err.Error()) + } + + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + + if err := enc.Encode(test.op); err != nil { + t.Errorf("error while encoding: %v", err.Error()) + } + if !bytes.Equal(refBuf.Bytes(), buf.Bytes()) { + t.Errorf("encode response is wrong:\n expected %v\n got: %v", + refBuf, buf.Bytes()) + } + }) + } +} diff --git a/crud/options.go b/crud/options.go new file mode 100644 index 000000000..311df522c --- /dev/null +++ b/crud/options.go @@ -0,0 +1,408 @@ +package crud + +import ( + "github.com/vmihailenco/msgpack/v5" +) + +const ( + timeoutOptName = "timeout" + vshardRouterOptName = "vshard_router" + fieldsOptName = "fields" + bucketIdOptName = "bucket_id" + skipNullabilityCheckOnFlattenOptName = "skip_nullability_check_on_flatten" + stopOnErrorOptName = "stop_on_error" + rollbackOnErrorOptName = "rollback_on_error" + modeOptName = "mode" + preferReplicaOptName = "prefer_replica" + balanceOptName = "balance" + yieldEveryOptName = "yield_every" + forceMapCallOptName = "force_map_call" + fullscanOptName = "fullscan" + firstOptName = "first" + afterOptName = "after" + batchSizeOptName = "batch_size" + fetchLatestMetadataOptName = "fetch_latest_metadata" + noreturnOptName = "noreturn" + cachedOptName = "cached" +) + +// OptUint is an optional uint. +type OptUint struct { + value uint + exist bool +} + +// MakeOptUint creates an optional uint from value. +func MakeOptUint(value uint) OptUint { + return OptUint{ + value: value, + exist: true, + } +} + +// Get returns the integer value or an error if not present. +func (opt OptUint) Get() (uint, bool) { + return opt.value, opt.exist +} + +// OptInt is an optional int. +type OptInt struct { + value int + exist bool +} + +// MakeOptInt creates an optional int from value. +func MakeOptInt(value int) OptInt { + return OptInt{ + value: value, + exist: true, + } +} + +// Get returns the integer value or an error if not present. +func (opt OptInt) Get() (int, bool) { + return opt.value, opt.exist +} + +// OptFloat64 is an optional float64. +type OptFloat64 struct { + value float64 + exist bool +} + +// MakeOptFloat64 creates an optional float64 from value. +func MakeOptFloat64(value float64) OptFloat64 { + return OptFloat64{ + value: value, + exist: true, + } +} + +// Get returns the float64 value or an error if not present. +func (opt OptFloat64) Get() (float64, bool) { + return opt.value, opt.exist +} + +// OptString is an optional string. +type OptString struct { + value string + exist bool +} + +// MakeOptString creates an optional string from value. +func MakeOptString(value string) OptString { + return OptString{ + value: value, + exist: true, + } +} + +// Get returns the string value or an error if not present. +func (opt OptString) Get() (string, bool) { + return opt.value, opt.exist +} + +// OptBool is an optional bool. +type OptBool struct { + value bool + exist bool +} + +// MakeOptBool creates an optional bool from value. +func MakeOptBool(value bool) OptBool { + return OptBool{ + value: value, + exist: true, + } +} + +// Get returns the boolean value or an error if not present. +func (opt OptBool) Get() (bool, bool) { + return opt.value, opt.exist +} + +// OptTuple is an optional tuple. +type OptTuple struct { + tuple interface{} +} + +// MakeOptTuple creates an optional tuple from tuple. +func MakeOptTuple(tuple interface{}) OptTuple { + return OptTuple{tuple} +} + +// Get returns the tuple value or an error if not present. +func (o *OptTuple) Get() (interface{}, bool) { + return o.tuple, o.tuple != nil +} + +// BaseOpts describes base options for CRUD operations. +type BaseOpts struct { + // Timeout is a `vshard.call` timeout and vshard + // master discovery timeout (in seconds). + Timeout OptFloat64 + // VshardRouter is cartridge vshard group name or + // vshard router instance. + VshardRouter OptString +} + +// EncodeMsgpack provides custom msgpack encoder. +func (opts BaseOpts) EncodeMsgpack(enc *msgpack.Encoder) error { + const optsCnt = 2 + + names := [optsCnt]string{timeoutOptName, vshardRouterOptName} + values := [optsCnt]interface{}{} + exists := [optsCnt]bool{} + values[0], exists[0] = opts.Timeout.Get() + values[1], exists[1] = opts.VshardRouter.Get() + + return encodeOptions(enc, names[:], values[:], exists[:]) +} + +// SimpleOperationOpts describes options for simple CRUD operations. +// It also covers `upsert_object` options. +type SimpleOperationOpts struct { + // Timeout is a `vshard.call` timeout and vshard + // master discovery timeout (in seconds). + Timeout OptFloat64 + // VshardRouter is cartridge vshard group name or + // vshard router instance. + VshardRouter OptString + // Fields is field names for getting only a subset of fields. + Fields OptTuple + // BucketId is a bucket ID. + BucketId OptUint + // FetchLatestMetadata guarantees the up-to-date metadata (space format) + // in first return value, otherwise it may not take into account + // the latest migration of the data format. Performance overhead is up to 15%. + // Disabled by default. + FetchLatestMetadata OptBool + // Noreturn suppresses successfully processed data (first return value is `nil`). + // Disabled by default. + Noreturn OptBool +} + +// EncodeMsgpack provides custom msgpack encoder. +func (opts SimpleOperationOpts) EncodeMsgpack(enc *msgpack.Encoder) error { + const optsCnt = 6 + + names := [optsCnt]string{timeoutOptName, vshardRouterOptName, + fieldsOptName, bucketIdOptName, fetchLatestMetadataOptName, + noreturnOptName} + values := [optsCnt]interface{}{} + exists := [optsCnt]bool{} + values[0], exists[0] = opts.Timeout.Get() + values[1], exists[1] = opts.VshardRouter.Get() + values[2], exists[2] = opts.Fields.Get() + values[3], exists[3] = opts.BucketId.Get() + values[4], exists[4] = opts.FetchLatestMetadata.Get() + values[5], exists[5] = opts.Noreturn.Get() + + return encodeOptions(enc, names[:], values[:], exists[:]) +} + +// SimpleOperationObjectOpts describes options for simple CRUD +// operations with objects. It doesn't cover `upsert_object` options. +type SimpleOperationObjectOpts struct { + // Timeout is a `vshard.call` timeout and vshard + // master discovery timeout (in seconds). + Timeout OptFloat64 + // VshardRouter is cartridge vshard group name or + // vshard router instance. + VshardRouter OptString + // Fields is field names for getting only a subset of fields. + Fields OptTuple + // BucketId is a bucket ID. + BucketId OptUint + // SkipNullabilityCheckOnFlatten is a parameter to allow + // setting null values to non-nullable fields. + SkipNullabilityCheckOnFlatten OptBool + // FetchLatestMetadata guarantees the up-to-date metadata (space format) + // in first return value, otherwise it may not take into account + // the latest migration of the data format. Performance overhead is up to 15%. + // Disabled by default. + FetchLatestMetadata OptBool + // Noreturn suppresses successfully processed data (first return value is `nil`). + // Disabled by default. + Noreturn OptBool +} + +// EncodeMsgpack provides custom msgpack encoder. +func (opts SimpleOperationObjectOpts) EncodeMsgpack(enc *msgpack.Encoder) error { + const optsCnt = 7 + + names := [optsCnt]string{timeoutOptName, vshardRouterOptName, + fieldsOptName, bucketIdOptName, skipNullabilityCheckOnFlattenOptName, + fetchLatestMetadataOptName, noreturnOptName} + values := [optsCnt]interface{}{} + exists := [optsCnt]bool{} + values[0], exists[0] = opts.Timeout.Get() + values[1], exists[1] = opts.VshardRouter.Get() + values[2], exists[2] = opts.Fields.Get() + values[3], exists[3] = opts.BucketId.Get() + values[4], exists[4] = opts.SkipNullabilityCheckOnFlatten.Get() + values[5], exists[5] = opts.FetchLatestMetadata.Get() + values[6], exists[6] = opts.Noreturn.Get() + + return encodeOptions(enc, names[:], values[:], exists[:]) +} + +// OperationManyOpts describes options for CRUD operations with many tuples. +// It also covers `upsert_object_many` options. +type OperationManyOpts struct { + // Timeout is a `vshard.call` timeout and vshard + // master discovery timeout (in seconds). + Timeout OptFloat64 + // VshardRouter is cartridge vshard group name or + // vshard router instance. + VshardRouter OptString + // Fields is field names for getting only a subset of fields. + Fields OptTuple + // StopOnError is a parameter to stop on a first error and report + // error regarding the failed operation and error about what tuples + // were not performed. + StopOnError OptBool + // RollbackOnError is a parameter because of what any failed operation + // will lead to rollback on a storage, where the operation is failed. + RollbackOnError OptBool + // FetchLatestMetadata guarantees the up-to-date metadata (space format) + // in first return value, otherwise it may not take into account + // the latest migration of the data format. Performance overhead is up to 15%. + // Disabled by default. + FetchLatestMetadata OptBool + // Noreturn suppresses successfully processed data (first return value is `nil`). + // Disabled by default. + Noreturn OptBool +} + +// EncodeMsgpack provides custom msgpack encoder. +func (opts OperationManyOpts) EncodeMsgpack(enc *msgpack.Encoder) error { + const optsCnt = 7 + + names := [optsCnt]string{timeoutOptName, vshardRouterOptName, + fieldsOptName, stopOnErrorOptName, rollbackOnErrorOptName, + fetchLatestMetadataOptName, noreturnOptName} + values := [optsCnt]interface{}{} + exists := [optsCnt]bool{} + values[0], exists[0] = opts.Timeout.Get() + values[1], exists[1] = opts.VshardRouter.Get() + values[2], exists[2] = opts.Fields.Get() + values[3], exists[3] = opts.StopOnError.Get() + values[4], exists[4] = opts.RollbackOnError.Get() + values[5], exists[5] = opts.FetchLatestMetadata.Get() + values[6], exists[6] = opts.Noreturn.Get() + + return encodeOptions(enc, names[:], values[:], exists[:]) +} + +// OperationObjectManyOpts describes options for CRUD operations +// with many objects. It doesn't cover `upsert_object_many` options. +type OperationObjectManyOpts struct { + // Timeout is a `vshard.call` timeout and vshard + // master discovery timeout (in seconds). + Timeout OptFloat64 + // VshardRouter is cartridge vshard group name or + // vshard router instance. + VshardRouter OptString + // Fields is field names for getting only a subset of fields. + Fields OptTuple + // StopOnError is a parameter to stop on a first error and report + // error regarding the failed operation and error about what tuples + // were not performed. + StopOnError OptBool + // RollbackOnError is a parameter because of what any failed operation + // will lead to rollback on a storage, where the operation is failed. + RollbackOnError OptBool + // SkipNullabilityCheckOnFlatten is a parameter to allow + // setting null values to non-nullable fields. + SkipNullabilityCheckOnFlatten OptBool + // FetchLatestMetadata guarantees the up-to-date metadata (space format) + // in first return value, otherwise it may not take into account + // the latest migration of the data format. Performance overhead is up to 15%. + // Disabled by default. + FetchLatestMetadata OptBool + // Noreturn suppresses successfully processed data (first return value is `nil`). + // Disabled by default. + Noreturn OptBool +} + +// EncodeMsgpack provides custom msgpack encoder. +func (opts OperationObjectManyOpts) EncodeMsgpack(enc *msgpack.Encoder) error { + const optsCnt = 8 + + names := [optsCnt]string{timeoutOptName, vshardRouterOptName, + fieldsOptName, stopOnErrorOptName, rollbackOnErrorOptName, + skipNullabilityCheckOnFlattenOptName, fetchLatestMetadataOptName, + noreturnOptName} + values := [optsCnt]interface{}{} + exists := [optsCnt]bool{} + values[0], exists[0] = opts.Timeout.Get() + values[1], exists[1] = opts.VshardRouter.Get() + values[2], exists[2] = opts.Fields.Get() + values[3], exists[3] = opts.StopOnError.Get() + values[4], exists[4] = opts.RollbackOnError.Get() + values[5], exists[5] = opts.SkipNullabilityCheckOnFlatten.Get() + values[6], exists[6] = opts.FetchLatestMetadata.Get() + values[7], exists[7] = opts.Noreturn.Get() + + return encodeOptions(enc, names[:], values[:], exists[:]) +} + +// BorderOpts describes options for `crud.min` and `crud.max`. +type BorderOpts struct { + // Timeout is a `vshard.call` timeout and vshard + // master discovery timeout (in seconds). + Timeout OptFloat64 + // VshardRouter is cartridge vshard group name or + // vshard router instance. + VshardRouter OptString + // Fields is field names for getting only a subset of fields. + Fields OptTuple + // FetchLatestMetadata guarantees the up-to-date metadata (space format) + // in first return value, otherwise it may not take into account + // the latest migration of the data format. Performance overhead is up to 15%. + // Disabled by default. + FetchLatestMetadata OptBool +} + +// EncodeMsgpack provides custom msgpack encoder. +func (opts BorderOpts) EncodeMsgpack(enc *msgpack.Encoder) error { + const optsCnt = 4 + + names := [optsCnt]string{timeoutOptName, vshardRouterOptName, fieldsOptName, + fetchLatestMetadataOptName} + values := [optsCnt]interface{}{} + exists := [optsCnt]bool{} + values[0], exists[0] = opts.Timeout.Get() + values[1], exists[1] = opts.VshardRouter.Get() + values[2], exists[2] = opts.Fields.Get() + values[3], exists[3] = opts.FetchLatestMetadata.Get() + + return encodeOptions(enc, names[:], values[:], exists[:]) +} + +func encodeOptions(enc *msgpack.Encoder, + names []string, values []interface{}, exists []bool) error { + mapLen := 0 + + for _, exist := range exists { + if exist { + mapLen += 1 + } + } + + if err := enc.EncodeMapLen(mapLen); err != nil { + return err + } + + if mapLen > 0 { + for i, name := range names { + if exists[i] { + enc.EncodeString(name) + enc.Encode(values[i]) + } + } + } + + return nil +} diff --git a/crud/replace.go b/crud/replace.go new file mode 100644 index 000000000..b47bba9ab --- /dev/null +++ b/crud/replace.go @@ -0,0 +1,125 @@ +package crud + +import ( + "context" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// ReplaceOpts describes options for `crud.replace` method. +type ReplaceOpts = SimpleOperationOpts + +// ReplaceRequest helps you to create request object to call `crud.replace` +// for execution by a Connection. +type ReplaceRequest struct { + spaceRequest + tuple Tuple + opts ReplaceOpts +} + +type replaceArgs struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Space string + Tuple Tuple + Opts ReplaceOpts +} + +// MakeReplaceRequest returns a new empty ReplaceRequest. +func MakeReplaceRequest(space string) ReplaceRequest { + req := ReplaceRequest{} + req.impl = newCall("crud.replace") + req.space = space + req.opts = ReplaceOpts{} + return req +} + +// Tuple sets the tuple for the ReplaceRequest request. +// Note: default value is nil. +func (req ReplaceRequest) Tuple(tuple Tuple) ReplaceRequest { + req.tuple = tuple + return req +} + +// Opts sets the options for the ReplaceRequest request. +// Note: default value is nil. +func (req ReplaceRequest) Opts(opts ReplaceOpts) ReplaceRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req ReplaceRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + if req.tuple == nil { + req.tuple = []interface{}{} + } + args := replaceArgs{Space: req.space, Tuple: req.tuple, Opts: req.opts} + req.impl = req.impl.Args(args) + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req ReplaceRequest) Context(ctx context.Context) ReplaceRequest { + req.impl = req.impl.Context(ctx) + + return req +} + +// ReplaceObjectOpts describes options for `crud.replace_object` method. +type ReplaceObjectOpts = SimpleOperationObjectOpts + +// ReplaceObjectRequest helps you to create request object to call +// `crud.replace_object` for execution by a Connection. +type ReplaceObjectRequest struct { + spaceRequest + object Object + opts ReplaceObjectOpts +} + +type replaceObjectArgs struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Space string + Object Object + Opts ReplaceObjectOpts +} + +// MakeReplaceObjectRequest returns a new empty ReplaceObjectRequest. +func MakeReplaceObjectRequest(space string) ReplaceObjectRequest { + req := ReplaceObjectRequest{} + req.impl = newCall("crud.replace_object") + req.space = space + req.opts = ReplaceObjectOpts{} + return req +} + +// Object sets the tuple for the ReplaceObjectRequest request. +// Note: default value is nil. +func (req ReplaceObjectRequest) Object(object Object) ReplaceObjectRequest { + req.object = object + return req +} + +// Opts sets the options for the ReplaceObjectRequest request. +// Note: default value is nil. +func (req ReplaceObjectRequest) Opts(opts ReplaceObjectOpts) ReplaceObjectRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req ReplaceObjectRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + if req.object == nil { + req.object = MapObject{} + } + args := replaceObjectArgs{Space: req.space, Object: req.object, Opts: req.opts} + req.impl = req.impl.Args(args) + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req ReplaceObjectRequest) Context(ctx context.Context) ReplaceObjectRequest { + req.impl = req.impl.Context(ctx) + + return req +} diff --git a/crud/replace_many.go b/crud/replace_many.go new file mode 100644 index 000000000..024b863b7 --- /dev/null +++ b/crud/replace_many.go @@ -0,0 +1,125 @@ +package crud + +import ( + "context" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// ReplaceManyOpts describes options for `crud.replace_many` method. +type ReplaceManyOpts = OperationManyOpts + +// ReplaceManyRequest helps you to create request object to call +// `crud.replace_many` for execution by a Connection. +type ReplaceManyRequest struct { + spaceRequest + tuples Tuples + opts ReplaceManyOpts +} + +type replaceManyArgs struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Space string + Tuples Tuples + Opts ReplaceManyOpts +} + +// MakeReplaceManyRequest returns a new empty ReplaceManyRequest. +func MakeReplaceManyRequest(space string) ReplaceManyRequest { + req := ReplaceManyRequest{} + req.impl = newCall("crud.replace_many") + req.space = space + req.opts = ReplaceManyOpts{} + return req +} + +// Tuples sets the tuples for the ReplaceManyRequest request. +// Note: default value is nil. +func (req ReplaceManyRequest) Tuples(tuples Tuples) ReplaceManyRequest { + req.tuples = tuples + return req +} + +// Opts sets the options for the ReplaceManyRequest request. +// Note: default value is nil. +func (req ReplaceManyRequest) Opts(opts ReplaceManyOpts) ReplaceManyRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req ReplaceManyRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + if req.tuples == nil { + req.tuples = []Tuple{} + } + args := replaceManyArgs{Space: req.space, Tuples: req.tuples, Opts: req.opts} + req.impl = req.impl.Args(args) + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req ReplaceManyRequest) Context(ctx context.Context) ReplaceManyRequest { + req.impl = req.impl.Context(ctx) + + return req +} + +// ReplaceObjectManyOpts describes options for `crud.replace_object_many` method. +type ReplaceObjectManyOpts = OperationObjectManyOpts + +// ReplaceObjectManyRequest helps you to create request object to call +// `crud.replace_object_many` for execution by a Connection. +type ReplaceObjectManyRequest struct { + spaceRequest + objects Objects + opts ReplaceObjectManyOpts +} + +type replaceObjectManyArgs struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Space string + Objects Objects + Opts ReplaceObjectManyOpts +} + +// MakeReplaceObjectManyRequest returns a new empty ReplaceObjectManyRequest. +func MakeReplaceObjectManyRequest(space string) ReplaceObjectManyRequest { + req := ReplaceObjectManyRequest{} + req.impl = newCall("crud.replace_object_many") + req.space = space + req.opts = ReplaceObjectManyOpts{} + return req +} + +// Objects sets the tuple for the ReplaceObjectManyRequest request. +// Note: default value is nil. +func (req ReplaceObjectManyRequest) Objects(objects Objects) ReplaceObjectManyRequest { + req.objects = objects + return req +} + +// Opts sets the options for the ReplaceObjectManyRequest request. +// Note: default value is nil. +func (req ReplaceObjectManyRequest) Opts(opts ReplaceObjectManyOpts) ReplaceObjectManyRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req ReplaceObjectManyRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + if req.objects == nil { + req.objects = []Object{} + } + args := replaceObjectManyArgs{Space: req.space, Objects: req.objects, Opts: req.opts} + req.impl = req.impl.Args(args) + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req ReplaceObjectManyRequest) Context(ctx context.Context) ReplaceObjectManyRequest { + req.impl = req.impl.Context(ctx) + + return req +} diff --git a/crud/request_test.go b/crud/request_test.go new file mode 100644 index 000000000..ba2bae859 --- /dev/null +++ b/crud/request_test.go @@ -0,0 +1,899 @@ +package crud_test + +import ( + "bytes" + "context" + "fmt" + "testing" + + "github.com/tarantool/go-iproto" + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/crud" +) + +const validSpace = "test" // Any valid value != default. + +const CrudRequestType = iproto.IPROTO_CALL + +var reqObject = crud.MapObject{ + "id": uint(24), +} + +var reqObjects = []crud.Object{ + crud.MapObject{ + "id": uint(24), + }, + crud.MapObject{ + "id": uint(25), + }, +} + +var reqObjectsOperationsData = []crud.ObjectOperationsData{ + { + Object: crud.MapObject{ + "id": uint(24), + }, + Operations: []crud.Operation{ + { + Operator: crud.Add, + Field: "id", + Value: uint(1020), + }, + }, + }, + { + Object: crud.MapObject{ + "id": uint(25), + }, + Operations: []crud.Operation{ + { + Operator: crud.Add, + Field: "id", + Value: uint(1020), + }, + }, + }, +} + +var expectedOpts = map[string]interface{}{ + "timeout": timeout, +} + +func extractRequestBody(req tarantool.Request) ([]byte, error) { + var reqBuf bytes.Buffer + reqEnc := msgpack.NewEncoder(&reqBuf) + + err := req.Body(nil, reqEnc) + if err != nil { + return nil, fmt.Errorf("An unexpected Response.Body() error: %q", err.Error()) + } + + return reqBuf.Bytes(), nil +} + +func assertBodyEqual(t testing.TB, reference tarantool.Request, req tarantool.Request) { + t.Helper() + + reqBody, err := extractRequestBody(req) + if err != nil { + t.Fatalf("An unexpected Response.Body() error: %q", err.Error()) + } + + refBody, err := extractRequestBody(reference) + if err != nil { + t.Fatalf("An unexpected Response.Body() error: %q", err.Error()) + } + + if !bytes.Equal(reqBody, refBody) { + t.Errorf("Encoded request %v != reference %v", reqBody, refBody) + } +} + +func BenchmarkLenRequest(b *testing.B) { + buf := bytes.Buffer{} + buf.Grow(512 * 1024 * 1024) // Avoid allocs in test. + enc := msgpack.NewEncoder(&buf) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + buf.Reset() + req := crud.MakeLenRequest(spaceName). + Opts(crud.LenOpts{ + Timeout: crud.MakeOptFloat64(3.5), + }) + if err := req.Body(nil, enc); err != nil { + b.Error(err) + } + } +} + +func BenchmarkSelectRequest(b *testing.B) { + buf := bytes.Buffer{} + buf.Grow(512 * 1024 * 1024) // Avoid allocs in test. + enc := msgpack.NewEncoder(&buf) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + buf.Reset() + req := crud.MakeSelectRequest(spaceName). + Opts(crud.SelectOpts{ + Timeout: crud.MakeOptFloat64(3.5), + VshardRouter: crud.MakeOptString("asd"), + Balance: crud.MakeOptBool(true), + }) + if err := req.Body(nil, enc); err != nil { + b.Error(err) + } + } +} + +func TestRequestsCodes(t *testing.T) { + tests := []struct { + req tarantool.Request + rtype iproto.Type + }{ + {req: crud.MakeInsertRequest(validSpace), rtype: CrudRequestType}, + {req: crud.MakeInsertObjectRequest(validSpace), rtype: CrudRequestType}, + {req: crud.MakeInsertManyRequest(validSpace), rtype: CrudRequestType}, + {req: crud.MakeInsertObjectManyRequest(validSpace), rtype: CrudRequestType}, + {req: crud.MakeGetRequest(validSpace), rtype: CrudRequestType}, + {req: crud.MakeUpdateRequest(validSpace), rtype: CrudRequestType}, + {req: crud.MakeDeleteRequest(validSpace), rtype: CrudRequestType}, + {req: crud.MakeReplaceRequest(validSpace), rtype: CrudRequestType}, + {req: crud.MakeReplaceObjectRequest(validSpace), rtype: CrudRequestType}, + {req: crud.MakeReplaceManyRequest(validSpace), rtype: CrudRequestType}, + {req: crud.MakeReplaceObjectManyRequest(validSpace), rtype: CrudRequestType}, + {req: crud.MakeUpsertRequest(validSpace), rtype: CrudRequestType}, + {req: crud.MakeUpsertObjectRequest(validSpace), rtype: CrudRequestType}, + {req: crud.MakeUpsertManyRequest(validSpace), rtype: CrudRequestType}, + {req: crud.MakeUpsertObjectManyRequest(validSpace), rtype: CrudRequestType}, + {req: crud.MakeMinRequest(validSpace), rtype: CrudRequestType}, + {req: crud.MakeMaxRequest(validSpace), rtype: CrudRequestType}, + {req: crud.MakeSelectRequest(validSpace), rtype: CrudRequestType}, + {req: crud.MakeTruncateRequest(validSpace), rtype: CrudRequestType}, + {req: crud.MakeLenRequest(validSpace), rtype: CrudRequestType}, + {req: crud.MakeCountRequest(validSpace), rtype: CrudRequestType}, + {req: crud.MakeStorageInfoRequest(), rtype: CrudRequestType}, + {req: crud.MakeStatsRequest(), rtype: CrudRequestType}, + {req: crud.MakeSchemaRequest(), rtype: CrudRequestType}, + } + + for _, test := range tests { + if rtype := test.req.Type(); rtype != test.rtype { + t.Errorf("An invalid request type 0x%x, expected 0x%x", rtype, test.rtype) + } + } +} + +func TestRequestsAsync(t *testing.T) { + tests := []struct { + req tarantool.Request + async bool + }{ + {req: crud.MakeInsertRequest(validSpace), async: false}, + {req: crud.MakeInsertObjectRequest(validSpace), async: false}, + {req: crud.MakeInsertManyRequest(validSpace), async: false}, + {req: crud.MakeInsertObjectManyRequest(validSpace), async: false}, + {req: crud.MakeGetRequest(validSpace), async: false}, + {req: crud.MakeUpdateRequest(validSpace), async: false}, + {req: crud.MakeDeleteRequest(validSpace), async: false}, + {req: crud.MakeReplaceRequest(validSpace), async: false}, + {req: crud.MakeReplaceObjectRequest(validSpace), async: false}, + {req: crud.MakeReplaceManyRequest(validSpace), async: false}, + {req: crud.MakeReplaceObjectManyRequest(validSpace), async: false}, + {req: crud.MakeUpsertRequest(validSpace), async: false}, + {req: crud.MakeUpsertObjectRequest(validSpace), async: false}, + {req: crud.MakeUpsertManyRequest(validSpace), async: false}, + {req: crud.MakeUpsertObjectManyRequest(validSpace), async: false}, + {req: crud.MakeMinRequest(validSpace), async: false}, + {req: crud.MakeMaxRequest(validSpace), async: false}, + {req: crud.MakeSelectRequest(validSpace), async: false}, + {req: crud.MakeTruncateRequest(validSpace), async: false}, + {req: crud.MakeLenRequest(validSpace), async: false}, + {req: crud.MakeCountRequest(validSpace), async: false}, + {req: crud.MakeStorageInfoRequest(), async: false}, + {req: crud.MakeStatsRequest(), async: false}, + {req: crud.MakeSchemaRequest(), async: false}, + } + + for _, test := range tests { + if async := test.req.Async(); async != test.async { + t.Errorf("An invalid async %t, expected %t", async, test.async) + } + } +} + +func TestRequestsCtx_default(t *testing.T) { + tests := []struct { + req tarantool.Request + expected context.Context + }{ + {req: crud.MakeInsertRequest(validSpace), expected: nil}, + {req: crud.MakeInsertObjectRequest(validSpace), expected: nil}, + {req: crud.MakeInsertManyRequest(validSpace), expected: nil}, + {req: crud.MakeInsertObjectManyRequest(validSpace), expected: nil}, + {req: crud.MakeGetRequest(validSpace), expected: nil}, + {req: crud.MakeUpdateRequest(validSpace), expected: nil}, + {req: crud.MakeDeleteRequest(validSpace), expected: nil}, + {req: crud.MakeReplaceRequest(validSpace), expected: nil}, + {req: crud.MakeReplaceObjectRequest(validSpace), expected: nil}, + {req: crud.MakeReplaceManyRequest(validSpace), expected: nil}, + {req: crud.MakeReplaceObjectManyRequest(validSpace), expected: nil}, + {req: crud.MakeUpsertRequest(validSpace), expected: nil}, + {req: crud.MakeUpsertObjectRequest(validSpace), expected: nil}, + {req: crud.MakeUpsertManyRequest(validSpace), expected: nil}, + {req: crud.MakeUpsertObjectManyRequest(validSpace), expected: nil}, + {req: crud.MakeMinRequest(validSpace), expected: nil}, + {req: crud.MakeMaxRequest(validSpace), expected: nil}, + {req: crud.MakeSelectRequest(validSpace), expected: nil}, + {req: crud.MakeTruncateRequest(validSpace), expected: nil}, + {req: crud.MakeLenRequest(validSpace), expected: nil}, + {req: crud.MakeCountRequest(validSpace), expected: nil}, + {req: crud.MakeStorageInfoRequest(), expected: nil}, + {req: crud.MakeStatsRequest(), expected: nil}, + {req: crud.MakeSchemaRequest(), expected: nil}, + } + + for _, test := range tests { + if ctx := test.req.Ctx(); ctx != test.expected { + t.Errorf("An invalid ctx %t, expected %t", ctx, test.expected) + } + } +} + +func TestRequestsCtx_setter(t *testing.T) { + ctx := context.Background() + tests := []struct { + req tarantool.Request + expected context.Context + }{ + {req: crud.MakeInsertRequest(validSpace).Context(ctx), expected: ctx}, + {req: crud.MakeInsertObjectRequest(validSpace).Context(ctx), expected: ctx}, + {req: crud.MakeInsertManyRequest(validSpace).Context(ctx), expected: ctx}, + {req: crud.MakeInsertObjectManyRequest(validSpace).Context(ctx), expected: ctx}, + {req: crud.MakeGetRequest(validSpace).Context(ctx), expected: ctx}, + {req: crud.MakeUpdateRequest(validSpace).Context(ctx), expected: ctx}, + {req: crud.MakeDeleteRequest(validSpace).Context(ctx), expected: ctx}, + {req: crud.MakeReplaceRequest(validSpace).Context(ctx), expected: ctx}, + {req: crud.MakeReplaceObjectRequest(validSpace).Context(ctx), expected: ctx}, + {req: crud.MakeReplaceManyRequest(validSpace).Context(ctx), expected: ctx}, + {req: crud.MakeReplaceObjectManyRequest(validSpace).Context(ctx), expected: ctx}, + {req: crud.MakeUpsertRequest(validSpace).Context(ctx), expected: ctx}, + {req: crud.MakeUpsertObjectRequest(validSpace).Context(ctx), expected: ctx}, + {req: crud.MakeUpsertManyRequest(validSpace).Context(ctx), expected: ctx}, + {req: crud.MakeUpsertObjectManyRequest(validSpace).Context(ctx), expected: ctx}, + {req: crud.MakeMinRequest(validSpace).Context(ctx), expected: ctx}, + {req: crud.MakeMaxRequest(validSpace).Context(ctx), expected: ctx}, + {req: crud.MakeSelectRequest(validSpace).Context(ctx), expected: ctx}, + {req: crud.MakeTruncateRequest(validSpace).Context(ctx), expected: ctx}, + {req: crud.MakeLenRequest(validSpace).Context(ctx), expected: ctx}, + {req: crud.MakeCountRequest(validSpace).Context(ctx), expected: ctx}, + {req: crud.MakeStorageInfoRequest().Context(ctx), expected: ctx}, + {req: crud.MakeStatsRequest().Context(ctx), expected: ctx}, + {req: crud.MakeSchemaRequest().Context(ctx), expected: ctx}, + } + + for _, test := range tests { + if ctx := test.req.Ctx(); ctx != test.expected { + t.Errorf("An invalid ctx %t, expected %t", ctx, test.expected) + } + } +} + +func TestRequestsDefaultValues(t *testing.T) { + testCases := []struct { + name string + ref tarantool.Request + target tarantool.Request + }{ + { + name: "InsertRequest", + ref: tarantool.NewCall17Request("crud.insert").Args( + []interface{}{validSpace, []interface{}{}, map[string]interface{}{}}), + target: crud.MakeInsertRequest(validSpace), + }, + { + name: "InsertObjectRequest", + ref: tarantool.NewCall17Request("crud.insert_object").Args( + []interface{}{validSpace, map[string]interface{}{}, map[string]interface{}{}}), + target: crud.MakeInsertObjectRequest(validSpace), + }, + { + name: "InsertManyRequest", + ref: tarantool.NewCall17Request("crud.insert_many").Args( + []interface{}{validSpace, []interface{}{}, map[string]interface{}{}}), + target: crud.MakeInsertManyRequest(validSpace), + }, + { + name: "InsertObjectManyRequest", + ref: tarantool.NewCall17Request("crud.insert_object_many").Args( + []interface{}{validSpace, []map[string]interface{}{}, map[string]interface{}{}}), + target: crud.MakeInsertObjectManyRequest(validSpace), + }, + { + name: "GetRequest", + ref: tarantool.NewCall17Request("crud.get").Args( + []interface{}{validSpace, []interface{}{}, map[string]interface{}{}}), + target: crud.MakeGetRequest(validSpace), + }, + { + name: "UpdateRequest", + ref: tarantool.NewCall17Request("crud.update").Args( + []interface{}{validSpace, []interface{}{}, + []interface{}{}, map[string]interface{}{}}), + target: crud.MakeUpdateRequest(validSpace), + }, + { + name: "DeleteRequest", + ref: tarantool.NewCall17Request("crud.delete").Args( + []interface{}{validSpace, []interface{}{}, map[string]interface{}{}}), + target: crud.MakeDeleteRequest(validSpace), + }, + { + name: "ReplaceRequest", + ref: tarantool.NewCall17Request("crud.replace").Args( + []interface{}{validSpace, []interface{}{}, map[string]interface{}{}}), + target: crud.MakeReplaceRequest(validSpace), + }, + { + name: "ReplaceObjectRequest", + ref: tarantool.NewCall17Request("crud.replace_object").Args([]interface{}{validSpace, + map[string]interface{}{}, map[string]interface{}{}}), + target: crud.MakeReplaceObjectRequest(validSpace), + }, + { + name: "ReplaceManyRequest", + ref: tarantool.NewCall17Request("crud.replace_many").Args([]interface{}{validSpace, + []interface{}{}, map[string]interface{}{}}), + target: crud.MakeReplaceManyRequest(validSpace), + }, + { + name: "ReplaceObjectManyRequest", + ref: tarantool.NewCall17Request("crud.replace_object_many").Args( + []interface{}{validSpace, []map[string]interface{}{}, map[string]interface{}{}}), + target: crud.MakeReplaceObjectManyRequest(validSpace), + }, + { + name: "UpsertRequest", + ref: tarantool.NewCall17Request("crud.upsert").Args( + []interface{}{validSpace, []interface{}{}, []interface{}{}, + map[string]interface{}{}}), + target: crud.MakeUpsertRequest(validSpace), + }, + { + name: "UpsertObjectRequest", + ref: tarantool.NewCall17Request("crud.upsert_object").Args( + []interface{}{validSpace, map[string]interface{}{}, []interface{}{}, + map[string]interface{}{}}), + target: crud.MakeUpsertObjectRequest(validSpace), + }, + { + name: "UpsertManyRequest", + ref: tarantool.NewCall17Request("crud.upsert_many").Args( + []interface{}{validSpace, []interface{}{}, map[string]interface{}{}}), + target: crud.MakeUpsertManyRequest(validSpace), + }, + { + name: "UpsertObjectManyRequest", + ref: tarantool.NewCall17Request("crud.upsert_object_many").Args( + []interface{}{validSpace, []interface{}{}, map[string]interface{}{}}), + target: crud.MakeUpsertObjectManyRequest(validSpace), + }, + { + name: "SelectRequest", + ref: tarantool.NewCall17Request("crud.select").Args( + []interface{}{validSpace, nil, map[string]interface{}{}}), + target: crud.MakeSelectRequest(validSpace), + }, + { + name: "MinRequest", + ref: tarantool.NewCall17Request("crud.min").Args( + []interface{}{validSpace, nil, map[string]interface{}{}}), + target: crud.MakeMinRequest(validSpace), + }, + { + name: "MaxRequest", + ref: tarantool.NewCall17Request("crud.max").Args( + []interface{}{validSpace, nil, map[string]interface{}{}}), + target: crud.MakeMaxRequest(validSpace), + }, + { + name: "TruncateRequest", + ref: tarantool.NewCall17Request("crud.truncate").Args( + []interface{}{validSpace, map[string]interface{}{}}), + target: crud.MakeTruncateRequest(validSpace), + }, + { + name: "LenRequest", + ref: tarantool.NewCall17Request("crud.len").Args( + []interface{}{validSpace, map[string]interface{}{}}), + target: crud.MakeLenRequest(validSpace), + }, + { + name: "CountRequest", + ref: tarantool.NewCall17Request("crud.count").Args( + []interface{}{validSpace, nil, map[string]interface{}{}}), + target: crud.MakeCountRequest(validSpace), + }, + { + name: "StorageInfoRequest", + ref: tarantool.NewCall17Request("crud.storage_info").Args( + []interface{}{map[string]interface{}{}}), + target: crud.MakeStorageInfoRequest(), + }, + { + name: "StatsRequest", + ref: tarantool.NewCall17Request("crud.stats").Args( + []interface{}{}), + target: crud.MakeStatsRequest(), + }, + { + name: "SchemaRequest", + ref: tarantool.NewCall17Request("crud.schema").Args( + []interface{}{nil, map[string]interface{}{}}), + target: crud.MakeSchemaRequest(), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assertBodyEqual(t, tc.ref, tc.target) + }) + } +} + +func TestRequestsSetters(t *testing.T) { + testCases := []struct { + name string + ref tarantool.Request + target tarantool.Request + }{ + { + name: "InsertRequest", + ref: tarantool.NewCall17Request("crud.insert").Args( + []interface{}{spaceName, tuple, expectedOpts}), + target: crud.MakeInsertRequest(spaceName).Tuple(tuple).Opts(simpleOperationOpts), + }, + { + name: "InsertObjectRequest", + ref: tarantool.NewCall17Request("crud.insert_object").Args( + []interface{}{spaceName, reqObject, expectedOpts}), + target: crud.MakeInsertObjectRequest(spaceName).Object(reqObject). + Opts(simpleOperationObjectOpts), + }, + { + name: "InsertManyRequest", + ref: tarantool.NewCall17Request("crud.insert_many").Args( + []interface{}{spaceName, tuples, expectedOpts}), + target: crud.MakeInsertManyRequest(spaceName).Tuples(tuples).Opts(opManyOpts), + }, + { + name: "InsertObjectManyRequest", + ref: tarantool.NewCall17Request("crud.insert_object_many").Args( + []interface{}{spaceName, reqObjects, expectedOpts}), + target: crud.MakeInsertObjectManyRequest(spaceName).Objects(reqObjects). + Opts(opObjManyOpts), + }, + { + name: "GetRequest", + ref: tarantool.NewCall17Request("crud.get").Args( + []interface{}{spaceName, key, expectedOpts}), + target: crud.MakeGetRequest(spaceName).Key(key).Opts(getOpts), + }, + { + name: "UpdateRequest", + ref: tarantool.NewCall17Request("crud.update").Args( + []interface{}{spaceName, key, operations, expectedOpts}), + target: crud.MakeUpdateRequest(spaceName).Key(key).Operations(operations). + Opts(simpleOperationOpts), + }, + { + name: "DeleteRequest", + ref: tarantool.NewCall17Request("crud.delete").Args( + []interface{}{spaceName, key, expectedOpts}), + target: crud.MakeDeleteRequest(spaceName).Key(key).Opts(simpleOperationOpts), + }, + { + name: "ReplaceRequest", + ref: tarantool.NewCall17Request("crud.replace").Args( + []interface{}{spaceName, tuple, expectedOpts}), + target: crud.MakeReplaceRequest(spaceName).Tuple(tuple).Opts(simpleOperationOpts), + }, + { + name: "ReplaceObjectRequest", + ref: tarantool.NewCall17Request("crud.replace_object").Args( + []interface{}{spaceName, reqObject, expectedOpts}), + target: crud.MakeReplaceObjectRequest(spaceName).Object(reqObject). + Opts(simpleOperationObjectOpts), + }, + { + name: "ReplaceManyRequest", + ref: tarantool.NewCall17Request("crud.replace_many").Args( + []interface{}{spaceName, tuples, expectedOpts}), + target: crud.MakeReplaceManyRequest(spaceName).Tuples(tuples).Opts(opManyOpts), + }, + { + name: "ReplaceObjectManyRequest", + ref: tarantool.NewCall17Request("crud.replace_object_many").Args( + []interface{}{spaceName, reqObjects, expectedOpts}), + target: crud.MakeReplaceObjectManyRequest(spaceName).Objects(reqObjects). + Opts(opObjManyOpts), + }, + { + name: "UpsertRequest", + ref: tarantool.NewCall17Request("crud.upsert").Args( + []interface{}{spaceName, tuple, operations, expectedOpts}), + target: crud.MakeUpsertRequest(spaceName).Tuple(tuple).Operations(operations). + Opts(simpleOperationOpts), + }, + { + name: "UpsertObjectRequest", + ref: tarantool.NewCall17Request("crud.upsert_object").Args( + []interface{}{spaceName, reqObject, operations, expectedOpts}), + target: crud.MakeUpsertObjectRequest(spaceName).Object(reqObject). + Operations(operations).Opts(simpleOperationOpts), + }, + { + name: "UpsertManyRequest", + ref: tarantool.NewCall17Request("crud.upsert_many").Args( + []interface{}{spaceName, tuplesOperationsData, expectedOpts}), + target: crud.MakeUpsertManyRequest(spaceName). + TuplesOperationsData(tuplesOperationsData).Opts(opManyOpts), + }, + { + name: "UpsertObjectManyRequest", + ref: tarantool.NewCall17Request("crud.upsert_object_many").Args( + []interface{}{spaceName, reqObjectsOperationsData, expectedOpts}), + target: crud.MakeUpsertObjectManyRequest(spaceName). + ObjectsOperationsData(reqObjectsOperationsData).Opts(opManyOpts), + }, + { + name: "SelectRequest", + ref: tarantool.NewCall17Request("crud.select").Args( + []interface{}{spaceName, conditions, expectedOpts}), + target: crud.MakeSelectRequest(spaceName).Conditions(conditions).Opts(selectOpts), + }, + { + name: "MinRequest", + ref: tarantool.NewCall17Request("crud.min").Args( + []interface{}{spaceName, indexName, expectedOpts}), + target: crud.MakeMinRequest(spaceName).Index(indexName).Opts(minOpts), + }, + { + name: "MaxRequest", + ref: tarantool.NewCall17Request("crud.max").Args( + []interface{}{spaceName, indexName, expectedOpts}), + target: crud.MakeMaxRequest(spaceName).Index(indexName).Opts(maxOpts), + }, + { + name: "TruncateRequest", + ref: tarantool.NewCall17Request("crud.truncate").Args( + []interface{}{spaceName, expectedOpts}), + target: crud.MakeTruncateRequest(spaceName).Opts(baseOpts), + }, + { + name: "LenRequest", + ref: tarantool.NewCall17Request("crud.len").Args( + []interface{}{spaceName, expectedOpts}), + target: crud.MakeLenRequest(spaceName).Opts(baseOpts), + }, + { + name: "CountRequest", + ref: tarantool.NewCall17Request("crud.count").Args( + []interface{}{spaceName, conditions, expectedOpts}), + target: crud.MakeCountRequest(spaceName).Conditions(conditions).Opts(countOpts), + }, + { + name: "StorageInfoRequest", + ref: tarantool.NewCall17Request("crud.storage_info").Args( + []interface{}{expectedOpts}), + target: crud.MakeStorageInfoRequest().Opts(baseOpts), + }, + { + name: "StatsRequest", + ref: tarantool.NewCall17Request("crud.stats").Args( + []interface{}{spaceName}), + target: crud.MakeStatsRequest().Space(spaceName), + }, + { + name: "SchemaRequest", + ref: tarantool.NewCall17Request("crud.schema").Args( + []interface{}{nil, schemaOpts}, + ), + target: crud.MakeSchemaRequest().Opts(schemaOpts), + }, + { + name: "SchemaRequestWithSpace", + ref: tarantool.NewCall17Request("crud.schema").Args( + []interface{}{spaceName, schemaOpts}, + ), + target: crud.MakeSchemaRequest().Space(spaceName).Opts(schemaOpts), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assertBodyEqual(t, tc.ref, tc.target) + }) + } +} + +func TestRequestsVshardRouter(t *testing.T) { + testCases := []struct { + name string + ref tarantool.Request + target tarantool.Request + }{ + { + name: "InsertRequest", + ref: tarantool.NewCall17Request("crud.insert").Args([]interface{}{ + validSpace, + []interface{}{}, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeInsertRequest(validSpace).Opts(crud.InsertOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "InsertObjectRequest", + ref: tarantool.NewCall17Request("crud.insert_object").Args([]interface{}{ + validSpace, + map[string]interface{}{}, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeInsertObjectRequest(validSpace).Opts(crud.InsertObjectOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "InsertManyRequest", + ref: tarantool.NewCall17Request("crud.insert_many").Args([]interface{}{ + validSpace, + []interface{}{}, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeInsertManyRequest(validSpace).Opts(crud.InsertManyOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "InsertObjectManyRequest", + ref: tarantool.NewCall17Request("crud.insert_object_many").Args([]interface{}{ + validSpace, + []map[string]interface{}{}, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeInsertObjectManyRequest(validSpace).Opts(crud.InsertObjectManyOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "GetRequest", + ref: tarantool.NewCall17Request("crud.get").Args([]interface{}{ + validSpace, + []interface{}{}, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeGetRequest(validSpace).Opts(crud.GetOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "UpdateRequest", + ref: tarantool.NewCall17Request("crud.update").Args([]interface{}{ + validSpace, + []interface{}{}, + []interface{}{}, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeUpdateRequest(validSpace).Opts(crud.UpdateOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "DeleteRequest", + ref: tarantool.NewCall17Request("crud.delete").Args([]interface{}{ + validSpace, + []interface{}{}, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeDeleteRequest(validSpace).Opts(crud.DeleteOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "ReplaceRequest", + ref: tarantool.NewCall17Request("crud.replace").Args([]interface{}{ + validSpace, + []interface{}{}, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeReplaceRequest(validSpace).Opts(crud.ReplaceOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "ReplaceObjectRequest", + ref: tarantool.NewCall17Request("crud.replace_object").Args([]interface{}{ + validSpace, + map[string]interface{}{}, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeReplaceObjectRequest(validSpace).Opts(crud.ReplaceObjectOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "ReplaceManyRequest", + ref: tarantool.NewCall17Request("crud.replace_many").Args([]interface{}{ + validSpace, + []interface{}{}, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeReplaceManyRequest(validSpace).Opts(crud.ReplaceManyOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "ReplaceObjectManyRequest", + ref: tarantool.NewCall17Request("crud.replace_object_many").Args([]interface{}{ + validSpace, + []map[string]interface{}{}, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeReplaceObjectManyRequest(validSpace).Opts(crud.ReplaceObjectManyOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "UpsertRequest", + ref: tarantool.NewCall17Request("crud.upsert").Args([]interface{}{ + validSpace, + []interface{}{}, + []interface{}{}, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeUpsertRequest(validSpace).Opts(crud.UpsertOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "UpsertObjectRequest", + ref: tarantool.NewCall17Request("crud.upsert_object").Args([]interface{}{ + validSpace, + map[string]interface{}{}, + []interface{}{}, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeUpsertObjectRequest(validSpace).Opts(crud.UpsertObjectOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "UpsertManyRequest", + ref: tarantool.NewCall17Request("crud.upsert_many").Args([]interface{}{ + validSpace, + []interface{}{}, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeUpsertManyRequest(validSpace).Opts(crud.UpsertManyOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "UpsertObjectManyRequest", + ref: tarantool.NewCall17Request("crud.upsert_object_many").Args([]interface{}{ + validSpace, + []interface{}{}, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeUpsertObjectManyRequest(validSpace).Opts(crud.UpsertObjectManyOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "SelectRequest", + ref: tarantool.NewCall17Request("crud.select").Args([]interface{}{ + validSpace, + nil, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeSelectRequest(validSpace).Opts(crud.SelectOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "MinRequest", + ref: tarantool.NewCall17Request("crud.min").Args([]interface{}{ + validSpace, + nil, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeMinRequest(validSpace).Opts(crud.MinOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "MaxRequest", + ref: tarantool.NewCall17Request("crud.max").Args([]interface{}{ + validSpace, + nil, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeMaxRequest(validSpace).Opts(crud.MaxOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "TruncateRequest", + ref: tarantool.NewCall17Request("crud.truncate").Args([]interface{}{ + validSpace, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeTruncateRequest(validSpace).Opts(crud.TruncateOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "LenRequest", + ref: tarantool.NewCall17Request("crud.len").Args([]interface{}{ + validSpace, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeLenRequest(validSpace).Opts(crud.LenOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "CountRequest", + ref: tarantool.NewCall17Request("crud.count").Args([]interface{}{ + validSpace, + nil, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeCountRequest(validSpace).Opts(crud.CountOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "StorageInfoRequest", + ref: tarantool.NewCall17Request("crud.storage_info").Args([]interface{}{ + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeStorageInfoRequest().Opts(crud.StorageInfoOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "SchemaRequest", + ref: tarantool.NewCall17Request("crud.schema").Args([]interface{}{ + nil, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeSchemaRequest().Opts(crud.SchemaOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + { + name: "SchemaRequestWithSpace", + ref: tarantool.NewCall17Request("crud.schema").Args([]interface{}{ + validSpace, + map[string]interface{}{"vshard_router": "custom_router"}, + }), + target: crud.MakeSchemaRequest().Space(validSpace).Opts(crud.SchemaOpts{ + VshardRouter: crud.MakeOptString("custom_router"), + }), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assertBodyEqual(t, tc.ref, tc.target) + }) + } +} diff --git a/crud/result.go b/crud/result.go new file mode 100644 index 000000000..7ae00e68f --- /dev/null +++ b/crud/result.go @@ -0,0 +1,260 @@ +package crud + +import ( + "fmt" + "reflect" + + "github.com/vmihailenco/msgpack/v5" + "github.com/vmihailenco/msgpack/v5/msgpcode" +) + +// FieldFormat contains field definition: {name='...',type='...'[,is_nullable=...]}. +type FieldFormat struct { + Name string + Type string + IsNullable bool +} + +// DecodeMsgpack provides custom msgpack decoder. +func (format *FieldFormat) DecodeMsgpack(d *msgpack.Decoder) error { + l, err := d.DecodeMapLen() + if err != nil { + return err + } + for i := 0; i < l; i++ { + key, err := d.DecodeString() + if err != nil { + return err + } + switch key { + case "name": + if format.Name, err = d.DecodeString(); err != nil { + return err + } + case "type": + if format.Type, err = d.DecodeString(); err != nil { + return err + } + case "is_nullable": + if format.IsNullable, err = d.DecodeBool(); err != nil { + return err + } + default: + if err := d.Skip(); err != nil { + return err + } + } + } + + return nil +} + +// Result describes CRUD result as an object containing metadata and rows. +type Result struct { + Metadata []FieldFormat + Rows interface{} + rowType reflect.Type +} + +// MakeResult create a Result object with a custom row type for decoding. +func MakeResult(rowType reflect.Type) Result { + return Result{ + rowType: rowType, + } +} + +func msgpackIsArray(code byte) bool { + return code == msgpcode.Array16 || code == msgpcode.Array32 || + msgpcode.IsFixedArray(code) +} + +// DecodeMsgpack provides custom msgpack decoder. +func (r *Result) DecodeMsgpack(d *msgpack.Decoder) error { + arrLen, err := d.DecodeArrayLen() + if err != nil { + return err + } + + if arrLen == 0 { + return fmt.Errorf("unexpected empty response array") + } + + // DecodeMapLen processes `nil` as zero length map, + // so in `return nil, err` case we don't miss error info. + // https://github.com/vmihailenco/msgpack/blob/3f7bd806fea698e7a9fe80979aa3512dea0a7368/decode_map.go#L79-L81 + l, err := d.DecodeMapLen() + if err != nil { + return err + } + + for i := 0; i < l; i++ { + key, err := d.DecodeString() + if err != nil { + return err + } + + switch key { + case "metadata": + metadataLen, err := d.DecodeArrayLen() + if err != nil { + return err + } + + metadata := make([]FieldFormat, metadataLen) + + for i := 0; i < metadataLen; i++ { + fieldFormat := FieldFormat{} + if err = d.Decode(&fieldFormat); err != nil { + return err + } + + metadata[i] = fieldFormat + } + + r.Metadata = metadata + case "rows": + if r.rowType != nil { + tuples := reflect.New(reflect.SliceOf(r.rowType)) + if err = d.DecodeValue(tuples); err != nil { + return err + } + r.Rows = tuples.Elem().Interface() + } else { + var decoded []interface{} + if err = d.Decode(&decoded); err != nil { + return err + } + r.Rows = decoded + } + default: + if err := d.Skip(); err != nil { + return err + } + } + } + + if arrLen > 1 { + code, err := d.PeekCode() + if err != nil { + return err + } + + if msgpackIsArray(code) { + crudErr := newErrorMany(r.rowType) + if err := d.Decode(&crudErr); err != nil { + return err + } + if crudErr != nil { + return *crudErr + } + } else if code != msgpcode.Nil { + crudErr := newError(r.rowType) + if err := d.Decode(&crudErr); err != nil { + return err + } + if crudErr != nil { + return *crudErr + } + } else { + if err := d.DecodeNil(); err != nil { + return err + } + } + } + + for i := 2; i < arrLen; i++ { + if err := d.Skip(); err != nil { + return err + } + } + + return nil +} + +// NumberResult describes CRUD result as an object containing number. +type NumberResult struct { + Value uint64 +} + +// DecodeMsgpack provides custom msgpack decoder. +func (r *NumberResult) DecodeMsgpack(d *msgpack.Decoder) error { + arrLen, err := d.DecodeArrayLen() + if err != nil { + return err + } + + if arrLen == 0 { + return fmt.Errorf("unexpected empty response array") + } + + // DecodeUint64 processes `nil` as `0`, + // so in `return nil, err` case we don't miss error info. + // https://github.com/vmihailenco/msgpack/blob/3f7bd806fea698e7a9fe80979aa3512dea0a7368/decode_number.go#L91-L93 + if r.Value, err = d.DecodeUint64(); err != nil { + return err + } + + if arrLen > 1 { + var crudErr *Error = nil + + if err := d.Decode(&crudErr); err != nil { + return err + } + + if crudErr != nil { + return crudErr + } + } + + for i := 2; i < arrLen; i++ { + if err := d.Skip(); err != nil { + return err + } + } + + return nil +} + +// BoolResult describes CRUD result as an object containing bool. +type BoolResult struct { + Value bool +} + +// DecodeMsgpack provides custom msgpack decoder. +func (r *BoolResult) DecodeMsgpack(d *msgpack.Decoder) error { + arrLen, err := d.DecodeArrayLen() + if err != nil { + return err + } + + if arrLen == 0 { + return fmt.Errorf("unexpected empty response array") + } + + // DecodeBool processes `nil` as `false`, + // so in `return nil, err` case we don't miss error info. + // https://github.com/vmihailenco/msgpack/blob/3f7bd806fea698e7a9fe80979aa3512dea0a7368/decode.go#L367-L369 + if r.Value, err = d.DecodeBool(); err != nil { + return err + } + + if arrLen > 1 { + var crudErr *Error = nil + + if err := d.Decode(&crudErr); err != nil { + return err + } + + if crudErr != nil { + return crudErr + } + } + + for i := 2; i < arrLen; i++ { + if err := d.Skip(); err != nil { + return err + } + } + + return nil +} diff --git a/crud/result_test.go b/crud/result_test.go new file mode 100644 index 000000000..578eebed1 --- /dev/null +++ b/crud/result_test.go @@ -0,0 +1,34 @@ +package crud_test + +import ( + "bytes" + "reflect" + "testing" + + "github.com/stretchr/testify/require" + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3/crud" +) + +func TestResult_DecodeMsgpack(t *testing.T) { + sampleCrudResponse := []interface{}{ + map[string]interface{}{ + "rows": []interface{}{"1", "2", "3"}, + }, + nil, + } + responses := []interface{}{sampleCrudResponse, sampleCrudResponse} + + b := bytes.NewBuffer([]byte{}) + enc := msgpack.NewEncoder(b) + err := enc.Encode(responses) + require.NoError(t, err) + + var results []crud.Result + decoder := msgpack.NewDecoder(b) + err = decoder.DecodeValue(reflect.ValueOf(&results)) + require.NoError(t, err) + require.Equal(t, results[0].Rows, []interface{}{"1", "2", "3"}) + require.Equal(t, results[1].Rows, []interface{}{"1", "2", "3"}) +} diff --git a/crud/schema.go b/crud/schema.go new file mode 100644 index 000000000..4c2d661ec --- /dev/null +++ b/crud/schema.go @@ -0,0 +1,249 @@ +package crud + +import ( + "context" + "fmt" + + "github.com/vmihailenco/msgpack/v5" + "github.com/vmihailenco/msgpack/v5/msgpcode" + + "github.com/tarantool/go-tarantool/v3" +) + +func msgpackIsMap(code byte) bool { + return code == msgpcode.Map16 || code == msgpcode.Map32 || msgpcode.IsFixedMap(code) +} + +// SchemaOpts describes options for `crud.schema` method. +type SchemaOpts struct { + // Timeout is a `vshard.call` timeout and vshard + // master discovery timeout (in seconds). + Timeout OptFloat64 + // VshardRouter is cartridge vshard group name or + // vshard router instance. + VshardRouter OptString + // Cached defines whether router should reload storage schema on call. + Cached OptBool +} + +// EncodeMsgpack provides custom msgpack encoder. +func (opts SchemaOpts) EncodeMsgpack(enc *msgpack.Encoder) error { + const optsCnt = 3 + + names := [optsCnt]string{timeoutOptName, vshardRouterOptName, + cachedOptName} + values := [optsCnt]interface{}{} + exists := [optsCnt]bool{} + values[0], exists[0] = opts.Timeout.Get() + values[1], exists[1] = opts.VshardRouter.Get() + values[2], exists[2] = opts.Cached.Get() + + return encodeOptions(enc, names[:], values[:], exists[:]) +} + +// SchemaRequest helps you to create request object to call `crud.schema` +// for execution by a Connection. +type SchemaRequest struct { + baseRequest + space OptString + opts SchemaOpts +} + +// MakeSchemaRequest returns a new empty SchemaRequest. +func MakeSchemaRequest() SchemaRequest { + req := SchemaRequest{} + req.impl = newCall("crud.schema") + return req +} + +// Space sets the space name for the SchemaRequest request. +// Note: default value is nil. +func (req SchemaRequest) Space(space string) SchemaRequest { + req.space = MakeOptString(space) + return req +} + +// Opts sets the options for the SchemaRequest request. +// Note: default value is nil. +func (req SchemaRequest) Opts(opts SchemaOpts) SchemaRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req SchemaRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + if value, ok := req.space.Get(); ok { + req.impl = req.impl.Args([]interface{}{value, req.opts}) + } else { + req.impl = req.impl.Args([]interface{}{nil, req.opts}) + } + + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req SchemaRequest) Context(ctx context.Context) SchemaRequest { + req.impl = req.impl.Context(ctx) + + return req +} + +// Schema contains CRUD cluster schema definition. +type Schema map[string]SpaceSchema + +// DecodeMsgpack provides custom msgpack decoder. +func (schema *Schema) DecodeMsgpack(d *msgpack.Decoder) error { + var l int + + code, err := d.PeekCode() + if err != nil { + return err + } + + if msgpackIsArray(code) { + // Process empty schema case. + l, err = d.DecodeArrayLen() + if err != nil { + return err + } + if l != 0 { + return fmt.Errorf("expected map or empty array, got non-empty array") + } + *schema = make(map[string]SpaceSchema, l) + } else if msgpackIsMap(code) { + l, err := d.DecodeMapLen() + if err != nil { + return err + } + *schema = make(map[string]SpaceSchema, l) + + for i := 0; i < l; i++ { + key, err := d.DecodeString() + if err != nil { + return err + } + + var spaceSchema SpaceSchema + if err := d.Decode(&spaceSchema); err != nil { + return err + } + + (*schema)[key] = spaceSchema + } + } else { + return fmt.Errorf("unexpected code=%d decoding map or empty array", code) + } + + return nil +} + +// SpaceSchema contains a single CRUD space schema definition. +type SpaceSchema struct { + Format []FieldFormat `msgpack:"format"` + Indexes map[uint32]Index `msgpack:"indexes"` +} + +// Index contains a CRUD space index definition. +type Index struct { + Id uint32 `msgpack:"id"` + Name string `msgpack:"name"` + Type string `msgpack:"type"` + Unique bool `msgpack:"unique"` + Parts []IndexPart `msgpack:"parts"` +} + +// IndexField contains a CRUD space index part definition. +type IndexPart struct { + Fieldno uint32 `msgpack:"fieldno"` + Type string `msgpack:"type"` + ExcludeNull bool `msgpack:"exclude_null"` + IsNullable bool `msgpack:"is_nullable"` +} + +// SchemaResult contains a schema request result for all spaces. +type SchemaResult struct { + Value Schema +} + +// DecodeMsgpack provides custom msgpack decoder. +func (result *SchemaResult) DecodeMsgpack(d *msgpack.Decoder) error { + arrLen, err := d.DecodeArrayLen() + if err != nil { + return err + } + + if arrLen == 0 { + return fmt.Errorf("unexpected empty response array") + } + + // DecodeMapLen inside Schema decode processes `nil` as zero length map, + // so in `return nil, err` case we don't miss error info. + // https://github.com/vmihailenco/msgpack/blob/3f7bd806fea698e7a9fe80979aa3512dea0a7368/decode_map.go#L79-L81 + if err = d.Decode(&result.Value); err != nil { + return err + } + + if arrLen > 1 { + var crudErr *Error = nil + + if err := d.Decode(&crudErr); err != nil { + return err + } + + if crudErr != nil { + return crudErr + } + } + + for i := 2; i < arrLen; i++ { + if err := d.Skip(); err != nil { + return err + } + } + + return nil +} + +// SchemaResult contains a schema request result for a single space. +type SpaceSchemaResult struct { + Value SpaceSchema +} + +// DecodeMsgpack provides custom msgpack decoder. +func (result *SpaceSchemaResult) DecodeMsgpack(d *msgpack.Decoder) error { + arrLen, err := d.DecodeArrayLen() + if err != nil { + return err + } + + if arrLen == 0 { + return fmt.Errorf("unexpected empty response array") + } + + // DecodeMapLen inside SpaceSchema decode processes `nil` as zero length map, + // so in `return nil, err` case we don't miss error info. + // https://github.com/vmihailenco/msgpack/blob/3f7bd806fea698e7a9fe80979aa3512dea0a7368/decode_map.go#L79-L81 + if err = d.Decode(&result.Value); err != nil { + return err + } + + if arrLen > 1 { + var crudErr *Error = nil + + if err := d.Decode(&crudErr); err != nil { + return err + } + + if crudErr != nil { + return crudErr + } + } + + for i := 2; i < arrLen; i++ { + if err := d.Skip(); err != nil { + return err + } + } + + return nil +} diff --git a/crud/select.go b/crud/select.go new file mode 100644 index 000000000..b52eb7003 --- /dev/null +++ b/crud/select.go @@ -0,0 +1,135 @@ +package crud + +import ( + "context" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// SelectOpts describes options for `crud.select` method. +type SelectOpts struct { + // Timeout is a `vshard.call` timeout and vshard + // master discovery timeout (in seconds). + Timeout OptFloat64 + // VshardRouter is cartridge vshard group name or + // vshard router instance. + VshardRouter OptString + // Fields is field names for getting only a subset of fields. + Fields OptTuple + // BucketId is a bucket ID. + BucketId OptUint + // Mode is a parameter with `write`/`read` possible values, + // if `write` is specified then operation is performed on master. + Mode OptString + // PreferReplica is a parameter to specify preferred target + // as one of the replicas. + PreferReplica OptBool + // Balance is a parameter to use replica according to vshard + // load balancing policy. + Balance OptBool + // First describes the maximum count of the objects to return. + First OptInt + // After is a tuple after which objects should be selected. + After OptTuple + // BatchSize is a number of tuples to process per one request to storage. + BatchSize OptUint + // ForceMapCall describes the map call is performed without any + // optimizations even if full primary key equal condition is specified. + ForceMapCall OptBool + // Fullscan describes if a critical log entry will be skipped on + // potentially long select. + Fullscan OptBool + // FetchLatestMetadata guarantees the up-to-date metadata (space format) + // in first return value, otherwise it may not take into account + // the latest migration of the data format. Performance overhead is up to 15%. + // Disabled by default. + FetchLatestMetadata OptBool + // YieldEvery describes number of tuples processed to yield after. + // Should be positive. + YieldEvery OptUint +} + +// EncodeMsgpack provides custom msgpack encoder. +func (opts SelectOpts) EncodeMsgpack(enc *msgpack.Encoder) error { + const optsCnt = 14 + + names := [optsCnt]string{timeoutOptName, vshardRouterOptName, + fieldsOptName, bucketIdOptName, + modeOptName, preferReplicaOptName, balanceOptName, + firstOptName, afterOptName, batchSizeOptName, + forceMapCallOptName, fullscanOptName, fetchLatestMetadataOptName, + yieldEveryOptName} + values := [optsCnt]interface{}{} + exists := [optsCnt]bool{} + values[0], exists[0] = opts.Timeout.Get() + values[1], exists[1] = opts.VshardRouter.Get() + values[2], exists[2] = opts.Fields.Get() + values[3], exists[3] = opts.BucketId.Get() + values[4], exists[4] = opts.Mode.Get() + values[5], exists[5] = opts.PreferReplica.Get() + values[6], exists[6] = opts.Balance.Get() + values[7], exists[7] = opts.First.Get() + values[8], exists[8] = opts.After.Get() + values[9], exists[9] = opts.BatchSize.Get() + values[10], exists[10] = opts.ForceMapCall.Get() + values[11], exists[11] = opts.Fullscan.Get() + values[12], exists[12] = opts.FetchLatestMetadata.Get() + values[13], exists[13] = opts.YieldEvery.Get() + + return encodeOptions(enc, names[:], values[:], exists[:]) +} + +// SelectRequest helps you to create request object to call `crud.select` +// for execution by a Connection. +type SelectRequest struct { + spaceRequest + conditions []Condition + opts SelectOpts +} + +type selectArgs struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Space string + Conditions []Condition + Opts SelectOpts +} + +// MakeSelectRequest returns a new empty SelectRequest. +func MakeSelectRequest(space string) SelectRequest { + req := SelectRequest{} + req.impl = newCall("crud.select") + req.space = space + req.conditions = nil + req.opts = SelectOpts{} + return req +} + +// Conditions sets the conditions for the SelectRequest request. +// Note: default value is nil. +func (req SelectRequest) Conditions(conditions []Condition) SelectRequest { + req.conditions = conditions + return req +} + +// Opts sets the options for the SelectRequest request. +// Note: default value is nil. +func (req SelectRequest) Opts(opts SelectOpts) SelectRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req SelectRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + args := selectArgs{Space: req.space, Conditions: req.conditions, Opts: req.opts} + req.impl = req.impl.Args(args) + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req SelectRequest) Context(ctx context.Context) SelectRequest { + req.impl = req.impl.Context(ctx) + + return req +} diff --git a/crud/stats.go b/crud/stats.go new file mode 100644 index 000000000..c4f6988a0 --- /dev/null +++ b/crud/stats.go @@ -0,0 +1,48 @@ +package crud + +import ( + "context" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// StatsRequest helps you to create request object to call `crud.stats` +// for execution by a Connection. +type StatsRequest struct { + baseRequest + space OptString +} + +// MakeStatsRequest returns a new empty StatsRequest. +func MakeStatsRequest() StatsRequest { + req := StatsRequest{} + req.impl = newCall("crud.stats") + return req +} + +// Space sets the space name for the StatsRequest request. +// Note: default value is nil. +func (req StatsRequest) Space(space string) StatsRequest { + req.space = MakeOptString(space) + return req +} + +// Body fills an encoder with the call request body. +func (req StatsRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + if value, ok := req.space.Get(); ok { + req.impl = req.impl.Args([]interface{}{value}) + } else { + req.impl = req.impl.Args([]interface{}{}) + } + + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req StatsRequest) Context(ctx context.Context) StatsRequest { + req.impl = req.impl.Context(ctx) + + return req +} diff --git a/crud/storage_info.go b/crud/storage_info.go new file mode 100644 index 000000000..b39bf37a5 --- /dev/null +++ b/crud/storage_info.go @@ -0,0 +1,132 @@ +package crud + +import ( + "context" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// StatusTable describes information for instance. +type StatusTable struct { + Status string + IsMaster bool + Message string +} + +// DecodeMsgpack provides custom msgpack decoder. +func (statusTable *StatusTable) DecodeMsgpack(d *msgpack.Decoder) error { + l, err := d.DecodeMapLen() + if err != nil { + return err + } + for i := 0; i < l; i++ { + key, err := d.DecodeString() + if err != nil { + return err + } + + switch key { + case "status": + if statusTable.Status, err = d.DecodeString(); err != nil { + return err + } + case "is_master": + if statusTable.IsMaster, err = d.DecodeBool(); err != nil { + return err + } + case "message": + if statusTable.Message, err = d.DecodeString(); err != nil { + return err + } + default: + if err := d.Skip(); err != nil { + return err + } + } + } + + return nil +} + +// StorageInfoResult describes result for `crud.storage_info` method. +type StorageInfoResult struct { + Info map[string]StatusTable +} + +// DecodeMsgpack provides custom msgpack decoder. +func (r *StorageInfoResult) DecodeMsgpack(d *msgpack.Decoder) error { + _, err := d.DecodeArrayLen() + if err != nil { + return err + } + + l, err := d.DecodeMapLen() + if err != nil { + return err + } + + info := make(map[string]StatusTable) + for i := 0; i < l; i++ { + key, err := d.DecodeString() + if err != nil { + return err + } + + statusTable := StatusTable{} + if err := d.Decode(&statusTable); err != nil { + return nil + } + + info[key] = statusTable + } + + r.Info = info + + return nil +} + +// StorageInfoOpts describes options for `crud.storage_info` method. +type StorageInfoOpts = BaseOpts + +// StorageInfoRequest helps you to create request object to call +// `crud.storage_info` for execution by a Connection. +type StorageInfoRequest struct { + baseRequest + opts StorageInfoOpts +} + +type storageInfoArgs struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Opts StorageInfoOpts +} + +// MakeStorageInfoRequest returns a new empty StorageInfoRequest. +func MakeStorageInfoRequest() StorageInfoRequest { + req := StorageInfoRequest{} + req.impl = newCall("crud.storage_info") + req.opts = StorageInfoOpts{} + return req +} + +// Opts sets the options for the torageInfoRequest request. +// Note: default value is nil. +func (req StorageInfoRequest) Opts(opts StorageInfoOpts) StorageInfoRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req StorageInfoRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + args := storageInfoArgs{Opts: req.opts} + req.impl = req.impl.Args(args) + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req StorageInfoRequest) Context(ctx context.Context) StorageInfoRequest { + req.impl = req.impl.Context(ctx) + + return req +} diff --git a/crud/tarantool_test.go b/crud/tarantool_test.go new file mode 100644 index 000000000..0e1c1791a --- /dev/null +++ b/crud/tarantool_test.go @@ -0,0 +1,1529 @@ +package crud_test + +import ( + "fmt" + "log" + "os" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/tarantool/go-iproto" + + "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/crud" + "github.com/tarantool/go-tarantool/v3/test_helpers" +) + +var server = "127.0.0.1:3013" +var spaceNo = uint32(617) +var spaceName = "test" +var invalidSpaceName = "invalid" +var indexNo = uint32(0) +var indexName = "primary_index" + +var dialer = tarantool.NetDialer{ + Address: server, + User: "test", + Password: "test", +} + +var opts = tarantool.Opts{ + Timeout: 5 * time.Second, +} + +var startOpts test_helpers.StartOpts = test_helpers.StartOpts{ + Dialer: dialer, + InitScript: "testdata/config.lua", + Listen: server, + WaitStart: 100 * time.Millisecond, + ConnectRetry: 10, + RetryTimeout: 500 * time.Millisecond, +} + +var timeout = float64(1.1) + +var operations = []crud.Operation{ + // Insert new fields, + // because double update of the same field results in an error. + { + Operator: crud.Insert, + Field: 4, + Value: 0, + }, + { + Operator: crud.Insert, + Field: 5, + Value: 0, + }, + { + Operator: crud.Insert, + Field: 6, + Value: 0, + }, + { + Operator: crud.Insert, + Field: 7, + Value: 0, + }, + { + Operator: crud.Insert, + Field: 8, + Value: 0, + }, + { + Operator: crud.Add, + Field: 4, + Value: 1, + }, + { + Operator: crud.Sub, + Field: 5, + Value: 1, + }, + { + Operator: crud.And, + Field: 6, + Value: 1, + }, + { + Operator: crud.Or, + Field: 7, + Value: 1, + }, + { + Operator: crud.Xor, + Field: 8, + Value: 1, + }, + { + Operator: crud.Delete, + Field: 4, + Value: 5, + }, + { + Operator: crud.Assign, + Field: "name", + Value: "bye", + }, +} + +var selectOpts = crud.SelectOpts{ + Timeout: crud.MakeOptFloat64(timeout), +} + +var countOpts = crud.CountOpts{ + Timeout: crud.MakeOptFloat64(timeout), +} + +var getOpts = crud.GetOpts{ + Timeout: crud.MakeOptFloat64(timeout), +} + +var minOpts = crud.MinOpts{ + Timeout: crud.MakeOptFloat64(timeout), +} + +var maxOpts = crud.MaxOpts{ + Timeout: crud.MakeOptFloat64(timeout), +} + +var baseOpts = crud.BaseOpts{ + Timeout: crud.MakeOptFloat64(timeout), +} + +var simpleOperationOpts = crud.SimpleOperationOpts{ + Timeout: crud.MakeOptFloat64(timeout), +} + +var simpleOperationObjectOpts = crud.SimpleOperationObjectOpts{ + Timeout: crud.MakeOptFloat64(timeout), +} + +var opManyOpts = crud.OperationManyOpts{ + Timeout: crud.MakeOptFloat64(timeout), +} + +var opObjManyOpts = crud.OperationObjectManyOpts{ + Timeout: crud.MakeOptFloat64(timeout), +} + +var schemaOpts = crud.SchemaOpts{ + Timeout: crud.MakeOptFloat64(timeout), + Cached: crud.MakeOptBool(false), +} + +var conditions = []crud.Condition{ + { + Operator: crud.Lt, + Field: "id", + Value: uint(1020), + }, +} + +var key = []interface{}{uint(1019)} + +var tuples = generateTuples() +var objects = generateObjects() + +var tuple = []interface{}{uint(1019), nil, "bla"} +var object = crud.MapObject{ + "id": uint(1019), + "name": "bla", +} + +func connect(t testing.TB) *tarantool.Connection { + for i := 0; i < 10; i++ { + ctx, cancel := test_helpers.GetConnectContext() + conn, err := tarantool.Connect(ctx, dialer, opts) + cancel() + if err != nil { + t.Fatalf("Failed to connect: %s", err) + } + + ret := struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Result bool + }{} + err = conn.Do(tarantool.NewCall17Request("is_ready")).GetTyped(&ret) + if err != nil { + t.Fatalf("Failed to check is_ready: %s", err) + } + + if ret.Result { + return conn + } + + time.Sleep(time.Second) + } + + t.Fatalf("Failed to wait for a ready state connect.") + return nil +} + +var testProcessDataCases = []struct { + name string + expectedRespLen int + req tarantool.Request +}{ + { + "Select", + 2, + crud.MakeSelectRequest(spaceName). + Conditions(conditions). + Opts(selectOpts), + }, + { + "Get", + 2, + crud.MakeGetRequest(spaceName). + Key(key). + Opts(getOpts), + }, + { + "Update", + 2, + crud.MakeUpdateRequest(spaceName). + Key(key). + Operations(operations). + Opts(simpleOperationOpts), + }, + { + "Delete", + 2, + crud.MakeDeleteRequest(spaceName). + Key(key). + Opts(simpleOperationOpts), + }, + { + "Min", + 2, + crud.MakeMinRequest(spaceName).Opts(minOpts), + }, + { + "Min", + 2, + crud.MakeMinRequest(spaceName).Index(indexName).Opts(minOpts), + }, + { + "Max", + 2, + crud.MakeMaxRequest(spaceName).Opts(maxOpts), + }, + { + "Max", + 2, + crud.MakeMaxRequest(spaceName).Index(indexName).Opts(maxOpts), + }, + { + "Truncate", + 1, + crud.MakeTruncateRequest(spaceName).Opts(baseOpts), + }, + { + "Len", + 1, + crud.MakeLenRequest(spaceName).Opts(baseOpts), + }, + { + "Count", + 2, + crud.MakeCountRequest(spaceName). + Conditions(conditions). + Opts(countOpts), + }, + { + "Stats", + 1, + crud.MakeStatsRequest().Space(spaceName), + }, + { + "StorageInfo", + 1, + crud.MakeStorageInfoRequest().Opts(baseOpts), + }, + { + "Schema", + 1, + crud.MakeSchemaRequest().Opts(schemaOpts), + }, +} + +var testResultWithErrCases = []struct { + name string + resp interface{} + req tarantool.Request +}{ + { + "BaseResult", + &crud.Result{}, + crud.MakeSelectRequest(invalidSpaceName).Opts(selectOpts), + }, + { + "ManyResult", + &crud.Result{}, + crud.MakeReplaceManyRequest(invalidSpaceName).Tuples(tuples).Opts(opManyOpts), + }, + { + "NumberResult", + &crud.CountResult{}, + crud.MakeCountRequest(invalidSpaceName).Opts(countOpts), + }, + { + "BoolResult", + &crud.TruncateResult{}, + crud.MakeTruncateRequest(invalidSpaceName).Opts(baseOpts), + }, +} + +var tuplesOperationsData = generateTuplesOperationsData(tuples, operations) +var objectsOperationData = generateObjectsOperationsData(objects, operations) + +var testGenerateDataCases = []struct { + name string + expectedRespLen int + expectedTuplesCount int + req tarantool.Request +}{ + { + "Insert", + 1, + 1, + crud.MakeInsertRequest(spaceName). + Tuple(tuple). + Opts(simpleOperationOpts), + }, + { + "InsertObject", + 1, + 1, + crud.MakeInsertObjectRequest(spaceName). + Object(object). + Opts(simpleOperationObjectOpts), + }, + { + "InsertMany", + 1, + 10, + crud.MakeInsertManyRequest(spaceName). + Tuples(tuples). + Opts(opManyOpts), + }, + { + "InsertObjectMany", + 1, + 10, + crud.MakeInsertObjectManyRequest(spaceName). + Objects(objects). + Opts(opObjManyOpts), + }, + { + "Replace", + 1, + 1, + crud.MakeReplaceRequest(spaceName). + Tuple(tuple). + Opts(simpleOperationOpts), + }, + { + "ReplaceObject", + 1, + 1, + crud.MakeReplaceObjectRequest(spaceName). + Object(object). + Opts(simpleOperationObjectOpts), + }, + { + "ReplaceMany", + 1, + 10, + crud.MakeReplaceManyRequest(spaceName). + Tuples(tuples). + Opts(opManyOpts), + }, + { + "ReplaceObjectMany", + 1, + 10, + crud.MakeReplaceObjectManyRequest(spaceName). + Objects(objects). + Opts(opObjManyOpts), + }, + { + "Upsert", + 1, + 1, + crud.MakeUpsertRequest(spaceName). + Tuple(tuple). + Operations(operations). + Opts(simpleOperationOpts), + }, + { + "UpsertObject", + 1, + 1, + crud.MakeUpsertObjectRequest(spaceName). + Object(object). + Operations(operations). + Opts(simpleOperationOpts), + }, + { + "UpsertMany", + 1, + 10, + crud.MakeUpsertManyRequest(spaceName). + TuplesOperationsData(tuplesOperationsData). + Opts(opManyOpts), + }, + { + "UpsertObjectMany", + 1, + 10, + crud.MakeUpsertObjectManyRequest(spaceName). + ObjectsOperationsData(objectsOperationData). + Opts(opManyOpts), + }, +} + +func generateTuples() []crud.Tuple { + tpls := []crud.Tuple{} + for i := 1010; i < 1020; i++ { + tpls = append(tpls, []interface{}{uint(i), nil, "bla"}) + } + + return tpls +} + +func generateTuplesOperationsData(tpls []crud.Tuple, + operations []crud.Operation) []crud.TupleOperationsData { + tuplesOperationsData := []crud.TupleOperationsData{} + for _, tpl := range tpls { + tuplesOperationsData = append(tuplesOperationsData, crud.TupleOperationsData{ + Tuple: tpl, + Operations: operations, + }) + } + + return tuplesOperationsData +} + +func generateObjects() []crud.Object { + objs := []crud.Object{} + for i := 1010; i < 1020; i++ { + objs = append(objs, crud.MapObject{ + "id": uint(i), + "name": "bla", + }) + } + + return objs +} + +func generateObjectsOperationsData(objs []crud.Object, + operations []crud.Operation) []crud.ObjectOperationsData { + objectsOperationsData := []crud.ObjectOperationsData{} + for _, obj := range objs { + objectsOperationsData = append(objectsOperationsData, crud.ObjectOperationsData{ + Object: obj, + Operations: operations, + }) + } + + return objectsOperationsData +} + +func getCrudError(req tarantool.Request, crudError interface{}) (interface{}, error) { + var err []interface{} + var ok bool + + rtype := req.Type() + if crudError != nil { + if rtype == iproto.IPROTO_CALL { + return crudError, nil + } + + if err, ok = crudError.([]interface{}); !ok { + return nil, fmt.Errorf("Incorrect CRUD error format") + } + + if len(err) < 1 { + return nil, fmt.Errorf("Incorrect CRUD error format") + } + + if err[0] != nil { + return err[0], nil + } + } + + return nil, nil +} + +func testCrudRequestPrepareData(t *testing.T, conn tarantool.Connector) { + t.Helper() + + for i := 1010; i < 1020; i++ { + req := tarantool.NewReplaceRequest(spaceName).Tuple( + []interface{}{uint(i), nil, "bla"}) + if _, err := conn.Do(req).Get(); err != nil { + t.Fatalf("Unable to prepare tuples: %s", err) + } + } +} + +func testSelectGeneratedData(t *testing.T, conn tarantool.Connector, + expectedTuplesCount int) { + req := tarantool.NewSelectRequest(spaceNo). + Index(indexNo). + Limit(20). + Iterator(tarantool.IterGe). + Key([]interface{}{uint(1010)}) + data, err := conn.Do(req).Get() + if err != nil { + t.Fatalf("Failed to Select: %s", err.Error()) + } + if len(data) != expectedTuplesCount { + t.Fatalf("Response Data len %d != %d", len(data), expectedTuplesCount) + } +} + +func testCrudRequestCheck(t *testing.T, req tarantool.Request, + data []interface{}, err error, expectedLen int) { + t.Helper() + + if err != nil { + t.Fatalf("Failed to Do CRUD request: %s", err.Error()) + } + + if len(data) < expectedLen { + t.Fatalf("Response Body len < %#v, actual len %#v", + expectedLen, len(data)) + } + + // resp.Data[0] - CRUD res. + // resp.Data[1] - CRUD err. + if expectedLen >= 2 && data[1] != nil { + if crudErr, err := getCrudError(req, data[1]); err != nil { + t.Fatalf("Failed to get CRUD error: %#v", err) + } else if crudErr != nil { + t.Fatalf("Failed to perform CRUD request on CRUD side: %#v", crudErr) + } + } +} + +func TestCrudGenerateData(t *testing.T) { + conn := connect(t) + defer conn.Close() + + for _, testCase := range testGenerateDataCases { + t.Run(testCase.name, func(t *testing.T) { + for i := 1010; i < 1020; i++ { + req := tarantool.NewDeleteRequest(spaceName). + Key([]interface{}{uint(i)}) + conn.Do(req).Get() + } + + data, err := conn.Do(testCase.req).Get() + testCrudRequestCheck(t, testCase.req, data, + err, testCase.expectedRespLen) + + testSelectGeneratedData(t, conn, testCase.expectedTuplesCount) + + for i := 1010; i < 1020; i++ { + req := tarantool.NewDeleteRequest(spaceName). + Key([]interface{}{uint(i)}) + conn.Do(req).Get() + } + }) + } +} + +func TestCrudProcessData(t *testing.T) { + conn := connect(t) + defer conn.Close() + + for _, testCase := range testProcessDataCases { + t.Run(testCase.name, func(t *testing.T) { + testCrudRequestPrepareData(t, conn) + data, err := conn.Do(testCase.req).Get() + testCrudRequestCheck(t, testCase.req, data, + err, testCase.expectedRespLen) + for i := 1010; i < 1020; i++ { + req := tarantool.NewDeleteRequest(spaceName). + Key([]interface{}{uint(i)}) + conn.Do(req).Get() + } + }) + } +} + +func TestCrudUpdateSplice(t *testing.T) { + test_helpers.SkipIfCrudSpliceBroken(t) + + conn := connect(t) + defer conn.Close() + + req := crud.MakeUpdateRequest(spaceName). + Key(key). + Operations([]crud.Operation{ + { + Operator: crud.Splice, + Field: "name", + Pos: 1, + Len: 1, + Replace: "!!", + }, + }). + Opts(simpleOperationOpts) + + testCrudRequestPrepareData(t, conn) + data, err := conn.Do(req).Get() + testCrudRequestCheck(t, req, data, + err, 2) +} + +func TestCrudUpsertSplice(t *testing.T) { + test_helpers.SkipIfCrudSpliceBroken(t) + + conn := connect(t) + defer conn.Close() + + req := crud.MakeUpsertRequest(spaceName). + Tuple(tuple). + Operations([]crud.Operation{ + { + Operator: crud.Splice, + Field: "name", + Pos: 1, + Len: 1, + Replace: "!!", + }, + }). + Opts(simpleOperationOpts) + + testCrudRequestPrepareData(t, conn) + data, err := conn.Do(req).Get() + testCrudRequestCheck(t, req, data, + err, 2) +} + +func TestCrudUpsertObjectSplice(t *testing.T) { + test_helpers.SkipIfCrudSpliceBroken(t) + + conn := connect(t) + defer conn.Close() + + req := crud.MakeUpsertObjectRequest(spaceName). + Object(object). + Operations([]crud.Operation{ + { + Operator: crud.Splice, + Field: "name", + Pos: 1, + Len: 1, + Replace: "!!", + }, + }). + Opts(simpleOperationOpts) + + testCrudRequestPrepareData(t, conn) + data, err := conn.Do(req).Get() + testCrudRequestCheck(t, req, data, + err, 2) +} + +func TestUnflattenRows_IncorrectParams(t *testing.T) { + invalidMetadata := []interface{}{ + map[interface{}]interface{}{ + "name": true, + "type": "number", + }, + map[interface{}]interface{}{ + "name": "name", + "type": "string", + }, + } + + tpls := []interface{}{ + tuple, + } + + // Format tuples with invalid format with UnflattenRows. + objs, err := crud.UnflattenRows(tpls, invalidMetadata) + require.Nil(t, objs) + require.NotNil(t, err) + require.Contains(t, err.Error(), "unexpected space format") +} + +func TestUnflattenRows(t *testing.T) { + var ( + ok bool + err error + expectedId uint64 + actualId uint64 + res map[interface{}]interface{} + metadata []interface{} + tpls []interface{} + ) + + conn := connect(t) + defer conn.Close() + + // Do `replace`. + req := crud.MakeReplaceRequest(spaceName). + Tuple(tuple). + Opts(simpleOperationOpts) + data, err := conn.Do(req).Get() + testCrudRequestCheck(t, req, data, err, 2) + + if res, ok = data[0].(map[interface{}]interface{}); !ok { + t.Fatalf("Unexpected CRUD result: %#v", data[0]) + } + + if rawMetadata, ok := res["metadata"]; !ok { + t.Fatalf("Failed to get CRUD metadata") + } else { + if metadata, ok = rawMetadata.([]interface{}); !ok { + t.Fatalf("Unexpected CRUD metadata: %#v", rawMetadata) + } + } + + if rawTuples, ok := res["rows"]; !ok { + t.Fatalf("Failed to get CRUD rows") + } else { + if tpls, ok = rawTuples.([]interface{}); !ok { + t.Fatalf("Unexpected CRUD rows: %#v", rawTuples) + } + } + + // Format `replace` result with UnflattenRows. + objs, err := crud.UnflattenRows(tpls, metadata) + if err != nil { + t.Fatalf("Failed to unflatten rows: %#v", err) + } + if len(objs) < 1 { + t.Fatalf("Unexpected unflatten rows result: %#v", objs) + } + + if _, ok := objs[0]["bucket_id"]; ok { + delete(objs[0], "bucket_id") + } else { + t.Fatalf("Expected `bucket_id` field") + } + + require.Equal(t, len(object), len(objs[0])) + if expectedId, err = test_helpers.ConvertUint64(object["id"]); err != nil { + t.Fatalf("Unexpected `id` type") + } + + if actualId, err = test_helpers.ConvertUint64(objs[0]["id"]); err != nil { + t.Fatalf("Unexpected `id` type") + } + + require.Equal(t, expectedId, actualId) + require.Equal(t, object["name"], objs[0]["name"]) +} + +func TestResultWithErr(t *testing.T) { + conn := connect(t) + defer conn.Close() + + for _, testCase := range testResultWithErrCases { + t.Run(testCase.name, func(t *testing.T) { + err := conn.Do(testCase.req).GetTyped(testCase.resp) + if err == nil { + t.Fatalf("Expected CRUD fails with error, but error is not received") + } + require.Contains(t, err.Error(), "Space \"invalid\" doesn't exist") + }) + } +} + +func TestBoolResult(t *testing.T) { + conn := connect(t) + defer conn.Close() + + req := crud.MakeTruncateRequest(spaceName).Opts(baseOpts) + resp := crud.TruncateResult{} + + testCrudRequestPrepareData(t, conn) + + err := conn.Do(req).GetTyped(&resp) + if err != nil { + t.Fatalf("Failed to Do CRUD request: %s", err.Error()) + } + + if resp.Value != true { + t.Fatalf("Unexpected response value: %#v != %#v", resp.Value, true) + } + + for i := 1010; i < 1020; i++ { + req := tarantool.NewDeleteRequest(spaceName). + Key([]interface{}{uint(i)}) + conn.Do(req).Get() + } +} + +func TestNumberResult(t *testing.T) { + conn := connect(t) + defer conn.Close() + + req := crud.MakeCountRequest(spaceName).Opts(countOpts) + resp := crud.CountResult{} + + testCrudRequestPrepareData(t, conn) + + err := conn.Do(req).GetTyped(&resp) + if err != nil { + t.Fatalf("Failed to Do CRUD request: %s", err.Error()) + } + + if resp.Value != 10 { + t.Fatalf("Unexpected response value: %#v != %#v", resp.Value, 10) + } + + for i := 1010; i < 1020; i++ { + req := tarantool.NewDeleteRequest(spaceName). + Key([]interface{}{uint(i)}) + conn.Do(req).Get() + } +} + +func TestBaseResult(t *testing.T) { + expectedMetadata := []crud.FieldFormat{ + { + Name: "bucket_id", + Type: "unsigned", + IsNullable: true, + }, + { + Name: "id", + Type: "unsigned", + IsNullable: false, + }, + { + Name: "name", + Type: "string", + IsNullable: false, + }, + } + + conn := connect(t) + defer conn.Close() + + req := crud.MakeSelectRequest(spaceName).Opts(selectOpts) + resp := crud.Result{} + + testCrudRequestPrepareData(t, conn) + + err := conn.Do(req).GetTyped(&resp) + if err != nil { + t.Fatalf("Failed to Do CRUD request: %s", err) + } + + require.ElementsMatch(t, resp.Metadata, expectedMetadata) + + if len(resp.Rows.([]interface{})) != 10 { + t.Fatalf("Unexpected rows: %#v", resp.Rows) + } + + for i := 1010; i < 1020; i++ { + req := tarantool.NewDeleteRequest(spaceName). + Key([]interface{}{uint(i)}) + conn.Do(req).Get() + } +} + +func TestManyResult(t *testing.T) { + expectedMetadata := []crud.FieldFormat{ + { + Name: "bucket_id", + Type: "unsigned", + IsNullable: true, + }, + { + Name: "id", + Type: "unsigned", + IsNullable: false, + }, + { + Name: "name", + Type: "string", + IsNullable: false, + }, + } + + conn := connect(t) + defer conn.Close() + + req := crud.MakeReplaceManyRequest(spaceName).Tuples(tuples).Opts(opManyOpts) + resp := crud.Result{} + + testCrudRequestPrepareData(t, conn) + + err := conn.Do(req).GetTyped(&resp) + if err != nil { + t.Fatalf("Failed to Do CRUD request: %s", err.Error()) + } + + require.ElementsMatch(t, resp.Metadata, expectedMetadata) + + if len(resp.Rows.([]interface{})) != 10 { + t.Fatalf("Unexpected rows: %#v", resp.Rows) + } + + for i := 1010; i < 1020; i++ { + req := tarantool.NewDeleteRequest(spaceName). + Key([]interface{}{uint(i)}) + conn.Do(req).Get() + } +} + +func TestStorageInfoResult(t *testing.T) { + conn := connect(t) + defer conn.Close() + + req := crud.MakeStorageInfoRequest().Opts(baseOpts) + resp := crud.StorageInfoResult{} + + err := conn.Do(req).GetTyped(&resp) + if err != nil { + t.Fatalf("Failed to Do CRUD request: %s", err.Error()) + } + + if resp.Info == nil { + t.Fatalf("Failed to Do CRUD storage info request") + } + + for _, info := range resp.Info { + if info.Status != "running" { + t.Fatalf("Unexpected Status: %s != running", info.Status) + } + + if info.IsMaster != true { + t.Fatalf("Unexpected IsMaster: %v != true", info.IsMaster) + } + + if msg := info.Message; msg != "" { + t.Fatalf("Unexpected Message: %s", msg) + } + } +} + +func TestGetAdditionalOpts(t *testing.T) { + conn := connect(t) + defer conn.Close() + + req := crud.MakeGetRequest(spaceName).Key(key).Opts(crud.GetOpts{ + Timeout: crud.MakeOptFloat64(1.1), + Fields: crud.MakeOptTuple([]interface{}{"name"}), + Mode: crud.MakeOptString("read"), + PreferReplica: crud.MakeOptBool(true), + Balance: crud.MakeOptBool(true), + }) + resp := crud.Result{} + + testCrudRequestPrepareData(t, conn) + + err := conn.Do(req).GetTyped(&resp) + if err != nil { + t.Fatalf("Failed to Do CRUD request: %s", err) + } +} + +var testMetadataCases = []struct { + name string + req tarantool.Request +}{ + { + "Insert", + crud.MakeInsertRequest(spaceName). + Tuple(tuple). + Opts(crud.InsertOpts{ + FetchLatestMetadata: crud.MakeOptBool(true), + }), + }, + { + "InsertObject", + crud.MakeInsertObjectRequest(spaceName). + Object(object). + Opts(crud.InsertObjectOpts{ + FetchLatestMetadata: crud.MakeOptBool(true), + }), + }, + { + "InsertMany", + crud.MakeInsertManyRequest(spaceName). + Tuples(tuples). + Opts(crud.InsertManyOpts{ + FetchLatestMetadata: crud.MakeOptBool(true), + }), + }, + { + "InsertObjectMany", + crud.MakeInsertObjectManyRequest(spaceName). + Objects(objects). + Opts(crud.InsertObjectManyOpts{ + FetchLatestMetadata: crud.MakeOptBool(true), + }), + }, + { + "Replace", + crud.MakeReplaceRequest(spaceName). + Tuple(tuple). + Opts(crud.ReplaceOpts{ + FetchLatestMetadata: crud.MakeOptBool(true), + }), + }, + { + "ReplaceObject", + crud.MakeReplaceObjectRequest(spaceName). + Object(object). + Opts(crud.ReplaceObjectOpts{ + FetchLatestMetadata: crud.MakeOptBool(true), + }), + }, + { + "ReplaceMany", + crud.MakeReplaceManyRequest(spaceName). + Tuples(tuples). + Opts(crud.ReplaceManyOpts{ + FetchLatestMetadata: crud.MakeOptBool(true), + }), + }, + { + "ReplaceObjectMany", + crud.MakeReplaceObjectManyRequest(spaceName). + Objects(objects). + Opts(crud.ReplaceObjectManyOpts{ + FetchLatestMetadata: crud.MakeOptBool(true), + }), + }, + { + "Upsert", + crud.MakeUpsertRequest(spaceName). + Tuple(tuple). + Operations(operations). + Opts(crud.UpsertOpts{ + FetchLatestMetadata: crud.MakeOptBool(true), + }), + }, + { + "UpsertObject", + crud.MakeUpsertObjectRequest(spaceName). + Object(object). + Operations(operations). + Opts(crud.UpsertObjectOpts{ + FetchLatestMetadata: crud.MakeOptBool(true), + }), + }, + { + "UpsertMany", + crud.MakeUpsertManyRequest(spaceName). + TuplesOperationsData(tuplesOperationsData). + Opts(crud.UpsertManyOpts{ + FetchLatestMetadata: crud.MakeOptBool(true), + }), + }, + { + "UpsertObjectMany", + crud.MakeUpsertObjectManyRequest(spaceName). + ObjectsOperationsData(objectsOperationData). + Opts(crud.UpsertObjectManyOpts{ + FetchLatestMetadata: crud.MakeOptBool(true), + }), + }, + { + "Select", + crud.MakeSelectRequest(spaceName). + Conditions(conditions). + Opts(crud.SelectOpts{ + FetchLatestMetadata: crud.MakeOptBool(true), + }), + }, + { + "Get", + crud.MakeGetRequest(spaceName). + Key(key). + Opts(crud.GetOpts{ + FetchLatestMetadata: crud.MakeOptBool(true), + }), + }, + { + "Update", + crud.MakeUpdateRequest(spaceName). + Key(key). + Operations(operations). + Opts(crud.UpdateOpts{ + FetchLatestMetadata: crud.MakeOptBool(true), + }), + }, + { + "Delete", + crud.MakeDeleteRequest(spaceName). + Key(key). + Opts(crud.DeleteOpts{ + FetchLatestMetadata: crud.MakeOptBool(true), + }), + }, + { + "Min", + crud.MakeMinRequest(spaceName). + Opts(crud.MinOpts{ + FetchLatestMetadata: crud.MakeOptBool(true), + }), + }, + { + "Max", + crud.MakeMaxRequest(spaceName). + Opts(crud.MaxOpts{ + FetchLatestMetadata: crud.MakeOptBool(true), + }), + }, +} + +func TestFetchLatestMetadataOption(t *testing.T) { + conn := connect(t) + defer conn.Close() + + for _, testCase := range testMetadataCases { + t.Run(testCase.name, func(t *testing.T) { + for i := 1010; i < 1020; i++ { + req := tarantool.NewDeleteRequest(spaceName). + Key([]interface{}{uint(i)}) + conn.Do(req).Get() + } + + resp := crud.Result{} + + err := conn.Do(testCase.req).GetTyped(&resp) + if err != nil { + t.Fatalf("Failed to Do CRUD request: %s", err) + } + + if len(resp.Metadata) == 0 { + t.Fatalf("Failed to get relevant metadata") + } + + for i := 1010; i < 1020; i++ { + req := tarantool.NewDeleteRequest(spaceName). + Key([]interface{}{uint(i)}) + conn.Do(req).Get() + } + }) + } +} + +var testNoreturnCases = []struct { + name string + req tarantool.Request +}{ + { + "Insert", + crud.MakeInsertRequest(spaceName). + Tuple(tuple). + Opts(crud.InsertOpts{ + Noreturn: crud.MakeOptBool(true), + }), + }, + { + "InsertObject", + crud.MakeInsertObjectRequest(spaceName). + Object(object). + Opts(crud.InsertObjectOpts{ + Noreturn: crud.MakeOptBool(true), + }), + }, + { + "InsertMany", + crud.MakeInsertManyRequest(spaceName). + Tuples(tuples). + Opts(crud.InsertManyOpts{ + Noreturn: crud.MakeOptBool(true), + }), + }, + { + "InsertObjectMany", + crud.MakeInsertObjectManyRequest(spaceName). + Objects(objects). + Opts(crud.InsertObjectManyOpts{ + Noreturn: crud.MakeOptBool(true), + }), + }, + { + "Replace", + crud.MakeReplaceRequest(spaceName). + Tuple(tuple). + Opts(crud.ReplaceOpts{ + Noreturn: crud.MakeOptBool(true), + }), + }, + { + "ReplaceObject", + crud.MakeReplaceObjectRequest(spaceName). + Object(object). + Opts(crud.ReplaceObjectOpts{ + Noreturn: crud.MakeOptBool(true), + }), + }, + { + "ReplaceMany", + crud.MakeReplaceManyRequest(spaceName). + Tuples(tuples). + Opts(crud.ReplaceManyOpts{ + Noreturn: crud.MakeOptBool(true), + }), + }, + { + "ReplaceObjectMany", + crud.MakeReplaceObjectManyRequest(spaceName). + Objects(objects). + Opts(crud.ReplaceObjectManyOpts{ + Noreturn: crud.MakeOptBool(true), + }), + }, + { + "Upsert", + crud.MakeUpsertRequest(spaceName). + Tuple(tuple). + Operations(operations). + Opts(crud.UpsertOpts{ + Noreturn: crud.MakeOptBool(true), + }), + }, + { + "UpsertObject", + crud.MakeUpsertObjectRequest(spaceName). + Object(object). + Operations(operations). + Opts(crud.UpsertObjectOpts{ + Noreturn: crud.MakeOptBool(true), + }), + }, + { + "UpsertMany", + crud.MakeUpsertManyRequest(spaceName). + TuplesOperationsData(tuplesOperationsData). + Opts(crud.UpsertManyOpts{ + Noreturn: crud.MakeOptBool(true), + }), + }, + { + "UpsertObjectMany", + crud.MakeUpsertObjectManyRequest(spaceName). + ObjectsOperationsData(objectsOperationData). + Opts(crud.UpsertObjectManyOpts{ + Noreturn: crud.MakeOptBool(true), + }), + }, + { + "Update", + crud.MakeUpdateRequest(spaceName). + Key(key). + Operations(operations). + Opts(crud.UpdateOpts{ + Noreturn: crud.MakeOptBool(true), + }), + }, + { + "Delete", + crud.MakeDeleteRequest(spaceName). + Key(key). + Opts(crud.DeleteOpts{ + Noreturn: crud.MakeOptBool(true), + }), + }, +} + +func TestNoreturnOption(t *testing.T) { + conn := connect(t) + defer conn.Close() + + for _, testCase := range testNoreturnCases { + t.Run(testCase.name, func(t *testing.T) { + for i := 1010; i < 1020; i++ { + req := tarantool.NewDeleteRequest(spaceName). + Key([]interface{}{uint(i)}) + conn.Do(req).Get() + } + + data, err := conn.Do(testCase.req).Get() + if err != nil { + t.Fatalf("Failed to Do CRUD request: %s", err) + } + + if len(data) == 0 { + t.Fatalf("Expected explicit nil") + } + + if data[0] != nil { + t.Fatalf("Expected nil result, got %v", data[0]) + } + + if len(data) >= 2 && data[1] != nil { + t.Fatalf("Expected no returned errors, got %v", data[1]) + } + + for i := 1010; i < 1020; i++ { + req := tarantool.NewDeleteRequest(spaceName). + Key([]interface{}{uint(i)}) + conn.Do(req).Get() + } + }) + } +} + +func TestNoreturnOptionTyped(t *testing.T) { + conn := connect(t) + defer conn.Close() + + for _, testCase := range testNoreturnCases { + t.Run(testCase.name, func(t *testing.T) { + for i := 1010; i < 1020; i++ { + req := tarantool.NewDeleteRequest(spaceName). + Key([]interface{}{uint(i)}) + conn.Do(req).Get() + } + + resp := crud.Result{} + + err := conn.Do(testCase.req).GetTyped(&resp) + if err != nil { + t.Fatalf("Failed to Do CRUD request: %s", err) + } + + if resp.Rows != nil { + t.Fatalf("Expected nil rows, got %v", resp.Rows) + } + + if len(resp.Metadata) != 0 { + t.Fatalf("Expected no metadata") + } + + for i := 1010; i < 1020; i++ { + req := tarantool.NewDeleteRequest(spaceName). + Key([]interface{}{uint(i)}) + conn.Do(req).Get() + } + }) + } +} + +func getTestSchema(t *testing.T) crud.Schema { + schema := crud.Schema{ + "test": crud.SpaceSchema{ + Format: []crud.FieldFormat{ + crud.FieldFormat{ + Name: "id", + Type: "unsigned", + IsNullable: false, + }, + { + Name: "bucket_id", + Type: "unsigned", + IsNullable: true, + }, + { + Name: "name", + Type: "string", + IsNullable: false, + }, + }, + Indexes: map[uint32]crud.Index{ + 0: { + Id: 0, + Name: "primary_index", + Type: "TREE", + Unique: true, + Parts: []crud.IndexPart{ + { + Fieldno: 1, + Type: "unsigned", + ExcludeNull: false, + IsNullable: false, + }, + }, + }, + }, + }, + } + + // https://github.com/tarantool/tarantool/issues/4091 + uniqueIssue, err := test_helpers.IsTarantoolVersionLess(2, 2, 1) + require.Equal(t, err, nil, "expected version check to succeed") + + if uniqueIssue { + for sk, sv := range schema { + for ik, iv := range sv.Indexes { + iv.Unique = false + sv.Indexes[ik] = iv + } + schema[sk] = sv + } + } + + // https://github.com/tarantool/tarantool/commit/17c9c034933d726925910ce5bf8b20e8e388f6e3 + excludeNullUnsupported, err := test_helpers.IsTarantoolVersionLess(2, 8, 1) + require.Equal(t, err, nil, "expected version check to succeed") + + if excludeNullUnsupported { + for sk, sv := range schema { + for ik, iv := range sv.Indexes { + for pk, pv := range iv.Parts { + // Struct default value. + pv.ExcludeNull = false + iv.Parts[pk] = pv + } + sv.Indexes[ik] = iv + } + schema[sk] = sv + } + } + + return schema +} + +func TestSchemaTyped(t *testing.T) { + conn := connect(t) + defer conn.Close() + + req := crud.MakeSchemaRequest() + var result crud.SchemaResult + + err := conn.Do(req).GetTyped(&result) + require.Equal(t, err, nil, "Expected CRUD request to succeed") + require.Equal(t, result.Value, getTestSchema(t), "map with \"test\" schema expected") +} + +func TestSpaceSchemaTyped(t *testing.T) { + conn := connect(t) + defer conn.Close() + + req := crud.MakeSchemaRequest().Space("test") + var result crud.SpaceSchemaResult + + err := conn.Do(req).GetTyped(&result) + require.Equal(t, err, nil, "Expected CRUD request to succeed") + require.Equal(t, result.Value, getTestSchema(t)["test"], "map with \"test\" schema expected") +} + +func TestSpaceSchemaTypedError(t *testing.T) { + conn := connect(t) + defer conn.Close() + + req := crud.MakeSchemaRequest().Space("not_exist") + var result crud.SpaceSchemaResult + + err := conn.Do(req).GetTyped(&result) + require.NotEqual(t, err, nil, "Expected CRUD request to fail") + require.Regexp(t, "Space \"not_exist\" doesn't exist", err.Error()) +} + +func TestUnitEmptySchema(t *testing.T) { + // We need to create another cluster with no spaces + // to test `{}` schema, so let's at least add a unit test. + conn := connect(t) + defer conn.Close() + + req := tarantool.NewEvalRequest("return {}") + var result crud.SchemaResult + + err := conn.Do(req).GetTyped(&result) + require.Equal(t, err, nil, "Expected CRUD request to succeed") + require.Equal(t, result.Value, crud.Schema{}, "empty schema expected") +} + +var testStorageYieldCases = []struct { + name string + req tarantool.Request +}{ + { + "Count", + crud.MakeCountRequest(spaceName). + Opts(crud.CountOpts{ + YieldEvery: crud.MakeOptUint(500), + }), + }, + { + "Select", + crud.MakeSelectRequest(spaceName). + Opts(crud.SelectOpts{ + YieldEvery: crud.MakeOptUint(500), + }), + }, +} + +func TestYieldEveryOption(t *testing.T) { + conn := connect(t) + defer conn.Close() + + for _, testCase := range testStorageYieldCases { + t.Run(testCase.name, func(t *testing.T) { + _, err := conn.Do(testCase.req).Get() + if err != nil { + t.Fatalf("Failed to Do CRUD request: %s", err) + } + }) + } +} + +// runTestMain is a body of TestMain function +// (see https://pkg.go.dev/testing#hdr-Main). +// Using defer + os.Exit is not works so TestMain body +// is a separate function, see +// https://stackoverflow.com/questions/27629380/how-to-exit-a-go-program-honoring-deferred-calls +func runTestMain(m *testing.M) int { + inst, err := test_helpers.StartTarantool(startOpts) + defer test_helpers.StopTarantoolWithCleanup(inst) + + if err != nil { + log.Printf("Failed to prepare test tarantool: %s", err) + return 1 + } + + return m.Run() +} + +func TestMain(m *testing.M) { + code := runTestMain(m) + os.Exit(code) +} diff --git a/crud/testdata/config.lua b/crud/testdata/config.lua new file mode 100644 index 000000000..81f486b38 --- /dev/null +++ b/crud/testdata/config.lua @@ -0,0 +1,117 @@ +-- configure path so that you can run application +-- from outside the root directory +if package.setsearchroot ~= nil then + package.setsearchroot() +else + -- Workaround for rocks loading in tarantool 1.10 + -- It can be removed in tarantool > 2.2 + -- By default, when you do require('mymodule'), tarantool looks into + -- the current working directory and whatever is specified in + -- package.path and package.cpath. If you run your app while in the + -- root directory of that app, everything goes fine, but if you try to + -- start your app with "tarantool myapp/init.lua", it will fail to load + -- its modules, and modules from myapp/.rocks. + local fio = require('fio') + local app_dir = fio.abspath(fio.dirname(arg[0])) + package.path = app_dir .. '/?.lua;' .. package.path + package.path = app_dir .. '/?/init.lua;' .. package.path + package.path = app_dir .. '/.rocks/share/tarantool/?.lua;' .. package.path + package.path = app_dir .. '/.rocks/share/tarantool/?/init.lua;' .. package.path + package.cpath = app_dir .. '/?.so;' .. package.cpath + package.cpath = app_dir .. '/?.dylib;' .. package.cpath + package.cpath = app_dir .. '/.rocks/lib/tarantool/?.so;' .. package.cpath + package.cpath = app_dir .. '/.rocks/lib/tarantool/?.dylib;' .. package.cpath +end + +local crud = require('crud') +local vshard = require('vshard') + +-- Do not set listen for now so connector won't be +-- able to send requests until everything is configured. +box.cfg{ + work_dir = os.getenv("TEST_TNT_WORK_DIR"), +} + +box.schema.user.grant( + 'guest', + 'read,write,execute', + 'universe' +) + +local s = box.schema.space.create('test', { + id = 617, + if_not_exists = true, + format = { + {name = 'id', type = 'unsigned'}, + {name = 'bucket_id', type = 'unsigned', is_nullable = true}, + {name = 'name', type = 'string'}, + } +}) +s:create_index('primary_index', { + parts = { + {field = 1, type = 'unsigned'}, + }, +}) +s:create_index('bucket_id', { + parts = { + {field = 2, type = 'unsigned'}, + }, + unique = false, +}) + +local function is_ready_false() + return false +end + +local function is_ready_true() + return true +end + +rawset(_G, 'is_ready', is_ready_false) + +-- Setup vshard. +_G.vshard = vshard +box.once('guest', function() + box.schema.user.grant('guest', 'super') +end) +local uri = 'guest@127.0.0.1:3013' +local box_info = box.info() + +local replicaset_uuid +if box_info.replicaset then + -- Since Tarantool 3.0. + replicaset_uuid = box_info.replicaset.uuid +else + replicaset_uuid = box_info.cluster.uuid +end + +local cfg = { + bucket_count = 300, + sharding = { + [replicaset_uuid] = { + replicas = { + [box_info.uuid] = { + uri = uri, + name = 'storage', + master = true, + }, + }, + }, + }, +} +vshard.storage.cfg(cfg, box_info.uuid) +vshard.router.cfg(cfg) +vshard.router.bootstrap() + +-- Initialize crud. +crud.init_storage() +crud.init_router() +crud.cfg{stats = true} + +box.schema.user.create('test', { password = 'test' , if_not_exists = true }) +box.schema.user.grant('test', 'execute', 'universe', nil, { if_not_exists = true }) +box.schema.user.grant('test', 'create,read,write,drop,alter', 'space', nil, { if_not_exists = true }) +box.schema.user.grant('test', 'create', 'sequence', nil, { if_not_exists = true }) + +-- Set is_ready = is_ready_true only when every other thing is configured. +rawset(_G, 'is_ready', is_ready_true) diff --git a/crud/truncate.go b/crud/truncate.go new file mode 100644 index 000000000..8313785d9 --- /dev/null +++ b/crud/truncate.go @@ -0,0 +1,58 @@ +package crud + +import ( + "context" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// TruncateResult describes result for `crud.truncate` method. +type TruncateResult = BoolResult + +// TruncateOpts describes options for `crud.truncate` method. +type TruncateOpts = BaseOpts + +// TruncateRequest helps you to create request object to call `crud.truncate` +// for execution by a Connection. +type TruncateRequest struct { + spaceRequest + opts TruncateOpts +} + +type truncateArgs struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Space string + Opts TruncateOpts +} + +// MakeTruncateRequest returns a new empty TruncateRequest. +func MakeTruncateRequest(space string) TruncateRequest { + req := TruncateRequest{} + req.impl = newCall("crud.truncate") + req.space = space + req.opts = TruncateOpts{} + return req +} + +// Opts sets the options for the TruncateRequest request. +// Note: default value is nil. +func (req TruncateRequest) Opts(opts TruncateOpts) TruncateRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req TruncateRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + args := truncateArgs{Space: req.space, Opts: req.opts} + req.impl = req.impl.Args(args) + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req TruncateRequest) Context(ctx context.Context) TruncateRequest { + req.impl = req.impl.Context(ctx) + + return req +} diff --git a/crud/tuple.go b/crud/tuple.go new file mode 100644 index 000000000..1f6850521 --- /dev/null +++ b/crud/tuple.go @@ -0,0 +1,13 @@ +package crud + +// Tuple is a type to describe tuple for CRUD methods. It can be any type that +// msgpask can encode as an array. +type Tuple = interface{} + +// Tuples is a type to describe an array of tuples for CRUD methods. It can be +// any type that msgpack can encode, but encoded data must be an array of +// tuples. +// +// See the reason why not just []Tuple: +// https://github.com/tarantool/go-tarantool/issues/365 +type Tuples = interface{} diff --git a/crud/unflatten_rows.go b/crud/unflatten_rows.go new file mode 100644 index 000000000..f360f8ad5 --- /dev/null +++ b/crud/unflatten_rows.go @@ -0,0 +1,35 @@ +package crud + +import ( + "fmt" +) + +// UnflattenRows can be used to convert received tuples to objects. +func UnflattenRows(tuples []interface{}, format []interface{}) ([]MapObject, error) { + var ( + ok bool + fieldName string + fieldInfo map[interface{}]interface{} + ) + + objects := []MapObject{} + + for _, tuple := range tuples { + object := make(map[string]interface{}) + for fieldIdx, field := range tuple.([]interface{}) { + if fieldInfo, ok = format[fieldIdx].(map[interface{}]interface{}); !ok { + return nil, fmt.Errorf("unexpected space format: %q", format) + } + + if fieldName, ok = fieldInfo["name"].(string); !ok { + return nil, fmt.Errorf("unexpected space format: %q", format) + } + + object[fieldName] = field + } + + objects = append(objects, object) + } + + return objects, nil +} diff --git a/crud/update.go b/crud/update.go new file mode 100644 index 000000000..4bdeb01ce --- /dev/null +++ b/crud/update.go @@ -0,0 +1,78 @@ +package crud + +import ( + "context" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// UpdateOpts describes options for `crud.update` method. +type UpdateOpts = SimpleOperationOpts + +// UpdateRequest helps you to create request object to call `crud.update` +// for execution by a Connection. +type UpdateRequest struct { + spaceRequest + key Tuple + operations []Operation + opts UpdateOpts +} + +type updateArgs struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Space string + Key Tuple + Operations []Operation + Opts UpdateOpts +} + +// MakeUpdateRequest returns a new empty UpdateRequest. +func MakeUpdateRequest(space string) UpdateRequest { + req := UpdateRequest{} + req.impl = newCall("crud.update") + req.space = space + req.operations = []Operation{} + req.opts = UpdateOpts{} + return req +} + +// Key sets the key for the UpdateRequest request. +// Note: default value is nil. +func (req UpdateRequest) Key(key Tuple) UpdateRequest { + req.key = key + return req +} + +// Operations sets the operations for UpdateRequest request. +// Note: default value is nil. +func (req UpdateRequest) Operations(operations []Operation) UpdateRequest { + req.operations = operations + return req +} + +// Opts sets the options for the UpdateRequest request. +// Note: default value is nil. +func (req UpdateRequest) Opts(opts UpdateOpts) UpdateRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req UpdateRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + if req.key == nil { + req.key = []interface{}{} + } + args := updateArgs{Space: req.space, Key: req.key, + Operations: req.operations, Opts: req.opts} + req.impl = req.impl.Args(args) + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req UpdateRequest) Context(ctx context.Context) UpdateRequest { + req.impl = req.impl.Context(ctx) + + return req +} diff --git a/crud/upsert.go b/crud/upsert.go new file mode 100644 index 000000000..d55d1da1b --- /dev/null +++ b/crud/upsert.go @@ -0,0 +1,147 @@ +package crud + +import ( + "context" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// UpsertOpts describes options for `crud.upsert` method. +type UpsertOpts = SimpleOperationOpts + +// UpsertRequest helps you to create request object to call `crud.upsert` +// for execution by a Connection. +type UpsertRequest struct { + spaceRequest + tuple Tuple + operations []Operation + opts UpsertOpts +} + +type upsertArgs struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Space string + Tuple Tuple + Operations []Operation + Opts UpsertOpts +} + +// MakeUpsertRequest returns a new empty UpsertRequest. +func MakeUpsertRequest(space string) UpsertRequest { + req := UpsertRequest{} + req.impl = newCall("crud.upsert") + req.space = space + req.operations = []Operation{} + req.opts = UpsertOpts{} + return req +} + +// Tuple sets the tuple for the UpsertRequest request. +// Note: default value is nil. +func (req UpsertRequest) Tuple(tuple Tuple) UpsertRequest { + req.tuple = tuple + return req +} + +// Operations sets the operations for the UpsertRequest request. +// Note: default value is nil. +func (req UpsertRequest) Operations(operations []Operation) UpsertRequest { + req.operations = operations + return req +} + +// Opts sets the options for the UpsertRequest request. +// Note: default value is nil. +func (req UpsertRequest) Opts(opts UpsertOpts) UpsertRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req UpsertRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + if req.tuple == nil { + req.tuple = []interface{}{} + } + args := upsertArgs{Space: req.space, Tuple: req.tuple, + Operations: req.operations, Opts: req.opts} + req.impl = req.impl.Args(args) + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req UpsertRequest) Context(ctx context.Context) UpsertRequest { + req.impl = req.impl.Context(ctx) + + return req +} + +// UpsertObjectOpts describes options for `crud.upsert_object` method. +type UpsertObjectOpts = SimpleOperationOpts + +// UpsertObjectRequest helps you to create request object to call +// `crud.upsert_object` for execution by a Connection. +type UpsertObjectRequest struct { + spaceRequest + object Object + operations []Operation + opts UpsertObjectOpts +} + +type upsertObjectArgs struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Space string + Object Object + Operations []Operation + Opts UpsertObjectOpts +} + +// MakeUpsertObjectRequest returns a new empty UpsertObjectRequest. +func MakeUpsertObjectRequest(space string) UpsertObjectRequest { + req := UpsertObjectRequest{} + req.impl = newCall("crud.upsert_object") + req.space = space + req.operations = []Operation{} + req.opts = UpsertObjectOpts{} + return req +} + +// Object sets the tuple for the UpsertObjectRequest request. +// Note: default value is nil. +func (req UpsertObjectRequest) Object(object Object) UpsertObjectRequest { + req.object = object + return req +} + +// Operations sets the operations for the UpsertObjectRequest request. +// Note: default value is nil. +func (req UpsertObjectRequest) Operations(operations []Operation) UpsertObjectRequest { + req.operations = operations + return req +} + +// Opts sets the options for the UpsertObjectRequest request. +// Note: default value is nil. +func (req UpsertObjectRequest) Opts(opts UpsertObjectOpts) UpsertObjectRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req UpsertObjectRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + if req.object == nil { + req.object = MapObject{} + } + args := upsertObjectArgs{Space: req.space, Object: req.object, + Operations: req.operations, Opts: req.opts} + req.impl = req.impl.Args(args) + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req UpsertObjectRequest) Context(ctx context.Context) UpsertObjectRequest { + req.impl = req.impl.Context(ctx) + + return req +} diff --git a/crud/upsert_many.go b/crud/upsert_many.go new file mode 100644 index 000000000..dad7dd158 --- /dev/null +++ b/crud/upsert_many.go @@ -0,0 +1,141 @@ +package crud + +import ( + "context" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// UpsertManyOpts describes options for `crud.upsert_many` method. +type UpsertManyOpts = OperationManyOpts + +// TupleOperationsData contains tuple with operations to be applied to tuple. +type TupleOperationsData struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Tuple Tuple + Operations []Operation +} + +// UpsertManyRequest helps you to create request object to call +// `crud.upsert_many` for execution by a Connection. +type UpsertManyRequest struct { + spaceRequest + tuplesOperationsData []TupleOperationsData + opts UpsertManyOpts +} + +type upsertManyArgs struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Space string + TuplesOperationsData []TupleOperationsData + Opts UpsertManyOpts +} + +// MakeUpsertManyRequest returns a new empty UpsertManyRequest. +func MakeUpsertManyRequest(space string) UpsertManyRequest { + req := UpsertManyRequest{} + req.impl = newCall("crud.upsert_many") + req.space = space + req.tuplesOperationsData = []TupleOperationsData{} + req.opts = UpsertManyOpts{} + return req +} + +// TuplesOperationsData sets tuples and operations for +// the UpsertManyRequest request. +// Note: default value is nil. +func (req UpsertManyRequest) TuplesOperationsData( + tuplesOperationData []TupleOperationsData) UpsertManyRequest { + req.tuplesOperationsData = tuplesOperationData + return req +} + +// Opts sets the options for the UpsertManyRequest request. +// Note: default value is nil. +func (req UpsertManyRequest) Opts(opts UpsertManyOpts) UpsertManyRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req UpsertManyRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + args := upsertManyArgs{Space: req.space, TuplesOperationsData: req.tuplesOperationsData, + Opts: req.opts} + req.impl = req.impl.Args(args) + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req UpsertManyRequest) Context(ctx context.Context) UpsertManyRequest { + req.impl = req.impl.Context(ctx) + + return req +} + +// UpsertObjectManyOpts describes options for `crud.upsert_object_many` method. +type UpsertObjectManyOpts = OperationManyOpts + +// ObjectOperationsData contains object with operations to be applied to object. +type ObjectOperationsData struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Object Object + Operations []Operation +} + +// UpsertObjectManyRequest helps you to create request object to call +// `crud.upsert_object_many` for execution by a Connection. +type UpsertObjectManyRequest struct { + spaceRequest + objectsOperationsData []ObjectOperationsData + opts UpsertObjectManyOpts +} + +type upsertObjectManyArgs struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Space string + ObjectsOperationsData []ObjectOperationsData + Opts UpsertObjectManyOpts +} + +// MakeUpsertObjectManyRequest returns a new empty UpsertObjectManyRequest. +func MakeUpsertObjectManyRequest(space string) UpsertObjectManyRequest { + req := UpsertObjectManyRequest{} + req.impl = newCall("crud.upsert_object_many") + req.space = space + req.objectsOperationsData = []ObjectOperationsData{} + req.opts = UpsertObjectManyOpts{} + return req +} + +// ObjectOperationsData sets objects and operations +// for the UpsertObjectManyRequest request. +// Note: default value is nil. +func (req UpsertObjectManyRequest) ObjectsOperationsData( + objectsOperationData []ObjectOperationsData) UpsertObjectManyRequest { + req.objectsOperationsData = objectsOperationData + return req +} + +// Opts sets the options for the UpsertObjectManyRequest request. +// Note: default value is nil. +func (req UpsertObjectManyRequest) Opts(opts UpsertObjectManyOpts) UpsertObjectManyRequest { + req.opts = opts + return req +} + +// Body fills an encoder with the call request body. +func (req UpsertObjectManyRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + args := upsertObjectManyArgs{Space: req.space, ObjectsOperationsData: req.objectsOperationsData, + Opts: req.opts} + req.impl = req.impl.Args(args) + return req.impl.Body(res, enc) +} + +// Context sets a passed context to CRUD request. +func (req UpsertObjectManyRequest) Context(ctx context.Context) UpsertObjectManyRequest { + req.impl = req.impl.Context(ctx) + + return req +} diff --git a/datetime/adjust.go b/datetime/adjust.go new file mode 100644 index 000000000..35812f45f --- /dev/null +++ b/datetime/adjust.go @@ -0,0 +1,31 @@ +package datetime + +// An Adjust is used as a parameter for date adjustions, see: +// https://github.com/tarantool/tarantool/wiki/Datetime-Internals#date-adjustions-and-leap-years +type Adjust int + +const ( + NoneAdjust Adjust = 0 // adjust = "none" in Tarantool + ExcessAdjust Adjust = 1 // adjust = "excess" in Tarantool + LastAdjust Adjust = 2 // adjust = "last" in Tarantool +) + +// We need the mappings to make NoneAdjust as a default value instead of +// dtExcess. +const ( + dtExcess = 0 // DT_EXCESS from dt-c/dt_arithmetic.h + dtLimit = 1 // DT_LIMIT + dtSnap = 2 // DT_SNAP +) + +var adjustToDt = map[Adjust]int64{ + NoneAdjust: dtLimit, + ExcessAdjust: dtExcess, + LastAdjust: dtSnap, +} + +var dtToAdjust = map[int64]Adjust{ + dtExcess: ExcessAdjust, + dtLimit: NoneAdjust, + dtSnap: LastAdjust, +} diff --git a/datetime/config.lua b/datetime/config.lua new file mode 100644 index 000000000..9b5baf719 --- /dev/null +++ b/datetime/config.lua @@ -0,0 +1,79 @@ +local has_datetime, datetime = pcall(require, 'datetime') + +if not has_datetime then + error('Datetime unsupported, use Tarantool 2.10 or newer') +end + +-- Do not set listen for now so connector won't be +-- able to send requests until everything is configured. +box.cfg{ + work_dir = os.getenv("TEST_TNT_WORK_DIR"), +} + +box.schema.user.create('test', { password = 'test' , if_not_exists = true }) +box.schema.user.grant('test', 'execute', 'universe', nil, { if_not_exists = true }) + +box.once("init", function() + local s_1 = box.schema.space.create('testDatetime_1', { + id = 524, + if_not_exists = true, + }) + s_1:create_index('primary', { + type = 'TREE', + parts = { + { field = 1, type = 'datetime' }, + }, + if_not_exists = true + }) + s_1:truncate() + + local s_3 = box.schema.space.create('testDatetime_2', { + id = 526, + if_not_exists = true, + }) + s_3:create_index('primary', { + type = 'tree', + parts = { + {1, 'uint'}, + }, + if_not_exists = true + }) + s_3:truncate() + + box.schema.func.create('call_datetime_testdata') + box.schema.user.grant('test', 'read,write', 'space', 'testDatetime_1', { if_not_exists = true }) + box.schema.user.grant('test', 'read,write', 'space', 'testDatetime_2', { if_not_exists = true }) +end) + +local function call_datetime_testdata() + local dt1 = datetime.new({ year = 1934 }) + local dt2 = datetime.new({ year = 1961 }) + local dt3 = datetime.new({ year = 1968 }) + return { + { + 5, "Go!", { + {"Klushino", dt1}, + {"Baikonur", dt2}, + {"Novoselovo", dt3}, + }, + } + } +end +rawset(_G, 'call_datetime_testdata', call_datetime_testdata) + +local function call_interval_testdata(interval) + return interval +end +rawset(_G, 'call_interval_testdata', call_interval_testdata) + +local function call_datetime_interval(dtleft, dtright) + return dtright - dtleft +end +rawset(_G, 'call_datetime_interval', call_datetime_interval) + +-- Set listen only when every other thing is configured. +box.cfg{ + listen = os.getenv("TEST_TNT_LISTEN"), +} + +require('console').start() diff --git a/datetime/datetime.go b/datetime/datetime.go new file mode 100644 index 000000000..23901305b --- /dev/null +++ b/datetime/datetime.go @@ -0,0 +1,356 @@ +// Package with support of Tarantool's datetime data type. +// +// Datetime data type supported in Tarantool since 2.10. +// +// Since: 1.7.0 +// +// See also: +// +// * Datetime Internals https://github.com/tarantool/tarantool/wiki/Datetime-Internals +package datetime + +import ( + "encoding/binary" + "fmt" + "reflect" + "time" + + "github.com/vmihailenco/msgpack/v5" +) + +// Datetime MessagePack serialization schema is an MP_EXT extension, which +// creates container of 8 or 16 bytes long payload. +// +// +---------+--------+===============+-------------------------------+ +// |0xd7/0xd8|type (4)| seconds (8b) | nsec; tzoffset; tzindex; (8b) | +// +---------+--------+===============+-------------------------------+ +// +// MessagePack data encoded using fixext8 (0xd7) or fixext16 (0xd8), and may +// contain: +// +// * [required] seconds parts as full, unencoded, signed 64-bit integer, +// stored in little-endian order; +// +// * [optional] all the other fields (nsec, tzoffset, tzindex) if any of them +// were having not 0 value. They are packed naturally in little-endian order; + +// Datetime external type. Supported since Tarantool 2.10. See more details in +// issue https://github.com/tarantool/tarantool/issues/5946. +const datetimeExtID = 4 + +// datetime structure keeps a number of seconds and nanoseconds since Unix Epoch. +// Time is normalized by UTC, so time-zone offset is informative only. +type datetime struct { + // Seconds since Epoch, where the epoch is the point where the time + // starts, and is platform dependent. For Unix, the epoch is January 1, + // 1970, 00:00:00 (UTC). Tarantool uses a double type, see a structure + // definition in src/lib/core/datetime.h and reasons in + // https://github.com/tarantool/tarantool/wiki/Datetime-internals#intervals-in-c + seconds int64 + // Nanoseconds, fractional part of seconds. Tarantool uses int32_t, see + // a definition in src/lib/core/datetime.h. + nsec int32 + // Timezone offset in minutes from UTC. Tarantool uses a int16_t type, + // see a structure definition in src/lib/core/datetime.h. + tzOffset int16 + // Olson timezone id. Tarantool uses a int16_t type, see a structure + // definition in src/lib/core/datetime.h. + tzIndex int16 +} + +// Size of datetime fields in a MessagePack value. +const ( + secondsSize = 8 + nsecSize = 4 + tzIndexSize = 2 + tzOffsetSize = 2 +) + +// Limits are from c-dt library: +// https://github.com/tarantool/c-dt/blob/e6214325fe8d4336464ebae859ac2b456fd22b77/API.pod#introduction +// https://github.com/tarantool/tarantool/blob/a99ccce5f517d2a04670289d3d09a8cc2f5916f9/src/lib/core/datetime.h#L44-L61 +const ( + minSeconds = -185604722870400 + maxSeconds = 185480451417600 +) + +const maxSize = secondsSize + nsecSize + tzIndexSize + tzOffsetSize + +//go:generate go tool gentypes -ext-code 4 Datetime +type Datetime struct { + time time.Time +} + +const ( + // NoTimezone allows to create a datetime without UTC timezone for + // Tarantool. The problem is that Golang by default creates a time value + // with UTC timezone. So it is a way to create a datetime without timezone. + NoTimezone = "" +) + +var noTimezoneLoc = time.FixedZone(NoTimezone, 0) + +const ( + offsetMin = -12 * 60 * 60 + offsetMax = 14 * 60 * 60 +) + +// MakeDatetime returns a datetime.Datetime object that contains a +// specified time.Time. It may return an error if the Time value is out of +// supported range: [-5879610-06-22T00:00Z .. 5879611-07-11T00:00Z] or +// an invalid timezone or offset value is out of supported range: +// [-12 * 60 * 60, 14 * 60 * 60]. +// +// NOTE: Tarantool's datetime.tz value is picked from t.Location().String(). +// "Local" location is unsupported, see ExampleMakeDatetime_localUnsupported. +func MakeDatetime(t time.Time) (Datetime, error) { + dt := Datetime{} + seconds := t.Unix() + + if seconds < minSeconds || seconds > maxSeconds { + return dt, fmt.Errorf("time %s is out of supported range", t) + } + + zone := t.Location().String() + _, offset := t.Zone() + if zone != NoTimezone { + if _, ok := timezoneToIndex[zone]; !ok { + return dt, fmt.Errorf("unknown timezone %s with offset %d", + zone, offset) + } + } + + if offset < offsetMin || offset > offsetMax { + return dt, fmt.Errorf("offset must be between %d and %d hours", + offsetMin, offsetMax) + } + + dt.time = t + return dt, nil +} + +func intervalFromDatetime(dtime Datetime) (ival Interval) { + ival.Year = int64(dtime.time.Year()) + ival.Month = int64(dtime.time.Month()) + ival.Day = int64(dtime.time.Day()) + ival.Hour = int64(dtime.time.Hour()) + ival.Min = int64(dtime.time.Minute()) + ival.Sec = int64(dtime.time.Second()) + ival.Nsec = int64(dtime.time.Nanosecond()) + ival.Adjust = NoneAdjust + + return ival +} + +func daysInMonth(year int64, month int64) int64 { + if month == 12 { + year++ + month = 1 + } else { + month += 1 + } + + // We use the fact that time.Date accepts values outside their usual + // ranges - the values are normalized during the conversion. + // + // So we got a day (year, month - 1, last day of the month) before + // (year, month, 1) because we pass (year, month, 0). + return int64(time.Date(int(year), time.Month(month), 0, 0, 0, 0, 0, time.UTC).Day()) +} + +// C implementation: +// https://github.com/tarantool/c-dt/blob/cec6acebb54d9e73ea0b99c63898732abd7683a6/dt_arithmetic.c#L74-L98 +func addMonth(ival Interval, delta int64, adjust Adjust) Interval { + oldYear := ival.Year + oldMonth := ival.Month + + ival.Month += delta + if ival.Month < 1 || ival.Month > 12 { + ival.Year += ival.Month / 12 + ival.Month %= 12 + if ival.Month < 1 { + ival.Year-- + ival.Month += 12 + } + } + if adjust == ExcessAdjust || ival.Day < 28 { + return ival + } + + dim := daysInMonth(ival.Year, ival.Month) + if ival.Day > dim || (adjust == LastAdjust && ival.Day == daysInMonth(oldYear, oldMonth)) { + ival.Day = dim + } + return ival +} + +// MarshalMsgpack implements a custom msgpack marshaler. +func (d Datetime) MarshalMsgpack() ([]byte, error) { + tm := d.ToTime() + + var dt datetime + dt.seconds = tm.Unix() + dt.nsec = int32(tm.Nanosecond()) + + zone := tm.Location().String() + _, offset := tm.Zone() + if zone != NoTimezone { + // The zone value already checked in MakeDatetime() or + // UnmarshalMsgpack() calls. + dt.tzIndex = int16(timezoneToIndex[zone]) + } + dt.tzOffset = int16(offset / 60) + + var bytesSize = secondsSize + if dt.nsec != 0 || dt.tzOffset != 0 || dt.tzIndex != 0 { + bytesSize += nsecSize + tzIndexSize + tzOffsetSize + } + + buf := make([]byte, bytesSize) + binary.LittleEndian.PutUint64(buf, uint64(dt.seconds)) + if bytesSize == maxSize { + binary.LittleEndian.PutUint32(buf[secondsSize:], uint32(dt.nsec)) + binary.LittleEndian.PutUint16(buf[secondsSize+nsecSize:], uint16(dt.tzOffset)) + binary.LittleEndian.PutUint16(buf[secondsSize+nsecSize+tzOffsetSize:], uint16(dt.tzIndex)) + } + + return buf, nil +} + +// UnmarshalMsgpack implements a custom msgpack unmarshaler. +func (d *Datetime) UnmarshalMsgpack(data []byte) error { + var dt datetime + + sec := binary.LittleEndian.Uint64(data) + dt.seconds = int64(sec) + dt.nsec = 0 + if len(data) == maxSize { + dt.nsec = int32(binary.LittleEndian.Uint32(data[secondsSize:])) + dt.tzOffset = int16(binary.LittleEndian.Uint16(data[secondsSize+nsecSize:])) + dt.tzIndex = int16(binary.LittleEndian.Uint16(data[secondsSize+nsecSize+tzOffsetSize:])) + } + + tt := time.Unix(dt.seconds, int64(dt.nsec)) + + loc := noTimezoneLoc + if dt.tzIndex != 0 || dt.tzOffset != 0 { + zone := NoTimezone + offset := int(dt.tzOffset) * 60 + + if dt.tzIndex != 0 { + if _, ok := indexToTimezone[int(dt.tzIndex)]; !ok { + return fmt.Errorf("unknown timezone index %d", dt.tzIndex) + } + zone = indexToTimezone[int(dt.tzIndex)] + } + if zone != NoTimezone { + if loadLoc, err := time.LoadLocation(zone); err == nil { + loc = loadLoc + } else { + // Unable to load location. + loc = time.FixedZone(zone, offset) + } + } else { + // Only offset. + loc = time.FixedZone(zone, offset) + } + } + tt = tt.In(loc) + + newDatetime, err := MakeDatetime(tt) + if err != nil { + return err + } + + *d = newDatetime + + return nil +} + +func (d Datetime) add(ival Interval, positive bool) (Datetime, error) { + newVal := intervalFromDatetime(d) + + var direction int64 + if positive { + direction = 1 + } else { + direction = -1 + } + + newVal = addMonth(newVal, direction*ival.Year*12+direction*ival.Month, ival.Adjust) + newVal.Day += direction * 7 * ival.Week + newVal.Day += direction * ival.Day + newVal.Hour += direction * ival.Hour + newVal.Min += direction * ival.Min + newVal.Sec += direction * ival.Sec + newVal.Nsec += direction * ival.Nsec + + tm := time.Date(int(newVal.Year), time.Month(newVal.Month), + int(newVal.Day), int(newVal.Hour), int(newVal.Min), + int(newVal.Sec), int(newVal.Nsec), d.time.Location()) + + return MakeDatetime(tm) +} + +// Add creates a new Datetime as addition of the Datetime and Interval. It may +// return an error if a new Datetime is out of supported range. +func (d Datetime) Add(ival Interval) (Datetime, error) { + return d.add(ival, true) +} + +// Sub creates a new Datetime as subtraction of the Datetime and Interval. It +// may return an error if a new Datetime is out of supported range. +func (d Datetime) Sub(ival Interval) (Datetime, error) { + return d.add(ival, false) +} + +// Interval returns an Interval value to a next Datetime value. +func (d Datetime) Interval(next Datetime) Interval { + curIval := intervalFromDatetime(d) + nextIval := intervalFromDatetime(next) + _, curOffset := d.time.Zone() + _, nextOffset := next.time.Zone() + curIval.Min -= int64(curOffset-nextOffset) / 60 + return nextIval.Sub(curIval) +} + +// ToTime returns a time.Time that Datetime contains. +// +// If a Datetime created from time.Time value then an original location is used +// for the time value. +// +// If a Datetime created via unmarshaling Tarantool's datetime then we try to +// create a location with time.LoadLocation() first. In case of failure, we use +// a location created with time.FixedZone(). +func (d *Datetime) ToTime() time.Time { + return d.time +} + +func datetimeEncoder(e *msgpack.Encoder, v reflect.Value) ([]byte, error) { + dtime := v.Interface().(Datetime) + + return dtime.MarshalMsgpack() +} + +func datetimeDecoder(d *msgpack.Decoder, v reflect.Value, extLen int) error { + if extLen != maxSize && extLen != secondsSize { + return fmt.Errorf("invalid data length: got %d, wanted %d or %d", + extLen, secondsSize, maxSize) + } + + b := make([]byte, extLen) + switch n, err := d.Buffered().Read(b); { + case err != nil: + return err + case n < extLen: + return fmt.Errorf("msgpack: unexpected end of stream after %d datetime bytes", n) + } + + ptr := v.Addr().Interface().(*Datetime) + return ptr.UnmarshalMsgpack(b) +} + +func init() { + msgpack.RegisterExtDecoder(datetimeExtID, Datetime{}, datetimeDecoder) + msgpack.RegisterExtEncoder(datetimeExtID, Datetime{}, datetimeEncoder) +} diff --git a/datetime/datetime_gen.go b/datetime/datetime_gen.go new file mode 100644 index 000000000..753d9c371 --- /dev/null +++ b/datetime/datetime_gen.go @@ -0,0 +1,241 @@ +// Code generated by github.com/tarantool/go-option; DO NOT EDIT. + +package datetime + +import ( + "fmt" + + "github.com/vmihailenco/msgpack/v5" + "github.com/vmihailenco/msgpack/v5/msgpcode" + + "github.com/tarantool/go-option" +) + +// OptionalDatetime represents an optional value of type Datetime. +// It can either hold a valid Datetime (IsSome == true) or be empty (IsZero == true). +type OptionalDatetime struct { + value Datetime + exists bool +} + +// SomeOptionalDatetime creates an optional OptionalDatetime with the given Datetime value. +// The returned OptionalDatetime will have IsSome() == true and IsZero() == false. +func SomeOptionalDatetime(value Datetime) OptionalDatetime { + return OptionalDatetime{ + value: value, + exists: true, + } +} + +// NoneOptionalDatetime creates an empty optional OptionalDatetime value. +// The returned OptionalDatetime will have IsSome() == false and IsZero() == true. +// +// Example: +// +// o := NoneOptionalDatetime() +// if o.IsZero() { +// fmt.Println("value is absent") +// } +func NoneOptionalDatetime() OptionalDatetime { + return OptionalDatetime{} +} + +func (o OptionalDatetime) newEncodeError(err error) error { + if err == nil { + return nil + } + return &option.EncodeError{ + Type: "OptionalDatetime", + Parent: err, + } +} + +func (o OptionalDatetime) newDecodeError(err error) error { + if err == nil { + return nil + } + + return &option.DecodeError{ + Type: "OptionalDatetime", + Parent: err, + } +} + +// IsSome returns true if the OptionalDatetime contains a value. +// This indicates the value is explicitly set (not None). +func (o OptionalDatetime) IsSome() bool { + return o.exists +} + +// IsZero returns true if the OptionalDatetime does not contain a value. +// Equivalent to !IsSome(). Useful for consistency with types where +// zero value (e.g. 0, false, zero struct) is valid and needs to be distinguished. +func (o OptionalDatetime) IsZero() bool { + return !o.exists +} + +// IsNil is an alias for IsZero. +// +// This method is provided for compatibility with the msgpack Encoder interface. +func (o OptionalDatetime) IsNil() bool { + return o.IsZero() +} + +// Get returns the stored value and a boolean flag indicating its presence. +// If the value is present, returns (value, true). +// If the value is absent, returns (zero value of Datetime, false). +// +// Recommended usage: +// +// if value, ok := o.Get(); ok { +// // use value +// } +func (o OptionalDatetime) Get() (Datetime, bool) { + return o.value, o.exists +} + +// MustGet returns the stored value if it is present. +// Panics if the value is absent (i.e., IsZero() == true). +// +// Use with caution — only when you are certain the value exists. +// +// Panics with: "optional value is not set" if no value is set. +func (o OptionalDatetime) MustGet() Datetime { + if !o.exists { + panic("optional value is not set") + } + + return o.value +} + +// Unwrap returns the stored value regardless of presence. +// If no value is set, returns the zero value for Datetime. +// +// Warning: Does not check presence. Use IsSome() before calling if you need +// to distinguish between absent value and explicit zero value. +func (o OptionalDatetime) Unwrap() Datetime { + return o.value +} + +// UnwrapOr returns the stored value if present. +// Otherwise, returns the provided default value. +// +// Example: +// +// o := NoneOptionalDatetime() +// v := o.UnwrapOr(someDefaultOptionalDatetime) +func (o OptionalDatetime) UnwrapOr(defaultValue Datetime) Datetime { + if o.exists { + return o.value + } + + return defaultValue +} + +// UnwrapOrElse returns the stored value if present. +// Otherwise, calls the provided function and returns its result. +// Useful when the default value requires computation or side effects. +// +// Example: +// +// o := NoneOptionalDatetime() +// v := o.UnwrapOrElse(func() Datetime { return computeDefault() }) +func (o OptionalDatetime) UnwrapOrElse(defaultValue func() Datetime) Datetime { + if o.exists { + return o.value + } + + return defaultValue() +} + +func (o OptionalDatetime) encodeValue(encoder *msgpack.Encoder) error { + value, err := o.value.MarshalMsgpack() + if err != nil { + return err + } + + err = encoder.EncodeExtHeader(4, len(value)) + if err != nil { + return err + } + + _, err = encoder.Writer().Write(value) + if err != nil { + return err + } + + return nil +} + +// EncodeMsgpack encodes the OptionalDatetime value using MessagePack format. +// - If the value is present, it is encoded as Datetime. +// - If the value is absent (None), it is encoded as nil. +// +// Returns an error if encoding fails. +func (o OptionalDatetime) EncodeMsgpack(encoder *msgpack.Encoder) error { + if o.exists { + return o.newEncodeError(o.encodeValue(encoder)) + } + + return o.newEncodeError(encoder.EncodeNil()) +} + +func (o *OptionalDatetime) decodeValue(decoder *msgpack.Decoder) error { + tp, length, err := decoder.DecodeExtHeader() + switch { + case err != nil: + return o.newDecodeError(err) + case tp != 4: + return o.newDecodeError(fmt.Errorf("invalid extension code: %d", tp)) + } + + a := make([]byte, length) + if err := decoder.ReadFull(a); err != nil { + return o.newDecodeError(err) + } + + if err := o.value.UnmarshalMsgpack(a); err != nil { + return o.newDecodeError(err) + } + + o.exists = true + return nil +} + +func (o *OptionalDatetime) checkCode(code byte) bool { + return msgpcode.IsExt(code) +} + +// DecodeMsgpack decodes a OptionalDatetime value from MessagePack format. +// Supports two input types: +// - nil: interpreted as no value (NoneOptionalDatetime) +// - Datetime: interpreted as a present value (SomeOptionalDatetime) +// +// Returns an error if the input type is unsupported or decoding fails. +// +// After successful decoding: +// - on nil: exists = false, value = default zero value +// - on Datetime: exists = true, value = decoded value +func (o *OptionalDatetime) DecodeMsgpack(decoder *msgpack.Decoder) error { + code, err := decoder.PeekCode() + if err != nil { + return o.newDecodeError(err) + } + + switch { + case code == msgpcode.Nil: + o.exists = false + + return o.newDecodeError(decoder.Skip()) + case o.checkCode(code): + err := o.decodeValue(decoder) + if err != nil { + return o.newDecodeError(err) + } + o.exists = true + + return err + default: + return o.newDecodeError(fmt.Errorf("unexpected code: %d", code)) + } +} diff --git a/datetime/datetime_gen_test.go b/datetime/datetime_gen_test.go new file mode 100644 index 000000000..6b45b6fb1 --- /dev/null +++ b/datetime/datetime_gen_test.go @@ -0,0 +1,125 @@ +package datetime + +import ( + "bytes" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/vmihailenco/msgpack/v5" +) + +func TestSomeOptionalDatetime(t *testing.T) { + val, err := MakeDatetime(time.Now().In(time.UTC)) + assert.NoError(t, err) + opt := SomeOptionalDatetime(val) + + assert.True(t, opt.IsSome()) + assert.False(t, opt.IsZero()) + + v, ok := opt.Get() + assert.True(t, ok) + assert.Equal(t, val, v) +} + +func TestNoneOptionalDatetime(t *testing.T) { + opt := NoneOptionalDatetime() + + assert.False(t, opt.IsSome()) + assert.True(t, opt.IsZero()) + + _, ok := opt.Get() + assert.False(t, ok) +} + +func TestOptionalDatetime_MustGet(t *testing.T) { + val, err := MakeDatetime(time.Now().In(time.UTC)) + assert.NoError(t, err) + optSome := SomeOptionalDatetime(val) + optNone := NoneOptionalDatetime() + + assert.Equal(t, val, optSome.MustGet()) + assert.Panics(t, func() { optNone.MustGet() }) +} + +func TestOptionalDatetime_Unwrap(t *testing.T) { + val, err := MakeDatetime(time.Now().In(time.UTC)) + assert.NoError(t, err) + optSome := SomeOptionalDatetime(val) + optNone := NoneOptionalDatetime() + + assert.Equal(t, val, optSome.Unwrap()) + assert.Equal(t, Datetime{}, optNone.Unwrap()) +} + +func TestOptionalDatetime_UnwrapOr(t *testing.T) { + val, err := MakeDatetime(time.Now().In(time.UTC)) + assert.NoError(t, err) + def, err := MakeDatetime(time.Now().Add(1 * time.Hour).In(time.UTC)) + assert.NoError(t, err) + optSome := SomeOptionalDatetime(val) + optNone := NoneOptionalDatetime() + + assert.Equal(t, val, optSome.UnwrapOr(def)) + assert.Equal(t, def, optNone.UnwrapOr(def)) +} + +func TestOptionalDatetime_UnwrapOrElse(t *testing.T) { + val, err := MakeDatetime(time.Now().In(time.UTC)) + assert.NoError(t, err) + def, err := MakeDatetime(time.Now().Add(1 * time.Hour).In(time.UTC)) + assert.NoError(t, err) + optSome := SomeOptionalDatetime(val) + optNone := NoneOptionalDatetime() + + assert.Equal(t, val, optSome.UnwrapOrElse(func() Datetime { return def })) + assert.Equal(t, def, optNone.UnwrapOrElse(func() Datetime { return def })) +} + +func TestOptionalDatetime_EncodeDecodeMsgpack_Some(t *testing.T) { + val, err := MakeDatetime(time.Now().In(time.UTC)) + assert.NoError(t, err) + some := SomeOptionalDatetime(val) + + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + dec := msgpack.NewDecoder(&buf) + + err = enc.Encode(some) + assert.NoError(t, err) + + var decodedSome OptionalDatetime + err = dec.Decode(&decodedSome) + assert.NoError(t, err) + assert.True(t, decodedSome.IsSome()) + assert.Equal(t, val, decodedSome.Unwrap()) +} + +func TestOptionalDatetime_EncodeDecodeMsgpack_None(t *testing.T) { + none := NoneOptionalDatetime() + + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + dec := msgpack.NewDecoder(&buf) + + err := enc.Encode(none) + assert.NoError(t, err) + + var decodedNone OptionalDatetime + err = dec.Decode(&decodedNone) + assert.NoError(t, err) + assert.True(t, decodedNone.IsZero()) +} + +func TestOptionalDatetime_EncodeDecodeMsgpack_InvalidType(t *testing.T) { + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + dec := msgpack.NewDecoder(&buf) + + err := enc.Encode(123) + assert.NoError(t, err) + + var decodedInvalid OptionalDatetime + err = dec.Decode(&decodedInvalid) + assert.Error(t, err) +} diff --git a/datetime/datetime_test.go b/datetime/datetime_test.go new file mode 100644 index 000000000..d01153892 --- /dev/null +++ b/datetime/datetime_test.go @@ -0,0 +1,1180 @@ +package datetime_test + +import ( + "encoding/hex" + "fmt" + "log" + "os" + "reflect" + "testing" + "time" + + "github.com/vmihailenco/msgpack/v5" + + . "github.com/tarantool/go-tarantool/v3" + . "github.com/tarantool/go-tarantool/v3/datetime" + "github.com/tarantool/go-tarantool/v3/test_helpers" +) + +var noTimezoneLoc = time.FixedZone(NoTimezone, 0) + +var lesserBoundaryTimes = []time.Time{ + time.Date(-5879610, 06, 22, 0, 0, 1, 0, time.UTC), + time.Date(-5879610, 06, 22, 0, 0, 0, 1, time.UTC), + time.Date(5879611, 07, 10, 23, 59, 59, 0, time.UTC), + time.Date(5879611, 07, 10, 23, 59, 59, 999999999, time.UTC), +} + +var boundaryTimes = []time.Time{ + time.Date(-5879610, 06, 22, 0, 0, 0, 0, time.UTC), + time.Date(5879611, 07, 11, 0, 0, 0, 999999999, time.UTC), +} + +var greaterBoundaryTimes = []time.Time{ + time.Date(-5879610, 06, 21, 23, 59, 59, 999999999, time.UTC), + time.Date(5879611, 07, 11, 0, 0, 1, 0, time.UTC), +} + +var isDatetimeSupported = false + +var server = "127.0.0.1:3013" +var opts = Opts{ + Timeout: 5 * time.Second, +} +var dialer = NetDialer{ + Address: server, + User: "test", + Password: "test", +} + +var spaceTuple1 = "testDatetime_1" +var spaceTuple2 = "testDatetime_2" +var index = "primary" + +func skipIfDatetimeUnsupported(t *testing.T) { + t.Helper() + + if isDatetimeSupported == false { + t.Skip("Skipping test for Tarantool without datetime support in msgpack") + } +} + +func TestDatetimeAdd(t *testing.T) { + tm := time.Unix(0, 0).UTC() + dt, err := MakeDatetime(tm) + if err != nil { + t.Fatalf("Unexpected error: %s", err.Error()) + } + + newdt, err := dt.Add(Interval{ + Year: 1, + Month: -3, + Week: 3, + Day: 4, + Hour: -5, + Min: 5, + Sec: 6, + Nsec: -3, + }) + if err != nil { + t.Fatalf("Unexpected error: %s", err.Error()) + } + + expected := "1970-10-25 19:05:05.999999997 +0000 UTC" + if newdt.ToTime().String() != expected { + t.Fatalf("Unexpected result: %s, expected: %s", newdt.ToTime().String(), expected) + } +} + +func TestDatetimeAddAdjust(t *testing.T) { + /* + How-to test in Tarantool: + > date = require("datetime") + > date.parse("2012-12-31T00:00:00Z") + {month = -1, adjust = "excess"} + */ + cases := []struct { + year int64 + month int64 + adjust Adjust + date string + want string + }{ + { + year: 0, + month: 1, + adjust: NoneAdjust, + date: "2013-02-28T00:00:00Z", + want: "2013-03-28T00:00:00Z", + }, + { + year: 0, + month: 1, + adjust: LastAdjust, + date: "2013-02-28T00:00:00Z", + want: "2013-03-31T00:00:00Z", + }, + { + year: 0, + month: 1, + adjust: ExcessAdjust, + date: "2013-02-28T00:00:00Z", + want: "2013-03-28T00:00:00Z", + }, + { + year: 0, + month: 1, + adjust: NoneAdjust, + date: "2013-01-31T00:00:00Z", + want: "2013-02-28T00:00:00Z", + }, + { + year: 0, + month: 1, + adjust: LastAdjust, + date: "2013-01-31T00:00:00Z", + want: "2013-02-28T00:00:00Z", + }, + { + year: 0, + month: 1, + adjust: ExcessAdjust, + date: "2013-01-31T00:00:00Z", + want: "2013-03-03T00:00:00Z", + }, + { + year: 2, + month: 2, + adjust: NoneAdjust, + date: "2011-12-31T00:00:00Z", + want: "2014-02-28T00:00:00Z", + }, + { + year: 2, + month: 2, + adjust: LastAdjust, + date: "2011-12-31T00:00:00Z", + want: "2014-02-28T00:00:00Z", + }, + { + year: 2, + month: 2, + adjust: ExcessAdjust, + date: "2011-12-31T00:00:00Z", + want: "2014-03-03T00:00:00Z", + }, + { + year: 0, + month: -1, + adjust: NoneAdjust, + date: "2013-02-28T00:00:00Z", + want: "2013-01-28T00:00:00Z", + }, + { + year: 0, + month: -1, + adjust: LastAdjust, + date: "2013-02-28T00:00:00Z", + want: "2013-01-31T00:00:00Z", + }, + { + year: 0, + month: -1, + adjust: ExcessAdjust, + date: "2013-02-28T00:00:00Z", + want: "2013-01-28T00:00:00Z", + }, + { + year: 0, + month: -1, + adjust: NoneAdjust, + date: "2012-12-31T00:00:00Z", + want: "2012-11-30T00:00:00Z", + }, + { + year: 0, + month: -1, + adjust: LastAdjust, + date: "2012-12-31T00:00:00Z", + want: "2012-11-30T00:00:00Z", + }, + { + year: 0, + month: -1, + adjust: ExcessAdjust, + date: "2012-12-31T00:00:00Z", + want: "2012-12-01T00:00:00Z", + }, + { + year: -2, + month: -2, + adjust: NoneAdjust, + date: "2011-01-31T00:00:00Z", + want: "2008-11-30T00:00:00Z", + }, + { + year: -2, + month: -2, + adjust: LastAdjust, + date: "2011-12-31T00:00:00Z", + want: "2009-10-31T00:00:00Z", + }, + { + year: -2, + month: -2, + adjust: ExcessAdjust, + date: "2011-12-31T00:00:00Z", + want: "2009-10-31T00:00:00Z", + }, + } + + for _, tc := range cases { + tm, err := time.Parse(time.RFC3339, tc.date) + if err != nil { + t.Fatalf("Unexpected error: %s", err.Error()) + } + dt, err := MakeDatetime(tm) + if err != nil { + t.Fatalf("Unexpected error: %s", err.Error()) + } + t.Run(fmt.Sprintf("%d_%d_%d_%s", tc.year, tc.month, tc.adjust, tc.date), + func(t *testing.T) { + newdt, err := dt.Add(Interval{ + Year: tc.year, + Month: tc.month, + Adjust: tc.adjust, + }) + if err != nil { + t.Fatalf("Unable to add: %s", err.Error()) + } + res := newdt.ToTime().Format(time.RFC3339) + if res != tc.want { + t.Fatalf("Unexpected result %s, expected %s", res, tc.want) + } + }) + } +} + +func TestDatetimeAddSubSymmetric(t *testing.T) { + tm := time.Unix(0, 0).UTC() + dt, err := MakeDatetime(tm) + if err != nil { + t.Fatalf("Unexpected error: %s", err.Error()) + } + + newdtadd, err := dt.Add(Interval{ + Year: 1, + Month: -3, + Week: 3, + Day: 4, + Hour: -5, + Min: 5, + Sec: 6, + Nsec: -3, + }) + if err != nil { + t.Fatalf("Unexpected error: %s", err.Error()) + } + + newdtsub, err := dt.Sub(Interval{ + Year: -1, + Month: 3, + Week: -3, + Day: -4, + Hour: 5, + Min: -5, + Sec: -6, + Nsec: 3, + }) + if err != nil { + t.Fatalf("Unexpected error: %s", err.Error()) + } + + expected := "1970-10-25 19:05:05.999999997 +0000 UTC" + addstr := newdtadd.ToTime().String() + substr := newdtsub.ToTime().String() + + if addstr != expected { + t.Fatalf("Unexpected Add result: %s, expected: %s", addstr, expected) + } + if substr != expected { + t.Fatalf("Unexpected Sub result: %s, expected: %s", substr, expected) + } +} + +// We have a separate test for accurate Datetime boundaries. +func TestDatetimeAddOutOfRange(t *testing.T) { + tm := time.Unix(0, 0).UTC() + dt, err := MakeDatetime(tm) + if err != nil { + t.Fatalf("Unexpected error: %s", err.Error()) + } + + newdt, err := dt.Add(Interval{Year: 1000000000}) + if err == nil { + t.Fatalf("Unexpected success: %v", newdt) + } + expected := "time 1000001970-01-01 00:00:00 +0000 UTC is out of supported range" + if err.Error() != expected { + t.Fatalf("Unexpected error: %s", err.Error()) + } +} + +func TestDatetimeInterval(t *testing.T) { + var first = "2015-03-20T17:50:56.000000009+02:00" + var second = "2013-01-31T11:51:58.00000009+01:00" + + tmFirst, err := time.Parse(time.RFC3339, first) + if err != nil { + t.Fatalf("Error in time.Parse(): %s", err) + } + tmSecond, err := time.Parse(time.RFC3339, second) + if err != nil { + t.Fatalf("Error in time.Parse(): %s", err) + } + + dtFirst, err := MakeDatetime(tmFirst) + if err != nil { + t.Fatalf("Unable to create Datetime from %s: %s", tmFirst, err) + } + dtSecond, err := MakeDatetime(tmSecond) + if err != nil { + t.Fatalf("Unable to create Datetime from %s: %s", tmSecond, err) + } + + ivalFirst := dtFirst.Interval(dtSecond) + ivalSecond := dtSecond.Interval(dtFirst) + + expectedFirst := Interval{-2, -2, 0, 11, -6, 61, 2, 81, NoneAdjust} + expectedSecond := Interval{2, 2, 0, -11, 6, -61, -2, -81, NoneAdjust} + + if !reflect.DeepEqual(ivalFirst, expectedFirst) { + t.Errorf("Unexpected interval %v, expected %v", ivalFirst, expectedFirst) + } + if !reflect.DeepEqual(ivalSecond, expectedSecond) { + t.Errorf("Unexpected interval %v, expected %v", ivalSecond, expectedSecond) + } + + dtFirst, err = dtFirst.Add(ivalFirst) + if err != nil { + t.Fatalf("Unable to add an interval: %s", err) + } + if !dtFirst.ToTime().Equal(dtSecond.ToTime()) { + t.Errorf("Incorrect add an interval result: %s, expected %s", + dtFirst.ToTime(), dtSecond.ToTime()) + } +} + +func TestDatetimeTarantoolInterval(t *testing.T) { + skipIfDatetimeUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + dates := []string{ + // We could return tests with timezones after a release with a fix of + // the bug: + // https://github.com/tarantool/tarantool/issues/7698 + // + // "2010-02-24T23:03:56.0000013-04:00", + // "2015-03-20T17:50:56.000000009+01:00", + // "2020-01-01T01:01:01+11:30", + // "2025-08-01T00:00:00.000000003+11:00", + "2010-02-24T23:03:56.0000013Z", + "2015-03-20T17:50:56.000000009Z", + "2020-01-01T01:01:01Z", + "2025-08-01T00:00:00.000000003Z", + "2015-12-21T17:50:53Z", + "1980-03-28T13:18:39.000099Z", + } + datetimes := []Datetime{} + for _, date := range dates { + tm, err := time.Parse(time.RFC3339, date) + if err != nil { + t.Fatalf("Error in time.Parse(%s): %s", date, err) + } + dt, err := MakeDatetime(tm) + if err != nil { + t.Fatalf("Error in MakeDatetime(%s): %s", tm, err) + } + datetimes = append(datetimes, dt) + } + + for _, dti := range datetimes { + for _, dtj := range datetimes { + t.Run(fmt.Sprintf("%s_to_%s", dti.ToTime(), dtj.ToTime()), + func(t *testing.T) { + req := NewCallRequest("call_datetime_interval"). + Args([]interface{}{dti, dtj}) + data, err := conn.Do(req).Get() + if err != nil { + t.Fatalf("Unable to call call_datetime_interval: %s", err) + } + ival := dti.Interval(dtj) + ret := data[0].(Interval) + if !reflect.DeepEqual(ival, ret) { + t.Fatalf("%v != %v", ival, ret) + } + }) + } + } +} + +// Expect that first element of tuple is time.Time. Compare extracted actual +// and expected datetime values. +func assertDatetimeIsEqual(t *testing.T, tuples []interface{}, tm time.Time) { + t.Helper() + + dtIndex := 0 + if tpl, ok := tuples[dtIndex].([]interface{}); !ok { + t.Fatalf("Unexpected return value body") + } else { + if len(tpl) != 2 { + t.Fatalf("Unexpected return value body (tuple len = %d)", len(tpl)) + } + if val, ok := tpl[dtIndex].(Datetime); !ok || !val.ToTime().Equal(tm) { + t.Fatalf("Unexpected tuple %d field %v, expected %v", + dtIndex, + val, + tm) + } + } +} + +func TestTimezonesIndexMapping(t *testing.T) { + for _, index := range TimezoneToIndex { + if _, ok := IndexToTimezone[index]; !ok { + t.Errorf("Index %d not found", index) + } + } +} + +func TestTimezonesZonesMapping(t *testing.T) { + for _, zone := range IndexToTimezone { + if _, ok := TimezoneToIndex[zone]; !ok { + t.Errorf("Zone %s not found", zone) + } + } +} + +func TestInvalidTimezone(t *testing.T) { + invalidLoc := time.FixedZone("AnyInvalid", 0) + tm, err := time.Parse(time.RFC3339, "2010-08-12T11:39:14Z") + if err != nil { + t.Fatalf("Time parse failed: %s", err) + } + tm = tm.In(invalidLoc) + dt, err := MakeDatetime(tm) + if err == nil { + t.Fatalf("Unexpected success: %v", dt) + } + if err.Error() != "unknown timezone AnyInvalid with offset 0" { + t.Fatalf("Unexpected error: %s", err.Error()) + } +} + +func TestInvalidOffset(t *testing.T) { + tests := []struct { + ok bool + offset int + }{ + {ok: true, offset: -12 * 60 * 60}, + {ok: true, offset: -12*60*60 + 1}, + {ok: true, offset: 14*60*60 - 1}, + {ok: true, offset: 14 * 60 * 60}, + {ok: false, offset: -12*60*60 - 1}, + {ok: false, offset: 14*60*60 + 1}, + } + + for _, testcase := range tests { + name := "" + if testcase.ok { + name = fmt.Sprintf("in_boundary_%d", testcase.offset) + } else { + name = fmt.Sprintf("out_of_boundary_%d", testcase.offset) + } + t.Run(name, func(t *testing.T) { + loc := time.FixedZone("MSK", testcase.offset) + tm, err := time.Parse(time.RFC3339, "2010-08-12T11:39:14Z") + if err != nil { + t.Fatalf("Time parse failed: %s", err) + } + tm = tm.In(loc) + dt, err := MakeDatetime(tm) + if testcase.ok && err != nil { + t.Fatalf("Unexpected error: %s", err.Error()) + } + if !testcase.ok && err == nil { + t.Fatalf("Unexpected success: %v", dt) + } + if testcase.ok && isDatetimeSupported { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + tupleInsertSelectDelete(t, conn, tm) + } + }) + } +} + +func TestCustomTimezone(t *testing.T) { + skipIfDatetimeUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + customZone := "Europe/Moscow" + customOffset := 180 * 60 + // Tarantool does not use a custom offset value if a time zone is provided. + // So it will change to an actual one. + zoneOffset := 240 * 60 + + customLoc := time.FixedZone(customZone, customOffset) + tm, err := time.Parse(time.RFC3339, "2010-08-12T11:44:14Z") + if err != nil { + t.Fatalf("Time parse failed: %s", err) + } + tm = tm.In(customLoc) + dt, err := MakeDatetime(tm) + if err != nil { + t.Fatalf("Unable to create datetime: %s", err.Error()) + } + + req := NewReplaceRequest(spaceTuple1).Tuple([]interface{}{dt, "payload"}) + data, err := conn.Do(req).Get() + if err != nil { + t.Fatalf("Datetime replace failed %s", err.Error()) + } + assertDatetimeIsEqual(t, data, tm) + + tpl := data[0].([]interface{}) + if respDt, ok := tpl[0].(Datetime); ok { + zone := respDt.ToTime().Location().String() + _, offset := respDt.ToTime().Zone() + if zone != customZone { + t.Fatalf("Expected zone %s instead of %s", customZone, zone) + } + if offset != zoneOffset { + t.Fatalf("Expected offset %d instead of %d", customOffset, offset) + } + + req := NewDeleteRequest(spaceTuple1).Key([]interface{}{dt}) + _, err = conn.Do(req).Get() + if err != nil { + t.Fatalf("Datetime delete failed: %s", err.Error()) + } + } else { + t.Fatalf("Datetime doesn't match") + } + +} + +func tupleInsertSelectDelete(t *testing.T, conn *Connection, tm time.Time) { + t.Helper() + + dt, err := MakeDatetime(tm) + if err != nil { + t.Fatalf("Unable to create Datetime from %s: %s", tm, err) + } + + // Insert tuple with datetime. + ins := NewInsertRequest(spaceTuple1).Tuple([]interface{}{dt, "payload"}) + _, err = conn.Do(ins).Get() + if err != nil { + t.Fatalf("Datetime insert failed: %s", err.Error()) + } + + // Select tuple with datetime. + var offset uint32 = 0 + var limit uint32 = 1 + sel := NewSelectRequest(spaceTuple1). + Index(index). + Offset(offset). + Limit(limit). + Iterator(IterEq). + Key([]interface{}{dt}) + data, err := conn.Do(sel).Get() + if err != nil { + t.Fatalf("Datetime select failed: %s", err.Error()) + } + assertDatetimeIsEqual(t, data, tm) + + // Delete tuple with datetime. + del := NewDeleteRequest(spaceTuple1).Index(index).Key([]interface{}{dt}) + data, err = conn.Do(del).Get() + if err != nil { + t.Fatalf("Datetime delete failed: %s", err.Error()) + } + assertDatetimeIsEqual(t, data, tm) +} + +var datetimeSample = []struct { + fmt string + dt string + mpBuf string // MessagePack buffer. + zone string +}{ + /* Cases for base encoding without a timezone. */ + {time.RFC3339, "2012-01-31T23:59:59.000000010Z", "d8047f80284f000000000a00000000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.000000010Z", "d80400000000000000000a00000000000000", ""}, + {time.RFC3339, "2010-08-12T11:39:14Z", "d70462dd634c00000000", ""}, + {time.RFC3339, "1984-03-24T18:04:05Z", "d7041530c31a00000000", ""}, + {time.RFC3339, "2010-01-12T00:00:00Z", "d70480bb4b4b00000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00Z", "d7040000000000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.123456789Z", "d804000000000000000015cd5b0700000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.12345678Z", "d80400000000000000000ccd5b0700000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.1234567Z", "d8040000000000000000bccc5b0700000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.123456Z", "d804000000000000000000ca5b0700000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.12345Z", "d804000000000000000090b25b0700000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.1234Z", "d804000000000000000040ef5a0700000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.123Z", "d8040000000000000000c0d4540700000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.12Z", "d8040000000000000000000e270700000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.1Z", "d804000000000000000000e1f50500000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.01Z", "d80400000000000000008096980000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.001Z", "d804000000000000000040420f0000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.0001Z", "d8040000000000000000a086010000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.00001Z", "d80400000000000000001027000000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.000001Z", "d8040000000000000000e803000000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.0000001Z", "d80400000000000000006400000000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.00000001Z", "d80400000000000000000a00000000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.000000001Z", "d80400000000000000000100000000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.000000009Z", "d80400000000000000000900000000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.00000009Z", "d80400000000000000005a00000000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.0000009Z", "d80400000000000000008403000000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.000009Z", "d80400000000000000002823000000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.00009Z", "d8040000000000000000905f010000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.0009Z", "d8040000000000000000a0bb0d0000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.009Z", "d80400000000000000004054890000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.09Z", "d8040000000000000000804a5d0500000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.9Z", "d804000000000000000000e9a43500000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.99Z", "d80400000000000000008033023b00000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.999Z", "d8040000000000000000c0878b3b00000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.9999Z", "d80400000000000000006043993b00000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.99999Z", "d8040000000000000000f0a29a3b00000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.999999Z", "d804000000000000000018c69a3b00000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.9999999Z", "d80400000000000000009cc99a3b00000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.99999999Z", "d8040000000000000000f6c99a3b00000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.999999999Z", "d8040000000000000000ffc99a3b00000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.0Z", "d7040000000000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.00Z", "d7040000000000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.000Z", "d7040000000000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.0000Z", "d7040000000000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.00000Z", "d7040000000000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.000000Z", "d7040000000000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.0000000Z", "d7040000000000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.00000000Z", "d7040000000000000000", ""}, + {time.RFC3339, "1970-01-01T00:00:00.000000000Z", "d7040000000000000000", ""}, + {time.RFC3339, "1973-11-29T21:33:09Z", "d70415cd5b0700000000", ""}, + {time.RFC3339, "2013-10-28T17:51:56Z", "d7043ca46e5200000000", ""}, + {time.RFC3339, "9999-12-31T23:59:59Z", "d7047f41f4ff3a000000", ""}, + /* Cases for encoding with a timezone. */ + {time.RFC3339, "2006-01-02T15:04:00Z", "d804e040b9430000000000000000b400b303", "Europe/Moscow"}, +} + +func TestDatetimeInsertSelectDelete(t *testing.T) { + skipIfDatetimeUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + for _, testcase := range datetimeSample { + t.Run(testcase.dt, func(t *testing.T) { + tm, err := time.Parse(testcase.fmt, testcase.dt) + if testcase.zone == "" { + tm = tm.In(noTimezoneLoc) + } else { + loc, err := time.LoadLocation(testcase.zone) + if err != nil { + t.Fatalf("Unable to load location: %s", err) + } + tm = tm.In(loc) + } + if err != nil { + t.Fatalf("Time (%s) parse failed: %s", testcase.dt, err) + } + tupleInsertSelectDelete(t, conn, tm) + }) + } +} + +// time.Parse() could not parse formatted string with datetime where year is +// bigger than 9999. That's why testcase with maximum datetime value represented +// as a separate testcase. Testcase with minimal value added for consistency. +func TestDatetimeBoundaryRange(t *testing.T) { + skipIfDatetimeUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + for _, tm := range append(lesserBoundaryTimes, boundaryTimes...) { + t.Run(tm.String(), func(t *testing.T) { + tupleInsertSelectDelete(t, conn, tm) + }) + } +} + +func TestDatetimeOutOfRange(t *testing.T) { + skipIfDatetimeUnsupported(t) + + for _, tm := range greaterBoundaryTimes { + t.Run(tm.String(), func(t *testing.T) { + _, err := MakeDatetime(tm) + if err == nil { + t.Errorf("Time %s should be unsupported!", tm) + } + }) + } +} + +func TestDatetimeReplace(t *testing.T) { + skipIfDatetimeUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + tm, err := time.Parse(time.RFC3339, "2007-01-02T15:04:05Z") + if err != nil { + t.Fatalf("Time parse failed: %s", err) + } + + dt, err := MakeDatetime(tm) + if err != nil { + t.Fatalf("Unable to create Datetime from %s: %s", tm, err) + } + rep := NewReplaceRequest(spaceTuple1).Tuple([]interface{}{dt, "payload"}) + data, err := conn.Do(rep).Get() + if err != nil { + t.Fatalf("Datetime replace failed: %s", err) + } + assertDatetimeIsEqual(t, data, tm) + + sel := NewSelectRequest(spaceTuple1). + Index(index). + Limit(1). + Iterator(IterEq). + Key([]interface{}{dt}) + data, err = conn.Do(sel).Get() + if err != nil { + t.Fatalf("Datetime select failed: %s", err) + } + assertDatetimeIsEqual(t, data, tm) + + // Delete tuple with datetime. + del := NewDeleteRequest(spaceTuple1).Index(index).Key([]interface{}{dt}) + _, err = conn.Do(del).Get() + if err != nil { + t.Fatalf("Datetime delete failed: %s", err.Error()) + } +} + +type Event struct { + Datetime Datetime + Location string +} + +type Tuple2 struct { + Cid uint + Orig string + Events []Event +} + +type Tuple1 struct { + Datetime Datetime +} + +func (t *Tuple1) EncodeMsgpack(e *msgpack.Encoder) error { + if err := e.EncodeArrayLen(2); err != nil { + return err + } + if err := e.Encode(&t.Datetime); err != nil { + return err + } + return nil +} + +func (t *Tuple1) DecodeMsgpack(d *msgpack.Decoder) error { + var err error + var l int + if l, err = d.DecodeArrayLen(); err != nil { + return err + } + if l != 1 { + return fmt.Errorf("Array len doesn't match: %d", l) + } + err = d.Decode(&t.Datetime) + if err != nil { + return err + } + return nil +} + +func (ev *Event) EncodeMsgpack(e *msgpack.Encoder) error { + if err := e.EncodeArrayLen(2); err != nil { + return err + } + if err := e.EncodeString(ev.Location); err != nil { + return err + } + if err := e.Encode(&ev.Datetime); err != nil { + return err + } + return nil +} + +func (ev *Event) DecodeMsgpack(d *msgpack.Decoder) error { + var err error + var l int + if l, err = d.DecodeArrayLen(); err != nil { + return err + } + if l != 2 { + return fmt.Errorf("Array len doesn't match: %d", l) + } + if ev.Location, err = d.DecodeString(); err != nil { + return err + } + res, err := d.DecodeInterface() + if err != nil { + return err + } + + if dt, ok := res.(Datetime); !ok { + return fmt.Errorf("Datetime doesn't match") + } else { + ev.Datetime = dt + } + return nil +} + +func (c *Tuple2) EncodeMsgpack(e *msgpack.Encoder) error { + if err := e.EncodeArrayLen(3); err != nil { + return err + } + if err := e.EncodeUint64(uint64(c.Cid)); err != nil { + return err + } + if err := e.EncodeString(c.Orig); err != nil { + return err + } + e.Encode(c.Events) + return nil +} + +func (c *Tuple2) DecodeMsgpack(d *msgpack.Decoder) error { + var err error + var l int + if l, err = d.DecodeArrayLen(); err != nil { + return err + } + if l != 3 { + return fmt.Errorf("Array len doesn't match: %d", l) + } + if c.Cid, err = d.DecodeUint(); err != nil { + return err + } + if c.Orig, err = d.DecodeString(); err != nil { + return err + } + if l, err = d.DecodeArrayLen(); err != nil { + return err + } + c.Events = make([]Event, l) + for i := 0; i < l; i++ { + d.Decode(&c.Events[i]) + } + return nil +} + +func TestCustomEncodeDecodeTuple1(t *testing.T) { + skipIfDatetimeUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + tm1, _ := time.Parse(time.RFC3339, "2010-05-24T17:51:56.000000009Z") + tm2, _ := time.Parse(time.RFC3339, "2022-05-24T17:51:56.000000009Z") + dt1, err := MakeDatetime(tm1) + if err != nil { + t.Fatalf("Unable to create Datetime from %s: %s", tm1, err) + } + dt2, err := MakeDatetime(tm2) + if err != nil { + t.Fatalf("Unable to create Datetime from %s: %s", tm2, err) + } + const cid = 13 + const orig = "orig" + + tuple := Tuple2{Cid: cid, + Orig: orig, + Events: []Event{ + {dt1, "Minsk"}, + {dt2, "Moscow"}, + }, + } + rep := NewReplaceRequest(spaceTuple2).Tuple(&tuple) + data, err := conn.Do(rep).Get() + if err != nil { + t.Fatalf("Failed to replace: %s", err.Error()) + } + if len(data) != 1 { + t.Fatalf("Response Body len != 1") + } + + tpl, ok := data[0].([]interface{}) + if !ok { + t.Fatalf("Unexpected body of Replace") + } + + // Delete the tuple. + del := NewDeleteRequest(spaceTuple2).Index(index).Key([]interface{}{cid}) + _, err = conn.Do(del).Get() + if err != nil { + t.Fatalf("Datetime delete failed: %s", err.Error()) + } + + if len(tpl) != 3 { + t.Fatalf("Unexpected body of Replace (tuple len)") + } + if id, ok := tpl[0].(uint64); !ok || id != cid { + t.Fatalf("Unexpected body of Replace (%d)", cid) + } + if o, ok := tpl[1].(string); !ok || o != orig { + t.Fatalf("Unexpected body of Replace (%s)", orig) + } + + events, ok := tpl[2].([]interface{}) + if !ok { + t.Fatalf("Unable to convert 2 field to []interface{}") + } + + for i, tv := range []time.Time{tm1, tm2} { + dt, ok := events[i].([]interface{})[1].(Datetime) + if !ok || !dt.ToTime().Equal(tv) { + t.Fatalf("%v != %v", dt.ToTime(), tv) + } + } +} + +func TestCustomDecodeFunction(t *testing.T) { + skipIfDatetimeUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + // Call function 'call_datetime_testdata' returning a custom tuples. + var tuple [][]Tuple2 + call := NewCallRequest("call_datetime_testdata").Args([]interface{}{1}) + err := conn.Do(call).GetTyped(&tuple) + if err != nil { + t.Fatalf("Failed to CallTyped: %s", err.Error()) + } + + if cid := tuple[0][0].Cid; cid != 5 { + t.Fatalf("Wrong Cid (%d), should be 5", cid) + } + if orig := tuple[0][0].Orig; orig != "Go!" { + t.Fatalf("Wrong Orig (%s), should be 'Hello, there!'", orig) + } + + events := tuple[0][0].Events + if len(events) != 3 { + t.Fatalf("Wrong a number of Events (%d), should be 3", len(events)) + } + + locations := []string{ + "Klushino", + "Baikonur", + "Novoselovo", + } + + for i, ev := range events { + loc := ev.Location + dt := ev.Datetime + if loc != locations[i] || dt.ToTime().IsZero() { + t.Fatalf("Expected: %s non-zero time, got %s %v", + locations[i], + loc, + dt.ToTime()) + } + } +} + +func TestCustomEncodeDecodeTuple5(t *testing.T) { + skipIfDatetimeUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + tm := time.Unix(500, 1000).In(time.FixedZone(NoTimezone, 0)) + dt, err := MakeDatetime(tm) + if err != nil { + t.Fatalf("Unable to create Datetime from %s: %s", tm, err) + } + + ins := NewInsertRequest(spaceTuple1).Tuple([]interface{}{dt}) + _, err = conn.Do(ins).Get() + if err != nil { + t.Fatalf("Datetime insert failed: %s", err.Error()) + } + + sel := NewSelectRequest(spaceTuple1). + Index(index). + Limit(1). + Iterator(IterEq). + Key([]interface{}{dt}) + data, errSel := conn.Do(sel).Get() + if errSel != nil { + t.Errorf("Failed to Select: %s", errSel.Error()) + } + if tpl, ok := data[0].([]interface{}); !ok { + t.Errorf("Unexpected body of Select") + } else { + if val, ok := tpl[0].(Datetime); !ok || !val.ToTime().Equal(tm) { + t.Fatalf("Unexpected body of Select") + } + } + + // Teardown: delete a value. + del := NewDeleteRequest(spaceTuple1).Index(index).Key([]interface{}{dt}) + _, err = conn.Do(del).Get() + if err != nil { + t.Fatalf("Datetime delete failed: %s", err.Error()) + } +} + +func TestMPEncode(t *testing.T) { + for _, testcase := range datetimeSample { + t.Run(testcase.dt, func(t *testing.T) { + tm, err := time.Parse(testcase.fmt, testcase.dt) + if testcase.zone == "" { + tm = tm.In(noTimezoneLoc) + } else { + loc, err := time.LoadLocation(testcase.zone) + if err != nil { + t.Fatalf("Unable to load location: %s", err) + } + tm = tm.In(loc) + } + if err != nil { + t.Fatalf("Time (%s) parse failed: %s", testcase.dt, err) + } + dt, err := MakeDatetime(tm) + if err != nil { + t.Fatalf("Unable to create Datetime from %s: %s", tm, err) + } + buf, err := msgpack.Marshal(dt) + if err != nil { + t.Fatalf("Marshalling failed: %s", err.Error()) + } + refBuf, _ := hex.DecodeString(testcase.mpBuf) + if reflect.DeepEqual(buf, refBuf) != true { + t.Fatalf("Failed to encode datetime '%s', actual %x, expected %x", + tm, + buf, + refBuf) + } + }) + } +} + +func TestMPDecode(t *testing.T) { + for _, testcase := range datetimeSample { + t.Run(testcase.dt, func(t *testing.T) { + tm, err := time.Parse(testcase.fmt, testcase.dt) + if testcase.zone == "" { + tm = tm.In(noTimezoneLoc) + } else { + loc, err := time.LoadLocation(testcase.zone) + if err != nil { + t.Fatalf("Unable to load location: %s", err) + } + tm = tm.In(loc) + } + if err != nil { + t.Fatalf("Time (%s) parse failed: %s", testcase.dt, err) + } + buf, _ := hex.DecodeString(testcase.mpBuf) + var v Datetime + err = msgpack.Unmarshal(buf, &v) + if err != nil { + t.Fatalf("Unmarshalling failed: %s", err.Error()) + } + if !tm.Equal(v.ToTime()) { + t.Fatalf("Failed to decode datetime buf '%s', actual %v, expected %v", + testcase.mpBuf, + testcase.dt, + v.ToTime()) + } + }) + } +} + +func TestUnmarshalMsgpackInvalidLength(t *testing.T) { + var v Datetime + + err := msgpack.Unmarshal([]byte{0xd4, 0x04, 0x04}, &v) + if err == nil { + t.Fatalf("Unexpected success %v", v) + } + if err.Error() != "invalid data length: got 1, wanted 8 or 16" { + t.Fatalf("Unexpected error: %s", err.Error()) + } +} + +func TestUnmarshalMsgpackInvalidZone(t *testing.T) { + var v Datetime + + // The original value from datetimeSample array: + // {time.RFC3339 + " MST", + // "2006-01-02T15:04:00+03:00 MSK", + // "d804b016b9430000000000000000b400ee00"} + buf, _ := hex.DecodeString("d804b016b9430000000000000000b400ee01") + err := msgpack.Unmarshal(buf, &v) + if err == nil { + t.Fatalf("Unexpected success %v", v) + } + if err.Error() != "unknown timezone index 494" { + t.Fatalf("Unexpected error: %s", err.Error()) + } +} + +// runTestMain is a body of TestMain function +// (see https://pkg.go.dev/testing#hdr-Main). +// Using defer + os.Exit is not works so TestMain body +// is a separate function, see +// https://stackoverflow.com/questions/27629380/how-to-exit-a-go-program-honoring-deferred-calls +func runTestMain(m *testing.M) int { + isLess, err := test_helpers.IsTarantoolVersionLess(2, 10, 0) + if err != nil { + log.Fatalf("Failed to extract Tarantool version: %s", err) + } + + if isLess { + log.Println("Skipping datetime tests...") + isDatetimeSupported = false + return m.Run() + } else { + isDatetimeSupported = true + } + + instance, err := test_helpers.StartTarantool(test_helpers.StartOpts{ + Dialer: dialer, + InitScript: "config.lua", + Listen: server, + WaitStart: 100 * time.Millisecond, + ConnectRetry: 10, + RetryTimeout: 500 * time.Millisecond, + }) + defer test_helpers.StopTarantoolWithCleanup(instance) + + if err != nil { + log.Printf("Failed to prepare test Tarantool: %s", err) + return 1 + } + + return m.Run() +} + +func TestMain(m *testing.M) { + code := runTestMain(m) + os.Exit(code) +} diff --git a/datetime/example_test.go b/datetime/example_test.go new file mode 100644 index 000000000..df5d55563 --- /dev/null +++ b/datetime/example_test.go @@ -0,0 +1,302 @@ +// Run a Tarantool instance before example execution: +// Terminal 1: +// $ cd datetime +// $ TEST_TNT_LISTEN=3013 TEST_TNT_WORK_DIR=$(mktemp -d -t 'tarantool.XXX') tarantool config.lua +// +// Terminal 2: +// $ cd datetime +// $ go test -v example_test.go +package datetime_test + +import ( + "context" + "fmt" + "time" + + "github.com/tarantool/go-tarantool/v3" + . "github.com/tarantool/go-tarantool/v3/datetime" +) + +// Example demonstrates how to use tuples with datetime. To enable support of +// datetime import tarantool/datetime package. +func Example() { + dialer := tarantool.NetDialer{ + Address: "127.0.0.1:3013", + User: "test", + Password: "test", + } + opts := tarantool.Opts{} + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, dialer, opts) + if err != nil { + fmt.Printf("Error in connect is %v", err) + return + } + + var datetime = "2013-10-28T17:51:56.000000009Z" + tm, err := time.Parse(time.RFC3339, datetime) + if err != nil { + fmt.Printf("Error in time.Parse() is %v", err) + return + } + dt, err := MakeDatetime(tm) + if err != nil { + fmt.Printf("Unable to create Datetime from %s: %s", tm, err) + return + } + + space := "testDatetime_1" + index := "primary" + + // Replace a tuple with datetime. + data, err := conn.Do(tarantool.NewReplaceRequest(space). + Tuple([]interface{}{dt}), + ).Get() + if err != nil { + fmt.Printf("Error in replace is %v", err) + return + } + respDt := data[0].([]interface{})[0].(Datetime) + fmt.Println("Datetime tuple replace") + fmt.Printf("Data: %v\n", respDt.ToTime()) + + // Select a tuple with datetime. + var offset uint32 = 0 + var limit uint32 = 1 + data, err = conn.Do(tarantool.NewSelectRequest(space). + Index(index). + Offset(offset). + Limit(limit). + Iterator(tarantool.IterEq). + Key([]interface{}{dt}), + ).Get() + if err != nil { + fmt.Printf("Error in select is %v", err) + return + } + respDt = data[0].([]interface{})[0].(Datetime) + fmt.Println("Datetime tuple select") + fmt.Printf("Data: %v\n", respDt.ToTime()) + + // Delete a tuple with datetime. + data, err = conn.Do(tarantool.NewDeleteRequest(space). + Index(index). + Key([]interface{}{dt}), + ).Get() + if err != nil { + fmt.Printf("Error in delete is %v", err) + return + } + respDt = data[0].([]interface{})[0].(Datetime) + fmt.Println("Datetime tuple delete") + fmt.Printf("Data: %v\n", respDt.ToTime()) +} + +// ExampleMakeDatetime_localUnsupported demonstrates that "Local" location is +// unsupported. +func ExampleMakeDatetime_localUnsupported() { + tm := time.Now().Local() + loc := tm.Location() + fmt.Println("Location:", loc) + if _, err := MakeDatetime(tm); err != nil { + fmt.Printf("Could not create a Datetime with %s location.\n", loc) + } else { + fmt.Printf("A Datetime with %s location created.\n", loc) + } + // Output: + // Location: Local + // Could not create a Datetime with Local location. +} + +// Example demonstrates how to create a datetime for Tarantool without UTC +// timezone in datetime. +func ExampleMakeDatetime_noTimezone() { + var datetime = "2013-10-28T17:51:56.000000009Z" + tm, err := time.Parse(time.RFC3339, datetime) + if err != nil { + fmt.Printf("Error in time.Parse() is %v", err) + return + } + + tm = tm.In(time.FixedZone(NoTimezone, 0)) + + dt, err := MakeDatetime(tm) + if err != nil { + fmt.Printf("Unable to create Datetime from %s: %s", tm, err) + return + } + + fmt.Printf("Time value: %v\n", dt.ToTime()) +} + +// ExampleDatetime_Interval demonstrates how to get an Interval value between +// two Datetime values. +func ExampleDatetime_Interval() { + var first = "2013-01-31T17:51:56.000000009Z" + var second = "2015-03-20T17:50:56.000000009Z" + + tmFirst, err := time.Parse(time.RFC3339, first) + if err != nil { + fmt.Printf("Error in time.Parse() is %v", err) + return + } + tmSecond, err := time.Parse(time.RFC3339, second) + if err != nil { + fmt.Printf("Error in time.Parse() is %v", err) + return + } + + dtFirst, err := MakeDatetime(tmFirst) + if err != nil { + fmt.Printf("Unable to create Datetime from %s: %s", tmFirst, err) + return + } + dtSecond, err := MakeDatetime(tmSecond) + if err != nil { + fmt.Printf("Unable to create Datetime from %s: %s", tmSecond, err) + return + } + + ival := dtFirst.Interval(dtSecond) + fmt.Printf("%v", ival) + // Output: + // {2 2 0 -11 0 -1 0 0 0} +} + +// ExampleDatetime_Add demonstrates how to add an Interval to a Datetime value. +func ExampleDatetime_Add() { + var datetime = "2013-01-31T17:51:56.000000009Z" + tm, err := time.Parse(time.RFC3339, datetime) + if err != nil { + fmt.Printf("Error in time.Parse() is %s", err) + return + } + dt, err := MakeDatetime(tm) + if err != nil { + fmt.Printf("Unable to create Datetime from %s: %s", tm, err) + return + } + + newdt, err := dt.Add(Interval{ + Year: 1, + Month: 1, + Sec: 333, + Adjust: LastAdjust, + }) + if err != nil { + fmt.Printf("Unable to add to Datetime: %s", err) + return + } + + fmt.Printf("New time: %s\n", newdt.ToTime().String()) + // Output: + // New time: 2014-02-28 17:57:29.000000009 +0000 UTC +} + +// ExampleDatetime_Add_dst demonstrates how to add an Interval to a +// Datetime value with a DST location. +func ExampleDatetime_Add_dst() { + loc, err := time.LoadLocation("Europe/Moscow") + if err != nil { + fmt.Printf("Unable to load location: %s", err) + return + } + tm := time.Date(2008, 1, 1, 1, 1, 1, 1, loc) + dt, err := MakeDatetime(tm) + if err != nil { + fmt.Printf("Unable to create Datetime: %s", err) + return + } + + fmt.Printf("Datetime time:\n") + fmt.Printf("%s\n", dt.ToTime()) + fmt.Printf("Datetime time + 6 month:\n") + fmt.Printf("%s\n", dt.ToTime().AddDate(0, 6, 0)) + dt, err = dt.Add(Interval{Month: 6}) + if err != nil { + fmt.Printf("Unable to add 6 month: %s", err) + return + } + fmt.Printf("Datetime + 6 month time:\n") + fmt.Printf("%s\n", dt.ToTime()) + + // Output: + // Datetime time: + // 2008-01-01 01:01:01.000000001 +0300 MSK + // Datetime time + 6 month: + // 2008-07-01 01:01:01.000000001 +0400 MSD + // Datetime + 6 month time: + // 2008-07-01 01:01:01.000000001 +0400 MSD +} + +// ExampleDatetime_Sub demonstrates how to subtract an Interval from a +// Datetime value. +func ExampleDatetime_Sub() { + var datetime = "2013-01-31T17:51:56.000000009Z" + tm, err := time.Parse(time.RFC3339, datetime) + if err != nil { + fmt.Printf("Error in time.Parse() is %s", err) + return + } + dt, err := MakeDatetime(tm) + if err != nil { + fmt.Printf("Unable to create Datetime from %s: %s", tm, err) + return + } + + newdt, err := dt.Sub(Interval{ + Year: 1, + Month: 1, + Sec: 333, + Adjust: LastAdjust, + }) + if err != nil { + fmt.Printf("Unable to sub from Datetime: %s", err) + return + } + + fmt.Printf("New time: %s\n", newdt.ToTime().String()) + // Output: + // New time: 2011-12-31 17:46:23.000000009 +0000 UTC +} + +// ExampleInterval_Add demonstrates how to add two intervals. +func ExampleInterval_Add() { + orig := Interval{ + Year: 1, + Month: 2, + Week: 3, + Sec: 10, + Adjust: ExcessAdjust, + } + ival := orig.Add(Interval{ + Year: 10, + Min: 30, + Adjust: LastAdjust, + }) + + fmt.Printf("%v", ival) + // Output: + // {11 2 3 0 0 30 10 0 1} +} + +// ExampleInterval_Sub demonstrates how to subtract two intervals. +func ExampleInterval_Sub() { + orig := Interval{ + Year: 1, + Month: 2, + Week: 3, + Sec: 10, + Adjust: ExcessAdjust, + } + ival := orig.Sub(Interval{ + Year: 10, + Min: 30, + Adjust: LastAdjust, + }) + + fmt.Printf("%v", ival) + // Output: + // {-9 2 3 0 0 -30 10 0 1} +} diff --git a/datetime/export_test.go b/datetime/export_test.go new file mode 100644 index 000000000..f138b7a7a --- /dev/null +++ b/datetime/export_test.go @@ -0,0 +1,5 @@ +package datetime + +/* It's kind of an integration test data from an external data source. */ +var IndexToTimezone = indexToTimezone +var TimezoneToIndex = timezoneToIndex diff --git a/datetime/gen-timezones.sh b/datetime/gen-timezones.sh new file mode 100755 index 000000000..e251d7db6 --- /dev/null +++ b/datetime/gen-timezones.sh @@ -0,0 +1,54 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +SRC_COMMIT="9ee45289e01232b8df1413efea11db170ae3b3b4" +SRC_FILE=timezones.h +DST_FILE=timezones.go + +[ -e ${SRC_FILE} ] && rm ${SRC_FILE} +wget -O ${SRC_FILE} \ + https://raw.githubusercontent.com/tarantool/tarantool/${SRC_COMMIT}/src/lib/tzcode/timezones.h + +# We don't need aliases in indexToTimezone because Tarantool always replace it: +# +# tarantool> T = date.parse '2022-01-01T00:00 Pacific/Enderbury' +# --- +# ... +# tarantool> T +# --- +# - 2022-01-01T00:00:00 Pacific/Kanton +# ... +# +# So we can do the same and don't worry, be happy. + +cat < ${DST_FILE} +package datetime + +/* Automatically generated by gen-timezones.sh */ + +var indexToTimezone = map[int]string{ +EOF + +grep ZONE_ABBREV ${SRC_FILE} | sed "s/ZONE_ABBREV( *//g" | sed "s/[),]//g" \ + | awk '{printf("\t%s : %s,\n", $1, $3)}' >> ${DST_FILE} +grep ZONE_UNIQUE ${SRC_FILE} | sed "s/ZONE_UNIQUE( *//g" | sed "s/[),]//g" \ + | awk '{printf("\t%s : %s,\n", $1, $2)}' >> ${DST_FILE} + +cat <> ${DST_FILE} +} + +var timezoneToIndex = map[string]int{ +EOF + +grep ZONE_ABBREV ${SRC_FILE} | sed "s/ZONE_ABBREV( *//g" | sed "s/[),]//g" \ + | awk '{printf("\t%s : %s,\n", $3, $1)}' >> ${DST_FILE} +grep ZONE_UNIQUE ${SRC_FILE} | sed "s/ZONE_UNIQUE( *//g" | sed "s/[),]//g" \ + | awk '{printf("\t%s : %s,\n", $2, $1)}' >> ${DST_FILE} +grep ZONE_ALIAS ${SRC_FILE} | sed "s/ZONE_ALIAS( *//g" | sed "s/[),]//g" \ + | awk '{printf("\t%s : %s,\n", $2, $1)}' >> ${DST_FILE} + +echo "}" >> ${DST_FILE} + +rm timezones.h + +gofmt -s -w ${DST_FILE} diff --git a/datetime/interval.go b/datetime/interval.go new file mode 100644 index 000000000..e6d39e4d8 --- /dev/null +++ b/datetime/interval.go @@ -0,0 +1,235 @@ +package datetime + +import ( + "bytes" + "reflect" + + "github.com/vmihailenco/msgpack/v5" +) + +const interval_extId = 6 + +const ( + fieldYear = 0 + fieldMonth = 1 + fieldWeek = 2 + fieldDay = 3 + fieldHour = 4 + fieldMin = 5 + fieldSec = 6 + fieldNSec = 7 + fieldAdjust = 8 +) + +// Interval type is GoLang implementation of Tarantool intervals. +// +//go:generate go tool gentypes -ext-code 6 Interval +type Interval struct { + Year int64 + Month int64 + Week int64 + Day int64 + Hour int64 + Min int64 + Sec int64 + Nsec int64 + Adjust Adjust +} + +func (ival Interval) countNonZeroFields() int { + count := 0 + + for _, field := range []int64{ + ival.Year, ival.Month, ival.Week, ival.Day, ival.Hour, + ival.Min, ival.Sec, ival.Nsec, adjustToDt[ival.Adjust], + } { + if field != 0 { + count++ + } + } + + return count +} + +// We use int64 for every field to avoid changes in the future, see: +// https://github.com/tarantool/tarantool/blob/943ce3caf8401510ced4f074bca7006c3d73f9b3/src/lib/core/datetime.h#L106 + +// Add creates a new Interval as addition of intervals. +func (ival Interval) Add(add Interval) Interval { + ival.Year += add.Year + ival.Month += add.Month + ival.Week += add.Week + ival.Day += add.Day + ival.Hour += add.Hour + ival.Min += add.Min + ival.Sec += add.Sec + ival.Nsec += add.Nsec + + return ival +} + +// Sub creates a new Interval as subtraction of intervals. +func (ival Interval) Sub(sub Interval) Interval { + ival.Year -= sub.Year + ival.Month -= sub.Month + ival.Week -= sub.Week + ival.Day -= sub.Day + ival.Hour -= sub.Hour + ival.Min -= sub.Min + ival.Sec -= sub.Sec + ival.Nsec -= sub.Nsec + + return ival +} + +// MarshalMsgpack implements a custom msgpack marshaler. +func (ival Interval) MarshalMsgpack() ([]byte, error) { + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + + if err := ival.MarshalMsgpackTo(enc); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// MarshalMsgpackTo implements a custom msgpack marshaler. +func (ival Interval) MarshalMsgpackTo(e *msgpack.Encoder) error { + var fieldNum = uint64(ival.countNonZeroFields()) + if err := e.EncodeUint(fieldNum); err != nil { + return err + } + + if err := encodeIntervalValue(e, fieldYear, ival.Year); err != nil { + return err + } + if err := encodeIntervalValue(e, fieldMonth, ival.Month); err != nil { + return err + } + if err := encodeIntervalValue(e, fieldWeek, ival.Week); err != nil { + return err + } + if err := encodeIntervalValue(e, fieldDay, ival.Day); err != nil { + return err + } + if err := encodeIntervalValue(e, fieldHour, ival.Hour); err != nil { + return err + } + if err := encodeIntervalValue(e, fieldMin, ival.Min); err != nil { + return err + } + if err := encodeIntervalValue(e, fieldSec, ival.Sec); err != nil { + return err + } + if err := encodeIntervalValue(e, fieldNSec, ival.Nsec); err != nil { + return err + } + if err := encodeIntervalValue(e, fieldAdjust, adjustToDt[ival.Adjust]); err != nil { + return err + } + + return nil +} + +// UnmarshalMsgpackFrom implements a custom msgpack unmarshaler. +func (ival *Interval) UnmarshalMsgpackFrom(d *msgpack.Decoder) error { + fieldNum, err := d.DecodeUint() + if err != nil { + return err + } + + ival.Adjust = dtToAdjust[int64(NoneAdjust)] + + for i := 0; i < int(fieldNum); i++ { + var fieldType uint + if fieldType, err = d.DecodeUint(); err != nil { + return err + } + + var fieldVal int64 + if fieldVal, err = d.DecodeInt64(); err != nil { + return err + } + + switch fieldType { + case fieldYear: + ival.Year = fieldVal + case fieldMonth: + ival.Month = fieldVal + case fieldWeek: + ival.Week = fieldVal + case fieldDay: + ival.Day = fieldVal + case fieldHour: + ival.Hour = fieldVal + case fieldMin: + ival.Min = fieldVal + case fieldSec: + ival.Sec = fieldVal + case fieldNSec: + ival.Nsec = fieldVal + case fieldAdjust: + ival.Adjust = dtToAdjust[fieldVal] + } + } + + return nil +} + +// UnmarshalMsgpack implements a custom msgpack unmarshaler. +func (ival *Interval) UnmarshalMsgpack(data []byte) error { + dec := msgpack.NewDecoder(bytes.NewReader(data)) + return ival.UnmarshalMsgpackFrom(dec) +} + +func encodeIntervalValue(e *msgpack.Encoder, typ uint64, value int64) error { + if value == 0 { + return nil + } + + err := e.EncodeUint(typ) + if err != nil { + return err + } + + switch { + case value > 0: + return e.EncodeUint(uint64(value)) + default: + return e.EncodeInt(value) + } +} + +func encodeInterval(e *msgpack.Encoder, v reflect.Value) (err error) { + val := v.Interface().(Interval) + return val.MarshalMsgpackTo(e) +} + +func decodeInterval(d *msgpack.Decoder, v reflect.Value) (err error) { + val := Interval{} + if err = val.UnmarshalMsgpackFrom(d); err != nil { + return + } + + v.Set(reflect.ValueOf(val)) + return nil +} + +func init() { + msgpack.RegisterExtEncoder(interval_extId, Interval{}, + func(e *msgpack.Encoder, v reflect.Value) (ret []byte, err error) { + var b bytes.Buffer + + enc := msgpack.NewEncoder(&b) + if err = encodeInterval(enc, v); err == nil { + ret = b.Bytes() + } + + return + }) + msgpack.RegisterExtDecoder(interval_extId, Interval{}, + func(d *msgpack.Decoder, v reflect.Value, extLen int) error { + return decodeInterval(d, v) + }) +} diff --git a/datetime/interval_gen.go b/datetime/interval_gen.go new file mode 100644 index 000000000..2cccaaca0 --- /dev/null +++ b/datetime/interval_gen.go @@ -0,0 +1,241 @@ +// Code generated by github.com/tarantool/go-option; DO NOT EDIT. + +package datetime + +import ( + "fmt" + + "github.com/vmihailenco/msgpack/v5" + "github.com/vmihailenco/msgpack/v5/msgpcode" + + "github.com/tarantool/go-option" +) + +// OptionalInterval represents an optional value of type Interval. +// It can either hold a valid Interval (IsSome == true) or be empty (IsZero == true). +type OptionalInterval struct { + value Interval + exists bool +} + +// SomeOptionalInterval creates an optional OptionalInterval with the given Interval value. +// The returned OptionalInterval will have IsSome() == true and IsZero() == false. +func SomeOptionalInterval(value Interval) OptionalInterval { + return OptionalInterval{ + value: value, + exists: true, + } +} + +// NoneOptionalInterval creates an empty optional OptionalInterval value. +// The returned OptionalInterval will have IsSome() == false and IsZero() == true. +// +// Example: +// +// o := NoneOptionalInterval() +// if o.IsZero() { +// fmt.Println("value is absent") +// } +func NoneOptionalInterval() OptionalInterval { + return OptionalInterval{} +} + +func (o OptionalInterval) newEncodeError(err error) error { + if err == nil { + return nil + } + return &option.EncodeError{ + Type: "OptionalInterval", + Parent: err, + } +} + +func (o OptionalInterval) newDecodeError(err error) error { + if err == nil { + return nil + } + + return &option.DecodeError{ + Type: "OptionalInterval", + Parent: err, + } +} + +// IsSome returns true if the OptionalInterval contains a value. +// This indicates the value is explicitly set (not None). +func (o OptionalInterval) IsSome() bool { + return o.exists +} + +// IsZero returns true if the OptionalInterval does not contain a value. +// Equivalent to !IsSome(). Useful for consistency with types where +// zero value (e.g. 0, false, zero struct) is valid and needs to be distinguished. +func (o OptionalInterval) IsZero() bool { + return !o.exists +} + +// IsNil is an alias for IsZero. +// +// This method is provided for compatibility with the msgpack Encoder interface. +func (o OptionalInterval) IsNil() bool { + return o.IsZero() +} + +// Get returns the stored value and a boolean flag indicating its presence. +// If the value is present, returns (value, true). +// If the value is absent, returns (zero value of Interval, false). +// +// Recommended usage: +// +// if value, ok := o.Get(); ok { +// // use value +// } +func (o OptionalInterval) Get() (Interval, bool) { + return o.value, o.exists +} + +// MustGet returns the stored value if it is present. +// Panics if the value is absent (i.e., IsZero() == true). +// +// Use with caution — only when you are certain the value exists. +// +// Panics with: "optional value is not set" if no value is set. +func (o OptionalInterval) MustGet() Interval { + if !o.exists { + panic("optional value is not set") + } + + return o.value +} + +// Unwrap returns the stored value regardless of presence. +// If no value is set, returns the zero value for Interval. +// +// Warning: Does not check presence. Use IsSome() before calling if you need +// to distinguish between absent value and explicit zero value. +func (o OptionalInterval) Unwrap() Interval { + return o.value +} + +// UnwrapOr returns the stored value if present. +// Otherwise, returns the provided default value. +// +// Example: +// +// o := NoneOptionalInterval() +// v := o.UnwrapOr(someDefaultOptionalInterval) +func (o OptionalInterval) UnwrapOr(defaultValue Interval) Interval { + if o.exists { + return o.value + } + + return defaultValue +} + +// UnwrapOrElse returns the stored value if present. +// Otherwise, calls the provided function and returns its result. +// Useful when the default value requires computation or side effects. +// +// Example: +// +// o := NoneOptionalInterval() +// v := o.UnwrapOrElse(func() Interval { return computeDefault() }) +func (o OptionalInterval) UnwrapOrElse(defaultValue func() Interval) Interval { + if o.exists { + return o.value + } + + return defaultValue() +} + +func (o OptionalInterval) encodeValue(encoder *msgpack.Encoder) error { + value, err := o.value.MarshalMsgpack() + if err != nil { + return err + } + + err = encoder.EncodeExtHeader(6, len(value)) + if err != nil { + return err + } + + _, err = encoder.Writer().Write(value) + if err != nil { + return err + } + + return nil +} + +// EncodeMsgpack encodes the OptionalInterval value using MessagePack format. +// - If the value is present, it is encoded as Interval. +// - If the value is absent (None), it is encoded as nil. +// +// Returns an error if encoding fails. +func (o OptionalInterval) EncodeMsgpack(encoder *msgpack.Encoder) error { + if o.exists { + return o.newEncodeError(o.encodeValue(encoder)) + } + + return o.newEncodeError(encoder.EncodeNil()) +} + +func (o *OptionalInterval) decodeValue(decoder *msgpack.Decoder) error { + tp, length, err := decoder.DecodeExtHeader() + switch { + case err != nil: + return o.newDecodeError(err) + case tp != 6: + return o.newDecodeError(fmt.Errorf("invalid extension code: %d", tp)) + } + + a := make([]byte, length) + if err := decoder.ReadFull(a); err != nil { + return o.newDecodeError(err) + } + + if err := o.value.UnmarshalMsgpack(a); err != nil { + return o.newDecodeError(err) + } + + o.exists = true + return nil +} + +func (o *OptionalInterval) checkCode(code byte) bool { + return msgpcode.IsExt(code) +} + +// DecodeMsgpack decodes a OptionalInterval value from MessagePack format. +// Supports two input types: +// - nil: interpreted as no value (NoneOptionalInterval) +// - Interval: interpreted as a present value (SomeOptionalInterval) +// +// Returns an error if the input type is unsupported or decoding fails. +// +// After successful decoding: +// - on nil: exists = false, value = default zero value +// - on Interval: exists = true, value = decoded value +func (o *OptionalInterval) DecodeMsgpack(decoder *msgpack.Decoder) error { + code, err := decoder.PeekCode() + if err != nil { + return o.newDecodeError(err) + } + + switch { + case code == msgpcode.Nil: + o.exists = false + + return o.newDecodeError(decoder.Skip()) + case o.checkCode(code): + err := o.decodeValue(decoder) + if err != nil { + return o.newDecodeError(err) + } + o.exists = true + + return err + default: + return o.newDecodeError(fmt.Errorf("unexpected code: %d", code)) + } +} diff --git a/datetime/interval_gen_test.go b/datetime/interval_gen_test.go new file mode 100644 index 000000000..162db3336 --- /dev/null +++ b/datetime/interval_gen_test.go @@ -0,0 +1,116 @@ +package datetime + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/vmihailenco/msgpack/v5" +) + +func TestSomeOptionalInterval(t *testing.T) { + val := Interval{Year: 1} + opt := SomeOptionalInterval(val) + + assert.True(t, opt.IsSome()) + assert.False(t, opt.IsZero()) + + v, ok := opt.Get() + assert.True(t, ok) + assert.Equal(t, val, v) +} + +func TestNoneOptionalInterval(t *testing.T) { + opt := NoneOptionalInterval() + + assert.False(t, opt.IsSome()) + assert.True(t, opt.IsZero()) + + _, ok := opt.Get() + assert.False(t, ok) +} + +func TestOptionalInterval_MustGet(t *testing.T) { + val := Interval{Year: 1} + optSome := SomeOptionalInterval(val) + optNone := NoneOptionalInterval() + + assert.Equal(t, val, optSome.MustGet()) + assert.Panics(t, func() { optNone.MustGet() }) +} + +func TestOptionalInterval_Unwrap(t *testing.T) { + val := Interval{Year: 1} + optSome := SomeOptionalInterval(val) + optNone := NoneOptionalInterval() + + assert.Equal(t, val, optSome.Unwrap()) + assert.Equal(t, Interval{}, optNone.Unwrap()) +} + +func TestOptionalInterval_UnwrapOr(t *testing.T) { + val := Interval{Year: 1} + def := Interval{Year: 2} + optSome := SomeOptionalInterval(val) + optNone := NoneOptionalInterval() + + assert.Equal(t, val, optSome.UnwrapOr(def)) + assert.Equal(t, def, optNone.UnwrapOr(def)) +} + +func TestOptionalInterval_UnwrapOrElse(t *testing.T) { + val := Interval{Year: 1} + def := Interval{Year: 2} + optSome := SomeOptionalInterval(val) + optNone := NoneOptionalInterval() + + assert.Equal(t, val, optSome.UnwrapOrElse(func() Interval { return def })) + assert.Equal(t, def, optNone.UnwrapOrElse(func() Interval { return def })) +} + +func TestOptionalInterval_EncodeDecodeMsgpack_Some(t *testing.T) { + val := Interval{Year: 1} + some := SomeOptionalInterval(val) + + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + dec := msgpack.NewDecoder(&buf) + + err := enc.Encode(some) + assert.NoError(t, err) + + var decodedSome OptionalInterval + err = dec.Decode(&decodedSome) + assert.NoError(t, err) + assert.True(t, decodedSome.IsSome()) + assert.Equal(t, val, decodedSome.Unwrap()) +} + +func TestOptionalInterval_EncodeDecodeMsgpack_None(t *testing.T) { + none := NoneOptionalInterval() + + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + dec := msgpack.NewDecoder(&buf) + + err := enc.Encode(none) + assert.NoError(t, err) + + var decodedNone OptionalInterval + err = dec.Decode(&decodedNone) + assert.NoError(t, err) + assert.True(t, decodedNone.IsZero()) +} + +func TestOptionalInterval_EncodeDecodeMsgpack_InvalidType(t *testing.T) { + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + dec := msgpack.NewDecoder(&buf) + + err := enc.Encode(123) + assert.NoError(t, err) + + var decodedInvalid OptionalInterval + err = dec.Decode(&decodedInvalid) + assert.Error(t, err) +} diff --git a/datetime/interval_test.go b/datetime/interval_test.go new file mode 100644 index 000000000..2f4bb8a66 --- /dev/null +++ b/datetime/interval_test.go @@ -0,0 +1,135 @@ +package datetime_test + +import ( + "fmt" + "reflect" + "testing" + + "github.com/tarantool/go-tarantool/v3" + . "github.com/tarantool/go-tarantool/v3/datetime" + "github.com/tarantool/go-tarantool/v3/test_helpers" +) + +func TestIntervalAdd(t *testing.T) { + orig := Interval{ + Year: 1, + Month: 2, + Week: 3, + Day: 4, + Hour: -5, + Min: 6, + Sec: -7, + Nsec: 8, + Adjust: LastAdjust, + } + cpyOrig := orig + add := Interval{ + Year: 2, + Month: 3, + Week: -4, + Day: 5, + Hour: -6, + Min: 7, + Sec: -8, + Nsec: 0, + Adjust: ExcessAdjust, + } + expected := Interval{ + Year: orig.Year + add.Year, + Month: orig.Month + add.Month, + Week: orig.Week + add.Week, + Day: orig.Day + add.Day, + Hour: orig.Hour + add.Hour, + Min: orig.Min + add.Min, + Sec: orig.Sec + add.Sec, + Nsec: orig.Nsec + add.Nsec, + Adjust: orig.Adjust, + } + + ival := orig.Add(add) + + if !reflect.DeepEqual(ival, expected) { + t.Fatalf("Unexpected %v, expected %v", ival, expected) + } + if !reflect.DeepEqual(cpyOrig, orig) { + t.Fatalf("Original value changed %v, expected %v", orig, cpyOrig) + } +} + +func TestIntervalSub(t *testing.T) { + orig := Interval{ + Year: 1, + Month: 2, + Week: 3, + Day: 4, + Hour: -5, + Min: 6, + Sec: -7, + Nsec: 8, + Adjust: LastAdjust, + } + cpyOrig := orig + sub := Interval{ + Year: 2, + Month: 3, + Week: -4, + Day: 5, + Hour: -6, + Min: 7, + Sec: -8, + Nsec: 0, + Adjust: ExcessAdjust, + } + expected := Interval{ + Year: orig.Year - sub.Year, + Month: orig.Month - sub.Month, + Week: orig.Week - sub.Week, + Day: orig.Day - sub.Day, + Hour: orig.Hour - sub.Hour, + Min: orig.Min - sub.Min, + Sec: orig.Sec - sub.Sec, + Nsec: orig.Nsec - sub.Nsec, + Adjust: orig.Adjust, + } + + ival := orig.Sub(sub) + + if !reflect.DeepEqual(ival, expected) { + t.Fatalf("Unexpected %v, expected %v", ival, expected) + } + if !reflect.DeepEqual(cpyOrig, orig) { + t.Fatalf("Original value changed %v, expected %v", orig, cpyOrig) + } +} + +func TestIntervalTarantoolEncoding(t *testing.T) { + skipIfDatetimeUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + cases := []Interval{ + {}, + {1, 2, 3, 4, -5, 6, -7, 8, LastAdjust}, + {1, 2, 3, 4, -5, 6, -7, 8, ExcessAdjust}, + {1, 2, 3, 4, -5, 6, -7, 8, LastAdjust}, + {0, 2, 3, 4, -5, 0, -7, 8, LastAdjust}, + {0, 0, 3, 0, -5, 6, -7, 8, ExcessAdjust}, + {0, 0, 0, 4, 0, 0, 0, 8, LastAdjust}, + } + for _, tc := range cases { + t.Run(fmt.Sprintf("%v", tc), func(t *testing.T) { + req := tarantool.NewCallRequest("call_interval_testdata"). + Args([]interface{}{tc}) + data, err := conn.Do(req).Get() + if err != nil { + t.Fatalf("Unexpected error: %s", err.Error()) + } + + ret := data[0].(Interval) + if !reflect.DeepEqual(ret, tc) { + t.Fatalf("Unexpected response: %v, expected %v", ret, tc) + } + }) + } +} diff --git a/datetime/timezones.go b/datetime/timezones.go new file mode 100644 index 000000000..ca4234fd1 --- /dev/null +++ b/datetime/timezones.go @@ -0,0 +1,1484 @@ +package datetime + +/* Automatically generated by gen-timezones.sh */ + +var indexToTimezone = map[int]string{ + 1: "A", + 2: "B", + 3: "C", + 4: "D", + 5: "E", + 6: "F", + 7: "G", + 8: "H", + 9: "I", + 10: "K", + 11: "L", + 12: "M", + 13: "N", + 14: "O", + 15: "P", + 16: "Q", + 17: "R", + 18: "S", + 19: "T", + 20: "U", + 21: "V", + 22: "W", + 23: "X", + 24: "Y", + 25: "Z", + 32: "AT", + 40: "BT", + 48: "CT", + 56: "ET", + 64: "GT", + 72: "IT", + 80: "KT", + 88: "MT", + 96: "PT", + 104: "ST", + 112: "UT", + 120: "WT", + 128: "ACT", + 129: "ADT", + 130: "AET", + 131: "AFT", + 132: "AMT", + 133: "AoE", + 134: "ART", + 135: "AST", + 136: "AZT", + 144: "BDT", + 145: "BNT", + 146: "BOT", + 147: "BRT", + 148: "BST", + 149: "BTT", + 152: "CAT", + 153: "CCT", + 154: "CDT", + 155: "CET", + 156: "CIT", + 157: "CKT", + 158: "CLT", + 159: "COT", + 160: "CST", + 161: "CVT", + 162: "CXT", + 168: "EAT", + 169: "ECT", + 170: "EDT", + 171: "EET", + 172: "EGT", + 173: "EST", + 176: "FET", + 177: "FJT", + 178: "FKT", + 179: "FNT", + 184: "GET", + 185: "GFT", + 186: "GMT", + 187: "GST", + 188: "GYT", + 192: "HAA", + 193: "HAC", + 194: "HAE", + 195: "HAP", + 196: "HAR", + 197: "HAT", + 198: "HDT", + 199: "HKT", + 200: "HLV", + 201: "HNA", + 202: "HNC", + 203: "HNE", + 204: "HNP", + 205: "HNR", + 206: "HNT", + 207: "HST", + 208: "ICT", + 209: "IDT", + 210: "IOT", + 211: "IST", + 216: "JST", + 224: "KGT", + 225: "KIT", + 226: "KST", + 232: "MCK", + 233: "MDT", + 234: "MEZ", + 235: "MHT", + 236: "MMT", + 237: "MSD", + 238: "MSK", + 239: "MST", + 240: "MUT", + 241: "MVT", + 242: "MYT", + 248: "NCT", + 249: "NDT", + 250: "NFT", + 251: "NPT", + 252: "NRT", + 253: "NST", + 254: "NUT", + 256: "OEZ", + 264: "PDT", + 265: "PET", + 266: "PGT", + 267: "PHT", + 268: "PKT", + 269: "PST", + 270: "PWT", + 271: "PYT", + 272: "RET", + 280: "SBT", + 281: "SCT", + 282: "SGT", + 283: "SRT", + 284: "SST", + 288: "TFT", + 289: "TJT", + 290: "TKT", + 291: "TLT", + 292: "TMT", + 293: "TOT", + 294: "TRT", + 295: "TVT", + 296: "UTC", + 297: "UYT", + 298: "UZT", + 304: "VET", + 305: "VUT", + 312: "WAT", + 313: "WDT", + 314: "WET", + 315: "WEZ", + 316: "WFT", + 317: "WGT", + 318: "WIB", + 319: "WIT", + 320: "WST", + 328: "ACDT", + 329: "ACST", + 330: "ADST", + 331: "AEDT", + 332: "AEST", + 333: "AKDT", + 334: "AKST", + 335: "ALMT", + 336: "AMDT", + 337: "AMST", + 338: "ANAT", + 339: "AQTT", + 340: "AWDT", + 341: "AWST", + 342: "AZOT", + 343: "AZST", + 344: "BDST", + 345: "BRST", + 352: "CAST", + 353: "CDST", + 354: "CEDT", + 355: "CEST", + 356: "CHOT", + 357: "ChST", + 358: "CHUT", + 359: "CIST", + 360: "CLDT", + 361: "CLST", + 368: "DAVT", + 369: "DDUT", + 376: "EADT", + 377: "EAST", + 378: "ECST", + 379: "EDST", + 380: "EEDT", + 381: "EEST", + 382: "EGST", + 384: "FJDT", + 385: "FJST", + 386: "FKDT", + 387: "FKST", + 392: "GALT", + 393: "GAMT", + 394: "GILT", + 400: "HADT", + 401: "HAST", + 402: "HOVT", + 408: "IRDT", + 409: "IRKT", + 410: "IRST", + 416: "KOST", + 417: "KRAT", + 418: "KUYT", + 424: "LHDT", + 425: "LHST", + 426: "LINT", + 432: "MAGT", + 433: "MART", + 434: "MAWT", + 435: "MDST", + 436: "MESZ", + 440: "NFDT", + 441: "NOVT", + 442: "NZDT", + 443: "NZST", + 448: "OESZ", + 449: "OMST", + 450: "ORAT", + 456: "PDST", + 457: "PETT", + 458: "PHOT", + 459: "PMDT", + 460: "PMST", + 461: "PONT", + 462: "PYST", + 464: "QYZT", + 472: "ROTT", + 480: "SAKT", + 481: "SAMT", + 482: "SAST", + 483: "SRET", + 484: "SYOT", + 488: "TAHT", + 489: "TOST", + 496: "ULAT", + 497: "UYST", + 504: "VLAT", + 505: "VOST", + 512: "WAKT", + 513: "WAST", + 514: "WEDT", + 515: "WEST", + 516: "WESZ", + 517: "WGST", + 518: "WITA", + 520: "YAKT", + 521: "YAPT", + 522: "YEKT", + 528: "ACWST", + 529: "ANAST", + 530: "AZODT", + 531: "AZOST", + 536: "CHADT", + 537: "CHAST", + 538: "CHODT", + 539: "CHOST", + 540: "CIDST", + 544: "EASST", + 545: "EFATE", + 552: "HOVDT", + 553: "HOVST", + 560: "IRKST", + 568: "KRAST", + 576: "MAGST", + 584: "NACDT", + 585: "NACST", + 586: "NAEDT", + 587: "NAEST", + 588: "NAMDT", + 589: "NAMST", + 590: "NAPDT", + 591: "NAPST", + 592: "NOVST", + 600: "OMSST", + 608: "PETST", + 616: "SAMST", + 624: "ULAST", + 632: "VLAST", + 640: "WARST", + 648: "YAKST", + 649: "YEKST", + 656: "CHODST", + 664: "HOVDST", + 672: "Africa/Abidjan", + 673: "Africa/Algiers", + 674: "Africa/Bissau", + 675: "Africa/Cairo", + 676: "Africa/Casablanca", + 677: "Africa/Ceuta", + 678: "Africa/El_Aaiun", + 679: "Africa/Johannesburg", + 680: "Africa/Juba", + 681: "Africa/Khartoum", + 682: "Africa/Lagos", + 683: "Africa/Maputo", + 684: "Africa/Monrovia", + 685: "Africa/Nairobi", + 686: "Africa/Ndjamena", + 687: "Africa/Sao_Tome", + 688: "Africa/Tripoli", + 689: "Africa/Tunis", + 690: "Africa/Windhoek", + 691: "America/Adak", + 692: "America/Anchorage", + 693: "America/Araguaina", + 694: "America/Argentina/Buenos_Aires", + 695: "America/Argentina/Catamarca", + 696: "America/Argentina/Cordoba", + 697: "America/Argentina/Jujuy", + 698: "America/Argentina/La_Rioja", + 699: "America/Argentina/Mendoza", + 700: "America/Argentina/Rio_Gallegos", + 701: "America/Argentina/Salta", + 702: "America/Argentina/San_Juan", + 703: "America/Argentina/San_Luis", + 704: "America/Argentina/Tucuman", + 705: "America/Argentina/Ushuaia", + 706: "America/Asuncion", + 707: "America/Bahia", + 708: "America/Bahia_Banderas", + 709: "America/Barbados", + 710: "America/Belem", + 711: "America/Belize", + 712: "America/Boa_Vista", + 713: "America/Bogota", + 714: "America/Boise", + 715: "America/Cambridge_Bay", + 716: "America/Campo_Grande", + 717: "America/Cancun", + 718: "America/Caracas", + 719: "America/Cayenne", + 720: "America/Chicago", + 721: "America/Chihuahua", + 722: "America/Costa_Rica", + 723: "America/Cuiaba", + 724: "America/Danmarkshavn", + 725: "America/Dawson", + 726: "America/Dawson_Creek", + 727: "America/Denver", + 728: "America/Detroit", + 729: "America/Edmonton", + 730: "America/Eirunepe", + 731: "America/El_Salvador", + 732: "America/Fort_Nelson", + 733: "America/Fortaleza", + 734: "America/Glace_Bay", + 735: "America/Goose_Bay", + 736: "America/Grand_Turk", + 737: "America/Guatemala", + 738: "America/Guayaquil", + 739: "America/Guyana", + 740: "America/Halifax", + 741: "America/Havana", + 742: "America/Hermosillo", + 743: "America/Indiana/Indianapolis", + 744: "America/Indiana/Knox", + 745: "America/Indiana/Marengo", + 746: "America/Indiana/Petersburg", + 747: "America/Indiana/Tell_City", + 748: "America/Indiana/Vevay", + 749: "America/Indiana/Vincennes", + 750: "America/Indiana/Winamac", + 751: "America/Inuvik", + 752: "America/Iqaluit", + 753: "America/Jamaica", + 754: "America/Juneau", + 755: "America/Kentucky/Louisville", + 756: "America/Kentucky/Monticello", + 757: "America/La_Paz", + 758: "America/Lima", + 759: "America/Los_Angeles", + 760: "America/Maceio", + 761: "America/Managua", + 762: "America/Manaus", + 763: "America/Martinique", + 764: "America/Matamoros", + 765: "America/Mazatlan", + 766: "America/Menominee", + 767: "America/Merida", + 768: "America/Metlakatla", + 769: "America/Mexico_City", + 770: "America/Miquelon", + 771: "America/Moncton", + 772: "America/Monterrey", + 773: "America/Montevideo", + 774: "America/New_York", + 775: "America/Nipigon", + 776: "America/Nome", + 777: "America/Noronha", + 778: "America/North_Dakota/Beulah", + 779: "America/North_Dakota/Center", + 780: "America/North_Dakota/New_Salem", + 781: "America/Nuuk", + 782: "America/Ojinaga", + 783: "America/Panama", + 784: "America/Pangnirtung", + 785: "America/Paramaribo", + 786: "America/Phoenix", + 787: "America/Port-au-Prince", + 788: "America/Porto_Velho", + 789: "America/Puerto_Rico", + 790: "America/Punta_Arenas", + 791: "America/Rainy_River", + 792: "America/Rankin_Inlet", + 793: "America/Recife", + 794: "America/Regina", + 795: "America/Resolute", + 796: "America/Rio_Branco", + 797: "America/Santarem", + 798: "America/Santiago", + 799: "America/Santo_Domingo", + 800: "America/Sao_Paulo", + 801: "America/Scoresbysund", + 802: "America/Sitka", + 803: "America/St_Johns", + 804: "America/Swift_Current", + 805: "America/Tegucigalpa", + 806: "America/Thule", + 807: "America/Thunder_Bay", + 808: "America/Tijuana", + 809: "America/Toronto", + 810: "America/Vancouver", + 811: "America/Whitehorse", + 812: "America/Winnipeg", + 813: "America/Yakutat", + 814: "America/Yellowknife", + 815: "Antarctica/Casey", + 816: "Antarctica/Davis", + 817: "Antarctica/Macquarie", + 818: "Antarctica/Mawson", + 819: "Antarctica/Palmer", + 820: "Antarctica/Rothera", + 821: "Antarctica/Troll", + 822: "Antarctica/Vostok", + 823: "Asia/Almaty", + 824: "Asia/Amman", + 825: "Asia/Anadyr", + 826: "Asia/Aqtau", + 827: "Asia/Aqtobe", + 828: "Asia/Ashgabat", + 829: "Asia/Atyrau", + 830: "Asia/Baghdad", + 831: "Asia/Baku", + 832: "Asia/Bangkok", + 833: "Asia/Barnaul", + 834: "Asia/Beirut", + 835: "Asia/Bishkek", + 836: "Asia/Brunei", + 837: "Asia/Chita", + 838: "Asia/Choibalsan", + 839: "Asia/Colombo", + 840: "Asia/Damascus", + 841: "Asia/Dhaka", + 842: "Asia/Dili", + 843: "Asia/Dubai", + 844: "Asia/Dushanbe", + 845: "Asia/Famagusta", + 846: "Asia/Gaza", + 847: "Asia/Hebron", + 848: "Asia/Ho_Chi_Minh", + 849: "Asia/Hong_Kong", + 850: "Asia/Hovd", + 851: "Asia/Irkutsk", + 852: "Asia/Jakarta", + 853: "Asia/Jayapura", + 854: "Asia/Jerusalem", + 855: "Asia/Kabul", + 856: "Asia/Kamchatka", + 857: "Asia/Karachi", + 858: "Asia/Kathmandu", + 859: "Asia/Khandyga", + 860: "Asia/Kolkata", + 861: "Asia/Krasnoyarsk", + 862: "Asia/Kuala_Lumpur", + 863: "Asia/Kuching", + 864: "Asia/Macau", + 865: "Asia/Magadan", + 866: "Asia/Makassar", + 867: "Asia/Manila", + 868: "Asia/Nicosia", + 869: "Asia/Novokuznetsk", + 870: "Asia/Novosibirsk", + 871: "Asia/Omsk", + 872: "Asia/Oral", + 873: "Asia/Pontianak", + 874: "Asia/Pyongyang", + 875: "Asia/Qatar", + 876: "Asia/Qostanay", + 877: "Asia/Qyzylorda", + 878: "Asia/Riyadh", + 879: "Asia/Sakhalin", + 880: "Asia/Samarkand", + 881: "Asia/Seoul", + 882: "Asia/Shanghai", + 883: "Asia/Singapore", + 884: "Asia/Srednekolymsk", + 885: "Asia/Taipei", + 886: "Asia/Tashkent", + 887: "Asia/Tbilisi", + 888: "Asia/Tehran", + 889: "Asia/Thimphu", + 890: "Asia/Tokyo", + 891: "Asia/Tomsk", + 892: "Asia/Ulaanbaatar", + 893: "Asia/Urumqi", + 894: "Asia/Ust-Nera", + 895: "Asia/Vladivostok", + 896: "Asia/Yakutsk", + 897: "Asia/Yangon", + 898: "Asia/Yekaterinburg", + 899: "Asia/Yerevan", + 900: "Atlantic/Azores", + 901: "Atlantic/Bermuda", + 902: "Atlantic/Canary", + 903: "Atlantic/Cape_Verde", + 904: "Atlantic/Faroe", + 905: "Atlantic/Madeira", + 906: "Atlantic/Reykjavik", + 907: "Atlantic/South_Georgia", + 908: "Atlantic/Stanley", + 909: "Australia/Adelaide", + 910: "Australia/Brisbane", + 911: "Australia/Broken_Hill", + 912: "Australia/Darwin", + 913: "Australia/Eucla", + 914: "Australia/Hobart", + 915: "Australia/Lindeman", + 916: "Australia/Lord_Howe", + 917: "Australia/Melbourne", + 918: "Australia/Perth", + 919: "Australia/Sydney", + 920: "Etc/GMT", + 921: "Etc/UTC", + 922: "Europe/Amsterdam", + 923: "Europe/Andorra", + 924: "Europe/Astrakhan", + 925: "Europe/Athens", + 926: "Europe/Belgrade", + 927: "Europe/Berlin", + 928: "Europe/Brussels", + 929: "Europe/Bucharest", + 930: "Europe/Budapest", + 931: "Europe/Chisinau", + 932: "Europe/Copenhagen", + 933: "Europe/Dublin", + 934: "Europe/Gibraltar", + 935: "Europe/Helsinki", + 936: "Europe/Istanbul", + 937: "Europe/Kaliningrad", + 938: "Europe/Kiev", + 939: "Europe/Kirov", + 940: "Europe/Lisbon", + 941: "Europe/London", + 942: "Europe/Luxembourg", + 943: "Europe/Madrid", + 944: "Europe/Malta", + 945: "Europe/Minsk", + 946: "Europe/Monaco", + 947: "Europe/Moscow", + 948: "Europe/Oslo", + 949: "Europe/Paris", + 950: "Europe/Prague", + 951: "Europe/Riga", + 952: "Europe/Rome", + 953: "Europe/Samara", + 954: "Europe/Saratov", + 955: "Europe/Simferopol", + 956: "Europe/Sofia", + 957: "Europe/Stockholm", + 958: "Europe/Tallinn", + 959: "Europe/Tirane", + 960: "Europe/Ulyanovsk", + 961: "Europe/Uzhgorod", + 962: "Europe/Vienna", + 963: "Europe/Vilnius", + 964: "Europe/Volgograd", + 965: "Europe/Warsaw", + 966: "Europe/Zaporozhye", + 967: "Europe/Zurich", + 968: "Indian/Chagos", + 969: "Indian/Christmas", + 970: "Indian/Cocos", + 971: "Indian/Kerguelen", + 972: "Indian/Mahe", + 973: "Indian/Maldives", + 974: "Indian/Mauritius", + 975: "Indian/Reunion", + 976: "Pacific/Apia", + 977: "Pacific/Auckland", + 978: "Pacific/Bougainville", + 979: "Pacific/Chatham", + 980: "Pacific/Chuuk", + 981: "Pacific/Easter", + 982: "Pacific/Efate", + 983: "Pacific/Fakaofo", + 984: "Pacific/Fiji", + 985: "Pacific/Funafuti", + 986: "Pacific/Galapagos", + 987: "Pacific/Gambier", + 988: "Pacific/Guadalcanal", + 989: "Pacific/Guam", + 990: "Pacific/Honolulu", + 991: "Pacific/Kanton", + 992: "Pacific/Kiritimati", + 993: "Pacific/Kosrae", + 994: "Pacific/Kwajalein", + 995: "Pacific/Majuro", + 996: "Pacific/Marquesas", + 997: "Pacific/Nauru", + 998: "Pacific/Niue", + 999: "Pacific/Norfolk", + 1000: "Pacific/Noumea", + 1001: "Pacific/Pago_Pago", + 1002: "Pacific/Palau", + 1003: "Pacific/Pitcairn", + 1004: "Pacific/Pohnpei", + 1005: "Pacific/Port_Moresby", + 1006: "Pacific/Rarotonga", + 1007: "Pacific/Tahiti", + 1008: "Pacific/Tarawa", + 1009: "Pacific/Tongatapu", + 1010: "Pacific/Wake", + 1011: "Pacific/Wallis", +} + +var timezoneToIndex = map[string]int{ + "A": 1, + "B": 2, + "C": 3, + "D": 4, + "E": 5, + "F": 6, + "G": 7, + "H": 8, + "I": 9, + "K": 10, + "L": 11, + "M": 12, + "N": 13, + "O": 14, + "P": 15, + "Q": 16, + "R": 17, + "S": 18, + "T": 19, + "U": 20, + "V": 21, + "W": 22, + "X": 23, + "Y": 24, + "Z": 25, + "AT": 32, + "BT": 40, + "CT": 48, + "ET": 56, + "GT": 64, + "IT": 72, + "KT": 80, + "MT": 88, + "PT": 96, + "ST": 104, + "UT": 112, + "WT": 120, + "ACT": 128, + "ADT": 129, + "AET": 130, + "AFT": 131, + "AMT": 132, + "AoE": 133, + "ART": 134, + "AST": 135, + "AZT": 136, + "BDT": 144, + "BNT": 145, + "BOT": 146, + "BRT": 147, + "BST": 148, + "BTT": 149, + "CAT": 152, + "CCT": 153, + "CDT": 154, + "CET": 155, + "CIT": 156, + "CKT": 157, + "CLT": 158, + "COT": 159, + "CST": 160, + "CVT": 161, + "CXT": 162, + "EAT": 168, + "ECT": 169, + "EDT": 170, + "EET": 171, + "EGT": 172, + "EST": 173, + "FET": 176, + "FJT": 177, + "FKT": 178, + "FNT": 179, + "GET": 184, + "GFT": 185, + "GMT": 186, + "GST": 187, + "GYT": 188, + "HAA": 192, + "HAC": 193, + "HAE": 194, + "HAP": 195, + "HAR": 196, + "HAT": 197, + "HDT": 198, + "HKT": 199, + "HLV": 200, + "HNA": 201, + "HNC": 202, + "HNE": 203, + "HNP": 204, + "HNR": 205, + "HNT": 206, + "HST": 207, + "ICT": 208, + "IDT": 209, + "IOT": 210, + "IST": 211, + "JST": 216, + "KGT": 224, + "KIT": 225, + "KST": 226, + "MCK": 232, + "MDT": 233, + "MEZ": 234, + "MHT": 235, + "MMT": 236, + "MSD": 237, + "MSK": 238, + "MST": 239, + "MUT": 240, + "MVT": 241, + "MYT": 242, + "NCT": 248, + "NDT": 249, + "NFT": 250, + "NPT": 251, + "NRT": 252, + "NST": 253, + "NUT": 254, + "OEZ": 256, + "PDT": 264, + "PET": 265, + "PGT": 266, + "PHT": 267, + "PKT": 268, + "PST": 269, + "PWT": 270, + "PYT": 271, + "RET": 272, + "SBT": 280, + "SCT": 281, + "SGT": 282, + "SRT": 283, + "SST": 284, + "TFT": 288, + "TJT": 289, + "TKT": 290, + "TLT": 291, + "TMT": 292, + "TOT": 293, + "TRT": 294, + "TVT": 295, + "UTC": 296, + "UYT": 297, + "UZT": 298, + "VET": 304, + "VUT": 305, + "WAT": 312, + "WDT": 313, + "WET": 314, + "WEZ": 315, + "WFT": 316, + "WGT": 317, + "WIB": 318, + "WIT": 319, + "WST": 320, + "ACDT": 328, + "ACST": 329, + "ADST": 330, + "AEDT": 331, + "AEST": 332, + "AKDT": 333, + "AKST": 334, + "ALMT": 335, + "AMDT": 336, + "AMST": 337, + "ANAT": 338, + "AQTT": 339, + "AWDT": 340, + "AWST": 341, + "AZOT": 342, + "AZST": 343, + "BDST": 344, + "BRST": 345, + "CAST": 352, + "CDST": 353, + "CEDT": 354, + "CEST": 355, + "CHOT": 356, + "ChST": 357, + "CHUT": 358, + "CIST": 359, + "CLDT": 360, + "CLST": 361, + "DAVT": 368, + "DDUT": 369, + "EADT": 376, + "EAST": 377, + "ECST": 378, + "EDST": 379, + "EEDT": 380, + "EEST": 381, + "EGST": 382, + "FJDT": 384, + "FJST": 385, + "FKDT": 386, + "FKST": 387, + "GALT": 392, + "GAMT": 393, + "GILT": 394, + "HADT": 400, + "HAST": 401, + "HOVT": 402, + "IRDT": 408, + "IRKT": 409, + "IRST": 410, + "KOST": 416, + "KRAT": 417, + "KUYT": 418, + "LHDT": 424, + "LHST": 425, + "LINT": 426, + "MAGT": 432, + "MART": 433, + "MAWT": 434, + "MDST": 435, + "MESZ": 436, + "NFDT": 440, + "NOVT": 441, + "NZDT": 442, + "NZST": 443, + "OESZ": 448, + "OMST": 449, + "ORAT": 450, + "PDST": 456, + "PETT": 457, + "PHOT": 458, + "PMDT": 459, + "PMST": 460, + "PONT": 461, + "PYST": 462, + "QYZT": 464, + "ROTT": 472, + "SAKT": 480, + "SAMT": 481, + "SAST": 482, + "SRET": 483, + "SYOT": 484, + "TAHT": 488, + "TOST": 489, + "ULAT": 496, + "UYST": 497, + "VLAT": 504, + "VOST": 505, + "WAKT": 512, + "WAST": 513, + "WEDT": 514, + "WEST": 515, + "WESZ": 516, + "WGST": 517, + "WITA": 518, + "YAKT": 520, + "YAPT": 521, + "YEKT": 522, + "ACWST": 528, + "ANAST": 529, + "AZODT": 530, + "AZOST": 531, + "CHADT": 536, + "CHAST": 537, + "CHODT": 538, + "CHOST": 539, + "CIDST": 540, + "EASST": 544, + "EFATE": 545, + "HOVDT": 552, + "HOVST": 553, + "IRKST": 560, + "KRAST": 568, + "MAGST": 576, + "NACDT": 584, + "NACST": 585, + "NAEDT": 586, + "NAEST": 587, + "NAMDT": 588, + "NAMST": 589, + "NAPDT": 590, + "NAPST": 591, + "NOVST": 592, + "OMSST": 600, + "PETST": 608, + "SAMST": 616, + "ULAST": 624, + "VLAST": 632, + "WARST": 640, + "YAKST": 648, + "YEKST": 649, + "CHODST": 656, + "HOVDST": 664, + "Africa/Abidjan": 672, + "Africa/Algiers": 673, + "Africa/Bissau": 674, + "Africa/Cairo": 675, + "Africa/Casablanca": 676, + "Africa/Ceuta": 677, + "Africa/El_Aaiun": 678, + "Africa/Johannesburg": 679, + "Africa/Juba": 680, + "Africa/Khartoum": 681, + "Africa/Lagos": 682, + "Africa/Maputo": 683, + "Africa/Monrovia": 684, + "Africa/Nairobi": 685, + "Africa/Ndjamena": 686, + "Africa/Sao_Tome": 687, + "Africa/Tripoli": 688, + "Africa/Tunis": 689, + "Africa/Windhoek": 690, + "America/Adak": 691, + "America/Anchorage": 692, + "America/Araguaina": 693, + "America/Argentina/Buenos_Aires": 694, + "America/Argentina/Catamarca": 695, + "America/Argentina/Cordoba": 696, + "America/Argentina/Jujuy": 697, + "America/Argentina/La_Rioja": 698, + "America/Argentina/Mendoza": 699, + "America/Argentina/Rio_Gallegos": 700, + "America/Argentina/Salta": 701, + "America/Argentina/San_Juan": 702, + "America/Argentina/San_Luis": 703, + "America/Argentina/Tucuman": 704, + "America/Argentina/Ushuaia": 705, + "America/Asuncion": 706, + "America/Bahia": 707, + "America/Bahia_Banderas": 708, + "America/Barbados": 709, + "America/Belem": 710, + "America/Belize": 711, + "America/Boa_Vista": 712, + "America/Bogota": 713, + "America/Boise": 714, + "America/Cambridge_Bay": 715, + "America/Campo_Grande": 716, + "America/Cancun": 717, + "America/Caracas": 718, + "America/Cayenne": 719, + "America/Chicago": 720, + "America/Chihuahua": 721, + "America/Costa_Rica": 722, + "America/Cuiaba": 723, + "America/Danmarkshavn": 724, + "America/Dawson": 725, + "America/Dawson_Creek": 726, + "America/Denver": 727, + "America/Detroit": 728, + "America/Edmonton": 729, + "America/Eirunepe": 730, + "America/El_Salvador": 731, + "America/Fort_Nelson": 732, + "America/Fortaleza": 733, + "America/Glace_Bay": 734, + "America/Goose_Bay": 735, + "America/Grand_Turk": 736, + "America/Guatemala": 737, + "America/Guayaquil": 738, + "America/Guyana": 739, + "America/Halifax": 740, + "America/Havana": 741, + "America/Hermosillo": 742, + "America/Indiana/Indianapolis": 743, + "America/Indiana/Knox": 744, + "America/Indiana/Marengo": 745, + "America/Indiana/Petersburg": 746, + "America/Indiana/Tell_City": 747, + "America/Indiana/Vevay": 748, + "America/Indiana/Vincennes": 749, + "America/Indiana/Winamac": 750, + "America/Inuvik": 751, + "America/Iqaluit": 752, + "America/Jamaica": 753, + "America/Juneau": 754, + "America/Kentucky/Louisville": 755, + "America/Kentucky/Monticello": 756, + "America/La_Paz": 757, + "America/Lima": 758, + "America/Los_Angeles": 759, + "America/Maceio": 760, + "America/Managua": 761, + "America/Manaus": 762, + "America/Martinique": 763, + "America/Matamoros": 764, + "America/Mazatlan": 765, + "America/Menominee": 766, + "America/Merida": 767, + "America/Metlakatla": 768, + "America/Mexico_City": 769, + "America/Miquelon": 770, + "America/Moncton": 771, + "America/Monterrey": 772, + "America/Montevideo": 773, + "America/New_York": 774, + "America/Nipigon": 775, + "America/Nome": 776, + "America/Noronha": 777, + "America/North_Dakota/Beulah": 778, + "America/North_Dakota/Center": 779, + "America/North_Dakota/New_Salem": 780, + "America/Nuuk": 781, + "America/Ojinaga": 782, + "America/Panama": 783, + "America/Pangnirtung": 784, + "America/Paramaribo": 785, + "America/Phoenix": 786, + "America/Port-au-Prince": 787, + "America/Porto_Velho": 788, + "America/Puerto_Rico": 789, + "America/Punta_Arenas": 790, + "America/Rainy_River": 791, + "America/Rankin_Inlet": 792, + "America/Recife": 793, + "America/Regina": 794, + "America/Resolute": 795, + "America/Rio_Branco": 796, + "America/Santarem": 797, + "America/Santiago": 798, + "America/Santo_Domingo": 799, + "America/Sao_Paulo": 800, + "America/Scoresbysund": 801, + "America/Sitka": 802, + "America/St_Johns": 803, + "America/Swift_Current": 804, + "America/Tegucigalpa": 805, + "America/Thule": 806, + "America/Thunder_Bay": 807, + "America/Tijuana": 808, + "America/Toronto": 809, + "America/Vancouver": 810, + "America/Whitehorse": 811, + "America/Winnipeg": 812, + "America/Yakutat": 813, + "America/Yellowknife": 814, + "Antarctica/Casey": 815, + "Antarctica/Davis": 816, + "Antarctica/Macquarie": 817, + "Antarctica/Mawson": 818, + "Antarctica/Palmer": 819, + "Antarctica/Rothera": 820, + "Antarctica/Troll": 821, + "Antarctica/Vostok": 822, + "Asia/Almaty": 823, + "Asia/Amman": 824, + "Asia/Anadyr": 825, + "Asia/Aqtau": 826, + "Asia/Aqtobe": 827, + "Asia/Ashgabat": 828, + "Asia/Atyrau": 829, + "Asia/Baghdad": 830, + "Asia/Baku": 831, + "Asia/Bangkok": 832, + "Asia/Barnaul": 833, + "Asia/Beirut": 834, + "Asia/Bishkek": 835, + "Asia/Brunei": 836, + "Asia/Chita": 837, + "Asia/Choibalsan": 838, + "Asia/Colombo": 839, + "Asia/Damascus": 840, + "Asia/Dhaka": 841, + "Asia/Dili": 842, + "Asia/Dubai": 843, + "Asia/Dushanbe": 844, + "Asia/Famagusta": 845, + "Asia/Gaza": 846, + "Asia/Hebron": 847, + "Asia/Ho_Chi_Minh": 848, + "Asia/Hong_Kong": 849, + "Asia/Hovd": 850, + "Asia/Irkutsk": 851, + "Asia/Jakarta": 852, + "Asia/Jayapura": 853, + "Asia/Jerusalem": 854, + "Asia/Kabul": 855, + "Asia/Kamchatka": 856, + "Asia/Karachi": 857, + "Asia/Kathmandu": 858, + "Asia/Khandyga": 859, + "Asia/Kolkata": 860, + "Asia/Krasnoyarsk": 861, + "Asia/Kuala_Lumpur": 862, + "Asia/Kuching": 863, + "Asia/Macau": 864, + "Asia/Magadan": 865, + "Asia/Makassar": 866, + "Asia/Manila": 867, + "Asia/Nicosia": 868, + "Asia/Novokuznetsk": 869, + "Asia/Novosibirsk": 870, + "Asia/Omsk": 871, + "Asia/Oral": 872, + "Asia/Pontianak": 873, + "Asia/Pyongyang": 874, + "Asia/Qatar": 875, + "Asia/Qostanay": 876, + "Asia/Qyzylorda": 877, + "Asia/Riyadh": 878, + "Asia/Sakhalin": 879, + "Asia/Samarkand": 880, + "Asia/Seoul": 881, + "Asia/Shanghai": 882, + "Asia/Singapore": 883, + "Asia/Srednekolymsk": 884, + "Asia/Taipei": 885, + "Asia/Tashkent": 886, + "Asia/Tbilisi": 887, + "Asia/Tehran": 888, + "Asia/Thimphu": 889, + "Asia/Tokyo": 890, + "Asia/Tomsk": 891, + "Asia/Ulaanbaatar": 892, + "Asia/Urumqi": 893, + "Asia/Ust-Nera": 894, + "Asia/Vladivostok": 895, + "Asia/Yakutsk": 896, + "Asia/Yangon": 897, + "Asia/Yekaterinburg": 898, + "Asia/Yerevan": 899, + "Atlantic/Azores": 900, + "Atlantic/Bermuda": 901, + "Atlantic/Canary": 902, + "Atlantic/Cape_Verde": 903, + "Atlantic/Faroe": 904, + "Atlantic/Madeira": 905, + "Atlantic/Reykjavik": 906, + "Atlantic/South_Georgia": 907, + "Atlantic/Stanley": 908, + "Australia/Adelaide": 909, + "Australia/Brisbane": 910, + "Australia/Broken_Hill": 911, + "Australia/Darwin": 912, + "Australia/Eucla": 913, + "Australia/Hobart": 914, + "Australia/Lindeman": 915, + "Australia/Lord_Howe": 916, + "Australia/Melbourne": 917, + "Australia/Perth": 918, + "Australia/Sydney": 919, + "Etc/GMT": 920, + "Etc/UTC": 921, + "Europe/Amsterdam": 922, + "Europe/Andorra": 923, + "Europe/Astrakhan": 924, + "Europe/Athens": 925, + "Europe/Belgrade": 926, + "Europe/Berlin": 927, + "Europe/Brussels": 928, + "Europe/Bucharest": 929, + "Europe/Budapest": 930, + "Europe/Chisinau": 931, + "Europe/Copenhagen": 932, + "Europe/Dublin": 933, + "Europe/Gibraltar": 934, + "Europe/Helsinki": 935, + "Europe/Istanbul": 936, + "Europe/Kaliningrad": 937, + "Europe/Kiev": 938, + "Europe/Kirov": 939, + "Europe/Lisbon": 940, + "Europe/London": 941, + "Europe/Luxembourg": 942, + "Europe/Madrid": 943, + "Europe/Malta": 944, + "Europe/Minsk": 945, + "Europe/Monaco": 946, + "Europe/Moscow": 947, + "Europe/Oslo": 948, + "Europe/Paris": 949, + "Europe/Prague": 950, + "Europe/Riga": 951, + "Europe/Rome": 952, + "Europe/Samara": 953, + "Europe/Saratov": 954, + "Europe/Simferopol": 955, + "Europe/Sofia": 956, + "Europe/Stockholm": 957, + "Europe/Tallinn": 958, + "Europe/Tirane": 959, + "Europe/Ulyanovsk": 960, + "Europe/Uzhgorod": 961, + "Europe/Vienna": 962, + "Europe/Vilnius": 963, + "Europe/Volgograd": 964, + "Europe/Warsaw": 965, + "Europe/Zaporozhye": 966, + "Europe/Zurich": 967, + "Indian/Chagos": 968, + "Indian/Christmas": 969, + "Indian/Cocos": 970, + "Indian/Kerguelen": 971, + "Indian/Mahe": 972, + "Indian/Maldives": 973, + "Indian/Mauritius": 974, + "Indian/Reunion": 975, + "Pacific/Apia": 976, + "Pacific/Auckland": 977, + "Pacific/Bougainville": 978, + "Pacific/Chatham": 979, + "Pacific/Chuuk": 980, + "Pacific/Easter": 981, + "Pacific/Efate": 982, + "Pacific/Fakaofo": 983, + "Pacific/Fiji": 984, + "Pacific/Funafuti": 985, + "Pacific/Galapagos": 986, + "Pacific/Gambier": 987, + "Pacific/Guadalcanal": 988, + "Pacific/Guam": 989, + "Pacific/Honolulu": 990, + "Pacific/Kanton": 991, + "Pacific/Kiritimati": 992, + "Pacific/Kosrae": 993, + "Pacific/Kwajalein": 994, + "Pacific/Majuro": 995, + "Pacific/Marquesas": 996, + "Pacific/Nauru": 997, + "Pacific/Niue": 998, + "Pacific/Norfolk": 999, + "Pacific/Noumea": 1000, + "Pacific/Pago_Pago": 1001, + "Pacific/Palau": 1002, + "Pacific/Pitcairn": 1003, + "Pacific/Pohnpei": 1004, + "Pacific/Port_Moresby": 1005, + "Pacific/Rarotonga": 1006, + "Pacific/Tahiti": 1007, + "Pacific/Tarawa": 1008, + "Pacific/Tongatapu": 1009, + "Pacific/Wake": 1010, + "Pacific/Wallis": 1011, + "Africa/Accra": 672, + "Africa/Addis_Ababa": 685, + "Africa/Asmara": 685, + "Africa/Asmera": 685, + "Africa/Bamako": 672, + "Africa/Bangui": 682, + "Africa/Banjul": 672, + "Africa/Blantyre": 683, + "Africa/Brazzaville": 682, + "Africa/Bujumbura": 683, + "Africa/Conakry": 672, + "Africa/Dakar": 672, + "Africa/Dar_es_Salaam": 685, + "Africa/Djibouti": 685, + "Africa/Douala": 682, + "Africa/Freetown": 672, + "Africa/Gaborone": 683, + "Africa/Harare": 683, + "Africa/Kampala": 685, + "Africa/Kigali": 683, + "Africa/Kinshasa": 682, + "Africa/Libreville": 682, + "Africa/Lome": 672, + "Africa/Luanda": 682, + "Africa/Lubumbashi": 683, + "Africa/Lusaka": 683, + "Africa/Malabo": 682, + "Africa/Maseru": 679, + "Africa/Mbabane": 679, + "Africa/Mogadishu": 685, + "Africa/Niamey": 682, + "Africa/Nouakchott": 672, + "Africa/Ouagadougou": 672, + "Africa/Porto-Novo": 682, + "Africa/Timbuktu": 672, + "America/Anguilla": 789, + "America/Antigua": 789, + "America/Argentina/ComodRivadavia": 695, + "America/Aruba": 789, + "America/Atikokan": 783, + "America/Atka": 691, + "America/Blanc-Sablon": 789, + "America/Buenos_Aires": 694, + "America/Catamarca": 695, + "America/Cayman": 783, + "America/Coral_Harbour": 783, + "America/Cordoba": 696, + "America/Creston": 786, + "America/Curacao": 789, + "America/Dominica": 789, + "America/Ensenada": 808, + "America/Fort_Wayne": 743, + "America/Godthab": 781, + "America/Grenada": 789, + "America/Guadeloupe": 789, + "America/Indianapolis": 743, + "America/Jujuy": 697, + "America/Knox_IN": 744, + "America/Kralendijk": 789, + "America/Louisville": 755, + "America/Lower_Princes": 789, + "America/Marigot": 789, + "America/Mendoza": 699, + "America/Montreal": 809, + "America/Montserrat": 789, + "America/Nassau": 809, + "America/Port_of_Spain": 789, + "America/Porto_Acre": 796, + "America/Rosario": 696, + "America/Santa_Isabel": 808, + "America/Shiprock": 727, + "America/St_Barthelemy": 789, + "America/St_Kitts": 789, + "America/St_Lucia": 789, + "America/St_Thomas": 789, + "America/St_Vincent": 789, + "America/Tortola": 789, + "America/Virgin": 789, + "Antarctica/DumontDUrville": 1005, + "Antarctica/McMurdo": 977, + "Antarctica/South_Pole": 977, + "Antarctica/Syowa": 878, + "Arctic/Longyearbyen": 948, + "Asia/Aden": 878, + "Asia/Ashkhabad": 828, + "Asia/Bahrain": 875, + "Asia/Calcutta": 860, + "Asia/Chongqing": 882, + "Asia/Chungking": 882, + "Asia/Dacca": 841, + "Asia/Harbin": 882, + "Asia/Istanbul": 936, + "Asia/Kashgar": 893, + "Asia/Katmandu": 858, + "Asia/Kuwait": 878, + "Asia/Macao": 864, + "Asia/Muscat": 843, + "Asia/Phnom_Penh": 832, + "Asia/Rangoon": 897, + "Asia/Saigon": 848, + "Asia/Tel_Aviv": 854, + "Asia/Thimbu": 889, + "Asia/Ujung_Pandang": 866, + "Asia/Ulan_Bator": 892, + "Asia/Vientiane": 832, + "Atlantic/Faeroe": 904, + "Atlantic/Jan_Mayen": 948, + "Atlantic/St_Helena": 672, + "Australia/ACT": 919, + "Australia/Canberra": 919, + "Australia/Currie": 914, + "Australia/LHI": 916, + "Australia/NSW": 919, + "Australia/North": 912, + "Australia/Queensland": 910, + "Australia/South": 909, + "Australia/Tasmania": 914, + "Australia/Victoria": 917, + "Australia/West": 918, + "Australia/Yancowinna": 911, + "Brazil/Acre": 796, + "Brazil/DeNoronha": 777, + "Brazil/East": 800, + "Brazil/West": 762, + "Canada/Atlantic": 740, + "Canada/Central": 812, + "Canada/Eastern": 809, + "Canada/Mountain": 729, + "Canada/Newfoundland": 803, + "Canada/Pacific": 810, + "Canada/Saskatchewan": 794, + "Canada/Yukon": 811, + "Chile/Continental": 798, + "Chile/EasterIsland": 981, + "Cuba": 741, + "Egypt": 675, + "Eire": 933, + "Etc/GMT+0": 920, + "Etc/GMT-0": 920, + "Etc/GMT0": 920, + "Etc/Greenwich": 920, + "Etc/UCT": 921, + "Etc/Universal": 921, + "Etc/Zulu": 921, + "Europe/Belfast": 941, + "Europe/Bratislava": 950, + "Europe/Busingen": 967, + "Europe/Guernsey": 941, + "Europe/Isle_of_Man": 941, + "Europe/Jersey": 941, + "Europe/Ljubljana": 926, + "Europe/Mariehamn": 935, + "Europe/Nicosia": 868, + "Europe/Podgorica": 926, + "Europe/San_Marino": 952, + "Europe/Sarajevo": 926, + "Europe/Skopje": 926, + "Europe/Tiraspol": 931, + "Europe/Vaduz": 967, + "Europe/Vatican": 952, + "Europe/Zagreb": 926, + "GB": 941, + "GB-Eire": 941, + "GMT+0": 920, + "GMT-0": 920, + "GMT0": 920, + "Greenwich": 920, + "Hongkong": 849, + "Iceland": 906, + "Indian/Antananarivo": 685, + "Indian/Comoro": 685, + "Indian/Mayotte": 685, + "Iran": 888, + "Israel": 854, + "Jamaica": 753, + "Japan": 890, + "Kwajalein": 994, + "Libya": 688, + "Mexico/BajaNorte": 808, + "Mexico/BajaSur": 765, + "Mexico/General": 769, + "NZ": 977, + "NZ-CHAT": 979, + "Navajo": 727, + "PRC": 882, + "Pacific/Enderbury": 991, + "Pacific/Johnston": 990, + "Pacific/Midway": 1001, + "Pacific/Ponape": 1004, + "Pacific/Saipan": 989, + "Pacific/Samoa": 1001, + "Pacific/Truk": 980, + "Pacific/Yap": 980, + "Poland": 965, + "Portugal": 940, + "ROC": 885, + "ROK": 881, + "Singapore": 883, + "Turkey": 936, + "UCT": 921, + "US/Alaska": 692, + "US/Aleutian": 691, + "US/Arizona": 786, + "US/Central": 720, + "US/East-Indiana": 743, + "US/Eastern": 774, + "US/Hawaii": 990, + "US/Indiana-Starke": 744, + "US/Michigan": 728, + "US/Mountain": 727, + "US/Pacific": 759, + "US/Samoa": 1001, + "Universal": 921, + "W-SU": 947, + "Zulu": 921, +} diff --git a/deadline_io.go b/deadline_io.go new file mode 100644 index 000000000..da547f5e2 --- /dev/null +++ b/deadline_io.go @@ -0,0 +1,27 @@ +package tarantool + +import ( + "net" + "time" +) + +type deadlineIO struct { + to time.Duration + c net.Conn +} + +func (d *deadlineIO) Write(b []byte) (n int, err error) { + if d.to > 0 { + d.c.SetWriteDeadline(time.Now().Add(d.to)) + } + n, err = d.c.Write(b) + return +} + +func (d *deadlineIO) Read(b []byte) (n int, err error) { + if d.to > 0 { + d.c.SetReadDeadline(time.Now().Add(d.to)) + } + n, err = d.c.Read(b) + return +} diff --git a/decimal/bcd.go b/decimal/bcd.go new file mode 100644 index 000000000..61a437276 --- /dev/null +++ b/decimal/bcd.go @@ -0,0 +1,250 @@ +package decimal + +// Package decimal implements methods to encode and decode BCD. +// +// BCD (Binary-Coded Decimal) is a sequence of bytes representing decimal +// digits of the encoded number (each byte has two decimal digits each encoded +// using 4-bit nibbles), so byte >> 4 is the first digit and byte & 0x0f is the +// second digit. The leftmost digit in the array is the most significant. The +// rightmost digit in the array is the least significant. +// +// The first byte of the BCD array contains the first digit of the number, +// represented as follows: +// +// | 4 bits | 4 bits | +// = 0x = the 1st digit +// +// (The first nibble contains 0 if the decimal number has an even number of +// digits). The last byte of the BCD array contains the last digit of the +// number and the final nibble, represented as follows: +// +// | 4 bits | 4 bits | +// = the last digit = nibble +// +// The final nibble represents the number's sign: 0x0a, 0x0c, 0x0e, 0x0f stand +// for plus, 0x0b and 0x0d stand for minus. +// +// Examples: +// +// The decimal -12.34 will be encoded as 0xd6, 0x01, 0x02, 0x01, 0x23, 0x4d: +// +// | MP_EXT (fixext 4) | MP_DECIMAL | scale | 1 | 2,3 | 4 (minus) | +// | 0xd6 | 0x01 | 0x02 | 0x01 | 0x23 | 0x4d | +// +// The decimal 0.000000000000000000000000000000000010 will be encoded as +// 0xc7, 0x03, 0x01, 0x24, 0x01, 0x0c: +// +// | MP_EXT (ext 8) | length | MP_DECIMAL | scale | 1 | 0 (plus) | +// | 0xc7 | 0x03 | 0x01 | 0x24 | 0x01 | 0x0c | +// +// See also: +// +// * MessagePack extensions: +// https://www.tarantool.io/en/doc/latest/dev_guide/internals/msgpack_extensions/ +// +// * An implementation in C language: +// https://github.com/tarantool/decNumber/blob/master/decPacked.c + +import ( + "bytes" + "fmt" + "strings" + + "github.com/vmihailenco/msgpack/v5" +) + +const ( + bytePlus = byte(0x0c) + byteMinus = byte(0x0d) +) + +var isNegative = [256]bool{ + 0x0a: false, + 0x0b: true, + 0x0c: false, + 0x0d: true, + 0x0e: false, + 0x0f: false, +} + +// Calculate a number of digits in a buffer with decimal number. +// +// Plus, minus, point and leading zeroes do not count. +// Contains a quirk for a zero - returns 1. +// +// Examples (see more examples in tests): +// +// - 0.0000000000000001 - 1 digit +// +// - 00012.34 - 4 digits +// +// - 0.340 - 3 digits +// +// - 0 - 1 digit +func getNumberLength(buf string) int { + if len(buf) == 0 { + return 0 + } + n := 0 + for _, ch := range []byte(buf) { + if ch >= '1' && ch <= '9' { + n += 1 + } else if ch == '0' && n != 0 { + n += 1 + } + } + + // Fix a case with a single 0. + if n == 0 { + n = 1 + } + + return n +} + +// encodeStringToBCD converts a string buffer to BCD Packed Decimal. +// +// The number is converted to a BCD packed decimal byte array, right aligned in +// the BCD array, whose length is indicated by the second parameter. The final +// 4-bit nibble in the array will be a sign nibble, 0x0c for "+" and 0x0d for +// "-". Unused bytes and nibbles to the left of the number are set to 0. scale +// is set to the scale of the number (this is the exponent, negated). +func encodeStringToBCD(buf string) ([]byte, error) { + if len(buf) == 0 { + return nil, fmt.Errorf("length of number is zero") + } + signByte := bytePlus // By default number is positive. + if buf[0] == '-' { + signByte = byteMinus + } + + // The first nibble should contain 0, if the decimal number has an even + // number of digits. Therefore highNibble is false when decimal number + // is even. + highNibble := true + l := getNumberLength(buf) + if l%2 == 0 { + highNibble = false + } + scale := 0 // By default decimal number is integer. + var byteBuf []byte + for i, ch := range []byte(buf) { + // Skip leading zeroes. + if (len(byteBuf) == 0) && ch == '0' { + continue + } + if (i == 0) && (ch == '-' || ch == '+') { + continue + } + // Calculate a number of digits after the decimal point. + if ch == '.' { + if scale != 0 { + return nil, fmt.Errorf("number contains more than one point") + } + scale = len(buf) - i - 1 + continue + } + + if ch < '0' || ch > '9' { + return nil, fmt.Errorf("failed to convert symbol '%c' to a digit", ch) + } + digit := ch - '0' + if highNibble { + // Add a digit to a high nibble. + digit <<= 4 + byteBuf = append(byteBuf, digit) + highNibble = false + } else { + if len(byteBuf) == 0 { + byteBuf = make([]byte, 1) + } + // Add a digit to a low nibble. + lowByteIdx := len(byteBuf) - 1 + byteBuf[lowByteIdx] |= digit + highNibble = true + } + } + if len(byteBuf) == 0 { + // a special case: -0 + signByte = bytePlus + } + if highNibble { + // Put a sign to a high nibble. + byteBuf = append(byteBuf, signByte) + } else { + // Put a sign to a low nibble. + lowByteIdx := len(byteBuf) - 1 + byteBuf[lowByteIdx] |= signByte + } + byteBuf = append([]byte{byte(scale)}, byteBuf...) + + return byteBuf, nil +} + +// decodeStringFromBCD converts a BCD Packed Decimal to a string buffer. +// +// The BCD packed decimal byte array, together with an associated scale, is +// converted to a string. The BCD array is assumed full of digits, and must be +// ended by a 4-bit sign nibble in the least significant four bits of the final +// byte. The scale is used (negated) as the exponent of the decimal number. +// Note that zeroes may have a sign and/or a scale. +func decodeStringFromBCD(bcdBuf []byte) (string, int, error) { + // Read scale. + buf := bytes.NewBuffer(bcdBuf) + dec := msgpack.NewDecoder(buf) + scale, err := dec.DecodeInt() + if err != nil { + return "", 0, fmt.Errorf("unable to decode the decimal scale: %w", err) + } + + // Get the data without the scale. + bcdBuf = buf.Bytes() + bufLen := len(bcdBuf) + + // Every nibble contains a digit, and the last low nibble contains a + // sign. + ndigits := bufLen*2 - 1 + + // The first nibble contains 0 if the decimal number has an even number of + // digits. Decrease a number of digits if so. + if bcdBuf[0]&0xf0 == 0 { + ndigits -= 1 + } + + // Reserve bytes for dot and sign. + numLen := ndigits + 2 + + var bld strings.Builder + bld.Grow(numLen) + + // Add a sign, it is encoded in a low nibble of a last byte. + lastByte := bcdBuf[bufLen-1] + sign := lastByte & 0x0f + if isNegative[sign] { + bld.WriteByte('-') + } + + const MaxDigit = 0x09 + // Builds a buffer with symbols of decimal number (digits, dot and sign). + processNibble := func(nibble byte) { + if nibble <= MaxDigit { + bld.WriteByte(nibble + '0') + ndigits-- + } + } + + for i, bcdByte := range bcdBuf { + highNibble := bcdByte >> 4 + lowNibble := bcdByte & 0x0f + // Skip a first high nibble as no digit there. + if i != 0 || highNibble != 0 { + processNibble(highNibble) + } + processNibble(lowNibble) + } + + if bld.Len() == 0 || isNegative[sign] && bld.Len() == 1 { + bld.WriteByte('0') + } + return bld.String(), -1 * scale, nil +} diff --git a/decimal/config.lua b/decimal/config.lua new file mode 100644 index 000000000..58b038958 --- /dev/null +++ b/decimal/config.lua @@ -0,0 +1,41 @@ +local decimal = require('decimal') +local msgpack = require('msgpack') + +-- Do not set listen for now so connector won't be +-- able to send requests until everything is configured. +box.cfg{ + work_dir = os.getenv("TEST_TNT_WORK_DIR"), +} + +box.schema.user.create('test', { password = 'test' , if_not_exists = true }) +box.schema.user.grant('test', 'execute', 'universe', nil, { if_not_exists = true }) + +local decimal_msgpack_supported = pcall(msgpack.encode, decimal.new(1)) +if not decimal_msgpack_supported then + error('Decimal unsupported, use Tarantool 2.2 or newer') +end + +local s = box.schema.space.create('testDecimal', { + id = 524, + if_not_exists = true, +}) +s:create_index('primary', { + type = 'TREE', + parts = { + { + field = 1, + type = 'decimal', + }, + }, + if_not_exists = true +}) +s:truncate() + +box.schema.user.grant('test', 'read,write', 'space', 'testDecimal', { if_not_exists = true }) + +-- Set listen only when every other thing is configured. +box.cfg{ + listen = os.getenv("TEST_TNT_LISTEN"), +} + +require('console').start() diff --git a/decimal/decimal.go b/decimal/decimal.go new file mode 100644 index 000000000..3c2681238 --- /dev/null +++ b/decimal/decimal.go @@ -0,0 +1,150 @@ +// Package decimal with support of Tarantool's decimal data type. +// +// Decimal data type supported in Tarantool since 2.2. +// +// Since: 1.7.0 +// +// See also: +// +// - Tarantool MessagePack extensions: +// https://www.tarantool.io/en/doc/latest/dev_guide/internals/msgpack_extensions/#the-decimal-type +// +// - Tarantool data model: +// https://www.tarantool.io/en/doc/latest/book/box/data_model/ +// +// - Tarantool issue for support decimal type: +// https://github.com/tarantool/tarantool/issues/692 +// +// - Tarantool module decimal: +// https://www.tarantool.io/en/doc/latest/reference/reference_lua/decimal/ +package decimal + +import ( + "fmt" + "reflect" + + "github.com/shopspring/decimal" + "github.com/vmihailenco/msgpack/v5" +) + +// Decimal numbers have 38 digits of precision, that is, the total +// number of digits before and after the decimal point can be 38. +// A decimal operation will fail if overflow happens (when a number is +// greater than 10^38 - 1 or less than -10^38 - 1). +// +// See also: +// +// - Tarantool module decimal: +// https://www.tarantool.io/en/doc/latest/reference/reference_lua/decimal/ + +const ( + // Decimal external type. + decimalExtID = 1 + decimalPrecision = 38 +) + +var ( + one = decimal.NewFromInt(1) + // -10^decimalPrecision - 1 + minSupportedDecimal = maxSupportedDecimal.Neg().Sub(one) + // 10^decimalPrecision - 1 + maxSupportedDecimal = decimal.New(1, decimalPrecision).Sub(one) +) + +//go:generate go tool gentypes -ext-code 1 Decimal +type Decimal struct { + decimal.Decimal +} + +// MakeDecimal creates a new Decimal from a decimal.Decimal. +func MakeDecimal(decimal decimal.Decimal) Decimal { + return Decimal{Decimal: decimal} +} + +// MakeDecimalFromString creates a new Decimal from a string. +func MakeDecimalFromString(src string) (Decimal, error) { + result := Decimal{} + dec, err := decimal.NewFromString(src) + if err != nil { + return result, err + } + result = MakeDecimal(dec) + return result, nil +} + +var ( + ErrDecimalOverflow = fmt.Errorf("msgpack: decimal number is bigger than"+ + " maximum supported number (10^%d - 1)", decimalPrecision) + ErrDecimalUnderflow = fmt.Errorf("msgpack: decimal number is lesser than"+ + " minimum supported number (-10^%d - 1)", decimalPrecision) +) + +// MarshalMsgpack implements a custom msgpack marshaler. +func (d Decimal) MarshalMsgpack() ([]byte, error) { + switch { + case d.GreaterThan(maxSupportedDecimal): + return nil, ErrDecimalOverflow + case d.LessThan(minSupportedDecimal): + return nil, ErrDecimalUnderflow + } + + // Decimal values can be encoded to fixext MessagePack, where buffer + // has a fixed length encoded by first byte, and ext MessagePack, where + // buffer length is not fixed and encoded by a number in a separate + // field: + // + // +--------+-------------------+------------+===============+ + // | MP_EXT | length (optional) | MP_DECIMAL | PackedDecimal | + // +--------+-------------------+------------+===============+ + strBuf := d.String() + bcdBuf, err := encodeStringToBCD(strBuf) + if err != nil { + return nil, fmt.Errorf("msgpack: can't encode string (%s) to a BCD buffer: %w", strBuf, err) + } + return bcdBuf, nil +} + +// UnmarshalMsgpack implements a custom msgpack unmarshaler. +func (d *Decimal) UnmarshalMsgpack(data []byte) error { + digits, exp, err := decodeStringFromBCD(data) + if err != nil { + return fmt.Errorf("msgpack: can't decode string from BCD buffer (%x): %w", data, err) + } + + dec, err := decimal.NewFromString(digits) + if err != nil { + return fmt.Errorf("msgpack: can't encode string (%s) to a decimal number: %w", digits, err) + } + + if exp != 0 { + dec = dec.Shift(int32(exp)) + } + + *d = MakeDecimal(dec) + return nil +} + +func decimalEncoder(e *msgpack.Encoder, v reflect.Value) ([]byte, error) { + dec := v.Interface().(Decimal) + + return dec.MarshalMsgpack() +} + +func decimalDecoder(d *msgpack.Decoder, v reflect.Value, extLen int) error { + b := make([]byte, extLen) + + switch n, err := d.Buffered().Read(b); { + case err != nil: + return err + case n < extLen: + return fmt.Errorf("msgpack: unexpected end of stream after %d decimal bytes", n) + } + + ptr := v.Addr().Interface().(*Decimal) + return ptr.UnmarshalMsgpack(b) +} + +func init() { + msgpack.RegisterExtDecoder(decimalExtID, Decimal{}, decimalDecoder) + msgpack.RegisterExtEncoder(decimalExtID, Decimal{}, decimalEncoder) +} diff --git a/decimal/decimal_gen.go b/decimal/decimal_gen.go new file mode 100644 index 000000000..0f9b18e3a --- /dev/null +++ b/decimal/decimal_gen.go @@ -0,0 +1,241 @@ +// Code generated by github.com/tarantool/go-option; DO NOT EDIT. + +package decimal + +import ( + "fmt" + + "github.com/vmihailenco/msgpack/v5" + "github.com/vmihailenco/msgpack/v5/msgpcode" + + "github.com/tarantool/go-option" +) + +// OptionalDecimal represents an optional value of type Decimal. +// It can either hold a valid Decimal (IsSome == true) or be empty (IsZero == true). +type OptionalDecimal struct { + value Decimal + exists bool +} + +// SomeOptionalDecimal creates an optional OptionalDecimal with the given Decimal value. +// The returned OptionalDecimal will have IsSome() == true and IsZero() == false. +func SomeOptionalDecimal(value Decimal) OptionalDecimal { + return OptionalDecimal{ + value: value, + exists: true, + } +} + +// NoneOptionalDecimal creates an empty optional OptionalDecimal value. +// The returned OptionalDecimal will have IsSome() == false and IsZero() == true. +// +// Example: +// +// o := NoneOptionalDecimal() +// if o.IsZero() { +// fmt.Println("value is absent") +// } +func NoneOptionalDecimal() OptionalDecimal { + return OptionalDecimal{} +} + +func (o OptionalDecimal) newEncodeError(err error) error { + if err == nil { + return nil + } + return &option.EncodeError{ + Type: "OptionalDecimal", + Parent: err, + } +} + +func (o OptionalDecimal) newDecodeError(err error) error { + if err == nil { + return nil + } + + return &option.DecodeError{ + Type: "OptionalDecimal", + Parent: err, + } +} + +// IsSome returns true if the OptionalDecimal contains a value. +// This indicates the value is explicitly set (not None). +func (o OptionalDecimal) IsSome() bool { + return o.exists +} + +// IsZero returns true if the OptionalDecimal does not contain a value. +// Equivalent to !IsSome(). Useful for consistency with types where +// zero value (e.g. 0, false, zero struct) is valid and needs to be distinguished. +func (o OptionalDecimal) IsZero() bool { + return !o.exists +} + +// IsNil is an alias for IsZero. +// +// This method is provided for compatibility with the msgpack Encoder interface. +func (o OptionalDecimal) IsNil() bool { + return o.IsZero() +} + +// Get returns the stored value and a boolean flag indicating its presence. +// If the value is present, returns (value, true). +// If the value is absent, returns (zero value of Decimal, false). +// +// Recommended usage: +// +// if value, ok := o.Get(); ok { +// // use value +// } +func (o OptionalDecimal) Get() (Decimal, bool) { + return o.value, o.exists +} + +// MustGet returns the stored value if it is present. +// Panics if the value is absent (i.e., IsZero() == true). +// +// Use with caution — only when you are certain the value exists. +// +// Panics with: "optional value is not set" if no value is set. +func (o OptionalDecimal) MustGet() Decimal { + if !o.exists { + panic("optional value is not set") + } + + return o.value +} + +// Unwrap returns the stored value regardless of presence. +// If no value is set, returns the zero value for Decimal. +// +// Warning: Does not check presence. Use IsSome() before calling if you need +// to distinguish between absent value and explicit zero value. +func (o OptionalDecimal) Unwrap() Decimal { + return o.value +} + +// UnwrapOr returns the stored value if present. +// Otherwise, returns the provided default value. +// +// Example: +// +// o := NoneOptionalDecimal() +// v := o.UnwrapOr(someDefaultOptionalDecimal) +func (o OptionalDecimal) UnwrapOr(defaultValue Decimal) Decimal { + if o.exists { + return o.value + } + + return defaultValue +} + +// UnwrapOrElse returns the stored value if present. +// Otherwise, calls the provided function and returns its result. +// Useful when the default value requires computation or side effects. +// +// Example: +// +// o := NoneOptionalDecimal() +// v := o.UnwrapOrElse(func() Decimal { return computeDefault() }) +func (o OptionalDecimal) UnwrapOrElse(defaultValue func() Decimal) Decimal { + if o.exists { + return o.value + } + + return defaultValue() +} + +func (o OptionalDecimal) encodeValue(encoder *msgpack.Encoder) error { + value, err := o.value.MarshalMsgpack() + if err != nil { + return err + } + + err = encoder.EncodeExtHeader(1, len(value)) + if err != nil { + return err + } + + _, err = encoder.Writer().Write(value) + if err != nil { + return err + } + + return nil +} + +// EncodeMsgpack encodes the OptionalDecimal value using MessagePack format. +// - If the value is present, it is encoded as Decimal. +// - If the value is absent (None), it is encoded as nil. +// +// Returns an error if encoding fails. +func (o OptionalDecimal) EncodeMsgpack(encoder *msgpack.Encoder) error { + if o.exists { + return o.newEncodeError(o.encodeValue(encoder)) + } + + return o.newEncodeError(encoder.EncodeNil()) +} + +func (o *OptionalDecimal) decodeValue(decoder *msgpack.Decoder) error { + tp, length, err := decoder.DecodeExtHeader() + switch { + case err != nil: + return o.newDecodeError(err) + case tp != 1: + return o.newDecodeError(fmt.Errorf("invalid extension code: %d", tp)) + } + + a := make([]byte, length) + if err := decoder.ReadFull(a); err != nil { + return o.newDecodeError(err) + } + + if err := o.value.UnmarshalMsgpack(a); err != nil { + return o.newDecodeError(err) + } + + o.exists = true + return nil +} + +func (o *OptionalDecimal) checkCode(code byte) bool { + return msgpcode.IsExt(code) +} + +// DecodeMsgpack decodes a OptionalDecimal value from MessagePack format. +// Supports two input types: +// - nil: interpreted as no value (NoneOptionalDecimal) +// - Decimal: interpreted as a present value (SomeOptionalDecimal) +// +// Returns an error if the input type is unsupported or decoding fails. +// +// After successful decoding: +// - on nil: exists = false, value = default zero value +// - on Decimal: exists = true, value = decoded value +func (o *OptionalDecimal) DecodeMsgpack(decoder *msgpack.Decoder) error { + code, err := decoder.PeekCode() + if err != nil { + return o.newDecodeError(err) + } + + switch { + case code == msgpcode.Nil: + o.exists = false + + return o.newDecodeError(decoder.Skip()) + case o.checkCode(code): + err := o.decodeValue(decoder) + if err != nil { + return o.newDecodeError(err) + } + o.exists = true + + return err + default: + return o.newDecodeError(fmt.Errorf("unexpected code: %d", code)) + } +} diff --git a/decimal/decimal_gen_test.go b/decimal/decimal_gen_test.go new file mode 100644 index 000000000..50f22bf23 --- /dev/null +++ b/decimal/decimal_gen_test.go @@ -0,0 +1,117 @@ +package decimal + +import ( + "bytes" + "testing" + + "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" + "github.com/vmihailenco/msgpack/v5" +) + +func TestSomeOptionalDecimal(t *testing.T) { + val := MakeDecimal(decimal.NewFromFloat(1.23)) + opt := SomeOptionalDecimal(val) + + assert.True(t, opt.IsSome()) + assert.False(t, opt.IsZero()) + + v, ok := opt.Get() + assert.True(t, ok) + assert.Equal(t, val, v) +} + +func TestNoneOptionalDecimal(t *testing.T) { + opt := NoneOptionalDecimal() + + assert.False(t, opt.IsSome()) + assert.True(t, opt.IsZero()) + + _, ok := opt.Get() + assert.False(t, ok) +} + +func TestOptionalDecimal_MustGet(t *testing.T) { + val := MakeDecimal(decimal.NewFromFloat(1.23)) + optSome := SomeOptionalDecimal(val) + optNone := NoneOptionalDecimal() + + assert.Equal(t, val, optSome.MustGet()) + assert.Panics(t, func() { optNone.MustGet() }) +} + +func TestOptionalDecimal_Unwrap(t *testing.T) { + val := MakeDecimal(decimal.NewFromFloat(1.23)) + optSome := SomeOptionalDecimal(val) + optNone := NoneOptionalDecimal() + + assert.Equal(t, val, optSome.Unwrap()) + assert.Equal(t, Decimal{}, optNone.Unwrap()) +} + +func TestOptionalDecimal_UnwrapOr(t *testing.T) { + val := MakeDecimal(decimal.NewFromFloat(1.23)) + def := MakeDecimal(decimal.NewFromFloat(4.56)) + optSome := SomeOptionalDecimal(val) + optNone := NoneOptionalDecimal() + + assert.Equal(t, val, optSome.UnwrapOr(def)) + assert.Equal(t, def, optNone.UnwrapOr(def)) +} + +func TestOptionalDecimal_UnwrapOrElse(t *testing.T) { + val := MakeDecimal(decimal.NewFromFloat(1.23)) + def := MakeDecimal(decimal.NewFromFloat(4.56)) + optSome := SomeOptionalDecimal(val) + optNone := NoneOptionalDecimal() + + assert.Equal(t, val, optSome.UnwrapOrElse(func() Decimal { return def })) + assert.Equal(t, def, optNone.UnwrapOrElse(func() Decimal { return def })) +} + +func TestOptionalDecimal_EncodeDecodeMsgpack_Some(t *testing.T) { + val := MakeDecimal(decimal.NewFromFloat(1.23)) + some := SomeOptionalDecimal(val) + + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + dec := msgpack.NewDecoder(&buf) + + err := enc.Encode(some) + assert.NoError(t, err) + + var decodedSome OptionalDecimal + err = dec.Decode(&decodedSome) + assert.NoError(t, err) + assert.True(t, decodedSome.IsSome()) + assert.Equal(t, val, decodedSome.Unwrap()) +} + +func TestOptionalDecimal_EncodeDecodeMsgpack_None(t *testing.T) { + none := NoneOptionalDecimal() + + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + dec := msgpack.NewDecoder(&buf) + + err := enc.Encode(none) + assert.NoError(t, err) + + var decodedNone OptionalDecimal + err = dec.Decode(&decodedNone) + assert.NoError(t, err) + assert.True(t, decodedNone.IsZero()) +} + +func TestOptionalDecimal_EncodeDecodeMsgpack_InvalidType(t *testing.T) { + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + dec := msgpack.NewDecoder(&buf) + + err := enc.Encode(123) + assert.NoError(t, err) + + var decodedInvalid OptionalDecimal + err = dec.Decode(&decodedInvalid) + assert.Error(t, err) +} diff --git a/decimal/decimal_test.go b/decimal/decimal_test.go new file mode 100644 index 000000000..f75494204 --- /dev/null +++ b/decimal/decimal_test.go @@ -0,0 +1,703 @@ +package decimal_test + +import ( + "encoding/hex" + "fmt" + "log" + "os" + "reflect" + "testing" + "time" + + "github.com/shopspring/decimal" + "github.com/vmihailenco/msgpack/v5" + + . "github.com/tarantool/go-tarantool/v3" + . "github.com/tarantool/go-tarantool/v3/decimal" + "github.com/tarantool/go-tarantool/v3/test_helpers" +) + +var isDecimalSupported = false + +var server = "127.0.0.1:3013" +var dialer = NetDialer{ + Address: server, + User: "test", + Password: "test", +} +var opts = Opts{ + Timeout: 5 * time.Second, +} + +func skipIfDecimalUnsupported(t *testing.T) { + t.Helper() + + if isDecimalSupported == false { + t.Skip("Skipping test for Tarantool without decimal support in msgpack") + } +} + +var space = "testDecimal" +var index = "primary" + +type TupleDecimal struct { + number Decimal +} + +func (t *TupleDecimal) EncodeMsgpack(e *msgpack.Encoder) error { + if err := e.EncodeArrayLen(1); err != nil { + return err + } + return e.EncodeValue(reflect.ValueOf(&t.number)) +} + +func (t *TupleDecimal) DecodeMsgpack(d *msgpack.Decoder) error { + var err error + var l int + if l, err = d.DecodeArrayLen(); err != nil { + return err + } + if l != 1 { + return fmt.Errorf("Array length doesn't match: %d", l) + } + + res, err := d.DecodeInterface() + if err != nil { + return err + } + + if dec, ok := res.(Decimal); !ok { + return fmt.Errorf("decimal doesn't match") + } else { + t.number = dec + } + return nil +} + +var benchmarkSamples = []struct { + numString string + mpBuf string + fixExt bool +}{ + {"0.7", "d501017c", true}, + {"0.3", "d501013c", true}, + {"0.00000000000000000000000000000000000001", "d501261c", true}, + {"0.00000000000000000000000000000000000009", "d501269c", true}, + {"-18.34", "d6010201834d", true}, + {"-108.123456789", "d701090108123456789d", true}, + {"-11111111111111111111111111111111111111", "c7150100011111111111111111111111111111111111111d", + false}, +} + +var correctnessSamples = []struct { + numString string + mpBuf string + fixExt bool +}{ + {"100", "c7030100100c", false}, + {"0.1", "d501011c", true}, + {"-0.1", "d501011d", true}, + {"0.0000000000000000000000000000000000001", "d501251c", true}, + {"-0.0000000000000000000000000000000000001", "d501251d", true}, + {"0.00000000000000000000000000000000000001", "d501261c", true}, + {"-0.00000000000000000000000000000000000001", "d501261d", true}, + {"1", "d501001c", true}, + {"-1", "d501001d", true}, + {"0", "d501000c", true}, + {"-0", "d501000c", true}, + {"0.01", "d501021c", true}, + {"0.001", "d501031c", true}, + {"99999999999999999999999999999999999999", "c7150100099999999999999999999999999999999999999c", + false}, + {"-99999999999999999999999999999999999999", "c7150100099999999999999999999999999999999999999d", + false}, + {"-12.34", "d6010201234d", true}, + {"12.34", "d6010201234c", true}, + {"1.4", "c7030101014c", false}, + {"2.718281828459045", "c70a010f02718281828459045c", false}, + {"-2.718281828459045", "c70a010f02718281828459045d", false}, + {"3.141592653589793", "c70a010f03141592653589793c", false}, + {"-3.141592653589793", "c70a010f03141592653589793d", false}, + {"1234567891234567890.0987654321987654321", "c7150113012345678912345678900987654321987654321c", + false}, + {"-1234567891234567890.0987654321987654321", + "c7150113012345678912345678900987654321987654321d", false}, +} + +var correctnessDecodeSamples = []struct { + numString string + mpBuf string + fixExt bool +}{ + {"1e2", "d501fe1c", true}, + {"1e33", "c70301d0df1c", false}, + {"1.1e31", "c70301e2011c", false}, + {"13e-2", "c7030102013c", false}, + {"-1e3", "d501fd1d", true}, +} + +// There is a difference between encoding result from a raw string and from +// decimal.Decimal. It's expected because decimal.Decimal simplifies decimals: +// 0.00010000 -> 0.0001 + +var rawSamples = []struct { + numString string + mpBuf string + fixExt bool +}{ + {"0.000000000000000000000000000000000010", "c7030124010c", false}, + {"0.010", "c7030103010c", false}, + {"123.456789000000000", "c70b010f0123456789000000000c", false}, +} + +var decimalSamples = []struct { + numString string + mpBuf string + fixExt bool +}{ + {"0.000000000000000000000000000000000010", "d501231c", true}, + {"0.010", "d501021c", true}, + {"123.456789000000000", "c7060106123456789c", false}, +} + +func TestMPEncodeDecode(t *testing.T) { + for _, testcase := range benchmarkSamples { + t.Run(testcase.numString, func(t *testing.T) { + decNum, err := MakeDecimalFromString(testcase.numString) + if err != nil { + t.Fatal(err) + } + var buf []byte + tuple := TupleDecimal{number: decNum} + if buf, err = msgpack.Marshal(&tuple); err != nil { + t.Fatalf( + "Failed to msgpack.Encoder decimal number '%s' to a MessagePack buffer: %s", + testcase.numString, err) + } + var v TupleDecimal + if err = msgpack.Unmarshal(buf, &v); err != nil { + t.Fatalf("Failed to decode MessagePack buffer '%x' to a decimal number: %s", + buf, err) + } + if !decNum.Equal(v.number.Decimal) { + t.Fatal("Decimal numbers are not equal") + } + }) + } +} + +var lengthSamples = []struct { + numString string + length int +}{ + {"0.010", 2}, + {"0.01", 1}, + {"-0.1", 1}, + {"0.1", 1}, + {"0", 1}, + {"00.1", 1}, + {"100", 3}, + {"0100", 3}, + {"+1", 1}, + {"-1", 1}, + {"1", 1}, + {"-12.34", 4}, + {"123.456789000000000", 18}, +} + +func TestGetNumberLength(t *testing.T) { + for _, testcase := range lengthSamples { + t.Run(testcase.numString, func(t *testing.T) { + l := GetNumberLength(testcase.numString) + if l != testcase.length { + t.Fatalf("Length is wrong: correct %d, incorrect %d", testcase.length, l) + } + }) + } + + if l := GetNumberLength(""); l != 0 { + t.Fatalf("Length is wrong: correct 0, incorrect %d", l) + } + + if l := GetNumberLength("0"); l != 1 { + t.Fatalf("Length is wrong: correct 0, incorrect %d", l) + } + + if l := GetNumberLength("10"); l != 2 { + t.Fatalf("Length is wrong: correct 0, incorrect %d", l) + } +} + +func TestEncodeStringToBCDIncorrectNumber(t *testing.T) { + referenceErrMsg := "number contains more than one point" + var numString = "0.1.0" + buf, err := EncodeStringToBCD(numString) + if err == nil { + t.Fatalf("no error on encoding a string with incorrect number") + } + if buf != nil { + t.Fatalf("buf is not nil on encoding of a string with double points") + } + if err.Error() != referenceErrMsg { + t.Fatalf("wrong error message on encoding of a string double points") + } + + referenceErrMsg = "length of number is zero" + numString = "" + buf, err = EncodeStringToBCD(numString) + if err == nil { + t.Fatalf("no error on encoding of an empty string") + } + if buf != nil { + t.Fatalf("buf is not nil on encoding of an empty string") + } + if err.Error() != referenceErrMsg { + t.Fatalf("wrong error message on encoding of an empty string") + } + + referenceErrMsg = "failed to convert symbol 'a' to a digit" + numString = "0.1a" + buf, err = EncodeStringToBCD(numString) + if err == nil { + t.Fatalf("no error on encoding of a string number with non-digit symbol") + } + if buf != nil { + t.Fatalf("buf is not nil on encoding of a string number with non-digit symbol") + } + if err.Error() != referenceErrMsg { + t.Fatalf("wrong error message on encoding of a string number with non-digit symbol") + } +} + +func TestEncodeMaxNumber(t *testing.T) { + referenceErrMsg := "msgpack: decimal number is bigger than maximum " + + "supported number (10^38 - 1)" + decNum := decimal.New(1, DecimalPrecision) // // 10^DecimalPrecision + tuple := TupleDecimal{number: MakeDecimal(decNum)} + _, err := msgpack.Marshal(&tuple) + if err == nil { + t.Fatalf("It is possible to msgpack.Encoder a number unsupported by Tarantool") + } + if err.Error() != referenceErrMsg { + t.Fatalf("Incorrect error message on attempt to msgpack.Encoder number unsupported") + } +} + +func TestEncodeMinNumber(t *testing.T) { + referenceErrMsg := "msgpack: decimal number is lesser than minimum " + + "supported number (-10^38 - 1)" + two := decimal.NewFromInt(2) + decNum := decimal.New(1, DecimalPrecision).Neg().Sub(two) // -10^DecimalPrecision - 2 + tuple := TupleDecimal{number: MakeDecimal(decNum)} + _, err := msgpack.Marshal(&tuple) + if err == nil { + t.Fatalf("It is possible to msgpack.Encoder a number unsupported by Tarantool") + } + if err.Error() != referenceErrMsg { + t.Fatalf("Incorrect error message on attempt to msgpack.Encoder number unsupported") + } +} + +func benchmarkMPEncodeDecode(b *testing.B, src decimal.Decimal, dst interface{}) { + b.ResetTimer() + + var v TupleDecimal + var buf []byte + var err error + for i := 0; i < b.N; i++ { + tuple := TupleDecimal{number: MakeDecimal(src)} + if buf, err = msgpack.Marshal(&tuple); err != nil { + b.Fatal(err) + } + if err = msgpack.Unmarshal(buf, &v); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkMPEncodeDecodeDecimal(b *testing.B) { + for _, testcase := range benchmarkSamples { + b.Run(testcase.numString, func(b *testing.B) { + dec, err := decimal.NewFromString(testcase.numString) + if err != nil { + b.Fatal(err) + } + benchmarkMPEncodeDecode(b, dec, &dec) + }) + } +} + +func BenchmarkMPEncodeDecimal(b *testing.B) { + for _, testcase := range benchmarkSamples { + b.Run(testcase.numString, func(b *testing.B) { + decNum, err := MakeDecimalFromString(testcase.numString) + if err != nil { + b.Fatal(err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := msgpack.Marshal(decNum); err != nil { + b.Fatal(err) + } + } + }) + } +} + +func BenchmarkMPDecodeDecimal(b *testing.B) { + for _, testcase := range benchmarkSamples { + b.Run(testcase.numString, func(b *testing.B) { + decNum, err := MakeDecimalFromString(testcase.numString) + if err != nil { + b.Fatal(err) + } + buf, err := msgpack.Marshal(decNum) + if err != nil { + b.Fatal(err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := msgpack.Unmarshal(buf, &decNum); err != nil { + b.Fatal(err) + } + } + }) + } +} + +func tupleValueIsDecimal(t *testing.T, tuples []interface{}, number decimal.Decimal) { + if len(tuples) != 1 { + t.Fatalf("Response Data len (%d) != 1", len(tuples)) + } + + if tpl, ok := tuples[0].([]interface{}); !ok { + t.Fatalf("Unexpected return value body") + } else { + if len(tpl) != 1 { + t.Fatalf("Unexpected return value body (tuple len)") + } + if val, ok := tpl[0].(Decimal); !ok || !val.Equal(number) { + t.Fatalf("Unexpected return value body (tuple 0 field)") + } + } +} + +func trimMPHeader(mpBuf []byte, fixExt bool) []byte { + mpHeaderLen := 2 + if fixExt == false { + mpHeaderLen = 3 + } + return mpBuf[mpHeaderLen:] +} + +func TestEncodeStringToBCD(t *testing.T) { + samples := correctnessSamples + samples = append(samples, rawSamples...) + samples = append(samples, benchmarkSamples...) + for _, testcase := range samples { + t.Run(testcase.numString, func(t *testing.T) { + buf, err := EncodeStringToBCD(testcase.numString) + if err != nil { + t.Fatalf("Failed to msgpack.Encoder decimal '%s' to BCD: %s", + testcase.numString, err) + } + b, _ := hex.DecodeString(testcase.mpBuf) + bcdBuf := trimMPHeader(b, testcase.fixExt) + if reflect.DeepEqual(buf, bcdBuf) != true { + t.Fatalf( + "Failed to msgpack.Encoder decimal '%s' to BCD: expected '%x', actual '%x'", + testcase.numString, bcdBuf, buf) + } + }) + } +} + +func TestDecodeStringFromBCD(t *testing.T) { + samples := correctnessSamples + samples = append(samples, correctnessDecodeSamples...) + samples = append(samples, rawSamples...) + samples = append(samples, benchmarkSamples...) + for _, testcase := range samples { + t.Run(testcase.numString, func(t *testing.T) { + b, _ := hex.DecodeString(testcase.mpBuf) + bcdBuf := trimMPHeader(b, testcase.fixExt) + s, exp, err := DecodeStringFromBCD(bcdBuf) + if err != nil { + t.Fatalf("Failed to decode BCD '%x' to decimal: %s", bcdBuf, err) + } + + decActual, err := decimal.NewFromString(s) + if exp != 0 { + decActual = decActual.Shift(int32(exp)) + } + if err != nil { + t.Fatalf("Failed to msgpack.Encoder string ('%s') to decimal", s) + } + decExpected, err := decimal.NewFromString(testcase.numString) + if err != nil { + t.Fatalf("Failed to msgpack.Encoder string ('%s') to decimal", testcase.numString) + } + if !decExpected.Equal(decActual) { + t.Fatalf( + "Decoded decimal from BCD ('%x') is incorrect: expected '%s', actual '%s'", + bcdBuf, testcase.numString, s) + } + }) + } +} + +func TestMPEncode(t *testing.T) { + samples := correctnessSamples + samples = append(samples, decimalSamples...) + samples = append(samples, benchmarkSamples...) + for _, testcase := range samples { + t.Run(testcase.numString, func(t *testing.T) { + dec, err := MakeDecimalFromString(testcase.numString) + if err != nil { + t.Fatalf("MakeDecimalFromString() failed: %s", err.Error()) + } + buf, err := msgpack.Marshal(dec) + if err != nil { + t.Fatalf("Marshalling failed: %s", err.Error()) + } + refBuf, _ := hex.DecodeString(testcase.mpBuf) + if reflect.DeepEqual(buf, refBuf) != true { + t.Fatalf("Failed to msgpack.Encoder decimal '%s', actual %x, expected %x", + testcase.numString, + buf, + refBuf) + } + }) + } +} + +func TestMPDecode(t *testing.T) { + samples := correctnessSamples + samples = append(samples, decimalSamples...) + samples = append(samples, benchmarkSamples...) + for _, testcase := range samples { + t.Run(testcase.numString, func(t *testing.T) { + mpBuf, err := hex.DecodeString(testcase.mpBuf) + if err != nil { + t.Fatalf("hex.DecodeString() failed: %s", err) + } + var v interface{} + err = msgpack.Unmarshal(mpBuf, &v) + if err != nil { + t.Fatalf("Unmsgpack.Marshalling failed: %s", err.Error()) + } + decActual, ok := v.(Decimal) + if !ok { + t.Fatalf("Unable to convert to Decimal") + } + + decExpected, err := decimal.NewFromString(testcase.numString) + if err != nil { + t.Fatalf("decimal.NewFromString() failed: %s", err.Error()) + } + if !decExpected.Equal(decActual.Decimal) { + t.Fatalf("Decoded decimal ('%s') is incorrect", testcase.mpBuf) + } + }) + } +} + +func BenchmarkEncodeStringToBCD(b *testing.B) { + for _, testcase := range benchmarkSamples { + b.Run(testcase.numString, func(b *testing.B) { + b.ResetTimer() + for n := 0; n < b.N; n++ { + EncodeStringToBCD(testcase.numString) + } + }) + } +} + +func BenchmarkDecodeStringFromBCD(b *testing.B) { + for _, testcase := range benchmarkSamples { + b.Run(testcase.numString, func(b *testing.B) { + buf, _ := hex.DecodeString(testcase.mpBuf) + bcdBuf := trimMPHeader(buf, testcase.fixExt) + b.ResetTimer() + for n := 0; n < b.N; n++ { + DecodeStringFromBCD(bcdBuf) + } + }) + } +} + +func TestSelect(t *testing.T) { + skipIfDecimalUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + number, err := decimal.NewFromString("-12.34") + if err != nil { + t.Fatalf("Failed to prepare test decimal: %s", err) + } + + ins := NewInsertRequest(space).Tuple([]interface{}{MakeDecimal(number)}) + data, err := conn.Do(ins).Get() + if err != nil { + t.Fatalf("Decimal insert failed: %s", err) + } + tupleValueIsDecimal(t, data, number) + + var offset uint32 = 0 + var limit uint32 = 1 + sel := NewSelectRequest(space). + Index(index). + Offset(offset). + Limit(limit). + Iterator(IterEq). + Key([]interface{}{MakeDecimal(number)}) + data, err = conn.Do(sel).Get() + if err != nil { + t.Fatalf("Decimal select failed: %s", err.Error()) + } + tupleValueIsDecimal(t, data, number) + + del := NewDeleteRequest(space).Index(index).Key([]interface{}{MakeDecimal(number)}) + data, err = conn.Do(del).Get() + if err != nil { + t.Fatalf("Decimal delete failed: %s", err) + } + tupleValueIsDecimal(t, data, number) +} + +func TestUnmarshal_from_decimal_new(t *testing.T) { + skipIfDecimalUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + samples := correctnessSamples + samples = append(samples, correctnessDecodeSamples...) + samples = append(samples, benchmarkSamples...) + for _, testcase := range samples { + str := testcase.numString + t.Run(str, func(t *testing.T) { + number, err := decimal.NewFromString(str) + if err != nil { + t.Fatalf("Failed to prepare test decimal: %s", err) + } + + call := NewEvalRequest("return require('decimal').new(...)"). + Args([]interface{}{str}) + data, err := conn.Do(call).Get() + if err != nil { + t.Fatalf("Decimal create failed: %s", err) + } + tupleValueIsDecimal(t, []interface{}{data}, number) + }) + } +} + +func assertInsert(t *testing.T, conn *Connection, numString string) { + number, err := decimal.NewFromString(numString) + if err != nil { + t.Fatalf("Failed to prepare test decimal: %s", err) + } + + ins := NewInsertRequest(space).Tuple([]interface{}{MakeDecimal(number)}) + data, err := conn.Do(ins).Get() + if err != nil { + t.Fatalf("Decimal insert failed: %s", err) + } + tupleValueIsDecimal(t, data, number) + + del := NewDeleteRequest(space).Index(index).Key([]interface{}{MakeDecimal(number)}) + data, err = conn.Do(del).Get() + if err != nil { + t.Fatalf("Decimal delete failed: %s", err) + } + tupleValueIsDecimal(t, data, number) +} + +func TestInsert(t *testing.T) { + skipIfDecimalUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + samples := correctnessSamples + samples = append(samples, benchmarkSamples...) + for _, testcase := range samples { + t.Run(testcase.numString, func(t *testing.T) { + assertInsert(t, conn, testcase.numString) + }) + } +} + +func TestReplace(t *testing.T) { + skipIfDecimalUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + number, err := decimal.NewFromString("-12.34") + if err != nil { + t.Fatalf("Failed to prepare test decimal: %s", err) + } + + rep := NewReplaceRequest(space).Tuple([]interface{}{MakeDecimal(number)}) + dataRep, errRep := conn.Do(rep).Get() + if errRep != nil { + t.Fatalf("Decimal replace failed: %s", errRep) + } + tupleValueIsDecimal(t, dataRep, number) + + sel := NewSelectRequest(space). + Index(index). + Limit(1). + Iterator(IterEq). + Key([]interface{}{MakeDecimal(number)}) + dataSel, errSel := conn.Do(sel).Get() + if errSel != nil { + t.Fatalf("Decimal select failed: %s", errSel) + } + tupleValueIsDecimal(t, dataSel, number) +} + +// runTestMain is a body of TestMain function +// (see https://pkg.go.dev/testing#hdr-Main). +// Using defer + os.Exit is not works so TestMain body +// is a separate function, see +// https://stackoverflow.com/questions/27629380/how-to-exit-a-go-program-honoring-deferred-calls +func runTestMain(m *testing.M) int { + isLess, err := test_helpers.IsTarantoolVersionLess(2, 2, 0) + if err != nil { + log.Fatalf("Failed to extract Tarantool version: %s", err) + } + + if isLess { + log.Println("Skipping decimal tests...") + isDecimalSupported = false + return m.Run() + } else { + isDecimalSupported = true + } + + instance, err := test_helpers.StartTarantool(test_helpers.StartOpts{ + Dialer: dialer, + InitScript: "config.lua", + Listen: server, + WaitStart: 100 * time.Millisecond, + ConnectRetry: 10, + RetryTimeout: 500 * time.Millisecond, + }) + defer test_helpers.StopTarantoolWithCleanup(instance) + + if err != nil { + log.Printf("Failed to prepare test Tarantool: %s", err) + return 1 + } + + return m.Run() +} + +func TestMain(m *testing.M) { + code := runTestMain(m) + os.Exit(code) +} diff --git a/decimal/example_test.go b/decimal/example_test.go new file mode 100644 index 000000000..7903b077e --- /dev/null +++ b/decimal/example_test.go @@ -0,0 +1,57 @@ +// Run Tarantool instance before example execution: +// +// Terminal 1: +// $ cd decimal +// $ TEST_TNT_LISTEN=3013 TEST_TNT_WORK_DIR=$(mktemp -d -t 'tarantool.XXX') tarantool config.lua +// +// Terminal 2: +// $ go test -v example_test.go +package decimal_test + +import ( + "context" + "log" + "time" + + "github.com/tarantool/go-tarantool/v3" + . "github.com/tarantool/go-tarantool/v3/decimal" +) + +// To enable support of decimal in msgpack with +// https://github.com/shopspring/decimal, +// import tarantool/decimal submodule. +func Example() { + server := "127.0.0.1:3013" + dialer := tarantool.NetDialer{ + Address: server, + User: "test", + Password: "test", + } + opts := tarantool.Opts{ + Timeout: 5 * time.Second, + } + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + client, err := tarantool.Connect(ctx, dialer, opts) + cancel() + if err != nil { + log.Fatalf("Failed to connect: %s", err.Error()) + } + + spaceNo := uint32(524) + + number, err := MakeDecimalFromString("-22.804") + if err != nil { + log.Fatalf("Failed to prepare test decimal: %s", err) + } + + data, err := client.Do(tarantool.NewReplaceRequest(spaceNo). + Tuple([]interface{}{number}), + ).Get() + if err != nil { + log.Fatalf("Decimal replace failed: %s", err) + } + + log.Println("Decimal tuple replace") + log.Println("Error", err) + log.Println("Data", data) +} diff --git a/decimal/export_test.go b/decimal/export_test.go new file mode 100644 index 000000000..2c8fda1c7 --- /dev/null +++ b/decimal/export_test.go @@ -0,0 +1,17 @@ +package decimal + +func EncodeStringToBCD(buf string) ([]byte, error) { + return encodeStringToBCD(buf) +} + +func DecodeStringFromBCD(bcdBuf []byte) (string, int, error) { + return decodeStringFromBCD(bcdBuf) +} + +func GetNumberLength(buf string) int { + return getNumberLength(buf) +} + +const ( + DecimalPrecision = decimalPrecision +) diff --git a/decimal/fuzzing_test.go b/decimal/fuzzing_test.go new file mode 100644 index 000000000..b6c49dcd9 --- /dev/null +++ b/decimal/fuzzing_test.go @@ -0,0 +1,50 @@ +//go:build go_tarantool_decimal_fuzzing +// +build go_tarantool_decimal_fuzzing + +package decimal_test + +import ( + "testing" + + "github.com/shopspring/decimal" + + . "github.com/tarantool/go-tarantool/v3/decimal" +) + +func strToDecimal(t *testing.T, buf string, exp int) decimal.Decimal { + decNum, err := decimal.NewFromString(buf) + if err != nil { + t.Fatal(err) + } + if exp != 0 { + decNum = decNum.Shift(int32(exp)) + } + return decNum +} + +func FuzzEncodeDecodeBCD(f *testing.F) { + samples := append(correctnessSamples, benchmarkSamples...) + for _, testcase := range samples { + if len(testcase.numString) > 0 { + f.Add(testcase.numString) // Use f.Add to provide a seed corpus. + } + } + f.Fuzz(func(t *testing.T, orig string) { + if l := GetNumberLength(orig); l > DecimalPrecision { + t.Skip("max number length is exceeded") + } + bcdBuf, err := EncodeStringToBCD(orig) + if err != nil { + t.Skip("Only correct requests are interesting: %w", err) + } + + dec, exp, err := DecodeStringFromBCD(bcdBuf) + if err != nil { + t.Fatalf("Failed to decode encoded value ('%s')", orig) + } + + if !strToDecimal(t, dec, exp).Equal(strToDecimal(t, orig, 0)) { + t.Fatal("Decimal numbers are not equal") + } + }) +} diff --git a/decoder.go b/decoder.go new file mode 100644 index 000000000..d1c682889 --- /dev/null +++ b/decoder.go @@ -0,0 +1,24 @@ +package tarantool + +import ( + "io" + + "github.com/vmihailenco/msgpack/v5" +) + +func untypedMapDecoder(dec *msgpack.Decoder) (interface{}, error) { + return dec.DecodeUntypedMap() +} + +func getDecoder(r io.Reader) *msgpack.Decoder { + d := msgpack.GetDecoder() + + d.Reset(r) + d.SetMapDecoder(untypedMapDecoder) + + return d +} + +func putDecoder(dec *msgpack.Decoder) { + msgpack.PutDecoder(dec) +} diff --git a/dial.go b/dial.go new file mode 100644 index 000000000..9faeaa98f --- /dev/null +++ b/dial.go @@ -0,0 +1,649 @@ +package tarantool + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "os" + "strings" + "time" + + "github.com/tarantool/go-iproto" + "github.com/vmihailenco/msgpack/v5" +) + +const bufSize = 128 * 1024 + +// Greeting is a message sent by Tarantool on connect. +type Greeting struct { + // Version is the supported protocol version. + Version string + // Salt is used to authenticate a user. + Salt string +} + +// writeFlusher is the interface that groups the basic Write and Flush methods. +type writeFlusher interface { + io.Writer + Flush() error +} + +// Conn is a generic stream-oriented network connection to a Tarantool +// instance. +type Conn interface { + // Read reads data from the connection. + Read(b []byte) (int, error) + // Write writes data to the connection. There may be an internal buffer for + // better performance control from a client side. + Write(b []byte) (int, error) + // Flush writes any buffered data. + Flush() error + // Close closes the connection. + // Any blocked Read or Flush operations will be unblocked and return + // errors. + Close() error + // Greeting returns server greeting. + Greeting() Greeting + // ProtocolInfo returns server protocol info. + ProtocolInfo() ProtocolInfo + // Addr returns the connection address. + Addr() net.Addr +} + +// DialOpts is a way to configure a Dial method to create a new Conn. +type DialOpts struct { + // IoTimeout is a timeout per a network read/write. + IoTimeout time.Duration +} + +// Dialer is the interface that wraps a method to connect to a Tarantool +// instance. The main idea is to provide a ready-to-work connection with +// basic preparation, successful authorization and additional checks. +// +// You can provide your own implementation to Connect() call if +// some functionality is not implemented in the connector. See NetDialer.Dial() +// implementation as example. +type Dialer interface { + // Dial connects to a Tarantool instance to the address with specified + // options. + Dial(ctx context.Context, opts DialOpts) (Conn, error) +} + +type tntConn struct { + net net.Conn + reader io.Reader + writer writeFlusher +} + +// Addr makes tntConn satisfy the Conn interface. +func (c *tntConn) Addr() net.Addr { + return c.net.RemoteAddr() +} + +// Read makes tntConn satisfy the Conn interface. +func (c *tntConn) Read(p []byte) (int, error) { + return c.reader.Read(p) +} + +// Write makes tntConn satisfy the Conn interface. +func (c *tntConn) Write(p []byte) (int, error) { + if l, err := c.writer.Write(p); err != nil { + return l, err + } else if l != len(p) { + return l, errors.New("wrong length written") + } else { + return l, nil + } +} + +// Flush makes tntConn satisfy the Conn interface. +func (c *tntConn) Flush() error { + return c.writer.Flush() +} + +// Close makes tntConn satisfy the Conn interface. +func (c *tntConn) Close() error { + return c.net.Close() +} + +// Greeting makes tntConn satisfy the Conn interface. +func (c *tntConn) Greeting() Greeting { + return Greeting{} +} + +// ProtocolInfo makes tntConn satisfy the Conn interface. +func (c *tntConn) ProtocolInfo() ProtocolInfo { + return ProtocolInfo{} +} + +// protocolConn is a wrapper for connections, so they contain the ProtocolInfo. +type protocolConn struct { + Conn + protocolInfo ProtocolInfo +} + +// ProtocolInfo returns ProtocolInfo of a protocolConn. +func (c *protocolConn) ProtocolInfo() ProtocolInfo { + return c.protocolInfo +} + +// greetingConn is a wrapper for connections, so they contain the Greeting. +type greetingConn struct { + Conn + greeting Greeting +} + +// Greeting returns Greeting of a greetingConn. +func (c *greetingConn) Greeting() Greeting { + return c.greeting +} + +type netDialer struct { + address string +} + +func (d netDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { + var err error + conn := new(tntConn) + + network, address := parseAddress(d.address) + dialer := net.Dialer{} + conn.net, err = dialer.DialContext(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("failed to dial: %w", err) + } + + dc := &deadlineIO{to: opts.IoTimeout, c: conn.net} + conn.reader = bufio.NewReaderSize(dc, bufSize) + conn.writer = bufio.NewWriterSize(dc, bufSize) + + return conn, nil +} + +// NetDialer is a basic Dialer implementation. +type NetDialer struct { + // Address is an address to connect. + // It could be specified in following ways: + // + // - TCP connections (tcp://192.168.1.1:3013, tcp://my.host:3013, + // tcp:192.168.1.1:3013, tcp:my.host:3013, 192.168.1.1:3013, my.host:3013) + // + // - Unix socket, first '/' or '.' indicates Unix socket + // (unix:///abs/path/tnt.sock, unix:path/tnt.sock, /abs/path/tnt.sock, + // ./rel/path/tnt.sock, unix/:path/tnt.sock) + Address string + // Username for logging in to Tarantool. + User string + // User password for logging in to Tarantool. + Password string + // RequiredProtocol contains minimal protocol version and + // list of protocol features that should be supported by + // Tarantool server. By default, there are no restrictions. + RequiredProtocolInfo ProtocolInfo +} + +// Dial makes NetDialer satisfy the Dialer interface. +func (d NetDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { + dialer := AuthDialer{ + Dialer: ProtocolDialer{ + Dialer: GreetingDialer{ + Dialer: netDialer{ + address: d.Address, + }, + }, + RequiredProtocolInfo: d.RequiredProtocolInfo, + }, + Auth: ChapSha1Auth, + Username: d.User, + Password: d.Password, + } + + return dialer.Dial(ctx, opts) +} + +type fdAddr struct { + Fd uintptr +} + +func (a fdAddr) Network() string { + return "fd" +} + +func (a fdAddr) String() string { + return fmt.Sprintf("fd://%d", a.Fd) +} + +type fdConn struct { + net.Conn + Addr fdAddr +} + +func (c *fdConn) RemoteAddr() net.Addr { + return c.Addr +} + +type fdDialer struct { + fd uintptr +} + +func (d fdDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { + file := os.NewFile(d.fd, "") + c, err := net.FileConn(file) + if err != nil { + return nil, fmt.Errorf("failed to dial: %w", err) + } + + conn := new(tntConn) + conn.net = &fdConn{Conn: c, Addr: fdAddr{Fd: d.fd}} + + dc := &deadlineIO{to: opts.IoTimeout, c: conn.net} + conn.reader = bufio.NewReaderSize(dc, bufSize) + conn.writer = bufio.NewWriterSize(dc, bufSize) + + return conn, nil +} + +// FdDialer allows using an existing socket fd for connection. +type FdDialer struct { + // Fd is a socket file descriptor. + Fd uintptr + // RequiredProtocol contains minimal protocol version and + // list of protocol features that should be supported by + // Tarantool server. By default, there are no restrictions. + RequiredProtocolInfo ProtocolInfo +} + +// Dial makes FdDialer satisfy the Dialer interface. +func (d FdDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { + dialer := ProtocolDialer{ + Dialer: GreetingDialer{ + Dialer: fdDialer{ + fd: d.Fd, + }, + }, + RequiredProtocolInfo: d.RequiredProtocolInfo, + } + + return dialer.Dial(ctx, opts) +} + +// AuthDialer is a dialer-wrapper that does authentication of a user. +type AuthDialer struct { + // Dialer is a base dialer. + Dialer Dialer + // Authentication options. + Auth Auth + // Username is a name of a user for authentication. + Username string + // Password is a user password for authentication. + Password string +} + +// Dial makes AuthDialer satisfy the Dialer interface. +func (d AuthDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { + conn, err := d.Dialer.Dial(ctx, opts) + if err != nil { + return conn, err + } + + greeting := conn.Greeting() + if greeting.Salt == "" { + conn.Close() + return nil, fmt.Errorf("failed to authenticate: " + + "an invalid connection without salt") + } + + if d.Username == "" { + return conn, nil + } + + protocolAuth := conn.ProtocolInfo().Auth + if d.Auth == AutoAuth { + if protocolAuth != AutoAuth { + d.Auth = protocolAuth + } else { + d.Auth = ChapSha1Auth + } + } + + if err := authenticate(ctx, conn, d.Auth, d.Username, d.Password, + conn.Greeting().Salt); err != nil { + conn.Close() + return nil, fmt.Errorf("failed to authenticate: %w", err) + } + return conn, nil +} + +// ProtocolDialer is a dialer-wrapper that reads and fills the ProtocolInfo +// of a connection. +type ProtocolDialer struct { + // Dialer is a base dialer. + Dialer Dialer + // RequiredProtocol contains minimal protocol version and + // list of protocol features that should be supported by + // Tarantool server. By default, there are no restrictions. + RequiredProtocolInfo ProtocolInfo +} + +// Dial makes ProtocolDialer satisfy the Dialer interface. +func (d ProtocolDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { + conn, err := d.Dialer.Dial(ctx, opts) + if err != nil { + return conn, err + } + + protocolConn := protocolConn{ + Conn: conn, + protocolInfo: d.RequiredProtocolInfo, + } + + protocolConn.protocolInfo, err = identify(ctx, &protocolConn) + if err != nil { + protocolConn.Close() + return nil, fmt.Errorf("failed to identify: %w", err) + } + + err = checkProtocolInfo(d.RequiredProtocolInfo, protocolConn.protocolInfo) + if err != nil { + protocolConn.Close() + return nil, fmt.Errorf("invalid server protocol: %w", err) + } + + return &protocolConn, nil +} + +// GreetingDialer is a dialer-wrapper that reads and fills the Greeting +// of a connection. +type GreetingDialer struct { + // Dialer is a base dialer. + Dialer Dialer +} + +// Dial makes GreetingDialer satisfy the Dialer interface. +func (d GreetingDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { + conn, err := d.Dialer.Dial(ctx, opts) + if err != nil { + return conn, err + } + + greetingConn := greetingConn{ + Conn: conn, + } + version, salt, err := readGreeting(ctx, &greetingConn) + if err != nil { + greetingConn.Close() + return nil, fmt.Errorf("failed to read greeting: %w", err) + } + + greetingConn.greeting = Greeting{ + Version: version, + Salt: salt, + } + + return &greetingConn, err +} + +// parseAddress split address into network and address parts. +func parseAddress(address string) (string, string) { + network := "tcp" + addrLen := len(address) + + if addrLen > 0 && (address[0] == '.' || address[0] == '/') { + network = "unix" + } else if addrLen >= 7 && address[0:7] == "unix://" { + network = "unix" + address = address[7:] + } else if addrLen >= 5 && address[0:5] == "unix:" { + network = "unix" + address = address[5:] + } else if addrLen >= 6 && address[0:6] == "unix/:" { + network = "unix" + address = address[6:] + } else if addrLen >= 6 && address[0:6] == "tcp://" { + address = address[6:] + } else if addrLen >= 4 && address[0:4] == "tcp:" { + address = address[4:] + } + + return network, address +} + +// ioWaiter waits in a background until an io operation done or a context +// is expired. It closes the connection and writes a context error into the +// output channel on context expiration. +// +// A user of the helper should close the first output channel after an IO +// operation done and read an error from a second channel to get the result +// of waiting. +func ioWaiter(ctx context.Context, conn Conn) (chan<- struct{}, <-chan error) { + doneIO := make(chan struct{}) + doneWait := make(chan error, 1) + + go func() { + defer close(doneWait) + + select { + case <-ctx.Done(): + conn.Close() + <-doneIO + doneWait <- ctx.Err() + case <-doneIO: + doneWait <- nil + } + }() + + return doneIO, doneWait +} + +// readGreeting reads a greeting message. +func readGreeting(ctx context.Context, conn Conn) (string, string, error) { + var version, salt string + + doneRead, doneWait := ioWaiter(ctx, conn) + + data := make([]byte, 128) + _, err := io.ReadFull(conn, data) + + close(doneRead) + + if err == nil { + version = bytes.NewBuffer(data[:64]).String() + salt = bytes.NewBuffer(data[64:108]).String() + } + + if waitErr := <-doneWait; waitErr != nil { + err = waitErr + } + + return version, salt, err +} + +// identify sends info about client protocol, receives info +// about server protocol in response and stores it in the connection. +func identify(ctx context.Context, conn Conn) (ProtocolInfo, error) { + var info ProtocolInfo + + req := NewIdRequest(clientProtocolInfo) + if err := writeRequest(ctx, conn, req); err != nil { + return info, err + } + + resp, err := readResponse(ctx, conn, req) + if err != nil { + if resp != nil && + resp.Header().Error == iproto.ER_UNKNOWN_REQUEST_TYPE { + // IPROTO_ID requests are not supported by server. + return info, nil + } + return info, err + } + data, err := resp.Decode() + if err != nil { + return info, err + } + + if len(data) == 0 { + return info, errors.New("unexpected response: no data") + } + + info, ok := data[0].(ProtocolInfo) + if !ok { + return info, errors.New("unexpected response: wrong data") + } + + return info, nil +} + +// checkProtocolInfo checks that required protocol version is +// and protocol features are supported. +func checkProtocolInfo(required ProtocolInfo, actual ProtocolInfo) error { + if required.Version > actual.Version { + return fmt.Errorf("protocol version %d is not supported", + required.Version) + } + + // It seems that iterating over a small list is way faster + // than building a map: https://stackoverflow.com/a/52710077/11646599 + var missed []string + for _, requiredFeature := range required.Features { + found := false + for _, actualFeature := range actual.Features { + if requiredFeature == actualFeature { + found = true + } + } + if !found { + missed = append(missed, requiredFeature.String()) + } + } + + switch { + case len(missed) == 1: + return fmt.Errorf("protocol feature %s is not supported", missed[0]) + case len(missed) > 1: + joined := strings.Join(missed, ", ") + return fmt.Errorf("protocol features %s are not supported", joined) + default: + return nil + } +} + +// authenticate authenticates for a connection. +func authenticate(ctx context.Context, c Conn, auth Auth, user, pass, salt string) error { + var req Request + var err error + + switch auth { + case ChapSha1Auth: + req, err = newChapSha1AuthRequest(user, pass, salt) + if err != nil { + return err + } + case PapSha256Auth: + req = newPapSha256AuthRequest(user, pass) + default: + return errors.New("unsupported method " + auth.String()) + } + + if err = writeRequest(ctx, c, req); err != nil { + return err + } + if _, err = readResponse(ctx, c, req); err != nil { + return err + } + return nil +} + +// writeRequest writes a request to the writer. +func writeRequest(ctx context.Context, conn Conn, req Request) error { + var packet smallWBuf + err := pack(&packet, msgpack.NewEncoder(&packet), 0, req, ignoreStreamId, nil) + + if err != nil { + return fmt.Errorf("pack error: %w", err) + } + + doneWrite, doneWait := ioWaiter(ctx, conn) + + _, err = conn.Write(packet.b) + + close(doneWrite) + + if waitErr := <-doneWait; waitErr != nil { + err = waitErr + } + + if err != nil { + return fmt.Errorf("write error: %w", err) + } + + doneWrite, doneWait = ioWaiter(ctx, conn) + + err = conn.Flush() + + close(doneWrite) + + if waitErr := <-doneWait; waitErr != nil { + err = waitErr + } + + if err != nil { + return fmt.Errorf("flush error: %w", err) + } + + if waitErr := <-doneWait; waitErr != nil { + err = waitErr + } + + return err +} + +// readResponse reads a response from the reader. +func readResponse(ctx context.Context, conn Conn, req Request) (Response, error) { + var lenbuf [packetLengthBytes]byte + + doneRead, doneWait := ioWaiter(ctx, conn) + + respBytes, err := read(conn, lenbuf[:]) + + close(doneRead) + + if waitErr := <-doneWait; waitErr != nil { + err = waitErr + } + + if err != nil { + return nil, fmt.Errorf("read error: %w", err) + } + + buf := smallBuf{b: respBytes} + + d := getDecoder(&buf) + defer putDecoder(d) + + header, _, err := decodeHeader(d, &buf) + if err != nil { + return nil, fmt.Errorf("decode response header error: %w", err) + } + + resp, err := req.Response(header, &buf) + if err != nil { + return nil, fmt.Errorf("creating response error: %w", err) + } + + _, err = resp.Decode() + if err != nil { + switch err.(type) { + case Error: + return resp, err + default: + return resp, fmt.Errorf("decode response body error: %w", err) + } + } + + return resp, nil +} diff --git a/dial_test.go b/dial_test.go new file mode 100644 index 000000000..ad55bb95a --- /dev/null +++ b/dial_test.go @@ -0,0 +1,1083 @@ +package tarantool_test + +import ( + "bytes" + "context" + "encoding/base64" + "errors" + "fmt" + "net" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tarantool/go-iproto" + + "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/test_helpers" +) + +type mockErrorDialer struct { + err error +} + +func (m mockErrorDialer) Dial(ctx context.Context, + opts tarantool.DialOpts) (tarantool.Conn, error) { + return nil, m.err +} + +func TestDialer_Dial_error(t *testing.T) { + const errMsg = "any msg" + dialer := mockErrorDialer{ + err: errors.New(errMsg), + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{}) + assert.Nil(t, conn) + assert.ErrorContains(t, err, errMsg) +} + +type mockPassedDialer struct { + ctx context.Context + opts tarantool.DialOpts +} + +func (m *mockPassedDialer) Dial(ctx context.Context, + opts tarantool.DialOpts) (tarantool.Conn, error) { + m.opts = opts + if ctx != m.ctx { + return nil, errors.New("wrong context") + } + return nil, errors.New("does not matter") +} + +func TestDialer_Dial_passedOpts(t *testing.T) { + dialer := &mockPassedDialer{} + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + dialer.ctx = ctx + + dialOpts := tarantool.DialOpts{ + IoTimeout: opts.Timeout, + } + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{ + Timeout: opts.Timeout, + }) + + assert.Nil(t, conn) + assert.NotNil(t, err) + assert.NotEqual(t, err.Error(), "wrong context") + assert.Equal(t, dialOpts, dialer.opts) +} + +type mockIoConn struct { + addr net.Addr + // Addr calls count. + addrCnt int + // Sends an event on Read()/Write()/Flush(). + read, written chan struct{} + // Read()/Write() buffers. + readbuf, writebuf bytes.Buffer + // Calls readWg/writeWg.Wait() in Read()/Flush(). + readWg, writeWg sync.WaitGroup + // wgDoneOnClose call Done() on the wait groups on Close(). + wgDoneOnClose bool + // How many times to wait before a wg.Wait() call. + readWgDelay, writeWgDelay int + // Write()/Read()/Flush()/Close() calls count. + writeCnt, readCnt, flushCnt, closeCnt int + // Greeting()/ProtocolInfo() calls count. + greetingCnt, infoCnt int + // Value for Greeting(). + greeting tarantool.Greeting + // Value for ProtocolInfo(). + info tarantool.ProtocolInfo +} + +func (m *mockIoConn) Read(b []byte) (int, error) { + m.readCnt++ + if m.readWgDelay == 0 { + m.readWg.Wait() + } + m.readWgDelay-- + + ret, err := m.readbuf.Read(b) + + if ret != 0 && m.read != nil { + m.read <- struct{}{} + } + + return ret, err +} + +func (m *mockIoConn) Write(b []byte) (int, error) { + m.writeCnt++ + if m.writeWgDelay == 0 { + m.writeWg.Wait() + } + m.writeWgDelay-- + + ret, err := m.writebuf.Write(b) + + if m.written != nil { + m.written <- struct{}{} + } + + return ret, err +} + +func (m *mockIoConn) Flush() error { + m.flushCnt++ + return nil +} + +func (m *mockIoConn) Close() error { + if m.wgDoneOnClose { + m.readWg.Done() + m.writeWg.Done() + m.wgDoneOnClose = false + } + + m.closeCnt++ + return nil +} + +func (m *mockIoConn) Greeting() tarantool.Greeting { + m.greetingCnt++ + return m.greeting +} + +func (m *mockIoConn) ProtocolInfo() tarantool.ProtocolInfo { + m.infoCnt++ + return m.info +} + +func (m *mockIoConn) Addr() net.Addr { + m.addrCnt++ + return m.addr +} + +type mockIoDialer struct { + init func(conn *mockIoConn) + conn *mockIoConn +} + +func newMockIoConn() *mockIoConn { + conn := new(mockIoConn) + conn.readWg.Add(1) + conn.writeWg.Add(1) + conn.wgDoneOnClose = true + return conn +} + +func (m *mockIoDialer) Dial(ctx context.Context, opts tarantool.DialOpts) (tarantool.Conn, error) { + m.conn = newMockIoConn() + if m.init != nil { + m.init(m.conn) + } + return m.conn, nil +} + +func dialIo(t *testing.T, + init func(conn *mockIoConn)) (*tarantool.Connection, mockIoDialer) { + t.Helper() + + dialer := mockIoDialer{ + init: init, + } + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := tarantool.Connect(ctx, &dialer, + tarantool.Opts{ + Timeout: 1000 * time.Second, // Avoid pings. + SkipSchema: true, + }) + require.Nil(t, err) + require.NotNil(t, conn) + + return conn, dialer +} + +func TestConn_Close(t *testing.T) { + conn, dialer := dialIo(t, nil) + conn.Close() + + assert.Equal(t, 1, dialer.conn.closeCnt) +} + +type stubAddr struct { + str string +} + +func (a stubAddr) String() string { + return a.str +} + +func (a stubAddr) Network() string { + return "stub" +} + +func TestConn_Addr(t *testing.T) { + const addr = "any" + conn, dialer := dialIo(t, func(conn *mockIoConn) { + conn.addr = stubAddr{str: addr} + }) + defer func() { + conn.Close() + }() + + assert.Equal(t, addr, conn.Addr().String()) + assert.Equal(t, 1, dialer.conn.addrCnt) +} + +func TestConn_Greeting(t *testing.T) { + greeting := tarantool.Greeting{ + Version: "any", + Salt: "salt", + } + conn, dialer := dialIo(t, func(conn *mockIoConn) { + conn.greeting = greeting + }) + defer func() { + conn.Close() + }() + + assert.Equal(t, &greeting, conn.Greeting) + assert.Equal(t, 1, dialer.conn.greetingCnt) +} + +func TestConn_ProtocolInfo(t *testing.T) { + info := tarantool.ProtocolInfo{ + Auth: tarantool.ChapSha1Auth, + Version: 33, + Features: []iproto.Feature{ + iproto.IPROTO_FEATURE_ERROR_EXTENSION, + }, + } + conn, dialer := dialIo(t, func(conn *mockIoConn) { + conn.info = info + }) + defer func() { + conn.Close() + }() + + assert.Equal(t, info, conn.ProtocolInfo()) + assert.Equal(t, 1, dialer.conn.infoCnt) +} + +func TestConn_ReadWrite(t *testing.T) { + conn, dialer := dialIo(t, func(conn *mockIoConn) { + conn.read = make(chan struct{}) + conn.written = make(chan struct{}) + conn.writeWgDelay = 1 + conn.readbuf.Write([]byte{ + 0xce, 0x00, 0x00, 0x00, 0x0a, // Length. + 0x82, // Header map. + 0x00, 0x00, + 0x01, 0xce, 0x00, 0x00, 0x00, 0x02, + 0x80, // Body map. + }) + conn.wgDoneOnClose = false + }) + defer func() { + dialer.conn.writeWg.Done() + conn.Close() + }() + + fut := conn.Do(tarantool.NewPingRequest()) + + <-dialer.conn.written + dialer.conn.written = nil + + dialer.conn.readWg.Done() + <-dialer.conn.read + <-dialer.conn.read + + assert.Equal(t, []byte{ + 0xce, 0x00, 0x00, 0x00, 0xa, // Length. + 0x82, // Header map. + 0x00, 0x40, + 0x01, 0xce, 0x00, 0x00, 0x00, 0x02, + 0x80, // Empty map. + }, dialer.conn.writebuf.Bytes()) + + _, err := fut.Get() + assert.Nil(t, err) +} + +func TestConn_ContextCancel(t *testing.T) { + dialer := tarantool.NetDialer{Address: "127.0.0.1:8080"} + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + conn, err := dialer.Dial(ctx, tarantool.DialOpts{}) + + assert.Nil(t, conn) + assert.NotNil(t, err) + assert.Truef(t, strings.Contains(err.Error(), "operation was canceled"), + fmt.Sprintf("unexpected error, expected to contain %s, got %v", + "operation was canceled", err)) +} + +func genSalt() [64]byte { + salt := [64]byte{} + for i := 0; i < 44; i++ { + salt[i] = 'a' + } + return salt +} + +var ( + testDialUser = "test" + testDialPass = "test" + testDialVersion = [64]byte{'t', 'e', 's', 't'} + + // Salt with end zeros. + testDialSalt = genSalt() + + idRequestExpected = []byte{ + 0xce, 0x00, 0x00, 0x00, 31, // Length. + 0x82, // Header map. + 0x00, 0x49, + 0x01, 0xce, 0x00, 0x00, 0x00, 0x00, + + 0x82, // Data map. + 0x54, + 0xcf, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, // Version. + 0x55, + 0x99, // Fixed arrау with 9 elements. + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x0b, 0x0c, // Features (9 elements). + } + + idResponseTyped = tarantool.ProtocolInfo{ + Version: 6, + Features: []iproto.Feature{iproto.Feature(1), iproto.Feature(21)}, + Auth: tarantool.ChapSha1Auth, + } + + idResponse = []byte{ + 0xce, 0x00, 0x00, 0x00, 37, // Length. + 0x83, // Header map. + 0x00, 0xce, 0x00, 0x00, 0x00, 0x00, + 0x01, 0xce, 0x00, 0x00, 0x00, 0x00, + 0x05, 0xce, 0x00, 0x00, 0x00, 0x61, + + 0x83, // Data map. + 0x54, + 0x06, // Version. + 0x55, + 0x92, 0x01, 0x15, // Features. + 0x5b, + 0xa9, 'c', 'h', 'a', 'p', '-', 's', 'h', 'a', '1', + } + + idResponseNotSupported = []byte{ + 0xce, 0x00, 0x00, 0x00, 25, // Length. + 0x83, // Header map. + 0x00, 0xce, 0x00, 0x00, 0x80, 0x30, + 0x01, 0xce, 0x00, 0x00, 0x00, 0x00, + 0x05, 0xce, 0x00, 0x00, 0x00, 0x61, + 0x81, + 0x31, + 0xa3, 'e', 'r', 'r', + } + + authRequestExpectedChapSha1 = []byte{ + 0xce, 0x00, 0x00, 0x00, 57, // Length. + 0x82, // Header map. + 0x00, 0x07, + 0x01, 0xce, 0x00, 0x00, 0x00, 0x00, + + 0x82, // Data map. + 0xce, 0x00, 0x00, 0x00, 0x23, + 0xa4, 't', 'e', 's', 't', // Login. + 0xce, 0x00, 0x00, 0x00, 0x21, + 0x92, // Tuple. + 0xa9, 'c', 'h', 'a', 'p', '-', 's', 'h', 'a', '1', + + // Scramble. + 0xb4, 0x1b, 0xd4, 0x20, 0x45, 0x73, 0x22, + 0xcf, 0xab, 0x05, 0x03, 0xf3, 0x89, 0x4b, + 0xfe, 0xc7, 0x24, 0x5a, 0xe6, 0xe8, 0x31, + } + + authRequestExpectedPapSha256 = []byte{ + 0xce, 0x00, 0x00, 0x00, 0x2a, // Length. + 0x82, // Header map. + 0x00, 0x07, + 0x01, 0xce, 0x00, 0x00, 0x00, 0x00, + + 0x82, // Data map. + 0xce, 0x00, 0x00, 0x00, 0x23, + 0xa4, 't', 'e', 's', 't', // Login. + 0xce, 0x00, 0x00, 0x00, 0x21, + 0x92, // Tuple. + 0xaa, 'p', 'a', 'p', '-', 's', 'h', 'a', '2', '5', '6', + 0xa4, 't', 'e', 's', 't', + } + + okResponse = []byte{ + 0xce, 0x00, 0x00, 0x00, 19, // Length. + 0x83, // Header map. + 0x00, 0xce, 0x00, 0x00, 0x00, 0x00, + 0x01, 0xce, 0x00, 0x00, 0x00, 0x00, + 0x05, 0xce, 0x00, 0x00, 0x00, 0x61, + } + + errResponse = []byte{0xce} +) + +type testDialOpts struct { + name string + wantErr bool + expectedErr string + expectedProtocolInfo tarantool.ProtocolInfo + + // These options configure the behavior of the server. + isErrGreeting bool + isErrId bool + isIdUnsupported bool + isErrAuth bool + isEmptyAuth bool +} + +type dialServerActual struct { + IdRequest []byte + AuthRequest []byte +} + +func testDialAccept(opts testDialOpts, l net.Listener) chan dialServerActual { + ch := make(chan dialServerActual, 1) + + go func() { + client, err := l.Accept() + if err != nil { + return + } + defer client.Close() + if opts.isErrGreeting { + client.Write(errResponse) + return + } else { + // Write greeting. + client.Write(testDialVersion[:]) + client.Write(testDialSalt[:]) + } + + // Read Id request. + idRequestActual := make([]byte, len(idRequestExpected)) + client.Read(idRequestActual) + + // Make Id response. + if opts.isErrId { + client.Write(errResponse) + } else if opts.isIdUnsupported { + client.Write(idResponseNotSupported) + } else { + client.Write(idResponse) + } + + // Read Auth request. + authRequestExpected := authRequestExpectedChapSha1 + if opts.isEmptyAuth { + authRequestExpected = []byte{} + } + authRequestActual := make([]byte, len(authRequestExpected)) + client.Read(authRequestActual) + + // Make Auth response. + if opts.isErrAuth { + client.Write(errResponse) + } else { + client.Write(okResponse) + } + ch <- dialServerActual{ + IdRequest: idRequestActual, + AuthRequest: authRequestActual, + } + }() + + return ch +} + +func testDialer(t *testing.T, l net.Listener, dialer tarantool.Dialer, + opts testDialOpts) { + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + ch := testDialAccept(opts, l) + conn, err := dialer.Dial(ctx, tarantool.DialOpts{ + IoTimeout: time.Second * 2, + }) + if opts.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), opts.expectedErr) + return + } + require.NoError(t, err) + require.Equal(t, opts.expectedProtocolInfo, conn.ProtocolInfo()) + require.Equal(t, testDialVersion[:], []byte(conn.Greeting().Version)) + require.Equal(t, testDialSalt[:44], []byte(conn.Greeting().Salt)) + + actual := <-ch + require.Equal(t, idRequestExpected, actual.IdRequest) + + authRequestExpected := authRequestExpectedChapSha1 + if opts.isEmptyAuth { + authRequestExpected = []byte{} + } + require.Equal(t, authRequestExpected, actual.AuthRequest) + conn.Close() +} + +func TestNetDialer_Dial(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer l.Close() + dialer := tarantool.NetDialer{ + Address: l.Addr().String(), + User: testDialUser, + Password: testDialPass, + } + cases := []testDialOpts{ + { + name: "all is ok", + expectedProtocolInfo: idResponseTyped.Clone(), + }, + { + name: "id request unsupported", + expectedProtocolInfo: tarantool.ProtocolInfo{}, + isIdUnsupported: true, + }, + { + name: "greeting response error", + wantErr: true, + expectedErr: "failed to read greeting", + isErrGreeting: true, + }, + { + name: "id response error", + wantErr: true, + expectedErr: "failed to identify", + isErrId: true, + }, + { + name: "auth response error", + wantErr: true, + expectedErr: "failed to authenticate", + isErrAuth: true, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + testDialer(t, l, dialer, tc) + }) + } +} + +func TestNetDialer_Dial_hang_connection(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer l.Close() + + dialer := tarantool.NetDialer{ + Address: l.Addr().String(), + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + + conn, err := dialer.Dial(ctx, tarantool.DialOpts{}) + + require.Nil(t, conn) + require.Error(t, err, context.DeadlineExceeded) +} + +func TestNetDialer_Dial_requirements(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer l.Close() + dialer := tarantool.NetDialer{ + Address: l.Addr().String(), + User: testDialUser, + Password: testDialPass, + RequiredProtocolInfo: tarantool.ProtocolInfo{ + Features: []iproto.Feature{42}, + }, + } + testDialAccept(testDialOpts{}, l) + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := dialer.Dial(ctx, tarantool.DialOpts{}) + if err == nil { + conn.Close() + } + require.Error(t, err) + require.Contains(t, err.Error(), "invalid server protocol") +} + +func TestFdDialer_Dial(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer l.Close() + addr := l.Addr().String() + + cases := []testDialOpts{ + { + name: "all is ok", + expectedProtocolInfo: idResponseTyped.Clone(), + isEmptyAuth: true, + }, + { + name: "id request unsupported", + expectedProtocolInfo: tarantool.ProtocolInfo{}, + isIdUnsupported: true, + isEmptyAuth: true, + }, + { + name: "greeting response error", + wantErr: true, + expectedErr: "failed to read greeting", + isErrGreeting: true, + }, + { + name: "id response error", + wantErr: true, + expectedErr: "failed to identify", + isErrId: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + sock, err := net.Dial("tcp", addr) + require.NoError(t, err) + defer sock.Close() + + // It seems that the file descriptor is not always fully ready + // after the connection is created. These lines help to avoid the + // "bad file descriptor" errors. + // + // We already tried to use the SyscallConn(), but it has the same + // issue. + time.Sleep(time.Millisecond) + sock.(*net.TCPConn).SetLinger(0) + + f, err := sock.(*net.TCPConn).File() + require.NoError(t, err) + defer f.Close() + + dialer := tarantool.FdDialer{ + Fd: f.Fd(), + } + testDialer(t, l, dialer, tc) + }) + } +} + +func TestFdDialer_Dial_requirements(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer l.Close() + addr := l.Addr().String() + + sock, err := net.Dial("tcp", addr) + require.NoError(t, err) + defer sock.Close() + + // It seems that the file descriptor is not always fully ready after the + // connection is created. These lines help to avoid the + // "bad file descriptor" errors. + // + // We already tried to use the SyscallConn(), but it has the same + // issue. + time.Sleep(time.Millisecond) + sock.(*net.TCPConn).SetLinger(0) + + f, err := sock.(*net.TCPConn).File() + require.NoError(t, err) + defer f.Close() + + dialer := tarantool.FdDialer{ + Fd: f.Fd(), + RequiredProtocolInfo: tarantool.ProtocolInfo{ + Features: []iproto.Feature{42}, + }, + } + + testDialAccept(testDialOpts{}, l) + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + + conn, err := dialer.Dial(ctx, tarantool.DialOpts{}) + if err == nil { + conn.Close() + } + require.Error(t, err) + require.Contains(t, err.Error(), "invalid server protocol") +} + +func TestAuthDialer_Dial_DialerError(t *testing.T) { + dialer := tarantool.AuthDialer{ + Dialer: mockErrorDialer{ + err: fmt.Errorf("some error"), + }, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + + conn, err := dialer.Dial(ctx, tarantool.DialOpts{}) + if conn != nil { + conn.Close() + } + + assert.NotNil(t, err) + assert.EqualError(t, err, "some error") +} + +func TestAuthDialer_Dial_NoSalt(t *testing.T) { + dialer := tarantool.AuthDialer{ + Dialer: &mockIoDialer{ + init: func(conn *mockIoConn) { + conn.greeting = tarantool.Greeting{ + Salt: "", + } + }, + }, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := dialer.Dial(ctx, tarantool.DialOpts{}) + + assert.NotNil(t, err) + assert.ErrorContains(t, err, "an invalid connection without salt") + if conn != nil { + conn.Close() + t.Errorf("connection is not nil") + } +} + +func TestConn_AuthDialer_hang_connection(t *testing.T) { + salt := fmt.Sprintf("%s", testDialSalt) + salt = base64.StdEncoding.EncodeToString([]byte(salt)) + mock := &mockIoDialer{ + init: func(conn *mockIoConn) { + conn.greeting.Salt = salt + conn.readWgDelay = 0 + conn.writeWgDelay = 0 + }, + } + dialer := tarantool.AuthDialer{ + Dialer: mock, + Username: "test", + Password: "test", + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + + conn, err := tarantool.Connect(ctx, &dialer, + tarantool.Opts{ + Timeout: 1000 * time.Second, // Avoid pings. + SkipSchema: true, + }) + + require.Nil(t, conn) + require.Error(t, err, context.DeadlineExceeded) + require.Equal(t, mock.conn.writeCnt, 1) + require.Equal(t, mock.conn.readCnt, 0) + require.Greater(t, mock.conn.closeCnt, 1) +} + +func TestAuthDialer_Dial(t *testing.T) { + salt := fmt.Sprintf("%s", testDialSalt) + salt = base64.StdEncoding.EncodeToString([]byte(salt)) + dialer := mockIoDialer{ + init: func(conn *mockIoConn) { + conn.greeting.Salt = salt + conn.writeWgDelay = 1 + conn.readWgDelay = 2 + conn.readbuf.Write(okResponse) + conn.wgDoneOnClose = false + }, + } + defer func() { + dialer.conn.writeWg.Done() + }() + + authDialer := tarantool.AuthDialer{ + Dialer: &dialer, + Username: "test", + Password: "test", + } + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := authDialer.Dial(ctx, tarantool.DialOpts{}) + if conn != nil { + conn.Close() + } + + assert.NoError(t, err) + assert.NotNil(t, conn) + assert.Equal(t, authRequestExpectedChapSha1[:41], dialer.conn.writebuf.Bytes()[:41]) +} + +func TestAuthDialer_Dial_PapSha256Auth(t *testing.T) { + salt := fmt.Sprintf("%s", testDialSalt) + salt = base64.StdEncoding.EncodeToString([]byte(salt)) + dialer := mockIoDialer{ + init: func(conn *mockIoConn) { + conn.greeting.Salt = salt + conn.writeWgDelay = 1 + conn.readWgDelay = 2 + conn.readbuf.Write(okResponse) + conn.wgDoneOnClose = false + }, + } + defer func() { + dialer.conn.writeWg.Done() + }() + + authDialer := tarantool.AuthDialer{ + Dialer: &dialer, + Username: "test", + Password: "test", + Auth: tarantool.PapSha256Auth, + } + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := authDialer.Dial(ctx, tarantool.DialOpts{}) + if conn != nil { + conn.Close() + } + + assert.NoError(t, err) + assert.NotNil(t, conn) + assert.Equal(t, authRequestExpectedPapSha256[:41], dialer.conn.writebuf.Bytes()[:41]) +} + +func TestProtocolDialer_Dial_DialerError(t *testing.T) { + dialer := tarantool.ProtocolDialer{ + Dialer: mockErrorDialer{ + err: fmt.Errorf("some error"), + }, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := dialer.Dial(ctx, tarantool.DialOpts{}) + if conn != nil { + conn.Close() + } + + assert.NotNil(t, err) + assert.EqualError(t, err, "some error") +} + +func TestConn_ProtocolDialer_hang_connection(t *testing.T) { + mock := &mockIoDialer{ + init: func(conn *mockIoConn) { + conn.info = tarantool.ProtocolInfo{Version: 1} + conn.readWgDelay = 0 + conn.writeWgDelay = 0 + }, + } + dialer := tarantool.ProtocolDialer{ + Dialer: mock, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + + conn, err := tarantool.Connect(ctx, &dialer, + tarantool.Opts{ + Timeout: 1000 * time.Second, // Avoid pings. + SkipSchema: true, + }) + + require.Nil(t, conn) + require.Error(t, err, context.DeadlineExceeded) + require.Equal(t, mock.conn.writeCnt, 1) + require.Equal(t, mock.conn.readCnt, 0) + require.Greater(t, mock.conn.closeCnt, 1) +} + +func TestProtocolDialer_Dial_IdentifyFailed(t *testing.T) { + dialer := tarantool.ProtocolDialer{ + Dialer: &mockIoDialer{ + init: func(conn *mockIoConn) { + conn.info = tarantool.ProtocolInfo{Version: 1} + conn.writeWgDelay = 1 + conn.readWgDelay = 2 + conn.readbuf.Write(errResponse) + }, + }, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := dialer.Dial(ctx, tarantool.DialOpts{}) + + assert.NotNil(t, err) + assert.ErrorContains(t, err, "failed to identify") + if conn != nil { + conn.Close() + t.Errorf("connection is not nil") + } +} + +func TestProtocolDialer_Dial_WrongInfo(t *testing.T) { + dialer := tarantool.ProtocolDialer{ + Dialer: &mockIoDialer{ + init: func(conn *mockIoConn) { + conn.info = tarantool.ProtocolInfo{Version: 1} + conn.writeWgDelay = 1 + conn.readWgDelay = 2 + conn.readbuf.Write(idResponse) + }, + }, + RequiredProtocolInfo: validProtocolInfo, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := dialer.Dial(ctx, tarantool.DialOpts{}) + + assert.NotNil(t, err) + assert.ErrorContains(t, err, "invalid server protocol") + if conn != nil { + conn.Close() + t.Errorf("connection is not nil") + } +} + +func TestProtocolDialer_Dial(t *testing.T) { + protoInfo := tarantool.ProtocolInfo{ + Auth: tarantool.ChapSha1Auth, + Version: 6, + Features: []iproto.Feature{0x01, 0x15}, + } + + dialer := tarantool.ProtocolDialer{ + Dialer: &mockIoDialer{ + init: func(conn *mockIoConn) { + conn.info = tarantool.ProtocolInfo{Version: 1} + conn.writeWgDelay = 1 + conn.readWgDelay = 2 + conn.readbuf.Write(idResponse) + }, + }, + RequiredProtocolInfo: protoInfo, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := dialer.Dial(ctx, tarantool.DialOpts{}) + if conn != nil { + conn.Close() + } + + assert.NoError(t, err) + assert.NotNil(t, conn) + assert.Equal(t, protoInfo, conn.ProtocolInfo()) +} + +func TestGreetingDialer_Dial_DialerError(t *testing.T) { + dialer := tarantool.GreetingDialer{ + Dialer: mockErrorDialer{ + err: fmt.Errorf("some error"), + }, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := dialer.Dial(ctx, tarantool.DialOpts{}) + if conn != nil { + conn.Close() + } + + assert.NotNil(t, err) + assert.EqualError(t, err, "some error") +} + +func TestConn_GreetingDialer_hang_connection(t *testing.T) { + mock := &mockIoDialer{ + init: func(conn *mockIoConn) { + conn.readWgDelay = 0 + }, + } + dialer := tarantool.GreetingDialer{ + Dialer: mock, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + + conn, err := tarantool.Connect(ctx, &dialer, + tarantool.Opts{ + Timeout: 1000 * time.Second, // Avoid pings. + SkipSchema: true, + }) + + require.Nil(t, conn) + require.Error(t, err, context.DeadlineExceeded) + require.Equal(t, mock.conn.readCnt, 1) + require.Greater(t, mock.conn.closeCnt, 1) +} + +func TestGreetingDialer_Dial_GreetingFailed(t *testing.T) { + dialer := tarantool.GreetingDialer{ + Dialer: &mockIoDialer{ + init: func(conn *mockIoConn) { + conn.writeWgDelay = 1 + conn.readWgDelay = 2 + conn.readbuf.Write(errResponse) + }, + }, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := dialer.Dial(ctx, tarantool.DialOpts{}) + + assert.NotNil(t, err) + assert.ErrorContains(t, err, "failed to read greeting") + if conn != nil { + conn.Close() + t.Errorf("connection is not nil") + } +} + +func TestGreetingDialer_Dial(t *testing.T) { + dialer := tarantool.GreetingDialer{ + Dialer: &mockIoDialer{ + init: func(conn *mockIoConn) { + conn.info = tarantool.ProtocolInfo{Version: 1} + conn.writeWgDelay = 1 + conn.readWgDelay = 3 + conn.readbuf.Write(append(testDialVersion[:], testDialSalt[:]...)) + conn.readbuf.Write(idResponse) + }, + }, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := dialer.Dial(ctx, tarantool.DialOpts{}) + if conn != nil { + conn.Close() + } + + assert.NoError(t, err) + assert.NotNil(t, conn) + assert.Equal(t, string(testDialVersion[:]), conn.Greeting().Version) + assert.Equal(t, string(testDialSalt[:44]), conn.Greeting().Salt) +} diff --git a/errors.go b/errors.go new file mode 100644 index 000000000..60f71a2b1 --- /dev/null +++ b/errors.go @@ -0,0 +1,64 @@ +package tarantool + +import ( + "fmt" + + "github.com/tarantool/go-iproto" +) + +// Error is wrapper around error returned by Tarantool. +type Error struct { + Code iproto.Error + Msg string + ExtendedInfo *BoxError +} + +// Error converts an Error to a string. +func (tnterr Error) Error() string { + if tnterr.ExtendedInfo != nil { + return tnterr.ExtendedInfo.Error() + } + + return fmt.Sprintf("%s (0x%x)", tnterr.Msg, tnterr.Code) +} + +// ClientError is connection error produced by this client, +// i.e. connection failures or timeouts. +type ClientError struct { + Code uint32 + Msg string +} + +// Error converts a ClientError to a string. +func (clierr ClientError) Error() string { + return fmt.Sprintf("%s (0x%x)", clierr.Msg, clierr.Code) +} + +// Temporary returns true if next attempt to perform request may succeeded. +// +// Currently it returns true when: +// +// - Connection is not connected at the moment +// +// - request is timeouted +// +// - request is aborted due to rate limit +func (clierr ClientError) Temporary() bool { + switch clierr.Code { + case ErrConnectionNotReady, ErrTimeouted, ErrRateLimited, ErrIoError: + return true + default: + return false + } +} + +// Tarantool client error codes. +const ( + ErrConnectionNotReady = 0x4000 + iota + ErrConnectionClosed = 0x4000 + iota + ErrProtocolError = 0x4000 + iota + ErrTimeouted = 0x4000 + iota + ErrRateLimited = 0x4000 + iota + ErrConnectionShutdown = 0x4000 + iota + ErrIoError = 0x4000 + iota +) diff --git a/example_custom_unpacking_test.go b/example_custom_unpacking_test.go new file mode 100644 index 000000000..d8c790a25 --- /dev/null +++ b/example_custom_unpacking_test.go @@ -0,0 +1,145 @@ +package tarantool_test + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +type Tuple2 struct { + Cid uint + Orig string + Members []Member +} + +// Same effect in a "magic" way, but slower. +type Tuple3 struct { + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + + Cid uint + Orig string + Members []Member +} + +func (c *Tuple2) EncodeMsgpack(e *msgpack.Encoder) error { + if err := e.EncodeArrayLen(3); err != nil { + return err + } + if err := e.EncodeUint(uint64(c.Cid)); err != nil { + return err + } + if err := e.EncodeString(c.Orig); err != nil { + return err + } + e.Encode(c.Members) + return nil +} + +func (c *Tuple2) DecodeMsgpack(d *msgpack.Decoder) error { + var err error + var l int + if l, err = d.DecodeArrayLen(); err != nil { + return err + } + if l != 3 { + return fmt.Errorf("array len doesn't match: %d", l) + } + if c.Cid, err = d.DecodeUint(); err != nil { + return err + } + if c.Orig, err = d.DecodeString(); err != nil { + return err + } + if l, err = d.DecodeArrayLen(); err != nil { + return err + } + c.Members = make([]Member, l) + for i := 0; i < l; i++ { + d.Decode(&c.Members[i]) + } + return nil +} + +// Example demonstrates how to use custom (un)packing with typed selects and +// function calls. +// +// You can specify user-defined packing/unpacking functions for your types. +// This allows you to store complex structures within a tuple and may speed up +// your requests. +// +// Alternatively, you can just instruct the msgpack library to encode your +// structure as an array. This is safe "magic". It is easier to implement than +// a custom packer/unpacker, but it will work slower. +func Example_customUnpacking() { + // Establish a connection. + + dialer := tarantool.NetDialer{ + Address: "127.0.0.1:3013", + User: "test", + Password: "test", + } + opts := tarantool.Opts{} + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + conn, err := tarantool.Connect(ctx, dialer, opts) + cancel() + if err != nil { + log.Fatalf("Failed to connect: %s", err.Error()) + } + + spaceNo := uint32(617) + indexNo := uint32(0) + + tuple := Tuple2{Cid: 777, Orig: "orig", Members: []Member{{"lol", "", 1}, {"wut", "", 3}}} + // Insert a structure itself. + initReq := tarantool.NewReplaceRequest(spaceNo).Tuple(&tuple) + data, err := conn.Do(initReq).Get() + if err != nil { + log.Fatalf("Failed to insert: %s", err.Error()) + return + } + fmt.Println("Data", data) + + var tuples1 []Tuple2 + selectReq := tarantool.NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(tarantool.IterEq). + Key([]interface{}{777}) + err = conn.Do(selectReq).GetTyped(&tuples1) + if err != nil { + log.Fatalf("Failed to SelectTyped: %s", err.Error()) + return + } + fmt.Println("Tuples (tuples1)", tuples1) + + // Same result in a "magic" way. + var tuples2 []Tuple3 + err = conn.Do(selectReq).GetTyped(&tuples2) + if err != nil { + log.Fatalf("Failed to SelectTyped: %s", err.Error()) + return + } + fmt.Println("Tuples (tuples2):", tuples2) + + // Call a function "func_name" returning a table of custom tuples. + var tuples3 [][]Tuple3 + callReq := tarantool.NewCallRequest("func_name") + err = conn.Do(callReq).GetTyped(&tuples3) + if err != nil { + log.Fatalf("Failed to CallTyped: %s", err.Error()) + return + } + fmt.Println("Tuples (tuples3):", tuples3) + + // Output: + // Data [[777 orig [[lol 1] [wut 3]]]] + // Tuples (tuples1) [{777 orig [{lol 1} {wut 3}]}] + // Tuples (tuples2): [{{} 777 orig [{lol 1} {wut 3}]}] + // Tuples (tuples3): [[{{} 221 [{Moscow 34} {Minsk 23} {Kiev 31}]}]] + +} diff --git a/example_test.go b/example_test.go new file mode 100644 index 000000000..9eadf5971 --- /dev/null +++ b/example_test.go @@ -0,0 +1,1521 @@ +package tarantool_test + +import ( + "context" + "fmt" + "net" + "regexp" + "time" + + "github.com/tarantool/go-iproto" + + "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/test_helpers" +) + +type Tuple struct { + // Instruct msgpack to pack this struct as array, so no custom packer + // is needed. + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Id uint + Msg string + Name string +} + +func exampleConnect(dialer tarantool.Dialer, opts tarantool.Opts) *tarantool.Connection { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, dialer, opts) + if err != nil { + panic("Connection is not established: " + err.Error()) + } + return conn +} + +func ExampleIntKey() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + const space = "test" + const index = "primary" + tuple := []interface{}{int(1111), "hello", "world"} + conn.Do(tarantool.NewReplaceRequest(space).Tuple(tuple)).Get() + + var t []Tuple + err := conn.Do(tarantool.NewSelectRequest(space). + Index(index). + Iterator(tarantool.IterEq). + Key(tarantool.IntKey{1111}), + ).GetTyped(&t) + fmt.Println("Error", err) + fmt.Println("Data", t) + // Output: + // Error + // Data [{{} 1111 hello world}] +} + +func ExampleUintKey() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + const space = "test" + const index = "primary" + tuple := []interface{}{uint(1111), "hello", "world"} + conn.Do(tarantool.NewReplaceRequest(space).Tuple(tuple)).Get() + + var t []Tuple + err := conn.Do(tarantool.NewSelectRequest(space). + Index(index). + Iterator(tarantool.IterEq). + Key(tarantool.UintKey{1111}), + ).GetTyped(&t) + fmt.Println("Error", err) + fmt.Println("Data", t) + // Output: + // Error + // Data [{{} 1111 hello world}] +} + +func ExampleStringKey() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + const space = "teststring" + const index = "primary" + tuple := []interface{}{"any", []byte{0x01, 0x02}} + conn.Do(tarantool.NewReplaceRequest(space).Tuple(tuple)).Get() + + t := []struct { + Key string + Value []byte + }{} + err := conn.Do(tarantool.NewSelectRequest(space). + Index(index). + Iterator(tarantool.IterEq). + Key(tarantool.StringKey{"any"}), + ).GetTyped(&t) + fmt.Println("Error", err) + fmt.Println("Data", t) + // Output: + // Error + // Data [{any [1 2]}] +} + +func ExampleIntIntKey() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + const space = "testintint" + const index = "primary" + tuple := []interface{}{1, 2, "foo"} + conn.Do(tarantool.NewReplaceRequest(space).Tuple(tuple)).Get() + + t := []struct { + Key1 int + Key2 int + Value string + }{} + err := conn.Do(tarantool.NewSelectRequest(space). + Index(index). + Iterator(tarantool.IterEq). + Key(tarantool.IntIntKey{1, 2}), + ).GetTyped(&t) + fmt.Println("Error", err) + fmt.Println("Data", t) + // Output: + // Error + // Data [{1 2 foo}] +} + +func ExamplePingRequest() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + // Ping a Tarantool instance to check connection. + data, err := conn.Do(tarantool.NewPingRequest()).Get() + fmt.Println("Ping Data", data) + fmt.Println("Ping Error", err) + // Output: + // Ping Data [] + // Ping Error +} + +// To pass contexts to request objects, use the Context() method. +// Pay attention that when using context with request objects, +// the timeout option for Connection will not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func ExamplePingRequest_Context() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + timeout := time.Nanosecond + + // This way you may set the a common timeout for requests with a context. + rootCtx, cancelRoot := context.WithTimeout(context.Background(), timeout) + defer cancelRoot() + + // This context will be canceled with the root after commonTimeout. + ctx, cancel := context.WithCancel(rootCtx) + defer cancel() + + req := tarantool.NewPingRequest().Context(ctx) + + // Ping a Tarantool instance to check connection. + data, err := conn.Do(req).Get() + fmt.Println("Ping Resp data", data) + fmt.Println("Ping Error", regexp.MustCompile("[0-9]+").ReplaceAllString(err.Error(), "N")) + // Output: + // Ping Resp data [] + // Ping Error context is done (request ID N): context deadline exceeded +} + +func ExampleSelectRequest() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + for i := 1111; i <= 1112; i++ { + conn.Do(tarantool.NewReplaceRequest(spaceNo). + Tuple([]interface{}{uint(i), "hello", "world"}), + ).Get() + } + + key := []interface{}{uint(1111)} + resp, err := conn.Do(tarantool.NewSelectRequest(617). + Limit(100). + Iterator(tarantool.IterEq). + Key(key), + ).GetResponse() + + if err != nil { + fmt.Printf("error in select is %v", err) + return + } + selResp, ok := resp.(*tarantool.SelectResponse) + if !ok { + fmt.Print("wrong response type") + return + } + + pos, err := selResp.Pos() + if err != nil { + fmt.Printf("error in Pos: %v", err) + return + } + fmt.Printf("pos for Select is %v\n", pos) + + data, err := resp.Decode() + if err != nil { + fmt.Printf("error while decoding: %v", err) + return + } + fmt.Printf("response is %#v\n", data) + + var res []Tuple + err = conn.Do(tarantool.NewSelectRequest("test"). + Index("primary"). + Limit(100). + Iterator(tarantool.IterEq). + Key(key), + ).GetTyped(&res) + if err != nil { + fmt.Printf("error in select is %v", err) + return + } + fmt.Printf("response is %v\n", res) + + // Output: + // pos for Select is [] + // response is []interface {}{[]interface {}{0x457, "hello", "world"}} + // response is [{{} 1111 hello world}] +} + +func ExampleSelectRequest_spaceAndIndexNames() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + req := tarantool.NewSelectRequest(spaceName) + req.Index(indexName) + data, err := conn.Do(req).Get() + + if err != nil { + fmt.Printf("Failed to execute the request: %s\n", err) + } else { + fmt.Println(data) + } +} + +func ExampleInsertRequest() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + // Insert a new tuple { 31, 1 }. + data, err := conn.Do(tarantool.NewInsertRequest(spaceNo). + Tuple([]interface{}{uint(31), "test", "one"}), + ).Get() + fmt.Println("Insert 31") + fmt.Println("Error", err) + fmt.Println("Data", data) + // Insert a new tuple { 32, 1 }. + data, err = conn.Do(tarantool.NewInsertRequest("test"). + Tuple(&Tuple{Id: 32, Msg: "test", Name: "one"}), + ).Get() + fmt.Println("Insert 32") + fmt.Println("Error", err) + fmt.Println("Data", data) + + // Delete tuple with primary key { 31 }. + conn.Do(tarantool.NewDeleteRequest("test"). + Index("primary"). + Key([]interface{}{uint(31)}), + ).Get() + // Delete tuple with primary key { 32 }. + conn.Do(tarantool.NewDeleteRequest("test"). + Index(indexNo). + Key([]interface{}{uint(31)}), + ).Get() + // Output: + // Insert 31 + // Error + // Data [[31 test one]] + // Insert 32 + // Error + // Data [[32 test one]] +} + +func ExampleInsertRequest_spaceAndIndexNames() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + req := tarantool.NewInsertRequest(spaceName) + data, err := conn.Do(req).Get() + + if err != nil { + fmt.Printf("Failed to execute the request: %s\n", err) + } else { + fmt.Println(data) + } +} + +func ExampleDeleteRequest() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + // Insert a new tuple { 35, 1 }. + conn.Do(tarantool.NewInsertRequest(spaceNo). + Tuple([]interface{}{uint(35), "test", "one"}), + ).Get() + // Insert a new tuple { 36, 1 }. + conn.Do(tarantool.NewInsertRequest("test"). + Tuple(&Tuple{Id: 36, Msg: "test", Name: "one"}), + ).Get() + + // Delete tuple with primary key { 35 }. + data, err := conn.Do(tarantool.NewDeleteRequest(spaceNo). + Index(indexNo). + Key([]interface{}{uint(35)}), + ).Get() + fmt.Println("Delete 35") + fmt.Println("Error", err) + fmt.Println("Data", data) + + // Delete tuple with primary key { 36 }. + data, err = conn.Do(tarantool.NewDeleteRequest("test"). + Index("primary"). + Key([]interface{}{uint(36)}), + ).Get() + fmt.Println("Delete 36") + fmt.Println("Error", err) + fmt.Println("Data", data) + // Output: + // Delete 35 + // Error + // Data [[35 test one]] + // Delete 36 + // Error + // Data [[36 test one]] +} + +func ExampleDeleteRequest_spaceAndIndexNames() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + req := tarantool.NewDeleteRequest(spaceName) + req.Index(indexName) + data, err := conn.Do(req).Get() + + if err != nil { + fmt.Printf("Failed to execute the request: %s\n", err) + } else { + fmt.Println(data) + } +} + +func ExampleReplaceRequest() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + // Insert a new tuple { 13, 1 }. + conn.Do(tarantool.NewInsertRequest(spaceNo). + Tuple([]interface{}{uint(13), "test", "one"}), + ).Get() + + // Replace a tuple with primary key 13. + // Note, Tuple is defined within tests, and has EncdodeMsgpack and + // DecodeMsgpack methods. + data, err := conn.Do(tarantool.NewReplaceRequest(spaceNo). + Tuple([]interface{}{uint(13), 1}), + ).Get() + fmt.Println("Replace 13") + fmt.Println("Error", err) + fmt.Println("Data", data) + data, err = conn.Do(tarantool.NewReplaceRequest("test"). + Tuple([]interface{}{uint(13), 1}), + ).Get() + fmt.Println("Replace 13") + fmt.Println("Error", err) + fmt.Println("Data", data) + data, err = conn.Do(tarantool.NewReplaceRequest("test"). + Tuple(&Tuple{Id: 13, Msg: "test", Name: "eleven"}), + ).Get() + fmt.Println("Replace 13") + fmt.Println("Error", err) + fmt.Println("Data", data) + data, err = conn.Do(tarantool.NewReplaceRequest("test"). + Tuple(&Tuple{Id: 13, Msg: "test", Name: "twelve"}), + ).Get() + fmt.Println("Replace 13") + fmt.Println("Error", err) + fmt.Println("Data", data) + // Output: + // Replace 13 + // Error + // Data [[13 1]] + // Replace 13 + // Error + // Data [[13 1]] + // Replace 13 + // Error + // Data [[13 test eleven]] + // Replace 13 + // Error + // Data [[13 test twelve]] +} + +func ExampleReplaceRequest_spaceAndIndexNames() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + req := tarantool.NewReplaceRequest(spaceName) + data, err := conn.Do(req).Get() + + if err != nil { + fmt.Printf("Failed to execute the request: %s\n", err) + } else { + fmt.Println(data) + } +} + +func ExampleUpdateRequest() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + for i := 1111; i <= 1112; i++ { + conn.Do(tarantool.NewReplaceRequest(spaceNo). + Tuple([]interface{}{uint(i), "text", 1, 1, 1, 1, 1}), + ).Get() + } + + req := tarantool.NewUpdateRequest(617). + Key(tarantool.IntKey{1111}). + Operations(tarantool.NewOperations(). + Add(2, 1). + Subtract(3, 1). + BitwiseAnd(4, 1). + BitwiseOr(5, 1). + BitwiseXor(6, 1). + Splice(1, 1, 2, "!!"). + Insert(7, "new"). + Assign(7, "updated")) + data, err := conn.Do(req).Get() + if err != nil { + fmt.Printf("error in do update request is %v", err) + return + } + fmt.Printf("response is %#v\n", data) + // Output: + // response is []interface {}{[]interface {}{0x457, "t!!t", 2, 0, 1, 1, 0, "updated"}} +} + +func ExampleUpdateRequest_spaceAndIndexNames() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + req := tarantool.NewUpdateRequest(spaceName) + req.Index(indexName) + data, err := conn.Do(req).Get() + + if err != nil { + fmt.Printf("Failed to execute the request: %s\n", err) + } else { + fmt.Println(data) + } +} + +func ExampleUpsertRequest() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + var req tarantool.Request + req = tarantool.NewUpsertRequest(617). + Tuple([]interface{}{uint(1113), "first", "first"}). + Operations(tarantool.NewOperations().Assign(1, "updated")) + data, err := conn.Do(req).Get() + if err != nil { + fmt.Printf("error in do select upsert is %v", err) + return + } + fmt.Printf("response is %#v\n", data) + + req = tarantool.NewUpsertRequest("test"). + Tuple([]interface{}{uint(1113), "second", "second"}). + Operations(tarantool.NewOperations().Assign(2, "updated")) + fut := conn.Do(req) + data, err = fut.Get() + if err != nil { + fmt.Printf("error in do async upsert request is %v", err) + return + } + fmt.Printf("response is %#v\n", data) + + req = tarantool.NewSelectRequest(617). + Limit(100). + Key(tarantool.IntKey{1113}) + data, err = conn.Do(req).Get() + if err != nil { + fmt.Printf("error in do select request is %v", err) + return + } + fmt.Printf("response is %#v\n", data) + // Output: + // response is []interface {}{} + // response is []interface {}{} + // response is []interface {}{[]interface {}{0x459, "first", "updated"}} +} + +func ExampleUpsertRequest_spaceAndIndexNames() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + req := tarantool.NewUpsertRequest(spaceName) + data, err := conn.Do(req).Get() + + if err != nil { + fmt.Printf("Failed to execute the request: %s\n", err) + } else { + fmt.Println(data) + } +} + +func ExampleCallRequest() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + // Call a function 'simple_concat' with arguments. + data, err := conn.Do(tarantool.NewCallRequest("simple_concat"). + Args([]interface{}{"1"}), + ).Get() + fmt.Println("Call simple_concat()") + fmt.Println("Error", err) + fmt.Println("Data", data) + // Output: + // Call simple_concat() + // Error + // Data [11] +} + +func ExampleEvalRequest() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + // Run raw Lua code. + data, err := conn.Do(tarantool.NewEvalRequest("return 1 + 2")).Get() + fmt.Println("Eval 'return 1 + 2'") + fmt.Println("Error", err) + fmt.Println("Data", data) + // Output: + // Eval 'return 1 + 2' + // Error + // Data [3] +} + +// To use SQL to query a tarantool instance, use ExecuteRequest. +// +// Pay attention that with different types of queries (DDL, DQL, DML etc.) +// some fields of the response structure (MetaData and InfoAutoincrementIds +// in SQLInfo) may be nil. +func ExampleExecuteRequest() { + // Tarantool supports SQL since version 2.0.0 + isLess, _ := test_helpers.IsTarantoolVersionLess(2, 0, 0) + if isLess { + return + } + + conn := exampleConnect(dialer, opts) + defer conn.Close() + + req := tarantool.NewExecuteRequest( + "CREATE TABLE SQL_TEST (id INTEGER PRIMARY KEY, name STRING)") + resp, err := conn.Do(req).GetResponse() + fmt.Println("Execute") + fmt.Println("Error", err) + + data, err := resp.Decode() + fmt.Println("Error", err) + fmt.Println("Data", data) + + exResp, ok := resp.(*tarantool.ExecuteResponse) + if !ok { + fmt.Printf("wrong response type") + return + } + + metaData, err := exResp.MetaData() + fmt.Println("MetaData", metaData) + fmt.Println("Error", err) + + sqlInfo, err := exResp.SQLInfo() + fmt.Println("SQL Info", sqlInfo) + fmt.Println("Error", err) + + // There are 4 options to pass named parameters to an SQL query: + // 1) The simple map; + sqlBind1 := map[string]interface{}{ + "id": 1, + "name": "test", + } + + // 2) Any type of structure; + sqlBind2 := struct { + Id int + Name string + }{1, "test"} + + // 3) It is possible to use []tarantool.KeyValueBind; + sqlBind3 := []interface{}{ + tarantool.KeyValueBind{Key: "id", Value: 1}, + tarantool.KeyValueBind{Key: "name", Value: "test"}, + } + + // 4) []interface{} slice with tarantool.KeyValueBind items inside; + sqlBind4 := []tarantool.KeyValueBind{ + {"id", 1}, + {"name", "test"}, + } + + // 1) + req = tarantool.NewExecuteRequest( + "CREATE TABLE SQL_TEST (id INTEGER PRIMARY KEY, name STRING)") + req = req.Args(sqlBind1) + resp, err = conn.Do(req).GetResponse() + fmt.Println("Execute") + fmt.Println("Error", err) + data, err = resp.Decode() + fmt.Println("Error", err) + fmt.Println("Data", data) + exResp, ok = resp.(*tarantool.ExecuteResponse) + if !ok { + fmt.Printf("wrong response type") + return + } + metaData, err = exResp.MetaData() + fmt.Println("MetaData", metaData) + fmt.Println("Error", err) + sqlInfo, err = exResp.SQLInfo() + fmt.Println("SQL Info", sqlInfo) + fmt.Println("Error", err) + + // 2) + req = req.Args(sqlBind2) + resp, err = conn.Do(req).GetResponse() + fmt.Println("Execute") + fmt.Println("Error", err) + data, err = resp.Decode() + fmt.Println("Error", err) + fmt.Println("Data", data) + exResp, ok = resp.(*tarantool.ExecuteResponse) + if !ok { + fmt.Printf("wrong response type") + return + } + metaData, err = exResp.MetaData() + fmt.Println("MetaData", metaData) + fmt.Println("Error", err) + sqlInfo, err = exResp.SQLInfo() + fmt.Println("SQL Info", sqlInfo) + fmt.Println("Error", err) + + // 3) + req = req.Args(sqlBind3) + resp, err = conn.Do(req).GetResponse() + fmt.Println("Execute") + fmt.Println("Error", err) + data, err = resp.Decode() + fmt.Println("Error", err) + fmt.Println("Data", data) + exResp, ok = resp.(*tarantool.ExecuteResponse) + if !ok { + fmt.Printf("wrong response type") + return + } + metaData, err = exResp.MetaData() + fmt.Println("MetaData", metaData) + fmt.Println("Error", err) + sqlInfo, err = exResp.SQLInfo() + fmt.Println("SQL Info", sqlInfo) + fmt.Println("Error", err) + + // 4) + req = req.Args(sqlBind4) + resp, err = conn.Do(req).GetResponse() + fmt.Println("Execute") + fmt.Println("Error", err) + data, err = resp.Decode() + fmt.Println("Error", err) + fmt.Println("Data", data) + exResp, ok = resp.(*tarantool.ExecuteResponse) + if !ok { + fmt.Printf("wrong response type") + return + } + metaData, err = exResp.MetaData() + fmt.Println("MetaData", metaData) + fmt.Println("Error", err) + sqlInfo, err = exResp.SQLInfo() + fmt.Println("SQL Info", sqlInfo) + fmt.Println("Error", err) + + // The way to pass positional arguments to an SQL query. + req = tarantool.NewExecuteRequest( + "SELECT id FROM SQL_TEST WHERE id=? AND name=?"). + Args([]interface{}{2, "test"}) + resp, err = conn.Do(req).GetResponse() + fmt.Println("Execute") + fmt.Println("Error", err) + data, err = resp.Decode() + fmt.Println("Error", err) + fmt.Println("Data", data) + exResp, ok = resp.(*tarantool.ExecuteResponse) + if !ok { + fmt.Printf("wrong response type") + return + } + metaData, err = exResp.MetaData() + fmt.Println("MetaData", metaData) + fmt.Println("Error", err) + sqlInfo, err = exResp.SQLInfo() + fmt.Println("SQL Info", sqlInfo) + fmt.Println("Error", err) + + // The way to pass SQL expression with using custom packing/unpacking for + // a type. + var res []Tuple + req = tarantool.NewExecuteRequest( + "SELECT id, name, name FROM SQL_TEST WHERE id=?"). + Args([]interface{}{2}) + err = conn.Do(req).GetTyped(&res) + fmt.Println("ExecuteTyped") + fmt.Println("Error", err) + fmt.Println("Data", res) + + // For using different types of parameters (positioned/named), collect all + // items in []interface{}. + // All "named" items must be passed with tarantool.KeyValueBind{}. + req = tarantool.NewExecuteRequest( + "SELECT id FROM SQL_TEST WHERE id=? AND name=?"). + Args([]interface{}{tarantool.KeyValueBind{"id", 1}, "test"}) + resp, err = conn.Do(req).GetResponse() + fmt.Println("Execute") + fmt.Println("Error", err) + data, err = resp.Decode() + fmt.Println("Error", err) + fmt.Println("Data", data) + exResp, ok = resp.(*tarantool.ExecuteResponse) + if !ok { + fmt.Printf("wrong response type") + return + } + metaData, err = exResp.MetaData() + fmt.Println("MetaData", metaData) + fmt.Println("Error", err) + sqlInfo, err = exResp.SQLInfo() + fmt.Println("SQL Info", sqlInfo) + fmt.Println("Error", err) +} + +func getTestTxnDialer() tarantool.Dialer { + txnDialer := dialer + + // Assert that server supports expected protocol features. + txnDialer.RequiredProtocolInfo = tarantool.ProtocolInfo{ + Version: tarantool.ProtocolVersion(1), + Features: []iproto.Feature{ + iproto.IPROTO_FEATURE_STREAMS, + iproto.IPROTO_FEATURE_TRANSACTIONS, + }, + } + + return txnDialer +} + +func ExampleCommitRequest() { + var req tarantool.Request + var err error + + // Tarantool supports streams and interactive transactions since version 2.10.0 + isLess, _ := test_helpers.IsTarantoolVersionLess(2, 10, 0) + if err != nil || isLess { + return + } + + txnDialer := getTestTxnDialer() + conn := exampleConnect(txnDialer, opts) + defer conn.Close() + + stream, _ := conn.NewStream() + + // Begin transaction + req = tarantool.NewBeginRequest() + data, err := stream.Do(req).Get() + if err != nil { + fmt.Printf("Failed to Begin: %s", err.Error()) + return + } + fmt.Printf("Begin transaction: response is %#v\n", data) + + // Insert in stream + req = tarantool.NewInsertRequest(spaceName). + Tuple([]interface{}{uint(1001), "commit_hello", "commit_world"}) + data, err = stream.Do(req).Get() + if err != nil { + fmt.Printf("Failed to Insert: %s", err.Error()) + return + } + fmt.Printf("Insert in stream: response is %#v\n", data) + + // Select not related to the transaction + // while transaction is not committed + // result of select is empty + selectReq := tarantool.NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(tarantool.IterEq). + Key([]interface{}{uint(1001)}) + data, err = conn.Do(selectReq).Get() + if err != nil { + fmt.Printf("Failed to Select: %s", err.Error()) + return + } + fmt.Printf("Select out of stream before commit: response is %#v\n", data) + + // Select in stream + data, err = stream.Do(selectReq).Get() + if err != nil { + fmt.Printf("Failed to Select: %s", err.Error()) + return + } + fmt.Printf("Select in stream: response is %#v\n", data) + + // Commit transaction + req = tarantool.NewCommitRequest() + data, err = stream.Do(req).Get() + if err != nil { + fmt.Printf("Failed to Commit: %s", err.Error()) + return + } + fmt.Printf("Commit transaction: response is %#v\n", data) + + // Select outside of transaction + data, err = conn.Do(selectReq).Get() + if err != nil { + fmt.Printf("Failed to Select: %s", err.Error()) + return + } + fmt.Printf("Select after commit: response is %#v\n", data) +} + +func ExampleRollbackRequest() { + var req tarantool.Request + var err error + + // Tarantool supports streams and interactive transactions since version 2.10.0 + isLess, _ := test_helpers.IsTarantoolVersionLess(2, 10, 0) + if err != nil || isLess { + return + } + + txnDialer := getTestTxnDialer() + conn := exampleConnect(txnDialer, opts) + defer conn.Close() + + stream, _ := conn.NewStream() + + // Begin transaction + req = tarantool.NewBeginRequest() + data, err := stream.Do(req).Get() + if err != nil { + fmt.Printf("Failed to Begin: %s", err.Error()) + return + } + fmt.Printf("Begin transaction: response is %#v\n", data) + + // Insert in stream + req = tarantool.NewInsertRequest(spaceName). + Tuple([]interface{}{uint(2001), "rollback_hello", "rollback_world"}) + data, err = stream.Do(req).Get() + if err != nil { + fmt.Printf("Failed to Insert: %s", err.Error()) + return + } + fmt.Printf("Insert in stream: response is %#v\n", data) + + // Select not related to the transaction + // while transaction is not committed + // result of select is empty + selectReq := tarantool.NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(tarantool.IterEq). + Key([]interface{}{uint(2001)}) + data, err = conn.Do(selectReq).Get() + if err != nil { + fmt.Printf("Failed to Select: %s", err.Error()) + return + } + fmt.Printf("Select out of stream: response is %#v\n", data) + + // Select in stream + data, err = stream.Do(selectReq).Get() + if err != nil { + fmt.Printf("Failed to Select: %s", err.Error()) + return + } + fmt.Printf("Select in stream: response is %#v\n", data) + + // Rollback transaction + req = tarantool.NewRollbackRequest() + data, err = stream.Do(req).Get() + if err != nil { + fmt.Printf("Failed to Rollback: %s", err.Error()) + return + } + fmt.Printf("Rollback transaction: response is %#v\n", data) + + // Select outside of transaction + data, err = conn.Do(selectReq).Get() + if err != nil { + fmt.Printf("Failed to Select: %s", err.Error()) + return + } + fmt.Printf("Select after Rollback: response is %#v\n", data) +} + +func ExampleBeginRequest_TxnIsolation() { + var req tarantool.Request + var err error + + // Tarantool supports streams and interactive transactions since version 2.10.0 + isLess, _ := test_helpers.IsTarantoolVersionLess(2, 10, 0) + if err != nil || isLess { + return + } + + txnDialer := getTestTxnDialer() + conn := exampleConnect(txnDialer, opts) + defer conn.Close() + + stream, _ := conn.NewStream() + + // Begin transaction + req = tarantool.NewBeginRequest(). + TxnIsolation(tarantool.ReadConfirmedLevel). + Timeout(500 * time.Millisecond) + data, err := stream.Do(req).Get() + if err != nil { + fmt.Printf("Failed to Begin: %s", err.Error()) + return + } + fmt.Printf("Begin transaction: response is %#v\n", data) + + // Insert in stream + req = tarantool.NewInsertRequest(spaceName). + Tuple([]interface{}{uint(2001), "rollback_hello", "rollback_world"}) + data, err = stream.Do(req).Get() + if err != nil { + fmt.Printf("Failed to Insert: %s", err.Error()) + return + } + fmt.Printf("Insert in stream: response is %#v\n", data) + + // Select not related to the transaction + // while transaction is not committed + // result of select is empty + selectReq := tarantool.NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(tarantool.IterEq). + Key([]interface{}{uint(2001)}) + data, err = conn.Do(selectReq).Get() + if err != nil { + fmt.Printf("Failed to Select: %s", err.Error()) + return + } + fmt.Printf("Select out of stream: response is %#v\n", data) + + // Select in stream + data, err = stream.Do(selectReq).Get() + if err != nil { + fmt.Printf("Failed to Select: %s", err.Error()) + return + } + fmt.Printf("Select in stream: response is %#v\n", data) + + // Rollback transaction + req = tarantool.NewRollbackRequest() + data, err = stream.Do(req).Get() + if err != nil { + fmt.Printf("Failed to Rollback: %s", err.Error()) + return + } + fmt.Printf("Rollback transaction: response is %#v\n", data) + + // Select outside of transaction + data, err = conn.Do(selectReq).Get() + if err != nil { + fmt.Printf("Failed to Select: %s", err.Error()) + return + } + fmt.Printf("Select after Rollback: response is %#v\n", data) +} + +func ExampleBeginRequest_IsSync() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + // Tarantool supports IS_SYNC flag for BeginRequest since version 3.1.0. + isLess, err := test_helpers.IsTarantoolVersionLess(3, 1, 0) + if err != nil || isLess { + return + } + + stream, err := conn.NewStream() + if err != nil { + fmt.Printf("error getting the stream: %s\n", err) + return + } + + // Begin transaction with synchronous mode. + req := tarantool.NewBeginRequest().IsSync(true) + resp, err := stream.Do(req).GetResponse() + switch { + case err != nil: + fmt.Printf("error getting the response: %s\n", err) + case resp.Header().Error != tarantool.ErrorNo: + fmt.Printf("response error code: %s\n", resp.Header().Error) + default: + fmt.Println("Success.") + } +} + +func ExampleCommitRequest_IsSync() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + // Tarantool supports IS_SYNC flag for CommitRequest since version 3.1.0. + isLess, err := test_helpers.IsTarantoolVersionLess(3, 1, 0) + if err != nil || isLess { + return + } + + var req tarantool.Request + + stream, err := conn.NewStream() + if err != nil { + fmt.Printf("error getting the stream: %s\n", err) + return + } + + // Begin transaction. + req = tarantool.NewBeginRequest() + resp, err := stream.Do(req).GetResponse() + switch { + case err != nil: + fmt.Printf("error getting the response: %s\n", err) + return + case resp.Header().Error != tarantool.ErrorNo: + fmt.Printf("response error code: %s\n", resp.Header().Error) + return + } + + // Insert in stream. + req = tarantool.NewReplaceRequest("test").Tuple([]interface{}{1, "test"}) + resp, err = stream.Do(req).GetResponse() + switch { + case err != nil: + fmt.Printf("error getting the response: %s\n", err) + return + case resp.Header().Error != tarantool.ErrorNo: + fmt.Printf("response error code: %s\n", resp.Header().Error) + return + } + + // Commit transaction in sync mode. + req = tarantool.NewCommitRequest().IsSync(true) + resp, err = stream.Do(req).GetResponse() + switch { + case err != nil: + fmt.Printf("error getting the response: %s\n", err) + case resp.Header().Error != tarantool.ErrorNo: + fmt.Printf("response error code: %s\n", resp.Header().Error) + default: + fmt.Println("Success.") + } +} + +func ExampleErrorNo() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + req := tarantool.NewPingRequest() + resp, err := conn.Do(req).GetResponse() + if err != nil { + fmt.Printf("error getting the response: %s\n", err) + return + } + + if resp.Header().Error != tarantool.ErrorNo { + fmt.Printf("response error code: %s\n", resp.Header().Error) + } else { + fmt.Println("Success.") + } + // Output: + // Success. +} + +func ExampleConnect() { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + dialer := tarantool.NetDialer{ + Address: server, + User: "test", + Password: "test", + } + + conn, err := tarantool.Connect(ctx, dialer, tarantool.Opts{ + Timeout: 5 * time.Second, + Concurrency: 32, + }) + if err != nil { + fmt.Println("No connection available") + return + } + defer conn.Close() + if conn != nil { + fmt.Println("Connection is ready") + } + // Output: + // Connection is ready +} + +func ExampleConnect_reconnects() { + dialer := tarantool.NetDialer{ + Address: "127.0.0.1:3013", + User: "test", + Password: "test", + } + + opts := tarantool.Opts{ + Timeout: 5 * time.Second, + Concurrency: 32, + Reconnect: time.Second, + MaxReconnects: 10, + } + + var conn *tarantool.Connection + var err error + + for i := uint(0); i < opts.MaxReconnects; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + conn, err = tarantool.Connect(ctx, dialer, opts) + cancel() + if err == nil { + break + } + time.Sleep(opts.Reconnect) + } + if err != nil { + fmt.Println("No connection available") + return + } + defer conn.Close() + if conn != nil { + fmt.Println("Connection is ready") + } + // Output: + // Connection is ready +} + +// Example demonstrates how to retrieve information with space schema. +func ExampleSchema() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + schema, err := tarantool.GetSchema(conn) + if err != nil { + fmt.Printf("unexpected error: %s\n", err.Error()) + } + if schema.SpacesById == nil { + fmt.Println("schema.SpacesById is nil") + } + if schema.Spaces == nil { + fmt.Println("schema.Spaces is nil") + } + + space1 := schema.Spaces["test"] + space2 := schema.SpacesById[616] + fmt.Printf("Space 1 ID %d %s\n", space1.Id, space1.Name) + fmt.Printf("Space 2 ID %d %s\n", space2.Id, space2.Name) + // Output: + // Space 1 ID 617 test + // Space 2 ID 616 schematest +} + +// Example demonstrates how to update the connection schema. +func ExampleConnection_SetSchema() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + // Get the actual schema. + schema, err := tarantool.GetSchema(conn) + if err != nil { + fmt.Printf("unexpected error: %s\n", err.Error()) + } + // Update the current schema to match the actual one. + conn.SetSchema(schema) +} + +// Example demonstrates how to retrieve information with space schema. +func ExampleSpace() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + // Save Schema to a local variable to avoid races + schema, err := tarantool.GetSchema(conn) + if err != nil { + fmt.Printf("unexpected error: %s\n", err.Error()) + } + if schema.SpacesById == nil { + fmt.Println("schema.SpacesById is nil") + } + if schema.Spaces == nil { + fmt.Println("schema.Spaces is nil") + } + + // Access Space objects by name or ID. + space1 := schema.Spaces["test"] + space2 := schema.SpacesById[616] // It's a map. + fmt.Printf("Space 1 ID %d %s %s\n", space1.Id, space1.Name, space1.Engine) + fmt.Printf("Space 1 ID %d %t\n", space1.FieldsCount, space1.Temporary) + + // Access index information by name or ID. + index1 := space1.Indexes["primary"] + index2 := space2.IndexesById[3] // It's a map. + fmt.Printf("Index %d %s\n", index1.Id, index1.Name) + + // Access index fields information by index. + indexField1 := index1.Fields[0] // It's a slice. + indexField2 := index2.Fields[1] // It's a slice. + fmt.Println(indexField1, indexField2) + + // Access space fields information by name or id (index). + spaceField1 := space2.Fields["name0"] + spaceField2 := space2.FieldsById[3] + fmt.Printf("SpaceField 1 %s %s\n", spaceField1.Name, spaceField1.Type) + fmt.Printf("SpaceField 2 %s %s\n", spaceField2.Name, spaceField2.Type) + + // Output: + // Space 1 ID 617 test memtx + // Space 1 ID 0 false + // Index 0 primary + // {0 unsigned} {2 string} + // SpaceField 1 name0 unsigned + // SpaceField 2 name3 unsigned +} + +// ExampleConnection_Do demonstrates how to send a request and process +// a response. +func ExampleConnection_Do() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + // It could be any request. + req := tarantool.NewReplaceRequest("test"). + Tuple([]interface{}{int(1111), "foo", "bar"}) + + // We got a future, the request actually not performed yet. + future := conn.Do(req) + + // When the future receives the response, the result of the Future is set + // and becomes available. We could wait for that moment with Future.Get() + // or Future.GetTyped() methods. + data, err := future.Get() + if err != nil { + fmt.Printf("Failed to execute the request: %s\n", err) + } else { + fmt.Println(data) + } + + // Output: + // [[1111 foo bar]] +} + +// ExampleConnection_Do_failure demonstrates how to send a request and process +// failure. +func ExampleConnection_Do_failure() { + conn := exampleConnect(dialer, opts) + defer conn.Close() + + // It could be any request. + req := tarantool.NewCallRequest("not_exist") + + // We got a future, the request actually not performed yet. + future := conn.Do(req) + + // When the future receives the response, the result of the Future is set + // and becomes available. We could wait for that moment with Future.Get(), + // Future.GetResponse() or Future.GetTyped() methods. + resp, err := future.GetResponse() + if err != nil { + fmt.Printf("Error in the future: %s\n", err) + } + // Optional step: check a response error. + // It allows checking that response has or hasn't an error without decoding. + if resp.Header().Error != tarantool.ErrorNo { + fmt.Printf("Response error: %s\n", resp.Header().Error) + } + + data, err := future.Get() + if err != nil { + fmt.Printf("Data: %v\n", data) + } + + if err != nil { + // We don't print the error here to keep the example reproducible. + // fmt.Printf("Failed to execute the request: %s\n", err) + if resp == nil { + // Something happens in a client process (timeout, IO error etc). + fmt.Printf("Resp == nil, ClientErr = %s\n", err.(tarantool.ClientError)) + } else { + // Response exist. So it could be a Tarantool error or a decode + // error. We need to check the error code. + fmt.Printf("Error code from the response: %d\n", resp.Header().Error) + if resp.Header().Error == tarantool.ErrorNo { + fmt.Printf("Decode error: %s\n", err) + } else { + code := err.(tarantool.Error).Code + fmt.Printf("Error code from the error: %d\n", code) + fmt.Printf("Error short from the error: %s\n", code) + } + } + } + + // Output: + // Response error: ER_NO_SUCH_PROC + // Data: [] + // Error code from the response: 33 + // Error code from the error: 33 + // Error short from the error: ER_NO_SUCH_PROC +} + +// To use prepared statements to query a tarantool instance, call NewPrepared. +func ExampleConnection_NewPrepared() { + // Tarantool supports SQL since version 2.0.0 + isLess, err := test_helpers.IsTarantoolVersionLess(2, 0, 0) + if err != nil || isLess { + return + } + + dialer := tarantool.NetDialer{ + Address: "127.0.0.1:3013", + User: "test", + Password: "test", + } + opts := tarantool.Opts{ + Timeout: 5 * time.Second, + } + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, dialer, opts) + if err != nil { + fmt.Printf("Failed to connect: %s", err.Error()) + } + + stmt, err := conn.NewPrepared("SELECT 1") + if err != nil { + fmt.Printf("Failed to connect: %s", err.Error()) + } + + executeReq := tarantool.NewExecutePreparedRequest(stmt) + unprepareReq := tarantool.NewUnprepareRequest(stmt) + + _, err = conn.Do(executeReq).Get() + if err != nil { + fmt.Printf("Failed to execute prepared stmt") + } + + _, err = conn.Do(unprepareReq).Get() + if err != nil { + fmt.Printf("Failed to prepare") + } +} + +func ExampleConnection_NewWatcher() { + const key = "foo" + const value = "bar" + + // Tarantool watchers since version 2.10 + isLess, err := test_helpers.IsTarantoolVersionLess(2, 10, 0) + if err != nil || isLess { + return + } + + dialer := tarantool.NetDialer{ + Address: "127.0.0.1:3013", + User: "test", + Password: "test", + // You can require the feature explicitly. + RequiredProtocolInfo: tarantool.ProtocolInfo{ + Features: []iproto.Feature{iproto.IPROTO_FEATURE_WATCHERS}, + }, + } + + opts := tarantool.Opts{ + Timeout: 5 * time.Second, + Reconnect: 5 * time.Second, + MaxReconnects: 3, + } + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, dialer, opts) + if err != nil { + fmt.Printf("Failed to connect: %s\n", err) + return + } + defer conn.Close() + + callback := func(event tarantool.WatchEvent) { + fmt.Printf("event connection: %s\n", event.Conn.Addr()) + fmt.Printf("event key: %s\n", event.Key) + fmt.Printf("event value: %v\n", event.Value) + } + watcher, err := conn.NewWatcher(key, callback) + if err != nil { + fmt.Printf("Failed to connect: %s\n", err) + return + } + defer watcher.Unregister() + + conn.Do(tarantool.NewBroadcastRequest(key).Value(value)).Get() + time.Sleep(time.Second) +} + +// ExampleConnection_CloseGraceful_force demonstrates how to force close +// a connection with graceful close in progress after a while. +func ExampleConnection_CloseGraceful_force() { + conn := exampleConnect(dialer, opts) + + eval := `local fiber = require('fiber') + local time = ... + fiber.sleep(time) +` + req := tarantool.NewEvalRequest(eval).Args([]interface{}{10}) + fut := conn.Do(req) + + done := make(chan struct{}) + go func() { + conn.CloseGraceful() + fmt.Println("Connection.CloseGraceful() done!") + close(done) + }() + + select { + case <-done: + case <-time.After(time.Second): + fmt.Println("Force Connection.Close()!") + conn.Close() + } + <-done + + fmt.Println("Result:") + fmt.Println(fut.Get()) + // Output: + // Force Connection.Close()! + // Connection.CloseGraceful() done! + // Result: + // [] connection closed by client (0x4001) +} + +func ExampleWatchOnceRequest() { + const key = "foo" + const value = "bar" + + // WatchOnce request present in Tarantool since version 3.0 + isLess, err := test_helpers.IsTarantoolVersionLess(3, 0, 0) + if err != nil || isLess { + return + } + + conn := exampleConnect(dialer, opts) + defer conn.Close() + + conn.Do(tarantool.NewBroadcastRequest(key).Value(value)).Get() + + data, err := conn.Do(tarantool.NewWatchOnceRequest(key)).Get() + if err != nil { + fmt.Printf("Failed to execute the request: %s\n", err) + } else { + fmt.Println(data) + } +} + +// This example demonstrates how to use an existing socket file descriptor +// to establish a connection with Tarantool. This can be useful if the socket fd +// was inherited from the Tarantool process itself. +// For details, please see TestFdDialer in tarantool_test.go. +func ExampleFdDialer() { + addr := dialer.Address + c, err := net.Dial("tcp", addr) + if err != nil { + fmt.Printf("can't establish connection: %v\n", err) + return + } + f, err := c.(*net.TCPConn).File() + if err != nil { + fmt.Printf("unexpected error: %v\n", err) + return + } + dialer := tarantool.FdDialer{ + Fd: f.Fd(), + } + // Use an existing socket fd to create connection with Tarantool. + conn, err := tarantool.Connect(context.Background(), dialer, opts) + if err != nil { + fmt.Printf("connect error: %v\n", err) + return + } + _, err = conn.Do(tarantool.NewPingRequest()).Get() + fmt.Println(err) + // Output: + // +} diff --git a/future.go b/future.go new file mode 100644 index 000000000..ed3d89cdc --- /dev/null +++ b/future.go @@ -0,0 +1,133 @@ +package tarantool + +import ( + "io" + "sync" + "time" +) + +// Future is a handle for asynchronous request. +type Future struct { + requestId uint32 + req Request + next *Future + timeout time.Duration + mutex sync.Mutex + resp Response + err error + ready chan struct{} + done chan struct{} +} + +func (fut *Future) wait() { + if fut.done == nil { + return + } + <-fut.done +} + +func (fut *Future) isDone() bool { + if fut.done == nil { + return true + } + select { + case <-fut.done: + return true + default: + return false + } +} + +// NewFuture creates a new empty Future for a given Request. +func NewFuture(req Request) (fut *Future) { + fut = &Future{} + fut.ready = make(chan struct{}, 1000000000) + fut.done = make(chan struct{}) + fut.req = req + return fut +} + +// SetResponse sets a response for the future and finishes the future. +func (fut *Future) SetResponse(header Header, body io.Reader) error { + fut.mutex.Lock() + defer fut.mutex.Unlock() + + if fut.isDone() { + return nil + } + + resp, err := fut.req.Response(header, body) + if err != nil { + return err + } + fut.resp = resp + + close(fut.ready) + close(fut.done) + return nil +} + +// SetError sets an error for the future and finishes the future. +func (fut *Future) SetError(err error) { + fut.mutex.Lock() + defer fut.mutex.Unlock() + + if fut.isDone() { + return + } + fut.err = err + + close(fut.ready) + close(fut.done) +} + +// GetResponse waits for Future to be filled and returns Response and error. +// +// Note: Response could be equal to nil if ClientError is returned in error. +// +// "error" could be Error, if it is error returned by Tarantool, +// or ClientError, if something bad happens in a client process. +func (fut *Future) GetResponse() (Response, error) { + fut.wait() + return fut.resp, fut.err +} + +// Get waits for Future to be filled and returns the data of the Response and error. +// +// The data will be []interface{}, so if you want more performance, use GetTyped method. +// +// "error" could be Error, if it is error returned by Tarantool, +// or ClientError, if something bad happens in a client process. +func (fut *Future) Get() ([]interface{}, error) { + fut.wait() + if fut.err != nil { + return nil, fut.err + } + return fut.resp.Decode() +} + +// GetTyped waits for Future and calls msgpack.Decoder.Decode(result) if no error happens. +// It is could be much faster than Get() function. +// +// Note: Tarantool usually returns array of tuples (except for Eval and Call17 actions). +func (fut *Future) GetTyped(result interface{}) error { + fut.wait() + if fut.err != nil { + return fut.err + } + return fut.resp.DecodeTyped(result) +} + +var closedChan = make(chan struct{}) + +func init() { + close(closedChan) +} + +// WaitChan returns channel which becomes closed when response arrived or error occurred. +func (fut *Future) WaitChan() <-chan struct{} { + if fut.done == nil { + return closedChan + } + return fut.done +} diff --git a/future_test.go b/future_test.go new file mode 100644 index 000000000..47f4e3c20 --- /dev/null +++ b/future_test.go @@ -0,0 +1,129 @@ +package tarantool_test + +import ( + "bytes" + "context" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tarantool/go-iproto" + . "github.com/tarantool/go-tarantool/v3" + "github.com/vmihailenco/msgpack/v5" +) + +type futureMockRequest struct { +} + +func (req *futureMockRequest) Type() iproto.Type { + return iproto.Type(0) +} + +func (req *futureMockRequest) Async() bool { + return false +} + +func (req *futureMockRequest) Body(_ SchemaResolver, _ *msgpack.Encoder) error { + return nil +} + +func (req *futureMockRequest) Conn() *Connection { + return &Connection{} +} + +func (req *futureMockRequest) Ctx() context.Context { + return nil +} + +func (req *futureMockRequest) Response(header Header, + body io.Reader) (Response, error) { + resp, err := createFutureMockResponse(header, body) + return resp, err +} + +type futureMockResponse struct { + header Header + data []byte + + decodeCnt int + decodeTypedCnt int +} + +func (resp *futureMockResponse) Header() Header { + return resp.header +} + +func (resp *futureMockResponse) Decode() ([]interface{}, error) { + resp.decodeCnt++ + + dataInt := make([]interface{}, len(resp.data)) + for i := range resp.data { + dataInt[i] = resp.data[i] + } + return dataInt, nil +} + +func (resp *futureMockResponse) DecodeTyped(res interface{}) error { + resp.decodeTypedCnt++ + return nil +} + +func createFutureMockResponse(header Header, body io.Reader) (Response, error) { + data, err := io.ReadAll(body) + if err != nil { + return nil, err + } + return &futureMockResponse{header: header, data: data}, nil +} + +func TestFuture_Get(t *testing.T) { + fut := NewFuture(&futureMockRequest{}) + fut.SetResponse(Header{}, bytes.NewReader([]byte{'v', '2'})) + + resp, err := fut.GetResponse() + assert.NoError(t, err) + mockResp, ok := resp.(*futureMockResponse) + assert.True(t, ok) + + data, err := fut.Get() + assert.NoError(t, err) + assert.Equal(t, []interface{}{uint8('v'), uint8('2')}, data) + assert.Equal(t, 1, mockResp.decodeCnt) + assert.Equal(t, 0, mockResp.decodeTypedCnt) +} + +func TestFuture_GetTyped(t *testing.T) { + fut := NewFuture(&futureMockRequest{}) + fut.SetResponse(Header{}, bytes.NewReader([]byte{'v', '2'})) + + resp, err := fut.GetResponse() + assert.NoError(t, err) + mockResp, ok := resp.(*futureMockResponse) + assert.True(t, ok) + + var data []byte + + err = fut.GetTyped(&data) + assert.NoError(t, err) + assert.Equal(t, 0, mockResp.decodeCnt) + assert.Equal(t, 1, mockResp.decodeTypedCnt) +} + +func TestFuture_GetResponse(t *testing.T) { + mockResp, err := createFutureMockResponse(Header{}, + bytes.NewReader([]byte{'v', '2'})) + assert.NoError(t, err) + + fut := NewFuture(&futureMockRequest{}) + fut.SetResponse(Header{}, bytes.NewReader([]byte{'v', '2'})) + + resp, err := fut.GetResponse() + assert.NoError(t, err) + respConv, ok := resp.(*futureMockResponse) + assert.True(t, ok) + assert.Equal(t, mockResp, respConv) + + data, err := resp.Decode() + assert.NoError(t, err) + assert.Equal(t, []interface{}{uint8('v'), uint8('2')}, data) +} diff --git a/go.mod b/go.mod new file mode 100644 index 000000000..7582412da --- /dev/null +++ b/go.mod @@ -0,0 +1,29 @@ +module github.com/tarantool/go-tarantool/v3 + +go 1.24 + +require ( + github.com/google/uuid v1.6.0 + github.com/shopspring/decimal v1.3.1 + github.com/stretchr/testify v1.11.1 + github.com/tarantool/go-iproto v1.1.0 + github.com/tarantool/go-option v1.0.0 + github.com/vmihailenco/msgpack/v5 v5.4.1 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect + github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect + golang.org/x/mod v0.27.0 // indirect + golang.org/x/sync v0.16.0 // indirect + golang.org/x/tools v0.36.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +tool ( + github.com/tarantool/go-option/cmd/gentypes + golang.org/x/tools/cmd/stringer +) diff --git a/go.sum b/go.sum new file mode 100644 index 000000000..91e1c4f25 --- /dev/null +++ b/go.sum @@ -0,0 +1,38 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= +github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tarantool/go-iproto v1.1.0 h1:HULVOIHsiehI+FnHfM7wMDntuzUddO09DKqu2WnFQ5A= +github.com/tarantool/go-iproto v1.1.0/go.mod h1:LNCtdyZxojUed8SbOiYHoc3v9NvaZTB7p96hUySMlIo= +github.com/tarantool/go-option v1.0.0 h1:+Etw0i3TjsXvADTo5rfZNCfsXe3BfHOs+iVfIrl0Nlo= +github.com/tarantool/go-option v1.0.0/go.mod h1:lXzzeZtL+rPUtLOCDP6ny3FemFBjruG9aHKzNN2bS08= +github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= +github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ= +golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= +golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/golden_test.go b/golden_test.go new file mode 100644 index 000000000..c2ee52f89 --- /dev/null +++ b/golden_test.go @@ -0,0 +1,536 @@ +package tarantool_test + +import ( + "bytes" + "encoding/json" + "flag" + "fmt" + "os" + "path" + "reflect" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tarantool/go-iproto" + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// golden_test.go contains tests that will check that the msgpack +// encoding of the requests body matches the golden files. +// +// Algorithm to add new test: +// 1. Create a new test record in TestGolden (name + Request). +// 2. Run the test with the flag -generate-golden to generate the golden file. +// (for example: `go test . -run=TestGolden -v -generate-golden`) +// 3. Verify that JSON representation of the msgpack is the same as expected. +// 4. Commit the test record. +// +// Example of JSON representation of the msgpack +// ``` +// golden_test.go:80: writing golden file testdata/requests/select-with-optionals.msgpack +// golden_test.go:38: { +// golden_test.go:38: "IPROTO_FETCH_POSITION[31]": true, +// golden_test.go:38: "IPROTO_INDEX_ID[17]": 0, +// golden_test.go:38: "IPROTO_ITERATOR[20]": 5, +// golden_test.go:38: "IPROTO_KEY[32]": [], +// golden_test.go:38: "IPROTO_LIMIT[18]": 123, +// golden_test.go:38: "IPROTO_OFFSET[19]": 123, +// golden_test.go:38: "IPROTO_SPACE_NAME[94]": "table_name" +// golden_test.go:38: } +// ``` +// +// +// In case of any changes in the msgpack encoding/tests, the test will fail with next error: +// +// ``` +// === RUN TestGolden/testdata/requests/select-with-after.msgpack +// golden_test.go:109: +// Error Trace: ../go-tarantool/golden_test.go:73 +// ../go-tarantool/golden_test.go:109 +// Error: Not equal: +// expected: []byte{0x87, 0x14, 0x5, 0x13, 0x0, ..., 0x33, 0x33, 0x33} +// actual : []byte{0x87, 0x14, 0x5, 0x13, 0x0, ..., 0x33, 0x33, 0x33} +// +// Diff: +// --- Expected +// +++ Actual +// @@ -1,3 +1,3 @@ +// ([]uint8) (len=105) { +// - 00000000 87 14 05 13 ... 6e |......{^.table_n| +// + 00000000 87 14 05 13 ... 6e |......|^.table_n| +// 00000010 61 6d 65 11 ... ff |ame.. ...args...| +// Test: TestGolden/testdata/requests/select-with-after.msgpack +// Messages: golden file content is not equal to actual +// golden_test.go:109: expected: +// golden_test.go:63: { +// golden_test.go:63: "IPROTO_AFTER_TUPLE[47]": [ +// golden_test.go:63: 1, +// golden_test.go:63: "args", +// golden_test.go:63: 3, +// golden_test.go:63: "2024-01-01T03:00:00+03:00", +// golden_test.go:63: "gZMIqvDBS3SYYcSrWiZjCA==", +// golden_test.go:63: 1.2 +// golden_test.go:63: ], +// golden_test.go:63: "IPROTO_INDEX_ID[17]": 0, +// golden_test.go:63: "IPROTO_ITERATOR[20]": 5, +// golden_test.go:63: "IPROTO_KEY[32]": [ +// golden_test.go:63: 1, +// golden_test.go:63: "args", +// golden_test.go:63: 3, +// golden_test.go:63: "2024-01-01T03:00:00+03:00", +// golden_test.go:63: "gZMIqvDBS3SYYcSrWiZjCA==", +// golden_test.go:63: 1.2 +// golden_test.go:63: ], +// golden_test.go:63: "IPROTO_LIMIT[18]": 123, +// golden_test.go:63: "IPROTO_OFFSET[19]": 0, +// golden_test.go:63: "IPROTO_SPACE_NAME[94]": "table_name" +// golden_test.go:63: } +// golden_test.go:109: actual: +// golden_test.go:63: { +// golden_test.go:63: "IPROTO_AFTER_TUPLE[47]": [ +// golden_test.go:63: 1, +// golden_test.go:63: "args", +// golden_test.go:63: 3, +// golden_test.go:63: "2024-01-01T03:00:00+03:00", +// golden_test.go:63: "gZMIqvDBS3SYYcSrWiZjCA==", +// golden_test.go:63: 1.2 +// golden_test.go:63: ], +// golden_test.go:63: "IPROTO_INDEX_ID[17]": 0, +// golden_test.go:63: "IPROTO_ITERATOR[20]": 5, +// golden_test.go:63: "IPROTO_KEY[32]": [ +// golden_test.go:63: 1, +// golden_test.go:63: "args", +// golden_test.go:63: 3, +// golden_test.go:63: "2024-01-01T03:00:00+03:00", +// golden_test.go:63: "gZMIqvDBS3SYYcSrWiZjCA==", +// golden_test.go:63: 1.2 +// golden_test.go:63: ], +// golden_test.go:63: "IPROTO_LIMIT[18]": 124, +// golden_test.go:63: "IPROTO_OFFSET[19]": 0, +// golden_test.go:63: "IPROTO_SPACE_NAME[94]": "table_name" +// golden_test.go:63: } +// --- FAIL: TestGolden/testdata/requests/select-with-after.msgpack (0.00s) +// ``` +// Use it to debug the test. +// +// If you want to update the golden file, run delete old file and rerun the test. + +func logMsgpackAsJsonConvert(t *testing.T, data []byte) { + t.Helper() + + var decodedMsgpack map[int]interface{} + + decoder := msgpack.NewDecoder(bytes.NewReader(data)) + require.NoError(t, decoder.Decode(&decodedMsgpack)) + + decodedConvertedMsgpack := map[string]interface{}{} + for k, v := range decodedMsgpack { + decodedConvertedMsgpack[fmt.Sprintf("%s[%d]", iproto.Key(k).String(), k)] = v + } + + encodedJson, err := json.MarshalIndent(decodedConvertedMsgpack, "", " ") + require.NoError(t, err, "failed to convert msgpack to json") + + for _, line := range bytes.Split(encodedJson, []byte("\n")) { + t.Log(string(line)) + } +} + +func compareGoldenMsgpackAndPrintDiff(t *testing.T, name string, data []byte) { + t.Helper() + + testContent, err := os.ReadFile(name) + require.NoError(t, err, "failed to read golden file", name) + + if assert.Equal(t, testContent, data, "golden file content is not equal to actual") { + return + } + + t.Logf("expected:\n") + logMsgpackAsJsonConvert(t, testContent) + t.Logf("actual:\n") + logMsgpackAsJsonConvert(t, data) +} + +func fileExists(name string) bool { + _, err := os.Stat(name) + return !os.IsNotExist(err) +} + +const ( + pathPrefix = "testdata/requests" +) + +var ( + generateGolden = flag.Bool("generate-golden", false, + "generate golden files if they do not exist") +) + +type goldenTestCase struct { + Name string + Test func(t *testing.T) tarantool.Request + Request tarantool.Request + Resolver tarantool.SchemaResolver +} + +func (tc goldenTestCase) Execute(t *testing.T) { + t.Helper() + + if tc.Request != nil { + if tc.Test != nil { + require.FailNow(t, "both Test and Request must not be set") + } + + tc.Test = func(t *testing.T) tarantool.Request { + return tc.Request + } + } + if tc.Resolver == nil { + tc.Resolver = &dummySchemaResolver{} + } + + name := path.Join(pathPrefix, tc.Name) + + t.Run(name, func(t *testing.T) { + var out bytes.Buffer + encoder := msgpack.NewEncoder(&out) + + req := tc.Test(t) + require.NotNil(t, req, "failed to create request") + + err := req.Body(&dummySchemaResolver{}, encoder) + require.NoError(t, err, "failed to encode request") + + goldenFileExists := fileExists(name) + generateGoldenIsSet := *generateGolden && !goldenFileExists + + switch { + case !goldenFileExists && generateGoldenIsSet: + t.Logf("writing golden file %s", name) + err := os.WriteFile(name, out.Bytes(), 0644) + require.NoError(t, err, "failed to write golden file", name) + logMsgpackAsJsonConvert(t, out.Bytes()) + case !goldenFileExists && !generateGoldenIsSet: + assert.FailNow(t, "golden file does not exist") + } + + compareGoldenMsgpackAndPrintDiff(t, name, out.Bytes()) + }) +} + +type dummySchemaResolver struct{} + +func interfaceToUint32(in interface{}) (uint32, bool) { + switch val := reflect.ValueOf(in); val.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return uint32(val.Int()), true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return uint32(val.Uint()), true + default: + return 0, false + } +} + +func (d dummySchemaResolver) ResolveSpace(in interface{}) (uint32, error) { + if num, ok := interfaceToUint32(in); ok { + return num, nil + } + return 0, fmt.Errorf("unexpected space type %T", in) +} + +func (d dummySchemaResolver) ResolveIndex(in interface{}, _ uint32) (uint32, error) { + if in == nil { + return 0, nil + } else if num, ok := interfaceToUint32(in); ok { + return num, nil + } + return 0, fmt.Errorf("unexpected index type %T", in) +} + +func (d dummySchemaResolver) NamesUseSupported() bool { + return true +} + +func TestGolden(t *testing.T) { + precachedUUID := uuid.MustParse("819308aa-f0c1-4b74-9861-c4ab5a266308") + precachedTime := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + precachedTuple := []interface{}{1, "args", 3, precachedTime, precachedUUID, 1.2} + + precachedUpdateOps := tarantool.NewOperations(). + Add(1, "test"). + Assign(2, "fest"). + Delete(3, ""). + Insert(4, "insert"). + Splice(5, 6, 7, "splice"). + Subtract(6, "subtract"). + BitwiseAnd(7, 10). + BitwiseOr(8, 11). + BitwiseXor(9, 12) + + testCases := []goldenTestCase{ + { + Name: "commit-raw.msgpack", + Request: tarantool.NewCommitRequest(), + }, + { + Name: "commit-with-sync.msgpack", + Request: tarantool.NewCommitRequest().IsSync(true), + }, + { + Name: "commit-with-sync-false.msgpack", + Request: tarantool.NewCommitRequest().IsSync(false), + }, + { + Name: "begin.msgpack", + Request: tarantool.NewBeginRequest(), + }, + { + Name: "begin-with-txn-isolation.msgpack", + Request: tarantool.NewBeginRequest(). + TxnIsolation(tarantool.ReadCommittedLevel), + }, + { + Name: "begin-with-txn-isolation-is-sync.msgpack", + Request: tarantool.NewBeginRequest(). + TxnIsolation(tarantool.ReadCommittedLevel). + IsSync(true), + }, + { + Name: "begin-with-txn-isolation-is-sync-timeout.msgpack", + Request: tarantool.NewBeginRequest(). + TxnIsolation(tarantool.ReadCommittedLevel). + IsSync(true). + Timeout(2 * time.Second), + }, + { + Name: "rollback.msgpack", + Request: tarantool.NewRollbackRequest(), + }, + { + Name: "ping.msgpack", + Request: tarantool.NewPingRequest(), + }, + { + Name: "call-no-args.msgpack", + Request: tarantool.NewCallRequest("function.name"), + }, + { + Name: "call-with-args.msgpack", + Request: tarantool.NewCallRequest("function.name").Args( + []interface{}{1, 2, 3}, + ), + }, + { + Name: "call-with-args-mixed.msgpack", + Request: tarantool.NewCallRequest("function.name").Args( + []interface{}{1, "args", 3, precachedTime, precachedUUID, 1.2}, + ), + }, + { + Name: "call-with-args-nil.msgpack", + Request: tarantool.NewCallRequest("function.name").Args(nil), + }, + { + Name: "call-with-args-empty-array.msgpack", + Request: tarantool.NewCallRequest("function.name").Args([]int{}), + }, + { + Name: "eval.msgpack", + Request: tarantool.NewEvalRequest("function_name()"), + }, + { + Name: "eval-with-nil.msgpack", + Request: tarantool.NewEvalRequest("function_name()").Args(nil), + }, + { + Name: "eval-with-empty-array.msgpack", + Request: tarantool.NewEvalRequest("function_name()").Args([]int{}), + }, + { + Name: "eval-with-args.msgpack", + Request: tarantool.NewEvalRequest("function_name(...)").Args(precachedTuple), + }, + { + Name: "eval-with-single-number.msgpack", + Request: tarantool.NewEvalRequest("function_name()").Args(1), + }, + { + Name: "delete-raw.msgpack", + Request: tarantool.NewDeleteRequest("table_name"), + }, + { + Name: "delete-snumber-inumber.msgpack", + Request: tarantool.NewDeleteRequest(246). + Index(123).Key([]interface{}{123}), + }, + { + Name: "delete-snumber-iname.msgpack", + Request: tarantool.NewDeleteRequest(246). + Index("index_name").Key([]interface{}{123}), + }, + { + Name: "delete-sname-inumber.msgpack", + Request: tarantool.NewDeleteRequest("table_name"). + Index(123).Key([]interface{}{123}), + }, + { + Name: "delete-sname-iname.msgpack", + Request: tarantool.NewDeleteRequest("table_name"). + Index("index_name").Key([]interface{}{123}), + }, + { + Name: "replace-sname.msgpack", + Request: tarantool.NewReplaceRequest("table_name"). + Tuple(precachedTuple), + }, + { + Name: "replace-snumber.msgpack", + Request: tarantool.NewReplaceRequest(123). + Tuple(precachedTuple), + }, + { + Name: "insert-sname.msgpack", + Request: tarantool.NewReplaceRequest("table_name"). + Tuple(precachedTuple), + }, + { + Name: "insert-snumber.msgpack", + Request: tarantool.NewReplaceRequest(123). + Tuple(precachedTuple), + }, + { + Name: "call16.msgpack", + Request: tarantool.NewCall16Request("function.name"), + }, + { + Name: "call16-with-args.msgpack", + Request: tarantool.NewCall16Request("function.name"). + Args(precachedTuple), + }, + { + Name: "call16-with-args-nil.msgpack", + Request: tarantool.NewCall16Request("function.name").Args(nil), + }, + { + Name: "call17.msgpack", + Request: tarantool.NewCall17Request("function.name"), + }, + { + Name: "call17-with-args-nil.msgpack", + Request: tarantool.NewCall17Request("function.name").Args(nil), + }, + { + Name: "call17-with-args.msgpack", + Request: tarantool.NewCall17Request("function.name"). + Args(precachedTuple), + }, + { + Name: "update.msgpack", + Request: tarantool.NewUpdateRequest("table_name"), + }, + { + Name: "update-snumber-iname.msgpack", + Request: tarantool.NewUpdateRequest(123). + Index("index_name").Key([]interface{}{123}). + Operations(precachedUpdateOps), + }, + { + Name: "update-sname-iname.msgpack", + Request: tarantool.NewUpdateRequest("table_name"). + Index("index_name").Key([]interface{}{123}). + Operations(precachedUpdateOps), + }, + { + Name: "update-sname-inumber.msgpack", + Request: tarantool.NewUpdateRequest("table_name"). + Index(123).Key([]interface{}{123}). + Operations(precachedUpdateOps), + }, + { + Name: "upsert.msgpack", + Request: tarantool.NewUpsertRequest("table_name"), + }, + { + Name: "upsert-snumber.msgpack", + Request: tarantool.NewUpsertRequest(123). + Operations(precachedUpdateOps). + Tuple(precachedTuple), + }, + { + Name: "upsert-sname.msgpack", + Request: tarantool.NewUpsertRequest("table_name"). + Operations(precachedUpdateOps). + Tuple(precachedTuple), + }, + { + Name: "select", + Request: tarantool.NewSelectRequest("table_name"), + }, + { + Name: "select-sname-iname.msgpack", + Request: tarantool.NewSelectRequest("table_name"). + Index("index_name"), + }, + { + Name: "select-sname-inumber.msgpack", + Request: tarantool.NewSelectRequest("table_name"). + Index(123), + }, + { + Name: "select-snumber-iname.msgpack", + Request: tarantool.NewSelectRequest(123). + Index("index_name"), + }, + { + Name: "select-snumber-inumber.msgpack", + Request: tarantool.NewSelectRequest(123). + Index(123), + }, + { + Name: "select-with-key.msgpack", + Request: tarantool.NewSelectRequest("table_name"). + Key(precachedTuple), + }, + { + Name: "select-key-sname-iname.msgpack", + Request: tarantool.NewSelectRequest("table_name"). + Index("index_name").Key(precachedTuple), + }, + { + Name: "select-key-sname-inumber.msgpack", + Request: tarantool.NewSelectRequest("table_name"). + Index(123).Key(precachedTuple), + }, + { + Name: "select-key-snumber-iname.msgpack", + Request: tarantool.NewSelectRequest(123). + Index("index_name").Key(precachedTuple), + }, + { + Name: "select-key-snumber-inumber.msgpack", + Request: tarantool.NewSelectRequest(123). + Index(123).Key(precachedTuple), + }, + { + Name: "select-with-optionals.msgpack", + Request: tarantool.NewSelectRequest("table_name"). + Offset(123).Limit(123).Iterator(tarantool.IterGe). + FetchPos(true), + }, + { + Name: "select-with-after.msgpack", + Request: tarantool.NewSelectRequest("table_name"). + After(precachedTuple).Limit(123).Iterator(tarantool.IterGe). + Key(precachedTuple), + }, + } + + for _, tc := range testCases { + tc.Execute(t) + } +} diff --git a/header.go b/header.go new file mode 100644 index 000000000..20a4a465f --- /dev/null +++ b/header.go @@ -0,0 +1,14 @@ +package tarantool + +import "github.com/tarantool/go-iproto" + +// Header is a response header. +type Header struct { + // RequestId is an id of a corresponding request. + RequestId uint32 + // Error is a response error. It could be used + // to check that response has or hasn't an error without decoding. + // Error == ErrorNo (iproto.ER_UNKNOWN) if there is no error. + // Otherwise, it contains an error code from iproto.Error enumeration. + Error iproto.Error +} diff --git a/iterator.go b/iterator.go new file mode 100644 index 000000000..128168d7b --- /dev/null +++ b/iterator.go @@ -0,0 +1,35 @@ +package tarantool + +import ( + "github.com/tarantool/go-iproto" +) + +// Iter is an enumeration type of a select iterator. +type Iter uint32 + +const ( + // Key == x ASC order. + IterEq Iter = Iter(iproto.ITER_EQ) + // Key == x DESC order. + IterReq Iter = Iter(iproto.ITER_REQ) + // All tuples. + IterAll Iter = Iter(iproto.ITER_ALL) + // Key < x. + IterLt Iter = Iter(iproto.ITER_LT) + // Key <= x. + IterLe Iter = Iter(iproto.ITER_LE) + // Key >= x. + IterGe Iter = Iter(iproto.ITER_GE) + // Key > x. + IterGt Iter = Iter(iproto.ITER_GT) + // All bits from x are set in key. + IterBitsAllSet Iter = Iter(iproto.ITER_BITS_ALL_SET) + // All bits are not set. + IterBitsAnySet Iter = Iter(iproto.ITER_BITS_ANY_SET) + // All bits are not set. + IterBitsAllNotSet Iter = Iter(iproto.ITER_BITS_ALL_NOT_SET) + // Key overlaps x. + IterOverlaps Iter = Iter(iproto.ITER_OVERLAPS) + // Tuples in distance ascending order from specified point. + IterNeighbor Iter = Iter(iproto.ITER_NEIGHBOR) +) diff --git a/pool/config.lua b/pool/config.lua new file mode 100644 index 000000000..4e91a0d83 --- /dev/null +++ b/pool/config.lua @@ -0,0 +1,54 @@ +-- Do not set listen for now so connector won't be +-- able to send requests until everything is configured. +box.cfg{ + work_dir = os.getenv("TEST_TNT_WORK_DIR"), + memtx_use_mvcc_engine = os.getenv("TEST_TNT_MEMTX_USE_MVCC_ENGINE") == 'true' or nil, +} + +box.once("init", function() + box.schema.user.create('test', { password = 'test' }) + box.schema.user.grant('test', 'read,write,execute', 'universe') + + box.schema.user.create('test_noexec', { password = 'test' }) + box.schema.user.grant('test_noexec', 'read,write', 'universe') + + local s = box.schema.space.create('testPool', { + id = 520, + if_not_exists = true, + format = { + {name = "key", type = "string"}, + {name = "value", type = "string"}, + }, + }) + s:create_index('pk', { + type = 'tree', + parts = {{ field = 1, type = 'string' }}, + if_not_exists = true + }) + + local sp = box.schema.space.create('SQL_TEST', { + id = 521, + if_not_exists = true, + format = { + {name = "NAME0", type = "unsigned"}, + {name = "NAME1", type = "string"}, + {name = "NAME2", type = "string"}, + } + }) + sp:create_index('primary', {type = 'tree', parts = {1, 'uint'}, if_not_exists = true}) + sp:insert{1, "test", "test"} + -- grants for sql tests + box.schema.user.grant('test', 'create,read,write,drop,alter', 'space') + box.schema.user.grant('test', 'create', 'sequence') +end) + +local function simple_incr(a) + return a + 1 +end + +rawset(_G, 'simple_incr', simple_incr) + +-- Set listen only when every other thing is configured. +box.cfg{ + listen = os.getenv("TEST_TNT_LISTEN"), +} diff --git a/pool/connection_pool.go b/pool/connection_pool.go new file mode 100644 index 000000000..9d020ba8e --- /dev/null +++ b/pool/connection_pool.go @@ -0,0 +1,1026 @@ +// Package with methods to work with a Tarantool cluster +// considering master discovery. +// +// Main features: +// +// - Return available connection from pool according to round-robin strategy. +// +// - Automatic master discovery by mode parameter. +// +// Since: 1.6.0 +package pool + +import ( + "context" + "errors" + "fmt" + "log" + "sync" + "time" + + "github.com/tarantool/go-iproto" + + "github.com/tarantool/go-tarantool/v3" +) + +var ( + ErrWrongCheckTimeout = errors.New("wrong check timeout, must be greater than 0") + ErrTooManyArgs = errors.New("too many arguments") + ErrIncorrectResponse = errors.New("incorrect response format") + ErrIncorrectStatus = errors.New("incorrect instance status: status should be `running`") + ErrNoRwInstance = errors.New("can't find rw instance in pool") + ErrNoRoInstance = errors.New("can't find ro instance in pool") + ErrNoHealthyInstance = errors.New("can't find healthy instance in pool") + ErrExists = errors.New("endpoint exists") + ErrClosed = errors.New("pool is closed") + ErrUnknownRequest = errors.New("the passed connected request doesn't belong to " + + "the current connection pool") + ErrContextCanceled = errors.New("operation was canceled") +) + +// ConnectionHandler provides callbacks for components interested in handling +// changes of connections in a ConnectionPool. +type ConnectionHandler interface { + // Discovered is called when a connection with a role has been detected + // (for the first time or when a role of a connection has been changed), + // but is not yet available to send requests. It allows for a client to + // initialize the connection before using it in a pool. + // + // The client code may cancel adding a connection to the pool. The client + // need to return an error from the Discovered call for that. In this case + // the pool will close connection and will try to reopen it later. + Discovered(name string, conn *tarantool.Connection, role Role) error + // Deactivated is called when a connection with a role has become + // unavaileble to send requests. It happens if the connection is closed or + // the connection role is switched. + // + // So if a connection switches a role, a pool calls: + // Deactivated() + Discovered(). + // + // Deactivated will not be called if a previous Discovered() call returns + // an error. Because in this case, the connection does not become available + // for sending requests. + Deactivated(name string, conn *tarantool.Connection, role Role) error +} + +// Instance describes a single instance configuration in the pool. +type Instance struct { + // Name is an instance name. The name must be unique. + Name string + // Dialer will be used to create a connection to the instance. + Dialer tarantool.Dialer + // Opts configures a connection to the instance. + Opts tarantool.Opts +} + +// Opts provides additional options (configurable via ConnectWithOpts). +type Opts struct { + // Timeout for timer to reopen connections that have been closed by some + // events and to relocate connection between subpools if ro/rw role has + // been updated. + CheckTimeout time.Duration + // ConnectionHandler provides an ability to handle connection updates. + ConnectionHandler ConnectionHandler +} + +/* +ConnectionInfo structure for information about connection statuses: + +- ConnectedNow reports if connection is established at the moment. + +- ConnRole reports master/replica role of instance. +*/ +type ConnectionInfo struct { + ConnectedNow bool + ConnRole Role + Instance Instance +} + +/* +Main features: + +- Return available connection from pool according to round-robin strategy. + +- Automatic master discovery by mode parameter. +*/ +type ConnectionPool struct { + ends map[string]*endpoint + endsMutex sync.RWMutex + + opts Opts + + state state + done chan struct{} + roPool *roundRobinStrategy + rwPool *roundRobinStrategy + anyPool *roundRobinStrategy + poolsMutex sync.RWMutex + watcherContainer watcherContainer +} + +var _ Pooler = (*ConnectionPool)(nil) + +type endpoint struct { + name string + dialer tarantool.Dialer + opts tarantool.Opts + notify chan tarantool.ConnEvent + conn *tarantool.Connection + role Role + // This is used to switch a connection states. + shutdown chan struct{} + close chan struct{} + closed chan struct{} + cancel context.CancelFunc + closeErr error +} + +func newEndpoint(name string, dialer tarantool.Dialer, opts tarantool.Opts) *endpoint { + return &endpoint{ + name: name, + dialer: dialer, + opts: opts, + notify: make(chan tarantool.ConnEvent, 100), + conn: nil, + role: UnknownRole, + shutdown: make(chan struct{}), + close: make(chan struct{}), + closed: make(chan struct{}), + cancel: nil, + } +} + +// ConnectWithOpts creates pool for instances with specified instances and +// opts. Instances must have unique names. +func ConnectWithOpts(ctx context.Context, instances []Instance, + opts Opts) (*ConnectionPool, error) { + unique := make(map[string]bool) + for _, instance := range instances { + if _, ok := unique[instance.Name]; ok { + return nil, fmt.Errorf("duplicate instance name: %q", instance.Name) + } + unique[instance.Name] = true + } + + if opts.CheckTimeout <= 0 { + return nil, ErrWrongCheckTimeout + } + + size := len(instances) + rwPool := newRoundRobinStrategy(size) + roPool := newRoundRobinStrategy(size) + anyPool := newRoundRobinStrategy(size) + + p := &ConnectionPool{ + ends: make(map[string]*endpoint), + opts: opts, + state: connectedState, + done: make(chan struct{}), + rwPool: rwPool, + roPool: roPool, + anyPool: anyPool, + } + + fillCtx, fillCancel := context.WithCancel(ctx) + defer fillCancel() + + var timeout <-chan time.Time + + timeout = make(chan time.Time) + filled := p.fillPools(fillCtx, instances) + done := 0 + success := len(instances) == 0 + + for done < len(instances) { + select { + case <-timeout: + fillCancel() + // To be sure that the branch is called only once. + timeout = make(chan time.Time) + case err := <-filled: + done++ + + if err == nil && !success { + timeout = time.After(opts.CheckTimeout) + success = true + } + } + } + + if !success && ctx.Err() != nil { + p.state.set(closedState) + return nil, ctx.Err() + } + + for _, endpoint := range p.ends { + endpointCtx, cancel := context.WithCancel(context.Background()) + endpoint.cancel = cancel + go p.controller(endpointCtx, endpoint) + } + + return p, nil +} + +// Connect creates pool for instances with specified instances. Instances must +// have unique names. +// +// It is useless to set up tarantool.Opts.Reconnect value for a connection. +// The connection pool has its own reconnection logic. See +// Opts.CheckTimeout description. +func Connect(ctx context.Context, instances []Instance) (*ConnectionPool, error) { + opts := Opts{ + CheckTimeout: 1 * time.Second, + } + return ConnectWithOpts(ctx, instances, opts) +} + +// ConnectedNow gets connected status of pool. +func (p *ConnectionPool) ConnectedNow(mode Mode) (bool, error) { + p.poolsMutex.RLock() + defer p.poolsMutex.RUnlock() + + if p.state.get() != connectedState { + return false, nil + } + switch mode { + case ANY: + return !p.anyPool.IsEmpty(), nil + case RW: + return !p.rwPool.IsEmpty(), nil + case RO: + return !p.roPool.IsEmpty(), nil + case PreferRW: + fallthrough + case PreferRO: + return !p.rwPool.IsEmpty() || !p.roPool.IsEmpty(), nil + default: + return false, ErrNoHealthyInstance + } +} + +// ConfiguredTimeout gets timeout of current connection. +func (p *ConnectionPool) ConfiguredTimeout(mode Mode) (time.Duration, error) { + conn, err := p.getNextConnection(mode) + if err != nil { + return 0, err + } + + return conn.ConfiguredTimeout(), nil +} + +// Add adds a new instance into the pool. The pool will try to connect to the +// instance later if it is unable to establish a connection. +// +// The function may return an error and don't add the instance into the pool +// if the context has been cancelled or on concurrent Close()/CloseGraceful() +// call. +func (p *ConnectionPool) Add(ctx context.Context, instance Instance) error { + e := newEndpoint(instance.Name, instance.Dialer, instance.Opts) + + p.endsMutex.Lock() + // Ensure that Close()/CloseGraceful() not in progress/done. + if p.state.get() != connectedState { + p.endsMutex.Unlock() + return ErrClosed + } + if _, ok := p.ends[instance.Name]; ok { + p.endsMutex.Unlock() + return ErrExists + } + + endpointCtx, endpointCancel := context.WithCancel(context.Background()) + connectCtx, connectCancel := context.WithCancel(ctx) + e.cancel = func() { + connectCancel() + endpointCancel() + } + + p.ends[instance.Name] = e + p.endsMutex.Unlock() + + if err := p.tryConnect(connectCtx, e); err != nil { + var canceled bool + select { + case <-connectCtx.Done(): + canceled = true + case <-endpointCtx.Done(): + canceled = true + default: + canceled = false + } + if canceled { + if p.state.get() != connectedState { + // If it is canceled (or could be canceled) due to a + // Close()/CloseGraceful() call we overwrite the error + // to make behavior expected. + err = ErrClosed + } + + p.endsMutex.Lock() + delete(p.ends, instance.Name) + p.endsMutex.Unlock() + e.cancel() + close(e.closed) + return err + } else { + log.Printf("tarantool: connect to %s failed: %s\n", instance.Name, err) + } + } + + go p.controller(endpointCtx, e) + return nil +} + +// Remove removes an endpoint with the name from the pool. The call +// closes an active connection gracefully. +func (p *ConnectionPool) Remove(name string) error { + p.endsMutex.Lock() + endpoint, ok := p.ends[name] + if !ok { + p.endsMutex.Unlock() + return errors.New("endpoint not exist") + } + + select { + case <-endpoint.close: + // Close() in progress/done. + case <-endpoint.shutdown: + // CloseGraceful()/Remove() in progress/done. + default: + endpoint.cancel() + close(endpoint.shutdown) + } + + delete(p.ends, name) + p.endsMutex.Unlock() + + <-endpoint.closed + return nil +} + +func (p *ConnectionPool) waitClose() []error { + p.endsMutex.RLock() + endpoints := make([]*endpoint, 0, len(p.ends)) + for _, e := range p.ends { + endpoints = append(endpoints, e) + } + p.endsMutex.RUnlock() + + errs := make([]error, 0, len(endpoints)) + for _, e := range endpoints { + <-e.closed + if e.closeErr != nil { + errs = append(errs, e.closeErr) + } + } + return errs +} + +// Close closes connections in the ConnectionPool. +func (p *ConnectionPool) Close() []error { + if p.state.cas(connectedState, closedState) || + p.state.cas(shutdownState, closedState) { + p.endsMutex.RLock() + for _, s := range p.ends { + s.cancel() + close(s.close) + } + p.endsMutex.RUnlock() + } + + return p.waitClose() +} + +// CloseGraceful closes connections in the ConnectionPool gracefully. It waits +// for all requests to complete. +func (p *ConnectionPool) CloseGraceful() []error { + if p.state.cas(connectedState, shutdownState) { + p.endsMutex.RLock() + for _, s := range p.ends { + s.cancel() + close(s.shutdown) + } + p.endsMutex.RUnlock() + } + + return p.waitClose() +} + +// GetInfo gets information of connections (connected status, ro/rw role). +func (p *ConnectionPool) GetInfo() map[string]ConnectionInfo { + info := make(map[string]ConnectionInfo) + + p.endsMutex.RLock() + defer p.endsMutex.RUnlock() + p.poolsMutex.RLock() + defer p.poolsMutex.RUnlock() + + if p.state.get() != connectedState { + return info + } + + for name, end := range p.ends { + conn, role := p.getConnectionFromPool(name) + + connInfo := ConnectionInfo{ + ConnectedNow: false, + ConnRole: UnknownRole, + Instance: Instance{ + Name: name, + Dialer: end.dialer, + Opts: end.opts, + }, + } + + if conn != nil { + connInfo.ConnRole = role + connInfo.ConnectedNow = conn.ConnectedNow() + } + + info[name] = connInfo + } + + return info +} + +// NewStream creates new Stream object for connection selected +// by userMode from pool. +// +// Since v. 2.10.0, Tarantool supports streams and interactive transactions over them. +// To use interactive transactions, memtx_use_mvcc_engine box option should be set to true. +// Since 1.7.0 +func (p *ConnectionPool) NewStream(userMode Mode) (*tarantool.Stream, error) { + conn, err := p.getNextConnection(userMode) + if err != nil { + return nil, err + } + return conn.NewStream() +} + +// NewPrepared passes a sql statement to Tarantool for preparation synchronously. +func (p *ConnectionPool) NewPrepared(expr string, userMode Mode) (*tarantool.Prepared, error) { + conn, err := p.getNextConnection(userMode) + if err != nil { + return nil, err + } + return conn.NewPrepared(expr) +} + +// NewWatcher creates a new Watcher object for the connection pool. +// A watcher could be created only for instances with the support. +// +// The behavior is same as if Connection.NewWatcher() called for each +// connection with a suitable mode. +// +// Keep in mind that garbage collection of a watcher handle doesn’t lead to the +// watcher’s destruction. In this case, the watcher remains registered. You +// need to call Unregister() directly. +// +// Unregister() guarantees that there will be no the watcher's callback calls +// after it, but Unregister() call from the callback leads to a deadlock. +// +// See: +// https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_events/#box-watchers +// +// Since 1.10.0 +func (p *ConnectionPool) NewWatcher(key string, + callback tarantool.WatchCallback, mode Mode) (tarantool.Watcher, error) { + + watcher := &poolWatcher{ + container: &p.watcherContainer, + mode: mode, + key: key, + callback: callback, + watchers: make(map[*tarantool.Connection]tarantool.Watcher), + unregistered: false, + } + + watcher.container.add(watcher) + + rr := p.anyPool + if mode == RW { + rr = p.rwPool + } else if mode == RO { + rr = p.roPool + } + + conns := rr.GetConnections() + for _, conn := range conns { + // Check that connection supports watchers. + if !isFeatureInSlice(iproto.IPROTO_FEATURE_WATCHERS, conn.ProtocolInfo().Features) { + continue + } + if err := watcher.watch(conn); err != nil { + conn.Close() + } + } + + return watcher, nil +} + +// Do sends the request and returns a future. +// For requests that belong to the only one connection (e.g. Unprepare or ExecutePrepared) +// the argument of type Mode is unused. +func (p *ConnectionPool) Do(req tarantool.Request, userMode Mode) *tarantool.Future { + if connectedReq, ok := req.(tarantool.ConnectedRequest); ok { + conns := p.anyPool.GetConnections() + isOurConnection := false + for _, conn := range conns { + // Compare raw pointers. + if conn == connectedReq.Conn() { + isOurConnection = true + break + } + } + if !isOurConnection { + return newErrorFuture(ErrUnknownRequest) + } + return connectedReq.Conn().Do(req) + } + conn, err := p.getNextConnection(userMode) + if err != nil { + return newErrorFuture(err) + } + + return conn.Do(req) +} + +// DoInstance sends the request into a target instance and returns a future. +func (p *ConnectionPool) DoInstance(req tarantool.Request, name string) *tarantool.Future { + conn := p.anyPool.GetConnection(name) + if conn == nil { + return newErrorFuture(ErrNoHealthyInstance) + } + + return conn.Do(req) +} + +// +// private +// + +func (p *ConnectionPool) getConnectionRole(conn *tarantool.Connection) (Role, error) { + var ( + roFieldName string + data []interface{} + err error + ) + + if isFeatureInSlice(iproto.IPROTO_FEATURE_WATCH_ONCE, conn.ProtocolInfo().Features) { + roFieldName = "is_ro" + data, err = conn.Do(tarantool.NewWatchOnceRequest("box.status")).Get() + } else { + roFieldName = "ro" + data, err = conn.Do(tarantool.NewCallRequest("box.info")).Get() + } + + if err != nil { + return UnknownRole, err + } + if len(data) < 1 { + return UnknownRole, ErrIncorrectResponse + } + + respFields, ok := data[0].(map[interface{}]interface{}) + if !ok { + return UnknownRole, ErrIncorrectResponse + } + + instanceStatus, ok := respFields["status"] + if !ok { + return UnknownRole, ErrIncorrectResponse + } + if instanceStatus != "running" { + return UnknownRole, ErrIncorrectStatus + } + + replicaRole, ok := respFields[roFieldName] + if !ok { + return UnknownRole, ErrIncorrectResponse + } + + switch replicaRole { + case false: + return MasterRole, nil + case true: + return ReplicaRole, nil + } + + return UnknownRole, nil +} + +func (p *ConnectionPool) getConnectionFromPool(name string) (*tarantool.Connection, Role) { + if conn := p.rwPool.GetConnection(name); conn != nil { + return conn, MasterRole + } + + if conn := p.roPool.GetConnection(name); conn != nil { + return conn, ReplicaRole + } + + return p.anyPool.GetConnection(name), UnknownRole +} + +func (p *ConnectionPool) deleteConnection(name string) { + if conn := p.anyPool.DeleteConnection(name); conn != nil { + if conn := p.rwPool.DeleteConnection(name); conn == nil { + p.roPool.DeleteConnection(name) + } + // The internal connection deinitialization. + p.watcherContainer.mutex.RLock() + defer p.watcherContainer.mutex.RUnlock() + + p.watcherContainer.foreach(func(watcher *poolWatcher) error { + watcher.unwatch(conn) + return nil + }) + } +} + +func (p *ConnectionPool) addConnection(name string, + conn *tarantool.Connection, role Role) error { + // The internal connection initialization. + p.watcherContainer.mutex.RLock() + defer p.watcherContainer.mutex.RUnlock() + + if isFeatureInSlice(iproto.IPROTO_FEATURE_WATCHERS, conn.ProtocolInfo().Features) { + watched := []*poolWatcher{} + err := p.watcherContainer.foreach(func(watcher *poolWatcher) error { + watch := false + switch watcher.mode { + case RW: + watch = role == MasterRole + case RO: + watch = role == ReplicaRole + default: + watch = true + } + if watch { + if err := watcher.watch(conn); err != nil { + return err + } + watched = append(watched, watcher) + } + return nil + }) + if err != nil { + for _, watcher := range watched { + watcher.unwatch(conn) + } + log.Printf("tarantool: failed initialize watchers for %s: %s", name, err) + return err + } + } + + p.anyPool.AddConnection(name, conn) + + switch role { + case MasterRole: + p.rwPool.AddConnection(name, conn) + case ReplicaRole: + p.roPool.AddConnection(name, conn) + } + return nil +} + +func (p *ConnectionPool) handlerDiscovered(name string, conn *tarantool.Connection, + role Role) bool { + var err error + if p.opts.ConnectionHandler != nil { + err = p.opts.ConnectionHandler.Discovered(name, conn, role) + } + + if err != nil { + log.Printf("tarantool: storing connection to %s canceled: %s\n", name, err) + return false + } + return true +} + +func (p *ConnectionPool) handlerDeactivated(name string, conn *tarantool.Connection, + role Role) { + var err error + if p.opts.ConnectionHandler != nil { + err = p.opts.ConnectionHandler.Deactivated(name, conn, role) + } + + if err != nil { + log.Printf("tarantool: deactivating connection to %s by user failed: %s\n", + name, err) + } +} + +func (p *ConnectionPool) fillPools(ctx context.Context, instances []Instance) <-chan error { + done := make(chan error, len(instances)) + + // It is called before controller() goroutines, so we don't expect + // concurrency issues here. + for _, instance := range instances { + end := newEndpoint(instance.Name, instance.Dialer, instance.Opts) + p.ends[instance.Name] = end + } + + for _, instance := range instances { + name := instance.Name + end := p.ends[name] + + go func() { + if err := p.tryConnect(ctx, end); err != nil { + log.Printf("tarantool: connect to %s failed: %s\n", name, err) + done <- fmt.Errorf("failed to connect to %s :%w", name, err) + + return + } + + done <- nil + }() + } + + return done +} + +func (p *ConnectionPool) updateConnection(e *endpoint) { + p.poolsMutex.Lock() + + if p.state.get() != connectedState { + p.poolsMutex.Unlock() + return + } + + if role, err := p.getConnectionRole(e.conn); err == nil { + if e.role != role { + p.deleteConnection(e.name) + p.poolsMutex.Unlock() + + p.handlerDeactivated(e.name, e.conn, e.role) + opened := p.handlerDiscovered(e.name, e.conn, role) + if !opened { + e.conn.Close() + e.conn = nil + e.role = UnknownRole + return + } + + p.poolsMutex.Lock() + if p.state.get() != connectedState { + p.poolsMutex.Unlock() + + e.conn.Close() + p.handlerDeactivated(e.name, e.conn, role) + e.conn = nil + e.role = UnknownRole + return + } + + if p.addConnection(e.name, e.conn, role) != nil { + p.poolsMutex.Unlock() + + e.conn.Close() + p.handlerDeactivated(e.name, e.conn, role) + e.conn = nil + e.role = UnknownRole + return + } + e.role = role + } + p.poolsMutex.Unlock() + return + } else { + p.deleteConnection(e.name) + p.poolsMutex.Unlock() + + e.conn.Close() + p.handlerDeactivated(e.name, e.conn, e.role) + e.conn = nil + e.role = UnknownRole + return + } +} + +func (p *ConnectionPool) tryConnect(ctx context.Context, e *endpoint) error { + e.conn = nil + e.role = UnknownRole + + connOpts := e.opts + connOpts.Notify = e.notify + conn, err := tarantool.Connect(ctx, e.dialer, connOpts) + + p.poolsMutex.Lock() + + if p.state.get() != connectedState { + if err == nil { + conn.Close() + } + + p.poolsMutex.Unlock() + return ErrClosed + } + + if err == nil { + role, err := p.getConnectionRole(conn) + p.poolsMutex.Unlock() + + if err != nil { + conn.Close() + log.Printf("tarantool: storing connection to %s failed: %s\n", + e.name, err) + return err + } + + opened := p.handlerDiscovered(e.name, conn, role) + if !opened { + conn.Close() + return errors.New("storing connection canceled") + } + + p.poolsMutex.Lock() + if p.state.get() != connectedState { + p.poolsMutex.Unlock() + conn.Close() + p.handlerDeactivated(e.name, conn, role) + return ErrClosed + } + + if err = p.addConnection(e.name, conn, role); err != nil { + p.poolsMutex.Unlock() + conn.Close() + p.handlerDeactivated(e.name, conn, role) + return err + } + e.conn = conn + e.role = role + } + + p.poolsMutex.Unlock() + return err +} + +func (p *ConnectionPool) reconnect(ctx context.Context, e *endpoint) { + p.poolsMutex.Lock() + + if p.state.get() != connectedState { + p.poolsMutex.Unlock() + return + } + + p.deleteConnection(e.name) + p.poolsMutex.Unlock() + + p.handlerDeactivated(e.name, e.conn, e.role) + e.conn = nil + e.role = UnknownRole + + if err := p.tryConnect(ctx, e); err != nil { + log.Printf("tarantool: reconnect to %s failed: %s\n", e.name, err) + } +} + +func (p *ConnectionPool) controller(ctx context.Context, e *endpoint) { + timer := time.NewTicker(p.opts.CheckTimeout) + defer timer.Stop() + + shutdown := false + for { + if shutdown { + // Graceful shutdown in progress. We need to wait for a finish or + // to force close. + select { + case <-e.closed: + case <-e.close: + } + } + + select { + case <-e.closed: + return + default: + } + + select { + // e.close has priority to avoid concurrency with e.shutdown. + case <-e.close: + if e.conn != nil { + p.poolsMutex.Lock() + p.deleteConnection(e.name) + p.poolsMutex.Unlock() + + if !shutdown { + e.closeErr = e.conn.Close() + p.handlerDeactivated(e.name, e.conn, e.role) + close(e.closed) + } else { + // Force close the connection. + e.conn.Close() + // And wait for a finish. + <-e.closed + } + } else { + close(e.closed) + } + default: + select { + case <-e.shutdown: + shutdown = true + if e.conn != nil { + p.poolsMutex.Lock() + p.deleteConnection(e.name) + p.poolsMutex.Unlock() + + // We need to catch s.close in the current goroutine, so + // we need to start an another one for the shutdown. + go func() { + e.closeErr = e.conn.CloseGraceful() + p.handlerDeactivated(e.name, e.conn, e.role) + close(e.closed) + }() + } else { + close(e.closed) + } + default: + select { + case <-e.close: + // Will be processed at an upper level. + case <-e.shutdown: + // Will be processed at an upper level. + case <-e.notify: + if e.conn != nil && e.conn.ClosedNow() { + p.poolsMutex.Lock() + if p.state.get() == connectedState { + p.deleteConnection(e.name) + p.poolsMutex.Unlock() + p.handlerDeactivated(e.name, e.conn, e.role) + e.conn = nil + e.role = UnknownRole + } else { + p.poolsMutex.Unlock() + } + } + case <-timer.C: + // Reopen connection. + // Relocate connection between subpools + // if ro/rw was updated. + if e.conn == nil { + if err := p.tryConnect(ctx, e); err != nil { + log.Printf("tarantool: reopen connection to %s failed: %s\n", + e.name, err) + } + } else if !e.conn.ClosedNow() { + p.updateConnection(e) + } else { + p.reconnect(ctx, e) + } + } + } + } + } +} + +func (p *ConnectionPool) getNextConnection(mode Mode) (*tarantool.Connection, error) { + + switch mode { + case ANY: + if next := p.anyPool.GetNextConnection(); next != nil { + return next, nil + } + case RW: + if next := p.rwPool.GetNextConnection(); next != nil { + return next, nil + } + return nil, ErrNoRwInstance + case RO: + if next := p.roPool.GetNextConnection(); next != nil { + return next, nil + } + return nil, ErrNoRoInstance + case PreferRW: + if next := p.rwPool.GetNextConnection(); next != nil { + return next, nil + } + if next := p.roPool.GetNextConnection(); next != nil { + return next, nil + } + case PreferRO: + if next := p.roPool.GetNextConnection(); next != nil { + return next, nil + } + if next := p.rwPool.GetNextConnection(); next != nil { + return next, nil + } + } + return nil, ErrNoHealthyInstance +} + +func newErrorFuture(err error) *tarantool.Future { + fut := tarantool.NewFuture(nil) + fut.SetError(err) + return fut +} + +func isFeatureInSlice(expected iproto.Feature, actualSlice []iproto.Feature) bool { + for _, actual := range actualSlice { + if expected == actual { + return true + } + } + return false +} diff --git a/pool/connection_pool_test.go b/pool/connection_pool_test.go new file mode 100644 index 000000000..f4a0376d4 --- /dev/null +++ b/pool/connection_pool_test.go @@ -0,0 +1,3708 @@ +package pool_test + +import ( + "bytes" + "context" + "fmt" + "log" + "net" + "os" + "reflect" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/pool" + "github.com/tarantool/go-tarantool/v3/test_helpers" +) + +var user = "test" +var userNoExec = "test_noexec" +var pass = "test" +var spaceNo = uint32(520) +var spaceName = "testPool" +var indexNo = uint32(0) + +var ports = []string{"3013", "3014", "3015", "3016", "3017"} +var host = "127.0.0.1" + +var servers = []string{ + strings.Join([]string{host, ports[0]}, ":"), + strings.Join([]string{host, ports[1]}, ":"), + strings.Join([]string{host, ports[2]}, ":"), + strings.Join([]string{host, ports[3]}, ":"), + strings.Join([]string{host, ports[4]}, ":"), +} + +func makeDialer(server string) tarantool.Dialer { + return tarantool.NetDialer{ + Address: server, + User: user, + Password: pass, + } +} + +func makeDialers(servers []string) []tarantool.Dialer { + dialers := make([]tarantool.Dialer, 0, len(servers)) + for _, server := range servers { + dialers = append(dialers, makeDialer(server)) + } + return dialers +} + +var dialers = makeDialers(servers) + +func makeInstance(server string, opts tarantool.Opts) pool.Instance { + return pool.Instance{ + Name: server, + Dialer: tarantool.NetDialer{ + Address: server, + User: user, + Password: pass, + }, + Opts: opts, + } +} + +func makeNoExecuteInstance(server string, opts tarantool.Opts) pool.Instance { + return pool.Instance{ + Name: server, + Dialer: tarantool.NetDialer{ + Address: server, + User: userNoExec, + Password: pass, + }, + Opts: opts, + } +} + +func makeInstances(servers []string, opts tarantool.Opts) []pool.Instance { + var instances []pool.Instance + for _, server := range servers { + instances = append(instances, makeInstance(server, opts)) + } + return instances +} + +var instances = makeInstances(servers, connOpts) +var connOpts = tarantool.Opts{ + Timeout: 5 * time.Second, +} + +var defaultCountRetry = 5 +var defaultTimeoutRetry = 500 * time.Millisecond + +var helpInstances []*test_helpers.TarantoolInstance + +func TestConnect_error_duplicate(t *testing.T) { + ctx, cancel := test_helpers.GetPoolConnectContext() + connPool, err := pool.Connect(ctx, makeInstances([]string{"foo", "foo"}, connOpts)) + cancel() + + require.Nilf(t, connPool, "conn is not nil with incorrect param") + require.EqualError(t, err, "duplicate instance name: \"foo\"") +} + +func TestConnectWithOpts_error_no_timeout(t *testing.T) { + ctx, cancel := test_helpers.GetPoolConnectContext() + connPool, err := pool.ConnectWithOpts(ctx, makeInstances([]string{"any"}, connOpts), + pool.Opts{}) + cancel() + require.Nilf(t, connPool, "conn is not nil with incorrect param") + require.ErrorIs(t, err, pool.ErrWrongCheckTimeout) +} + +func TestConnSuccessfully(t *testing.T) { + healthyServ := servers[0] + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, makeInstances([]string{healthyServ, "err"}, connOpts)) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + args := test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: []string{healthyServ}, + ExpectedPoolStatus: true, + ExpectedStatuses: map[string]bool{ + healthyServ: true, + "err": false, + }, + } + + err = test_helpers.CheckPoolStatuses(args) + require.NoError(t, err) +} + +func TestConn_no_execute_supported(t *testing.T) { + test_helpers.SkipIfWatchOnceUnsupported(t) + + healthyServ := servers[0] + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, + []pool.Instance{makeNoExecuteInstance(healthyServ, connOpts)}) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + args := test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: []string{healthyServ}, + ExpectedPoolStatus: true, + ExpectedStatuses: map[string]bool{ + healthyServ: true, + }, + } + + err = test_helpers.CheckPoolStatuses(args) + require.Nil(t, err) + + _, err = connPool.Do(tarantool.NewPingRequest(), pool.ANY).Get() + require.Nil(t, err) +} + +func TestConn_no_execute_unsupported(t *testing.T) { + test_helpers.SkipIfWatchOnceSupported(t) + + var buf bytes.Buffer + log.SetOutput(&buf) + defer log.SetOutput(os.Stderr) + + healthyServ := servers[0] + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, + []pool.Instance{makeNoExecuteInstance(healthyServ, connOpts)}) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + require.Contains(t, buf.String(), + fmt.Sprintf("connect to %s failed: Execute access to function "+ + "'box.info' is denied for user '%s'", servers[0], userNoExec)) + + args := test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: []string{healthyServ}, + ExpectedPoolStatus: false, + ExpectedStatuses: map[string]bool{ + healthyServ: false, + }, + } + + err = test_helpers.CheckPoolStatuses(args) + require.Nil(t, err) + + _, err = connPool.Do(tarantool.NewPingRequest(), pool.ANY).Get() + require.Error(t, err) + require.Equal(t, "can't find healthy instance in pool", err.Error()) +} + +func TestConnect_empty(t *testing.T) { + cases := []struct { + Name string + Instances []pool.Instance + }{ + {"nil", nil}, + {"empty", []pool.Instance{}}, + } + + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, tc.Instances) + if connPool != nil { + defer connPool.Close() + } + require.NoError(t, err, "failed to create a pool") + require.NotNilf(t, connPool, "pool is nil after Connect") + require.Lenf(t, connPool.GetInfo(), 0, "empty pool expected") + }) + } +} + +func TestConnect_unavailable(t *testing.T) { + servers := []string{"err1", "err2"} + ctx, cancel := test_helpers.GetPoolConnectContext() + insts := makeInstances([]string{"err1", "err2"}, connOpts) + + connPool, err := pool.Connect(ctx, insts) + cancel() + + if connPool != nil { + defer connPool.Close() + } + + require.NoError(t, err, "failed to create a pool") + require.NotNilf(t, connPool, "pool is nil after Connect") + require.Equal(t, map[string]pool.ConnectionInfo{ + servers[0]: pool.ConnectionInfo{ + ConnectedNow: false, ConnRole: pool.UnknownRole, Instance: insts[0]}, + servers[1]: pool.ConnectionInfo{ + ConnectedNow: false, ConnRole: pool.UnknownRole, Instance: insts[1]}, + }, connPool.GetInfo()) +} + +func TestConnect_single_server_hang(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer l.Close() + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + insts := makeInstances([]string{l.Addr().String()}, connOpts) + + connPool, err := pool.Connect(ctx, insts) + if connPool != nil { + defer connPool.Close() + } + + require.ErrorIs(t, err, context.DeadlineExceeded) + require.Nil(t, connPool) +} + +func TestConnect_server_hang(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer l.Close() + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + servers := []string{l.Addr().String(), servers[0]} + insts := makeInstances(servers, connOpts) + + connPool, err := pool.Connect(ctx, insts) + if connPool != nil { + defer connPool.Close() + } + + require.NoError(t, err, "failed to create a pool") + require.NotNil(t, connPool, "pool is nil after Connect") + require.Equal(t, map[string]pool.ConnectionInfo{ + servers[0]: pool.ConnectionInfo{ + ConnectedNow: false, ConnRole: pool.UnknownRole, Instance: insts[0]}, + servers[1]: pool.ConnectionInfo{ + ConnectedNow: true, ConnRole: pool.MasterRole, Instance: insts[1]}, + }, connPool.GetInfo()) +} + +func TestConnErrorAfterCtxCancel(t *testing.T) { + var connLongReconnectOpts = tarantool.Opts{ + Timeout: 5 * time.Second, + Reconnect: time.Second, + MaxReconnects: 100, + } + + ctx, cancel := context.WithCancel(context.Background()) + + var connPool *pool.ConnectionPool + var err error + + cancel() + connPool, err = pool.Connect(ctx, makeInstances(servers, connLongReconnectOpts)) + + if connPool != nil || err == nil { + t.Fatalf("ConnectionPool was created after cancel") + } + if !strings.Contains(err.Error(), "context canceled") { + t.Fatalf("Unexpected error, expected to contain %s, got %v", + "operation was canceled", err) + } +} + +type mockClosingDialer struct { + addr string + ctx context.Context + ctxCancel context.CancelFunc +} + +func (m *mockClosingDialer) Dial(ctx context.Context, + opts tarantool.DialOpts) (tarantool.Conn, error) { + dialer := tarantool.NetDialer{ + Address: m.addr, + User: user, + Password: pass, + } + conn, err := dialer.Dial(m.ctx, tarantool.DialOpts{}) + + m.ctxCancel() + + return conn, err +} + +func TestConnectContextCancelAfterConnect(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var instances []pool.Instance + for _, server := range servers { + instances = append(instances, pool.Instance{ + Name: server, + Dialer: &mockClosingDialer{ + addr: server, + ctx: ctx, + ctxCancel: cancel, + }, + Opts: connOpts, + }) + } + + connPool, err := pool.Connect(ctx, instances) + if connPool != nil { + defer connPool.Close() + } + + assert.NoError(t, err, "expected err after ctx cancel") + assert.NotNil(t, connPool) +} + +func TestConnSuccessfullyDuplicates(t *testing.T) { + server := servers[0] + + var instances []pool.Instance + for i := 0; i < 4; i++ { + instances = append(instances, pool.Instance{ + Name: fmt.Sprintf("c%d", i), + Dialer: tarantool.NetDialer{ + Address: server, + User: user, + Password: pass, + }, + Opts: connOpts, + }) + } + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + args := test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: []string{"c0", "c1", "c2", "c3"}, + ExpectedPoolStatus: true, + ExpectedStatuses: map[string]bool{ + "c0": true, + "c1": true, + "c2": true, + "c3": true, + }, + } + + err = test_helpers.CheckPoolStatuses(args) + require.Nil(t, err) +} + +func TestReconnect(t *testing.T) { + server := servers[0] + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + test_helpers.StopTarantoolWithCleanup(helpInstances[0]) + + args := test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: []string{server}, + ExpectedPoolStatus: true, + ExpectedStatuses: map[string]bool{ + server: false, + }, + } + + err = test_helpers.Retry(test_helpers.CheckPoolStatuses, args, + defaultCountRetry, defaultTimeoutRetry) + require.Nil(t, err) + + err = test_helpers.RestartTarantool(helpInstances[0]) + require.Nilf(t, err, "failed to restart tarantool") + + args = test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: []string{server}, + ExpectedPoolStatus: true, + ExpectedStatuses: map[string]bool{ + server: true, + }, + } + + err = test_helpers.Retry(test_helpers.CheckPoolStatuses, args, + defaultCountRetry, defaultTimeoutRetry) + require.Nil(t, err) +} + +func TestDisconnect_withReconnect(t *testing.T) { + serverId := 0 + server := servers[serverId] + + opts := connOpts + opts.Reconnect = 10 * time.Second + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, makeInstances([]string{server}, opts)) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + // Test. + test_helpers.StopTarantoolWithCleanup(helpInstances[serverId]) + args := test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: []string{servers[serverId]}, + ExpectedPoolStatus: false, + ExpectedStatuses: map[string]bool{ + servers[serverId]: false, + }, + } + err = test_helpers.Retry(test_helpers.CheckPoolStatuses, + args, defaultCountRetry, defaultTimeoutRetry) + require.Nil(t, err) + + // Restart the server after success. + err = test_helpers.RestartTarantool(helpInstances[serverId]) + require.Nilf(t, err, "failed to restart tarantool") + + args = test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: []string{servers[serverId]}, + ExpectedPoolStatus: true, + ExpectedStatuses: map[string]bool{ + servers[serverId]: true, + }, + } + + err = test_helpers.Retry(test_helpers.CheckPoolStatuses, + args, defaultCountRetry, defaultTimeoutRetry) + require.Nil(t, err) +} + +func TestDisconnectAll(t *testing.T) { + server1 := servers[0] + server2 := servers[1] + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, makeInstances([]string{server1, server2}, connOpts)) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + test_helpers.StopTarantoolWithCleanup(helpInstances[0]) + test_helpers.StopTarantoolWithCleanup(helpInstances[1]) + + args := test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: []string{server1, server2}, + ExpectedPoolStatus: false, + ExpectedStatuses: map[string]bool{ + server1: false, + server2: false, + }, + } + + err = test_helpers.Retry(test_helpers.CheckPoolStatuses, args, + defaultCountRetry, defaultTimeoutRetry) + require.Nil(t, err) + + err = test_helpers.RestartTarantool(helpInstances[0]) + require.Nilf(t, err, "failed to restart tarantool") + + err = test_helpers.RestartTarantool(helpInstances[1]) + require.Nilf(t, err, "failed to restart tarantool") + + args = test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: []string{server1, server2}, + ExpectedPoolStatus: true, + ExpectedStatuses: map[string]bool{ + server1: true, + server2: true, + }, + } + + err = test_helpers.Retry(test_helpers.CheckPoolStatuses, args, + defaultCountRetry, defaultTimeoutRetry) + require.Nil(t, err) +} + +func TestAdd(t *testing.T) { + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, []pool.Instance{}) + require.NoError(t, err, "failed to connect") + require.NotNil(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + for _, server := range servers { + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + + err = connPool.Add(ctx, makeInstance(server, connOpts)) + require.Nil(t, err) + } + + args := test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: servers, + ExpectedPoolStatus: true, + ExpectedStatuses: map[string]bool{ + servers[0]: true, + servers[1]: true, + servers[2]: true, + servers[3]: true, + servers[4]: true, + }, + } + + err = test_helpers.Retry(test_helpers.CheckPoolStatuses, args, + defaultCountRetry, defaultTimeoutRetry) + require.Nil(t, err) +} + +func TestAdd_canceled_ctx(t *testing.T) { + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, []pool.Instance{}) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + ctx, cancel = test_helpers.GetConnectContext() + cancel() + + err = connPool.Add(ctx, makeInstance(servers[0], connOpts)) + require.Error(t, err) +} + +func TestAdd_exist(t *testing.T) { + server := servers[0] + ctx, cancel := test_helpers.GetPoolConnectContext() + connPool, err := pool.Connect(ctx, makeInstances([]string{server}, connOpts)) + cancel() + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + ctx, cancel = test_helpers.GetConnectContext() + defer cancel() + + err = connPool.Add(ctx, makeInstance(server, connOpts)) + require.Equal(t, pool.ErrExists, err) + + args := test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: servers, + ExpectedPoolStatus: true, + ExpectedStatuses: map[string]bool{ + server: true, + }, + } + + err = test_helpers.Retry(test_helpers.CheckPoolStatuses, args, + defaultCountRetry, defaultTimeoutRetry) + require.Nil(t, err) +} + +func TestAdd_unreachable(t *testing.T) { + server := servers[0] + + ctx, cancel := test_helpers.GetPoolConnectContext() + connPool, err := pool.Connect(ctx, makeInstances([]string{server}, connOpts)) + cancel() + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + unhealthyServ := "unreachable:6667" + err = connPool.Add(context.Background(), pool.Instance{ + Name: unhealthyServ, + Dialer: tarantool.NetDialer{ + Address: unhealthyServ, + }, + Opts: connOpts, + }) + require.NoError(t, err) + + args := test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: servers, + ExpectedPoolStatus: true, + ExpectedStatuses: map[string]bool{ + server: true, + unhealthyServ: false, + }, + } + + err = test_helpers.Retry(test_helpers.CheckPoolStatuses, args, + defaultCountRetry, defaultTimeoutRetry) + require.Nil(t, err) +} + +func TestAdd_afterClose(t *testing.T) { + server := servers[0] + ctx, cancel := test_helpers.GetPoolConnectContext() + connPool, err := pool.Connect(ctx, makeInstances([]string{server}, connOpts)) + cancel() + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + connPool.Close() + + ctx, cancel = test_helpers.GetConnectContext() + defer cancel() + + err = connPool.Add(ctx, makeInstance(server, connOpts)) + assert.Equal(t, err, pool.ErrClosed) +} + +func TestAdd_Close_concurrent(t *testing.T) { + serv0 := servers[0] + serv1 := servers[1] + + ctx, cancel := test_helpers.GetPoolConnectContext() + connPool, err := pool.Connect(ctx, makeInstances([]string{serv0}, connOpts)) + cancel() + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + + ctx, cancel = test_helpers.GetConnectContext() + defer cancel() + + err = connPool.Add(ctx, makeInstance(serv1, connOpts)) + if err != nil { + assert.Equal(t, pool.ErrClosed, err) + } + }() + + connPool.Close() + + wg.Wait() +} + +func TestAdd_CloseGraceful_concurrent(t *testing.T) { + serv0 := servers[0] + serv1 := servers[1] + + ctx, cancel := test_helpers.GetPoolConnectContext() + connPool, err := pool.Connect(ctx, makeInstances([]string{serv0}, connOpts)) + cancel() + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + + ctx, cancel = test_helpers.GetConnectContext() + defer cancel() + + err = connPool.Add(ctx, makeInstance(serv1, connOpts)) + if err != nil { + assert.Equal(t, pool.ErrClosed, err) + } + }() + + connPool.CloseGraceful() + + wg.Wait() +} + +func TestRemove(t *testing.T) { + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + for _, server := range servers[1:] { + err = connPool.Remove(server) + require.Nil(t, err) + } + + args := test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: servers, + ExpectedPoolStatus: true, + ExpectedStatuses: map[string]bool{ + servers[0]: true, + }, + } + + err = test_helpers.Retry(test_helpers.CheckPoolStatuses, args, + defaultCountRetry, defaultTimeoutRetry) + require.Nil(t, err) +} + +func TestRemove_double(t *testing.T) { + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, makeInstances(servers[:2], connOpts)) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + err = connPool.Remove(servers[1]) + require.Nil(t, err) + err = connPool.Remove(servers[1]) + require.ErrorContains(t, err, "endpoint not exist") + + args := test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: servers, + ExpectedPoolStatus: true, + ExpectedStatuses: map[string]bool{ + servers[0]: true, + }, + } + + err = test_helpers.Retry(test_helpers.CheckPoolStatuses, args, + defaultCountRetry, defaultTimeoutRetry) + require.Nil(t, err) +} + +func TestRemove_unknown(t *testing.T) { + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, makeInstances(servers[:2], connOpts)) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + err = connPool.Remove("not_exist:6667") + require.ErrorContains(t, err, "endpoint not exist") + + args := test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: servers, + ExpectedPoolStatus: true, + ExpectedStatuses: map[string]bool{ + servers[0]: true, + servers[1]: true, + }, + } + + err = test_helpers.Retry(test_helpers.CheckPoolStatuses, args, + defaultCountRetry, defaultTimeoutRetry) + require.Nil(t, err) +} + +func TestRemove_concurrent(t *testing.T) { + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, makeInstances(servers[:2], connOpts)) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + const concurrency = 10 + var ( + wg sync.WaitGroup + ok uint32 + errs uint32 + ) + + wg.Add(concurrency) + for i := 0; i < concurrency; i++ { + go func() { + defer wg.Done() + err := connPool.Remove(servers[1]) + if err == nil { + atomic.AddUint32(&ok, 1) + } else { + assert.ErrorContains(t, err, "endpoint not exist") + atomic.AddUint32(&errs, 1) + } + }() + } + + wg.Wait() + assert.Equal(t, uint32(1), ok) + assert.Equal(t, uint32(concurrency-1), errs) + + args := test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: servers, + ExpectedPoolStatus: true, + ExpectedStatuses: map[string]bool{ + servers[0]: true, + }, + } + + err = test_helpers.Retry(test_helpers.CheckPoolStatuses, args, + defaultCountRetry, defaultTimeoutRetry) + require.Nil(t, err) +} + +func TestRemove_Close_concurrent(t *testing.T) { + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, makeInstances(servers[:2], connOpts)) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + + err = connPool.Remove(servers[1]) + assert.Nil(t, err) + }() + + connPool.Close() + + wg.Wait() +} + +func TestRemove_CloseGraceful_concurrent(t *testing.T) { + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, makeInstances(servers[:2], connOpts)) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + + err = connPool.Remove(servers[1]) + assert.Nil(t, err) + }() + + connPool.CloseGraceful() + + wg.Wait() +} + +func TestClose(t *testing.T) { + server1 := servers[0] + server2 := servers[1] + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, makeInstances([]string{server1, server2}, connOpts)) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + args := test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: []string{server1, server2}, + ExpectedPoolStatus: true, + ExpectedStatuses: map[string]bool{ + server1: true, + server2: true, + }, + } + + err = test_helpers.CheckPoolStatuses(args) + require.Nil(t, err) + + connPool.Close() + + args = test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: []string{server1, server2}, + ExpectedPoolStatus: false, + ExpectedStatuses: map[string]bool{ + server1: false, + server2: false, + }, + } + + err = test_helpers.Retry(test_helpers.CheckPoolStatuses, args, + defaultCountRetry, defaultTimeoutRetry) + require.Nil(t, err) +} + +func TestCloseGraceful(t *testing.T) { + server1 := servers[0] + server2 := servers[1] + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, makeInstances([]string{server1, server2}, connOpts)) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + args := test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: []string{server1, server2}, + ExpectedPoolStatus: true, + ExpectedStatuses: map[string]bool{ + server1: true, + server2: true, + }, + } + + err = test_helpers.CheckPoolStatuses(args) + require.Nil(t, err) + + eval := `local fiber = require('fiber') + local time = ... + fiber.sleep(time) + +` + + evalSleep := 3 // In seconds. + req := tarantool.NewEvalRequest(eval).Args([]interface{}{evalSleep}) + fut := connPool.Do(req, pool.ANY) + go func() { + connPool.CloseGraceful() + }() + + // Check that a request rejected if graceful shutdown in progress. + time.Sleep((time.Duration(evalSleep) * time.Second) / 2) + _, err = connPool.Do(tarantool.NewPingRequest(), pool.ANY).Get() + require.ErrorContains(t, err, "can't find healthy instance in pool") + + // Check that a previous request was successful. + _, err = fut.Get() + require.Nilf(t, err, "sleep request no error") + + args = test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: []string{server1, server2}, + ExpectedPoolStatus: false, + ExpectedStatuses: map[string]bool{ + server1: false, + server2: false, + }, + } + + err = test_helpers.Retry(test_helpers.CheckPoolStatuses, args, + defaultCountRetry, defaultTimeoutRetry) + require.Nil(t, err) +} + +type testHandler struct { + discovered, deactivated uint32 + errs []error +} + +func (h *testHandler) addErr(err error) { + h.errs = append(h.errs, err) +} + +func (h *testHandler) Discovered(name string, conn *tarantool.Connection, + role pool.Role) error { + discovered := atomic.AddUint32(&h.discovered, 1) + + if conn == nil { + h.addErr(fmt.Errorf("discovered conn == nil")) + return nil + } + + // discovered < 3 - initial open of connections + // discovered >= 3 - update a connection after a role update + if name == servers[0] { + if discovered < 3 && role != pool.MasterRole { + h.addErr(fmt.Errorf("unexpected init role %d for name %s", role, name)) + } + if discovered >= 3 && role != pool.ReplicaRole { + h.addErr(fmt.Errorf("unexpected updated role %d for name %s", role, name)) + } + } else if name == servers[1] { + if discovered >= 3 { + h.addErr(fmt.Errorf("unexpected discovery for name %s", name)) + } + if role != pool.ReplicaRole { + h.addErr(fmt.Errorf("unexpected role %d for name %s", role, name)) + } + } else { + h.addErr(fmt.Errorf("unexpected discovered name %s", name)) + } + + return nil +} + +func (h *testHandler) Deactivated(name string, conn *tarantool.Connection, + role pool.Role) error { + deactivated := atomic.AddUint32(&h.deactivated, 1) + + if conn == nil { + h.addErr(fmt.Errorf("removed conn == nil")) + return nil + } + + if deactivated == 1 && name == servers[0] { + // A first close is a role update. + if role != pool.MasterRole { + h.addErr(fmt.Errorf("unexpected removed role %d for name %s", role, name)) + } + return nil + } + + if name == servers[0] || name == servers[1] { + // Close. + if role != pool.ReplicaRole { + h.addErr(fmt.Errorf("unexpected removed role %d for name %s", role, name)) + } + } else { + h.addErr(fmt.Errorf("unexpected removed name %s", name)) + } + + return nil +} + +func TestConnectionHandlerOpenUpdateClose(t *testing.T) { + poolServers := []string{servers[0], servers[1]} + poolInstances := makeInstances(poolServers, connOpts) + roles := []bool{false, true} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, makeDialers(poolServers), connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + h := &testHandler{} + poolOpts := pool.Opts{ + CheckTimeout: 100 * time.Microsecond, + ConnectionHandler: h, + } + connPool, err := pool.ConnectWithOpts(ctx, poolInstances, poolOpts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + _, err = connPool.Do(tarantool.NewCall17Request("box.cfg"). + Args([]interface{}{map[string]bool{ + "read_only": true, + }}), + pool.RW).GetResponse() + require.Nilf(t, err, "failed to make ro") + + for i := 0; i < 100; i++ { + // Wait for read_only update, it should report about close connection + // with old role. + if atomic.LoadUint32(&h.discovered) >= 3 { + break + } + time.Sleep(poolOpts.CheckTimeout) + } + + discovered := atomic.LoadUint32(&h.discovered) + deactivated := atomic.LoadUint32(&h.deactivated) + require.Equalf(t, uint32(3), discovered, + "updated not reported as discovered") + require.Equalf(t, uint32(1), deactivated, + "updated not reported as deactivated") + + connPool.Close() + + for i := 0; i < 100; i++ { + // Wait for close of all connections. + if atomic.LoadUint32(&h.deactivated) >= 3 { + break + } + time.Sleep(poolOpts.CheckTimeout) + } + + for _, err := range h.errs { + t.Errorf("Unexpected error: %s", err) + } + connected, err := connPool.ConnectedNow(pool.ANY) + require.Nilf(t, err, "failed to get connected state") + require.Falsef(t, connected, "connection pool still be connected") + + discovered = atomic.LoadUint32(&h.discovered) + deactivated = atomic.LoadUint32(&h.deactivated) + require.Equalf(t, uint32(len(poolServers)+1), discovered, + "unexpected discovered count") + require.Equalf(t, uint32(len(poolServers)+1), deactivated, + "unexpected deactivated count") +} + +type testAddErrorHandler struct { + discovered, deactivated int +} + +func (h *testAddErrorHandler) Discovered(name string, conn *tarantool.Connection, + role pool.Role) error { + h.discovered++ + return fmt.Errorf("any error") +} + +func (h *testAddErrorHandler) Deactivated(name string, conn *tarantool.Connection, + role pool.Role) error { + h.deactivated++ + return nil +} + +func TestConnectionHandlerOpenError(t *testing.T) { + poolServers := []string{servers[0], servers[1]} + + h := &testAddErrorHandler{} + poolOpts := pool.Opts{ + CheckTimeout: 100 * time.Microsecond, + ConnectionHandler: h, + } + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + insts := makeInstances(poolServers, connOpts) + connPool, err := pool.ConnectWithOpts(ctx, insts, poolOpts) + if err == nil { + defer connPool.Close() + } + require.NoError(t, err, "failed to connect") + require.NotNil(t, connPool, "pool expected") + require.Equal(t, map[string]pool.ConnectionInfo{ + servers[0]: pool.ConnectionInfo{ + ConnectedNow: false, ConnRole: pool.UnknownRole, Instance: insts[0]}, + servers[1]: pool.ConnectionInfo{ + ConnectedNow: false, ConnRole: pool.UnknownRole, Instance: insts[1]}, + }, connPool.GetInfo()) + connPool.Close() + + // It could happen additional reconnect attempts in the background, but + // at least 2 connects on start. + require.GreaterOrEqualf(t, h.discovered, 2, "unexpected discovered count") + require.Equalf(t, 0, h.deactivated, "unexpected deactivated count") +} + +type testUpdateErrorHandler struct { + discovered, deactivated uint32 +} + +func (h *testUpdateErrorHandler) Discovered(name string, conn *tarantool.Connection, + role pool.Role) error { + atomic.AddUint32(&h.discovered, 1) + + if atomic.LoadUint32(&h.deactivated) != 0 { + // Don't add a connection into a pool again after it was deleted. + return fmt.Errorf("any error") + } + return nil +} + +func (h *testUpdateErrorHandler) Deactivated(name string, conn *tarantool.Connection, + role pool.Role) error { + atomic.AddUint32(&h.deactivated, 1) + return nil +} + +func TestConnectionHandlerUpdateError(t *testing.T) { + poolServers := []string{servers[0], servers[1]} + poolInstances := makeInstances(poolServers, connOpts) + roles := []bool{false, false} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, makeDialers(poolServers), connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + h := &testUpdateErrorHandler{} + poolOpts := pool.Opts{ + CheckTimeout: 100 * time.Microsecond, + ConnectionHandler: h, + } + connPool, err := pool.ConnectWithOpts(ctx, poolInstances, poolOpts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + defer connPool.Close() + + connected, err := connPool.ConnectedNow(pool.ANY) + require.Nilf(t, err, "failed to get ConnectedNow()") + require.Truef(t, connected, "should be connected") + + for i := 0; i < len(poolServers); i++ { + _, err = connPool.Do(tarantool.NewCall17Request("box.cfg"). + Args([]interface{}{map[string]bool{ + "read_only": true, + }}), pool.RW).Get() + require.Nilf(t, err, "failed to make ro") + } + + for i := 0; i < 100; i++ { + // Wait for updates done. + connected, err = connPool.ConnectedNow(pool.ANY) + if !connected || err != nil { + break + } + time.Sleep(poolOpts.CheckTimeout) + } + connected, err = connPool.ConnectedNow(pool.ANY) + + require.Nilf(t, err, "failed to get ConnectedNow()") + require.Falsef(t, connected, "should not be any active connection") + + connPool.Close() + + connected, err = connPool.ConnectedNow(pool.ANY) + + require.Nilf(t, err, "failed to get ConnectedNow()") + require.Falsef(t, connected, "should be deactivated") + discovered := atomic.LoadUint32(&h.discovered) + deactivated := atomic.LoadUint32(&h.deactivated) + require.GreaterOrEqualf(t, discovered, deactivated, "discovered < deactivated") + require.Nilf(t, err, "failed to get ConnectedNow()") +} + +type testDeactivatedErrorHandler struct { + mut sync.Mutex + deactivated []string +} + +func (h *testDeactivatedErrorHandler) Discovered(name string, conn *tarantool.Connection, + role pool.Role) error { + return nil +} + +func (h *testDeactivatedErrorHandler) Deactivated(name string, conn *tarantool.Connection, + role pool.Role) error { + h.mut.Lock() + defer h.mut.Unlock() + + h.deactivated = append(h.deactivated, name) + return nil +} + +func TestConnectionHandlerDeactivated_on_remove(t *testing.T) { + poolServers := []string{servers[0], servers[1]} + poolInstances := makeInstances(poolServers, connOpts) + roles := []bool{false, false} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, makeDialers(poolServers), connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + h := &testDeactivatedErrorHandler{} + poolOpts := pool.Opts{ + CheckTimeout: 100 * time.Microsecond, + ConnectionHandler: h, + } + connPool, err := pool.ConnectWithOpts(ctx, poolInstances, poolOpts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + defer connPool.Close() + + args := test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: servers, + ExpectedPoolStatus: true, + ExpectedStatuses: map[string]bool{ + servers[0]: true, + servers[1]: true, + }, + } + err = test_helpers.CheckPoolStatuses(args) + require.Nil(t, err) + + for _, server := range poolServers { + connPool.Remove(server) + connPool.Remove(server) + } + + args = test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: servers, + ExpectedPoolStatus: false, + } + err = test_helpers.CheckPoolStatuses(args) + require.Nil(t, err) + + h.mut.Lock() + defer h.mut.Unlock() + require.ElementsMatch(t, poolServers, h.deactivated) +} + +func TestRequestOnClosed(t *testing.T) { + server1 := servers[0] + server2 := servers[1] + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, makeInstances([]string{server1, server2}, connOpts)) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + test_helpers.StopTarantoolWithCleanup(helpInstances[0]) + test_helpers.StopTarantoolWithCleanup(helpInstances[1]) + + args := test_helpers.CheckStatusesArgs{ + ConnPool: connPool, + Mode: pool.ANY, + Servers: []string{server1, server2}, + ExpectedPoolStatus: false, + ExpectedStatuses: map[string]bool{ + server1: false, + server2: false, + }, + } + err = test_helpers.Retry(test_helpers.CheckPoolStatuses, args, + defaultCountRetry, defaultTimeoutRetry) + require.Nil(t, err) + + _, err = connPool.Do(tarantool.NewPingRequest(), pool.ANY).Get() + require.NotNilf(t, err, "err is nil after Do with PingRequest") + + err = test_helpers.RestartTarantool(helpInstances[0]) + require.Nilf(t, err, "failed to restart tarantool") + + err = test_helpers.RestartTarantool(helpInstances[1]) + require.Nilf(t, err, "failed to restart tarantool") +} + +func TestDoWithCallRequest(t *testing.T) { + roles := []bool{false, true, false, false, true} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + // PreferRO + data, err := connPool.Do( + tarantool.NewCallRequest("box.info"). + Args([]interface{}{}), + pool.PreferRO).Get() + require.Nilf(t, err, "failed to Do with CallRequest") + require.NotNilf(t, data, "response is nil after Do with CallRequest") + require.GreaterOrEqualf(t, len(data), 1, "response.Data is empty after Do with CallRequest") + + val := data[0].(map[interface{}]interface{})["ro"] + ro, ok := val.(bool) + require.Truef(t, ok, "expected `true` with mode `PreferRO`") + require.Truef(t, ro, "expected `true` with mode `PreferRO`") + + // PreferRW + data, err = connPool.Do( + tarantool.NewCallRequest("box.info"). + Args([]interface{}{}), + pool.PreferRW).Get() + require.Nilf(t, err, "failed to Do with CallRequest") + require.NotNilf(t, data, "response is nil after Do with CallRequest") + require.GreaterOrEqualf(t, len(data), 1, "response.Data is empty after Do with CallRequest") + + val = data[0].(map[interface{}]interface{})["ro"] + ro, ok = val.(bool) + require.Truef(t, ok, "expected `false` with mode `PreferRW`") + require.Falsef(t, ro, "expected `false` with mode `PreferRW`") + + // RO + data, err = connPool.Do( + tarantool.NewCallRequest("box.info"). + Args([]interface{}{}), + pool.RO).Get() + require.Nilf(t, err, "failed to Do with CallRequest") + require.NotNilf(t, data, "response is nil after Do with CallRequest") + require.GreaterOrEqualf(t, len(data), 1, "response.Data is empty after Do with CallRequest") + + val = data[0].(map[interface{}]interface{})["ro"] + ro, ok = val.(bool) + require.Truef(t, ok, "expected `true` with mode `RO`") + require.Truef(t, ro, "expected `true` with mode `RO`") + + // RW + data, err = connPool.Do( + tarantool.NewCallRequest("box.info"). + Args([]interface{}{}), + pool.RW).Get() + require.Nilf(t, err, "failed to Do with CallRequest") + require.NotNilf(t, data, "response is nil after Do with CallRequest") + require.GreaterOrEqualf(t, len(data), 1, "response.Data is empty after Do with CallRequest") + + val = data[0].(map[interface{}]interface{})["ro"] + ro, ok = val.(bool) + require.Truef(t, ok, "expected `false` with mode `RW`") + require.Falsef(t, ro, "expected `false` with mode `RW`") +} + +func TestDoWithCall16Request(t *testing.T) { + roles := []bool{false, true, false, false, true} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + // PreferRO + data, err := connPool.Do( + tarantool.NewCall16Request("box.info"). + Args([]interface{}{}), + pool.PreferRO).Get() + require.Nilf(t, err, "failed to Do with Call16Request") + require.NotNilf(t, data, "response is nil after Do with Call16Request") + require.GreaterOrEqualf(t, len(data), 1, "response.Data is empty after Do with Call16Request") + + val := data[0].([]interface{})[0].(map[interface{}]interface{})["ro"] + ro, ok := val.(bool) + require.Truef(t, ok, "expected `true` with mode `PreferRO`") + require.Truef(t, ro, "expected `true` with mode `PreferRO`") + + // PreferRW + data, err = connPool.Do( + tarantool.NewCall16Request("box.info"). + Args([]interface{}{}), + pool.PreferRW).Get() + require.Nilf(t, err, "failed to Do with Call16Request") + require.NotNilf(t, data, "response is nil after Do with Call16Request") + require.GreaterOrEqualf(t, len(data), 1, "response.Data is empty after Do with Call16Request") + + val = data[0].([]interface{})[0].(map[interface{}]interface{})["ro"] + ro, ok = val.(bool) + require.Truef(t, ok, "expected `false` with mode `PreferRW`") + require.Falsef(t, ro, "expected `false` with mode `PreferRW`") + + // RO + data, err = connPool.Do( + tarantool.NewCall16Request("box.info"). + Args([]interface{}{}), + pool.RO).Get() + require.Nilf(t, err, "failed to Do with Call16Request") + require.NotNilf(t, data, "response is nil after Do with Call16Request") + require.GreaterOrEqualf(t, len(data), 1, "response.Data is empty after Do with Call16Request") + + val = data[0].([]interface{})[0].(map[interface{}]interface{})["ro"] + ro, ok = val.(bool) + require.Truef(t, ok, "expected `true` with mode `RO`") + require.Truef(t, ro, "expected `true` with mode `RO`") + + // RW + data, err = connPool.Do( + tarantool.NewCall16Request("box.info"). + Args([]interface{}{}), + pool.RW).Get() + require.Nilf(t, err, "failed to Do with Call16Request") + require.NotNilf(t, data, "response is nil after Do with Call16Request") + require.GreaterOrEqualf(t, len(data), 1, "response.Data is empty after Do with Call16Request") + + val = data[0].([]interface{})[0].(map[interface{}]interface{})["ro"] + ro, ok = val.(bool) + require.Truef(t, ok, "expected `false` with mode `RW`") + require.Falsef(t, ro, "expected `false` with mode `RW`") +} + +func TestDoWithCall17Request(t *testing.T) { + roles := []bool{false, true, false, false, true} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + // PreferRO + data, err := connPool.Do( + tarantool.NewCall17Request("box.info"). + Args([]interface{}{}), + pool.PreferRO).Get() + require.Nilf(t, err, "failed to Do with Call17Request") + require.NotNilf(t, data, "response is nil after Do with Call17Request") + require.GreaterOrEqualf(t, len(data), 1, "response.Data is empty after Do with Call17Request") + + val := data[0].(map[interface{}]interface{})["ro"] + ro, ok := val.(bool) + require.Truef(t, ok, "expected `true` with mode `PreferRO`") + require.Truef(t, ro, "expected `true` with mode `PreferRO`") + + // PreferRW + data, err = connPool.Do( + tarantool.NewCall17Request("box.info"). + Args([]interface{}{}), + pool.PreferRW).Get() + require.Nilf(t, err, "failed to Do with Call17Request") + require.NotNilf(t, data, "response is nil after Do with Call17Request") + require.GreaterOrEqualf(t, len(data), 1, "response.Data is empty after Do with Call17Request") + + val = data[0].(map[interface{}]interface{})["ro"] + ro, ok = val.(bool) + require.Truef(t, ok, "expected `false` with mode `PreferRW`") + require.Falsef(t, ro, "expected `false` with mode `PreferRW`") + + // RO + data, err = connPool.Do( + tarantool.NewCall17Request("box.info"). + Args([]interface{}{}), + pool.RO).Get() + require.Nilf(t, err, "failed to Do with Call17Request") + require.NotNilf(t, data, "response is nil after Do with Call17Request") + require.GreaterOrEqualf(t, len(data), 1, "response.Data is empty after Do with Call17Request") + + val = data[0].(map[interface{}]interface{})["ro"] + ro, ok = val.(bool) + require.Truef(t, ok, "expected `true` with mode `RO`") + require.Truef(t, ro, "expected `true` with mode `RO`") + + // RW + data, err = connPool.Do( + tarantool.NewCall17Request("box.info"). + Args([]interface{}{}), + pool.RW).Get() + require.Nilf(t, err, "failed to Do with Call17Request") + require.NotNilf(t, data, "response is nil after Do with Call17Request") + require.GreaterOrEqualf(t, len(data), 1, "response.Data is empty after Do with Call17Request") + + val = data[0].(map[interface{}]interface{})["ro"] + ro, ok = val.(bool) + require.Truef(t, ok, "expected `false` with mode `RW`") + require.Falsef(t, ro, "expected `false` with mode `RW`") +} + +func TestDoWithEvalRequest(t *testing.T) { + roles := []bool{false, true, false, false, true} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + // PreferRO + data, err := connPool.Do( + tarantool.NewEvalRequest("return box.info().ro"). + Args([]interface{}{}), + pool.PreferRO).Get() + require.Nilf(t, err, "failed to Do with EvalRequest") + require.NotNilf(t, data, "response is nil after Do with EvalRequest") + require.GreaterOrEqualf(t, len(data), 1, "response.Data is empty after Do with EvalRequest") + + val, ok := data[0].(bool) + require.Truef(t, ok, "expected `true` with mode `PreferRO`") + require.Truef(t, val, "expected `true` with mode `PreferRO`") + + // PreferRW + data, err = connPool.Do( + tarantool.NewEvalRequest("return box.info().ro"). + Args([]interface{}{}), + pool.PreferRW).Get() + require.Nilf(t, err, "failed to Do with EvalRequest") + require.NotNilf(t, data, "response is nil after Do with EvalRequest") + require.GreaterOrEqualf(t, len(data), 1, "response.Data is empty after Do with EvalRequest") + + val, ok = data[0].(bool) + require.Truef(t, ok, "expected `false` with mode `PreferRW`") + require.Falsef(t, val, "expected `false` with mode `PreferRW`") + + // RO + data, err = connPool.Do( + tarantool.NewEvalRequest("return box.info().ro"). + Args([]interface{}{}), + pool.RO).Get() + require.Nilf(t, err, "failed to Do with EvalRequest") + require.NotNilf(t, data, "response is nil after Do with EvalRequest") + require.GreaterOrEqualf(t, len(data), 1, "response.Data is empty after Do with EvalRequest") + + val, ok = data[0].(bool) + require.Truef(t, ok, "expected `true` with mode `RO`") + require.Truef(t, val, "expected `true` with mode `RO`") + + // RW + data, err = connPool.Do( + tarantool.NewEvalRequest("return box.info().ro"). + Args([]interface{}{}), + pool.RW).Get() + require.Nilf(t, err, "failed to Do with EvalRequest") + require.NotNilf(t, data, "response is nil after Do with EvalRequest") + require.GreaterOrEqualf(t, len(data), 1, "response.Data is empty after Do with EvalRequest") + + val, ok = data[0].(bool) + require.Truef(t, ok, "expected `false` with mode `RW`") + require.Falsef(t, val, "expected `false` with mode `RW`") +} + +type Member struct { + id uint + val string +} + +func (m *Member) DecodeMsgpack(d *msgpack.Decoder) error { + var err error + var l int + if l, err = d.DecodeArrayLen(); err != nil { + return err + } + if l != 2 { + return fmt.Errorf("array len doesn't match: %d", l) + } + if m.id, err = d.DecodeUint(); err != nil { + return err + } + if m.val, err = d.DecodeString(); err != nil { + return err + } + return nil +} + +func TestDoWithExecuteRequest(t *testing.T) { + test_helpers.SkipIfSQLUnsupported(t) + + roles := []bool{false, true, false, false, true} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + request := "SELECT NAME0, NAME1 FROM SQL_TEST WHERE NAME0 == 1;" + mem := []Member{} + + fut := connPool.Do(tarantool.NewExecuteRequest(request).Args([]interface{}{}), pool.ANY) + data, err := fut.Get() + require.Nilf(t, err, "failed to Do with ExecuteRequest") + require.NotNilf(t, data, "response is nil after Execute") + require.GreaterOrEqualf(t, len(data), 1, "response.Data is empty after Do with ExecuteRequest") + require.Equalf(t, len(data[0].([]interface{})), 2, "unexpected response") + err = fut.GetTyped(&mem) + require.Nilf(t, err, "Unable to GetTyped of fut") + require.Equalf(t, len(mem), 1, "wrong count of result") +} + +func TestRoundRobinStrategy(t *testing.T) { + roles := []bool{false, true, false, false, true} + + allPorts := map[string]bool{ + servers[0]: true, + servers[1]: true, + servers[2]: true, + servers[3]: true, + servers[4]: true, + } + + masterPorts := map[string]bool{ + servers[0]: true, + servers[2]: true, + servers[3]: true, + } + + replicaPorts := map[string]bool{ + servers[1]: true, + servers[4]: true, + } + + serversNumber := len(servers) + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + // ANY + args := test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: allPorts, + ConnPool: connPool, + Mode: pool.ANY, + } + + err = test_helpers.ProcessListenOnInstance(args) + require.Nil(t, err) + + // RW + args = test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: masterPorts, + ConnPool: connPool, + Mode: pool.RW, + } + + err = test_helpers.ProcessListenOnInstance(args) + require.Nil(t, err) + + // RO + args = test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: replicaPorts, + ConnPool: connPool, + Mode: pool.RO, + } + + err = test_helpers.ProcessListenOnInstance(args) + require.Nil(t, err) + + // PreferRW + args = test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: masterPorts, + ConnPool: connPool, + Mode: pool.PreferRW, + } + + err = test_helpers.ProcessListenOnInstance(args) + require.Nil(t, err) + + // PreferRO + args = test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: replicaPorts, + ConnPool: connPool, + Mode: pool.PreferRO, + } + + err = test_helpers.ProcessListenOnInstance(args) + require.Nil(t, err) +} + +func TestRoundRobinStrategy_NoReplica(t *testing.T) { + roles := []bool{false, false, false, false, false} + serversNumber := len(servers) + + allPorts := map[string]bool{ + servers[0]: true, + servers[1]: true, + servers[2]: true, + servers[3]: true, + servers[4]: true, + } + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + // RO + _, err = connPool.Do( + tarantool.NewEvalRequest("return box.cfg.listen"). + Args([]interface{}{}), + pool.RO).Get() + require.NotNilf(t, err, "expected to fail after Do with EvalRequest, but error is nil") + require.Equal(t, "can't find ro instance in pool", err.Error()) + + // ANY + args := test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: allPorts, + ConnPool: connPool, + Mode: pool.ANY, + } + + err = test_helpers.ProcessListenOnInstance(args) + require.Nil(t, err) + + // RW + args = test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: allPorts, + ConnPool: connPool, + Mode: pool.RW, + } + + err = test_helpers.ProcessListenOnInstance(args) + require.Nil(t, err) + + // PreferRW + args = test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: allPorts, + ConnPool: connPool, + Mode: pool.PreferRW, + } + + err = test_helpers.ProcessListenOnInstance(args) + require.Nil(t, err) + + // PreferRO + args = test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: allPorts, + ConnPool: connPool, + Mode: pool.PreferRO, + } + + err = test_helpers.ProcessListenOnInstance(args) + require.Nil(t, err) +} + +func TestRoundRobinStrategy_NoMaster(t *testing.T) { + roles := []bool{true, true, true, true, true} + serversNumber := len(servers) + + allPorts := map[string]bool{ + servers[0]: true, + servers[1]: true, + servers[2]: true, + servers[3]: true, + servers[4]: true, + } + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + // RW + _, err = connPool.Do( + tarantool.NewEvalRequest("return box.cfg.listen"). + Args([]interface{}{}), + pool.RW).Get() + require.NotNilf(t, err, "expected to fail after Do with EvalRequest, but error is nil") + require.Equal(t, "can't find rw instance in pool", err.Error()) + + // ANY + args := test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: allPorts, + ConnPool: connPool, + Mode: pool.ANY, + } + + err = test_helpers.ProcessListenOnInstance(args) + require.Nil(t, err) + + // RO + args = test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: allPorts, + ConnPool: connPool, + Mode: pool.RO, + } + + err = test_helpers.ProcessListenOnInstance(args) + require.Nil(t, err) + + // PreferRW + args = test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: allPorts, + ConnPool: connPool, + Mode: pool.PreferRW, + } + + err = test_helpers.ProcessListenOnInstance(args) + require.Nil(t, err) + + // PreferRO + args = test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: allPorts, + ConnPool: connPool, + Mode: pool.PreferRO, + } + + err = test_helpers.ProcessListenOnInstance(args) + require.Nil(t, err) +} + +func TestUpdateInstancesRoles(t *testing.T) { + roles := []bool{false, true, false, false, true} + + allPorts := map[string]bool{ + servers[0]: true, + servers[1]: true, + servers[2]: true, + servers[3]: true, + servers[4]: true, + } + + masterPorts := map[string]bool{ + servers[0]: true, + servers[2]: true, + servers[3]: true, + } + + replicaPorts := map[string]bool{ + servers[1]: true, + servers[4]: true, + } + + serversNumber := len(servers) + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + // ANY + args := test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: allPorts, + ConnPool: connPool, + Mode: pool.ANY, + } + + err = test_helpers.ProcessListenOnInstance(args) + require.Nil(t, err) + + // RW + args = test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: masterPorts, + ConnPool: connPool, + Mode: pool.RW, + } + + err = test_helpers.ProcessListenOnInstance(args) + require.Nil(t, err) + + // RO + args = test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: replicaPorts, + ConnPool: connPool, + Mode: pool.RO, + } + + err = test_helpers.ProcessListenOnInstance(args) + require.Nil(t, err) + + // PreferRW + args = test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: masterPorts, + ConnPool: connPool, + Mode: pool.PreferRW, + } + + err = test_helpers.ProcessListenOnInstance(args) + require.Nil(t, err) + + // PreferRO + args = test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: replicaPorts, + ConnPool: connPool, + Mode: pool.PreferRO, + } + + err = test_helpers.ProcessListenOnInstance(args) + require.Nil(t, err) + + roles = []bool{true, false, true, true, false} + + masterPorts = map[string]bool{ + servers[1]: true, + servers[4]: true, + } + + replicaPorts = map[string]bool{ + servers[0]: true, + servers[2]: true, + servers[3]: true, + } + + ctxSetRoles, cancelSetRoles := test_helpers.GetPoolConnectContext() + err = test_helpers.SetClusterRO(ctxSetRoles, dialers, connOpts, roles) + cancelSetRoles() + require.Nilf(t, err, "fail to set roles for cluster") + + // ANY + args = test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: allPorts, + ConnPool: connPool, + Mode: pool.ANY, + } + + err = test_helpers.Retry(test_helpers.ProcessListenOnInstance, args, + defaultCountRetry, defaultTimeoutRetry) + require.Nil(t, err) + + // RW + args = test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: masterPorts, + ConnPool: connPool, + Mode: pool.RW, + } + + err = test_helpers.Retry(test_helpers.ProcessListenOnInstance, args, + defaultCountRetry, defaultTimeoutRetry) + require.Nil(t, err) + + // RO + args = test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: replicaPorts, + ConnPool: connPool, + Mode: pool.RO, + } + + err = test_helpers.Retry(test_helpers.ProcessListenOnInstance, args, + defaultCountRetry, defaultTimeoutRetry) + require.Nil(t, err) + + // PreferRW + args = test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: masterPorts, + ConnPool: connPool, + Mode: pool.PreferRW, + } + + err = test_helpers.Retry(test_helpers.ProcessListenOnInstance, args, + defaultCountRetry, defaultTimeoutRetry) + require.Nil(t, err) + + // PreferRO + args = test_helpers.ListenOnInstanceArgs{ + ServersNumber: serversNumber, + ExpectedPorts: replicaPorts, + ConnPool: connPool, + Mode: pool.PreferRO, + } + + err = test_helpers.Retry(test_helpers.ProcessListenOnInstance, args, + defaultCountRetry, defaultTimeoutRetry) + require.Nil(t, err) +} + +func TestDoWithInsertRequest(t *testing.T) { + roles := []bool{true, true, false, true, true} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + // RW + data, err := connPool.Do(tarantool.NewInsertRequest(spaceName). + Tuple([]interface{}{"rw_insert_key", "rw_insert_value"}), + pool.RW).Get() + require.Nilf(t, err, "failed to Insert") + require.NotNilf(t, data, "response is nil after Insert") + require.Equalf(t, len(data), 1, "response Body len != 1 after Insert") + + tpl, ok := data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Insert") + require.Equalf(t, 2, len(tpl), "unexpected body of Insert") + + key, ok := tpl[0].(string) + require.Truef(t, ok, "unexpected body of Insert (0)") + require.Equalf(t, "rw_insert_key", key, "unexpected body of Insert (0)") + + value, ok := tpl[1].(string) + require.Truef(t, ok, "unexpected body of Insert (1)") + require.Equalf(t, "rw_insert_value", value, "unexpected body of Insert (1)") + + // Connect to servers[2] to check if tuple + // was inserted on RW instance + conn := test_helpers.ConnectWithValidation(t, makeDialer(servers[2]), connOpts) + defer conn.Close() + + sel := tarantool.NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(tarantool.IterEq). + Key([]interface{}{"rw_insert_key"}) + data, err = conn.Do(sel).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, len(data), 1, "response Body len != 1 after Select") + + tpl, ok = data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Select") + require.Equalf(t, 2, len(tpl), "unexpected body of Select") + + key, ok = tpl[0].(string) + require.Truef(t, ok, "unexpected body of Select (0)") + require.Equalf(t, "rw_insert_key", key, "unexpected body of Select (0)") + + value, ok = tpl[1].(string) + require.Truef(t, ok, "unexpected body of Select (1)") + require.Equalf(t, "rw_insert_value", value, "unexpected body of Select (1)") + + // PreferRW + data, err = connPool.Do(tarantool.NewInsertRequest(spaceName).Tuple( + []interface{}{"preferRW_insert_key", "preferRW_insert_value"}), pool.PreferRW).Get() + require.Nilf(t, err, "failed to Insert") + require.NotNilf(t, data, "response is nil after Insert") + require.Equalf(t, len(data), 1, "response Body len != 1 after Insert") + + tpl, ok = data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Insert") + require.Equalf(t, 2, len(tpl), "unexpected body of Insert") + + key, ok = tpl[0].(string) + require.Truef(t, ok, "unexpected body of Insert (0)") + require.Equalf(t, "preferRW_insert_key", key, "unexpected body of Insert (0)") + + value, ok = tpl[1].(string) + require.Truef(t, ok, "unexpected body of Insert (1)") + require.Equalf(t, "preferRW_insert_value", value, "unexpected body of Insert (1)") + + sel = tarantool.NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(tarantool.IterEq). + Key([]interface{}{"preferRW_insert_key"}) + data, err = conn.Do(sel).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, len(data), 1, "response Body len != 1 after Select") + + tpl, ok = data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Select") + require.Equalf(t, 2, len(tpl), "unexpected body of Select") + + key, ok = tpl[0].(string) + require.Truef(t, ok, "unexpected body of Select (0)") + require.Equalf(t, "preferRW_insert_key", key, "unexpected body of Select (0)") + + value, ok = tpl[1].(string) + require.Truef(t, ok, "unexpected body of Select (1)") + require.Equalf(t, "preferRW_insert_value", value, "unexpected body of Select (1)") +} + +func TestDoWithDeleteRequest(t *testing.T) { + roles := []bool{true, true, false, true, true} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + // Connect to servers[2] to check if tuple + // was inserted on RW instance + conn := test_helpers.ConnectWithValidation(t, makeDialer(servers[2]), connOpts) + defer conn.Close() + + ins := tarantool.NewInsertRequest(spaceNo).Tuple([]interface{}{"delete_key", "delete_value"}) + data, err := conn.Do(ins).Get() + require.Nilf(t, err, "failed to Insert") + require.Equalf(t, len(data), 1, "response Body len != 1 after Insert") + + tpl, ok := data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Insert") + require.Equalf(t, 2, len(tpl), "unexpected body of Insert") + + key, ok := tpl[0].(string) + require.Truef(t, ok, "unexpected body of Insert (0)") + require.Equalf(t, "delete_key", key, "unexpected body of Insert (0)") + + value, ok := tpl[1].(string) + require.Truef(t, ok, "unexpected body of Insert (1)") + require.Equalf(t, "delete_value", value, "unexpected body of Insert (1)") + + data, err = connPool.Do( + tarantool.NewDeleteRequest(spaceName). + Index(indexNo). + Key([]interface{}{"delete_key"}), + pool.RW).Get() + require.Nilf(t, err, "failed to Do with DeleteRequest") + require.NotNilf(t, data, "response is nil after Do with DeleteRequest") + require.Equalf(t, len(data), 1, "response Body len != 1 after Do with DeleteRequest") + + tpl, ok = data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Do with DeleteRequest") + require.Equalf(t, 2, len(tpl), "unexpected body of Do with DeleteRequest") + + key, ok = tpl[0].(string) + require.Truef(t, ok, "unexpected body of Do with DeleteRequest (0)") + require.Equalf(t, "delete_key", key, "unexpected body of Do with DeleteRequest (0)") + + value, ok = tpl[1].(string) + require.Truef(t, ok, "unexpected body of Do with DeleteRequest (1)") + require.Equalf(t, "delete_value", value, "unexpected body of Do with DeleteRequest (1)") + + sel := tarantool.NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(tarantool.IterEq). + Key([]interface{}{"delete_key"}) + data, err = conn.Do(sel).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, 0, len(data), "response Body len != 0 after Select") +} + +func TestDoWithUpsertRequest(t *testing.T) { + roles := []bool{true, true, false, true, true} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + // Connect to servers[2] to check if tuple + // was inserted on RW instance + conn := test_helpers.ConnectWithValidation(t, makeDialer(servers[2]), connOpts) + defer conn.Close() + + // RW + data, err := connPool.Do(tarantool.NewUpsertRequest(spaceName).Tuple( + []interface{}{"upsert_key", "upsert_value"}).Operations( + tarantool.NewOperations().Assign(1, "new_value")), pool.RW).Get() + require.Nilf(t, err, "failed to Upsert") + require.NotNilf(t, data, "response is nil after Upsert") + + sel := tarantool.NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(tarantool.IterEq). + Key([]interface{}{"upsert_key"}) + data, err = conn.Do(sel).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, len(data), 1, "response Body len != 1 after Select") + + tpl, ok := data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Select") + require.Equalf(t, 2, len(tpl), "unexpected body of Select") + + key, ok := tpl[0].(string) + require.Truef(t, ok, "unexpected body of Select (0)") + require.Equalf(t, "upsert_key", key, "unexpected body of Select (0)") + + value, ok := tpl[1].(string) + require.Truef(t, ok, "unexpected body of Select (1)") + require.Equalf(t, "upsert_value", value, "unexpected body of Select (1)") + + // PreferRW + data, err = connPool.Do(tarantool.NewUpsertRequest( + spaceName).Tuple([]interface{}{"upsert_key", "upsert_value"}).Operations( + tarantool.NewOperations().Assign(1, "new_value")), pool.PreferRW).Get() + + require.Nilf(t, err, "failed to Upsert") + require.NotNilf(t, data, "response is nil after Upsert") + + data, err = conn.Do(sel).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, len(data), 1, "response Body len != 1 after Select") + + tpl, ok = data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Select") + require.Equalf(t, 2, len(tpl), "unexpected body of Select") + + key, ok = tpl[0].(string) + require.Truef(t, ok, "unexpected body of Select (0)") + require.Equalf(t, "upsert_key", key, "unexpected body of Select (0)") + + value, ok = tpl[1].(string) + require.Truef(t, ok, "unexpected body of Select (1)") + require.Equalf(t, "new_value", value, "unexpected body of Select (1)") +} + +func TestDoWithUpdateRequest(t *testing.T) { + roles := []bool{true, true, false, true, true} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + // Connect to servers[2] to check if tuple + // was inserted on RW instance + conn := test_helpers.ConnectWithValidation(t, makeDialer(servers[2]), connOpts) + defer conn.Close() + + ins := tarantool.NewInsertRequest(spaceNo). + Tuple([]interface{}{"update_key", "update_value"}) + data, err := conn.Do(ins).Get() + require.Nilf(t, err, "failed to Insert") + require.Equalf(t, len(data), 1, "response Body len != 1 after Insert") + + tpl, ok := data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Insert") + require.Equalf(t, 2, len(tpl), "unexpected body of Insert") + + key, ok := tpl[0].(string) + require.Truef(t, ok, "unexpected body of Insert (0)") + require.Equalf(t, "update_key", key, "unexpected body of Insert (0)") + + value, ok := tpl[1].(string) + require.Truef(t, ok, "unexpected body of Insert (1)") + require.Equalf(t, "update_value", value, "unexpected body of Insert (1)") + + // RW + resp, err := connPool.Do(tarantool.NewUpdateRequest(spaceName). + Index(indexNo). + Key([]interface{}{"update_key"}). + Operations(tarantool.NewOperations().Assign(1, "new_value")), + pool.RW).Get() + require.Nilf(t, err, "failed to Update") + require.NotNilf(t, resp, "response is nil after Update") + + sel := tarantool.NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(tarantool.IterEq). + Key([]interface{}{"update_key"}) + data, err = conn.Do(sel).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, len(data), 1, "response Body len != 1 after Select") + + tpl, ok = data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Select") + require.Equalf(t, 2, len(tpl), "unexpected body of Select") + + key, ok = tpl[0].(string) + require.Truef(t, ok, "unexpected body of Select (0)") + require.Equalf(t, "update_key", key, "unexpected body of Select (0)") + + value, ok = tpl[1].(string) + require.Truef(t, ok, "unexpected body of Select (1)") + require.Equalf(t, "new_value", value, "unexpected body of Select (1)") + + // PreferRW + resp, err = connPool.Do(tarantool.NewUpdateRequest(spaceName). + Index(indexNo). + Key([]interface{}{"update_key"}). + Operations(tarantool.NewOperations().Assign(1, "another_value")), + pool.PreferRW).Get() + + require.Nilf(t, err, "failed to Update") + require.NotNilf(t, resp, "response is nil after Update") + + data, err = conn.Do(sel).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, len(data), 1, "response Body len != 1 after Select") + + tpl, ok = data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Select") + require.Equalf(t, 2, len(tpl), "unexpected body of Select") + + key, ok = tpl[0].(string) + require.Truef(t, ok, "unexpected body of Select (0)") + require.Equalf(t, "update_key", key, "unexpected body of Select (0)") + + value, ok = tpl[1].(string) + require.Truef(t, ok, "unexpected body of Select (1)") + require.Equalf(t, "another_value", value, "unexpected body of Select (1)") +} + +func TestDoWithReplaceRequest(t *testing.T) { + roles := []bool{true, true, false, true, true} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + // Connect to servers[2] to check if tuple + // was inserted on RW instance + conn := test_helpers.ConnectWithValidation(t, makeDialer(servers[2]), connOpts) + defer conn.Close() + + ins := tarantool.NewInsertRequest(spaceNo). + Tuple([]interface{}{"replace_key", "replace_value"}) + data, err := conn.Do(ins).Get() + require.Nilf(t, err, "failed to Insert") + require.Equalf(t, len(data), 1, "response Body len != 1 after Insert") + + tpl, ok := data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Insert") + require.Equalf(t, 2, len(tpl), "unexpected body of Insert") + + key, ok := tpl[0].(string) + require.Truef(t, ok, "unexpected body of Insert (0)") + require.Equalf(t, "replace_key", key, "unexpected body of Insert (0)") + + value, ok := tpl[1].(string) + require.Truef(t, ok, "unexpected body of Insert (1)") + require.Equalf(t, "replace_value", value, "unexpected body of Insert (1)") + + // RW + resp, err := connPool.Do(tarantool.NewReplaceRequest(spaceNo). + Tuple([]interface{}{"new_key", "new_value"}), + pool.RW).Get() + require.Nilf(t, err, "failed to Do with ReplaceRequest") + require.NotNilf(t, resp, "response is nil after Do with ReplaceRequest") + + sel := tarantool.NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(tarantool.IterEq). + Key([]interface{}{"new_key"}) + data, err = conn.Do(sel).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, len(data), 1, "response Body len != 1 after Select") + + tpl, ok = data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Select") + require.Equalf(t, 2, len(tpl), "unexpected body of Select") + + key, ok = tpl[0].(string) + require.Truef(t, ok, "unexpected body of Select (0)") + require.Equalf(t, "new_key", key, "unexpected body of Select (0)") + + value, ok = tpl[1].(string) + require.Truef(t, ok, "unexpected body of Select (1)") + require.Equalf(t, "new_value", value, "unexpected body of Select (1)") + + // PreferRW + resp, err = connPool.Do(tarantool.NewReplaceRequest(spaceNo). + Tuple([]interface{}{"new_key", "new_value"}), + pool.PreferRW).Get() + require.Nilf(t, err, "failed to Do with ReplaceRequest") + require.NotNilf(t, resp, "response is nil after Do with ReplaceRequest") + + data, err = conn.Do(sel).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, len(data), 1, "response Body len != 1 after Select") + + tpl, ok = data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Select") + require.Equalf(t, 2, len(tpl), "unexpected body of Select") + + key, ok = tpl[0].(string) + require.Truef(t, ok, "unexpected body of Select (0)") + require.Equalf(t, "new_key", key, "unexpected body of Select (0)") + + value, ok = tpl[1].(string) + require.Truef(t, ok, "unexpected body of Select (1)") + require.Equalf(t, "new_value", value, "unexpected body of Select (1)") +} + +func TestDoWithSelectRequest(t *testing.T) { + roles := []bool{true, true, false, true, false} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + roServers := []string{servers[0], servers[1], servers[3]} + rwServers := []string{servers[2], servers[4]} + allServers := []string{servers[0], servers[1], servers[2], servers[3], servers[4]} + + roTpl := []interface{}{"ro_select_key", "ro_select_value"} + rwTpl := []interface{}{"rw_select_key", "rw_select_value"} + anyTpl := []interface{}{"any_select_key", "any_select_value"} + + roKey := []interface{}{"ro_select_key"} + rwKey := []interface{}{"rw_select_key"} + anyKey := []interface{}{"any_select_key"} + + err = test_helpers.InsertOnInstances(ctx, makeDialers(roServers), connOpts, spaceNo, roTpl) + require.Nil(t, err) + + err = test_helpers.InsertOnInstances(ctx, makeDialers(rwServers), connOpts, spaceNo, rwTpl) + require.Nil(t, err) + + err = test_helpers.InsertOnInstances(ctx, makeDialers(allServers), connOpts, spaceNo, anyTpl) + require.Nil(t, err) + + // ANY + data, err := connPool.Do(tarantool.NewSelectRequest(spaceNo). + Index(indexNo). + Offset(0). + Limit(1). + Iterator(tarantool.IterEq). + Key(anyKey), + pool.ANY).Get() + require.Nilf(t, err, "failed to Do with SelectRequest") + require.NotNilf(t, data, "response is nil after Do with SelectRequest") + require.Equalf(t, len(data), 1, "response Body len != 1 after Do with SelectRequest") + + tpl, ok := data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Do with SelectRequest") + require.Equalf(t, 2, len(tpl), "unexpected body of Do with SelectRequest") + + key, ok := tpl[0].(string) + require.Truef(t, ok, "unexpected body of Do with SelectRequest (0)") + require.Equalf(t, "any_select_key", key, "unexpected body of Do with SelectRequest (0)") + + value, ok := tpl[1].(string) + require.Truef(t, ok, "unexpected body of Do with SelectRequest (1)") + require.Equalf(t, "any_select_value", value, "unexpected body of Do with SelectRequest (1)") + + // PreferRO + data, err = connPool.Do(tarantool.NewSelectRequest(spaceNo). + Index(indexNo). + Offset(0). + Limit(1). + Iterator(tarantool.IterEq). + Key(roKey), + pool.PreferRO).Get() + require.Nilf(t, err, "failed to Do with SelectRequest") + require.NotNilf(t, data, "response is nil after Do with SelectRequest") + require.Equalf(t, len(data), 1, "response Body len != 1 after Do with SelectRequest") + + tpl, ok = data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Do with SelectRequest") + require.Equalf(t, 2, len(tpl), "unexpected body of Do with SelectRequest") + + key, ok = tpl[0].(string) + require.Truef(t, ok, "unexpected body of Do with SelectRequest (0)") + require.Equalf(t, "ro_select_key", key, "unexpected body of Do with SelectRequest (0)") + + // PreferRW + data, err = connPool.Do(tarantool.NewSelectRequest(spaceNo). + Index(indexNo). + Offset(0). + Limit(1). + Iterator(tarantool.IterEq). + Key(rwKey), + pool.PreferRW).Get() + require.Nilf(t, err, "failed to Do with SelectRequest") + require.NotNilf(t, data, "response is nil after Do with SelectRequest") + require.Equalf(t, len(data), 1, "response Body len != 1 after Do with SelectRequest") + + tpl, ok = data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Do with SelectRequest") + require.Equalf(t, 2, len(tpl), "unexpected body of Do with SelectRequest") + + key, ok = tpl[0].(string) + require.Truef(t, ok, "unexpected body of Do with SelectRequest (0)") + require.Equalf(t, "rw_select_key", key, "unexpected body of Do with SelectRequest (0)") + + value, ok = tpl[1].(string) + require.Truef(t, ok, "unexpected body of Do with SelectRequest (1)") + require.Equalf(t, "rw_select_value", value, "unexpected body of Do with SelectRequest (1)") + + // RO + data, err = connPool.Do(tarantool.NewSelectRequest(spaceNo). + Index(indexNo). + Offset(0). + Limit(1). + Iterator(tarantool.IterEq). + Key(roKey), + pool.RO).Get() + require.Nilf(t, err, "failed to Do with SelectRequest") + require.NotNilf(t, data, "response is nil after Do with SelectRequest") + require.Equalf(t, len(data), 1, "response Body len != 1 after Do with SelectRequest") + + tpl, ok = data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Do with SelectRequest") + require.Equalf(t, 2, len(tpl), "unexpected body of Do with SelectRequest") + + key, ok = tpl[0].(string) + require.Truef(t, ok, "unexpected body of Do with SelectRequest (0)") + require.Equalf(t, "ro_select_key", key, "unexpected body of Do with SelectRequest (0)") + + value, ok = tpl[1].(string) + require.Truef(t, ok, "unexpected body of Do with SelectRequest (1)") + require.Equalf(t, "ro_select_value", value, "unexpected body of Do with SelectRequest (1)") + + // RW + data, err = connPool.Do(tarantool.NewSelectRequest(spaceNo). + Index(indexNo). + Offset(0). + Limit(1). + Iterator(tarantool.IterEq). + Key(rwKey), + pool.RW).Get() + require.Nilf(t, err, "failed to Do with SelectRequest") + require.NotNilf(t, data, "response is nil after Do with SelectRequest") + require.Equalf(t, len(data), 1, "response Body len != 1 after Do with SelectRequest") + + tpl, ok = data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Do with SelectRequest") + require.Equalf(t, 2, len(tpl), "unexpected body of Do with SelectRequest") + + key, ok = tpl[0].(string) + require.Truef(t, ok, "unexpected body of Do with SelectRequest (0)") + require.Equalf(t, "rw_select_key", key, "unexpected body of Do with SelectRequest (0)") + + value, ok = tpl[1].(string) + require.Truef(t, ok, "unexpected body of Do with SelectRequest (1)") + require.Equalf(t, "rw_select_value", value, "unexpected body of Do with SelectRequest (1)") +} + +func TestDoWithPingRequest(t *testing.T) { + roles := []bool{true, true, false, true, false} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + // ANY + data, err := connPool.Do(tarantool.NewPingRequest(), pool.ANY).Get() + require.Nilf(t, err, "failed to Do with Ping Request") + require.Nilf(t, data, "response data is not nil after Do with Ping Request") + + // RW + data, err = connPool.Do(tarantool.NewPingRequest(), pool.RW).Get() + require.Nilf(t, err, "failed to Do with Ping Request") + require.Nilf(t, data, "response data is not nil after Do with Ping Request") + + // RO + _, err = connPool.Do(tarantool.NewPingRequest(), pool.RO).Get() + require.Nilf(t, err, "failed to Do with Ping Request") + + // PreferRW + data, err = connPool.Do(tarantool.NewPingRequest(), pool.PreferRW).Get() + require.Nilf(t, err, "failed to Do with Ping Request") + require.Nilf(t, data, "response data is not nil after Do with Ping Request") + + // PreferRO + data, err = connPool.Do(tarantool.NewPingRequest(), pool.PreferRO).Get() + require.Nilf(t, err, "failed to Do with Ping Request") + require.Nilf(t, data, "response data is not nil after Do with Ping Request") +} + +func TestDo(t *testing.T) { + roles := []bool{true, true, false, true, false} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + req := tarantool.NewPingRequest() + // ANY + _, err = connPool.Do(req, pool.ANY).Get() + require.Nilf(t, err, "failed to Ping") + + // RW + _, err = connPool.Do(req, pool.RW).Get() + require.Nilf(t, err, "failed to Ping") + + // RO + _, err = connPool.Do(req, pool.RO).Get() + require.Nilf(t, err, "failed to Ping") + + // PreferRW + _, err = connPool.Do(req, pool.PreferRW).Get() + require.Nilf(t, err, "failed to Ping") + + // PreferRO + _, err = connPool.Do(req, pool.PreferRO).Get() + require.Nilf(t, err, "failed to Ping") +} + +func TestDo_concurrent(t *testing.T) { + roles := []bool{true, true, false, true, false} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + req := tarantool.NewPingRequest() + const concurrency = 100 + var wg sync.WaitGroup + wg.Add(concurrency) + + for i := 0; i < concurrency; i++ { + go func() { + defer wg.Done() + + _, err := connPool.Do(req, pool.ANY).Get() + assert.Nil(t, err) + }() + } + + wg.Wait() +} + +func TestDoInstance(t *testing.T) { + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + req := tarantool.NewEvalRequest("return box.cfg.listen") + for _, server := range servers { + data, err := connPool.DoInstance(req, server).Get() + require.NoError(t, err) + assert.Equal(t, []interface{}{server}, data) + } +} + +func TestDoInstance_not_found(t *testing.T) { + roles := []bool{true, true, false, true, false} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, []pool.Instance{}) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + data, err := connPool.DoInstance(tarantool.NewPingRequest(), "not_exist").Get() + assert.Nil(t, data) + require.ErrorIs(t, err, pool.ErrNoHealthyInstance) +} + +func TestDoInstance_concurrent(t *testing.T) { + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + eval := tarantool.NewEvalRequest("return box.cfg.listen") + ping := tarantool.NewPingRequest() + const concurrency = 100 + var wg sync.WaitGroup + wg.Add(concurrency) + + for i := 0; i < concurrency; i++ { + go func() { + defer wg.Done() + + for _, server := range servers { + data, err := connPool.DoInstance(eval, server).Get() + require.NoError(t, err) + assert.Equal(t, []interface{}{server}, data) + } + _, err := connPool.DoInstance(ping, "not_exist").Get() + require.ErrorIs(t, err, pool.ErrNoHealthyInstance) + }() + } + + wg.Wait() +} + +func TestNewPrepared(t *testing.T) { + test_helpers.SkipIfSQLUnsupported(t) + + roles := []bool{true, true, false, true, false} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + stmt, err := connPool.NewPrepared( + "SELECT NAME0, NAME1 FROM SQL_TEST WHERE NAME0=:id AND NAME1=:name;", pool.RO) + require.Nilf(t, err, "fail to prepare statement: %v", err) + + executeReq := tarantool.NewExecutePreparedRequest(stmt) + unprepareReq := tarantool.NewUnprepareRequest(stmt) + + resp, err := connPool.Do(executeReq.Args([]interface{}{1, "test"}), pool.ANY).GetResponse() + if err != nil { + t.Fatalf("failed to execute prepared: %v", err) + } + if resp == nil { + t.Fatalf("nil response") + } + data, err := resp.Decode() + if err != nil { + t.Fatalf("failed to Decode: %s", err.Error()) + } + if reflect.DeepEqual(data[0], []interface{}{1, "test"}) { + t.Error("Select with named arguments failed") + } + prepResp, ok := resp.(*tarantool.ExecuteResponse) + if !ok { + t.Fatalf("Not a Prepare response") + } + metaData, err := prepResp.MetaData() + if err != nil { + t.Errorf("Error while getting MetaData: %s", err.Error()) + } + if metaData[0].FieldType != "unsigned" || + metaData[0].FieldName != "NAME0" || + metaData[1].FieldType != "string" || + metaData[1].FieldName != "NAME1" { + t.Error("Wrong metadata") + } + + // the second argument for unprepare request is unused - it already belongs to some connection + _, err = connPool.Do(unprepareReq, pool.ANY).Get() + if err != nil { + t.Errorf("failed to unprepare prepared statement: %v", err) + } + + _, err = connPool.Do(unprepareReq, pool.ANY).Get() + if err == nil { + t.Errorf("the statement must be already unprepared") + } + require.Contains(t, err.Error(), "Prepared statement with id") + + _, err = connPool.Do(executeReq, pool.ANY).Get() + if err == nil { + t.Errorf("the statement must be already unprepared") + } + require.Contains(t, err.Error(), "Prepared statement with id") +} + +func TestDoWithStrangerConn(t *testing.T) { + expectedErr := fmt.Errorf("the passed connected request doesn't belong to " + + "the current connection pool") + + roles := []bool{true, true, false, true, false} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + req := test_helpers.NewMockRequest() + + _, err = connPool.Do(req, pool.ANY).Get() + if err == nil { + t.Fatalf("nil error caught") + } + if err.Error() != expectedErr.Error() { + t.Fatalf("Unexpected error caught") + } +} + +func TestStream_Commit(t *testing.T) { + var req tarantool.Request + var err error + + test_helpers.SkipIfStreamsUnsupported(t) + + roles := []bool{true, true, false, true, true} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err = test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + defer connPool.Close() + + stream, err := connPool.NewStream(pool.PreferRW) + require.Nilf(t, err, "failed to create stream") + require.NotNilf(t, connPool, "stream is nil after NewStream") + + // Begin transaction + req = tarantool.NewBeginRequest() + _, err = stream.Do(req).Get() + require.Nilf(t, err, "failed to Begin") + + // Insert in stream + req = tarantool.NewInsertRequest(spaceName). + Tuple([]interface{}{"commit_key", "commit_value"}) + _, err = stream.Do(req).Get() + require.Nilf(t, err, "failed to Insert") + + // Connect to servers[2] to check if tuple + // was inserted outside of stream on RW instance + // before transaction commit + conn := test_helpers.ConnectWithValidation(t, makeDialer(servers[2]), connOpts) + defer conn.Close() + + // Select not related to the transaction + // while transaction is not committed + // result of select is empty + selectReq := tarantool.NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(tarantool.IterEq). + Key([]interface{}{"commit_key"}) + data, err := conn.Do(selectReq).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, 0, len(data), "response Data len != 0") + + // Select in stream + data, err = stream.Do(selectReq).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, 1, len(data), "response Body len != 1 after Select") + + tpl, ok := data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Select") + require.Equalf(t, 2, len(tpl), "unexpected body of Select") + + key, ok := tpl[0].(string) + require.Truef(t, ok, "unexpected body of Select (0)") + require.Equalf(t, "commit_key", key, "unexpected body of Select (0)") + + value, ok := tpl[1].(string) + require.Truef(t, ok, "unexpected body of Select (1)") + require.Equalf(t, "commit_value", value, "unexpected body of Select (1)") + + // Commit transaction + req = tarantool.NewCommitRequest() + _, err = stream.Do(req).Get() + require.Nilf(t, err, "failed to Commit") + + // Select outside of transaction + data, err = conn.Do(selectReq).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, len(data), 1, "response Body len != 1 after Select") + + tpl, ok = data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Select") + require.Equalf(t, 2, len(tpl), "unexpected body of Select") + + key, ok = tpl[0].(string) + require.Truef(t, ok, "unexpected body of Select (0)") + require.Equalf(t, "commit_key", key, "unexpected body of Select (0)") + + value, ok = tpl[1].(string) + require.Truef(t, ok, "unexpected body of Select (1)") + require.Equalf(t, "commit_value", value, "unexpected body of Select (1)") +} + +func TestStream_Rollback(t *testing.T) { + var req tarantool.Request + var err error + + test_helpers.SkipIfStreamsUnsupported(t) + + roles := []bool{true, true, false, true, true} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err = test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + defer connPool.Close() + + stream, err := connPool.NewStream(pool.PreferRW) + require.Nilf(t, err, "failed to create stream") + require.NotNilf(t, connPool, "stream is nil after NewStream") + + // Begin transaction + req = tarantool.NewBeginRequest() + _, err = stream.Do(req).Get() + require.Nilf(t, err, "failed to Begin") + + // Insert in stream + req = tarantool.NewInsertRequest(spaceName). + Tuple([]interface{}{"rollback_key", "rollback_value"}) + _, err = stream.Do(req).Get() + require.Nilf(t, err, "failed to Insert") + + // Connect to servers[2] to check if tuple + // was not inserted outside of stream on RW instance + conn := test_helpers.ConnectWithValidation(t, makeDialer(servers[2]), connOpts) + defer conn.Close() + + // Select not related to the transaction + // while transaction is not committed + // result of select is empty + selectReq := tarantool.NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(tarantool.IterEq). + Key([]interface{}{"rollback_key"}) + data, err := conn.Do(selectReq).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, 0, len(data), "response Data len != 0") + + // Select in stream + data, err = stream.Do(selectReq).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, 1, len(data), "response Body len != 1 after Select") + + tpl, ok := data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Select") + require.Equalf(t, 2, len(tpl), "unexpected body of Select") + + key, ok := tpl[0].(string) + require.Truef(t, ok, "unexpected body of Select (0)") + require.Equalf(t, "rollback_key", key, "unexpected body of Select (0)") + + value, ok := tpl[1].(string) + require.Truef(t, ok, "unexpected body of Select (1)") + require.Equalf(t, "rollback_value", value, "unexpected body of Select (1)") + + // Rollback transaction + req = tarantool.NewRollbackRequest() + _, err = stream.Do(req).Get() + require.Nilf(t, err, "failed to Rollback") + + // Select outside of transaction + data, err = conn.Do(selectReq).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, 0, len(data), "response Body len != 0 after Select") + + // Select inside of stream after rollback + data, err = stream.Do(selectReq).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, 0, len(data), "response Body len != 0 after Select") +} + +func TestStream_TxnIsolationLevel(t *testing.T) { + var req tarantool.Request + var err error + + txnIsolationLevels := []tarantool.TxnIsolationLevel{ + tarantool.DefaultIsolationLevel, + tarantool.ReadCommittedLevel, + tarantool.ReadConfirmedLevel, + tarantool.BestEffortLevel, + } + + test_helpers.SkipIfStreamsUnsupported(t) + + roles := []bool{true, true, false, true, true} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err = test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + defer connPool.Close() + + stream, err := connPool.NewStream(pool.PreferRW) + require.Nilf(t, err, "failed to create stream") + require.NotNilf(t, connPool, "stream is nil after NewStream") + + // Connect to servers[2] to check if tuple + // was not inserted outside of stream on RW instance + conn := test_helpers.ConnectWithValidation(t, makeDialer(servers[2]), connOpts) + defer conn.Close() + + for _, level := range txnIsolationLevels { + // Begin transaction + req = tarantool.NewBeginRequest().TxnIsolation(level).Timeout(500 * time.Millisecond) + _, err = stream.Do(req).Get() + require.Nilf(t, err, "failed to Begin") + + // Insert in stream + req = tarantool.NewInsertRequest(spaceName). + Tuple([]interface{}{"level_key", "level_value"}) + _, err = stream.Do(req).Get() + require.Nilf(t, err, "failed to Insert") + + // Select not related to the transaction + // while transaction is not committed + // result of select is empty + selectReq := tarantool.NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(tarantool.IterEq). + Key([]interface{}{"level_key"}) + data, err := conn.Do(selectReq).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, 0, len(data), "response Data len != 0") + + // Select in stream + data, err = stream.Do(selectReq).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, 1, len(data), "response Body len != 1 after Select") + + tpl, ok := data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Select") + require.Equalf(t, 2, len(tpl), "unexpected body of Select") + + key, ok := tpl[0].(string) + require.Truef(t, ok, "unexpected body of Select (0)") + require.Equalf(t, "level_key", key, "unexpected body of Select (0)") + + value, ok := tpl[1].(string) + require.Truef(t, ok, "unexpected body of Select (1)") + require.Equalf(t, "level_value", value, "unexpected body of Select (1)") + + // Rollback transaction + req = tarantool.NewRollbackRequest() + _, err = stream.Do(req).Get() + require.Nilf(t, err, "failed to Rollback") + + // Select outside of transaction + data, err = conn.Do(selectReq).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, 0, len(data), "response Data len != 0") + + // Select inside of stream after rollback + data, err = stream.Do(selectReq).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, 0, len(data), "response Data len != 0") + + test_helpers.DeleteRecordByKey(t, conn, spaceNo, indexNo, []interface{}{"level_key"}) + } +} + +func TestConnectionPool_NewWatcher_no_watchers(t *testing.T) { + test_helpers.SkipIfWatchersSupported(t) + + const key = "TestConnectionPool_NewWatcher_no_watchers" + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nill after Connect") + defer connPool.Close() + + ch := make(chan struct{}) + connPool.NewWatcher(key, func(event tarantool.WatchEvent) { + close(ch) + }, pool.ANY) + + select { + case <-time.After(time.Second): + break + case <-ch: + t.Fatalf("watcher was created for connection that doesn't support it") + } +} + +func TestConnectionPool_NewWatcher_modes(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnectionPool_NewWatcher_modes" + + roles := []bool{true, false, false, true, true} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + defer connPool.Close() + + modes := []pool.Mode{ + pool.ANY, + pool.RW, + pool.RO, + pool.PreferRW, + pool.PreferRO, + } + for _, mode := range modes { + t.Run(fmt.Sprintf("%d", mode), func(t *testing.T) { + expectedServers := []string{} + for i, server := range servers { + if roles[i] && mode == pool.RW { + continue + } else if !roles[i] && mode == pool.RO { + continue + } + expectedServers = append(expectedServers, server) + } + + events := make(chan tarantool.WatchEvent, 1024) + defer close(events) + + watcher, err := connPool.NewWatcher(key, func(event tarantool.WatchEvent) { + require.Equal(t, key, event.Key) + require.Equal(t, nil, event.Value) + events <- event + }, mode) + require.Nilf(t, err, "failed to register a watcher") + defer watcher.Unregister() + + testMap := make(map[string]int) + + for i := 0; i < len(expectedServers); i++ { + select { + case event := <-events: + require.NotNil(t, event.Conn) + addr := event.Conn.Addr().String() + if val, ok := testMap[addr]; ok { + testMap[addr] = val + 1 + } else { + testMap[addr] = 1 + } + case <-time.After(time.Second): + t.Errorf("Failed to get a watch event.") + break + } + } + + for _, server := range expectedServers { + if val, ok := testMap[server]; !ok { + t.Errorf("Server not found: %s", server) + } else if val != 1 { + t.Errorf("Too many events %d for server %s", val, server) + } + } + }) + } +} + +func TestConnectionPool_NewWatcher_update(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnectionPool_NewWatcher_update" + const mode = pool.RW + const initCnt = 2 + roles := []bool{true, false, false, true, true} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + poolOpts := pool.Opts{ + CheckTimeout: 500 * time.Millisecond, + } + pool, err := pool.ConnectWithOpts(ctx, instances, poolOpts) + + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, pool, "conn is nil after Connect") + defer pool.Close() + + events := make(chan tarantool.WatchEvent, 1024) + defer close(events) + + watcher, err := pool.NewWatcher(key, func(event tarantool.WatchEvent) { + require.Equal(t, key, event.Key) + require.Equal(t, nil, event.Value) + events <- event + }, mode) + require.Nilf(t, err, "failed to create a watcher") + defer watcher.Unregister() + + // Wait for all initial events. + testMap := make(map[string]int) + for i := 0; i < initCnt; i++ { + select { + case event := <-events: + require.NotNil(t, event.Conn) + addr := event.Conn.Addr().String() + if val, ok := testMap[addr]; ok { + testMap[addr] = val + 1 + } else { + testMap[addr] = 1 + } + case <-time.After(poolOpts.CheckTimeout * 2): + t.Errorf("Failed to get a watch init event.") + break + } + } + + // Just invert roles for simplify the test. + for i, role := range roles { + roles[i] = !role + } + ctxSetRoles, cancelSetRoles := test_helpers.GetPoolConnectContext() + err = test_helpers.SetClusterRO(ctxSetRoles, dialers, connOpts, roles) + cancelSetRoles() + require.Nilf(t, err, "fail to set roles for cluster") + + // Wait for all updated events. + for i := 0; i < len(servers)-initCnt; i++ { + select { + case event := <-events: + require.NotNil(t, event.Conn) + addr := event.Conn.Addr().String() + if val, ok := testMap[addr]; ok { + testMap[addr] = val + 1 + } else { + testMap[addr] = 1 + } + case <-time.After(time.Second): + t.Errorf("Failed to get a watch update event.") + break + } + } + + // Check that all an event happen for an each connection. + for _, server := range servers { + if val, ok := testMap[server]; !ok { + t.Errorf("Server not found: %s", server) + } else { + require.Equal(t, val, 1, fmt.Sprintf("for server %s", server)) + } + } +} + +func TestWatcher_Unregister(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestWatcher_Unregister" + const mode = pool.RW + const expectedCnt = 2 + roles := []bool{true, false, false, true, true} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + pool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, pool, "conn is nil after Connect") + defer pool.Close() + + events := make(chan tarantool.WatchEvent, 1024) + defer close(events) + + watcher, err := pool.NewWatcher(key, func(event tarantool.WatchEvent) { + require.Equal(t, key, event.Key) + require.Equal(t, nil, event.Value) + events <- event + }, mode) + require.Nilf(t, err, "failed to create a watcher") + + for i := 0; i < expectedCnt; i++ { + select { + case <-events: + case <-time.After(time.Second): + t.Fatalf("Failed to skip initial events.") + } + } + watcher.Unregister() + + broadcast := tarantool.NewBroadcastRequest(key).Value("foo") + for i := 0; i < expectedCnt; i++ { + _, err := pool.Do(broadcast, mode).Get() + require.Nilf(t, err, "failed to send a broadcast request") + } + + select { + case event := <-events: + t.Fatalf("Get unexpected event: %v", event) + case <-time.After(time.Second): + } + + // Reset to the initial state. + broadcast = tarantool.NewBroadcastRequest(key) + for i := 0; i < expectedCnt; i++ { + _, err := pool.Do(broadcast, mode).Get() + require.Nilf(t, err, "failed to send a broadcast request") + } +} + +func TestConnectionPool_NewWatcher_concurrent(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const testConcurrency = 1000 + const key = "TestConnectionPool_NewWatcher_concurrent" + + roles := []bool{true, false, false, true, true} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + defer connPool.Close() + + var wg sync.WaitGroup + wg.Add(testConcurrency) + + mode := pool.ANY + callback := func(event tarantool.WatchEvent) {} + for i := 0; i < testConcurrency; i++ { + go func(i int) { + defer wg.Done() + + watcher, err := connPool.NewWatcher(key, callback, mode) + if err != nil { + t.Errorf("Failed to create a watcher: %s", err) + } else { + watcher.Unregister() + } + }(i) + } + wg.Wait() +} + +func TestWatcher_Unregister_concurrent(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const testConcurrency = 1000 + const key = "TestWatcher_Unregister_concurrent" + + roles := []bool{true, false, false, true, true} + + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := pool.Connect(ctx, instances) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + defer connPool.Close() + + mode := pool.ANY + watcher, err := connPool.NewWatcher(key, func(event tarantool.WatchEvent) { + }, mode) + require.Nilf(t, err, "failed to create a watcher") + + var wg sync.WaitGroup + wg.Add(testConcurrency) + + for i := 0; i < testConcurrency; i++ { + go func() { + defer wg.Done() + watcher.Unregister() + }() + } + wg.Wait() +} + +// runTestMain is a body of TestMain function +// (see https://pkg.go.dev/testing#hdr-Main). +// Using defer + os.Exit is not works so TestMain body +// is a separate function, see +// https://stackoverflow.com/questions/27629380/how-to-exit-a-go-program-honoring-deferred-calls +func runTestMain(m *testing.M) int { + initScript := "config.lua" + waitStart := 100 * time.Millisecond + connectRetry := 10 + retryTimeout := 500 * time.Millisecond + + // Tarantool supports streams and interactive transactions since version 2.10.0 + isStreamUnsupported, err := test_helpers.IsTarantoolVersionLess(2, 10, 0) + if err != nil { + log.Fatalf("Could not check the Tarantool version: %s", err) + } + + instsOpts := make([]test_helpers.StartOpts, 0, len(servers)) + for _, serv := range servers { + instsOpts = append(instsOpts, test_helpers.StartOpts{ + Listen: serv, + Dialer: tarantool.NetDialer{ + Address: serv, + User: user, + Password: pass, + }, + InitScript: initScript, + WaitStart: waitStart, + ConnectRetry: connectRetry, + RetryTimeout: retryTimeout, + MemtxUseMvccEngine: !isStreamUnsupported, + }) + } + + helpInstances, err = test_helpers.StartTarantoolInstances(instsOpts) + + if err != nil { + log.Fatalf("Failed to prepare test tarantool: %s", err) + return -1 + } + + defer test_helpers.StopTarantoolInstances(helpInstances) + + return m.Run() +} + +func TestConnectionPool_GetInfo_equal_instance_info(t *testing.T) { + var tCases [][]pool.Instance + + tCases = append(tCases, makeInstances([]string{servers[0], servers[1]}, connOpts)) + tCases = append(tCases, makeInstances([]string{ + servers[0], + servers[1], + servers[3]}, + connOpts)) + tCases = append(tCases, makeInstances([]string{servers[0]}, connOpts)) + + for _, tc := range tCases { + ctx, cancel := test_helpers.GetPoolConnectContext() + connPool, err := pool.Connect(ctx, tc) + cancel() + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + info := connPool.GetInfo() + + var infoInstances []pool.Instance + + for _, infoInst := range info { + infoInstances = append(infoInstances, infoInst.Instance) + } + + require.ElementsMatch(t, tc, infoInstances) + connPool.Close() + } +} + +func TestMain(m *testing.M) { + code := runTestMain(m) + os.Exit(code) +} diff --git a/pool/connector.go b/pool/connector.go new file mode 100644 index 000000000..c984bfd32 --- /dev/null +++ b/pool/connector.go @@ -0,0 +1,86 @@ +package pool + +import ( + "errors" + "fmt" + "time" + + "github.com/tarantool/go-tarantool/v3" +) + +// ConnectorAdapter allows to use Pooler as Connector. +type ConnectorAdapter struct { + pool Pooler + mode Mode +} + +var _ tarantool.Connector = (*ConnectorAdapter)(nil) + +// NewConnectorAdapter creates a new ConnectorAdapter object for a pool +// and with a mode. All requests to the pool will be executed in the +// specified mode. +func NewConnectorAdapter(pool Pooler, mode Mode) *ConnectorAdapter { + return &ConnectorAdapter{pool: pool, mode: mode} +} + +// ConnectedNow reports if connections is established at the moment. +func (c *ConnectorAdapter) ConnectedNow() bool { + ret, err := c.pool.ConnectedNow(c.mode) + if err != nil { + return false + } + return ret +} + +// ClosedNow reports if the connector is closed by user or all connections +// in the specified mode closed. +func (c *ConnectorAdapter) Close() error { + errs := c.pool.Close() + if len(errs) == 0 { + return nil + } + + err := errors.New("failed to close connection pool") + for _, e := range errs { + err = fmt.Errorf("%s: %w", err.Error(), e) + } + return err +} + +// ConfiguredTimeout returns a timeout from connections config. +func (c *ConnectorAdapter) ConfiguredTimeout() time.Duration { + ret, err := c.pool.ConfiguredTimeout(c.mode) + if err != nil { + return 0 * time.Second + } + return ret +} + +// NewPrepared passes a sql statement to Tarantool for preparation +// synchronously. +func (c *ConnectorAdapter) NewPrepared(expr string) (*tarantool.Prepared, error) { + return c.pool.NewPrepared(expr, c.mode) +} + +// NewStream creates new Stream object for connection. +// +// Since v. 2.10.0, Tarantool supports streams and interactive transactions over +// them. To use interactive transactions, memtx_use_mvcc_engine box option +// should be set to true. +// Since 1.7.0 +func (c *ConnectorAdapter) NewStream() (*tarantool.Stream, error) { + return c.pool.NewStream(c.mode) +} + +// NewWatcher creates new Watcher object for the pool +// +// Since 1.10.0 +func (c *ConnectorAdapter) NewWatcher(key string, + callback tarantool.WatchCallback) (tarantool.Watcher, error) { + return c.pool.NewWatcher(key, callback, c.mode) +} + +// Do performs a request asynchronously on the connection. +func (c *ConnectorAdapter) Do(req tarantool.Request) *tarantool.Future { + return c.pool.Do(req, c.mode) +} diff --git a/pool/connector_test.go b/pool/connector_test.go new file mode 100644 index 000000000..3ed8ea81d --- /dev/null +++ b/pool/connector_test.go @@ -0,0 +1,253 @@ +package pool_test + +import ( + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/tarantool/go-tarantool/v3" + . "github.com/tarantool/go-tarantool/v3/pool" +) + +var testMode Mode = RW + +type connectedNowMock struct { + Pooler + called int + mode Mode + retErr bool +} + +// Tests for different logic. + +func (m *connectedNowMock) ConnectedNow(mode Mode) (bool, error) { + m.called++ + m.mode = mode + + if m.retErr { + return true, errors.New("mock error") + } + return true, nil +} + +func TestConnectorConnectedNow(t *testing.T) { + m := &connectedNowMock{retErr: false} + c := NewConnectorAdapter(m, testMode) + + require.Truef(t, c.ConnectedNow(), "unexpected result") + require.Equalf(t, 1, m.called, "should be called only once") + require.Equalf(t, testMode, m.mode, "unexpected proxy mode") +} + +func TestConnectorConnectedNowWithError(t *testing.T) { + m := &connectedNowMock{retErr: true} + c := NewConnectorAdapter(m, testMode) + + require.Falsef(t, c.ConnectedNow(), "unexpected result") + require.Equalf(t, 1, m.called, "should be called only once") + require.Equalf(t, testMode, m.mode, "unexpected proxy mode") +} + +type closeMock struct { + Pooler + called int + retErr bool +} + +func (m *closeMock) Close() []error { + m.called++ + if m.retErr { + return []error{errors.New("err1"), errors.New("err2")} + } + return nil +} + +func TestConnectorClose(t *testing.T) { + m := &closeMock{retErr: false} + c := NewConnectorAdapter(m, testMode) + + require.Nilf(t, c.Close(), "unexpected result") + require.Equalf(t, 1, m.called, "should be called only once") +} + +func TestConnectorCloseWithError(t *testing.T) { + m := &closeMock{retErr: true} + c := NewConnectorAdapter(m, testMode) + + err := c.Close() + require.NotNilf(t, err, "unexpected result") + require.Equalf(t, 1, m.called, "should be called only once") + require.Equal(t, "failed to close connection pool: err1: err2", err.Error()) +} + +type configuredTimeoutMock struct { + Pooler + called int + timeout time.Duration + mode Mode + retErr bool +} + +func (m *configuredTimeoutMock) ConfiguredTimeout(mode Mode) (time.Duration, error) { + m.called++ + m.mode = mode + m.timeout = 5 * time.Second + if m.retErr { + return m.timeout, fmt.Errorf("err") + } + return m.timeout, nil +} + +func TestConnectorConfiguredTimeout(t *testing.T) { + m := &configuredTimeoutMock{retErr: false} + c := NewConnectorAdapter(m, testMode) + + require.Equalf(t, c.ConfiguredTimeout(), m.timeout, "unexpected result") + require.Equalf(t, 1, m.called, "should be called only once") + require.Equalf(t, testMode, m.mode, "unexpected proxy mode") +} + +func TestConnectorConfiguredTimeoutWithError(t *testing.T) { + m := &configuredTimeoutMock{retErr: true} + c := NewConnectorAdapter(m, testMode) + + ret := c.ConfiguredTimeout() + + require.NotEqualf(t, ret, m.timeout, "unexpected result") + require.Equalf(t, ret, time.Duration(0), "unexpected result") + require.Equalf(t, 1, m.called, "should be called only once") + require.Equalf(t, testMode, m.mode, "unexpected proxy mode") +} + +// Tests for that ConnectorAdapter is just a proxy for requests. + +var errReq error = errors.New("response error") +var reqFuture *tarantool.Future = &tarantool.Future{} + +var reqFunctionName string = "any_name" +var reqPrepared *tarantool.Prepared = &tarantool.Prepared{} + +type newPreparedMock struct { + Pooler + called int + expr string + mode Mode +} + +func (m *newPreparedMock) NewPrepared(expr string, + mode Mode) (*tarantool.Prepared, error) { + m.called++ + m.expr = expr + m.mode = mode + return reqPrepared, errReq +} + +func TestConnectorNewPrepared(t *testing.T) { + m := &newPreparedMock{} + c := NewConnectorAdapter(m, testMode) + + p, err := c.NewPrepared(reqFunctionName) + + require.Equalf(t, reqPrepared, p, "unexpected prepared") + require.Equalf(t, errReq, err, "unexpected error") + require.Equalf(t, 1, m.called, "should be called only once") + require.Equalf(t, reqFunctionName, m.expr, + "unexpected expr was passed") + require.Equalf(t, testMode, m.mode, "unexpected proxy mode") +} + +var reqStream *tarantool.Stream = &tarantool.Stream{} + +type newStreamMock struct { + Pooler + called int + mode Mode +} + +func (m *newStreamMock) NewStream(mode Mode) (*tarantool.Stream, error) { + m.called++ + m.mode = mode + return reqStream, errReq +} + +func TestConnectorNewStream(t *testing.T) { + m := &newStreamMock{} + c := NewConnectorAdapter(m, testMode) + + s, err := c.NewStream() + + require.Equalf(t, reqStream, s, "unexpected stream") + require.Equalf(t, errReq, err, "unexpected error") + require.Equalf(t, 1, m.called, "should be called only once") + require.Equalf(t, testMode, m.mode, "unexpected proxy mode") +} + +type watcherMock struct{} + +func (w *watcherMock) Unregister() {} + +const reqWatchKey = "foo" + +var reqWatcher tarantool.Watcher = &watcherMock{} + +type newWatcherMock struct { + Pooler + key string + callback tarantool.WatchCallback + called int + mode Mode +} + +func (m *newWatcherMock) NewWatcher(key string, + callback tarantool.WatchCallback, mode Mode) (tarantool.Watcher, error) { + m.called++ + m.key = key + m.callback = callback + m.mode = mode + return reqWatcher, errReq +} + +func TestConnectorNewWatcher(t *testing.T) { + m := &newWatcherMock{} + c := NewConnectorAdapter(m, testMode) + + w, err := c.NewWatcher(reqWatchKey, func(event tarantool.WatchEvent) {}) + + require.Equalf(t, reqWatcher, w, "unexpected watcher") + require.Equalf(t, errReq, err, "unexpected error") + require.Equalf(t, 1, m.called, "should be called only once") + require.Equalf(t, reqWatchKey, m.key, "unexpected key") + require.NotNilf(t, m.callback, "callback must be set") + require.Equalf(t, testMode, m.mode, "unexpected proxy mode") +} + +var reqRequest tarantool.Request = tarantool.NewPingRequest() + +type doMock struct { + Pooler + called int + req tarantool.Request + mode Mode +} + +func (m *doMock) Do(req tarantool.Request, mode Mode) *tarantool.Future { + m.called++ + m.req = req + m.mode = mode + return reqFuture +} + +func TestConnectorDo(t *testing.T) { + m := &doMock{} + c := NewConnectorAdapter(m, testMode) + + fut := c.Do(reqRequest) + + require.Equalf(t, reqFuture, fut, "unexpected future") + require.Equalf(t, 1, m.called, "should be called only once") + require.Equalf(t, reqRequest, m.req, "unexpected request") + require.Equalf(t, testMode, m.mode, "unexpected proxy mode") +} diff --git a/pool/const.go b/pool/const.go new file mode 100644 index 000000000..d15490928 --- /dev/null +++ b/pool/const.go @@ -0,0 +1,41 @@ +//go:generate go tool stringer -type Role -linecomment +package pool + +/* +Default mode for each request table: + + Request Default mode + ---------- -------------- + | call | no default | + | eval | no default | + | execute | no default | + | ping | no default | + | insert | RW | + | delete | RW | + | replace | RW | + | update | RW | + | upsert | RW | + | select | ANY | + | get | ANY | +*/ +type Mode uint32 + +const ( + ANY Mode = iota // The request can be executed on any instance (master or replica). + RW // The request can only be executed on master. + RO // The request can only be executed on replica. + PreferRW // If there is one, otherwise fallback to a writeable one (master). + PreferRO // If there is one, otherwise fallback to a read only one (replica). +) + +// Role describes a role of an instance by its mode. +type Role uint32 + +const ( + // UnknownRole - the connection pool was unable to detect the instance mode. + UnknownRole Role = iota // unknown + // MasterRole - the instance is in read-write mode. + MasterRole // master + // ReplicaRole - the instance is in read-only mode. + ReplicaRole // replica +) diff --git a/pool/const_test.go b/pool/const_test.go new file mode 100644 index 000000000..83e2678de --- /dev/null +++ b/pool/const_test.go @@ -0,0 +1,13 @@ +package pool + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRole_String(t *testing.T) { + require.Equal(t, "unknown", UnknownRole.String()) + require.Equal(t, "master", MasterRole.String()) + require.Equal(t, "replica", ReplicaRole.String()) +} diff --git a/pool/example_test.go b/pool/example_test.go new file mode 100644 index 000000000..6cf339baf --- /dev/null +++ b/pool/example_test.go @@ -0,0 +1,477 @@ +package pool_test + +import ( + "fmt" + "time" + + "github.com/tarantool/go-iproto" + + "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/pool" + "github.com/tarantool/go-tarantool/v3/test_helpers" +) + +type Tuple struct { + // Instruct msgpack to pack this struct as array, so no custom packer + // is needed. + _msgpack struct{} `msgpack:",asArray"` // nolint: structcheck,unused + Key string + Value string +} + +var testRoles = []bool{true, true, false, true, true} + +func examplePool(roles []bool, + connOpts tarantool.Opts) (*pool.ConnectionPool, error) { + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + err := test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + if err != nil { + return nil, fmt.Errorf("ConnectionPool is not established") + } + connPool, err := pool.Connect(ctx, instances) + if err != nil || connPool == nil { + return nil, fmt.Errorf("ConnectionPool is not established") + } + + return connPool, nil +} + +func exampleFeaturesPool(roles []bool, connOpts tarantool.Opts, + requiredProtocol tarantool.ProtocolInfo) (*pool.ConnectionPool, error) { + poolInstances := []pool.Instance{} + poolDialers := []tarantool.Dialer{} + for _, server := range servers { + dialer := tarantool.NetDialer{ + Address: server, + User: user, + Password: pass, + RequiredProtocolInfo: requiredProtocol, + } + poolInstances = append(poolInstances, pool.Instance{ + Name: server, + Dialer: dialer, + Opts: connOpts, + }) + poolDialers = append(poolDialers, dialer) + } + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + err := test_helpers.SetClusterRO(ctx, poolDialers, connOpts, roles) + if err != nil { + return nil, fmt.Errorf("ConnectionPool is not established") + } + connPool, err := pool.Connect(ctx, poolInstances) + if err != nil || connPool == nil { + return nil, fmt.Errorf("ConnectionPool is not established") + } + + return connPool, nil +} + +func ExampleConnectionPool_Do() { + connPool, err := examplePool(testRoles, connOpts) + if err != nil { + fmt.Println(err) + } + defer connPool.Close() + + modes := []pool.Mode{ + pool.ANY, + pool.RW, + pool.RO, + pool.PreferRW, + pool.PreferRO, + } + for _, m := range modes { + // It could be any request object. + req := tarantool.NewPingRequest() + _, err := connPool.Do(req, m).Get() + fmt.Println("Ping Error", err) + } + // Output: + // Ping Error + // Ping Error + // Ping Error + // Ping Error + // Ping Error +} + +func ExampleConnectionPool_NewPrepared() { + connPool, err := examplePool(testRoles, connOpts) + if err != nil { + fmt.Println(err) + } + defer connPool.Close() + + stmt, err := connPool.NewPrepared("SELECT 1", pool.ANY) + if err != nil { + fmt.Println(err) + } + + executeReq := tarantool.NewExecutePreparedRequest(stmt) + unprepareReq := tarantool.NewUnprepareRequest(stmt) + + _, err = connPool.Do(executeReq, pool.ANY).Get() + if err != nil { + fmt.Printf("Failed to execute prepared stmt") + } + _, err = connPool.Do(unprepareReq, pool.ANY).Get() + if err != nil { + fmt.Printf("Failed to prepare") + } +} + +func ExampleConnectionPool_NewWatcher() { + const key = "foo" + const value = "bar" + + connPool, err := examplePool(testRoles, connOpts) + if err != nil { + fmt.Println(err) + } + defer connPool.Close() + + callback := func(event tarantool.WatchEvent) { + fmt.Printf("event connection: %s\n", event.Conn.Addr()) + fmt.Printf("event key: %s\n", event.Key) + fmt.Printf("event value: %v\n", event.Value) + } + mode := pool.ANY + watcher, err := connPool.NewWatcher(key, callback, mode) + if err != nil { + fmt.Printf("Unexpected error: %s\n", err) + return + } + defer watcher.Unregister() + + connPool.Do(tarantool.NewBroadcastRequest(key).Value(value), mode).Get() + time.Sleep(time.Second) +} + +func getTestTxnProtocol() tarantool.ProtocolInfo { + // Assert that server supports expected protocol features + return tarantool.ProtocolInfo{ + Version: tarantool.ProtocolVersion(1), + Features: []iproto.Feature{ + iproto.IPROTO_FEATURE_STREAMS, + iproto.IPROTO_FEATURE_TRANSACTIONS, + }, + } +} + +func ExampleCommitRequest() { + var req tarantool.Request + var err error + + // Tarantool supports streams and interactive transactions since version 2.10.0 + isLess, _ := test_helpers.IsTarantoolVersionLess(2, 10, 0) + if err != nil || isLess { + return + } + + connPool, err := exampleFeaturesPool(testRoles, connOpts, getTestTxnProtocol()) + if err != nil { + fmt.Println(err) + return + } + defer connPool.Close() + + // example pool has only one rw instance + stream, err := connPool.NewStream(pool.RW) + if err != nil { + fmt.Println(err) + return + } + + // Begin transaction + req = tarantool.NewBeginRequest() + data, err := stream.Do(req).Get() + if err != nil { + fmt.Printf("Failed to Begin: %s", err.Error()) + return + } + fmt.Printf("Begin transaction: response is %#v\n", data) + + // Insert in stream + req = tarantool.NewInsertRequest(spaceName). + Tuple([]interface{}{"example_commit_key", "example_commit_value"}) + data, err = stream.Do(req).Get() + if err != nil { + fmt.Printf("Failed to Insert: %s", err.Error()) + return + } + fmt.Printf("Insert in stream: response is %#v\n", data) + + // Select not related to the transaction + // while transaction is not committed + // result of select is empty + selectReq := tarantool.NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(tarantool.IterEq). + Key([]interface{}{"example_commit_key"}) + data, err = connPool.Do(selectReq, pool.RW).Get() + if err != nil { + fmt.Printf("Failed to Select: %s", err.Error()) + return + } + fmt.Printf("Select out of stream before commit: response is %#v\n", data) + + // Select in stream + data, err = stream.Do(selectReq).Get() + if err != nil { + fmt.Printf("Failed to Select: %s", err.Error()) + return + } + fmt.Printf("Select in stream: response is %#v\n", data) + + // Commit transaction + req = tarantool.NewCommitRequest() + data, err = stream.Do(req).Get() + if err != nil { + fmt.Printf("Failed to Commit: %s", err.Error()) + return + } + fmt.Printf("Commit transaction: response is %#v\n", data) + + // Select outside of transaction + // example pool has only one rw instance + data, err = connPool.Do(selectReq, pool.RW).Get() + if err != nil { + fmt.Printf("Failed to Select: %s", err.Error()) + return + } + fmt.Printf("Select after commit: response is %#v\n", data) +} + +func ExampleRollbackRequest() { + var req tarantool.Request + var err error + + // Tarantool supports streams and interactive transactions since version 2.10.0 + isLess, _ := test_helpers.IsTarantoolVersionLess(2, 10, 0) + if err != nil || isLess { + return + } + + // example pool has only one rw instance + connPool, err := exampleFeaturesPool(testRoles, connOpts, getTestTxnProtocol()) + if err != nil { + fmt.Println(err) + return + } + defer connPool.Close() + + stream, err := connPool.NewStream(pool.RW) + if err != nil { + fmt.Println(err) + return + } + + // Begin transaction + req = tarantool.NewBeginRequest() + data, err := stream.Do(req).Get() + if err != nil { + fmt.Printf("Failed to Begin: %s", err.Error()) + return + } + fmt.Printf("Begin transaction: response is %#v\n", data) + + // Insert in stream + req = tarantool.NewInsertRequest(spaceName). + Tuple([]interface{}{"example_rollback_key", "example_rollback_value"}) + data, err = stream.Do(req).Get() + if err != nil { + fmt.Printf("Failed to Insert: %s", err.Error()) + return + } + fmt.Printf("Insert in stream: response is %#v\n", data) + + // Select not related to the transaction + // while transaction is not committed + // result of select is empty + selectReq := tarantool.NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(tarantool.IterEq). + Key([]interface{}{"example_rollback_key"}) + data, err = connPool.Do(selectReq, pool.RW).Get() + if err != nil { + fmt.Printf("Failed to Select: %s", err.Error()) + return + } + fmt.Printf("Select out of stream: response is %#v\n", data) + + // Select in stream + data, err = stream.Do(selectReq).Get() + if err != nil { + fmt.Printf("Failed to Select: %s", err.Error()) + return + } + fmt.Printf("Select in stream: response is %#v\n", data) + + // Rollback transaction + req = tarantool.NewRollbackRequest() + data, err = stream.Do(req).Get() + if err != nil { + fmt.Printf("Failed to Rollback: %s", err.Error()) + return + } + fmt.Printf("Rollback transaction: response is %#v\n", data) + + // Select outside of transaction + // example pool has only one rw instance + data, err = connPool.Do(selectReq, pool.RW).Get() + if err != nil { + fmt.Printf("Failed to Select: %s", err.Error()) + return + } + fmt.Printf("Select after Rollback: response is %#v\n", data) +} + +func ExampleBeginRequest_TxnIsolation() { + var req tarantool.Request + var err error + + // Tarantool supports streams and interactive transactions since version 2.10.0 + isLess, _ := test_helpers.IsTarantoolVersionLess(2, 10, 0) + if err != nil || isLess { + return + } + + // example pool has only one rw instance + connPool, err := exampleFeaturesPool(testRoles, connOpts, getTestTxnProtocol()) + if err != nil { + fmt.Println(err) + return + } + defer connPool.Close() + + stream, err := connPool.NewStream(pool.RW) + if err != nil { + fmt.Println(err) + return + } + + // Begin transaction + req = tarantool.NewBeginRequest(). + TxnIsolation(tarantool.ReadConfirmedLevel). + Timeout(500 * time.Millisecond) + data, err := stream.Do(req).Get() + if err != nil { + fmt.Printf("Failed to Begin: %s", err.Error()) + return + } + fmt.Printf("Begin transaction: response is %#v\n", data) + + // Insert in stream + req = tarantool.NewInsertRequest(spaceName). + Tuple([]interface{}{"isolation_level_key", "isolation_level_value"}) + data, err = stream.Do(req).Get() + if err != nil { + fmt.Printf("Failed to Insert: %s", err.Error()) + return + } + fmt.Printf("Insert in stream: response is %#v\n", data) + + // Select not related to the transaction + // while transaction is not committed + // result of select is empty + selectReq := tarantool.NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(tarantool.IterEq). + Key([]interface{}{"isolation_level_key"}) + data, err = connPool.Do(selectReq, pool.RW).Get() + if err != nil { + fmt.Printf("Failed to Select: %s", err.Error()) + return + } + fmt.Printf("Select out of stream: response is %#v\n", data) + + // Select in stream + data, err = stream.Do(selectReq).Get() + if err != nil { + fmt.Printf("Failed to Select: %s", err.Error()) + return + } + fmt.Printf("Select in stream: response is %#v\n", data) + + // Rollback transaction + req = tarantool.NewRollbackRequest() + data, err = stream.Do(req).Get() + if err != nil { + fmt.Printf("Failed to Rollback: %s", err.Error()) + return + } + fmt.Printf("Rollback transaction: response is %#v\n", data) + + // Select outside of transaction + // example pool has only one rw instance + data, err = connPool.Do(selectReq, pool.RW).Get() + if err != nil { + fmt.Printf("Failed to Select: %s", err.Error()) + return + } + fmt.Printf("Select after Rollback: response is %#v\n", data) +} + +func ExampleConnectorAdapter() { + connPool, err := examplePool(testRoles, connOpts) + if err != nil { + fmt.Println(err) + } + defer connPool.Close() + + adapter := pool.NewConnectorAdapter(connPool, pool.RW) + var connector tarantool.Connector = adapter + + // Ping an RW instance to check connection. + data, err := connector.Do(tarantool.NewPingRequest()).Get() + fmt.Println("Ping Data", data) + fmt.Println("Ping Error", err) + // Output: + // Ping Data [] + // Ping Error +} + +// ExampleConnectionPool_CloseGraceful_force demonstrates how to force close +// a connection pool with graceful close in progress after a while. +func ExampleConnectionPool_CloseGraceful_force() { + connPool, err := examplePool(testRoles, connOpts) + if err != nil { + fmt.Println(err) + return + } + + eval := `local fiber = require('fiber') + local time = ... + fiber.sleep(time) +` + req := tarantool.NewEvalRequest(eval).Args([]interface{}{10}) + fut := connPool.Do(req, pool.ANY) + + done := make(chan struct{}) + go func() { + connPool.CloseGraceful() + fmt.Println("ConnectionPool.CloseGraceful() done!") + close(done) + }() + + select { + case <-done: + case <-time.After(3 * time.Second): + fmt.Println("Force ConnectionPool.Close()!") + connPool.Close() + } + <-done + + fmt.Println("Result:") + fmt.Println(fut.Get()) + // Output: + // Force ConnectionPool.Close()! + // ConnectionPool.CloseGraceful() done! + // Result: + // [] connection closed by client (0x4001) +} diff --git a/pool/pooler.go b/pool/pooler.go new file mode 100644 index 000000000..1d05d1eb9 --- /dev/null +++ b/pool/pooler.go @@ -0,0 +1,32 @@ +package pool + +import ( + "context" + "time" + + "github.com/tarantool/go-tarantool/v3" +) + +// TopologyEditor is the interface that must be implemented by a connection pool. +// It describes edit topology methods. +type TopologyEditor interface { + Add(ctx context.Context, instance Instance) error + Remove(name string) error +} + +// Pooler is the interface that must be implemented by a connection pool. +type Pooler interface { + TopologyEditor + + ConnectedNow(mode Mode) (bool, error) + Close() []error + // CloseGraceful closes connections in the ConnectionPool gracefully. It waits + // for all requests to complete. + CloseGraceful() []error + ConfiguredTimeout(mode Mode) (time.Duration, error) + NewPrepared(expr string, mode Mode) (*tarantool.Prepared, error) + NewStream(mode Mode) (*tarantool.Stream, error) + NewWatcher(key string, callback tarantool.WatchCallback, + mode Mode) (tarantool.Watcher, error) + Do(req tarantool.Request, mode Mode) (fut *tarantool.Future) +} diff --git a/pool/role_string.go b/pool/role_string.go new file mode 100644 index 000000000..162ebd698 --- /dev/null +++ b/pool/role_string.go @@ -0,0 +1,25 @@ +// Code generated by "stringer -type Role -linecomment"; DO NOT EDIT. + +package pool + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[UnknownRole-0] + _ = x[MasterRole-1] + _ = x[ReplicaRole-2] +} + +const _Role_name = "unknownmasterreplica" + +var _Role_index = [...]uint8{0, 7, 13, 20} + +func (i Role) String() string { + if i >= Role(len(_Role_index)-1) { + return "Role(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _Role_name[_Role_index[i]:_Role_index[i+1]] +} diff --git a/pool/round_robin.go b/pool/round_robin.go new file mode 100644 index 000000000..f3ccb014c --- /dev/null +++ b/pool/round_robin.go @@ -0,0 +1,112 @@ +package pool + +import ( + "sync" + "sync/atomic" + + "github.com/tarantool/go-tarantool/v3" +) + +type roundRobinStrategy struct { + conns []*tarantool.Connection + indexById map[string]uint + mutex sync.RWMutex + size uint64 + current uint64 +} + +func newRoundRobinStrategy(size int) *roundRobinStrategy { + return &roundRobinStrategy{ + conns: make([]*tarantool.Connection, 0, size), + indexById: make(map[string]uint, size), + size: 0, + current: 0, + } +} + +func (r *roundRobinStrategy) GetConnection(id string) *tarantool.Connection { + r.mutex.RLock() + defer r.mutex.RUnlock() + + index, found := r.indexById[id] + if !found { + return nil + } + + return r.conns[index] +} + +func (r *roundRobinStrategy) DeleteConnection(id string) *tarantool.Connection { + r.mutex.Lock() + defer r.mutex.Unlock() + + if r.size == 0 { + return nil + } + + index, found := r.indexById[id] + if !found { + return nil + } + + delete(r.indexById, id) + + conn := r.conns[index] + r.conns = append(r.conns[:index], r.conns[index+1:]...) + r.size -= 1 + + for k, v := range r.indexById { + if v > index { + r.indexById[k] = v - 1 + } + } + + return conn +} + +func (r *roundRobinStrategy) IsEmpty() bool { + r.mutex.RLock() + defer r.mutex.RUnlock() + + return r.size == 0 +} + +func (r *roundRobinStrategy) GetNextConnection() *tarantool.Connection { + r.mutex.RLock() + defer r.mutex.RUnlock() + + if r.size == 0 { + return nil + } + return r.conns[r.nextIndex()] +} + +func (r *roundRobinStrategy) GetConnections() map[string]*tarantool.Connection { + r.mutex.RLock() + defer r.mutex.RUnlock() + + conns := map[string]*tarantool.Connection{} + for id, index := range r.indexById { + conns[id] = r.conns[index] + } + + return conns +} + +func (r *roundRobinStrategy) AddConnection(id string, conn *tarantool.Connection) { + r.mutex.Lock() + defer r.mutex.Unlock() + + if idx, ok := r.indexById[id]; ok { + r.conns[idx] = conn + } else { + r.conns = append(r.conns, conn) + r.indexById[id] = uint(r.size) + r.size += 1 + } +} + +func (r *roundRobinStrategy) nextIndex() uint64 { + next := atomic.AddUint64(&r.current, 1) + return (next - 1) % r.size +} diff --git a/pool/round_robin_test.go b/pool/round_robin_test.go new file mode 100644 index 000000000..dcc219fd4 --- /dev/null +++ b/pool/round_robin_test.go @@ -0,0 +1,90 @@ +package pool + +import ( + "testing" + + "github.com/tarantool/go-tarantool/v3" +) + +const ( + validAddr1 = "x" + validAddr2 = "y" +) + +func TestRoundRobinAddDelete(t *testing.T) { + rr := newRoundRobinStrategy(10) + + addrs := []string{validAddr1, validAddr2} + conns := []*tarantool.Connection{&tarantool.Connection{}, &tarantool.Connection{}} + + for i, addr := range addrs { + rr.AddConnection(addr, conns[i]) + } + + for i, addr := range addrs { + if conn := rr.DeleteConnection(addr); conn != conns[i] { + t.Errorf("Unexpected connection on address %s", addr) + } + } + if !rr.IsEmpty() { + t.Errorf("RoundRobin does not empty") + } +} + +func TestRoundRobinAddDuplicateDelete(t *testing.T) { + rr := newRoundRobinStrategy(10) + + conn1 := &tarantool.Connection{} + conn2 := &tarantool.Connection{} + + rr.AddConnection(validAddr1, conn1) + rr.AddConnection(validAddr1, conn2) + + if rr.DeleteConnection(validAddr1) != conn2 { + t.Errorf("Unexpected deleted connection") + } + if !rr.IsEmpty() { + t.Errorf("RoundRobin does not empty") + } + if rr.DeleteConnection(validAddr1) != nil { + t.Errorf("Unexpected value after second deletion") + } +} + +func TestRoundRobinGetNextConnection(t *testing.T) { + rr := newRoundRobinStrategy(10) + + addrs := []string{validAddr1, validAddr2} + conns := []*tarantool.Connection{&tarantool.Connection{}, &tarantool.Connection{}} + + for i, addr := range addrs { + rr.AddConnection(addr, conns[i]) + } + + expectedConns := []*tarantool.Connection{conns[0], conns[1], conns[0], conns[1]} + for i, expected := range expectedConns { + if rr.GetNextConnection() != expected { + t.Errorf("Unexpected connection on %d call", i) + } + } +} + +func TestRoundRobinStrategy_GetConnections(t *testing.T) { + rr := newRoundRobinStrategy(10) + + addrs := []string{validAddr1, validAddr2} + conns := []*tarantool.Connection{&tarantool.Connection{}, &tarantool.Connection{}} + + for i, addr := range addrs { + rr.AddConnection(addr, conns[i]) + } + + rr.GetConnections()[validAddr2] = conns[0] // GetConnections() returns a copy. + rrConns := rr.GetConnections() + + for i, addr := range addrs { + if conns[i] != rrConns[addr] { + t.Errorf("Unexpected connection on %s addr", addr) + } + } +} diff --git a/pool/state.go b/pool/state.go new file mode 100644 index 000000000..2af093e60 --- /dev/null +++ b/pool/state.go @@ -0,0 +1,27 @@ +package pool + +import ( + "sync/atomic" +) + +// pool state +type state uint32 + +const ( + unknownState state = iota + connectedState + shutdownState + closedState +) + +func (s *state) set(news state) { + atomic.StoreUint32((*uint32)(s), uint32(news)) +} + +func (s *state) cas(olds, news state) bool { + return atomic.CompareAndSwapUint32((*uint32)(s), uint32(olds), uint32(news)) +} + +func (s *state) get() state { + return state(atomic.LoadUint32((*uint32)(s))) +} diff --git a/pool/watcher.go b/pool/watcher.go new file mode 100644 index 000000000..aee3103fd --- /dev/null +++ b/pool/watcher.go @@ -0,0 +1,133 @@ +package pool + +import ( + "sync" + + "github.com/tarantool/go-tarantool/v3" +) + +// watcherContainer is a very simple implementation of a thread-safe container +// for watchers. It is not expected that there will be too many watchers and +// they will registered/unregistered too frequently. +// +// Otherwise, the implementation will need to be optimized. +type watcherContainer struct { + head *poolWatcher + mutex sync.RWMutex +} + +// add adds a watcher to the container. +func (c *watcherContainer) add(watcher *poolWatcher) { + c.mutex.Lock() + defer c.mutex.Unlock() + + watcher.next = c.head + c.head = watcher +} + +// remove removes a watcher from the container. +func (c *watcherContainer) remove(watcher *poolWatcher) bool { + c.mutex.Lock() + defer c.mutex.Unlock() + + if watcher == c.head { + c.head = watcher.next + return true + } else if c.head != nil { + cur := c.head + for cur.next != nil { + if cur.next == watcher { + cur.next = watcher.next + return true + } + cur = cur.next + } + } + return false +} + +// foreach iterates over the container to the end or until the call returns +// false. +func (c *watcherContainer) foreach(call func(watcher *poolWatcher) error) error { + cur := c.head + for cur != nil { + if err := call(cur); err != nil { + return err + } + cur = cur.next + } + return nil +} + +// poolWatcher is an internal implementation of the tarantool.Watcher interface. +type poolWatcher struct { + // The watcher container data. We can split the structure into two parts + // in the future: a watcher data and a watcher container data, but it looks + // simple at now. + + // next item in the watcher container. + next *poolWatcher + // container is the container for all active poolWatcher objects. + container *watcherContainer + + // The watcher data. + // mode of the watcher. + mode Mode + key string + callback tarantool.WatchCallback + // watchers is a map connection -> connection watcher. + watchers map[*tarantool.Connection]tarantool.Watcher + // unregistered is true if the watcher already unregistered. + unregistered bool + // mutex for the pool watcher. + mutex sync.Mutex +} + +// Unregister unregisters the pool watcher. +func (w *poolWatcher) Unregister() { + w.mutex.Lock() + unregistered := w.unregistered + w.mutex.Unlock() + + if !unregistered && w.container.remove(w) { + w.mutex.Lock() + w.unregistered = true + for _, watcher := range w.watchers { + watcher.Unregister() + } + w.mutex.Unlock() + } +} + +// watch adds a watcher for the connection. +func (w *poolWatcher) watch(conn *tarantool.Connection) error { + w.mutex.Lock() + defer w.mutex.Unlock() + + if !w.unregistered { + if _, ok := w.watchers[conn]; ok { + return nil + } + + if watcher, err := conn.NewWatcher(w.key, w.callback); err == nil { + w.watchers[conn] = watcher + return nil + } else { + return err + } + } + return nil +} + +// unwatch removes a watcher for the connection. +func (w *poolWatcher) unwatch(conn *tarantool.Connection) { + w.mutex.Lock() + defer w.mutex.Unlock() + + if !w.unregistered { + if watcher, ok := w.watchers[conn]; ok { + watcher.Unregister() + delete(w.watchers, conn) + } + } +} diff --git a/prepared.go b/prepared.go new file mode 100644 index 000000000..6f7ace911 --- /dev/null +++ b/prepared.go @@ -0,0 +1,208 @@ +package tarantool + +import ( + "context" + "fmt" + "io" + + "github.com/tarantool/go-iproto" + "github.com/vmihailenco/msgpack/v5" +) + +// PreparedID is a type for Prepared Statement ID +type PreparedID uint64 + +// Prepared is a type for handling prepared statements +// +// Since 1.7.0 +type Prepared struct { + StatementID PreparedID + MetaData []ColumnMetaData + ParamCount uint64 + Conn *Connection +} + +// NewPreparedFromResponse constructs a Prepared object. +func NewPreparedFromResponse(conn *Connection, resp Response) (*Prepared, error) { + if resp == nil { + return nil, fmt.Errorf("passed nil response") + } + data, err := resp.Decode() + if err != nil { + return nil, fmt.Errorf("decode response body error: %s", err.Error()) + } + if data == nil { + return nil, fmt.Errorf("response Data is nil") + } + if len(data) == 0 { + return nil, fmt.Errorf("response Data format is wrong") + } + stmt, ok := data[0].(*Prepared) + if !ok { + return nil, fmt.Errorf("response Data format is wrong") + } + stmt.Conn = conn + return stmt, nil +} + +// PrepareRequest helps you to create a prepare request object for execution +// by a Connection. +type PrepareRequest struct { + baseRequest + expr string +} + +// NewPrepareRequest returns a new empty PrepareRequest. +func NewPrepareRequest(expr string) *PrepareRequest { + req := new(PrepareRequest) + req.rtype = iproto.IPROTO_PREPARE + req.expr = expr + return req +} + +// Body fills an msgpack.Encoder with the execute request body. +func (req *PrepareRequest) Body(_ SchemaResolver, enc *msgpack.Encoder) error { + if err := enc.EncodeMapLen(1); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(iproto.IPROTO_SQL_TEXT)); err != nil { + return err + } + + return enc.EncodeString(req.expr) +} + +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *PrepareRequest) Context(ctx context.Context) *PrepareRequest { + req.ctx = ctx + return req +} + +// Response creates a response for the PrepareRequest. +func (req *PrepareRequest) Response(header Header, body io.Reader) (Response, error) { + baseResp, err := createBaseResponse(header, body) + if err != nil { + return nil, err + } + return &PrepareResponse{baseResponse: baseResp}, nil +} + +// UnprepareRequest helps you to create an unprepare request object for +// execution by a Connection. +type UnprepareRequest struct { + baseRequest + stmt *Prepared +} + +// NewUnprepareRequest returns a new empty UnprepareRequest. +func NewUnprepareRequest(stmt *Prepared) *UnprepareRequest { + req := new(UnprepareRequest) + req.rtype = iproto.IPROTO_PREPARE + req.stmt = stmt + return req +} + +// Conn returns the Connection object the request belongs to +func (req *UnprepareRequest) Conn() *Connection { + return req.stmt.Conn +} + +// Body fills an msgpack.Encoder with the execute request body. +func (req *UnprepareRequest) Body(_ SchemaResolver, enc *msgpack.Encoder) error { + if err := enc.EncodeMapLen(1); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(iproto.IPROTO_STMT_ID)); err != nil { + return err + } + + return enc.EncodeUint(uint64(req.stmt.StatementID)) +} + +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *UnprepareRequest) Context(ctx context.Context) *UnprepareRequest { + req.ctx = ctx + return req +} + +// ExecutePreparedRequest helps you to create an execute prepared request +// object for execution by a Connection. +type ExecutePreparedRequest struct { + baseRequest + stmt *Prepared + args interface{} +} + +// NewExecutePreparedRequest returns a new empty preparedExecuteRequest. +func NewExecutePreparedRequest(stmt *Prepared) *ExecutePreparedRequest { + req := new(ExecutePreparedRequest) + req.rtype = iproto.IPROTO_EXECUTE + req.stmt = stmt + req.args = []interface{}{} + return req +} + +// Conn returns the Connection object the request belongs to +func (req *ExecutePreparedRequest) Conn() *Connection { + return req.stmt.Conn +} + +// Args sets the args for execute the prepared request. +// Note: default value is empty. +func (req *ExecutePreparedRequest) Args(args interface{}) *ExecutePreparedRequest { + req.args = args + return req +} + +// Body fills an msgpack.Encoder with the execute request body. +func (req *ExecutePreparedRequest) Body(_ SchemaResolver, enc *msgpack.Encoder) error { + if err := enc.EncodeMapLen(2); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(iproto.IPROTO_STMT_ID)); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(req.stmt.StatementID)); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(iproto.IPROTO_SQL_BIND)); err != nil { + return err + } + + return encodeSQLBind(enc, req.args) +} + +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *ExecutePreparedRequest) Context(ctx context.Context) *ExecutePreparedRequest { + req.ctx = ctx + return req +} + +// Response creates a response for the ExecutePreparedRequest. +func (req *ExecutePreparedRequest) Response(header Header, body io.Reader) (Response, error) { + baseResp, err := createBaseResponse(header, body) + if err != nil { + return nil, err + } + return &ExecuteResponse{baseResponse: baseResp}, nil +} diff --git a/protocol.go b/protocol.go new file mode 100644 index 000000000..e4e93e7c7 --- /dev/null +++ b/protocol.go @@ -0,0 +1,121 @@ +package tarantool + +import ( + "context" + + "github.com/tarantool/go-iproto" + "github.com/vmihailenco/msgpack/v5" +) + +// ProtocolVersion type stores Tarantool protocol version. +type ProtocolVersion uint64 + +// ProtocolInfo type aggregates Tarantool protocol version and features info. +type ProtocolInfo struct { + // Auth is an authentication method. + Auth Auth + // Version is the supported protocol version. + Version ProtocolVersion + // Features are supported protocol features. + Features []iproto.Feature +} + +// Clone returns an exact copy of the ProtocolInfo object. +// Any changes in copy will not affect the original values. +func (info ProtocolInfo) Clone() ProtocolInfo { + infoCopy := info + + if info.Features != nil { + infoCopy.Features = make([]iproto.Feature, len(info.Features)) + copy(infoCopy.Features, info.Features) + } + + return infoCopy +} + +var clientProtocolInfo ProtocolInfo = ProtocolInfo{ + // Protocol version supported by connector. Version 3 + // was introduced in Tarantool 2.10.0, version 4 was + // introduced in master 948e5cd (possible 2.10.5 or 2.11.0). + // Support of protocol version on connector side was introduced in + // 1.10.0. + Version: ProtocolVersion(6), + // Streams and transactions were introduced in protocol version 1 + // (Tarantool 2.10.0), in connector since 1.7.0. + // Error extension type was introduced in protocol + // version 2 (Tarantool 2.10.0), in connector since 1.10.0. + // Watchers were introduced in protocol version 3 (Tarantool 2.10.0), in + // connector since 1.10.0. + // Pagination were introduced in protocol version 4 (Tarantool 2.11.0), in + // connector since 1.11.0. + // WatchOnce request type was introduces in protocol version 6 + // (Tarantool 3.0.0), in connector since 2.0.0. + Features: []iproto.Feature{ + iproto.IPROTO_FEATURE_STREAMS, + iproto.IPROTO_FEATURE_TRANSACTIONS, + iproto.IPROTO_FEATURE_ERROR_EXTENSION, + iproto.IPROTO_FEATURE_WATCHERS, + iproto.IPROTO_FEATURE_PAGINATION, + iproto.IPROTO_FEATURE_SPACE_AND_INDEX_NAMES, + iproto.IPROTO_FEATURE_WATCH_ONCE, + iproto.IPROTO_FEATURE_IS_SYNC, + iproto.IPROTO_FEATURE_INSERT_ARROW, + }, +} + +// IdRequest informs the server about supported protocol +// version and protocol features. +type IdRequest struct { + baseRequest + protocolInfo ProtocolInfo +} + +// NewIdRequest returns a new IdRequest. +func NewIdRequest(protocolInfo ProtocolInfo) *IdRequest { + req := new(IdRequest) + req.rtype = iproto.IPROTO_ID + req.protocolInfo = protocolInfo.Clone() + return req +} + +// Body fills an msgpack.Encoder with the id request body. +func (req *IdRequest) Body(_ SchemaResolver, enc *msgpack.Encoder) error { + if err := enc.EncodeMapLen(2); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(iproto.IPROTO_VERSION)); err != nil { + return err + } + + if err := enc.Encode(req.protocolInfo.Version); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(iproto.IPROTO_FEATURES)); err != nil { + return err + } + + if err := enc.EncodeArrayLen(len(req.protocolInfo.Features)); err != nil { + return err + } + + for _, feature := range req.protocolInfo.Features { + if err := enc.Encode(feature); err != nil { + return err + } + } + + return nil +} + +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *IdRequest) Context(ctx context.Context) *IdRequest { + req.ctx = ctx + return req +} diff --git a/protocol_test.go b/protocol_test.go new file mode 100644 index 000000000..c79a8afd3 --- /dev/null +++ b/protocol_test.go @@ -0,0 +1,28 @@ +package tarantool_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/tarantool/go-iproto" + + . "github.com/tarantool/go-tarantool/v3" +) + +func TestProtocolInfoClonePreservesFeatures(t *testing.T) { + original := ProtocolInfo{ + Version: ProtocolVersion(100), + Features: []iproto.Feature{iproto.Feature(99), iproto.Feature(100)}, + } + + origCopy := original.Clone() + + original.Features[1] = iproto.Feature(98) + + require.Equal(t, + origCopy, + ProtocolInfo{ + Version: ProtocolVersion(100), + Features: []iproto.Feature{iproto.Feature(99), iproto.Feature(100)}, + }) +} diff --git a/queue/const.go b/queue/const.go new file mode 100644 index 000000000..0e0eadcc7 --- /dev/null +++ b/queue/const.go @@ -0,0 +1,37 @@ +package queue + +const ( + READY = "r" + TAKEN = "t" + DONE = "-" + BURIED = "!" + DELAYED = "~" +) + +type queueType string + +const ( + FIFO queueType = "fifo" + FIFO_TTL queueType = "fifottl" + UTUBE queueType = "utube" + UTUBE_TTL queueType = "utubettl" +) + +type State int + +const ( + UnknownState State = iota + InitState + StartupState + RunningState + EndingState + WaitingState +) + +var strToState = map[string]State{ + "INIT": InitState, + "STARTUP": StartupState, + "RUNNING": RunningState, + "ENDING": EndingState, + "WAITING": WaitingState, +} diff --git a/queue/example_connection_pool_test.go b/queue/example_connection_pool_test.go new file mode 100644 index 000000000..8b5aab7cb --- /dev/null +++ b/queue/example_connection_pool_test.go @@ -0,0 +1,273 @@ +package queue_test + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + + "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/pool" + "github.com/tarantool/go-tarantool/v3/queue" + "github.com/tarantool/go-tarantool/v3/test_helpers" +) + +// QueueConnectionHandler handles new connections in a ConnectionPool. +type QueueConnectionHandler struct { + name string + cfg queue.Cfg + + uuid uuid.UUID + registered bool + err error + mutex sync.Mutex + updated chan struct{} + masterCnt int32 +} + +// QueueConnectionHandler implements the ConnectionHandler interface. +var _ pool.ConnectionHandler = &QueueConnectionHandler{} + +// NewQueueConnectionHandler creates a QueueConnectionHandler object. +func NewQueueConnectionHandler(name string, cfg queue.Cfg) *QueueConnectionHandler { + return &QueueConnectionHandler{ + name: name, + cfg: cfg, + updated: make(chan struct{}, 10), + } +} + +// Discovered configures a queue for an instance and identifies a shared queue +// session on master instances. +// +// NOTE: the Queue supports only a master-replica cluster configuration. It +// does not support a master-master configuration. +func (h *QueueConnectionHandler) Discovered(name string, conn *tarantool.Connection, + role pool.Role) error { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.err != nil { + return h.err + } + + master := role == pool.MasterRole + + q := queue.New(conn, h.name) + + // Check is queue ready to work. + if state, err := q.State(); err != nil { + h.updated <- struct{}{} + h.err = err + return err + } else if master && state != queue.RunningState { + return fmt.Errorf("queue state is not RUNNING: %d", state) + } else if !master && state != queue.InitState && state != queue.WaitingState { + return fmt.Errorf("queue state is not INIT and not WAITING: %d", state) + } + + defer func() { + h.updated <- struct{}{} + }() + + // Set up a queue module configuration for an instance. Ideally, this + // should be done before box.cfg({}) or you need to wait some time + // before start a work. + // + // See: + // https://github.com/tarantool/queue/issues/206 + opts := queue.CfgOpts{InReplicaset: true, Ttr: 60 * time.Second} + + if h.err = q.Cfg(opts); h.err != nil { + return fmt.Errorf("unable to configure queue: %w", h.err) + } + + // The queue only works with a master instance. + if !master { + return nil + } + + if !h.registered { + // We register a shared session at the first time. + if h.uuid, h.err = q.Identify(nil); h.err != nil { + return h.err + } + h.registered = true + } else { + // We re-identify as the shared session. + if _, h.err = q.Identify(&h.uuid); h.err != nil { + return h.err + } + } + + if h.err = q.Create(h.cfg); h.err != nil { + return h.err + } + + fmt.Printf("Master %s is ready to work!\n", name) + atomic.AddInt32(&h.masterCnt, 1) + + return nil +} + +// Deactivated doesn't do anything useful for the example. +func (h *QueueConnectionHandler) Deactivated(name string, conn *tarantool.Connection, + role pool.Role) error { + if role == pool.MasterRole { + atomic.AddInt32(&h.masterCnt, -1) + } + return nil +} + +// Closes closes a QueueConnectionHandler object. +func (h *QueueConnectionHandler) Close() { + close(h.updated) +} + +// Example demonstrates how to use the queue package with the pool +// package. First of all, you need to create a ConnectionHandler implementation +// for the a ConnectionPool object to process new connections from +// RW-instances. +// +// You need to register a shared session UUID at a first master connection. +// It needs to be used to re-identify as the shared session on new +// RW-instances. See QueueConnectionHandler.Discovered() implementation. +// +// After that, you need to create a ConnectorAdapter object with RW mode for +// the ConnectionPool to send requests into RW-instances. This adapter can +// be used to create a ready-to-work queue object. +func Example_connectionPool() { + // Create a ConnectionHandler object. + cfg := queue.Cfg{ + Temporary: false, + IfNotExists: true, + Kind: queue.FIFO, + Opts: queue.Opts{ + Ttl: 10 * time.Second, + }, + } + h := NewQueueConnectionHandler("test_queue", cfg) + defer h.Close() + + // Create a ConnectionPool object. + poolServers := []string{"127.0.0.1:3014", "127.0.0.1:3015"} + poolDialers := []tarantool.Dialer{} + poolInstances := []pool.Instance{} + + connOpts := tarantool.Opts{ + Timeout: 5 * time.Second, + } + for _, server := range poolServers { + dialer := tarantool.NetDialer{ + Address: server, + User: "test", + Password: "test", + } + poolDialers = append(poolDialers, dialer) + poolInstances = append(poolInstances, pool.Instance{ + Name: server, + Dialer: dialer, + Opts: connOpts, + }) + } + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + poolOpts := pool.Opts{ + CheckTimeout: 5 * time.Second, + ConnectionHandler: h, + } + connPool, err := pool.ConnectWithOpts(ctx, poolInstances, poolOpts) + if err != nil { + fmt.Printf("Unable to connect to the pool: %s", err) + return + } + defer connPool.Close() + + // Wait for a queue initialization and master instance identification in + // the queue. + <-h.updated + <-h.updated + if h.err != nil { + fmt.Printf("Unable to identify in the pool: %s", h.err) + return + } + + // Create a Queue object from the ConnectionPool object via + // a ConnectorAdapter. + rw := pool.NewConnectorAdapter(connPool, pool.RW) + q := queue.New(rw, "test_queue") + fmt.Println("A Queue object is ready to work.") + + testData := "test_data" + fmt.Println("Send data:", testData) + if _, err = q.Put(testData); err != nil { + fmt.Printf("Unable to put data into the queue: %s", err) + return + } + + // Switch a master instance in the pool. + roles := []bool{true, false} + for { + ctx, cancel := test_helpers.GetPoolConnectContext() + err := test_helpers.SetClusterRO(ctx, poolDialers, connOpts, roles) + cancel() + if err == nil { + break + } + } + + // Wait for a replica instance connection and a new master instance + // re-identification. + <-h.updated + <-h.updated + h.mutex.Lock() + err = h.err + h.mutex.Unlock() + + if err != nil { + fmt.Printf("Unable to re-identify in the pool: %s", err) + return + } + + for i := 0; i < 2 && atomic.LoadInt32(&h.masterCnt) != 1; i++ { + // The pool does not immediately detect role switching. It may happen + // that requests will be sent to RO instances. In that case q.Take() + // method will return a nil value. + // + // We need to make the example test output deterministic so we need to + // avoid it here. But in real life, you need to take this into account. + time.Sleep(poolOpts.CheckTimeout) + } + + for { + // Take a data from the new master instance. + task, err := q.Take() + + if err == pool.ErrNoRwInstance { + // It may be not registered yet by the pool. + continue + } else if err != nil { + fmt.Println("Unable to got task:", err) + } else if task == nil { + fmt.Println("task == nil") + } else if task.Data() == nil { + fmt.Println("task.Data() == nil") + } else { + task.Ack() + fmt.Println("Got data:", task.Data()) + } + break + } + + // Output: + // Master 127.0.0.1:3014 is ready to work! + // A Queue object is ready to work. + // Send data: test_data + // Master 127.0.0.1:3015 is ready to work! + // Got data: test_data +} diff --git a/queue/example_msgpack_test.go b/queue/example_msgpack_test.go new file mode 100644 index 000000000..53e54dc72 --- /dev/null +++ b/queue/example_msgpack_test.go @@ -0,0 +1,152 @@ +// Setup queue module and start Tarantool instance before execution: +// Terminal 1: +// $ make deps +// $ TEST_TNT_LISTEN=3013 tarantool queue/config.lua +// +// Terminal 2: +// $ cd queue +// $ go test -v example_msgpack_test.go +package queue_test + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/queue" +) + +type dummyData struct { + Dummy bool +} + +func (c *dummyData) DecodeMsgpack(d *msgpack.Decoder) error { + var err error + if c.Dummy, err = d.DecodeBool(); err != nil { + return err + } + return nil +} + +func (c *dummyData) EncodeMsgpack(e *msgpack.Encoder) error { + return e.EncodeBool(c.Dummy) +} + +// Example demonstrates an operations like Put and Take with queue and custom +// MsgPack structure. +// +// Features of the implementation: +// +// - If you use the connection timeout and call TakeWithTimeout with a +// parameter greater than the connection timeout, the parameter is reduced to +// it. +// +// - If you use the connection timeout and call Take, we return an error if we +// cannot take the task out of the queue within the time corresponding to the +// connection timeout. +func Example_simpleQueueCustomMsgPack() { + dialer := tarantool.NetDialer{ + Address: "127.0.0.1:3013", + User: "test", + Password: "test", + } + opts := tarantool.Opts{ + Reconnect: time.Second, + Timeout: 5 * time.Second, + MaxReconnects: 5, + } + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + conn, err := tarantool.Connect(ctx, dialer, opts) + cancel() + if err != nil { + log.Fatalf("connection: %s", err) + return + } + defer conn.Close() + + cfg := queue.Cfg{ + Temporary: true, + IfNotExists: true, + Kind: queue.FIFO, + Opts: queue.Opts{ + Ttl: 20 * time.Second, + Ttr: 10 * time.Second, + Delay: 6 * time.Second, + Pri: 1, + }, + } + + que := queue.New(conn, "test_queue_msgpack") + if err = que.Create(cfg); err != nil { + fmt.Printf("queue create: %s", err) + return + } + + // Put data. + task, err := que.Put("test_data") + if err != nil { + fmt.Printf("put task: %s", err) + return + } + fmt.Println("Task id is", task.Id()) + + // Take data. + task, err = que.Take() // Blocking operation. + if err != nil { + fmt.Printf("take task: %s", err) + return + } + fmt.Println("Data is", task.Data()) + task.Ack() + + // Take typed example. + putData := dummyData{} + // Put data. + task, err = que.Put(&putData) + if err != nil { + fmt.Printf("put typed task: %s", err) + return + } + fmt.Println("Task id is ", task.Id()) + + takeData := dummyData{} + // Take data. + task, err = que.TakeTyped(&takeData) // Blocking operation. + if err != nil { + fmt.Printf("take take typed: %s", err) + return + } + fmt.Println("Data is ", takeData) + // Same data. + fmt.Println("Data is ", task.Data()) + + task, err = que.Put([]int{1, 2, 3}) + if err != nil { + fmt.Printf("Put failed: %s", err) + return + } + task.Bury() + + task, err = que.TakeTimeout(2 * time.Second) + if err != nil { + fmt.Printf("Take with timeout failed: %s", err) + return + } + if task == nil { + fmt.Println("Task is nil") + } + + que.Drop() + + // Unordered output: + // Task id is 0 + // Data is test_data + // Task id is 0 + // Data is {false} + // Data is &{false} + // Task is nil +} diff --git a/queue/example_test.go b/queue/example_test.go new file mode 100644 index 000000000..99efa769b --- /dev/null +++ b/queue/example_test.go @@ -0,0 +1,92 @@ +// Setup queue module and start Tarantool instance before execution: +// Terminal 1: +// $ make deps +// $ TEST_TNT_LISTEN=3013 tarantool queue/config.lua +// +// Terminal 2: +// $ cd queue +// $ go test -v example_test.go +package queue_test + +import ( + "context" + "fmt" + "time" + + "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/queue" +) + +// Example demonstrates an operations like Put and Take with queue. +func Example_simpleQueue() { + cfg := queue.Cfg{ + Temporary: false, + Kind: queue.FIFO, + Opts: queue.Opts{ + Ttl: 10 * time.Second, + }, + } + opts := tarantool.Opts{ + Timeout: 2500 * time.Millisecond, + } + dialer := tarantool.NetDialer{ + Address: "127.0.0.1:3013", + User: "test", + Password: "test", + } + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, dialer, opts) + if err != nil { + fmt.Printf("error in prepare is %v", err) + return + } + defer conn.Close() + + q := queue.New(conn, "test_queue") + if err := q.Create(cfg); err != nil { + fmt.Printf("error in queue is %v", err) + return + } + + defer q.Drop() + + testData_1 := "test_data_1" + if _, err = q.Put(testData_1); err != nil { + fmt.Printf("error in put is %v", err) + return + } + + testData_2 := "test_data_2" + task_2, err := q.PutWithOpts(testData_2, queue.Opts{Ttl: 2 * time.Second}) + if err != nil { + fmt.Printf("error in put with config is %v", err) + return + } + + task, err := q.Take() + if err != nil { + fmt.Printf("error in take with is %v", err) + return + } + task.Ack() + fmt.Println("data_1: ", task.Data()) + + err = task_2.Bury() + if err != nil { + fmt.Printf("error in bury with is %v", err) + return + } + + task, err = q.TakeTimeout(2 * time.Second) + if err != nil { + fmt.Printf("error in take with timeout") + } + if task != nil { + fmt.Printf("Task should be nil, but %d", task.Id()) + return + } + + // Output: data_1: test_data_1 +} diff --git a/queue/queue.go b/queue/queue.go new file mode 100644 index 000000000..c8f968dff --- /dev/null +++ b/queue/queue.go @@ -0,0 +1,494 @@ +// Package with implementation of methods for work with a Tarantool's queue +// implementations. +// +// Since: 1.5. +// +// # See also +// +// * Tarantool queue module https://github.com/tarantool/queue +package queue + +import ( + "fmt" + "time" + + "github.com/google/uuid" + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// Queue is a handle to Tarantool queue's tube. +type Queue interface { + // Set queue settings. + Cfg(opts CfgOpts) error + // Exists checks tube for existence. + // Note: it uses Eval, so user needs 'execute universe' privilege. + Exists() (bool, error) + // Identify to a shared session. + // In the queue the session has a unique UUID and many connections may + // share one logical session. Also, the consumer can reconnect to the + // existing session during the ttr time. + // To get the UUID of the current session, call the Queue.Identify(nil). + Identify(u *uuid.UUID) (uuid.UUID, error) + // Create creates new tube with configuration. + // Note: it uses Eval, so user needs 'execute universe' privilege + // Note: you'd better not use this function in your application, cause it is + // administrative task to create or delete queue. + Create(cfg Cfg) error + // Drop destroys tube. + // Note: you'd better not use this function in your application, cause it is + // administrative task to create or delete queue. + Drop() error + // ReleaseAll forcibly returns all taken tasks to a ready state. + ReleaseAll() error + // Put creates new task in a tube. + Put(data interface{}) (*Task, error) + // PutWithOpts creates new task with options different from tube's defaults. + PutWithOpts(data interface{}, cfg Opts) (*Task, error) + // Take takes 'ready' task from a tube and marks it as 'in progress'. + // Note: if connection has a request Timeout, then 0.9 * connection.Timeout is + // used as a timeout. + // If you use a connection timeout and we can not take task from queue in + // a time equal to the connection timeout after calling `Take` then we + // return an error. + Take() (*Task, error) + // TakeTimeout takes 'ready' task from a tube and marks it as "in progress", + // or it is timeouted after "timeout" period. + // Note: if connection has a request Timeout, and conn.Timeout * 0.9 < timeout + // then timeout = conn.Timeout*0.9. + // If you use connection timeout and call `TakeTimeout` with parameter + // greater than the connection timeout then parameter reduced to it. + TakeTimeout(timeout time.Duration) (*Task, error) + // TakeTyped takes 'ready' task from a tube and marks it as 'in progress' + // Note: if connection has a request Timeout, then 0.9 * connection.Timeout is + // used as a timeout. + // Data will be unpacked to result. + TakeTyped(interface{}) (*Task, error) + // TakeTypedTimeout takes 'ready' task from a tube and marks it as "in progress", + // or it is timeouted after "timeout" period. + // Note: if connection has a request Timeout, and conn.Timeout * 0.9 < timeout + // then timeout = conn.Timeout*0.9. + // Data will be unpacked to result. + TakeTypedTimeout(timeout time.Duration, result interface{}) (*Task, error) + // Peek returns task by its id. + Peek(taskId uint64) (*Task, error) + // Kick reverts effect of Task.Bury() for count tasks. + Kick(count uint64) (uint64, error) + // Delete the task identified by its id. + Delete(taskId uint64) error + // State returns a current queue state. + State() (State, error) + // Statistic returns some statistic about queue. + Statistic() (interface{}, error) +} + +type queue struct { + name string + conn tarantool.Connector + cmds cmd +} + +type cmd struct { + put string + take string + drop string + peek string + touch string + ack string + delete string + bury string + kick string + release string + releaseAll string + cfg string + identify string + state string + statistics string +} + +type Cfg struct { + Temporary bool // If true, the contents do not persist on disk. + IfNotExists bool // If true, no error will be returned if the tube already exists. + Kind queueType + Opts +} + +func (cfg Cfg) toMap() map[string]interface{} { + res := cfg.Opts.toMap() + res["temporary"] = cfg.Temporary + res["if_not_exists"] = cfg.IfNotExists + return res +} + +func (cfg Cfg) getType() string { + kind := string(cfg.Kind) + if kind == "" { + kind = string(FIFO) + } + + return kind +} + +// CfgOpts is argument type for the Queue.Cfg() call. +type CfgOpts struct { + // Enable replication mode. Must be true if the queue is used in master and + // replica mode. With replication mode enabled, the potential loss of + // performance can be ~20% compared to single mode. Default value is false. + InReplicaset bool + // Time to release in seconds. The time after which, if there is no active + // connection in the session, it will be released with all its tasks. + Ttr time.Duration +} + +func (opts CfgOpts) toMap() map[string]interface{} { + ret := make(map[string]interface{}) + ret["in_replicaset"] = opts.InReplicaset + if opts.Ttr != 0 { + ret["ttr"] = opts.Ttr.Seconds() + } + return ret +} + +type Opts struct { + Pri int // Task priorities. + Ttl time.Duration // Task time to live. + Ttr time.Duration // Task time to execute. + Delay time.Duration // Delayed execution. + Utube string +} + +func (opts Opts) toMap() map[string]interface{} { + ret := make(map[string]interface{}) + + if opts.Ttl.Seconds() != 0 { + ret["ttl"] = opts.Ttl.Seconds() + } + + if opts.Ttr.Seconds() != 0 { + ret["ttr"] = opts.Ttr.Seconds() + } + + if opts.Delay.Seconds() != 0 { + ret["delay"] = opts.Delay.Seconds() + } + + if opts.Pri != 0 { + ret["pri"] = opts.Pri + } + + if opts.Utube != "" { + ret["utube"] = opts.Utube + } + + return ret +} + +// New creates a queue handle. +func New(conn tarantool.Connector, name string) Queue { + q := &queue{ + name: name, + conn: conn, + } + makeCmd(q) + return q +} + +// Create creates a new queue with config. +func (q *queue) Create(cfg Cfg) error { + cmd := "local name, type, cfg = ... ; queue.create_tube(name, type, cfg)" + _, err := q.conn.Do(tarantool.NewEvalRequest(cmd). + Args([]interface{}{q.name, cfg.getType(), cfg.toMap()}), + ).Get() + return err +} + +// Set queue settings. +func (q *queue) Cfg(opts CfgOpts) error { + req := tarantool.NewCallRequest(q.cmds.cfg).Args([]interface{}{opts.toMap()}) + _, err := q.conn.Do(req).Get() + return err +} + +// Exists checks existence of a tube. +func (q *queue) Exists() (bool, error) { + cmd := "local name = ... ; return queue.tube[name] ~= nil" + data, err := q.conn.Do(tarantool.NewEvalRequest(cmd). + Args([]string{q.name}), + ).Get() + if err != nil { + return false, err + } + + exist := len(data) != 0 && data[0].(bool) + return exist, nil +} + +// Identify to a shared session. +// In the queue the session has a unique UUID and many connections may share +// one logical session. Also, the consumer can reconnect to the existing +// session during the ttr time. +// To get the UUID of the current session, call the Queue.Identify(nil). +func (q *queue) Identify(u *uuid.UUID) (uuid.UUID, error) { + // Unfortunately we can't use go-tarantool/uuid here: + // https://github.com/tarantool/queue/issues/182 + var args []interface{} + if u == nil { + args = []interface{}{} + } else { + if bytes, err := u.MarshalBinary(); err != nil { + return uuid.UUID{}, err + } else { + args = []interface{}{string(bytes)} + } + } + + req := tarantool.NewCallRequest(q.cmds.identify).Args(args) + if data, err := q.conn.Do(req).Get(); err == nil { + if us, ok := data[0].(string); ok { + return uuid.FromBytes([]byte(us)) + } else { + return uuid.UUID{}, fmt.Errorf("unexpected response: %v", data) + } + } else { + return uuid.UUID{}, err + } +} + +// Put data to queue. Returns task. +func (q *queue) Put(data interface{}) (*Task, error) { + return q.put(data) +} + +// Put data with options (ttl/ttr/pri/delay) to queue. Returns task. +func (q *queue) PutWithOpts(data interface{}, cfg Opts) (*Task, error) { + return q.put(data, cfg.toMap()) +} + +func (q *queue) put(params ...interface{}) (*Task, error) { + qd := queueData{ + result: params[0], + q: q, + } + req := tarantool.NewCallRequest(q.cmds.put).Args(params) + if err := q.conn.Do(req).GetTyped(&qd); err != nil { + return nil, err + } + return qd.task, nil +} + +// The take request searches for a task in the queue. +func (q *queue) Take() (*Task, error) { + var params interface{} + timeout := q.conn.ConfiguredTimeout() + if timeout > 0 { + params = (timeout * 9 / 10).Seconds() + } + return q.take(params) +} + +// The take request searches for a task in the queue. Waits until a task +// becomes ready or the timeout expires. +func (q *queue) TakeTimeout(timeout time.Duration) (*Task, error) { + t := q.conn.ConfiguredTimeout() * 9 / 10 + if t > 0 && timeout > t { + timeout = t + } + return q.take(timeout.Seconds()) +} + +// The take request searches for a task in the queue. +func (q *queue) TakeTyped(result interface{}) (*Task, error) { + var params interface{} + timeout := q.conn.ConfiguredTimeout() + if timeout > 0 { + params = (timeout * 9 / 10).Seconds() + } + return q.take(params, result) +} + +// The take request searches for a task in the queue. Waits until a task +// becomes ready or the timeout expires. +func (q *queue) TakeTypedTimeout(timeout time.Duration, result interface{}) (*Task, error) { + t := q.conn.ConfiguredTimeout() * 9 / 10 + if t > 0 && timeout > t { + timeout = t + } + return q.take(timeout.Seconds(), result) +} + +func (q *queue) take(params interface{}, result ...interface{}) (*Task, error) { + qd := queueData{q: q} + if len(result) > 0 { + qd.result = result[0] + } + req := tarantool.NewCallRequest(q.cmds.take).Args([]interface{}{params}) + if err := q.conn.Do(req).GetTyped(&qd); err != nil { + return nil, err + } + return qd.task, nil +} + +// Drop queue. +func (q *queue) Drop() error { + _, err := q.conn.Do(tarantool.NewCallRequest(q.cmds.drop)).Get() + return err +} + +// ReleaseAll forcibly returns all taken tasks to a ready state. +func (q *queue) ReleaseAll() error { + _, err := q.conn.Do(tarantool.NewCallRequest(q.cmds.releaseAll)).Get() + return err +} + +// Look at a task without changing its state. +func (q *queue) Peek(taskId uint64) (*Task, error) { + qd := queueData{q: q} + req := tarantool.NewCallRequest(q.cmds.peek).Args([]interface{}{taskId}) + if err := q.conn.Do(req).GetTyped(&qd); err != nil { + return nil, err + } + return qd.task, nil +} + +func (q *queue) _touch(taskId uint64, increment time.Duration) (string, error) { + return q.produce(q.cmds.touch, taskId, increment.Seconds()) +} + +func (q *queue) _ack(taskId uint64) (string, error) { + return q.produce(q.cmds.ack, taskId) +} + +func (q *queue) _delete(taskId uint64) (string, error) { + return q.produce(q.cmds.delete, taskId) +} + +func (q *queue) _bury(taskId uint64) (string, error) { + return q.produce(q.cmds.bury, taskId) +} + +func (q *queue) _release(taskId uint64, cfg Opts) (string, error) { + return q.produce(q.cmds.release, taskId, cfg.toMap()) +} +func (q *queue) produce(cmd string, params ...interface{}) (string, error) { + qd := queueData{q: q} + req := tarantool.NewCallRequest(cmd).Args(params) + if err := q.conn.Do(req).GetTyped(&qd); err != nil || qd.task == nil { + return "", err + } + return qd.task.status, nil +} + +type kickResult struct { + id uint64 +} + +func (r *kickResult) DecodeMsgpack(d *msgpack.Decoder) (err error) { + var l int + if l, err = d.DecodeArrayLen(); err != nil { + return err + } + if l > 1 { + return fmt.Errorf("array len doesn't match for queue kick data: %d", l) + } + r.id, err = d.DecodeUint64() + return +} + +// Reverse the effect of a bury request on one or more tasks. +func (q *queue) Kick(count uint64) (uint64, error) { + var r kickResult + req := tarantool.NewCallRequest(q.cmds.kick).Args([]interface{}{count}) + err := q.conn.Do(req).GetTyped(&r) + return r.id, err +} + +// Delete the task identified by its id. +func (q *queue) Delete(taskId uint64) error { + _, err := q._delete(taskId) + return err +} + +// State returns a current queue state. +func (q *queue) State() (State, error) { + data, err := q.conn.Do(tarantool.NewCallRequest(q.cmds.state)).Get() + if err != nil { + return UnknownState, err + } + + if respState, ok := data[0].(string); ok { + if state, ok := strToState[respState]; ok { + return state, nil + } + return UnknownState, fmt.Errorf("unknown state: %v", data[0]) + } + return UnknownState, fmt.Errorf("unexpected response: %v", data) +} + +// Return the number of tasks in a queue broken down by task_state, and the +// number of requests broken down by the type of request. +func (q *queue) Statistic() (interface{}, error) { + req := tarantool.NewCallRequest(q.cmds.statistics).Args([]interface{}{q.name}) + data, err := q.conn.Do(req).Get() + if err != nil { + return nil, err + } + + if len(data) != 0 { + return data[0], nil + } + + return nil, nil +} + +func makeCmd(q *queue) { + q.cmds = cmd{ + put: "queue.tube." + q.name + ":put", + take: "queue.tube." + q.name + ":take", + drop: "queue.tube." + q.name + ":drop", + peek: "queue.tube." + q.name + ":peek", + touch: "queue.tube." + q.name + ":touch", + ack: "queue.tube." + q.name + ":ack", + delete: "queue.tube." + q.name + ":delete", + bury: "queue.tube." + q.name + ":bury", + kick: "queue.tube." + q.name + ":kick", + release: "queue.tube." + q.name + ":release", + releaseAll: "queue.tube." + q.name + ":release_all", + cfg: "queue.cfg", + identify: "queue.identify", + state: "queue.state", + statistics: "queue.statistics", + } +} + +type queueData struct { + q *queue + task *Task + result interface{} +} + +func (qd *queueData) DecodeMsgpack(d *msgpack.Decoder) error { + var err error + var l int + if l, err = d.DecodeArrayLen(); err != nil { + return err + } + if l > 1 { + return fmt.Errorf("array len doesn't match for queue data: %d", l) + } + if l == 0 { + return nil + } + + qd.task = &Task{data: qd.result, q: qd.q} + if err = d.Decode(&qd.task); err != nil { + return err + } + + if qd.task.Data() == nil { + // It may happen if the msgpack.Decoder has a code.Nil value inside. As a + // result, the task will not be decoded. + qd.task = nil + } + return nil +} diff --git a/queue/queue_test.go b/queue/queue_test.go new file mode 100644 index 000000000..81f768e18 --- /dev/null +++ b/queue/queue_test.go @@ -0,0 +1,976 @@ +package queue_test + +import ( + "fmt" + "log" + "math" + "os" + "testing" + "time" + + "github.com/vmihailenco/msgpack/v5" + + . "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/queue" + "github.com/tarantool/go-tarantool/v3/test_helpers" +) + +const ( + user = "test" + pass = "test" +) + +var servers = []string{"127.0.0.1:3014", "127.0.0.1:3015"} +var server = "127.0.0.1:3013" + +var dialer = NetDialer{ + Address: server, + User: user, + Password: pass, +} + +var opts = Opts{ + Timeout: 5 * time.Second, + // Concurrency: 32, + // RateLimit: 4*1024, +} + +func createQueue(t *testing.T, conn *Connection, name string, cfg queue.Cfg) queue.Queue { + t.Helper() + + q := queue.New(conn, name) + if err := q.Create(cfg); err != nil { + t.Fatalf("Failed to create queue: %s", err) + } + + return q +} + +func dropQueue(t *testing.T, q queue.Queue) { + t.Helper() + + if err := q.Drop(); err != nil { + t.Fatalf("Failed to drop queue: %s", err) + } +} + +// ///////QUEUE///////// + +func TestFifoQueue(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + name := "test_queue" + q := createQueue(t, conn, name, queue.Cfg{Temporary: true, Kind: queue.FIFO}) + defer dropQueue(t, q) +} + +func TestQueue_Cfg(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + name := "test_queue" + q := createQueue(t, conn, name, queue.Cfg{Temporary: true, Kind: queue.FIFO}) + defer dropQueue(t, q) + + err := q.Cfg(queue.CfgOpts{InReplicaset: false, Ttr: 5 * time.Second}) + if err != nil { + t.Fatalf("Unexpected q.Cfg() error: %s", err) + } +} + +func TestQueue_Identify(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + name := "test_queue" + q := createQueue(t, conn, name, queue.Cfg{Temporary: true, Kind: queue.FIFO}) + defer dropQueue(t, q) + + uuid, err := q.Identify(nil) + if err != nil { + t.Fatalf("Failed to identify: %s", err) + } + cpy := uuid + + uuid, err = q.Identify(&cpy) + if err != nil { + t.Fatalf("Failed to identify with uuid %s: %s", cpy, err) + } + if cpy.String() != uuid.String() { + t.Fatalf("Unequal UUIDs after re-identify: %s, expected %s", uuid, cpy) + } +} + +func TestQueue_ReIdentify(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer func() { + if conn != nil { + conn.Close() + } + }() + + name := "test_queue" + cfg := queue.Cfg{ + Temporary: true, + Kind: queue.FIFO_TTL, + Opts: queue.Opts{Ttl: 5 * time.Second}, + } + q := createQueue(t, conn, name, cfg) + q.Cfg(queue.CfgOpts{InReplicaset: false, Ttr: 5 * time.Second}) + defer func() { + dropQueue(t, q) + }() + + uuid, err := q.Identify(nil) + if err != nil { + t.Fatalf("Failed to identify: %s", err) + } + newuuid, err := q.Identify(&uuid) + if err != nil { + t.Fatalf("Failed to identify: %s", err) + } + if newuuid.String() != uuid.String() { + t.Fatalf("Unequal UUIDs after re-identify: %s, expected %s", newuuid, uuid) + } + // Put. + putData := "put_data" + task, err := q.Put(putData) + if err != nil { + conn.Close() + t.Fatalf("Failed put to queue: %s", err) + } else if err == nil && task == nil { + t.Fatalf("Task is nil after put") + } else if task.Data() != putData { + t.Errorf("Task data after put not equal with example. %s != %s", task.Data(), putData) + } + + // Take. + task, err = q.TakeTimeout(2 * time.Second) + if err != nil { + t.Fatalf("Failed take from queue: %s", err) + } else if task == nil { + t.Fatalf("Task is nil after take") + } + + conn.Close() + conn = nil + + conn = test_helpers.ConnectWithValidation(t, dialer, opts) + q = queue.New(conn, name) + + // Identify in another connection. + newuuid, err = q.Identify(&uuid) + if err != nil { + t.Fatalf("Failed to identify: %s", err) + } + if newuuid.String() != uuid.String() { + t.Fatalf("Unequal UUIDs after re-identify: %s, expected %s", newuuid, uuid) + } + + // Peek in another connection. + task, err = q.Peek(task.Id()) + if err != nil { + t.Fatalf("Failed take from queue: %s", err) + } else if task == nil { + t.Fatalf("Task is nil after take") + } + + // Ack in another connection. + err = task.Ack() + if err != nil { + t.Errorf("Failed ack %s", err) + } else if !task.IsDone() { + t.Errorf("Task status after take is not done. Status = %s", task.Status()) + } +} + +func TestQueue_State(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + name := "test_queue" + q := createQueue(t, conn, name, queue.Cfg{Temporary: true, Kind: queue.FIFO}) + defer dropQueue(t, q) + + state, err := q.State() + if err != nil { + t.Fatalf("Failed to get queue state: %s", err) + } + if state != queue.InitState && state != queue.RunningState { + t.Fatalf("Unexpected state: %d", state) + } +} + +func TestFifoQueue_GetExist_Statistic(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + name := "test_queue" + q := createQueue(t, conn, name, queue.Cfg{Temporary: true, Kind: queue.FIFO}) + defer dropQueue(t, q) + + ok, err := q.Exists() + if err != nil { + t.Fatalf("Failed to get exist queue: %s", err) + } + if !ok { + t.Fatal("Queue is not found") + } + + putData := "put_data" + _, err = q.Put(putData) + if err != nil { + t.Fatalf("Failed to put queue: %s", err) + } + + stat, err := q.Statistic() + if err != nil { + t.Errorf("Failed to get statistic queue: %s", err) + } else if stat == nil { + t.Error("Statistic is nil") + } +} + +func TestFifoQueue_Put(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + name := "test_queue" + q := createQueue(t, conn, name, queue.Cfg{Temporary: true, Kind: queue.FIFO}) + defer dropQueue(t, q) + + // Put. + putData := "put_data" + task, err := q.Put(putData) + if err != nil { + t.Fatalf("Failed put to queue: %s", err) + } else if err == nil && task == nil { + t.Fatalf("Task is nil after put") + } else if task.Data() != putData { + t.Errorf("Task data after put not equal with example. %s != %s", task.Data(), putData) + } +} + +func TestFifoQueue_Take(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + name := "test_queue" + q := createQueue(t, conn, name, queue.Cfg{Temporary: true, Kind: queue.FIFO}) + defer dropQueue(t, q) + + // Put. + putData := "put_data" + task, err := q.Put(putData) + if err != nil { + t.Fatalf("Failed put to queue: %s", err) + } else if err == nil && task == nil { + t.Fatalf("Task is nil after put") + } else if task.Data() != putData { + t.Errorf("Task data after put not equal with example. %s != %s", task.Data(), putData) + } + + // Take. + task, err = q.TakeTimeout(2 * time.Second) + if err != nil { + t.Errorf("Failed take from queue: %s", err) + } else if task == nil { + t.Errorf("Task is nil after take") + } else { + if task.Data() != putData { + t.Errorf("Task data after take not equal with example. %s != %s", task.Data(), putData) + } + + if !task.IsTaken() { + t.Errorf("Task status after take is not taken. Status = %s", task.Status()) + + } + + err = task.Ack() + if err != nil { + t.Errorf("Failed ack %s", err) + } else if !task.IsDone() { + t.Errorf("Task status after take is not done. Status = %s", task.Status()) + } + } +} + +type customData struct { + customField string +} + +func (c *customData) DecodeMsgpack(d *msgpack.Decoder) error { + var err error + var l int + if l, err = d.DecodeArrayLen(); err != nil { + return err + } + if l != 1 { + return fmt.Errorf("array len doesn't match: %d", l) + } + if c.customField, err = d.DecodeString(); err != nil { + return err + } + return nil +} + +func (c *customData) EncodeMsgpack(e *msgpack.Encoder) error { + if err := e.EncodeArrayLen(1); err != nil { + return err + } + if err := e.EncodeString(c.customField); err != nil { + return err + } + return nil +} + +func TestFifoQueue_TakeTyped(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + name := "test_queue" + q := createQueue(t, conn, name, queue.Cfg{Temporary: true, Kind: queue.FIFO}) + defer dropQueue(t, q) + + // Put. + putData := &customData{customField: "put_data"} + task, err := q.Put(putData) + if err != nil { + t.Fatalf("Failed put to queue: %s", err) + } else if err == nil && task == nil { + t.Fatalf("Task is nil after put") + } else { + typedData, ok := task.Data().(*customData) + if !ok { + t.Errorf("Task data after put has different type. %#v != %#v", + task.Data(), putData) + } + if *typedData != *putData { + t.Errorf("Task data after put not equal with example. %s != %s", + task.Data(), putData) + } + } + + // Take. + takeData := &customData{} + task, err = q.TakeTypedTimeout(2*time.Second, takeData) + if err != nil { + t.Errorf("Failed take from queue: %s", err) + } else if task == nil { + t.Errorf("Task is nil after take") + } else { + typedData, ok := task.Data().(*customData) + if !ok { + t.Errorf("Task data after put has different type. %#v != %#v", + task.Data(), putData) + } + if *typedData != *putData { + t.Errorf("Task data after take not equal with example. %#v != %#v", + task.Data(), putData) + } + if *takeData != *putData { + t.Errorf("Task data after take not equal with example. %#v != %#v", + task.Data(), putData) + } + if !task.IsTaken() { + t.Errorf("Task status after take is not taken. Status = %s", + task.Status()) + } + + err = task.Ack() + if err != nil { + t.Errorf("Failed ack %s", err) + } else if !task.IsDone() { + t.Errorf("Task status after take is not done. Status = %s", + task.Status()) + } + } +} + +func TestFifoQueue_Peek(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + name := "test_queue" + q := createQueue(t, conn, name, queue.Cfg{Temporary: true, Kind: queue.FIFO}) + defer dropQueue(t, q) + + // Put. + putData := "put_data" + task, err := q.Put(putData) + if err != nil { + t.Fatalf("Failed put to queue: %s", err) + } else if err == nil && task == nil { + t.Fatalf("Task is nil after put") + } else if task.Data() != putData { + t.Errorf("Task data after put not equal with example. %s != %s", task.Data(), putData) + } + + // Peek. + task, err = q.Peek(task.Id()) + if err != nil { + t.Errorf("Failed peek from queue: %s", err) + } else if task == nil { + t.Errorf("Task is nil after peek") + } else if task.Data() != putData { + t.Errorf("Task data after peek not equal with example. %s != %s", task.Data(), putData) + } else if !task.IsReady() { + t.Errorf("Task status after peek is not ready. Status = %s", task.Status()) + } +} + +func TestFifoQueue_Bury_Kick(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + name := "test_queue" + q := createQueue(t, conn, name, queue.Cfg{Temporary: true, Kind: queue.FIFO}) + defer dropQueue(t, q) + + // Put. + putData := "put_data" + task, err := q.Put(putData) + if err != nil { + t.Fatalf("Failed put to queue: %s", err) + } else if err == nil && task == nil { + t.Fatalf("Task is nil after put") + } else if task.Data() != putData { + t.Errorf("Task data after put not equal with example. %s != %s", task.Data(), putData) + } + + // Bury. + err = task.Bury() + if err != nil { + t.Fatalf("Failed bury task %s", err) + } else if !task.IsBuried() { + t.Errorf("Task status after bury is not buried. Status = %s", task.Status()) + } + + // Kick. + count, err := q.Kick(1) + if err != nil { + t.Fatalf("Failed kick task %s", err) + } else if count != 1 { + t.Fatalf("Kick result != 1") + } + + // Take. + task, err = q.TakeTimeout(2 * time.Second) + if err != nil { + t.Errorf("Failed take from queue: %s", err) + } else if task == nil { + t.Errorf("Task is nil after take") + } else { + if task.Data() != putData { + t.Errorf("Task data after take not equal with example. %s != %s", task.Data(), putData) + } + + if !task.IsTaken() { + t.Errorf("Task status after take is not taken. Status = %s", task.Status()) + } + + err = task.Ack() + if err != nil { + t.Errorf("Failed ack %s", err) + } else if !task.IsDone() { + t.Errorf("Task status after take is not done. Status = %s", task.Status()) + } + } +} + +func TestFifoQueue_Delete(t *testing.T) { + var err error + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + name := "test_queue" + q := createQueue(t, conn, name, queue.Cfg{Temporary: true, Kind: queue.FIFO}) + defer dropQueue(t, q) + + // Put. + var putData = "put_data" + var tasks = [2]*queue.Task{} + + for i := 0; i < 2; i++ { + tasks[i], err = q.Put(putData) + if err != nil { + t.Fatalf("Failed put to queue: %s", err) + } else if err == nil && tasks[i] == nil { + t.Fatalf("Task is nil after put") + } else if tasks[i].Data() != putData { + t.Errorf( + "Task data after put not equal with example. %s != %s", + tasks[i].Data(), putData) + } + } + + // Delete by task method. + err = tasks[0].Delete() + if err != nil { + t.Fatalf("Failed bury task %s", err) + } else if !tasks[0].IsDone() { + t.Errorf("Task status after delete is not done. Status = %s", tasks[0].Status()) + } + + // Delete by task ID. + err = q.Delete(tasks[1].Id()) + if err != nil { + t.Fatalf("Failed bury task %s", err) + } else if !tasks[0].IsDone() { + t.Errorf("Task status after delete is not done. Status = %s", tasks[0].Status()) + } + + // Take. + for i := 0; i < 2; i++ { + tasks[i], err = q.TakeTimeout(2 * time.Second) + if err != nil { + t.Errorf("Failed take from queue: %s", err) + } else if tasks[i] != nil { + t.Errorf("Task is not nil after take. Task is %d", tasks[i].Id()) + } + } +} + +func TestFifoQueue_Release(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + name := "test_queue" + q := createQueue(t, conn, name, queue.Cfg{Temporary: true, Kind: queue.FIFO}) + defer dropQueue(t, q) + + putData := "put_data" + task, err := q.Put(putData) + if err != nil { + t.Fatalf("Failed put to queue: %s", err) + } else if err == nil && task == nil { + t.Fatalf("Task is nil after put") + } else if task.Data() != putData { + t.Errorf("Task data after put not equal with example. %s != %s", task.Data(), putData) + } + + // Take. + task, err = q.Take() + if err != nil { + t.Fatalf("Failed take from queue: %s", err) + } else if task == nil { + t.Fatal("Task is nil after take") + } + + // Release. + err = task.Release() + if err != nil { + t.Fatalf("Failed release task %s", err) + } + + if !task.IsReady() { + t.Fatalf("Task status is not ready, but %s", task.Status()) + } + + // Take. + task, err = q.Take() + if err != nil { + t.Fatalf("Failed take from queue: %s", err) + } else if task == nil { + t.Fatal("Task is nil after take") + } else { + if task.Data() != putData { + t.Errorf("Task data after take not equal with example. %s != %s", task.Data(), putData) + } + + if !task.IsTaken() { + t.Errorf("Task status after take is not taken. Status = %s", task.Status()) + } + + err = task.Ack() + if err != nil { + t.Errorf("Failed ack %s", err) + } else if !task.IsDone() { + t.Errorf("Task status after take is not done. Status = %s", task.Status()) + } + } +} + +func TestQueue_ReleaseAll(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + name := "test_queue" + q := createQueue(t, conn, name, queue.Cfg{Temporary: true, Kind: queue.FIFO}) + defer dropQueue(t, q) + + putData := "put_data" + task, err := q.Put(putData) + if err != nil { + t.Fatalf("Failed put to queue: %s", err) + } else if err == nil && task == nil { + t.Fatalf("Task is nil after put") + } else if task.Data() != putData { + t.Errorf("Task data after put not equal with example. %s != %s", task.Data(), putData) + } + + // Take. + task, err = q.Take() + if err != nil { + t.Fatalf("Failed take from queue: %s", err) + } else if task == nil { + t.Fatal("Task is nil after take") + } + + // ReleaseAll. + err = q.ReleaseAll() + if err != nil { + t.Fatalf("Failed release task %s", err) + } + + task, err = q.Peek(task.Id()) + if err != nil { + t.Fatalf("Failed to peek task %s", err) + } + if !task.IsReady() { + t.Fatalf("Task status is not ready, but %s", task.Status()) + } + + // Take. + task, err = q.Take() + if err != nil { + t.Fatalf("Failed take from queue: %s", err) + } else if task == nil { + t.Fatal("Task is nil after take") + } else { + if task.Data() != putData { + t.Errorf("Task data after take not equal with example. %s != %s", task.Data(), putData) + } + + if !task.IsTaken() { + t.Errorf("Task status after take is not taken. Status = %s", task.Status()) + } + + err = task.Ack() + if err != nil { + t.Errorf("Failed ack %s", err) + } else if !task.IsDone() { + t.Errorf("Task status after take is not done. Status = %s", task.Status()) + } + } +} + +func TestTtlQueue(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + name := "test_queue" + cfg := queue.Cfg{ + Temporary: true, + Kind: queue.FIFO_TTL, + Opts: queue.Opts{Ttl: 5 * time.Second}, + } + q := createQueue(t, conn, name, cfg) + defer dropQueue(t, q) + + putData := "put_data" + task, err := q.Put(putData) + if err != nil { + t.Fatalf("Failed put to queue: %s", err) + } else if err == nil && task == nil { + t.Fatalf("Task is nil after put") + } else if task.Data() != putData { + t.Errorf("Task data after put not equal with example. %s != %s", task.Data(), putData) + } + + time.Sleep(10 * time.Second) + + // Take. + task, err = q.TakeTimeout(2 * time.Second) + if err != nil { + t.Errorf("Failed take from queue: %s", err) + } else if task != nil { + t.Errorf("Task is not nil after sleep") + } +} + +func TestTtlQueue_Put(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + name := "test_queue" + cfg := queue.Cfg{ + Temporary: true, + Kind: queue.FIFO_TTL, + Opts: queue.Opts{Ttl: 5 * time.Second}, + } + q := createQueue(t, conn, name, cfg) + defer dropQueue(t, q) + + putData := "put_data" + task, err := q.PutWithOpts(putData, queue.Opts{Ttl: 10 * time.Second}) + if err != nil { + t.Fatalf("Failed put to queue: %s", err) + } else if err == nil && task == nil { + t.Fatalf("Task is nil after put") + } else if task.Data() != putData { + t.Errorf("Task data after put not equal with example. %s != %s", task.Data(), putData) + } + + time.Sleep(5 * time.Second) + + // Take. + task, err = q.TakeTimeout(2 * time.Second) + if err != nil { + t.Errorf("Failed take from queue: %s", err) + } else if task == nil { + t.Errorf("Task is nil after sleep") + } else { + if task.Data() != putData { + t.Errorf("Task data after take not equal with example. %s != %s", task.Data(), putData) + } + + if !task.IsTaken() { + t.Errorf("Task status after take is not taken. Status = %s", task.Status()) + } + + err = task.Ack() + if err != nil { + t.Errorf("Failed ack %s", err) + } else if !task.IsDone() { + t.Errorf("Task status after take is not done. Status = %s", task.Status()) + } + } +} + +func TestUtube_Put(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + name := "test_utube" + cfg := queue.Cfg{ + Temporary: true, + Kind: queue.UTUBE, + IfNotExists: true, + } + q := createQueue(t, conn, name, cfg) + defer dropQueue(t, q) + + data1 := &customData{"test-data-0"} + _, err := q.PutWithOpts(data1, queue.Opts{Utube: "test-utube-consumer-key"}) + if err != nil { + t.Fatalf("Failed put task to queue: %s", err) + } + data2 := &customData{"test-data-1"} + _, err = q.PutWithOpts(data2, queue.Opts{Utube: "test-utube-consumer-key"}) + if err != nil { + t.Fatalf("Failed put task to queue: %s", err) + } + + errChan := make(chan struct{}) + go func() { + t1, err := q.TakeTimeout(2 * time.Second) + if err != nil { + t.Errorf("Failed to take task from utube: %s", err) + errChan <- struct{}{} + return + } + + time.Sleep(2 * time.Second) + if err := t1.Ack(); err != nil { + t.Errorf("Failed to ack task: %s", err) + errChan <- struct{}{} + return + } + close(errChan) + }() + + time.Sleep(500 * time.Millisecond) + // the queue should be blocked for ~2 seconds + start := time.Now() + t2, err := q.TakeTimeout(2 * time.Second) + if err != nil { + <-errChan + t.Fatalf("Failed to take task from utube: %s", err) + } + + if t2 == nil { + <-errChan + t.Fatalf("Got nil task") + } + + if err := t2.Ack(); err != nil { + <-errChan + t.Fatalf("Failed to ack task: %s", err) + } + end := time.Now() + if _, ok := <-errChan; ok { + t.Fatalf("One of tasks failed") + } + + timeSpent := math.Abs(float64(end.Sub(start) - 2*time.Second)) + + if timeSpent > float64(700*time.Millisecond) { + t.Fatalf("Blocking time is less than expected: actual = %.2fs, expected = 1s", + end.Sub(start).Seconds()) + } +} + +func TestTask_Touch(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + tests := []struct { + name string + cfg queue.Cfg + ok bool + }{ + {"test_queue", + queue.Cfg{ + Temporary: true, + Kind: queue.FIFO, + }, + false, + }, + {"test_queue_ttl", + queue.Cfg{ + Temporary: true, + Kind: queue.FIFO_TTL, + Opts: queue.Opts{Ttl: 5 * time.Second}, + }, + true, + }, + {"test_utube", + queue.Cfg{ + Temporary: true, + Kind: queue.UTUBE, + }, + false, + }, + {"test_utube_ttl", + queue.Cfg{ + Temporary: true, + Kind: queue.UTUBE_TTL, + Opts: queue.Opts{Ttl: 5 * time.Second}, + }, + true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var task *queue.Task + + q := createQueue(t, conn, tc.name, tc.cfg) + defer func() { + if task != nil { + if err := task.Ack(); err != nil { + t.Fatalf("Failed to Ack: %s", err) + } + } + dropQueue(t, q) + }() + + putData := "put_data" + _, err := q.PutWithOpts(putData, + queue.Opts{ + Ttl: 10 * time.Second, + Utube: "test_utube", + }) + if err != nil { + t.Fatalf("Failed put a task: %s", err) + } + + task, err = q.TakeTimeout(2 * time.Second) + if err != nil { + t.Fatalf("Failed to take task from utube: %s", err) + } + + err = task.Touch(1 * time.Second) + if tc.ok && err != nil { + t.Fatalf("Failed to touch: %s", err) + } else if !tc.ok && err == nil { + t.Fatalf("Unexpected success") + } + }) + } +} + +// runTestMain is a body of TestMain function +// (see https://pkg.go.dev/testing#hdr-Main). +// Using defer + os.Exit is not works so TestMain body +// is a separate function, see +// https://stackoverflow.com/questions/27629380/how-to-exit-a-go-program-honoring-deferred-calls +func runTestMain(m *testing.M) int { + inst, err := test_helpers.StartTarantool(test_helpers.StartOpts{ + Dialer: dialer, + InitScript: "testdata/config.lua", + Listen: server, + WaitStart: 100 * time.Millisecond, + ConnectRetry: 10, + RetryTimeout: 500 * time.Millisecond, + }) + + if err != nil { + log.Fatalf("Failed to prepare test tarantool: %s", err) + } + + defer test_helpers.StopTarantoolWithCleanup(inst) + + poolInstsOpts := make([]test_helpers.StartOpts, 0, len(servers)) + for _, serv := range servers { + poolInstsOpts = append(poolInstsOpts, test_helpers.StartOpts{ + Listen: serv, + Dialer: NetDialer{ + Address: serv, + User: user, + Password: pass, + }, + InitScript: "testdata/pool.lua", + WaitStart: 3 * time.Second, // replication_timeout * 3 + ConnectRetry: -1, + }) + } + + instances, err := test_helpers.StartTarantoolInstances(poolInstsOpts) + + if err != nil { + log.Printf("Failed to prepare test tarantool pool: %s", err) + return 1 + } + + defer test_helpers.StopTarantoolInstances(instances) + + for i := 0; i < 10; i++ { + // We need to skip bootstrap errors and to make sure that cluster is + // configured. + roles := []bool{false, true} + connOpts := Opts{ + Timeout: 500 * time.Millisecond, + } + dialers := make([]Dialer, 0, len(servers)) + for _, serv := range servers { + dialers = append(dialers, NetDialer{ + Address: serv, + User: user, + Password: pass, + }) + } + + ctx, cancel := test_helpers.GetPoolConnectContext() + err = test_helpers.SetClusterRO(ctx, dialers, connOpts, roles) + cancel() + if err == nil { + break + } + time.Sleep(time.Second) + } + + if err != nil { + log.Printf("Failed to set roles in tarantool pool: %s", err) + return 1 + } + return m.Run() +} + +func TestMain(m *testing.M) { + code := runTestMain(m) + os.Exit(code) +} diff --git a/queue/task.go b/queue/task.go new file mode 100644 index 000000000..db970884e --- /dev/null +++ b/queue/task.go @@ -0,0 +1,121 @@ +package queue + +import ( + "fmt" + "time" + + "github.com/vmihailenco/msgpack/v5" +) + +// Task represents a task from Tarantool queue's tube. +type Task struct { + id uint64 + status string + data interface{} + q *queue +} + +func (t *Task) DecodeMsgpack(d *msgpack.Decoder) error { + var err error + var l int + if l, err = d.DecodeArrayLen(); err != nil { + return err + } + if l < 3 { + return fmt.Errorf("array len doesn't match: %d", l) + } + if t.id, err = d.DecodeUint64(); err != nil { + return err + } + if t.status, err = d.DecodeString(); err != nil { + return err + } + if t.data != nil { + d.Decode(t.data) + } else if t.data, err = d.DecodeInterface(); err != nil { + return err + } + return nil +} + +// Id is a getter for task id. +func (t *Task) Id() uint64 { + return t.id +} + +// Data is a getter for task data. +func (t *Task) Data() interface{} { + return t.data +} + +// Status is a getter for task status. +func (t *Task) Status() string { + return t.status +} + +// Touch increases ttr of running task. +func (t *Task) Touch(increment time.Duration) error { + return t.accept(t.q._touch(t.id, increment)) +} + +// Ack signals about task completion. +func (t *Task) Ack() error { + return t.accept(t.q._ack(t.id)) +} + +// Delete task from queue. +func (t *Task) Delete() error { + return t.accept(t.q._delete(t.id)) +} + +// Bury signals that task task cannot be executed in the current circumstances, +// task becomes "buried" - ie neither completed, nor ready, so it could not be +// deleted or taken by other worker. +// To revert "burying" call queue.Kick(numberOfBurried). +func (t *Task) Bury() error { + return t.accept(t.q._bury(t.id)) +} + +// Release returns task back in the queue without making it complete. +// In other words, this worker failed to complete the task, and +// it, so other worker could try to do that again. +func (t *Task) Release() error { + return t.accept(t.q._release(t.id, Opts{})) +} + +// ReleaseCfg returns task to a queue and changes its configuration. +func (t *Task) ReleaseCfg(cfg Opts) error { + return t.accept(t.q._release(t.id, cfg)) +} + +func (t *Task) accept(newStatus string, err error) error { + if err == nil { + t.status = newStatus + } + return err +} + +// IsReady returns if task is ready. +func (t *Task) IsReady() bool { + return t.status == READY +} + +// IsTaken returns if task is taken. +func (t *Task) IsTaken() bool { + return t.status == TAKEN +} + +// IsDone returns if task is done. +func (t *Task) IsDone() bool { + return t.status == DONE +} + +// IsBurred returns if task is buried. +func (t *Task) IsBuried() bool { + return t.status == BURIED +} + +// IsDelayed returns if task is delayed. +func (t *Task) IsDelayed() bool { + return t.status == DELAYED +} diff --git a/queue/testdata/config.lua b/queue/testdata/config.lua new file mode 100644 index 000000000..e0adc069c --- /dev/null +++ b/queue/testdata/config.lua @@ -0,0 +1,79 @@ +-- configure path so that you can run application +-- from outside the root directory +if package.setsearchroot ~= nil then + package.setsearchroot() +else + -- Workaround for rocks loading in tarantool 1.10 + -- It can be removed in tarantool > 2.2 + -- By default, when you do require('mymodule'), tarantool looks into + -- the current working directory and whatever is specified in + -- package.path and package.cpath. If you run your app while in the + -- root directory of that app, everything goes fine, but if you try to + -- start your app with "tarantool myapp/init.lua", it will fail to load + -- its modules, and modules from myapp/.rocks. + local fio = require('fio') + local app_dir = fio.abspath(fio.dirname(arg[0])) + package.path = app_dir .. '/?.lua;' .. package.path + package.path = app_dir .. '/?/init.lua;' .. package.path + package.path = app_dir .. '/.rocks/share/tarantool/?.lua;' .. package.path + package.path = app_dir .. '/.rocks/share/tarantool/?/init.lua;' .. package.path + package.cpath = app_dir .. '/?.so;' .. package.cpath + package.cpath = app_dir .. '/?.dylib;' .. package.cpath + package.cpath = app_dir .. '/.rocks/lib/tarantool/?.so;' .. package.cpath + package.cpath = app_dir .. '/.rocks/lib/tarantool/?.dylib;' .. package.cpath +end + +local queue = require('queue') +rawset(_G, 'queue', queue) + +-- Do not set listen for now so connector won't be +-- able to send requests until everything is configured. +box.cfg{ + work_dir = os.getenv("TEST_TNT_WORK_DIR"), +} + +box.once("init", function() + box.schema.user.create('test', {password = 'test'}) + box.schema.func.create('queue.tube.test_queue:touch') + box.schema.func.create('queue.tube.test_queue:ack') + box.schema.func.create('queue.tube.test_queue:put') + box.schema.func.create('queue.tube.test_queue:drop') + box.schema.func.create('queue.tube.test_queue:peek') + box.schema.func.create('queue.tube.test_queue:kick') + box.schema.func.create('queue.tube.test_queue:take') + box.schema.func.create('queue.tube.test_queue:delete') + box.schema.func.create('queue.tube.test_queue:release') + box.schema.func.create('queue.tube.test_queue:release_all') + box.schema.func.create('queue.tube.test_queue:bury') + box.schema.func.create('queue.identify') + box.schema.func.create('queue.state') + box.schema.func.create('queue.statistics') + box.schema.user.grant('test', 'create,read,write,drop', 'space') + box.schema.user.grant('test', 'read, write', 'space', '_queue_session_ids') + box.schema.user.grant('test', 'execute', 'universe') + box.schema.user.grant('test', 'read,write', 'space', '_queue') + box.schema.user.grant('test', 'read,write', 'space', '_schema') + box.schema.user.grant('test', 'read,write', 'space', '_space_sequence') + box.schema.user.grant('test', 'read,write', 'space', '_space') + box.schema.user.grant('test', 'read,write', 'space', '_index') + box.schema.user.grant('test', 'read,write', 'space', '_priv') + if box.space._trigger ~= nil then + box.schema.user.grant('test', 'read', 'space', '_trigger') + end + if box.space._fk_constraint ~= nil then + box.schema.user.grant('test', 'read', 'space', '_fk_constraint') + end + if box.space._ck_constraint ~= nil then + box.schema.user.grant('test', 'read', 'space', '_ck_constraint') + end + if box.space._func_index ~= nil then + box.schema.user.grant('test', 'read', 'space', '_func_index') + end +end) + +-- Set listen only when every other thing is configured. +box.cfg{ + listen = os.getenv("TEST_TNT_LISTEN"), +} + +require('console').start() diff --git a/queue/testdata/pool.lua b/queue/testdata/pool.lua new file mode 100644 index 000000000..9ca98bbf1 --- /dev/null +++ b/queue/testdata/pool.lua @@ -0,0 +1,84 @@ +-- configure path so that you can run application +-- from outside the root directory +if package.setsearchroot ~= nil then + package.setsearchroot() +else + -- Workaround for rocks loading in tarantool 1.10 + -- It can be removed in tarantool > 2.2 + -- By default, when you do require('mymodule'), tarantool looks into + -- the current working directory and whatever is specified in + -- package.path and package.cpath. If you run your app while in the + -- root directory of that app, everything goes fine, but if you try to + -- start your app with "tarantool myapp/init.lua", it will fail to load + -- its modules, and modules from myapp/.rocks. + local fio = require('fio') + local app_dir = fio.abspath(fio.dirname(arg[0])) + package.path = app_dir .. '/?.lua;' .. package.path + package.path = app_dir .. '/?/init.lua;' .. package.path + package.path = app_dir .. '/.rocks/share/tarantool/?.lua;' .. package.path + package.path = app_dir .. '/.rocks/share/tarantool/?/init.lua;' .. package.path + package.cpath = app_dir .. '/?.so;' .. package.cpath + package.cpath = app_dir .. '/?.dylib;' .. package.cpath + package.cpath = app_dir .. '/.rocks/lib/tarantool/?.so;' .. package.cpath + package.cpath = app_dir .. '/.rocks/lib/tarantool/?.dylib;' .. package.cpath +end + +local queue = require('queue') +rawset(_G, 'queue', queue) +-- queue.cfg({in_replicaset = true}) should be called before box.cfg({}) +-- https://github.com/tarantool/queue/issues/206 +queue.cfg({in_replicaset = true, ttr = 60}) + +local listen = os.getenv("TEST_TNT_LISTEN") +box.cfg{ + work_dir = os.getenv("TEST_TNT_WORK_DIR"), + listen = listen, + replication = { + "test:test@127.0.0.1:3014", + "test:test@127.0.0.1:3015", + }, + read_only = listen == "127.0.0.1:3015" +} + +box.once("schema", function() + box.schema.user.create('test', {password = 'test'}) + box.schema.user.grant('test', 'replication') + + box.schema.func.create('queue.tube.test_queue:touch') + box.schema.func.create('queue.tube.test_queue:ack') + box.schema.func.create('queue.tube.test_queue:put') + box.schema.func.create('queue.tube.test_queue:drop') + box.schema.func.create('queue.tube.test_queue:peek') + box.schema.func.create('queue.tube.test_queue:kick') + box.schema.func.create('queue.tube.test_queue:take') + box.schema.func.create('queue.tube.test_queue:delete') + box.schema.func.create('queue.tube.test_queue:release') + box.schema.func.create('queue.tube.test_queue:release_all') + box.schema.func.create('queue.tube.test_queue:bury') + box.schema.func.create('queue.identify') + box.schema.func.create('queue.state') + box.schema.func.create('queue.statistics') + box.schema.user.grant('test', 'create,read,write,drop', 'space') + box.schema.user.grant('test', 'read, write', 'space', '_queue_session_ids') + box.schema.user.grant('test', 'execute', 'universe') + box.schema.user.grant('test', 'read,write', 'space', '_queue') + box.schema.user.grant('test', 'read,write', 'space', '_schema') + box.schema.user.grant('test', 'read,write', 'space', '_space_sequence') + box.schema.user.grant('test', 'read,write', 'space', '_space') + box.schema.user.grant('test', 'read,write', 'space', '_index') + box.schema.user.grant('test', 'read,write', 'space', '_priv') + if box.space._trigger ~= nil then + box.schema.user.grant('test', 'read', 'space', '_trigger') + end + if box.space._fk_constraint ~= nil then + box.schema.user.grant('test', 'read', 'space', '_fk_constraint') + end + if box.space._ck_constraint ~= nil then + box.schema.user.grant('test', 'read', 'space', '_ck_constraint') + end + if box.space._func_index ~= nil then + box.schema.user.grant('test', 'read', 'space', '_func_index') + end +end) + +require('console').start() diff --git a/request.go b/request.go index a8e721833..c18b3aeb2 100644 --- a/request.go +++ b/request.go @@ -1,159 +1,1603 @@ package tarantool -import( - "github.com/vmihailenco/msgpack" +import ( + "context" "errors" + "fmt" + "io" + "reflect" + "strings" + "sync" + + "github.com/tarantool/go-iproto" + "github.com/vmihailenco/msgpack/v5" ) -type Request struct { - conn *Connection - requestId uint32 - requestCode int32 - body map[int]interface{} +type spaceEncoder struct { + Id uint32 + Name string + IsId bool } -func (conn *Connection) NewRequest(requestCode int32) (req *Request) { - req = &Request{} - req.conn = conn - req.requestId = conn.nextRequestId() - req.requestCode = requestCode - req.body = make(map[int]interface{}) +func newSpaceEncoder(res SchemaResolver, spaceInfo interface{}) (spaceEncoder, error) { + if res.NamesUseSupported() { + if spaceName, ok := spaceInfo.(string); ok { + return spaceEncoder{ + Id: 0, + Name: spaceName, + IsId: false, + }, nil + } + } + + spaceId, err := res.ResolveSpace(spaceInfo) + return spaceEncoder{ + Id: spaceId, + IsId: true, + }, err +} - return +func (e spaceEncoder) Encode(enc *msgpack.Encoder) error { + if e.IsId { + if err := enc.EncodeUint(uint64(iproto.IPROTO_SPACE_ID)); err != nil { + return err + } + return enc.EncodeUint(uint64(e.Id)) + } + if err := enc.EncodeUint(uint64(iproto.IPROTO_SPACE_NAME)); err != nil { + return err + } + return enc.EncodeString(e.Name) } -func (conn *Connection) Ping() (resp *Response, err error) { - request := conn.NewRequest(PingRequest) - resp, err = request.perform() - return +type indexEncoder struct { + Id uint32 + Name string + IsId bool } -func (conn *Connection) Select(spaceNo, indexNo, offset, limit, iterator uint32, key []interface{}) (resp *Response, err error) { - request := conn.NewRequest(SelectRequest) +func newIndexEncoder(res SchemaResolver, indexInfo interface{}, + spaceNo uint32) (indexEncoder, error) { + if res.NamesUseSupported() { + if indexName, ok := indexInfo.(string); ok { + return indexEncoder{ + Name: indexName, + IsId: false, + }, nil + } + } - request.body[KeySpaceNo] = spaceNo - request.body[KeyIndexNo] = indexNo - request.body[KeyIterator] = iterator - request.body[KeyOffset] = offset - request.body[KeyLimit] = limit - request.body[KeyKey] = key - - resp, err = request.perform() - return + indexId, err := res.ResolveIndex(indexInfo, spaceNo) + return indexEncoder{ + Id: indexId, + IsId: true, + }, err } -func (conn *Connection) Insert(spaceNo uint32, tuple []interface{}) (resp *Response, err error) { - request := conn.NewRequest(InsertRequest) +func (e indexEncoder) Encode(enc *msgpack.Encoder) error { + if e.IsId { + if err := enc.EncodeUint(uint64(iproto.IPROTO_INDEX_ID)); err != nil { + return err + } + return enc.EncodeUint(uint64(e.Id)) + } + if err := enc.EncodeUint(uint64(iproto.IPROTO_INDEX_NAME)); err != nil { + return err + } + return enc.EncodeString(e.Name) +} - request.body[KeySpaceNo] = spaceNo - request.body[KeyTuple] = tuple +func fillSearch(enc *msgpack.Encoder, spaceEnc spaceEncoder, indexEnc indexEncoder, + key interface{}) error { + if err := spaceEnc.Encode(enc); err != nil { + return err + } - resp, err = request.perform() - return + if err := indexEnc.Encode(enc); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(iproto.IPROTO_KEY)); err != nil { + return err + } + return enc.Encode(key) } -func (conn *Connection) Replace(spaceNo uint32, tuple []interface{}) (resp *Response, err error) { - request := conn.NewRequest(ReplaceRequest) +// Ping sends empty request to Tarantool to check connection. +// +// Deprecated: the method will be removed in the next major version, +// use a PingRequest object + Do() instead. +func (conn *Connection) Ping() ([]interface{}, error) { + return conn.Do(NewPingRequest()).Get() +} - request.body[KeySpaceNo] = spaceNo - request.body[KeyTuple] = tuple +// Select performs select to box space. +// +// It is equal to conn.SelectAsync(...).Get(). +// +// Deprecated: the method will be removed in the next major version, +// use a SelectRequest object + Do() instead. +func (conn *Connection) Select(space, index interface{}, offset, limit uint32, iterator Iter, + key interface{}) ([]interface{}, error) { + return conn.SelectAsync(space, index, offset, limit, iterator, key).Get() +} - resp, err = request.perform() - return +// Insert performs insertion to box space. +// Tarantool will reject Insert when tuple with same primary key exists. +// +// It is equal to conn.InsertAsync(space, tuple).Get(). +// +// Deprecated: the method will be removed in the next major version, +// use an InsertRequest object + Do() instead. +func (conn *Connection) Insert(space interface{}, tuple interface{}) ([]interface{}, error) { + return conn.InsertAsync(space, tuple).Get() } -func (conn *Connection) Delete(spaceNo, indexNo uint32, key []interface{}) (resp *Response, err error) { - request := conn.NewRequest(DeleteRequest) +// Replace performs "insert or replace" action to box space. +// If tuple with same primary key exists, it will be replaced. +// +// It is equal to conn.ReplaceAsync(space, tuple).Get(). +// +// Deprecated: the method will be removed in the next major version, +// use a ReplaceRequest object + Do() instead. +func (conn *Connection) Replace(space interface{}, tuple interface{}) ([]interface{}, error) { + return conn.ReplaceAsync(space, tuple).Get() +} - request.body[KeySpaceNo] = spaceNo - request.body[KeyIndexNo] = indexNo - request.body[KeyKey] = key +// Delete performs deletion of a tuple by key. +// Result will contain array with deleted tuple. +// +// It is equal to conn.DeleteAsync(space, tuple).Get(). +// +// Deprecated: the method will be removed in the next major version, +// use a DeleteRequest object + Do() instead. +func (conn *Connection) Delete(space, index interface{}, key interface{}) ([]interface{}, error) { + return conn.DeleteAsync(space, index, key).Get() +} - resp, err = request.perform() - return +// Update performs update of a tuple by key. +// Result will contain array with updated tuple. +// +// It is equal to conn.UpdateAsync(space, tuple).Get(). +// +// Deprecated: the method will be removed in the next major version, +// use a UpdateRequest object + Do() instead. +func (conn *Connection) Update(space, index, key interface{}, + ops *Operations) ([]interface{}, error) { + return conn.UpdateAsync(space, index, key, ops).Get() } -func (conn *Connection) Update(spaceNo, indexNo uint32, key, tuple []interface{}) (resp *Response, err error) { - request := conn.NewRequest(UpdateRequest) +// Upsert performs "update or insert" action of a tuple by key. +// Result will not contain any tuple. +// +// It is equal to conn.UpsertAsync(space, tuple, ops).Get(). +// +// Deprecated: the method will be removed in the next major version, +// use a UpsertRequest object + Do() instead. +func (conn *Connection) Upsert(space, tuple interface{}, ops *Operations) ([]interface{}, error) { + return conn.UpsertAsync(space, tuple, ops).Get() +} + +// Call calls registered Tarantool function. +// It uses request code for Tarantool >= 1.7, result is an array. +// +// It is equal to conn.CallAsync(functionName, args).Get(). +// +// Deprecated: the method will be removed in the next major version, +// use a CallRequest object + Do() instead. +func (conn *Connection) Call(functionName string, args interface{}) ([]interface{}, error) { + return conn.CallAsync(functionName, args).Get() +} - request.body[KeySpaceNo] = spaceNo - request.body[KeyIndexNo] = indexNo - request.body[KeyKey] = key - request.body[KeyTuple] = tuple +// Call16 calls registered Tarantool function. +// It uses request code for Tarantool 1.6, result is an array of arrays. +// Deprecated since Tarantool 1.7.2. +// +// It is equal to conn.Call16Async(functionName, args).Get(). +// +// Deprecated: the method will be removed in the next major version, +// use a Call16Request object + Do() instead. +func (conn *Connection) Call16(functionName string, args interface{}) ([]interface{}, error) { + return conn.Call16Async(functionName, args).Get() +} + +// Call17 calls registered Tarantool function. +// It uses request code for Tarantool >= 1.7, result is an array. +// +// It is equal to conn.Call17Async(functionName, args).Get(). +// +// Deprecated: the method will be removed in the next major version, +// use a Call17Request object + Do() instead. +func (conn *Connection) Call17(functionName string, args interface{}) ([]interface{}, error) { + return conn.Call17Async(functionName, args).Get() +} - resp, err = request.perform() - return +// Eval passes Lua expression for evaluation. +// +// It is equal to conn.EvalAsync(space, tuple).Get(). +// +// Deprecated: the method will be removed in the next major version, +// use an EvalRequest object + Do() instead. +func (conn *Connection) Eval(expr string, args interface{}) ([]interface{}, error) { + return conn.EvalAsync(expr, args).Get() } -func (conn *Connection) Call(functionName string, tuple []interface{}) (resp *Response, err error) { - request := conn.NewRequest(CallRequest) +// Execute passes sql expression to Tarantool for execution. +// +// It is equal to conn.ExecuteAsync(expr, args).Get(). +// Since 1.6.0 +// +// Deprecated: the method will be removed in the next major version, +// use an ExecuteRequest object + Do() instead. +func (conn *Connection) Execute(expr string, args interface{}) ([]interface{}, error) { + return conn.ExecuteAsync(expr, args).Get() +} - request.body[KeyFunctionName] = functionName - request.body[KeyTuple] = tuple +// single used for conn.GetTyped for decode one tuple. +type single struct { + res interface{} + found bool +} - resp, err = request.perform() - return +func (s *single) DecodeMsgpack(d *msgpack.Decoder) error { + var err error + var len int + if len, err = d.DecodeArrayLen(); err != nil { + return err + } + if s.found = len >= 1; !s.found { + return nil + } + if len != 1 { + return errors.New("tarantool returns unexpected value for Select(limit=1)") + } + return d.Decode(s.res) } +// GetTyped performs select (with limit = 1 and offset = 0) +// to box space and fills typed result. // -// To be implemented +// It is equal to conn.SelectAsync(space, index, 0, 1, IterEq, key).GetTyped(&result) // -func (conn *Connection) Auth(key, tuple []interface{}) (resp *Response, err error) { - return +// Deprecated: the method will be removed in the next major version, +// use a SelectRequest object + Do() instead. +func (conn *Connection) GetTyped(space, index interface{}, key interface{}, + result interface{}) error { + s := single{res: result} + return conn.SelectAsync(space, index, 0, 1, IterEq, key).GetTyped(&s) } +// SelectTyped performs select to box space and fills typed result. +// +// It is equal to conn.SelectAsync(space, index, offset, limit, iterator, key).GetTyped(&result) +// +// Deprecated: the method will be removed in the next major version, +// use a SelectRequest object + Do() instead. +func (conn *Connection) SelectTyped(space, index interface{}, offset, limit uint32, iterator Iter, + key interface{}, result interface{}) error { + return conn.SelectAsync(space, index, offset, limit, iterator, key).GetTyped(result) +} + +// InsertTyped performs insertion to box space. +// Tarantool will reject Insert when tuple with same primary key exists. +// +// It is equal to conn.InsertAsync(space, tuple).GetTyped(&result). +// +// Deprecated: the method will be removed in the next major version, +// use an InsertRequest object + Do() instead. +func (conn *Connection) InsertTyped(space interface{}, tuple interface{}, + result interface{}) error { + return conn.InsertAsync(space, tuple).GetTyped(result) +} + +// ReplaceTyped performs "insert or replace" action to box space. +// If tuple with same primary key exists, it will be replaced. +// +// It is equal to conn.ReplaceAsync(space, tuple).GetTyped(&result). +// +// Deprecated: the method will be removed in the next major version, +// use a ReplaceRequest object + Do() instead. +func (conn *Connection) ReplaceTyped(space interface{}, tuple interface{}, + result interface{}) error { + return conn.ReplaceAsync(space, tuple).GetTyped(result) +} + +// DeleteTyped performs deletion of a tuple by key and fills result with deleted tuple. +// +// It is equal to conn.DeleteAsync(space, tuple).GetTyped(&result). +// +// Deprecated: the method will be removed in the next major version, +// use a DeleteRequest object + Do() instead. +func (conn *Connection) DeleteTyped(space, index interface{}, key interface{}, + result interface{}) error { + return conn.DeleteAsync(space, index, key).GetTyped(result) +} + +// UpdateTyped performs update of a tuple by key and fills result with updated tuple. +// +// It is equal to conn.UpdateAsync(space, tuple, ops).GetTyped(&result). +// +// Deprecated: the method will be removed in the next major version, +// use a UpdateRequest object + Do() instead. +func (conn *Connection) UpdateTyped(space, index interface{}, key interface{}, + ops *Operations, result interface{}) error { + return conn.UpdateAsync(space, index, key, ops).GetTyped(result) +} + +// CallTyped calls registered function. +// It uses request code for Tarantool >= 1.7, result is an array. +// +// It is equal to conn.Call16Async(functionName, args).GetTyped(&result). +// +// Deprecated: the method will be removed in the next major version, +// use a CallRequest object + Do() instead. +func (conn *Connection) CallTyped(functionName string, args interface{}, + result interface{}) error { + return conn.CallAsync(functionName, args).GetTyped(result) +} + +// Call16Typed calls registered function. +// It uses request code for Tarantool 1.6, result is an array of arrays. +// Deprecated since Tarantool 1.7.2. +// +// It is equal to conn.Call16Async(functionName, args).GetTyped(&result). +// +// Deprecated: the method will be removed in the next major version, +// use a Call16Request object + Do() instead. +func (conn *Connection) Call16Typed(functionName string, args interface{}, + result interface{}) error { + return conn.Call16Async(functionName, args).GetTyped(result) +} + +// Call17Typed calls registered function. +// It uses request code for Tarantool >= 1.7, result is an array. +// +// It is equal to conn.Call17Async(functionName, args).GetTyped(&result). +// +// Deprecated: the method will be removed in the next major version, +// use a Call17Request object + Do() instead. +func (conn *Connection) Call17Typed(functionName string, args interface{}, + result interface{}) error { + return conn.Call17Async(functionName, args).GetTyped(result) +} + +// EvalTyped passes Lua expression for evaluation. +// +// It is equal to conn.EvalAsync(space, tuple).GetTyped(&result). +// +// Deprecated: the method will be removed in the next major version, +// use an EvalRequest object + Do() instead. +func (conn *Connection) EvalTyped(expr string, args interface{}, result interface{}) error { + return conn.EvalAsync(expr, args).GetTyped(result) +} + +// ExecuteTyped passes sql expression to Tarantool for execution. +// +// In addition to error returns sql info and columns meta data +// Since 1.6.0 +// +// Deprecated: the method will be removed in the next major version, +// use an ExecuteRequest object + Do() instead. +func (conn *Connection) ExecuteTyped(expr string, args interface{}, + result interface{}) (SQLInfo, []ColumnMetaData, error) { + var ( + sqlInfo SQLInfo + metaData []ColumnMetaData + ) + + fut := conn.ExecuteAsync(expr, args) + err := fut.GetTyped(&result) + if resp, ok := fut.resp.(*ExecuteResponse); ok { + sqlInfo = resp.sqlInfo + metaData = resp.metaData + } else if err == nil { + err = fmt.Errorf("unexpected response type %T, want: *ExecuteResponse", fut.resp) + } + return sqlInfo, metaData, err +} + +// SelectAsync sends select request to Tarantool and returns Future. +// +// Deprecated: the method will be removed in the next major version, +// use a SelectRequest object + Do() instead. +func (conn *Connection) SelectAsync(space, index interface{}, offset, limit uint32, iterator Iter, + key interface{}) *Future { + req := NewSelectRequest(space). + Index(index). + Offset(offset). + Limit(limit). + Iterator(iterator). + Key(key) + return conn.Do(req) +} + +// InsertAsync sends insert action to Tarantool and returns Future. +// Tarantool will reject Insert when tuple with same primary key exists. +// +// Deprecated: the method will be removed in the next major version, +// use an InsertRequest object + Do() instead. +func (conn *Connection) InsertAsync(space interface{}, tuple interface{}) *Future { + req := NewInsertRequest(space).Tuple(tuple) + return conn.Do(req) +} + +// ReplaceAsync sends "insert or replace" action to Tarantool and returns Future. +// If tuple with same primary key exists, it will be replaced. +// +// Deprecated: the method will be removed in the next major version, +// use a ReplaceRequest object + Do() instead. +func (conn *Connection) ReplaceAsync(space interface{}, tuple interface{}) *Future { + req := NewReplaceRequest(space).Tuple(tuple) + return conn.Do(req) +} + +// DeleteAsync sends deletion action to Tarantool and returns Future. +// Future's result will contain array with deleted tuple. +// +// Deprecated: the method will be removed in the next major version, +// use a DeleteRequest object + Do() instead. +func (conn *Connection) DeleteAsync(space, index interface{}, key interface{}) *Future { + req := NewDeleteRequest(space).Index(index).Key(key) + return conn.Do(req) +} + +// Update sends deletion of a tuple by key and returns Future. +// Future's result will contain array with updated tuple. +// +// Deprecated: the method will be removed in the next major version, +// use a UpdateRequest object + Do() instead. +func (conn *Connection) UpdateAsync(space, index interface{}, key interface{}, + ops *Operations) *Future { + req := NewUpdateRequest(space).Index(index).Key(key) + req.ops = ops + return conn.Do(req) +} + +// UpsertAsync sends "update or insert" action to Tarantool and returns Future. +// Future's sesult will not contain any tuple. +// +// Deprecated: the method will be removed in the next major version, +// use a UpsertRequest object + Do() instead. +func (conn *Connection) UpsertAsync(space, tuple interface{}, ops *Operations) *Future { + req := NewUpsertRequest(space).Tuple(tuple) + req.ops = ops + return conn.Do(req) +} + +// CallAsync sends a call to registered Tarantool function and returns Future. +// It uses request code for Tarantool >= 1.7, so future's result is an array. +// +// Deprecated: the method will be removed in the next major version, +// use a CallRequest object + Do() instead. +func (conn *Connection) CallAsync(functionName string, args interface{}) *Future { + req := NewCallRequest(functionName).Args(args) + return conn.Do(req) +} + +// Call16Async sends a call to registered Tarantool function and returns Future. +// It uses request code for Tarantool 1.6, so future's result is an array of arrays. +// Deprecated since Tarantool 1.7.2. +// +// Deprecated: the method will be removed in the next major version, +// use a Call16Request object + Do() instead. +func (conn *Connection) Call16Async(functionName string, args interface{}) *Future { + req := NewCall16Request(functionName).Args(args) + return conn.Do(req) +} + +// Call17Async sends a call to registered Tarantool function and returns Future. +// It uses request code for Tarantool >= 1.7, so future's result is an array. +// +// Deprecated: the method will be removed in the next major version, +// use a Call17Request object + Do() instead. +func (conn *Connection) Call17Async(functionName string, args interface{}) *Future { + req := NewCall17Request(functionName).Args(args) + return conn.Do(req) +} + +// EvalAsync sends a Lua expression for evaluation and returns Future. +// +// Deprecated: the method will be removed in the next major version, +// use an EvalRequest object + Do() instead. +func (conn *Connection) EvalAsync(expr string, args interface{}) *Future { + req := NewEvalRequest(expr).Args(args) + return conn.Do(req) +} + +// ExecuteAsync sends a sql expression for execution and returns Future. +// Since 1.6.0 +// +// Deprecated: the method will be removed in the next major version, +// use an ExecuteRequest object + Do() instead. +func (conn *Connection) ExecuteAsync(expr string, args interface{}) *Future { + req := NewExecuteRequest(expr).Args(args) + return conn.Do(req) +} + +// KeyValueBind is a type for encoding named SQL parameters +type KeyValueBind struct { + Key string + Value interface{} +} // // private // +// this map is needed for caching names of struct fields in lower case +// to avoid extra allocations in heap by calling strings.ToLower() +var lowerCaseNames sync.Map + +func encodeSQLBind(enc *msgpack.Encoder, from interface{}) error { + // internal function for encoding single map in msgpack + encodeKeyInterface := func(key string, val interface{}) error { + if err := enc.EncodeMapLen(1); err != nil { + return err + } + if err := enc.EncodeString(":" + key); err != nil { + return err + } + if err := enc.Encode(val); err != nil { + return err + } + return nil + } + + encodeKeyValue := func(key string, val reflect.Value) error { + if err := enc.EncodeMapLen(1); err != nil { + return err + } + if err := enc.EncodeString(":" + key); err != nil { + return err + } + if err := enc.EncodeValue(val); err != nil { + return err + } + return nil + } + + encodeNamedFromMap := func(mp map[string]interface{}) error { + if err := enc.EncodeArrayLen(len(mp)); err != nil { + return err + } + for k, v := range mp { + if err := encodeKeyInterface(k, v); err != nil { + return err + } + } + return nil + } + + encodeNamedFromStruct := func(val reflect.Value) error { + if err := enc.EncodeArrayLen(val.NumField()); err != nil { + return err + } + cached, ok := lowerCaseNames.Load(val.Type()) + if !ok { + fields := make([]string, val.NumField()) + for i := 0; i < val.NumField(); i++ { + key := val.Type().Field(i).Name + fields[i] = strings.ToLower(key) + v := val.Field(i) + if err := encodeKeyValue(fields[i], v); err != nil { + return err + } + } + lowerCaseNames.Store(val.Type(), fields) + return nil + } + + fields := cached.([]string) + for i := 0; i < val.NumField(); i++ { + k := fields[i] + v := val.Field(i) + if err := encodeKeyValue(k, v); err != nil { + return err + } + } + return nil + } + + encodeSlice := func(from interface{}) error { + castedSlice, ok := from.([]interface{}) + if !ok { + castedKVSlice := from.([]KeyValueBind) + t := len(castedKVSlice) + if err := enc.EncodeArrayLen(t); err != nil { + return err + } + for _, v := range castedKVSlice { + if err := encodeKeyInterface(v.Key, v.Value); err != nil { + return err + } + } + return nil + } + + if err := enc.EncodeArrayLen(len(castedSlice)); err != nil { + return err + } + for i := 0; i < len(castedSlice); i++ { + if kvb, ok := castedSlice[i].(KeyValueBind); ok { + k := kvb.Key + v := kvb.Value + if err := encodeKeyInterface(k, v); err != nil { + return err + } + } else { + if err := enc.Encode(castedSlice[i]); err != nil { + return err + } + } + } + return nil + } + + val := reflect.ValueOf(from) + switch val.Kind() { + case reflect.Map: + mp, ok := from.(map[string]interface{}) + if !ok { + return errors.New("failed to encode map: wrong format") + } + if err := encodeNamedFromMap(mp); err != nil { + return err + } + case reflect.Struct: + if err := encodeNamedFromStruct(val); err != nil { + return err + } + case reflect.Slice, reflect.Array: + if err := encodeSlice(from); err != nil { + return err + } + } + return nil +} + +// Request is an interface that provides the necessary data to create a request +// that will be sent to a tarantool instance. +type Request interface { + // Type returns a IPROTO type of the request. + Type() iproto.Type + // Body fills an msgpack.Encoder with a request body. + Body(resolver SchemaResolver, enc *msgpack.Encoder) error + // Ctx returns a context of the request. + Ctx() context.Context + // Async returns true if the request does not expect response. + Async() bool + // Response creates a response for current request type. + Response(header Header, body io.Reader) (Response, error) +} + +// ConnectedRequest is an interface that provides the info about a Connection +// the request belongs to. +type ConnectedRequest interface { + Request + // Conn returns a Connection the request belongs to. + Conn() *Connection +} + +type baseRequest struct { + rtype iproto.Type + async bool + ctx context.Context +} + +// Type returns a IPROTO type for the request. +func (req *baseRequest) Type() iproto.Type { + return req.rtype +} + +// Async returns true if the request does not require a response. +func (req *baseRequest) Async() bool { + return req.async +} + +// Ctx returns a context of the request. +func (req *baseRequest) Ctx() context.Context { + return req.ctx +} + +// Response creates a response for the baseRequest. +func (req *baseRequest) Response(header Header, body io.Reader) (Response, error) { + return DecodeBaseResponse(header, body) +} + +type spaceRequest struct { + baseRequest + space interface{} +} + +func (req *spaceRequest) setSpace(space interface{}) { + req.space = space +} + +func EncodeSpace(res SchemaResolver, enc *msgpack.Encoder, space interface{}) error { + spaceEnc, err := newSpaceEncoder(res, space) + if err != nil { + return err + } + if err := spaceEnc.Encode(enc); err != nil { + return err + } + return nil +} + +type spaceIndexRequest struct { + spaceRequest + index interface{} +} + +func (req *spaceIndexRequest) setIndex(index interface{}) { + req.index = index +} + +// authRequest implements IPROTO_AUTH request. +type authRequest struct { + auth Auth + user, pass string +} + +// newChapSha1AuthRequest create a new authRequest with chap-sha1 +// authentication method. +func newChapSha1AuthRequest(user, password, salt string) (authRequest, error) { + req := authRequest{} + scr, err := scramble(salt, password) + if err != nil { + return req, fmt.Errorf("scrambling failure: %w", err) + } + + req.auth = ChapSha1Auth + req.user = user + req.pass = string(scr) + return req, nil +} + +// newPapSha256AuthRequest create a new authRequest with pap-sha256 +// authentication method. +func newPapSha256AuthRequest(user, password string) authRequest { + return authRequest{ + auth: PapSha256Auth, + user: user, + pass: password, + } +} + +// Type returns a IPROTO type for the request. +func (req authRequest) Type() iproto.Type { + return iproto.IPROTO_AUTH +} + +// Async returns true if the request does not require a response. +func (req authRequest) Async() bool { + return false +} + +// Ctx returns a context of the request. +func (req authRequest) Ctx() context.Context { + return nil +} + +// Body fills an encoder with the auth request body. +func (req authRequest) Body(_ SchemaResolver, enc *msgpack.Encoder) error { + if err := enc.EncodeMapLen(2); err != nil { + return err + } + + if err := enc.EncodeUint32(uint32(iproto.IPROTO_USER_NAME)); err != nil { + return err + } + + if err := enc.EncodeString(req.user); err != nil { + return err + } + + if err := enc.EncodeUint32(uint32(iproto.IPROTO_TUPLE)); err != nil { + return err + } + + if err := enc.EncodeArrayLen(2); err != nil { + return err + } + + if err := enc.EncodeString(req.auth.String()); err != nil { + return err + } + + if err := enc.EncodeString(req.pass); err != nil { + return err + } + + return nil +} + +// Response creates a response for the authRequest. +func (req authRequest) Response(header Header, body io.Reader) (Response, error) { + return DecodeBaseResponse(header, body) +} + +// PingRequest helps you to create an execute request object for execution +// by a Connection. +type PingRequest struct { + baseRequest +} + +// NewPingRequest returns a new PingRequest. +func NewPingRequest() *PingRequest { + req := new(PingRequest) + req.rtype = iproto.IPROTO_PING + return req +} + +// Body fills an msgpack.Encoder with the ping request body. +func (req *PingRequest) Body(_ SchemaResolver, enc *msgpack.Encoder) error { + return enc.EncodeMapLen(0) +} + +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *PingRequest) Context(ctx context.Context) *PingRequest { + req.ctx = ctx + return req +} + +// SelectRequest allows you to create a select request object for execution +// by a Connection. +type SelectRequest struct { + spaceIndexRequest + isIteratorSet, fetchPos bool + offset, limit uint32 + iterator Iter + key, after interface{} +} + +// NewSelectRequest returns a new empty SelectRequest. +func NewSelectRequest(space interface{}) *SelectRequest { + req := new(SelectRequest) + req.rtype = iproto.IPROTO_SELECT + req.setSpace(space) + req.isIteratorSet = false + req.fetchPos = false + req.iterator = IterAll + req.key = []interface{}{} + req.after = nil + req.limit = 0xFFFFFFFF + return req +} + +// Index sets the index for the select request. +// Note: default value is 0. +func (req *SelectRequest) Index(index interface{}) *SelectRequest { + req.setIndex(index) + return req +} + +// Offset sets the offset for the select request. +// Note: default value is 0. +func (req *SelectRequest) Offset(offset uint32) *SelectRequest { + req.offset = offset + return req +} + +// Limit sets the limit for the select request. +// Note: default value is 0xFFFFFFFF. +func (req *SelectRequest) Limit(limit uint32) *SelectRequest { + req.limit = limit + return req +} + +// Iterator set the iterator for the select request. +// Note: default value is IterAll if key is not set or IterEq otherwise. +func (req *SelectRequest) Iterator(iterator Iter) *SelectRequest { + req.iterator = iterator + req.isIteratorSet = true + return req +} + +// Key set the key for the select request. +// Note: default value is empty. +func (req *SelectRequest) Key(key interface{}) *SelectRequest { + req.key = key + if !req.isIteratorSet { + req.iterator = IterEq + } + return req +} + +// FetchPos determines whether to fetch positions of the last tuple. A position +// descriptor will be saved in Response.Pos value. +// +// Note: default value is false. +// +// Requires Tarantool >= 2.11. +// Since 1.11.0 +func (req *SelectRequest) FetchPos(fetch bool) *SelectRequest { + req.fetchPos = fetch + return req +} + +// After must contain a tuple from which selection must continue or its +// position (a value from Response.Pos). +// +// Note: default value in nil. +// +// Requires Tarantool >= 2.11. +// Since 1.11.0 +func (req *SelectRequest) After(after interface{}) *SelectRequest { + req.after = after + return req +} -func (req *Request) perform() (resp *Response, err error) { - packet, err := req.pack() +// Body fills an encoder with the select request body. +func (req *SelectRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { + spaceEnc, err := newSpaceEncoder(res, req.space) if err != nil { - return + return err } - responseChan := make(chan *Response) + indexEnc, err := newIndexEncoder(res, req.index, spaceEnc.Id) + if err != nil { + return err + } - req.conn.mutex.Lock() - req.conn.requests[req.requestId] = responseChan - req.conn.mutex.Unlock() + mapLen := 6 + if req.fetchPos { + mapLen++ + } + if req.after != nil { + mapLen++ + } - req.conn.packets <- (packet) - resp = <-responseChan + if err := enc.EncodeMapLen(mapLen); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(iproto.IPROTO_ITERATOR)); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(req.iterator)); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(iproto.IPROTO_OFFSET)); err != nil { + return err + } - if resp.Error != "" { - err = errors.New(resp.Error) + if err := enc.EncodeUint(uint64(req.offset)); err != nil { + return err } - return + + if err := enc.EncodeUint(uint64(iproto.IPROTO_LIMIT)); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(req.limit)); err != nil { + return err + } + + if err := fillSearch(enc, spaceEnc, indexEnc, req.key); err != nil { + return err + } + + if req.fetchPos { + if err := enc.EncodeUint(uint64(iproto.IPROTO_FETCH_POSITION)); err != nil { + return err + } + + if err := enc.EncodeBool(req.fetchPos); err != nil { + return err + } + } + + if req.after != nil { + if pos, ok := req.after.([]byte); ok { + if err := enc.EncodeUint(uint64(iproto.IPROTO_AFTER_POSITION)); err != nil { + return err + } + + if err := enc.EncodeString(string(pos)); err != nil { + return err + } + } else { + if err := enc.EncodeUint(uint64(iproto.IPROTO_AFTER_TUPLE)); err != nil { + return err + } + + if err := enc.Encode(req.after); err != nil { + return err + } + } + } + + return nil +} + +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *SelectRequest) Context(ctx context.Context) *SelectRequest { + req.ctx = ctx + return req +} + +// Response creates a response for the SelectRequest. +func (req *SelectRequest) Response(header Header, body io.Reader) (Response, error) { + baseResp, err := createBaseResponse(header, body) + if err != nil { + return nil, err + } + return &SelectResponse{baseResponse: baseResp}, nil } -func (req *Request) pack() (packet []byte, err error) { - var header, body, packetLength []byte +// InsertRequest helps you to create an insert request object for execution +// by a Connection. +type InsertRequest struct { + spaceRequest + tuple interface{} +} + +// NewInsertRequest returns a new empty InsertRequest. +func NewInsertRequest(space interface{}) *InsertRequest { + req := new(InsertRequest) + req.rtype = iproto.IPROTO_INSERT + req.setSpace(space) + req.tuple = []interface{}{} + return req +} - msg_header := make(map[int]interface{}) - msg_header[KeyCode] = req.requestCode - msg_header[KeySync] = req.requestId +// Tuple sets the tuple for insertion the insert request. +// Note: default value is nil. +func (req *InsertRequest) Tuple(tuple interface{}) *InsertRequest { + req.tuple = tuple + return req +} - header, err = msgpack.Marshal(msg_header) +// Body fills an msgpack.Encoder with the insert request body. +func (req *InsertRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { + spaceEnc, err := newSpaceEncoder(res, req.space) if err != nil { - return + return err + } + + if err := enc.EncodeMapLen(2); err != nil { + return err + } + + if err := spaceEnc.Encode(enc); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(iproto.IPROTO_TUPLE)); err != nil { + return err } - body, err = msgpack.Marshal(req.body) + return enc.Encode(req.tuple) +} + +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *InsertRequest) Context(ctx context.Context) *InsertRequest { + req.ctx = ctx + return req +} + +// ReplaceRequest helps you to create a replace request object for execution +// by a Connection. +type ReplaceRequest struct { + spaceRequest + tuple interface{} +} + +// NewReplaceRequest returns a new empty ReplaceRequest. +func NewReplaceRequest(space interface{}) *ReplaceRequest { + req := new(ReplaceRequest) + req.rtype = iproto.IPROTO_REPLACE + req.setSpace(space) + req.tuple = []interface{}{} + return req +} + +// Tuple sets the tuple for replace by the replace request. +// Note: default value is nil. +func (req *ReplaceRequest) Tuple(tuple interface{}) *ReplaceRequest { + req.tuple = tuple + return req +} + +// Body fills an msgpack.Encoder with the replace request body. +func (req *ReplaceRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { + spaceEnc, err := newSpaceEncoder(res, req.space) + if err != nil { + return err + } + + if err := enc.EncodeMapLen(2); err != nil { + return err + } + + if err := spaceEnc.Encode(enc); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(iproto.IPROTO_TUPLE)); err != nil { + return err + } + + return enc.Encode(req.tuple) +} + +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *ReplaceRequest) Context(ctx context.Context) *ReplaceRequest { + req.ctx = ctx + return req +} + +// DeleteRequest helps you to create a delete request object for execution +// by a Connection. +type DeleteRequest struct { + spaceIndexRequest + key interface{} +} + +// NewDeleteRequest returns a new empty DeleteRequest. +func NewDeleteRequest(space interface{}) *DeleteRequest { + req := new(DeleteRequest) + req.rtype = iproto.IPROTO_DELETE + req.setSpace(space) + req.key = []interface{}{} + return req +} + +// Index sets the index for the delete request. +// Note: default value is 0. +func (req *DeleteRequest) Index(index interface{}) *DeleteRequest { + req.setIndex(index) + return req +} + +// Key sets the key of tuple for the delete request. +// Note: default value is empty. +func (req *DeleteRequest) Key(key interface{}) *DeleteRequest { + req.key = key + return req +} + +// Body fills an msgpack.Encoder with the delete request body. +func (req *DeleteRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { + spaceEnc, err := newSpaceEncoder(res, req.space) + if err != nil { + return err + } + + indexEnc, err := newIndexEncoder(res, req.index, spaceEnc.Id) + if err != nil { + return err + } + + if err := enc.EncodeMapLen(3); err != nil { + return err + } + + return fillSearch(enc, spaceEnc, indexEnc, req.key) +} + +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *DeleteRequest) Context(ctx context.Context) *DeleteRequest { + req.ctx = ctx + return req +} + +// UpdateRequest helps you to create an update request object for execution +// by a Connection. +type UpdateRequest struct { + spaceIndexRequest + key interface{} + ops *Operations +} + +// NewUpdateRequest returns a new empty UpdateRequest. +func NewUpdateRequest(space interface{}) *UpdateRequest { + req := new(UpdateRequest) + req.rtype = iproto.IPROTO_UPDATE + req.setSpace(space) + req.key = []interface{}{} + return req +} + +// Index sets the index for the update request. +// Note: default value is 0. +func (req *UpdateRequest) Index(index interface{}) *UpdateRequest { + req.setIndex(index) + return req +} + +// Key sets the key of tuple for the update request. +// Note: default value is empty. +func (req *UpdateRequest) Key(key interface{}) *UpdateRequest { + req.key = key + return req +} + +// Operations sets operations to be performed on update. +// Note: default value is empty. +func (req *UpdateRequest) Operations(ops *Operations) *UpdateRequest { + req.ops = ops + return req +} + +// Body fills an msgpack.Encoder with the update request body. +func (req *UpdateRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { + spaceEnc, err := newSpaceEncoder(res, req.space) if err != nil { - return + return err } - length := uint32(len(header) + len(body)) - packetLength, err = msgpack.Marshal(length) + indexEnc, err := newIndexEncoder(res, req.index, spaceEnc.Id) if err != nil { - return + return err + } + + if err := enc.EncodeMapLen(4); err != nil { + return err + } + + if err := fillSearch(enc, spaceEnc, indexEnc, req.key); err != nil { + return err } - packet = append(packet, packetLength...) - packet = append(packet, header...) - packet = append(packet, body...) - return + if err := enc.EncodeUint(uint64(iproto.IPROTO_TUPLE)); err != nil { + return err + } + + if req.ops == nil { + return enc.EncodeArrayLen(0) + } else { + return enc.Encode(req.ops) + } +} + +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *UpdateRequest) Context(ctx context.Context) *UpdateRequest { + req.ctx = ctx + return req +} + +// UpsertRequest helps you to create an upsert request object for execution +// by a Connection. +type UpsertRequest struct { + spaceRequest + tuple interface{} + ops *Operations +} + +// NewUpsertRequest returns a new empty UpsertRequest. +func NewUpsertRequest(space interface{}) *UpsertRequest { + req := new(UpsertRequest) + req.rtype = iproto.IPROTO_UPSERT + req.setSpace(space) + req.tuple = []interface{}{} + return req +} + +// Tuple sets the tuple for insertion or update by the upsert request. +// Note: default value is empty. +func (req *UpsertRequest) Tuple(tuple interface{}) *UpsertRequest { + req.tuple = tuple + return req +} + +// Operations sets operations to be performed on update case by the upsert request. +// Note: default value is empty. +func (req *UpsertRequest) Operations(ops *Operations) *UpsertRequest { + req.ops = ops + return req +} + +// Body fills an msgpack.Encoder with the upsert request body. +func (req *UpsertRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { + spaceEnc, err := newSpaceEncoder(res, req.space) + if err != nil { + return err + } + + if err := enc.EncodeMapLen(3); err != nil { + return err + } + + if err := spaceEnc.Encode(enc); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(iproto.IPROTO_TUPLE)); err != nil { + return err + } + + if err := enc.Encode(req.tuple); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(iproto.IPROTO_OPS)); err != nil { + return err + } + + if req.ops == nil { + return enc.EncodeArrayLen(0) + } else { + return enc.Encode(req.ops) + } +} + +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *UpsertRequest) Context(ctx context.Context) *UpsertRequest { + req.ctx = ctx + return req +} + +// CallRequest helps you to create a call request object for execution +// by a Connection. +type CallRequest struct { + baseRequest + function string + args interface{} +} + +// NewCallRequest returns a new empty CallRequest. It uses request code for +// Tarantool >= 1.7. +func NewCallRequest(function string) *CallRequest { + req := new(CallRequest) + req.rtype = iproto.IPROTO_CALL + req.function = function + return req +} + +// Args sets the args for the call request. +// Note: default value is empty. +func (req *CallRequest) Args(args interface{}) *CallRequest { + req.args = args + return req +} + +// Body fills an encoder with the call request body. +func (req *CallRequest) Body(_ SchemaResolver, enc *msgpack.Encoder) error { + if err := enc.EncodeMapLen(2); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(iproto.IPROTO_FUNCTION_NAME)); err != nil { + return err + } + + if err := enc.EncodeString(req.function); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(iproto.IPROTO_TUPLE)); err != nil { + return err + } + + if req.args == nil { + return enc.EncodeArrayLen(0) + } else { + return enc.Encode(req.args) + } +} + +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *CallRequest) Context(ctx context.Context) *CallRequest { + req.ctx = ctx + return req +} + +// NewCall16Request returns a new empty Call16Request. It uses request code for +// Tarantool 1.6. +// Deprecated since Tarantool 1.7.2. +func NewCall16Request(function string) *CallRequest { + req := NewCallRequest(function) + req.rtype = iproto.IPROTO_CALL_16 + return req +} + +// NewCall17Request returns a new empty CallRequest. It uses request code for +// Tarantool >= 1.7. +func NewCall17Request(function string) *CallRequest { + req := NewCallRequest(function) + req.rtype = iproto.IPROTO_CALL + return req +} + +// EvalRequest helps you to create an eval request object for execution +// by a Connection. +type EvalRequest struct { + baseRequest + expr string + args interface{} +} + +// NewEvalRequest returns a new empty EvalRequest. +func NewEvalRequest(expr string) *EvalRequest { + req := new(EvalRequest) + req.rtype = iproto.IPROTO_EVAL + req.expr = expr + req.args = []interface{}{} + return req +} + +// Args sets the args for the eval request. +// Note: default value is empty. +func (req *EvalRequest) Args(args interface{}) *EvalRequest { + req.args = args + return req +} + +// Body fills an msgpack.Encoder with the eval request body. +func (req *EvalRequest) Body(_ SchemaResolver, enc *msgpack.Encoder) error { + if err := enc.EncodeMapLen(2); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(iproto.IPROTO_EXPR)); err != nil { + return err + } + + if err := enc.EncodeString(req.expr); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(iproto.IPROTO_TUPLE)); err != nil { + return err + } + + if req.args == nil { + return enc.EncodeArrayLen(0) + } else { + return enc.Encode(req.args) + } +} + +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *EvalRequest) Context(ctx context.Context) *EvalRequest { + req.ctx = ctx + return req +} + +// ExecuteRequest helps you to create an execute request object for execution +// by a Connection. +type ExecuteRequest struct { + baseRequest + expr string + args interface{} +} + +// NewExecuteRequest returns a new empty ExecuteRequest. +func NewExecuteRequest(expr string) *ExecuteRequest { + req := new(ExecuteRequest) + req.rtype = iproto.IPROTO_EXECUTE + req.expr = expr + req.args = []interface{}{} + return req +} + +// Args sets the args for the execute request. +// Note: default value is empty. +func (req *ExecuteRequest) Args(args interface{}) *ExecuteRequest { + req.args = args + return req +} + +// Body fills an msgpack.Encoder with the execute request body. +func (req *ExecuteRequest) Body(_ SchemaResolver, enc *msgpack.Encoder) error { + if err := enc.EncodeMapLen(2); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(iproto.IPROTO_SQL_TEXT)); err != nil { + return err + } + + if err := enc.EncodeString(req.expr); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(iproto.IPROTO_SQL_BIND)); err != nil { + return err + } + + return encodeSQLBind(enc, req.args) +} + +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *ExecuteRequest) Context(ctx context.Context) *ExecuteRequest { + req.ctx = ctx + return req +} + +// Response creates a response for the ExecuteRequest. +func (req *ExecuteRequest) Response(header Header, body io.Reader) (Response, error) { + baseResp, err := createBaseResponse(header, body) + if err != nil { + return nil, err + } + return &ExecuteResponse{baseResponse: baseResp}, nil +} + +// WatchOnceRequest synchronously fetches the value currently associated with a +// specified notification key without subscribing to changes. +type WatchOnceRequest struct { + baseRequest + key string +} + +// NewWatchOnceRequest returns a new watchOnceRequest. +func NewWatchOnceRequest(key string) *WatchOnceRequest { + req := new(WatchOnceRequest) + req.rtype = iproto.IPROTO_WATCH_ONCE + req.key = key + return req +} + +// Body fills an msgpack.Encoder with the watchOnce request body. +func (req *WatchOnceRequest) Body(_ SchemaResolver, enc *msgpack.Encoder) error { + if err := enc.EncodeMapLen(1); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(iproto.IPROTO_EVENT_KEY)); err != nil { + return err + } + + return enc.EncodeString(req.key) +} + +// Context sets a passed context to the request. +func (req *WatchOnceRequest) Context(ctx context.Context) *WatchOnceRequest { + req.ctx = ctx + return req } diff --git a/request_test.go b/request_test.go new file mode 100644 index 000000000..fb4290299 --- /dev/null +++ b/request_test.go @@ -0,0 +1,469 @@ +package tarantool_test + +import ( + "bytes" + "context" + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tarantool/go-iproto" + "github.com/vmihailenco/msgpack/v5" + + . "github.com/tarantool/go-tarantool/v3" +) + +const invalidSpaceMsg = "invalid space" +const invalidIndexMsg = "invalid index" + +const invalidSpace uint32 = 2 +const invalidIndex uint32 = 2 +const validSpace uint32 = 1 // Any valid value != default. +const validIndex uint32 = 3 // Any valid value != default. +const validExpr = "any string" // We don't check the value here. +const validKey = "foo" // Any string. +const defaultSpace uint32 = 0 // And valid too. +const defaultIndex uint32 = 0 // And valid too. + +var ( + validStmt = &Prepared{StatementID: 1, Conn: &Connection{}} + validProtocolInfo = ProtocolInfo{ + Version: ProtocolVersion(3), + Features: []iproto.Feature{iproto.IPROTO_FEATURE_STREAMS}, + } +) + +type ValidSchemeResolver struct { + nameUseSupported bool + spaceResolverCalls int + indexResolverCalls int +} + +func (r *ValidSchemeResolver) ResolveSpace(s interface{}) (uint32, error) { + r.spaceResolverCalls++ + + var spaceNo uint32 + if no, ok := s.(uint32); ok { + spaceNo = no + } else { + spaceNo = defaultSpace + } + if spaceNo == invalidSpace { + return 0, errors.New(invalidSpaceMsg) + } + return spaceNo, nil +} + +func (r *ValidSchemeResolver) ResolveIndex(i interface{}, spaceNo uint32) (uint32, error) { + r.indexResolverCalls++ + + var indexNo uint32 + if no, ok := i.(uint32); ok { + indexNo = no + } else { + indexNo = defaultIndex + } + if indexNo == invalidIndex { + return 0, errors.New(invalidIndexMsg) + } + return indexNo, nil +} + +func (r *ValidSchemeResolver) NamesUseSupported() bool { + return r.nameUseSupported +} + +var resolver ValidSchemeResolver + +func assertBodyCall(t testing.TB, requests []Request, errorMsg string) { + t.Helper() + + const errBegin = "An unexpected Request.Body() " + for _, req := range requests { + var reqBuf bytes.Buffer + enc := msgpack.NewEncoder(&reqBuf) + + err := req.Body(&resolver, enc) + if err != nil && errorMsg != "" && err.Error() != errorMsg { + t.Errorf(errBegin+"error %q expected %q", err.Error(), errorMsg) + } + if err != nil && errorMsg == "" { + t.Errorf(errBegin+"error %q", err.Error()) + } + if err == nil && errorMsg != "" { + t.Errorf(errBegin+"result, expected error %q", errorMsg) + } + } +} + +func TestRequestsValidSpaceAndIndex(t *testing.T) { + requests := []Request{ + NewSelectRequest(validSpace), + NewSelectRequest(validSpace).Index(validIndex), + NewUpdateRequest(validSpace), + NewUpdateRequest(validSpace).Index(validIndex), + NewUpsertRequest(validSpace), + NewInsertRequest(validSpace), + NewReplaceRequest(validSpace), + NewDeleteRequest(validSpace), + NewDeleteRequest(validSpace).Index(validIndex), + } + + assertBodyCall(t, requests, "") +} + +func TestRequestsInvalidSpace(t *testing.T) { + requests := []Request{ + NewSelectRequest(invalidSpace).Index(validIndex), + NewSelectRequest(invalidSpace), + NewUpdateRequest(invalidSpace).Index(validIndex), + NewUpdateRequest(invalidSpace), + NewUpsertRequest(invalidSpace), + NewInsertRequest(invalidSpace), + NewReplaceRequest(invalidSpace), + NewDeleteRequest(invalidSpace).Index(validIndex), + NewDeleteRequest(invalidSpace), + } + + assertBodyCall(t, requests, invalidSpaceMsg) +} + +func TestRequestsInvalidIndex(t *testing.T) { + requests := []Request{ + NewSelectRequest(validSpace).Index(invalidIndex), + NewUpdateRequest(validSpace).Index(invalidIndex), + NewDeleteRequest(validSpace).Index(invalidIndex), + } + + assertBodyCall(t, requests, invalidIndexMsg) +} + +func TestRequestsTypes(t *testing.T) { + tests := []struct { + req Request + rtype iproto.Type + }{ + {req: NewSelectRequest(validSpace), rtype: iproto.IPROTO_SELECT}, + {req: NewUpdateRequest(validSpace), rtype: iproto.IPROTO_UPDATE}, + {req: NewUpsertRequest(validSpace), rtype: iproto.IPROTO_UPSERT}, + {req: NewInsertRequest(validSpace), rtype: iproto.IPROTO_INSERT}, + {req: NewReplaceRequest(validSpace), rtype: iproto.IPROTO_REPLACE}, + {req: NewDeleteRequest(validSpace), rtype: iproto.IPROTO_DELETE}, + {req: NewCallRequest(validExpr), rtype: iproto.IPROTO_CALL}, + {req: NewCall16Request(validExpr), rtype: iproto.IPROTO_CALL_16}, + {req: NewCall17Request(validExpr), rtype: iproto.IPROTO_CALL}, + {req: NewEvalRequest(validExpr), rtype: iproto.IPROTO_EVAL}, + {req: NewExecuteRequest(validExpr), rtype: iproto.IPROTO_EXECUTE}, + {req: NewPingRequest(), rtype: iproto.IPROTO_PING}, + {req: NewPrepareRequest(validExpr), rtype: iproto.IPROTO_PREPARE}, + {req: NewUnprepareRequest(validStmt), rtype: iproto.IPROTO_PREPARE}, + {req: NewExecutePreparedRequest(validStmt), rtype: iproto.IPROTO_EXECUTE}, + {req: NewBeginRequest(), rtype: iproto.IPROTO_BEGIN}, + {req: NewCommitRequest(), rtype: iproto.IPROTO_COMMIT}, + {req: NewRollbackRequest(), rtype: iproto.IPROTO_ROLLBACK}, + {req: NewIdRequest(validProtocolInfo), rtype: iproto.IPROTO_ID}, + {req: NewBroadcastRequest(validKey), rtype: iproto.IPROTO_CALL}, + {req: NewWatchOnceRequest(validKey), rtype: iproto.IPROTO_WATCH_ONCE}, + } + + for _, test := range tests { + if rtype := test.req.Type(); rtype != test.rtype { + t.Errorf("An invalid request type 0x%x, expected 0x%x", + rtype, test.rtype) + } + } +} + +func TestRequestsAsync(t *testing.T) { + tests := []struct { + req Request + async bool + }{ + {req: NewSelectRequest(validSpace), async: false}, + {req: NewUpdateRequest(validSpace), async: false}, + {req: NewUpsertRequest(validSpace), async: false}, + {req: NewInsertRequest(validSpace), async: false}, + {req: NewReplaceRequest(validSpace), async: false}, + {req: NewDeleteRequest(validSpace), async: false}, + {req: NewCallRequest(validExpr), async: false}, + {req: NewCall16Request(validExpr), async: false}, + {req: NewCall17Request(validExpr), async: false}, + {req: NewEvalRequest(validExpr), async: false}, + {req: NewExecuteRequest(validExpr), async: false}, + {req: NewPingRequest(), async: false}, + {req: NewPrepareRequest(validExpr), async: false}, + {req: NewUnprepareRequest(validStmt), async: false}, + {req: NewExecutePreparedRequest(validStmt), async: false}, + {req: NewBeginRequest(), async: false}, + {req: NewCommitRequest(), async: false}, + {req: NewRollbackRequest(), async: false}, + {req: NewIdRequest(validProtocolInfo), async: false}, + {req: NewBroadcastRequest(validKey), async: false}, + {req: NewWatchOnceRequest(validKey), async: false}, + } + + for _, test := range tests { + if async := test.req.Async(); async != test.async { + t.Errorf("An invalid async %t, expected %t", async, test.async) + } + } +} + +func TestRequestsCtx_default(t *testing.T) { + tests := []struct { + req Request + expected context.Context + }{ + {req: NewSelectRequest(validSpace), expected: nil}, + {req: NewUpdateRequest(validSpace), expected: nil}, + {req: NewUpsertRequest(validSpace), expected: nil}, + {req: NewInsertRequest(validSpace), expected: nil}, + {req: NewReplaceRequest(validSpace), expected: nil}, + {req: NewDeleteRequest(validSpace), expected: nil}, + {req: NewCallRequest(validExpr), expected: nil}, + {req: NewCall16Request(validExpr), expected: nil}, + {req: NewCall17Request(validExpr), expected: nil}, + {req: NewEvalRequest(validExpr), expected: nil}, + {req: NewExecuteRequest(validExpr), expected: nil}, + {req: NewPingRequest(), expected: nil}, + {req: NewPrepareRequest(validExpr), expected: nil}, + {req: NewUnprepareRequest(validStmt), expected: nil}, + {req: NewExecutePreparedRequest(validStmt), expected: nil}, + {req: NewBeginRequest(), expected: nil}, + {req: NewCommitRequest(), expected: nil}, + {req: NewRollbackRequest(), expected: nil}, + {req: NewIdRequest(validProtocolInfo), expected: nil}, + {req: NewBroadcastRequest(validKey), expected: nil}, + {req: NewWatchOnceRequest(validKey), expected: nil}, + } + + for _, test := range tests { + if ctx := test.req.Ctx(); ctx != test.expected { + t.Errorf("An invalid ctx %t, expected %t", ctx, test.expected) + } + } +} + +func TestRequestsCtx_setter(t *testing.T) { + ctx := context.Background() + tests := []struct { + req Request + expected context.Context + }{ + {req: NewSelectRequest(validSpace).Context(ctx), expected: ctx}, + {req: NewUpdateRequest(validSpace).Context(ctx), expected: ctx}, + {req: NewUpsertRequest(validSpace).Context(ctx), expected: ctx}, + {req: NewInsertRequest(validSpace).Context(ctx), expected: ctx}, + {req: NewReplaceRequest(validSpace).Context(ctx), expected: ctx}, + {req: NewDeleteRequest(validSpace).Context(ctx), expected: ctx}, + {req: NewCallRequest(validExpr).Context(ctx), expected: ctx}, + {req: NewCall16Request(validExpr).Context(ctx), expected: ctx}, + {req: NewCall17Request(validExpr).Context(ctx), expected: ctx}, + {req: NewEvalRequest(validExpr).Context(ctx), expected: ctx}, + {req: NewExecuteRequest(validExpr).Context(ctx), expected: ctx}, + {req: NewPingRequest().Context(ctx), expected: ctx}, + {req: NewPrepareRequest(validExpr).Context(ctx), expected: ctx}, + {req: NewUnprepareRequest(validStmt).Context(ctx), expected: ctx}, + {req: NewExecutePreparedRequest(validStmt).Context(ctx), expected: ctx}, + {req: NewBeginRequest().Context(ctx), expected: ctx}, + {req: NewCommitRequest().Context(ctx), expected: ctx}, + {req: NewRollbackRequest().Context(ctx), expected: ctx}, + {req: NewIdRequest(validProtocolInfo).Context(ctx), expected: ctx}, + {req: NewBroadcastRequest(validKey).Context(ctx), expected: ctx}, + {req: NewWatchOnceRequest(validKey).Context(ctx), expected: ctx}, + } + + for _, test := range tests { + if ctx := test.req.Ctx(); ctx != test.expected { + t.Errorf("An invalid ctx %t, expected %t", ctx, test.expected) + } + } +} + +func TestResponseDecode(t *testing.T) { + header := Header{} + data := bytes.NewBuffer([]byte{'v', '2'}) + baseExample, err := NewPingRequest().Response(header, data) + assert.NoError(t, err) + + tests := []struct { + req Request + expected Response + }{ + {req: NewSelectRequest(validSpace), expected: &SelectResponse{}}, + {req: NewUpdateRequest(validSpace), expected: baseExample}, + {req: NewUpsertRequest(validSpace), expected: baseExample}, + {req: NewInsertRequest(validSpace), expected: baseExample}, + {req: NewReplaceRequest(validSpace), expected: baseExample}, + {req: NewDeleteRequest(validSpace), expected: baseExample}, + {req: NewCallRequest(validExpr), expected: baseExample}, + {req: NewCall16Request(validExpr), expected: baseExample}, + {req: NewCall17Request(validExpr), expected: baseExample}, + {req: NewEvalRequest(validExpr), expected: baseExample}, + {req: NewExecuteRequest(validExpr), expected: &ExecuteResponse{}}, + {req: NewPingRequest(), expected: baseExample}, + {req: NewPrepareRequest(validExpr), expected: &PrepareResponse{}}, + {req: NewUnprepareRequest(validStmt), expected: baseExample}, + {req: NewExecutePreparedRequest(validStmt), expected: &ExecuteResponse{}}, + {req: NewBeginRequest(), expected: baseExample}, + {req: NewCommitRequest(), expected: baseExample}, + {req: NewRollbackRequest(), expected: baseExample}, + {req: NewIdRequest(validProtocolInfo), expected: baseExample}, + {req: NewBroadcastRequest(validKey), expected: baseExample}, + {req: NewWatchOnceRequest(validKey), expected: baseExample}, + } + + for _, test := range tests { + buf := bytes.NewBuffer([]byte{}) + enc := msgpack.NewEncoder(buf) + + enc.EncodeMapLen(1) + enc.EncodeUint8(uint8(iproto.IPROTO_DATA)) + enc.Encode([]interface{}{'v', '2'}) + + resp, err := test.req.Response(header, bytes.NewBuffer(buf.Bytes())) + assert.NoError(t, err) + assert.True(t, fmt.Sprintf("%T", resp) == + fmt.Sprintf("%T", test.expected)) + assert.Equal(t, header, resp.Header()) + + decodedInterface, err := resp.Decode() + assert.NoError(t, err) + assert.Equal(t, []interface{}{'v', '2'}, decodedInterface) + } +} + +func TestResponseDecodeTyped(t *testing.T) { + header := Header{} + data := bytes.NewBuffer([]byte{'v', '2'}) + baseExample, err := NewPingRequest().Response(header, data) + assert.NoError(t, err) + + tests := []struct { + req Request + expected Response + }{ + {req: NewSelectRequest(validSpace), expected: &SelectResponse{}}, + {req: NewUpdateRequest(validSpace), expected: baseExample}, + {req: NewUpsertRequest(validSpace), expected: baseExample}, + {req: NewInsertRequest(validSpace), expected: baseExample}, + {req: NewReplaceRequest(validSpace), expected: baseExample}, + {req: NewDeleteRequest(validSpace), expected: baseExample}, + {req: NewCallRequest(validExpr), expected: baseExample}, + {req: NewCall16Request(validExpr), expected: baseExample}, + {req: NewCall17Request(validExpr), expected: baseExample}, + {req: NewEvalRequest(validExpr), expected: baseExample}, + {req: NewExecuteRequest(validExpr), expected: &ExecuteResponse{}}, + {req: NewPingRequest(), expected: baseExample}, + {req: NewPrepareRequest(validExpr), expected: &PrepareResponse{}}, + {req: NewUnprepareRequest(validStmt), expected: baseExample}, + {req: NewExecutePreparedRequest(validStmt), expected: &ExecuteResponse{}}, + {req: NewBeginRequest(), expected: baseExample}, + {req: NewCommitRequest(), expected: baseExample}, + {req: NewRollbackRequest(), expected: baseExample}, + {req: NewIdRequest(validProtocolInfo), expected: baseExample}, + {req: NewBroadcastRequest(validKey), expected: baseExample}, + {req: NewWatchOnceRequest(validKey), expected: baseExample}, + } + + for _, test := range tests { + buf := bytes.NewBuffer([]byte{}) + enc := msgpack.NewEncoder(buf) + + enc.EncodeMapLen(1) + enc.EncodeUint8(uint8(iproto.IPROTO_DATA)) + enc.EncodeBytes([]byte{'v', '2'}) + + resp, err := test.req.Response(header, bytes.NewBuffer(buf.Bytes())) + assert.NoError(t, err) + assert.True(t, fmt.Sprintf("%T", resp) == + fmt.Sprintf("%T", test.expected)) + assert.Equal(t, header, resp.Header()) + + var decoded []byte + err = resp.DecodeTyped(&decoded) + assert.NoError(t, err) + assert.Equal(t, []byte{'v', '2'}, decoded) + } +} + +type stubSchemeResolver struct { + space interface{} +} + +func (r stubSchemeResolver) ResolveSpace(s interface{}) (uint32, error) { + if id, ok := r.space.(uint32); ok { + return id, nil + } + if _, ok := r.space.(string); ok { + return 0, nil + } + return 0, fmt.Errorf("stub error message: %v", r.space) +} + +func (stubSchemeResolver) ResolveIndex(i interface{}, spaceNo uint32) (uint32, error) { + return 0, nil +} + +func (r stubSchemeResolver) NamesUseSupported() bool { + _, ok := r.space.(string) + return ok +} + +func TestEncodeSpace(t *testing.T) { + tests := []struct { + name string + res stubSchemeResolver + err string + out []byte + }{ + { + name: "string space", + res: stubSchemeResolver{"test"}, + out: []byte{0x5E, 0xA4, 0x74, 0x65, 0x73, 0x74}, + }, + { + name: "empty string", + res: stubSchemeResolver{""}, + out: []byte{0x5E, 0xA0}, + }, + { + name: "numeric 524", + res: stubSchemeResolver{uint32(524)}, + out: []byte{0x10, 0xCD, 0x02, 0x0C}, + }, + { + name: "numeric zero", + res: stubSchemeResolver{uint32(0)}, + out: []byte{0x10, 0x00}, + }, + { + name: "numeric max value", + res: stubSchemeResolver{^uint32(0)}, + out: []byte{0x10, 0xCE, 0xFF, 0xFF, 0xFF, 0xFF}, + }, + { + name: "resolve error", + res: stubSchemeResolver{false}, + err: "stub error message", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + + err := EncodeSpace(tt.res, enc, tt.res.space) + if tt.err != "" { + require.ErrorContains(t, err, tt.err) + return + } else { + require.NoError(t, err) + } + + require.Equal(t, tt.out, buf.Bytes()) + }) + } +} diff --git a/response.go b/response.go index 1c19b4f57..36aad66a0 100644 --- a/response.go +++ b/response.go @@ -1,34 +1,693 @@ package tarantool -import( - "github.com/vmihailenco/msgpack" +import ( + "fmt" + "io" + + "github.com/tarantool/go-iproto" + "github.com/vmihailenco/msgpack/v5" ) -type Response struct { - RequestId uint32 - Code uint32 - Error string - Data []interface{} +// Response is an interface with operations for the server responses. +type Response interface { + // Header returns a response header. + Header() Header + // Decode decodes a response. + Decode() ([]interface{}, error) + // DecodeTyped decodes a response into a given container res. + DecodeTyped(res interface{}) error +} + +type baseResponse struct { + // header is a response header. + header Header + // data contains deserialized data for untyped requests. + data []interface{} + buf smallBuf + // Was the Decode() func called for this response. + decoded bool + // Was the DecodeTyped() func called for this response. + decodedTyped bool + err error +} + +func createBaseResponse(header Header, body io.Reader) (baseResponse, error) { + if body == nil { + return baseResponse{header: header}, nil + } + if buf, ok := body.(*smallBuf); ok { + return baseResponse{header: header, buf: *buf}, nil + } + data, err := io.ReadAll(body) + if err != nil { + return baseResponse{}, err + } + return baseResponse{header: header, buf: smallBuf{b: data}}, nil +} + +// DecodeBaseResponse parse response header and body. +func DecodeBaseResponse(header Header, body io.Reader) (Response, error) { + resp, err := createBaseResponse(header, body) + return &resp, err +} + +// SelectResponse is used for the select requests. +// It might contain a position descriptor of the last selected tuple. +// +// You need to cast to SelectResponse a response from SelectRequest. +type SelectResponse struct { + baseResponse + // pos contains a position descriptor of last selected tuple. + pos []byte +} + +// PrepareResponse is used for the prepare requests. +// It might contain meta-data and sql info. +// +// Be careful: now this is an alias for `ExecuteResponse`, +// but it could be changed in the future. +// You need to cast to PrepareResponse a response from PrepareRequest. +type PrepareResponse ExecuteResponse + +// ExecuteResponse is used for the execute requests. +// It might contain meta-data and sql info. +// +// You need to cast to ExecuteResponse a response from ExecuteRequest. +type ExecuteResponse struct { + baseResponse + metaData []ColumnMetaData + sqlInfo SQLInfo +} + +type ColumnMetaData struct { + FieldName string + FieldType string + FieldCollation string + FieldIsNullable bool + FieldIsAutoincrement bool + FieldSpan string +} + +type SQLInfo struct { + AffectedCount uint64 + InfoAutoincrementIds []uint64 +} + +func (meta *ColumnMetaData) DecodeMsgpack(d *msgpack.Decoder) error { + var err error + var l int + if l, err = d.DecodeMapLen(); err != nil { + return err + } + if l == 0 { + return fmt.Errorf("map len doesn't match: %d", l) + } + for i := 0; i < l; i++ { + var mk uint64 + var mv interface{} + if mk, err = d.DecodeUint64(); err != nil { + return fmt.Errorf("failed to decode meta data") + } + if mv, err = d.DecodeInterface(); err != nil { + return fmt.Errorf("failed to decode meta data") + } + switch iproto.MetadataKey(mk) { + case iproto.IPROTO_FIELD_NAME: + meta.FieldName = mv.(string) + case iproto.IPROTO_FIELD_TYPE: + meta.FieldType = mv.(string) + case iproto.IPROTO_FIELD_COLL: + meta.FieldCollation = mv.(string) + case iproto.IPROTO_FIELD_IS_NULLABLE: + meta.FieldIsNullable = mv.(bool) + case iproto.IPROTO_FIELD_IS_AUTOINCREMENT: + meta.FieldIsAutoincrement = mv.(bool) + case iproto.IPROTO_FIELD_SPAN: + meta.FieldSpan = mv.(string) + default: + return fmt.Errorf("failed to decode meta data") + } + } + return nil +} + +func (info *SQLInfo) DecodeMsgpack(d *msgpack.Decoder) error { + var err error + var l int + if l, err = d.DecodeMapLen(); err != nil { + return err + } + if l == 0 { + return fmt.Errorf("map len doesn't match") + } + for i := 0; i < l; i++ { + var mk uint64 + if mk, err = d.DecodeUint64(); err != nil { + return fmt.Errorf("failed to decode meta data") + } + switch iproto.SqlInfoKey(mk) { + case iproto.SQL_INFO_ROW_COUNT: + if info.AffectedCount, err = d.DecodeUint64(); err != nil { + return fmt.Errorf("failed to decode meta data") + } + case iproto.SQL_INFO_AUTOINCREMENT_IDS: + if err = d.Decode(&info.InfoAutoincrementIds); err != nil { + return fmt.Errorf("failed to decode meta data") + } + default: + return fmt.Errorf("failed to decode meta data") + } + } + return nil +} + +func smallInt(d *msgpack.Decoder, buf *smallBuf) (i int, err error) { + b, err := buf.ReadByte() + if err != nil { + return + } + if b <= 127 { + return int(b), nil + } + buf.UnreadByte() + return d.DecodeInt() +} + +func decodeHeader(d *msgpack.Decoder, buf *smallBuf) (Header, iproto.Type, error) { + var l int + var code int + var err error + d.Reset(buf) + if l, err = d.DecodeMapLen(); err != nil { + return Header{}, 0, err + } + decodedHeader := Header{Error: ErrorNo} + for ; l > 0; l-- { + var cd int + if cd, err = smallInt(d, buf); err != nil { + return Header{}, 0, err + } + switch iproto.Key(cd) { + case iproto.IPROTO_SYNC: + var rid uint64 + if rid, err = d.DecodeUint64(); err != nil { + return Header{}, 0, err + } + decodedHeader.RequestId = uint32(rid) + case iproto.IPROTO_REQUEST_TYPE: + if code, err = d.DecodeInt(); err != nil { + return Header{}, 0, err + } + if code&int(iproto.IPROTO_TYPE_ERROR) != 0 { + decodedHeader.Error = iproto.Error(code &^ int(iproto.IPROTO_TYPE_ERROR)) + } else { + decodedHeader.Error = ErrorNo + } + default: + if err = d.Skip(); err != nil { + return Header{}, 0, err + } + } + } + return decodedHeader, iproto.Type(code), nil +} + +type decodeInfo struct { + stmtID uint64 + bindCount uint64 + serverProtocolInfo ProtocolInfo + errorExtendedInfo *BoxError + + decodedError string +} + +func (info *decodeInfo) parseData(resp *baseResponse) error { + if info.stmtID != 0 { + stmt := &Prepared{ + StatementID: PreparedID(info.stmtID), + ParamCount: info.bindCount, + } + resp.data = []interface{}{stmt} + return nil + } + + // Tarantool may send only version >= 1. + if info.serverProtocolInfo.Version != ProtocolVersion(0) || + info.serverProtocolInfo.Features != nil { + if info.serverProtocolInfo.Version == ProtocolVersion(0) { + return fmt.Errorf("no protocol version provided in Id response") + } + if info.serverProtocolInfo.Features == nil { + return fmt.Errorf("no features provided in Id response") + } + resp.data = []interface{}{info.serverProtocolInfo} + return nil + } + return nil +} + +func decodeCommonField(d *msgpack.Decoder, cd int, data *[]interface{}, + info *decodeInfo) (bool, error) { + var feature iproto.Feature + var err error + + switch iproto.Key(cd) { + case iproto.IPROTO_DATA: + var res interface{} + var ok bool + if res, err = d.DecodeInterface(); err != nil { + return false, err + } + if *data, ok = res.([]interface{}); !ok { + return false, fmt.Errorf("result is not array: %v", res) + } + case iproto.IPROTO_ERROR: + if info.errorExtendedInfo, err = decodeBoxError(d); err != nil { + return false, err + } + case iproto.IPROTO_ERROR_24: + if info.decodedError, err = d.DecodeString(); err != nil { + return false, err + } + case iproto.IPROTO_STMT_ID: + if info.stmtID, err = d.DecodeUint64(); err != nil { + return false, err + } + case iproto.IPROTO_BIND_COUNT: + if info.bindCount, err = d.DecodeUint64(); err != nil { + return false, err + } + case iproto.IPROTO_VERSION: + if err = d.Decode(&info.serverProtocolInfo.Version); err != nil { + return false, err + } + case iproto.IPROTO_FEATURES: + var larr int + if larr, err = d.DecodeArrayLen(); err != nil { + return false, err + } + + info.serverProtocolInfo.Features = make([]iproto.Feature, larr) + for i := 0; i < larr; i++ { + if err = d.Decode(&feature); err != nil { + return false, err + } + info.serverProtocolInfo.Features[i] = feature + } + case iproto.IPROTO_AUTH_TYPE: + var auth string + if auth, err = d.DecodeString(); err != nil { + return false, err + } + found := false + for _, a := range [...]Auth{ChapSha1Auth, PapSha256Auth} { + if auth == a.String() { + info.serverProtocolInfo.Auth = a + found = true + } + } + if !found { + return false, fmt.Errorf("unknown auth type %s", auth) + } + default: + return false, nil + } + return true, nil +} + +func (resp *baseResponse) Decode() ([]interface{}, error) { + if resp.decoded { + return resp.data, resp.err + } + + resp.decoded = true + var err error + if resp.buf.Len() > 2 { + offset := resp.buf.Offset() + defer resp.buf.Seek(offset) + + var l int + info := &decodeInfo{} + + d := getDecoder(&resp.buf) + defer putDecoder(d) + + if l, err = d.DecodeMapLen(); err != nil { + resp.err = err + return nil, resp.err + } + for ; l > 0; l-- { + var cd int + if cd, err = smallInt(d, &resp.buf); err != nil { + resp.err = err + return nil, resp.err + } + decoded, err := decodeCommonField(d, cd, &resp.data, info) + if err != nil { + resp.err = err + return nil, resp.err + } + if !decoded { + if err = d.Skip(); err != nil { + resp.err = err + return nil, resp.err + } + } + } + err = info.parseData(resp) + if err != nil { + resp.err = err + return nil, resp.err + } + + if info.decodedError != "" { + resp.err = Error{resp.header.Error, info.decodedError, + info.errorExtendedInfo} + } + } + return resp.data, resp.err } -func NewResponse(bytes []byte) (resp *Response) { - var header, body map[int32]interface{} - resp = &Response{} +func (resp *SelectResponse) Decode() ([]interface{}, error) { + if resp.decoded { + return resp.data, resp.err + } + + resp.decoded = true + var err error + if resp.buf.Len() > 2 { + offset := resp.buf.Offset() + defer resp.buf.Seek(offset) - msgpack.Unmarshal(bytes, &header, &body) - resp.RequestId = uint32(header[KeySync].(uint64)) - resp.Code = uint32(header[KeyCode].(uint64)) - if body[KeyData] != nil { - data := body[KeyData].([]interface{}) - resp.Data = make([]interface{}, len(data)) - for i, v := range(data) { - resp.Data[i] = v.([]interface{}) + var l int + info := &decodeInfo{} + + d := getDecoder(&resp.buf) + defer putDecoder(d) + + if l, err = d.DecodeMapLen(); err != nil { + resp.err = err + return nil, resp.err + } + for ; l > 0; l-- { + var cd int + if cd, err = smallInt(d, &resp.buf); err != nil { + resp.err = err + return nil, resp.err + } + decoded, err := decodeCommonField(d, cd, &resp.data, info) + if err != nil { + resp.err = err + return nil, err + } + if !decoded { + switch iproto.Key(cd) { + case iproto.IPROTO_POSITION: + if resp.pos, err = d.DecodeBytes(); err != nil { + resp.err = err + return nil, fmt.Errorf("unable to decode a position: %w", resp.err) + } + default: + if err = d.Skip(); err != nil { + resp.err = err + return nil, resp.err + } + } + } + } + err = info.parseData(&resp.baseResponse) + if err != nil { + resp.err = err + return nil, resp.err + } + + if info.decodedError != "" { + resp.err = Error{resp.header.Error, info.decodedError, + info.errorExtendedInfo} } } + return resp.data, resp.err +} - if resp.Code != OkCode { - resp.Error = body[KeyError].(string) +func (resp *ExecuteResponse) Decode() ([]interface{}, error) { + if resp.decoded { + return resp.data, resp.err } - return + resp.decoded = true + var err error + if resp.buf.Len() > 2 { + offset := resp.buf.Offset() + defer resp.buf.Seek(offset) + + var l int + info := &decodeInfo{} + + d := getDecoder(&resp.buf) + defer putDecoder(d) + + if l, err = d.DecodeMapLen(); err != nil { + resp.err = err + return nil, resp.err + } + for ; l > 0; l-- { + var cd int + if cd, err = smallInt(d, &resp.buf); err != nil { + resp.err = err + return nil, resp.err + } + decoded, err := decodeCommonField(d, cd, &resp.data, info) + if err != nil { + resp.err = err + return nil, resp.err + } + if !decoded { + switch iproto.Key(cd) { + case iproto.IPROTO_SQL_INFO: + if err = d.Decode(&resp.sqlInfo); err != nil { + resp.err = err + return nil, resp.err + } + case iproto.IPROTO_METADATA: + if err = d.Decode(&resp.metaData); err != nil { + resp.err = err + return nil, resp.err + } + default: + if err = d.Skip(); err != nil { + resp.err = err + return nil, resp.err + } + } + } + } + err = info.parseData(&resp.baseResponse) + if err != nil { + resp.err = err + return nil, resp.err + } + + if info.decodedError != "" { + resp.err = Error{resp.header.Error, info.decodedError, + info.errorExtendedInfo} + } + } + return resp.data, resp.err +} + +func decodeTypedCommonField(d *msgpack.Decoder, res interface{}, cd int, + info *decodeInfo) (bool, error) { + var err error + + switch iproto.Key(cd) { + case iproto.IPROTO_DATA: + if err = d.Decode(res); err != nil { + return false, err + } + case iproto.IPROTO_ERROR: + if info.errorExtendedInfo, err = decodeBoxError(d); err != nil { + return false, err + } + case iproto.IPROTO_ERROR_24: + if info.decodedError, err = d.DecodeString(); err != nil { + return false, err + } + default: + return false, nil + } + return true, nil +} + +func (resp *baseResponse) DecodeTyped(res interface{}) error { + resp.decodedTyped = true + + var err error + if resp.buf.Len() > 0 { + offset := resp.buf.Offset() + defer resp.buf.Seek(offset) + + info := &decodeInfo{} + var l int + + d := getDecoder(&resp.buf) + defer putDecoder(d) + + if l, err = d.DecodeMapLen(); err != nil { + return err + } + for ; l > 0; l-- { + var cd int + if cd, err = smallInt(d, &resp.buf); err != nil { + return err + } + decoded, err := decodeTypedCommonField(d, res, cd, info) + if err != nil { + return err + } + if !decoded { + if err = d.Skip(); err != nil { + return err + } + } + } + if info.decodedError != "" { + err = Error{resp.header.Error, info.decodedError, info.errorExtendedInfo} + } + } + return err +} + +func (resp *SelectResponse) DecodeTyped(res interface{}) error { + resp.decodedTyped = true + + var err error + if resp.buf.Len() > 0 { + offset := resp.buf.Offset() + defer resp.buf.Seek(offset) + + info := &decodeInfo{} + var l int + + d := getDecoder(&resp.buf) + defer putDecoder(d) + + if l, err = d.DecodeMapLen(); err != nil { + return err + } + for ; l > 0; l-- { + var cd int + if cd, err = smallInt(d, &resp.buf); err != nil { + return err + } + decoded, err := decodeTypedCommonField(d, res, cd, info) + if err != nil { + return err + } + if !decoded { + switch iproto.Key(cd) { + case iproto.IPROTO_POSITION: + if resp.pos, err = d.DecodeBytes(); err != nil { + return fmt.Errorf("unable to decode a position: %w", err) + } + default: + if err = d.Skip(); err != nil { + return err + } + } + } + } + if info.decodedError != "" { + err = Error{resp.header.Error, info.decodedError, info.errorExtendedInfo} + } + } + return err +} + +func (resp *ExecuteResponse) DecodeTyped(res interface{}) error { + resp.decodedTyped = true + + var err error + if resp.buf.Len() > 0 { + offset := resp.buf.Offset() + defer resp.buf.Seek(offset) + + info := &decodeInfo{} + var l int + + d := getDecoder(&resp.buf) + defer putDecoder(d) + + if l, err = d.DecodeMapLen(); err != nil { + return err + } + for ; l > 0; l-- { + var cd int + if cd, err = smallInt(d, &resp.buf); err != nil { + return err + } + decoded, err := decodeTypedCommonField(d, res, cd, info) + if err != nil { + return err + } + if !decoded { + switch iproto.Key(cd) { + case iproto.IPROTO_SQL_INFO: + if err = d.Decode(&resp.sqlInfo); err != nil { + return err + } + case iproto.IPROTO_METADATA: + if err = d.Decode(&resp.metaData); err != nil { + return err + } + default: + if err = d.Skip(); err != nil { + return err + } + } + } + } + if info.decodedError != "" { + err = Error{resp.header.Error, info.decodedError, info.errorExtendedInfo} + } + } + return err +} + +func (resp *baseResponse) Header() Header { + return resp.header +} + +// Pos returns a position descriptor of the last selected tuple for the SelectResponse. +// If the response was not decoded, this method will call Decode(). +func (resp *SelectResponse) Pos() ([]byte, error) { + if !resp.decoded && !resp.decodedTyped { + resp.Decode() + } + return resp.pos, resp.err +} + +// MetaData returns ExecuteResponse meta-data. +// If the response was not decoded, this method will call Decode(). +func (resp *ExecuteResponse) MetaData() ([]ColumnMetaData, error) { + if !resp.decoded && !resp.decodedTyped { + resp.Decode() + } + return resp.metaData, resp.err +} + +// SQLInfo returns ExecuteResponse sql info. +// If the response was not decoded, this method will call Decode(). +func (resp *ExecuteResponse) SQLInfo() (SQLInfo, error) { + if !resp.decoded && !resp.decodedTyped { + resp.Decode() + } + return resp.sqlInfo, resp.err +} + +// String implements Stringer interface. +func (resp *baseResponse) String() (str string) { + if resp.header.Error == ErrorNo { + return fmt.Sprintf("<%d OK %v>", resp.header.RequestId, resp.data) + } + return fmt.Sprintf("<%d ERR %s %v>", resp.header.RequestId, resp.header.Error, resp.err) } diff --git a/response_test.go b/response_test.go new file mode 100644 index 000000000..e58b4d47c --- /dev/null +++ b/response_test.go @@ -0,0 +1,56 @@ +package tarantool_test + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tarantool/go-iproto" + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +func encodeResponseData(t *testing.T, data interface{}) io.Reader { + t.Helper() + + buf := bytes.NewBuffer([]byte{}) + enc := msgpack.NewEncoder(buf) + + enc.EncodeMapLen(1) + enc.EncodeUint8(uint8(iproto.IPROTO_DATA)) + enc.Encode([]interface{}{data}) + return buf + +} + +func TestDecodeBaseResponse(t *testing.T) { + tests := []struct { + name string + header tarantool.Header + body interface{} + }{ + { + "test1", + tarantool.Header{}, + nil, + }, + { + "test2", + tarantool.Header{RequestId: 123}, + []byte{'v', '2'}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + res, err := tarantool.DecodeBaseResponse(tt.header, encodeResponseData(t, tt.body)) + require.NoError(t, err) + require.Equal(t, tt.header, res.Header()) + + got, err := res.Decode() + require.NoError(t, err) + require.Equal(t, []interface{}{tt.body}, got) + }) + } +} diff --git a/schema.go b/schema.go new file mode 100644 index 000000000..e7c09f80e --- /dev/null +++ b/schema.go @@ -0,0 +1,568 @@ +package tarantool + +import ( + "errors" + "fmt" + + "github.com/vmihailenco/msgpack/v5" + "github.com/vmihailenco/msgpack/v5/msgpcode" +) + +// nolint: varcheck,deadcode +const ( + maxSchemas = 10000 + spaceSpId = 280 + vspaceSpId = 281 + indexSpId = 288 + vindexSpId = 289 + vspaceSpTypeFieldNum = 6 + vspaceSpFormatFieldNum = 7 +) + +var ( + ErrConcurrentSchemaUpdate = errors.New("concurrent schema update") +) + +func msgpackIsUint(code byte) bool { + return code == msgpcode.Uint8 || code == msgpcode.Uint16 || + code == msgpcode.Uint32 || code == msgpcode.Uint64 || + msgpcode.IsFixedNum(code) +} + +func msgpackIsMap(code byte) bool { + return code == msgpcode.Map16 || code == msgpcode.Map32 || msgpcode.IsFixedMap(code) +} + +func msgpackIsArray(code byte) bool { + return code == msgpcode.Array16 || code == msgpcode.Array32 || + msgpcode.IsFixedArray(code) +} + +func msgpackIsString(code byte) bool { + return msgpcode.IsFixedString(code) || code == msgpcode.Str8 || + code == msgpcode.Str16 || code == msgpcode.Str32 +} + +// SchemaResolver is an interface for resolving schema details. +type SchemaResolver interface { + // ResolveSpace returns resolved space number or an error + // if it cannot be resolved. + ResolveSpace(s interface{}) (spaceNo uint32, err error) + // ResolveIndex returns resolved index number or an error + // if it cannot be resolved. + ResolveIndex(i interface{}, spaceNo uint32) (indexNo uint32, err error) + // NamesUseSupported shows if usage of space and index names, instead of + // IDs, is supported. It must return true if + // iproto.IPROTO_FEATURE_SPACE_AND_INDEX_NAMES is supported. + NamesUseSupported() bool +} + +// Schema contains information about spaces and indexes. +type Schema struct { + Version uint + // Spaces is map from space names to spaces. + Spaces map[string]Space + // SpacesById is map from space numbers to spaces. + SpacesById map[uint32]Space +} + +func (schema *Schema) copy() Schema { + schemaCopy := *schema + schemaCopy.Spaces = make(map[string]Space, len(schema.Spaces)) + for name, space := range schema.Spaces { + schemaCopy.Spaces[name] = space.copy() + } + schemaCopy.SpacesById = make(map[uint32]Space, len(schema.SpacesById)) + for id, space := range schema.SpacesById { + schemaCopy.SpacesById[id] = space.copy() + } + return schemaCopy +} + +// Space contains information about Tarantool's space. +type Space struct { + Id uint32 + Name string + // Could be "memtx" or "vinyl". + Engine string + Temporary bool // Is this space temporary? + // Field configuration is not mandatory and not checked by Tarantool. + FieldsCount uint32 + Fields map[string]Field + FieldsById map[uint32]Field + // Indexes is map from index names to indexes. + Indexes map[string]Index + // IndexesById is map from index numbers to indexes. + IndexesById map[uint32]Index +} + +func (space *Space) copy() Space { + spaceCopy := *space + spaceCopy.Fields = make(map[string]Field, len(space.Fields)) + for name, field := range space.Fields { + spaceCopy.Fields[name] = field + } + spaceCopy.FieldsById = make(map[uint32]Field, len(space.FieldsById)) + for id, field := range space.FieldsById { + spaceCopy.FieldsById[id] = field + } + spaceCopy.Indexes = make(map[string]Index, len(space.Indexes)) + for name, index := range space.Indexes { + spaceCopy.Indexes[name] = index.copy() + } + spaceCopy.IndexesById = make(map[uint32]Index, len(space.IndexesById)) + for id, index := range space.IndexesById { + spaceCopy.IndexesById[id] = index.copy() + } + return spaceCopy +} + +func (space *Space) DecodeMsgpack(d *msgpack.Decoder) error { + arrayLen, err := d.DecodeArrayLen() + if err != nil { + return err + } + if space.Id, err = d.DecodeUint32(); err != nil { + return err + } + if err := d.Skip(); err != nil { + return err + } + if space.Name, err = d.DecodeString(); err != nil { + return err + } + if space.Engine, err = d.DecodeString(); err != nil { + return err + } + if space.FieldsCount, err = d.DecodeUint32(); err != nil { + return err + } + if arrayLen >= vspaceSpTypeFieldNum { + code, err := d.PeekCode() + if err != nil { + return err + } + if msgpackIsString(code) { + val, err := d.DecodeString() + if err != nil { + return err + } + space.Temporary = val == "temporary" + } else if msgpackIsMap(code) { + mapLen, err := d.DecodeMapLen() + if err != nil { + return err + } + for i := 0; i < mapLen; i++ { + key, err := d.DecodeString() + if err != nil { + return err + } + if key == "temporary" { + if space.Temporary, err = d.DecodeBool(); err != nil { + return err + } + } else { + if err = d.Skip(); err != nil { + return err + } + } + } + } else { + return errors.New("unexpected schema format (space flags)") + } + } + space.FieldsById = make(map[uint32]Field) + space.Fields = make(map[string]Field) + space.IndexesById = make(map[uint32]Index) + space.Indexes = make(map[string]Index) + if arrayLen >= vspaceSpFormatFieldNum { + fieldCount, err := d.DecodeArrayLen() + if err != nil { + return err + } + for i := 0; i < fieldCount; i++ { + field := Field{} + if err := field.DecodeMsgpack(d); err != nil { + return err + } + field.Id = uint32(i) + space.FieldsById[field.Id] = field + if field.Name != "" { + space.Fields[field.Name] = field + } + } + } + return nil +} + +// Field is a schema field. +type Field struct { + Id uint32 + Name string + Type string + IsNullable bool +} + +func (field *Field) DecodeMsgpack(d *msgpack.Decoder) error { + l, err := d.DecodeMapLen() + if err != nil { + return err + } + for i := 0; i < l; i++ { + key, err := d.DecodeString() + if err != nil { + return err + } + switch key { + case "name": + if field.Name, err = d.DecodeString(); err != nil { + return err + } + case "type": + if field.Type, err = d.DecodeString(); err != nil { + return err + } + case "is_nullable": + if field.IsNullable, err = d.DecodeBool(); err != nil { + return err + } + default: + if err := d.Skip(); err != nil { + return err + } + } + } + return nil +} + +// Index contains information about index. +type Index struct { + Id uint32 + SpaceId uint32 + Name string + Type string + Unique bool + Fields []IndexField +} + +func (index *Index) copy() Index { + indexCopy := *index + indexCopy.Fields = make([]IndexField, len(index.Fields)) + copy(indexCopy.Fields, index.Fields) + return indexCopy +} + +func (index *Index) DecodeMsgpack(d *msgpack.Decoder) error { + _, err := d.DecodeArrayLen() + if err != nil { + return err + } + + if index.SpaceId, err = d.DecodeUint32(); err != nil { + return err + } + if index.Id, err = d.DecodeUint32(); err != nil { + return err + } + if index.Name, err = d.DecodeString(); err != nil { + return err + } + if index.Type, err = d.DecodeString(); err != nil { + return err + } + + var code byte + if code, err = d.PeekCode(); err != nil { + return err + } + + if msgpackIsUint(code) { + optsUint64, err := d.DecodeUint64() + if err != nil { + return nil + } + index.Unique = optsUint64 > 0 + } else { + var optsMap map[string]interface{} + if err := d.Decode(&optsMap); err != nil { + return fmt.Errorf("unexpected schema format (index flags): %w", err) + } + + var ok bool + if index.Unique, ok = optsMap["unique"].(bool); !ok { + /* see bug https://github.com/tarantool/tarantool/issues/2060 */ + index.Unique = true + } + } + + if code, err = d.PeekCode(); err != nil { + return err + } + + if msgpackIsUint(code) { + fieldCount, err := d.DecodeUint64() + if err != nil { + return err + } + index.Fields = make([]IndexField, fieldCount) + for i := 0; i < int(fieldCount); i++ { + index.Fields[i] = IndexField{} + if index.Fields[i].Id, err = d.DecodeUint32(); err != nil { + return err + } + if index.Fields[i].Type, err = d.DecodeString(); err != nil { + return err + } + } + } else { + if err := d.Decode(&index.Fields); err != nil { + return fmt.Errorf("unexpected schema format (index flags): %w", err) + } + } + + return nil +} + +// IndexFields is an index field. +type IndexField struct { + Id uint32 + Type string +} + +func (indexField *IndexField) DecodeMsgpack(d *msgpack.Decoder) error { + code, err := d.PeekCode() + if err != nil { + return err + } + + if msgpackIsMap(code) { + mapLen, err := d.DecodeMapLen() + if err != nil { + return err + } + for i := 0; i < mapLen; i++ { + key, err := d.DecodeString() + if err != nil { + return err + } + switch key { + case "field": + if indexField.Id, err = d.DecodeUint32(); err != nil { + return err + } + case "type": + if indexField.Type, err = d.DecodeString(); err != nil { + return err + } + default: + if err := d.Skip(); err != nil { + return err + } + } + } + return nil + } else if msgpackIsArray(code) { + arrayLen, err := d.DecodeArrayLen() + if err != nil { + return err + } + if indexField.Id, err = d.DecodeUint32(); err != nil { + return err + } + if indexField.Type, err = d.DecodeString(); err != nil { + return err + } + for i := 2; i < arrayLen; i++ { + if err := d.Skip(); err != nil { + return err + } + } + return nil + } + + return errors.New("unexpected schema format (index fields)") +} + +// GetSchema returns the actual schema for the Doer. +func GetSchema(doer Doer) (Schema, error) { + schema := Schema{} + schema.SpacesById = make(map[uint32]Space) + schema.Spaces = make(map[string]Space) + + // Reload spaces. + var spaces []Space + req := NewSelectRequest(vspaceSpId). + Index(0). + Limit(maxSchemas) + err := doer.Do(req).GetTyped(&spaces) + if err != nil { + return Schema{}, err + } + for _, space := range spaces { + schema.SpacesById[space.Id] = space + schema.Spaces[space.Name] = space + } + + // Reload indexes. + var indexes []Index + req = NewSelectRequest(vindexSpId). + Index(0). + Limit(maxSchemas) + err = doer.Do(req).GetTyped(&indexes) + if err != nil { + return Schema{}, err + } + for _, index := range indexes { + spaceId := index.SpaceId + if _, ok := schema.SpacesById[spaceId]; ok { + schema.SpacesById[spaceId].IndexesById[index.Id] = index + schema.SpacesById[spaceId].Indexes[index.Name] = index + } else { + return Schema{}, ErrConcurrentSchemaUpdate + } + } + + return schema, nil +} + +// resolveSpaceNumber tries to resolve a space number. +// Note: at this point, s can be a number, or an object of Space type. +func resolveSpaceNumber(s interface{}) (uint32, error) { + var spaceNo uint32 + + switch s := s.(type) { + case uint: + spaceNo = uint32(s) + case uint64: + spaceNo = uint32(s) + case uint32: + spaceNo = s + case uint16: + spaceNo = uint32(s) + case uint8: + spaceNo = uint32(s) + case int: + spaceNo = uint32(s) + case int64: + spaceNo = uint32(s) + case int32: + spaceNo = uint32(s) + case int16: + spaceNo = uint32(s) + case int8: + spaceNo = uint32(s) + case Space: + spaceNo = s.Id + case *Space: + spaceNo = s.Id + default: + panic("unexpected type of space param") + } + + return spaceNo, nil +} + +// resolveIndexNumber tries to resolve an index number. +// Note: at this point, i can be a number, or an object of Index type. +func resolveIndexNumber(i interface{}) (uint32, error) { + var indexNo uint32 + + switch i := i.(type) { + case uint: + indexNo = uint32(i) + case uint64: + indexNo = uint32(i) + case uint32: + indexNo = i + case uint16: + indexNo = uint32(i) + case uint8: + indexNo = uint32(i) + case int: + indexNo = uint32(i) + case int64: + indexNo = uint32(i) + case int32: + indexNo = uint32(i) + case int16: + indexNo = uint32(i) + case int8: + indexNo = uint32(i) + case Index: + indexNo = i.Id + case *Index: + indexNo = i.Id + default: + panic("unexpected type of index param") + } + + return indexNo, nil +} + +type loadedSchemaResolver struct { + Schema Schema + // SpaceAndIndexNamesSupported shows if a current Tarantool version supports + // iproto.IPROTO_FEATURE_SPACE_AND_INDEX_NAMES. + SpaceAndIndexNamesSupported bool +} + +func (r *loadedSchemaResolver) ResolveSpace(s interface{}) (uint32, error) { + if str, ok := s.(string); ok { + space, ok := r.Schema.Spaces[str] + if !ok { + return 0, fmt.Errorf("there is no space with name %s", s) + } + return space.Id, nil + } + return resolveSpaceNumber(s) +} + +func (r *loadedSchemaResolver) ResolveIndex(i interface{}, spaceNo uint32) (uint32, error) { + if i == nil { + return 0, nil + } + if str, ok := i.(string); ok { + space, ok := r.Schema.SpacesById[spaceNo] + if !ok { + return 0, fmt.Errorf("there is no space with id %d", spaceNo) + } + index, ok := space.Indexes[str] + if !ok { + err := fmt.Errorf("space %s has not index with name %s", space.Name, i) + return 0, err + } + return index.Id, nil + } + return resolveIndexNumber(i) +} + +func (r *loadedSchemaResolver) NamesUseSupported() bool { + return r.SpaceAndIndexNamesSupported +} + +type noSchemaResolver struct { + // SpaceAndIndexNamesSupported shows if a current Tarantool version supports + // iproto.IPROTO_FEATURE_SPACE_AND_INDEX_NAMES. + SpaceAndIndexNamesSupported bool +} + +func (*noSchemaResolver) ResolveSpace(s interface{}) (uint32, error) { + if _, ok := s.(string); ok { + return 0, fmt.Errorf("unable to use an index name " + + "because schema is not loaded") + } + return resolveSpaceNumber(s) +} + +func (*noSchemaResolver) ResolveIndex(i interface{}, spaceNo uint32) (uint32, error) { + if _, ok := i.(string); ok { + return 0, fmt.Errorf("unable to use an index name " + + "because schema is not loaded") + } + return resolveIndexNumber(i) +} + +func (r *noSchemaResolver) NamesUseSupported() bool { + return r.SpaceAndIndexNamesSupported +} diff --git a/schema_test.go b/schema_test.go new file mode 100644 index 000000000..e30d52690 --- /dev/null +++ b/schema_test.go @@ -0,0 +1,167 @@ +package tarantool_test + +import ( + "bytes" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/test_helpers" +) + +func TestGetSchema_ok(t *testing.T) { + space1 := tarantool.Space{ + Id: 1, + Name: "name1", + Indexes: make(map[string]tarantool.Index), + IndexesById: make(map[uint32]tarantool.Index), + Fields: make(map[string]tarantool.Field), + FieldsById: make(map[uint32]tarantool.Field), + } + index := tarantool.Index{ + Id: 1, + SpaceId: 2, + Name: "index_name", + Type: "index_type", + Unique: true, + Fields: make([]tarantool.IndexField, 0), + } + space2 := tarantool.Space{ + Id: 2, + Name: "name2", + Indexes: map[string]tarantool.Index{ + "index_name": index, + }, + IndexesById: map[uint32]tarantool.Index{ + 1: index, + }, + Fields: make(map[string]tarantool.Field), + FieldsById: make(map[uint32]tarantool.Field), + } + + mockDoer := test_helpers.NewMockDoer(t, + test_helpers.NewMockResponse(t, [][]interface{}{ + { + uint32(1), + "skip", + "name1", + "", + 0, + }, + { + uint32(2), + "skip", + "name2", + "", + 0, + }, + }), + test_helpers.NewMockResponse(t, [][]interface{}{ + { + uint32(2), + uint32(1), + "index_name", + "index_type", + uint8(1), + uint8(0), + }, + }), + ) + + expectedSchema := tarantool.Schema{ + SpacesById: map[uint32]tarantool.Space{ + 1: space1, + 2: space2, + }, + Spaces: map[string]tarantool.Space{ + "name1": space1, + "name2": space2, + }, + } + + schema, err := tarantool.GetSchema(&mockDoer) + require.NoError(t, err) + require.Equal(t, expectedSchema, schema) +} + +func TestGetSchema_spaces_select_error(t *testing.T) { + mockDoer := test_helpers.NewMockDoer(t, fmt.Errorf("some error")) + + schema, err := tarantool.GetSchema(&mockDoer) + require.EqualError(t, err, "some error") + require.Equal(t, tarantool.Schema{}, schema) +} + +func TestGetSchema_index_select_error(t *testing.T) { + mockDoer := test_helpers.NewMockDoer(t, + test_helpers.NewMockResponse(t, [][]interface{}{ + { + uint32(1), + "skip", + "name1", + "", + 0, + }, + }), + fmt.Errorf("some error")) + + schema, err := tarantool.GetSchema(&mockDoer) + require.EqualError(t, err, "some error") + require.Equal(t, tarantool.Schema{}, schema) +} + +func TestResolverCalledWithoutNameSupport(t *testing.T) { + resolver := ValidSchemeResolver{nameUseSupported: false} + + req := tarantool.NewSelectRequest("valid") + req.Index("valid") + + var reqBuf bytes.Buffer + reqEnc := msgpack.NewEncoder(&reqBuf) + + err := req.Body(&resolver, reqEnc) + if err != nil { + t.Errorf("An unexpected Response.Body() error: %q", err.Error()) + } + + if resolver.spaceResolverCalls != 1 { + t.Errorf("ResolveSpace was called %d times instead of 1.", + resolver.spaceResolverCalls) + } + if resolver.indexResolverCalls != 1 { + t.Errorf("ResolveIndex was called %d times instead of 1.", + resolver.indexResolverCalls) + } +} + +func TestResolverNotCalledWithNameSupport(t *testing.T) { + resolver := ValidSchemeResolver{nameUseSupported: true} + + req := tarantool.NewSelectRequest("valid") + req.Index("valid") + + var reqBuf bytes.Buffer + reqEnc := msgpack.NewEncoder(&reqBuf) + + err := req.Body(&resolver, reqEnc) + if err != nil { + t.Errorf("An unexpected Response.Body() error: %q", err.Error()) + } + + if resolver.spaceResolverCalls != 0 { + t.Errorf("ResolveSpace was called %d times instead of 0.", + resolver.spaceResolverCalls) + } + if resolver.indexResolverCalls != 0 { + t.Errorf("ResolveIndex was called %d times instead of 0.", + resolver.indexResolverCalls) + } +} + +func TestErrConcurrentSchemaUpdate(t *testing.T) { + assert.EqualError(t, tarantool.ErrConcurrentSchemaUpdate, "concurrent schema update") +} diff --git a/settings/const.go b/settings/const.go new file mode 100644 index 000000000..cc980cd7a --- /dev/null +++ b/settings/const.go @@ -0,0 +1,21 @@ +package settings + +const sessionSettingsSpace string = "_session_settings" + +// In Go and IPROTO_UPDATE count starts with 0. +const sessionSettingValueField int = 1 + +const ( + errorMarshalingEnabled string = "error_marshaling_enabled" + sqlDefaultEngine string = "sql_default_engine" + sqlDeferForeignKeys string = "sql_defer_foreign_keys" + sqlFullColumnNames string = "sql_full_column_names" + sqlFullMetadata string = "sql_full_metadata" + sqlParserDebug string = "sql_parser_debug" + sqlRecursiveTriggers string = "sql_recursive_triggers" + sqlReverseUnorderedSelects string = "sql_reverse_unordered_selects" + sqlSelectDebug string = "sql_select_debug" + sqlVDBEDebug string = "sql_vdbe_debug" +) + +const selectAllLimit uint32 = 1000 diff --git a/settings/example_test.go b/settings/example_test.go new file mode 100644 index 000000000..fdad495f3 --- /dev/null +++ b/settings/example_test.go @@ -0,0 +1,117 @@ +package settings_test + +import ( + "context" + "fmt" + "time" + + "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/settings" + "github.com/tarantool/go-tarantool/v3/test_helpers" +) + +var exampleDialer = tarantool.NetDialer{ + Address: "127.0.0.1", + User: "test", + Password: "test", +} + +var exampleOpts = tarantool.Opts{ + Timeout: 5 * time.Second, +} + +func example_connect(dialer tarantool.Dialer, opts tarantool.Opts) *tarantool.Connection { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, dialer, opts) + if err != nil { + panic("Connection is not established: " + err.Error()) + } + return conn +} + +func Example_sqlFullColumnNames() { + var resp tarantool.Response + var err error + var isLess bool + + conn := example_connect(exampleDialer, exampleOpts) + defer conn.Close() + + // Tarantool supports session settings since version 2.3.1 + isLess, err = test_helpers.IsTarantoolVersionLess(2, 3, 1) + if err != nil || isLess { + return + } + + // Create a space. + req := tarantool.NewExecuteRequest("CREATE TABLE example(id INT PRIMARY KEY, x INT);") + _, err = conn.Do(req).Get() + if err != nil { + fmt.Printf("error in create table: %v\n", err) + return + } + + // Insert some tuple into space. + req = tarantool.NewExecuteRequest("INSERT INTO example VALUES (1, 1);") + _, err = conn.Do(req).Get() + if err != nil { + fmt.Printf("error on insert: %v\n", err) + return + } + + // Enable showing full column names in SQL responses. + _, err = conn.Do(settings.NewSQLFullColumnNamesSetRequest(true)).Get() + if err != nil { + fmt.Printf("error on setting setup: %v\n", err) + return + } + + // Get some data with SQL query. + req = tarantool.NewExecuteRequest("SELECT x FROM example WHERE id = 1;") + resp, err = conn.Do(req).GetResponse() + if err != nil { + fmt.Printf("error on select: %v\n", err) + return + } + + exResp, ok := resp.(*tarantool.ExecuteResponse) + if !ok { + fmt.Printf("wrong response type") + return + } + + metaData, err := exResp.MetaData() + if err != nil { + fmt.Printf("error on getting MetaData: %v\n", err) + return + } + // Show response metadata. + fmt.Printf("full column name: %v\n", metaData[0].FieldName) + + // Disable showing full column names in SQL responses. + _, err = conn.Do(settings.NewSQLFullColumnNamesSetRequest(false)).Get() + if err != nil { + fmt.Printf("error on setting setup: %v\n", err) + return + } + + // Get some data with SQL query. + resp, err = conn.Do(req).GetResponse() + if err != nil { + fmt.Printf("error on select: %v\n", err) + return + } + exResp, ok = resp.(*tarantool.ExecuteResponse) + if !ok { + fmt.Printf("wrong response type") + return + } + metaData, err = exResp.MetaData() + if err != nil { + fmt.Printf("error on getting MetaData: %v\n", err) + return + } + // Show response metadata. + fmt.Printf("short column name: %v\n", metaData[0].FieldName) +} diff --git a/settings/request.go b/settings/request.go new file mode 100644 index 000000000..10c6cac25 --- /dev/null +++ b/settings/request.go @@ -0,0 +1,290 @@ +// Package settings is a collection of requests to set a connection session setting +// or get current session configuration. +// +// +============================+=========================+=========+===========================+ +// | Setting | Meaning | Default | Supported in | +// | | | | Tarantool versions | +// +============================+=========================+=========+===========================+ +// | ErrorMarshalingEnabled | Defines whether error | false | Since 2.4.1 till 2.10.0, | +// | | objectshave a special | | replaced with IPROTO_ID | +// | | structure. | | feature flag. | +// +----------------------------+-------------------------+---------+---------------------------+ +// | SQLDefaultEngine | Defines default storage | "memtx" | Since 2.3.1. | +// | | engine for new SQL | | | +// | | tables. | | | +// +----------------------------+-------------------------+---------+---------------------------+ +// | SQLDeferForeignKeys | Defines whether | false | Since 2.3.1 till master | +// | | foreign-key checks can | | commit 14618c4 (possible | +// | | wait till commit. | | 2.10.5 or 2.11.0) | +// +----------------------------+-------------------------+---------+---------------------------+ +// | SQLFullColumnNames | Defines whether full | false | Since 2.3.1. | +// | | column names is | | | +// | | displayed in SQL result | | | +// | | set metadata. | | | +// +----------------------------+-------------------------+---------+---------------------------+ +// | SQLFullMetadata | Defines whether SQL | false | Since 2.3.1. | +// | | result set metadata | | | +// | | will have more than | | | +// | | just name and type. | | | +// +----------------------------+-------------------------+---------+---------------------------+ +// | SQLParserDebug | Defines whether to show | false | Since 2.3.1 (only if | +// | | parser steps for | | built with | +// | | following statements. | | -DCMAKE_BUILD_TYPE=Debug) | +// +----------------------------+-------------------------+---------+---------------------------+ +// | SQLRecursiveTriggers | Defines whether a | true | Since 2.3.1. | +// | | triggered statement can | | | +// | | activate a trigger. | | | +// +----------------------------+-------------------------+---------+---------------------------+ +// | SQLReverseUnorderedSelects | Defines defines whether | false | Since 2.3.1. | +// | | result rows are usually | | | +// | | in reverse order if | | | +// | | there is no ORDER BY | | | +// | | clause. | | | +// +----------------------------+-------------------------+---------+---------------------------+ +// | SQLSelectDebug | Defines whether to show | false | Since 2.3.1 (only if | +// | | to show execution steps | | built with | +// | | during SELECT. | | -DCMAKE_BUILD_TYPE=Debug) | +// +----------------------------+-------------------------+---------+---------------------------+ +// | SQLVDBEDebug | Defines whether VDBE | false | Since 2.3.1 (only if | +// | | debug mode is enabled. | | built with | +// | | | | -DCMAKE_BUILD_TYPE=Debug) | +// +----------------------------+-------------------------+---------+---------------------------+ +// +// Since: 1.10.0 +// +// See also: +// +// - Session settings: +// https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_space/_session_settings/ +package settings + +import ( + "context" + "io" + + "github.com/tarantool/go-iproto" + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// SetRequest helps to set session settings. +type SetRequest struct { + impl *tarantool.UpdateRequest +} + +func newSetRequest(setting string, value interface{}) *SetRequest { + return &SetRequest{ + impl: tarantool.NewUpdateRequest(sessionSettingsSpace). + Key(tarantool.StringKey{S: setting}). + Operations(tarantool.NewOperations().Assign(sessionSettingValueField, value)), + } +} + +// Context sets a passed context to set session settings request. +func (req *SetRequest) Context(ctx context.Context) *SetRequest { + req.impl = req.impl.Context(ctx) + + return req +} + +// Type returns IPROTO type for set session settings request. +func (req *SetRequest) Type() iproto.Type { + return req.impl.Type() +} + +// Body fills an encoder with set session settings request body. +func (req *SetRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + return req.impl.Body(res, enc) +} + +// Ctx returns a context of set session settings request. +func (req *SetRequest) Ctx() context.Context { + return req.impl.Ctx() +} + +// Async returns is set session settings request expects a response. +func (req *SetRequest) Async() bool { + return req.impl.Async() +} + +// Response creates a response for the SetRequest. +func (req *SetRequest) Response(header tarantool.Header, + body io.Reader) (tarantool.Response, error) { + return req.impl.Response(header, body) +} + +// GetRequest helps to get session settings. +type GetRequest struct { + impl *tarantool.SelectRequest +} + +func newGetRequest(setting string) *GetRequest { + return &GetRequest{ + impl: tarantool.NewSelectRequest(sessionSettingsSpace). + Key(tarantool.StringKey{S: setting}). + Limit(1), + } +} + +// Context sets a passed context to get session settings request. +func (req *GetRequest) Context(ctx context.Context) *GetRequest { + req.impl = req.impl.Context(ctx) + + return req +} + +// Type returns IPROTO type for get session settings request. +func (req *GetRequest) Type() iproto.Type { + return req.impl.Type() +} + +// Body fills an encoder with get session settings request body. +func (req *GetRequest) Body(res tarantool.SchemaResolver, enc *msgpack.Encoder) error { + return req.impl.Body(res, enc) +} + +// Ctx returns a context of get session settings request. +func (req *GetRequest) Ctx() context.Context { + return req.impl.Ctx() +} + +// Async returns is get session settings request expects a response. +func (req *GetRequest) Async() bool { + return req.impl.Async() +} + +// Response creates a response for the GetRequest. +func (req *GetRequest) Response(header tarantool.Header, + body io.Reader) (tarantool.Response, error) { + return req.impl.Response(header, body) +} + +// NewErrorMarshalingEnabledSetRequest creates a request to +// update current session ErrorMarshalingEnabled setting. +func NewErrorMarshalingEnabledSetRequest(value bool) *SetRequest { + return newSetRequest(errorMarshalingEnabled, value) +} + +// NewErrorMarshalingEnabledGetRequest creates a request to get +// current session ErrorMarshalingEnabled setting in tuple format. +func NewErrorMarshalingEnabledGetRequest() *GetRequest { + return newGetRequest(errorMarshalingEnabled) +} + +// NewSQLDefaultEngineSetRequest creates a request to +// update current session SQLDefaultEngine setting. +func NewSQLDefaultEngineSetRequest(value string) *SetRequest { + return newSetRequest(sqlDefaultEngine, value) +} + +// NewSQLDefaultEngineGetRequest creates a request to get +// current session SQLDefaultEngine setting in tuple format. +func NewSQLDefaultEngineGetRequest() *GetRequest { + return newGetRequest(sqlDefaultEngine) +} + +// NewSQLDeferForeignKeysSetRequest creates a request to +// update current session SQLDeferForeignKeys setting. +func NewSQLDeferForeignKeysSetRequest(value bool) *SetRequest { + return newSetRequest(sqlDeferForeignKeys, value) +} + +// NewSQLDeferForeignKeysGetRequest creates a request to get +// current session SQLDeferForeignKeys setting in tuple format. +func NewSQLDeferForeignKeysGetRequest() *GetRequest { + return newGetRequest(sqlDeferForeignKeys) +} + +// NewSQLFullColumnNamesSetRequest creates a request to +// update current session SQLFullColumnNames setting. +func NewSQLFullColumnNamesSetRequest(value bool) *SetRequest { + return newSetRequest(sqlFullColumnNames, value) +} + +// NewSQLFullColumnNamesGetRequest creates a request to get +// current session SQLFullColumnNames setting in tuple format. +func NewSQLFullColumnNamesGetRequest() *GetRequest { + return newGetRequest(sqlFullColumnNames) +} + +// NewSQLFullMetadataSetRequest creates a request to +// update current session SQLFullMetadata setting. +func NewSQLFullMetadataSetRequest(value bool) *SetRequest { + return newSetRequest(sqlFullMetadata, value) +} + +// NewSQLFullMetadataGetRequest creates a request to get +// current session SQLFullMetadata setting in tuple format. +func NewSQLFullMetadataGetRequest() *GetRequest { + return newGetRequest(sqlFullMetadata) +} + +// NewSQLParserDebugSetRequest creates a request to +// update current session SQLParserDebug setting. +func NewSQLParserDebugSetRequest(value bool) *SetRequest { + return newSetRequest(sqlParserDebug, value) +} + +// NewSQLParserDebugGetRequest creates a request to get +// current session SQLParserDebug setting in tuple format. +func NewSQLParserDebugGetRequest() *GetRequest { + return newGetRequest(sqlParserDebug) +} + +// NewSQLRecursiveTriggersSetRequest creates a request to +// update current session SQLRecursiveTriggers setting. +func NewSQLRecursiveTriggersSetRequest(value bool) *SetRequest { + return newSetRequest(sqlRecursiveTriggers, value) +} + +// NewSQLRecursiveTriggersGetRequest creates a request to get +// current session SQLRecursiveTriggers setting in tuple format. +func NewSQLRecursiveTriggersGetRequest() *GetRequest { + return newGetRequest(sqlRecursiveTriggers) +} + +// NewSQLReverseUnorderedSelectsSetRequest creates a request to +// update current session SQLReverseUnorderedSelects setting. +func NewSQLReverseUnorderedSelectsSetRequest(value bool) *SetRequest { + return newSetRequest(sqlReverseUnorderedSelects, value) +} + +// NewSQLReverseUnorderedSelectsGetRequest creates a request to get +// current session SQLReverseUnorderedSelects setting in tuple format. +func NewSQLReverseUnorderedSelectsGetRequest() *GetRequest { + return newGetRequest(sqlReverseUnorderedSelects) +} + +// NewSQLSelectDebugSetRequest creates a request to +// update current session SQLSelectDebug setting. +func NewSQLSelectDebugSetRequest(value bool) *SetRequest { + return newSetRequest(sqlSelectDebug, value) +} + +// NewSQLSelectDebugGetRequest creates a request to get +// current session SQLSelectDebug setting in tuple format. +func NewSQLSelectDebugGetRequest() *GetRequest { + return newGetRequest(sqlSelectDebug) +} + +// NewSQLVDBEDebugSetRequest creates a request to +// update current session SQLVDBEDebug setting. +func NewSQLVDBEDebugSetRequest(value bool) *SetRequest { + return newSetRequest(sqlVDBEDebug, value) +} + +// NewSQLVDBEDebugGetRequest creates a request to get +// current session SQLVDBEDebug setting in tuple format. +func NewSQLVDBEDebugGetRequest() *GetRequest { + return newGetRequest(sqlVDBEDebug) +} + +// NewSessionSettingsGetRequest creates a request to get all +// current session settings in tuple format. +func NewSessionSettingsGetRequest() *GetRequest { + return &GetRequest{ + impl: tarantool.NewSelectRequest(sessionSettingsSpace). + Limit(selectAllLimit), + } +} diff --git a/settings/request_test.go b/settings/request_test.go new file mode 100644 index 000000000..bb6c9e5c5 --- /dev/null +++ b/settings/request_test.go @@ -0,0 +1,117 @@ +package settings_test + +import ( + "bytes" + "context" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tarantool/go-iproto" + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" + . "github.com/tarantool/go-tarantool/v3/settings" +) + +type ValidSchemeResolver struct { +} + +func (*ValidSchemeResolver) ResolveSpace(s interface{}) (uint32, error) { + return 0, nil +} + +func (*ValidSchemeResolver) ResolveIndex(i interface{}, spaceNo uint32) (uint32, error) { + return 0, nil +} + +func (r *ValidSchemeResolver) NamesUseSupported() bool { + return false +} + +var resolver ValidSchemeResolver + +func TestRequestsAPI(t *testing.T) { + tests := []struct { + req tarantool.Request + async bool + rtype iproto.Type + }{ + {req: NewErrorMarshalingEnabledSetRequest(false), async: false, + rtype: iproto.IPROTO_UPDATE}, + {req: NewErrorMarshalingEnabledGetRequest(), async: false, rtype: iproto.IPROTO_SELECT}, + {req: NewSQLDefaultEngineSetRequest("memtx"), async: false, rtype: iproto.IPROTO_UPDATE}, + {req: NewSQLDefaultEngineGetRequest(), async: false, rtype: iproto.IPROTO_SELECT}, + {req: NewSQLDeferForeignKeysSetRequest(false), async: false, rtype: iproto.IPROTO_UPDATE}, + {req: NewSQLDeferForeignKeysGetRequest(), async: false, rtype: iproto.IPROTO_SELECT}, + {req: NewSQLFullColumnNamesSetRequest(false), async: false, rtype: iproto.IPROTO_UPDATE}, + {req: NewSQLFullColumnNamesGetRequest(), async: false, rtype: iproto.IPROTO_SELECT}, + {req: NewSQLFullMetadataSetRequest(false), async: false, rtype: iproto.IPROTO_UPDATE}, + {req: NewSQLFullMetadataGetRequest(), async: false, rtype: iproto.IPROTO_SELECT}, + {req: NewSQLParserDebugSetRequest(false), async: false, rtype: iproto.IPROTO_UPDATE}, + {req: NewSQLParserDebugGetRequest(), async: false, rtype: iproto.IPROTO_SELECT}, + {req: NewSQLRecursiveTriggersSetRequest(false), async: false, rtype: iproto.IPROTO_UPDATE}, + {req: NewSQLRecursiveTriggersGetRequest(), async: false, rtype: iproto.IPROTO_SELECT}, + {req: NewSQLReverseUnorderedSelectsSetRequest(false), async: false, + rtype: iproto.IPROTO_UPDATE}, + {req: NewSQLReverseUnorderedSelectsGetRequest(), async: false, + rtype: iproto.IPROTO_SELECT}, + {req: NewSQLSelectDebugSetRequest(false), async: false, rtype: iproto.IPROTO_UPDATE}, + {req: NewSQLSelectDebugGetRequest(), async: false, rtype: iproto.IPROTO_SELECT}, + {req: NewSQLVDBEDebugSetRequest(false), async: false, rtype: iproto.IPROTO_UPDATE}, + {req: NewSQLVDBEDebugGetRequest(), async: false, rtype: iproto.IPROTO_SELECT}, + {req: NewSessionSettingsGetRequest(), async: false, rtype: iproto.IPROTO_SELECT}, + } + + for _, test := range tests { + require.Equal(t, test.async, test.req.Async()) + require.Equal(t, test.rtype, test.req.Type()) + + var reqBuf bytes.Buffer + enc := msgpack.NewEncoder(&reqBuf) + require.Nilf(t, test.req.Body(&resolver, enc), "No errors on fill") + } +} + +func TestRequestsCtx(t *testing.T) { + // tarantool.Request interface doesn't have Context() + getTests := []struct { + req *GetRequest + }{ + {req: NewErrorMarshalingEnabledGetRequest()}, + {req: NewSQLDefaultEngineGetRequest()}, + {req: NewSQLDeferForeignKeysGetRequest()}, + {req: NewSQLFullColumnNamesGetRequest()}, + {req: NewSQLFullMetadataGetRequest()}, + {req: NewSQLParserDebugGetRequest()}, + {req: NewSQLRecursiveTriggersGetRequest()}, + {req: NewSQLReverseUnorderedSelectsGetRequest()}, + {req: NewSQLSelectDebugGetRequest()}, + {req: NewSQLVDBEDebugGetRequest()}, + {req: NewSessionSettingsGetRequest()}, + } + + for _, test := range getTests { + var ctx context.Context + require.Equal(t, ctx, test.req.Context(ctx).Ctx()) + } + + setTests := []struct { + req *SetRequest + }{ + {req: NewErrorMarshalingEnabledSetRequest(false)}, + {req: NewSQLDefaultEngineSetRequest("memtx")}, + {req: NewSQLDeferForeignKeysSetRequest(false)}, + {req: NewSQLFullColumnNamesSetRequest(false)}, + {req: NewSQLFullMetadataSetRequest(false)}, + {req: NewSQLParserDebugSetRequest(false)}, + {req: NewSQLRecursiveTriggersSetRequest(false)}, + {req: NewSQLReverseUnorderedSelectsSetRequest(false)}, + {req: NewSQLSelectDebugSetRequest(false)}, + {req: NewSQLVDBEDebugSetRequest(false)}, + } + + for _, test := range setTests { + var ctx context.Context + require.Equal(t, ctx, test.req.Context(ctx).Ctx()) + } +} diff --git a/settings/tarantool_test.go b/settings/tarantool_test.go new file mode 100644 index 000000000..891959397 --- /dev/null +++ b/settings/tarantool_test.go @@ -0,0 +1,716 @@ +package settings_test + +import ( + "log" + "os" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/tarantool/go-tarantool/v3" + . "github.com/tarantool/go-tarantool/v3/settings" + "github.com/tarantool/go-tarantool/v3/test_helpers" +) + +// There is no way to skip tests in testing.M, +// so we use this variable to pass info +// to each testing.T that it should skip. +var isSettingsSupported = false + +var server = "127.0.0.1:3013" +var dialer = tarantool.NetDialer{ + Address: server, + User: "test", + Password: "test", +} +var opts = tarantool.Opts{ + Timeout: 5 * time.Second, +} + +func skipIfSettingsUnsupported(t *testing.T) { + t.Helper() + + if isSettingsSupported == false { + t.Skip("Skipping test for Tarantool without session settings support") + } +} + +func skipIfErrorMarshalingEnabledSettingUnsupported(t *testing.T) { + t.Helper() + + test_helpers.SkipIfFeatureUnsupported(t, "error_marshaling_enabled session setting", 2, 4, 1) + test_helpers.SkipIfFeatureDropped(t, "error_marshaling_enabled session setting", 2, 10, 0) +} + +func skipIfSQLDeferForeignKeysSettingUnsupported(t *testing.T) { + t.Helper() + + test_helpers.SkipIfFeatureUnsupported(t, "sql_defer_foreign_keys session setting", 2, 3, 1) + test_helpers.SkipIfFeatureDropped(t, "sql_defer_foreign_keys session setting", 2, 10, 5) +} + +func TestErrorMarshalingEnabledSetting(t *testing.T) { + skipIfErrorMarshalingEnabledSettingUnsupported(t) + + var err error + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + // Disable receiving box.error as MP_EXT 3. + data, err := conn.Do(NewErrorMarshalingEnabledSetRequest(false)).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"error_marshaling_enabled", false}}, data) + + // Fetch current setting value. + data, err = conn.Do(NewErrorMarshalingEnabledGetRequest()).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"error_marshaling_enabled", false}}, data) + + // Get a box.Error value. + eval := tarantool.NewEvalRequest("return box.error.new(box.error.UNKNOWN)") + data, err = conn.Do(eval).Get() + require.Nil(t, err) + require.IsType(t, "string", data[0]) + + // Enable receiving box.error as MP_EXT 3. + data, err = conn.Do(NewErrorMarshalingEnabledSetRequest(true)).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"error_marshaling_enabled", true}}, data) + + // Fetch current setting value. + data, err = conn.Do(NewErrorMarshalingEnabledGetRequest()).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"error_marshaling_enabled", true}}, data) + + // Get a box.Error value. + data, err = conn.Do(eval).Get() + require.Nil(t, err) + _, ok := data[0].(*tarantool.BoxError) + require.True(t, ok) +} + +func TestSQLDefaultEngineSetting(t *testing.T) { + // https://github.com/tarantool/tarantool/blob/680990a082374e4790539215f69d9e9ee39c3307/test/sql/engine.test.lua + skipIfSettingsUnsupported(t) + + var err error + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + // Set default SQL "CREATE TABLE" engine to "vinyl". + data, err := conn.Do(NewSQLDefaultEngineSetRequest("vinyl")).Get() + require.Nil(t, err) + require.EqualValues(t, []interface{}{[]interface{}{"sql_default_engine", "vinyl"}}, data) + + // Fetch current setting value. + data, err = conn.Do(NewSQLDefaultEngineGetRequest()).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_default_engine", "vinyl"}}, data) + + // Create a space with "CREATE TABLE". + exec := tarantool.NewExecuteRequest("CREATE TABLE T1_VINYL(a INT PRIMARY KEY, b INT, c INT);") + resp, err := conn.Do(exec).GetResponse() + require.Nil(t, err) + require.NotNil(t, resp) + exResp, ok := resp.(*tarantool.ExecuteResponse) + require.True(t, ok, "wrong response type") + sqlInfo, err := exResp.SQLInfo() + require.Nil(t, err) + require.Equal(t, uint64(1), sqlInfo.AffectedCount) + + // Check new space engine. + eval := tarantool.NewEvalRequest("return box.space['T1_VINYL'].engine") + data, err = conn.Do(eval).Get() + require.Nil(t, err) + require.Equal(t, "vinyl", data[0]) + + // Set default SQL "CREATE TABLE" engine to "memtx". + data, err = conn.Do(NewSQLDefaultEngineSetRequest("memtx")).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_default_engine", "memtx"}}, data) + + // Fetch current setting value. + data, err = conn.Do(NewSQLDefaultEngineGetRequest()).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_default_engine", "memtx"}}, data) + + // Create a space with "CREATE TABLE". + exec = tarantool.NewExecuteRequest("CREATE TABLE T2_MEMTX(a INT PRIMARY KEY, b INT, c INT);") + resp, err = conn.Do(exec).GetResponse() + require.Nil(t, err) + require.NotNil(t, resp) + exResp, ok = resp.(*tarantool.ExecuteResponse) + sqlInfo, err = exResp.SQLInfo() + require.Nil(t, err) + require.True(t, ok, "wrong response type") + require.Equal(t, uint64(1), sqlInfo.AffectedCount) + + // Check new space engine. + eval = tarantool.NewEvalRequest("return box.space['T2_MEMTX'].engine") + data, err = conn.Do(eval).Get() + require.Nil(t, err) + require.Equal(t, "memtx", data[0]) +} + +func TestSQLDeferForeignKeysSetting(t *testing.T) { + // https://github.com/tarantool/tarantool/blob/eafadc13425f14446d7aaa49dea67dfc1d5f45e9/test/sql/transitive-transactions.result + skipIfSQLDeferForeignKeysSettingUnsupported(t) + + var resp tarantool.Response + var err error + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + // Create a parent space. + exec := tarantool.NewExecuteRequest("CREATE TABLE parent(id INT PRIMARY KEY, y INT UNIQUE);") + resp, err = conn.Do(exec).GetResponse() + require.Nil(t, err) + require.NotNil(t, resp) + exResp, ok := resp.(*tarantool.ExecuteResponse) + require.True(t, ok, "wrong response type") + sqlInfo, err := exResp.SQLInfo() + require.Nil(t, err) + require.Equal(t, uint64(1), sqlInfo.AffectedCount) + + // Create a space with reference to the parent space. + exec = tarantool.NewExecuteRequest( + "CREATE TABLE child(id INT PRIMARY KEY, x INT REFERENCES parent(y));") + resp, err = conn.Do(exec).GetResponse() + require.Nil(t, err) + require.NotNil(t, resp) + exResp, ok = resp.(*tarantool.ExecuteResponse) + require.True(t, ok, "wrong response type") + sqlInfo, err = exResp.SQLInfo() + require.Nil(t, err) + require.Equal(t, uint64(1), sqlInfo.AffectedCount) + + deferEval := ` + box.begin() + local _, err = box.execute('INSERT INTO child VALUES (2, 2);') + if err ~= nil then + box.rollback() + error(err) + end + box.execute('INSERT INTO parent VALUES (2, 2);') + box.commit() + return true + ` + + // Disable foreign key constraint checks before commit. + data, err := conn.Do(NewSQLDeferForeignKeysSetRequest(false)).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_defer_foreign_keys", false}}, data) + + // Fetch current setting value. + data, err = conn.Do(NewSQLDeferForeignKeysGetRequest()).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_defer_foreign_keys", false}}, data) + + // Evaluate a scenario when foreign key not exists + // on INSERT, but exists on commit. + _, err = conn.Do(tarantool.NewEvalRequest(deferEval)).Get() + require.NotNil(t, err) + require.ErrorContains(t, err, "Failed to execute SQL statement: FOREIGN KEY constraint failed") + + data, err = conn.Do(NewSQLDeferForeignKeysSetRequest(true)).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_defer_foreign_keys", true}}, data) + + // Fetch current setting value. + data, err = conn.Do(NewSQLDeferForeignKeysGetRequest()).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_defer_foreign_keys", true}}, data) + + // Evaluate a scenario when foreign key not exists + // on INSERT, but exists on commit. + data, err = conn.Do(tarantool.NewEvalRequest(deferEval)).Get() + require.Nil(t, err) + require.Equal(t, true, data[0]) +} + +func TestSQLFullColumnNamesSetting(t *testing.T) { + skipIfSettingsUnsupported(t) + + var resp tarantool.Response + var err error + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + // Create a space. + exec := tarantool.NewExecuteRequest("CREATE TABLE FKNAME(ID INT PRIMARY KEY, X INT);") + resp, err = conn.Do(exec).GetResponse() + require.Nil(t, err) + require.NotNil(t, resp) + exResp, ok := resp.(*tarantool.ExecuteResponse) + require.True(t, ok, "wrong response type") + sqlInfo, err := exResp.SQLInfo() + require.Nil(t, err) + require.Equal(t, uint64(1), sqlInfo.AffectedCount) + + // Fill it with some data. + exec = tarantool.NewExecuteRequest("INSERT INTO FKNAME VALUES (1, 1);") + resp, err = conn.Do(exec).GetResponse() + require.Nil(t, err) + require.NotNil(t, resp) + exResp, ok = resp.(*tarantool.ExecuteResponse) + require.True(t, ok, "wrong response type") + sqlInfo, err = exResp.SQLInfo() + require.Nil(t, err) + require.Equal(t, uint64(1), sqlInfo.AffectedCount) + + // Disable displaying full column names in metadata. + data, err := conn.Do(NewSQLFullColumnNamesSetRequest(false)).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_full_column_names", false}}, data) + + // Fetch current setting value. + data, err = conn.Do(NewSQLFullColumnNamesGetRequest()).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_full_column_names", false}}, data) + + // Get a data with short column names in metadata. + exec = tarantool.NewExecuteRequest("SELECT X FROM FKNAME WHERE ID = 1;") + resp, err = conn.Do(exec).GetResponse() + require.Nil(t, err) + require.NotNil(t, resp) + exResp, ok = resp.(*tarantool.ExecuteResponse) + require.True(t, ok, "wrong response type") + metaData, err := exResp.MetaData() + require.Nil(t, err) + require.Equal(t, "X", metaData[0].FieldName) + + // Enable displaying full column names in metadata. + data, err = conn.Do(NewSQLFullColumnNamesSetRequest(true)).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_full_column_names", true}}, data) + + // Fetch current setting value. + data, err = conn.Do(NewSQLFullColumnNamesGetRequest()).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_full_column_names", true}}, data) + + // Get a data with full column names in metadata. + exec = tarantool.NewExecuteRequest("SELECT X FROM FKNAME WHERE ID = 1;") + resp, err = conn.Do(exec).GetResponse() + require.Nil(t, err) + require.NotNil(t, resp) + exResp, ok = resp.(*tarantool.ExecuteResponse) + require.True(t, ok, "wrong response type") + metaData, err = exResp.MetaData() + require.Nil(t, err) + require.Equal(t, "FKNAME.X", metaData[0].FieldName) +} + +func TestSQLFullMetadataSetting(t *testing.T) { + skipIfSettingsUnsupported(t) + + var resp tarantool.Response + var err error + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + // Create a space. + exec := tarantool.NewExecuteRequest("CREATE TABLE fmt(id INT PRIMARY KEY, x INT);") + resp, err = conn.Do(exec).GetResponse() + require.Nil(t, err) + require.NotNil(t, resp) + exResp, ok := resp.(*tarantool.ExecuteResponse) + require.True(t, ok, "wrong response type") + sqlInfo, err := exResp.SQLInfo() + require.Nil(t, err) + require.Equal(t, uint64(1), sqlInfo.AffectedCount) + + // Fill it with some data. + exec = tarantool.NewExecuteRequest("INSERT INTO fmt VALUES (1, 1);") + resp, err = conn.Do(exec).GetResponse() + require.Nil(t, err) + require.NotNil(t, resp) + exResp, ok = resp.(*tarantool.ExecuteResponse) + require.True(t, ok, "wrong response type") + sqlInfo, err = exResp.SQLInfo() + require.Nil(t, err) + require.Equal(t, uint64(1), sqlInfo.AffectedCount) + + // Disable displaying additional fields in metadata. + data, err := conn.Do(NewSQLFullMetadataSetRequest(false)).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_full_metadata", false}}, data) + + // Fetch current setting value. + data, err = conn.Do(NewSQLFullMetadataGetRequest()).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_full_metadata", false}}, data) + + // Get a data without additional fields in metadata. + exec = tarantool.NewExecuteRequest("SELECT x FROM fmt WHERE id = 1;") + resp, err = conn.Do(exec).GetResponse() + require.Nil(t, err) + require.NotNil(t, resp) + exResp, ok = resp.(*tarantool.ExecuteResponse) + require.True(t, ok, "wrong response type") + metaData, err := exResp.MetaData() + require.Nil(t, err) + require.Equal(t, "", metaData[0].FieldSpan) + + // Enable displaying full column names in metadata. + data, err = conn.Do(NewSQLFullMetadataSetRequest(true)).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_full_metadata", true}}, data) + + // Fetch current setting value. + data, err = conn.Do(NewSQLFullMetadataGetRequest()).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_full_metadata", true}}, data) + + // Get a data with additional fields in metadata. + exec = tarantool.NewExecuteRequest("SELECT x FROM fmt WHERE id = 1;") + resp, err = conn.Do(exec).GetResponse() + require.Nil(t, err) + require.NotNil(t, resp) + exResp, ok = resp.(*tarantool.ExecuteResponse) + require.True(t, ok, "wrong response type") + metaData, err = exResp.MetaData() + require.Nil(t, err) + require.Equal(t, "x", metaData[0].FieldSpan) +} + +func TestSQLParserDebugSetting(t *testing.T) { + skipIfSettingsUnsupported(t) + + var err error + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + // Disable parser debug mode. + data, err := conn.Do(NewSQLParserDebugSetRequest(false)).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_parser_debug", false}}, data) + + // Fetch current setting value. + data, err = conn.Do(NewSQLParserDebugGetRequest()).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_parser_debug", false}}, data) + + // Enable parser debug mode. + data, err = conn.Do(NewSQLParserDebugSetRequest(true)).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_parser_debug", true}}, data) + + // Fetch current setting value. + data, err = conn.Do(NewSQLParserDebugGetRequest()).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_parser_debug", true}}, data) + + // To test real effect we need a Tarantool instance built with + // `-DCMAKE_BUILD_TYPE=Debug`. +} + +func TestSQLRecursiveTriggersSetting(t *testing.T) { + // https://github.com/tarantool/tarantool/blob/d11fb3061e15faf4e0eb5375fb8056b4e64348ae/test/sql-tap/triggerC.test.lua + skipIfSettingsUnsupported(t) + + var resp tarantool.Response + var err error + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + // Create a space. + exec := tarantool.NewExecuteRequest("CREATE TABLE rec(id INTEGER PRIMARY KEY, a INT, b INT);") + resp, err = conn.Do(exec).GetResponse() + require.Nil(t, err) + require.NotNil(t, resp) + exResp, ok := resp.(*tarantool.ExecuteResponse) + require.True(t, ok, "wrong response type") + sqlInfo, err := exResp.SQLInfo() + require.Nil(t, err) + require.Equal(t, uint64(1), sqlInfo.AffectedCount) + + // Fill it with some data. + exec = tarantool.NewExecuteRequest("INSERT INTO rec VALUES(1, 1, 2);") + resp, err = conn.Do(exec).GetResponse() + require.Nil(t, err) + require.NotNil(t, resp) + exResp, ok = resp.(*tarantool.ExecuteResponse) + require.True(t, ok, "wrong response type") + sqlInfo, err = exResp.SQLInfo() + require.Nil(t, err) + require.Equal(t, uint64(1), sqlInfo.AffectedCount) + + // Create a recursive trigger (with infinite depth). + exec = tarantool.NewExecuteRequest(` + CREATE TRIGGER tr12 AFTER UPDATE ON rec FOR EACH ROW BEGIN + UPDATE rec SET a=new.a+1, b=new.b+1; + END;`) + resp, err = conn.Do(exec).GetResponse() + require.Nil(t, err) + require.NotNil(t, resp) + exResp, ok = resp.(*tarantool.ExecuteResponse) + require.True(t, ok, "wrong response type") + sqlInfo, err = exResp.SQLInfo() + require.Nil(t, err) + require.Equal(t, uint64(1), sqlInfo.AffectedCount) + + // Enable SQL recursive triggers. + data, err := conn.Do(NewSQLRecursiveTriggersSetRequest(true)).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_recursive_triggers", true}}, data) + + // Fetch current setting value. + data, err = conn.Do(NewSQLRecursiveTriggersGetRequest()).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_recursive_triggers", true}}, data) + + // Trigger the recursion. + exec = tarantool.NewExecuteRequest("UPDATE rec SET a=a+1, b=b+1;") + _, err = conn.Do(exec).Get() + require.NotNil(t, err) + require.ErrorContains(t, err, + "Failed to execute SQL statement: too many levels of trigger recursion") + + // Disable SQL recursive triggers. + data, err = conn.Do(NewSQLRecursiveTriggersSetRequest(false)).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_recursive_triggers", false}}, data) + + // Fetch current setting value. + data, err = conn.Do(NewSQLRecursiveTriggersGetRequest()).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_recursive_triggers", false}}, data) + + // Trigger the recursion. + exec = tarantool.NewExecuteRequest("UPDATE rec SET a=a+1, b=b+1;") + resp, err = conn.Do(exec).GetResponse() + require.Nil(t, err) + require.NotNil(t, resp) + exResp, ok = resp.(*tarantool.ExecuteResponse) + require.True(t, ok, "wrong response type") + sqlInfo, err = exResp.SQLInfo() + require.Nil(t, err) + require.Equal(t, uint64(1), sqlInfo.AffectedCount) +} + +func TestSQLReverseUnorderedSelectsSetting(t *testing.T) { + skipIfSettingsUnsupported(t) + + var resp tarantool.Response + var err error + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + // Create a space. + exec := tarantool.NewExecuteRequest("CREATE TABLE data(id STRING PRIMARY KEY);") + resp, err = conn.Do(exec).GetResponse() + require.Nil(t, err) + require.NotNil(t, resp) + exResp, ok := resp.(*tarantool.ExecuteResponse) + require.True(t, ok, "wrong response type") + sqlInfo, err := exResp.SQLInfo() + require.Nil(t, err) + require.Equal(t, uint64(1), sqlInfo.AffectedCount) + + // Fill it with some data. + exec = tarantool.NewExecuteRequest("INSERT INTO data VALUES('1');") + resp, err = conn.Do(exec).GetResponse() + require.Nil(t, err) + require.NotNil(t, resp) + exResp, ok = resp.(*tarantool.ExecuteResponse) + require.True(t, ok, "wrong response type") + sqlInfo, err = exResp.SQLInfo() + require.Nil(t, err) + require.Equal(t, uint64(1), sqlInfo.AffectedCount) + + exec = tarantool.NewExecuteRequest("INSERT INTO data VALUES('2');") + resp, err = conn.Do(exec).GetResponse() + require.Nil(t, err) + require.NotNil(t, resp) + exResp, ok = resp.(*tarantool.ExecuteResponse) + require.True(t, ok, "wrong response type") + sqlInfo, err = exResp.SQLInfo() + require.Nil(t, err) + require.Equal(t, uint64(1), sqlInfo.AffectedCount) + + // Disable reverse order in unordered selects. + data, err := conn.Do(NewSQLReverseUnorderedSelectsSetRequest(false)).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_reverse_unordered_selects", false}}, + data) + + // Fetch current setting value. + data, err = conn.Do(NewSQLReverseUnorderedSelectsGetRequest()).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_reverse_unordered_selects", false}}, + data) + + // Select multiple records. + query := "SELECT * FROM seqscan data;" + if isSeqScanOld, err := test_helpers.IsTarantoolVersionLess(3, 0, 0); err != nil { + t.Fatalf("Could not check the Tarantool version: %s", err) + } else if isSeqScanOld { + query = "SELECT * FROM data;" + } + + data, err = conn.Do(tarantool.NewExecuteRequest(query)).Get() + require.Nil(t, err) + require.EqualValues(t, []interface{}{"1"}, data[0]) + require.EqualValues(t, []interface{}{"2"}, data[1]) + + // Enable reverse order in unordered selects. + data, err = conn.Do(NewSQLReverseUnorderedSelectsSetRequest(true)).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_reverse_unordered_selects", true}}, + data) + + // Fetch current setting value. + data, err = conn.Do(NewSQLReverseUnorderedSelectsGetRequest()).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_reverse_unordered_selects", true}}, + data) + + // Select multiple records. + data, err = conn.Do(tarantool.NewExecuteRequest(query)).Get() + require.Nil(t, err) + require.EqualValues(t, []interface{}{"2"}, data[0]) + require.EqualValues(t, []interface{}{"1"}, data[1]) +} + +func TestSQLSelectDebugSetting(t *testing.T) { + skipIfSettingsUnsupported(t) + + var err error + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + // Disable select debug mode. + data, err := conn.Do(NewSQLSelectDebugSetRequest(false)).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_select_debug", false}}, data) + + // Fetch current setting value. + data, err = conn.Do(NewSQLSelectDebugGetRequest()).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_select_debug", false}}, data) + + // Enable select debug mode. + data, err = conn.Do(NewSQLSelectDebugSetRequest(true)).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_select_debug", true}}, data) + + // Fetch current setting value. + data, err = conn.Do(NewSQLSelectDebugGetRequest()).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_select_debug", true}}, data) + + // To test real effect we need a Tarantool instance built with + // `-DCMAKE_BUILD_TYPE=Debug`. +} + +func TestSQLVDBEDebugSetting(t *testing.T) { + skipIfSettingsUnsupported(t) + + var err error + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + // Disable VDBE debug mode. + data, err := conn.Do(NewSQLVDBEDebugSetRequest(false)).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_vdbe_debug", false}}, data) + + // Fetch current setting value. + data, err = conn.Do(NewSQLVDBEDebugGetRequest()).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_vdbe_debug", false}}, data) + + // Enable VDBE debug mode. + data, err = conn.Do(NewSQLVDBEDebugSetRequest(true)).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_vdbe_debug", true}}, data) + + // Fetch current setting value. + data, err = conn.Do(NewSQLVDBEDebugGetRequest()).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_vdbe_debug", true}}, data) + + // To test real effect we need a Tarantool instance built with + // `-DCMAKE_BUILD_TYPE=Debug`. +} + +func TestSessionSettings(t *testing.T) { + skipIfSettingsUnsupported(t) + + var err error + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + // Set some settings values. + data, err := conn.Do(NewSQLDefaultEngineSetRequest("memtx")).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_default_engine", "memtx"}}, data) + + data, err = conn.Do(NewSQLFullColumnNamesSetRequest(true)).Get() + require.Nil(t, err) + require.Equal(t, []interface{}{[]interface{}{"sql_full_column_names", true}}, data) + + // Fetch current settings values. + data, err = conn.Do(NewSessionSettingsGetRequest()).Get() + require.Nil(t, err) + require.Subset(t, data, + []interface{}{ + []interface{}{"sql_default_engine", "memtx"}, + []interface{}{"sql_full_column_names", true}, + }) +} + +// runTestMain is a body of TestMain function +// (see https://pkg.go.dev/testing#hdr-Main). +// Using defer + os.Exit is not works so TestMain body +// is a separate function, see +// https://stackoverflow.com/questions/27629380/how-to-exit-a-go-program-honoring-deferred-calls +func runTestMain(m *testing.M) int { + isLess, err := test_helpers.IsTarantoolVersionLess(2, 3, 1) + if err != nil { + log.Fatalf("Failed to extract tarantool version: %s", err) + } + + if isLess { + log.Println("Skipping session settings tests...") + isSettingsSupported = false + return m.Run() + } + + isSettingsSupported = true + + inst, err := test_helpers.StartTarantool(test_helpers.StartOpts{ + Dialer: dialer, + InitScript: "testdata/config.lua", + Listen: server, + WaitStart: 100 * time.Millisecond, + ConnectRetry: 10, + RetryTimeout: 500 * time.Millisecond, + }) + + if err != nil { + log.Fatalf("Failed to prepare test tarantool: %s", err) + } + + defer test_helpers.StopTarantoolWithCleanup(inst) + + return m.Run() +} + +func TestMain(m *testing.M) { + code := runTestMain(m) + os.Exit(code) +} diff --git a/settings/testdata/config.lua b/settings/testdata/config.lua new file mode 100644 index 000000000..7f6af1db2 --- /dev/null +++ b/settings/testdata/config.lua @@ -0,0 +1,15 @@ +-- Do not set listen for now so connector won't be +-- able to send requests until everything is configured. +box.cfg{ + work_dir = os.getenv("TEST_TNT_WORK_DIR"), +} + +box.schema.user.create('test', { password = 'test' , if_not_exists = true }) +box.schema.user.grant('test', 'execute', 'universe', nil, { if_not_exists = true }) +box.schema.user.grant('test', 'create,read,write,drop,alter', 'space', nil, { if_not_exists = true }) +box.schema.user.grant('test', 'create', 'sequence', nil, { if_not_exists = true }) + +-- Set listen only when every other thing is configured. +box.cfg{ + listen = os.getenv("TEST_TNT_LISTEN"), +} diff --git a/shutdown_test.go b/shutdown_test.go new file mode 100644 index 000000000..434600824 --- /dev/null +++ b/shutdown_test.go @@ -0,0 +1,574 @@ +//go:build linux || (darwin && !cgo) +// +build linux darwin,!cgo + +// Use OS build flags since signals are system-dependent. +package tarantool_test + +import ( + "fmt" + "sync" + "syscall" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + . "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/test_helpers" +) + +var shtdnServer = "127.0.0.1:3014" +var shtdnDialer = NetDialer{ + Address: shtdnServer, + User: dialer.User, + Password: dialer.Password, +} + +var shtdnClntOpts = Opts{ + Timeout: 20 * time.Second, + Reconnect: 500 * time.Millisecond, + MaxReconnects: 10, +} +var shtdnSrvOpts = test_helpers.StartOpts{ + Dialer: shtdnDialer, + InitScript: "config.lua", + Listen: shtdnServer, + WaitStart: 100 * time.Millisecond, + ConnectRetry: 10, + RetryTimeout: 500 * time.Millisecond, +} + +var evalMsg = "got enough sleep" +var evalBody = ` + local fiber = require('fiber') + local time, msg = ... + fiber.sleep(time) + return msg +` + +func testGracefulShutdown(t *testing.T, conn *Connection, inst *test_helpers.TarantoolInstance) { + var err error + + // Set a big timeout so it would be easy to differ + // if server went down on timeout or after all connections were terminated. + serverShutdownTimeout := 60 // in seconds + _, err = conn.Call("box.ctl.set_on_shutdown_timeout", []interface{}{serverShutdownTimeout}) + require.Nil(t, err) + + // Send request with sleep. + evalSleep := 1 // in seconds + require.Lessf(t, + time.Duration(evalSleep)*time.Second, + shtdnClntOpts.Timeout, + "test request won't be failed by timeout") + + // Create a helper watcher to ensure that async + // shutdown is set up. + helperCh := make(chan WatchEvent, 10) + helperW, herr := conn.NewWatcher("box.shutdown", func(event WatchEvent) { + helperCh <- event + }) + require.Nil(t, herr) + defer helperW.Unregister() + <-helperCh + + req := NewEvalRequest(evalBody).Args([]interface{}{evalSleep, evalMsg}) + + fut := conn.Do(req) + + // SIGTERM the server. + shutdownStart := time.Now() + require.Nil(t, inst.Signal(syscall.SIGTERM)) + + // Check that we can't send new requests after shutdown starts. + // Retry helps to wait a bit until server starts to shutdown + // and send us the shutdown event. + shutdownWaitRetries := 5 + shutdownWaitTimeout := 100 * time.Millisecond + + err = test_helpers.Retry(func(interface{}) error { + _, err = conn.Do(NewPingRequest()).Get() + if err == nil { + return fmt.Errorf("expected error for requests sent on shutdown") + } + + if err.Error() != "server shutdown in progress (0x4005)" { + return err + } + + return nil + }, nil, shutdownWaitRetries, shutdownWaitTimeout) + require.Nil(t, err) + + // Check that requests started before the shutdown finish successfully. + data, err := fut.Get() + require.Nil(t, err) + require.Equal(t, data, []interface{}{evalMsg}) + + // Wait until server go down. + // Server will go down only when it process all requests from our connection + // (or on timeout). + err = inst.Wait() + require.Nil(t, err) + shutdownFinish := time.Now() + shutdownTime := shutdownFinish.Sub(shutdownStart) + + // Check that it wasn't a timeout. + require.Lessf(t, + shutdownTime, + time.Duration(serverShutdownTimeout/2)*time.Second, + "server went down not by timeout") + + // Connection is unavailable when server is down. + require.Equal(t, false, conn.ConnectedNow()) +} + +func TestGracefulShutdown(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + var conn *Connection + + inst, err := test_helpers.StartTarantool(shtdnSrvOpts) + require.Nil(t, err) + defer test_helpers.StopTarantoolWithCleanup(inst) + + conn = test_helpers.ConnectWithValidation(t, shtdnDialer, shtdnClntOpts) + defer conn.Close() + + testGracefulShutdown(t, conn, inst) +} + +func TestCloseGraceful(t *testing.T) { + opts := Opts{ + Timeout: shtdnClntOpts.Timeout, + } + testDialer := shtdnDialer + testDialer.RequiredProtocolInfo = ProtocolInfo{} + testSrvOpts := shtdnSrvOpts + testSrvOpts.Dialer = testDialer + + inst, err := test_helpers.StartTarantool(testSrvOpts) + require.Nil(t, err) + defer test_helpers.StopTarantoolWithCleanup(inst) + + conn := test_helpers.ConnectWithValidation(t, testDialer, opts) + defer conn.Close() + + // Send request with sleep. + evalSleep := 3 // In seconds. + require.Lessf(t, + time.Duration(evalSleep)*time.Second, + shtdnClntOpts.Timeout, + "test request won't be failed by timeout") + + req := NewEvalRequest(evalBody).Args([]interface{}{evalSleep, evalMsg}) + fut := conn.Do(req) + + go func() { + // CloseGraceful closes the connection gracefully. + conn.CloseGraceful() + // Connection is closed. + assert.Equal(t, true, conn.ClosedNow()) + }() + + // Check that a request rejected if graceful shutdown in progress. + time.Sleep((time.Duration(evalSleep) * time.Second) / 2) + _, err = conn.Do(NewPingRequest()).Get() + assert.ErrorContains(t, err, "server shutdown in progress") + + // Check that a previous request was successful. + resp, err := fut.Get() + assert.Nilf(t, err, "sleep request no error") + assert.NotNilf(t, resp, "sleep response exists") +} + +func TestGracefulShutdownWithReconnect(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + inst, err := test_helpers.StartTarantool(shtdnSrvOpts) + require.Nil(t, err) + defer test_helpers.StopTarantoolWithCleanup(inst) + + conn := test_helpers.ConnectWithValidation(t, shtdnDialer, shtdnClntOpts) + defer conn.Close() + + testGracefulShutdown(t, conn, inst) + + err = test_helpers.RestartTarantool(inst) + require.Nilf(t, err, "Failed to restart tarantool") + + connected := test_helpers.WaitUntilReconnected(conn, shtdnClntOpts.MaxReconnects, + shtdnClntOpts.Reconnect) + require.Truef(t, connected, "Reconnect success") + + testGracefulShutdown(t, conn, inst) +} + +func TestNoGracefulShutdown(t *testing.T) { + // No watchers = no graceful shutdown. + noSthdClntOpts := opts + noShtdDialer := shtdnDialer + noShtdDialer.RequiredProtocolInfo = ProtocolInfo{} + test_helpers.SkipIfWatchersSupported(t) + + var conn *Connection + + testSrvOpts := shtdnSrvOpts + testSrvOpts.Dialer = noShtdDialer + + inst, err := test_helpers.StartTarantool(testSrvOpts) + require.Nil(t, err) + defer test_helpers.StopTarantoolWithCleanup(inst) + + conn = test_helpers.ConnectWithValidation(t, noShtdDialer, noSthdClntOpts) + defer conn.Close() + + evalSleep := 10 // in seconds + serverShutdownTimeout := 60 // in seconds + require.Less(t, evalSleep, serverShutdownTimeout) + + // Send request with sleep. + require.Lessf(t, + time.Duration(evalSleep)*time.Second, + shtdnClntOpts.Timeout, + "test request won't be failed by timeout") + + req := NewEvalRequest(evalBody).Args([]interface{}{evalSleep, evalMsg}) + + fut := conn.Do(req) + + // SIGTERM the server. + shutdownStart := time.Now() + require.Nil(t, inst.Signal(syscall.SIGTERM)) + + // Check that request was interrupted. + _, err = fut.Get() + require.NotNilf(t, err, "sleep request error") + + // Wait until server go down. + err = inst.Wait() + require.Nil(t, err) + shutdownFinish := time.Now() + shutdownTime := shutdownFinish.Sub(shutdownStart) + + // Check that server finished without waiting for eval to finish. + require.Lessf(t, + shutdownTime, + time.Duration(evalSleep/2)*time.Second, + "server went down without any additional waiting") +} + +func TestGracefulShutdownRespectsClose(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + var conn *Connection + + inst, err := test_helpers.StartTarantool(shtdnSrvOpts) + require.Nil(t, err) + defer test_helpers.StopTarantoolWithCleanup(inst) + + conn = test_helpers.ConnectWithValidation(t, shtdnDialer, shtdnClntOpts) + defer conn.Close() + + // Create a helper watcher to ensure that async + // shutdown is set up. + helperCh := make(chan WatchEvent, 10) + helperW, herr := conn.NewWatcher("box.shutdown", func(event WatchEvent) { + helperCh <- event + }) + require.Nil(t, herr) + defer helperW.Unregister() + <-helperCh + + // Set a big timeout so it would be easy to differ + // if server went down on timeout or after all connections were terminated. + serverShutdownTimeout := 60 // in seconds + _, err = conn.Call("box.ctl.set_on_shutdown_timeout", []interface{}{serverShutdownTimeout}) + require.Nil(t, err) + + // Send request with sleep. + evalSleep := 10 // in seconds + require.Lessf(t, + time.Duration(evalSleep)*time.Second, + shtdnClntOpts.Timeout, + "test request won't be failed by timeout") + + req := NewEvalRequest(evalBody).Args([]interface{}{evalSleep, evalMsg}) + + fut := conn.Do(req) + + // SIGTERM the server. + shutdownStart := time.Now() + require.Nil(t, inst.Signal(syscall.SIGTERM)) + + // Close the connection. + conn.Close() + + // Connection is closed. + require.Equal(t, true, conn.ClosedNow()) + + // Check that request was interrupted. + _, err = fut.Get() + require.NotNilf(t, err, "sleep request error") + + // Wait until server go down. + err = inst.Wait() + require.Nil(t, err) + shutdownFinish := time.Now() + shutdownTime := shutdownFinish.Sub(shutdownStart) + + // Check that server finished without waiting for eval to finish. + require.Lessf(t, + shutdownTime, + time.Duration(evalSleep/2)*time.Second, + "server went down without any additional waiting") + + // Check that it wasn't a timeout. + require.Lessf(t, + shutdownTime, + time.Duration(serverShutdownTimeout/2)*time.Second, + "server went down not by timeout") + + // Connection is still closed. + require.Equal(t, true, conn.ClosedNow()) +} + +func TestGracefulShutdownNotRacesWithRequestReconnect(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + var conn *Connection + + inst, err := test_helpers.StartTarantool(shtdnSrvOpts) + require.Nil(t, err) + defer test_helpers.StopTarantoolWithCleanup(inst) + + conn = test_helpers.ConnectWithValidation(t, shtdnDialer, shtdnClntOpts) + defer conn.Close() + + // Create a helper watcher to ensure that async + // shutdown is set up. + helperCh := make(chan WatchEvent, 10) + helperW, herr := conn.NewWatcher("box.shutdown", func(event WatchEvent) { + helperCh <- event + }) + require.Nil(t, herr) + defer helperW.Unregister() + <-helperCh + + // Set a small timeout so server will shutdown before requesst finishes. + serverShutdownTimeout := 1 // in seconds + _, err = conn.Call("box.ctl.set_on_shutdown_timeout", []interface{}{serverShutdownTimeout}) + require.Nil(t, err) + + // Send request with sleep. + evalSleep := 5 // in seconds + require.Lessf(t, + serverShutdownTimeout, + evalSleep, + "test request will be failed by timeout") + require.Lessf(t, + time.Duration(serverShutdownTimeout)*time.Second, + shtdnClntOpts.Timeout, + "test request will be failed by timeout") + + req := NewEvalRequest(evalBody).Args([]interface{}{evalSleep, evalMsg}) + + evalStart := time.Now() + fut := conn.Do(req) + + // SIGTERM the server. + require.Nil(t, inst.Signal(syscall.SIGTERM)) + + // Wait until server go down. + // Server is expected to go down on timeout. + err = inst.Wait() + require.Nil(t, err) + + // Check that request failed by server disconnect, not a client timeout. + _, err = fut.Get() + require.NotNil(t, err) + require.NotContains(t, err.Error(), "client timeout for request") + + evalFinish := time.Now() + evalTime := evalFinish.Sub(evalStart) + + // Check that it wasn't a client timeout. + require.Lessf(t, + evalTime, + shtdnClntOpts.Timeout, + "server went down not by timeout") +} + +func TestGracefulShutdownCloseConcurrent(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + var srvShtdnStart, srvShtdnFinish time.Time + + inst, err := test_helpers.StartTarantool(shtdnSrvOpts) + require.Nil(t, err) + defer test_helpers.StopTarantoolWithCleanup(inst) + + conn := test_helpers.ConnectWithValidation(t, shtdnDialer, shtdnClntOpts) + defer conn.Close() + + // Create a helper watcher to ensure that async + // shutdown is set up. + helperCh := make(chan WatchEvent, 10) + helperW, herr := conn.NewWatcher("box.shutdown", func(event WatchEvent) { + helperCh <- event + }) + require.Nil(t, herr) + defer helperW.Unregister() + <-helperCh + + // Set a big timeout so it would be easy to differ + // if server went down on timeout or after all connections were terminated. + serverShutdownTimeout := 60 // in seconds + _, err = conn.Call("box.ctl.set_on_shutdown_timeout", []interface{}{serverShutdownTimeout}) + require.Nil(t, err) + conn.Close() + + const testConcurrency = 50 + + var caseWg, srvToStop, srvStop sync.WaitGroup + caseWg.Add(testConcurrency) + srvToStop.Add(testConcurrency) + srvStop.Add(1) + + // Create many connections. + for i := 0; i < testConcurrency; i++ { + go func(i int) { + defer caseWg.Done() + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + + // Do not wait till Tarantool register out watcher, + // test everything is ok even on async. + conn, err := Connect(ctx, shtdnDialer, shtdnClntOpts) + if err != nil { + t.Errorf("Failed to connect: %s", err) + } else { + defer conn.Close() + } + + // Wait till all connections created. + srvToStop.Done() + srvStop.Wait() + }(i) + } + + var sret error + go func(inst *test_helpers.TarantoolInstance) { + srvToStop.Wait() + srvShtdnStart = time.Now() + cerr := inst.Signal(syscall.SIGTERM) + if cerr != nil { + sret = cerr + } + srvStop.Done() + }(inst) + + srvStop.Wait() + require.Nil(t, sret, "No errors on server SIGTERM") + + err = inst.Wait() + require.Nil(t, err) + + srvShtdnFinish = time.Now() + srvShtdnTime := srvShtdnFinish.Sub(srvShtdnStart) + + require.Less(t, + srvShtdnTime, + time.Duration(serverShutdownTimeout/2)*time.Second, + "server went down not by timeout") +} + +func TestGracefulShutdownConcurrent(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + var srvShtdnStart, srvShtdnFinish time.Time + + inst, err := test_helpers.StartTarantool(shtdnSrvOpts) + require.Nil(t, err) + defer test_helpers.StopTarantoolWithCleanup(inst) + + conn := test_helpers.ConnectWithValidation(t, shtdnDialer, shtdnClntOpts) + defer conn.Close() + + // Set a big timeout so it would be easy to differ + // if server went down on timeout or after all connections were terminated. + serverShutdownTimeout := 60 // in seconds + _, err = conn.Call("box.ctl.set_on_shutdown_timeout", []interface{}{serverShutdownTimeout}) + require.Nil(t, err) + conn.Close() + + const testConcurrency = 50 + + var caseWg, srvToStop, srvStop sync.WaitGroup + caseWg.Add(testConcurrency) + srvToStop.Add(testConcurrency) + srvStop.Add(1) + + // Create many connections. + var ret error + for i := 0; i < testConcurrency; i++ { + go func(i int) { + defer caseWg.Done() + + conn := test_helpers.ConnectWithValidation(t, shtdnDialer, shtdnClntOpts) + defer conn.Close() + + // Create a helper watcher to ensure that async + // shutdown is set up. + helperCh := make(chan WatchEvent, 10) + helperW, _ := conn.NewWatcher("box.shutdown", func(event WatchEvent) { + helperCh <- event + }) + defer helperW.Unregister() + <-helperCh + + evalSleep := 1 // in seconds + req := NewEvalRequest(evalBody).Args([]interface{}{evalSleep, evalMsg}) + fut := conn.Do(req) + + // Wait till all connections had started sleeping. + srvToStop.Done() + srvStop.Wait() + + _, gerr := fut.Get() + if gerr != nil { + ret = gerr + } + }(i) + } + + var sret error + go func(inst *test_helpers.TarantoolInstance) { + srvToStop.Wait() + srvShtdnStart = time.Now() + cerr := inst.Signal(syscall.SIGTERM) + if cerr != nil { + sret = cerr + } + srvStop.Done() + }(inst) + + srvStop.Wait() + require.Nil(t, sret, "No errors on server SIGTERM") + + caseWg.Wait() + require.Nil(t, ret, "No errors on concurrent wait") + + err = inst.Wait() + require.Nil(t, err) + + srvShtdnFinish = time.Now() + srvShtdnTime := srvShtdnFinish.Sub(srvShtdnStart) + + require.Less(t, + srvShtdnTime, + time.Duration(serverShutdownTimeout/2)*time.Second, + "server went down not by timeout") +} diff --git a/smallbuf.go b/smallbuf.go new file mode 100644 index 000000000..a6590b409 --- /dev/null +++ b/smallbuf.go @@ -0,0 +1,112 @@ +package tarantool + +import ( + "errors" + "io" +) + +type smallBuf struct { + b []byte + p int +} + +func (s *smallBuf) Read(d []byte) (l int, err error) { + l = len(s.b) - s.p + if l == 0 && len(d) > 0 { + return 0, io.EOF + } + if l > len(d) { + l = len(d) + } + copy(d, s.b[s.p:]) + s.p += l + return l, nil +} + +func (s *smallBuf) ReadByte() (b byte, err error) { + if s.p == len(s.b) { + return 0, io.EOF + } + b = s.b[s.p] + s.p++ + return b, nil +} + +func (s *smallBuf) UnreadByte() error { + if s.p == 0 { + return errors.New("could not unread") + } + s.p-- + return nil +} + +func (s *smallBuf) Len() int { + return len(s.b) - s.p +} + +func (s *smallBuf) Bytes() []byte { + if len(s.b) > s.p { + return s.b[s.p:] + } + return nil +} + +func (s *smallBuf) Offset() int { + return s.p +} + +func (s *smallBuf) Seek(offset int) error { + if offset < 0 { + return errors.New("too small offset") + } + if offset > len(s.b) { + return errors.New("too big offset") + } + s.p = offset + return nil +} + +type smallWBuf struct { + b []byte + sum uint + n uint +} + +func (s *smallWBuf) Write(b []byte) (int, error) { + s.b = append(s.b, b...) + return len(s.b), nil +} + +func (s *smallWBuf) WriteByte(b byte) error { + s.b = append(s.b, b) + return nil +} + +func (s *smallWBuf) WriteString(ss string) (int, error) { + s.b = append(s.b, ss...) + return len(ss), nil +} + +func (s smallWBuf) Len() int { + return len(s.b) +} + +func (s smallWBuf) Cap() int { + return cap(s.b) +} + +func (s *smallWBuf) Trunc(n int) { + s.b = s.b[:n] +} + +func (s *smallWBuf) Reset() { + s.sum = uint(uint64(s.sum)*15/16) + uint(len(s.b)) + if s.n < 16 { + s.n++ + } + if cap(s.b) > 1024 && s.sum/s.n < uint(cap(s.b))/4 { + s.b = make([]byte, 0, s.sum/s.n) + } else { + s.b = s.b[:0] + } +} diff --git a/stream.go b/stream.go new file mode 100644 index 000000000..cdabf12fa --- /dev/null +++ b/stream.go @@ -0,0 +1,250 @@ +package tarantool + +import ( + "context" + "errors" + "time" + + "github.com/tarantool/go-iproto" + "github.com/vmihailenco/msgpack/v5" +) + +type TxnIsolationLevel uint + +const ( + // By default, the isolation level of Tarantool is serializable. + DefaultIsolationLevel TxnIsolationLevel = 0 + // The ReadCommittedLevel isolation level makes visible all transactions + // that started commit (stream.Do(NewCommitRequest()) was called). + ReadCommittedLevel TxnIsolationLevel = 1 + // The ReadConfirmedLevel isolation level makes visible all transactions + // that finished the commit (stream.Do(NewCommitRequest()) was returned). + ReadConfirmedLevel TxnIsolationLevel = 2 + // If the BestEffortLevel (serializable) isolation level becomes unreachable, + // the transaction is marked as «conflicted» and can no longer be committed. + BestEffortLevel TxnIsolationLevel = 3 +) + +var ( + errUnknownStreamRequest = errors.New("the passed connected request doesn't belong " + + "to the current connection or connection pool") +) + +type Stream struct { + Id uint64 + Conn *Connection +} + +// BeginRequest helps you to create a begin request object for execution +// by a Stream. +// Begin request can not be processed out of stream. +type BeginRequest struct { + baseRequest + txnIsolation TxnIsolationLevel + timeout time.Duration + isSync bool + isSyncSet bool +} + +// NewBeginRequest returns a new BeginRequest. +func NewBeginRequest() *BeginRequest { + req := new(BeginRequest) + req.rtype = iproto.IPROTO_BEGIN + req.txnIsolation = DefaultIsolationLevel + return req +} + +// TxnIsolation sets the the transaction isolation level for transaction manager. +// By default, the isolation level of Tarantool is serializable. +func (req *BeginRequest) TxnIsolation(txnIsolation TxnIsolationLevel) *BeginRequest { + req.txnIsolation = txnIsolation + return req +} + +// Timeout allows to set up a timeout for call BeginRequest. +func (req *BeginRequest) Timeout(timeout time.Duration) *BeginRequest { + req.timeout = timeout + return req +} + +// IsSync allows to set up a IsSync flag for call BeginRequest. +func (req *BeginRequest) IsSync(isSync bool) *BeginRequest { + req.isSync = isSync + req.isSyncSet = true + return req +} + +// Body fills an msgpack.Encoder with the begin request body. +func (req *BeginRequest) Body(_ SchemaResolver, enc *msgpack.Encoder) error { + var ( + mapLen = 0 + hasTimeout = req.timeout > 0 + hasIsolationLevel = req.txnIsolation != DefaultIsolationLevel + ) + + if hasTimeout { + mapLen++ + } + + if hasIsolationLevel { + mapLen++ + } + + if req.isSyncSet { + mapLen++ + } + + if err := enc.EncodeMapLen(mapLen); err != nil { + return err + } + + if hasTimeout { + if err := enc.EncodeUint(uint64(iproto.IPROTO_TIMEOUT)); err != nil { + return err + } + + if err := enc.Encode(req.timeout.Seconds()); err != nil { + return err + } + } + + if hasIsolationLevel { + if err := enc.EncodeUint(uint64(iproto.IPROTO_TXN_ISOLATION)); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(req.txnIsolation)); err != nil { + return err + } + } + + if req.isSyncSet { + if err := enc.EncodeUint(uint64(iproto.IPROTO_IS_SYNC)); err != nil { + return err + } + + if err := enc.EncodeBool(req.isSync); err != nil { + return err + } + } + + return nil +} + +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *BeginRequest) Context(ctx context.Context) *BeginRequest { + req.ctx = ctx + return req +} + +// CommitRequest helps you to create a commit request object for execution +// by a Stream. +// Commit request can not be processed out of stream. +type CommitRequest struct { + baseRequest + + isSync bool + isSyncSet bool +} + +// NewCommitRequest returns a new CommitRequest. +func NewCommitRequest() *CommitRequest { + req := new(CommitRequest) + req.rtype = iproto.IPROTO_COMMIT + return req +} + +// IsSync allows to set up a IsSync flag for call BeginRequest. +func (req *CommitRequest) IsSync(isSync bool) *CommitRequest { + req.isSync = isSync + req.isSyncSet = true + return req +} + +// Body fills an msgpack.Encoder with the commit request body. +func (req *CommitRequest) Body(_ SchemaResolver, enc *msgpack.Encoder) error { + var ( + mapLen = 0 + ) + + if req.isSyncSet { + mapLen++ + } + + if err := enc.EncodeMapLen(mapLen); err != nil { + return err + } + + if req.isSyncSet { + if err := enc.EncodeUint(uint64(iproto.IPROTO_IS_SYNC)); err != nil { + return err + } + + if err := enc.EncodeBool(req.isSync); err != nil { + return err + } + } + + return nil +} + +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *CommitRequest) Context(ctx context.Context) *CommitRequest { + req.ctx = ctx + return req +} + +// RollbackRequest helps you to create a rollback request object for execution +// by a Stream. +// Rollback request can not be processed out of stream. +type RollbackRequest struct { + baseRequest +} + +// NewRollbackRequest returns a new RollbackRequest. +func NewRollbackRequest() *RollbackRequest { + req := new(RollbackRequest) + req.rtype = iproto.IPROTO_ROLLBACK + return req +} + +// Body fills an msgpack.Encoder with the rollback request body. +func (req *RollbackRequest) Body(_ SchemaResolver, enc *msgpack.Encoder) error { + return enc.EncodeMapLen(0) +} + +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *RollbackRequest) Context(ctx context.Context) *RollbackRequest { + req.ctx = ctx + return req +} + +// Do verifies, sends the request and returns a future. +// +// An error is returned if the request was formed incorrectly, or failure to +// create the future. +func (s *Stream) Do(req Request) *Future { + if connectedReq, ok := req.(ConnectedRequest); ok { + if connectedReq.Conn() != s.Conn { + fut := NewFuture(req) + fut.SetError(errUnknownStreamRequest) + return fut + } + } + return s.Conn.send(req, s.Id) +} diff --git a/tarantool_test.go b/tarantool_test.go index 96ddc8d3c..5bf790ced 100644 --- a/tarantool_test.go +++ b/tarantool_test.go @@ -1,110 +1,3946 @@ -package tarantool +package tarantool_test -import( - "testing" +import ( + "bytes" + "context" + "encoding/binary" + "errors" "fmt" + "io" + "log" + "math" + "os" + "os/exec" + "path/filepath" + "reflect" + "regexp" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tarantool/go-iproto" + "github.com/vmihailenco/msgpack/v5" + + . "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/test_helpers" ) +var startOpts test_helpers.StartOpts = test_helpers.StartOpts{ + Dialer: dialer, + InitScript: "config.lua", + Listen: server, + WaitStart: 100 * time.Millisecond, + ConnectRetry: 10, + RetryTimeout: 500 * time.Millisecond, +} + +var dialer = NetDialer{ + Address: server, + User: "test", + Password: "test", +} + +type Member struct { + Name string + Nonce string + Val uint +} + +var contextDoneErrRegexp = regexp.MustCompile( + `^context is done \(request ID [0-9]+\): context canceled$`) + +func (m *Member) EncodeMsgpack(e *msgpack.Encoder) error { + if err := e.EncodeArrayLen(2); err != nil { + return err + } + if err := e.EncodeString(m.Name); err != nil { + return err + } + if err := e.EncodeUint(uint64(m.Val)); err != nil { + return err + } + return nil +} + +func (m *Member) DecodeMsgpack(d *msgpack.Decoder) error { + var err error + var l int + if l, err = d.DecodeArrayLen(); err != nil { + return err + } + if l != 2 { + return fmt.Errorf("array len doesn't match: %d", l) + } + if m.Name, err = d.DecodeString(); err != nil { + return err + } + if m.Val, err = d.DecodeUint(); err != nil { + return err + } + return nil +} + +var server = "127.0.0.1:3013" +var fdDialerTestServer = "127.0.0.1:3014" +var spaceNo = uint32(617) +var spaceName = "test" +var indexNo = uint32(0) +var indexName = "primary" +var opts = Opts{ + Timeout: 5 * time.Second, + // Concurrency: 32, + // RateLimit: 4*1024, +} + +const N = 500 + +func BenchmarkSync_naive(b *testing.B) { + var err error + + conn := test_helpers.ConnectWithValidation(b, dialer, opts) + defer conn.Close() + + _, err = conn.Do( + NewReplaceRequest(spaceNo). + Tuple([]interface{}{uint(1111), "hello", "world"}), + ).Get() + if err != nil { + b.Fatalf("failed to initialize database: %s", err) + } + + b.ResetTimer() + + for b.Loop() { + req := NewSelectRequest(spaceNo). + Index(indexNo). + Iterator(IterEq). + Key([]interface{}{uint(1111)}) + data, err := conn.Do(req).Get() + if err != nil { + b.Errorf("request error: %s", err) + } + + tuple := data[0].([]any) + if tuple[0].(uint16) != uint16(1111) { + b.Errorf("invalid result") + } + } +} + +func BenchmarkSync_naive_with_single_request(b *testing.B) { + var err error + + conn := test_helpers.ConnectWithValidation(b, dialer, opts) + defer conn.Close() + + _, err = conn.Do( + NewReplaceRequest(spaceNo). + Tuple([]interface{}{uint(1111), "hello", "world"}), + ).Get() + if err != nil { + b.Fatalf("failed to initialize database: %s", err) + } + + req := NewSelectRequest(spaceNo). + Index(indexNo). + Iterator(IterEq). + Key(UintKey{I: 1111}) + + b.ResetTimer() + + for b.Loop() { + data, err := conn.Do(req).Get() + if err != nil { + b.Errorf("request error: %s", err) + } + + tuple := data[0].([]any) + if tuple[0].(uint16) != uint16(1111) { + b.Errorf("invalid result") + } + } +} + +type benchTuple struct { + id uint +} + +func (t *benchTuple) DecodeMsgpack(dec *msgpack.Decoder) error { + l, err := dec.DecodeArrayLen() + if err != nil { + return fmt.Errorf("failed to decode tuples array: %w", err) + } + + if l != 1 { + return fmt.Errorf("unexpected tuples array with len %d", l) + } + + l, err = dec.DecodeArrayLen() + if err != nil { + return fmt.Errorf("failed to decode tuple array: %w", err) + } + + if l < 1 { + return fmt.Errorf("too small tuple have 0 fields") + } + + t.id, err = dec.DecodeUint() + if err != nil { + return fmt.Errorf("failed to decode id: %w", err) + } + + return nil +} + +func BenchmarkSync_naive_with_custom_type(b *testing.B) { + var err error + + conn := test_helpers.ConnectWithValidation(b, dialer, opts) + defer conn.Close() + + _, err = conn.Do( + NewReplaceRequest(spaceNo). + Tuple([]interface{}{uint(1111), "hello", "world"}), + ).Get() + if err != nil { + b.Fatalf("failed to initialize database: %s", err) + } + + req := NewSelectRequest(spaceNo). + Index(indexNo). + Iterator(IterEq). + Key(UintKey{I: 1111}) + + var tuple benchTuple + + b.ResetTimer() + + for b.Loop() { + err := conn.Do(req).GetTyped(&tuple) + if err != nil { + b.Errorf("request error: %s", err) + } + + if tuple.id != 1111 { + b.Errorf("invalid result") + } + } +} + +func BenchmarkSync_multithread(b *testing.B) { + var err error + + conn := test_helpers.ConnectWithValidation(b, dialer, opts) + defer conn.Close() + + _, err = conn.Do( + NewReplaceRequest(spaceNo). + Tuple([]interface{}{uint(1111), "hello", "world"}), + ).Get() + if err != nil { + b.Fatalf("failed to initialize database: %s", err) + } + + req := NewSelectRequest(spaceNo). + Index(indexNo). + Iterator(IterEq). + Key(UintKey{I: 1111}) + + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + var tuple benchTuple + + for pb.Next() { + err := conn.Do(req).GetTyped(&tuple) + if err != nil { + b.Errorf("request error: %s", err) + } + + if tuple.id != 1111 { + b.Errorf("invalid result") + } + } + }) +} + +func BenchmarkAsync_multithread_parallelism(b *testing.B) { + var err error + + conn := test_helpers.ConnectWithValidation(b, dialer, opts) + defer conn.Close() + + _, err = conn.Do( + NewReplaceRequest(spaceNo). + Tuple([]interface{}{uint(1111), "hello", "world"}), + ).Get() + if err != nil { + b.Fatalf("failed to initialize database: %s", err) + } + + req := NewSelectRequest(spaceNo). + Index(indexNo). + Iterator(IterEq). + Key(UintKey{I: 1111}) + + b.ResetTimer() + + for p := 1; p <= 1024; p *= 2 { + b.Run(fmt.Sprintf("%d", p), func(b *testing.B) { + b.SetParallelism(p) + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + var tuple benchTuple + + for pb.Next() { + err := conn.Do(req).GetTyped(&tuple) + if err != nil { + b.Errorf("request error: %s", err) + } + + if tuple.id != 1111 { + b.Errorf("invalid result") + } + } + }) + }) + } +} + +// TestBenchmarkAsync is a benchmark for the async API that is unable to +// implement with a Go-benchmark. It can be used to test performance with +// different numbers of connections and processing goroutines. +func TestBenchmarkAsync(t *testing.T) { + t.Skip() + + requests := int64(10_000_000) + connections := 16 + + ops := opts + // ops.Concurrency = 2 // 4 max. // 2 max. + + conns := make([]*Connection, 0, connections) + for range connections { + conn := test_helpers.ConnectWithValidation(t, dialer, ops) + defer conn.Close() + + conns = append(conns, conn) + } + + _, err := conns[0].Do( + NewReplaceRequest(spaceNo). + Tuple([]interface{}{uint(1111), "hello", "world"}), + ).Get() + if err != nil { + t.Fatalf("failed to initialize database: %s", err) + } + + req := NewSelectRequest(spaceNo). + Index(indexNo). + Iterator(IterEq). + Key(UintKey{I: 1111}) + + maxRps := float64(0) + maxConnections := 0 + maxConcurrency := 0 + + for cn := 1; cn <= connections; cn *= 2 { + for cc := 1; cc <= 512; cc *= 2 { + var wg sync.WaitGroup + + curRequests := requests + + start := time.Now() + + for i := range cc { + wg.Add(1) + + ch := make(chan *Future, 1024) + + go func(i int) { + defer close(ch) + + for atomic.AddInt64(&curRequests, -1) >= 0 { + ch <- conns[i%cn].Do(req) + } + }(i) + + go func() { + defer wg.Done() + + var tuple benchTuple + + for fut := range ch { + err := fut.GetTyped(&tuple) + if err != nil { + t.Errorf("request error: %s", err) + } + + if tuple.id != 1111 { + t.Errorf("invalid result") + } + } + }() + } + + wg.Wait() + + duration := time.Since(start) + + rps := float64(requests) / duration.Seconds() + t.Log("requests :", requests) + t.Log("concurrency:", cc) + t.Log("connections:", cn) + t.Logf("duration : %.2f\n", duration.Seconds()) + t.Logf("requests/s : %.2f\n", rps) + t.Log("============") + + if maxRps < rps { + maxRps = rps + maxConnections = cn + maxConcurrency = cc + } + } + } + + t.Log("max connections:", maxConnections) + t.Log("max concurrency:", maxConcurrency) + t.Logf("max requests/s : %.2f\n", maxRps) +} + +type mockRequest struct { + conn *Connection +} + +func (req *mockRequest) Type() iproto.Type { + return iproto.Type(0) +} + +func (req *mockRequest) Async() bool { + return false +} + +func (req *mockRequest) Body(resolver SchemaResolver, enc *msgpack.Encoder) error { + return nil +} + +func (req *mockRequest) Conn() *Connection { + return req.conn +} + +func (req *mockRequest) Ctx() context.Context { + return nil +} + +func (req *mockRequest) Response(header Header, + body io.Reader) (Response, error) { + return nil, fmt.Errorf("some error") +} + +func TestNetDialer(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := dialer.Dial(ctx, DialOpts{}) + require.Nil(err) + require.NotNil(conn) + defer conn.Close() + + assert.Equal(server, conn.Addr().String()) + assert.NotEqual("", conn.Greeting().Version) + + // Write IPROTO_PING. + ping := []byte{ + 0xce, 0x00, 0x00, 0x00, 0xa, // Length. + 0x82, // Header map. + 0x00, 0x40, + 0x01, 0xce, 0x00, 0x00, 0x00, 0x02, + 0x80, // Empty map. + } + ret, err := conn.Write(ping) + require.Equal(len(ping), ret) + require.Nil(err) + require.Nil(conn.Flush()) + + // Read IPROTO_PING response length. + lenbuf := make([]byte, 5) + ret, err = io.ReadFull(conn, lenbuf) + require.Nil(err) + require.Equal(len(lenbuf), ret) + length := int(binary.BigEndian.Uint32(lenbuf[1:])) + require.Greater(length, 0) + + // Read IPROTO_PING response. + buf := make([]byte, length) + ret, err = io.ReadFull(conn, buf) + require.Nil(err) + require.Equal(len(buf), ret) + // Check that it is IPROTO_OK. + assert.Equal([]byte{0x83, 0x00, 0xce, 0x00, 0x00, 0x00, 0x00}, buf[:7]) +} + +func TestNetDialer_BadUser(t *testing.T) { + badDialer := NetDialer{ + Address: server, + User: "Cpt Smollett", + Password: "none", + } + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + + conn, err := Connect(ctx, badDialer, opts) + assert.NotNil(t, err) + assert.ErrorContains(t, err, "failed to authenticate") + if conn != nil { + conn.Close() + t.Errorf("connection is not nil") + } +} + +// NetDialer does not work with PapSha256Auth, no matter the Tarantool version +// and edition. +func TestNetDialer_PapSha256Auth(t *testing.T) { + authDialer := AuthDialer{ + Dialer: dialer, + Username: "test", + Password: "test", + Auth: PapSha256Auth, + } + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := authDialer.Dial(ctx, DialOpts{}) + if conn != nil { + conn.Close() + t.Fatalf("Connection created successfully") + } + + assert.ErrorContains(t, err, "failed to authenticate") +} + +func TestFutureMultipleGetGetTyped(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + fut := conn.Call17Async("simple_concat", []interface{}{"1"}) + + for i := 0; i < 30; i++ { + // [0, 10) fut.Get() + // [10, 20) fut.GetTyped() + // [20, 30) Mix + get := false + if (i < 10) || (i >= 20 && i%2 == 0) { + get = true + } + + if get { + data, err := fut.Get() + if err != nil { + t.Errorf("Failed to call Get(): %s", err) + } + if val, ok := data[0].(string); !ok || val != "11" { + t.Errorf("Wrong Get() result: %v", data) + } + } else { + tpl := struct { + Val string + }{} + err := fut.GetTyped(&tpl) + if err != nil { + t.Errorf("Failed to call GetTyped(): %s", err) + } + if tpl.Val != "11" { + t.Errorf("Wrong GetTyped() result: %v", tpl) + } + } + } +} + +func TestFutureMultipleGetWithError(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + fut := conn.Call17Async("non_exist", []interface{}{"1"}) + + for i := 0; i < 2; i++ { + if _, err := fut.Get(); err == nil { + t.Fatalf("An error expected") + } + } +} + +func TestFutureMultipleGetTypedWithError(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + fut := conn.Call17Async("simple_concat", []interface{}{"1"}) + + wrongTpl := struct { + Val int + }{} + goodTpl := struct { + Val string + }{} + + if err := fut.GetTyped(&wrongTpl); err == nil { + t.Fatalf("An error expected") + } + if err := fut.GetTyped(&goodTpl); err != nil { + t.Fatalf("Unexpected error: %s", err) + } + if goodTpl.Val != "11" { + t.Fatalf("Wrong result: %s", goodTpl.Val) + } +} + +// ///////////////// + func TestClient(t *testing.T) { - server := "127.0.0.1:3013" - spaceNo := uint32(512) - indexNo := uint32(0) - limit := uint32(10) - offset := uint32(0) - iterator := IterAll - key := []interface{}{ 12 } - tuple1 := []interface{}{ 12, "Hello World", "Olga" } - tuple2 := []interface{}{ 12, "Hello Mars", "Anna" } - upd_tuple := []interface{}{ []interface{}{ "=", 1, "Hello Moon" }, []interface{}{ "#", 2, 1 } } - - functionName := "box.cfg()" - functionTuple := []interface{}{ "box.schema.SPACE_ID" } - - - client, err := Connect(server) - if err != nil { - t.Errorf("No connection available") - } - - var resp *Response - - resp, err = client.Ping() - fmt.Println("Ping") - fmt.Println("ERROR", err) - fmt.Println("Code", resp.Code) - fmt.Println("Data", resp.Data) - fmt.Println("----") - - resp, err = client.Insert(spaceNo, tuple1) - fmt.Println("Insert") - fmt.Println("ERROR", err) - fmt.Println("Code", resp.Code) - fmt.Println("Data", resp.Data) - fmt.Println("----") - - resp, err = client.Select(spaceNo, indexNo, offset, limit, iterator, key) - fmt.Println("Select") - fmt.Println("ERROR", err) - fmt.Println("Code", resp.Code) - fmt.Println("Data", resp.Data) - fmt.Println("----") - - resp, err = client.Replace(spaceNo, tuple2) - fmt.Println("Replace") - fmt.Println("ERROR", err) - fmt.Println("Code", resp.Code) - fmt.Println("Data", resp.Data) - fmt.Println("----") - - resp, err = client.Select(spaceNo, indexNo, offset, limit, iterator, key) - fmt.Println("Select") - fmt.Println("ERROR", err) - fmt.Println("Code", resp.Code) - fmt.Println("Data", resp.Data) - fmt.Println("----") - - resp, err = client.Update(spaceNo, indexNo, key, upd_tuple) - fmt.Println("Update") - fmt.Println("ERROR", err) - fmt.Println("Code", resp.Code) - fmt.Println("Data", resp.Data) - fmt.Println("----") - - resp, err = client.Select(spaceNo, indexNo, offset, limit, iterator, key) - fmt.Println("Select") - fmt.Println("ERROR", err) - fmt.Println("Code", resp.Code) - fmt.Println("Data", resp.Data) - fmt.Println("----") - - responses := make(chan *Response) - cnt1 := 50 - cnt2 := 500 - for j := 0; j < cnt1; j++ { - for i := 0; i < cnt2; i++ { - go func(){ - resp, err = client.Select(spaceNo, indexNo, offset, limit, iterator, key) - responses <- resp - }() - } - for i := 0; i < cnt2; i++ { - resp = <-responses - // fmt.Println(resp) - } - } - - resp, err = client.Delete(spaceNo, indexNo, key) - fmt.Println("Delete") - fmt.Println("ERROR", err) - fmt.Println("Code", resp.Code) - fmt.Println("Data", resp.Data) - fmt.Println("----") - - resp, err = client.Call(functionName, functionTuple) - fmt.Println("Call") - fmt.Println("ERROR", err) - fmt.Println("Code", resp.Code) - fmt.Println("Data", resp.Data) - fmt.Println("----") - -} \ No newline at end of file + var err error + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + // Ping + data, err := conn.Ping() + if err != nil { + t.Fatalf("Failed to Ping: %s", err) + } + if data != nil { + t.Fatalf("Response data is not nil after Ping") + } + + // Insert + data, err = conn.Insert(spaceNo, []interface{}{uint(1), "hello", "world"}) + if err != nil { + t.Fatalf("Failed to Insert: %s", err) + } + if len(data) != 1 { + t.Errorf("Response Body len != 1") + } + if tpl, ok := data[0].([]interface{}); !ok { + t.Errorf("Unexpected body of Insert") + } else { + if len(tpl) != 3 { + t.Errorf("Unexpected body of Insert (tuple len)") + } + if id, err := test_helpers.ConvertUint64(tpl[0]); err != nil || id != 1 { + t.Errorf("Unexpected body of Insert (0)") + } + if h, ok := tpl[1].(string); !ok || h != "hello" { + t.Errorf("Unexpected body of Insert (1)") + } + } + data, err = conn.Insert(spaceNo, &Tuple{Id: 1, Msg: "hello", Name: "world"}) + if tntErr, ok := err.(Error); !ok || tntErr.Code != iproto.ER_TUPLE_FOUND { + t.Errorf("Expected %s but got: %v", iproto.ER_TUPLE_FOUND, err) + } + if len(data) != 0 { + t.Errorf("Response Body len != 0") + } + + // Delete + data, err = conn.Delete(spaceNo, indexNo, []interface{}{uint(1)}) + if err != nil { + t.Fatalf("Failed to Delete: %s", err) + } + if len(data) != 1 { + t.Errorf("Response Body len != 1") + } + if tpl, ok := data[0].([]interface{}); !ok { + t.Errorf("Unexpected body of Delete") + } else { + if len(tpl) != 3 { + t.Errorf("Unexpected body of Delete (tuple len)") + } + if id, err := test_helpers.ConvertUint64(tpl[0]); err != nil || id != 1 { + t.Errorf("Unexpected body of Delete (0)") + } + if h, ok := tpl[1].(string); !ok || h != "hello" { + t.Errorf("Unexpected body of Delete (1)") + } + } + data, err = conn.Delete(spaceNo, indexNo, []interface{}{uint(101)}) + if err != nil { + t.Fatalf("Failed to Delete: %s", err) + } + if len(data) != 0 { + t.Errorf("Response Data len != 0") + } + + // Replace + data, err = conn.Replace(spaceNo, []interface{}{uint(2), "hello", "world"}) + if err != nil { + t.Fatalf("Failed to Replace: %s", err) + } + if data == nil { + t.Fatalf("Response is nil after Replace") + } + data, err = conn.Replace(spaceNo, []interface{}{uint(2), "hi", "planet"}) + if err != nil { + t.Fatalf("Failed to Replace (duplicate): %s", err) + } + if len(data) != 1 { + t.Errorf("Response Data len != 1") + } + if tpl, ok := data[0].([]interface{}); !ok { + t.Errorf("Unexpected body of Replace") + } else { + if len(tpl) != 3 { + t.Errorf("Unexpected body of Replace (tuple len)") + } + if id, err := test_helpers.ConvertUint64(tpl[0]); err != nil || id != 2 { + t.Errorf("Unexpected body of Replace (0)") + } + if h, ok := tpl[1].(string); !ok || h != "hi" { + t.Errorf("Unexpected body of Replace (1)") + } + } + + // Update + data, err = conn.Update(spaceNo, indexNo, []interface{}{uint(2)}, + NewOperations().Assign(1, "bye").Delete(2, 1)) + if err != nil { + t.Fatalf("Failed to Update: %s", err) + } + if len(data) != 1 { + t.Errorf("Response Data len != 1") + } + if tpl, ok := data[0].([]interface{}); !ok { + t.Errorf("Unexpected body of Update") + } else { + if len(tpl) != 2 { + t.Errorf("Unexpected body of Update (tuple len)") + } + if id, err := test_helpers.ConvertUint64(tpl[0]); err != nil || id != 2 { + t.Errorf("Unexpected body of Update (0)") + } + if h, ok := tpl[1].(string); !ok || h != "bye" { + t.Errorf("Unexpected body of Update (1)") + } + } + + // Upsert + data, err = conn.Upsert(spaceNo, []interface{}{uint(3), 1}, + NewOperations().Add(1, 1)) + if err != nil { + t.Fatalf("Failed to Upsert (insert): %s", err) + } + if data == nil { + t.Fatalf("Response is nil after Upsert (insert)") + } + data, err = conn.Upsert(spaceNo, []interface{}{uint(3), 1}, + NewOperations().Add(1, 1)) + if err != nil { + t.Fatalf("Failed to Upsert (update): %s", err) + } + if data == nil { + t.Errorf("Response is nil after Upsert (update)") + } + + // Select + for i := 10; i < 20; i++ { + data, err = conn.Replace(spaceNo, []interface{}{uint(i), fmt.Sprintf("val %d", i), "bla"}) + if err != nil { + t.Fatalf("Failed to Replace: %s", err) + } + if data == nil { + t.Errorf("Response is nil after Replace") + } + } + data, err = conn.Select(spaceNo, indexNo, 0, 1, IterEq, []interface{}{uint(10)}) + if err != nil { + t.Fatalf("Failed to Select: %s", err) + } + if len(data) != 1 { + t.Fatalf("Response Data len != 1") + } + if tpl, ok := data[0].([]interface{}); !ok { + t.Errorf("Unexpected body of Select") + } else { + if id, err := test_helpers.ConvertUint64(tpl[0]); err != nil || id != 10 { + t.Errorf("Unexpected body of Select (0)") + } + if h, ok := tpl[1].(string); !ok || h != "val 10" { + t.Errorf("Unexpected body of Select (1)") + } + } + + // Select empty + data, err = conn.Select(spaceNo, indexNo, 0, 1, IterEq, []interface{}{uint(30)}) + if err != nil { + t.Fatalf("Failed to Select: %s", err) + } + if len(data) != 0 { + t.Errorf("Response Data len != 0") + } + + // Select Typed + var tpl []Tuple + err = conn.SelectTyped(spaceNo, indexNo, 0, 1, IterEq, []interface{}{uint(10)}, &tpl) + if err != nil { + t.Fatalf("Failed to SelectTyped: %s", err) + } + if len(tpl) != 1 { + t.Errorf("Result len of SelectTyped != 1") + } else if tpl[0].Id != 10 { + t.Errorf("Bad value loaded from SelectTyped") + } + + // Get Typed + var singleTpl = Tuple{} + err = conn.GetTyped(spaceNo, indexNo, []interface{}{uint(10)}, &singleTpl) + if err != nil { + t.Fatalf("Failed to GetTyped: %s", err) + } + if singleTpl.Id != 10 { + t.Errorf("Bad value loaded from GetTyped") + } + + // Select Typed for one tuple + var tpl1 [1]Tuple + err = conn.SelectTyped(spaceNo, indexNo, 0, 1, IterEq, []interface{}{uint(10)}, &tpl1) + if err != nil { + t.Fatalf("Failed to SelectTyped: %s", err) + } + if len(tpl) != 1 { + t.Errorf("Result len of SelectTyped != 1") + } else if tpl[0].Id != 10 { + t.Errorf("Bad value loaded from SelectTyped") + } + + // Get Typed Empty + var singleTpl2 Tuple + err = conn.GetTyped(spaceNo, indexNo, []interface{}{uint(30)}, &singleTpl2) + if err != nil { + t.Fatalf("Failed to GetTyped: %s", err) + } + if singleTpl2.Id != 0 { + t.Errorf("Bad value loaded from GetTyped") + } + + // Select Typed Empty + var tpl2 []Tuple + err = conn.SelectTyped(spaceNo, indexNo, 0, 1, IterEq, []interface{}{uint(30)}, &tpl2) + if err != nil { + t.Fatalf("Failed to SelectTyped: %s", err) + } + if len(tpl2) != 0 { + t.Errorf("Result len of SelectTyped != 1") + } + + // Call16 + data, err = conn.Call16("box.info", []interface{}{"box.schema.SPACE_ID"}) + if err != nil { + t.Fatalf("Failed to Call16: %s", err) + } + if len(data) < 1 { + t.Errorf("Response.Data is empty after Eval") + } + + // Call16 vs Call17 + data, err = conn.Call16("simple_concat", []interface{}{"1"}) + if err != nil { + t.Errorf("Failed to use Call16") + } + if val, ok := data[0].([]interface{})[0].(string); !ok || val != "11" { + t.Errorf("result is not {{1}} : %v", data) + } + + data, err = conn.Call17("simple_concat", []interface{}{"1"}) + if err != nil { + t.Errorf("Failed to use Call") + } + if val, ok := data[0].(string); !ok || val != "11" { + t.Errorf("result is not {{1}} : %v", data) + } + + // Eval + data, err = conn.Eval("return 5 + 6", []interface{}{}) + if err != nil { + t.Fatalf("Failed to Eval: %s", err) + } + if len(data) < 1 { + t.Errorf("Response.Data is empty after Eval") + } + if val, err := test_helpers.ConvertUint64(data[0]); err != nil || val != 11 { + t.Errorf("5 + 6 == 11, but got %v", val) + } +} + +const ( + createTableQuery = "CREATE TABLE SQL_SPACE (ID STRING PRIMARY KEY, NAME " + + "STRING COLLATE \"unicode\" DEFAULT NULL);" + insertQuery = "INSERT INTO SQL_SPACE VALUES (?, ?);" + selectNamedQuery = "SELECT ID, NAME FROM SQL_SPACE WHERE ID=:ID AND NAME=:NAME;" + selectPosQuery = "SELECT ID, NAME FROM SQL_SPACE WHERE ID=? AND NAME=?;" + updateQuery = "UPDATE SQL_SPACE SET NAME=? WHERE ID=?;" + enableFullMetaDataQuery = "SET SESSION \"sql_full_metadata\" = true;" + selectSpanDifQueryNew = "SELECT ID||ID, NAME, ID FROM seqscan SQL_SPACE WHERE NAME=?;" + selectSpanDifQueryOld = "SELECT ID||ID, NAME, ID FROM SQL_SPACE WHERE NAME=?;" + alterTableQuery = "ALTER TABLE SQL_SPACE RENAME TO SQL_SPACE2;" + insertIncrQuery = "INSERT INTO SQL_SPACE2 VALUES (?, ?);" + deleteQuery = "DELETE FROM SQL_SPACE2 WHERE NAME=?;" + dropQuery = "DROP TABLE SQL_SPACE2;" + dropQuery2 = "DROP TABLE SQL_SPACE;" + disableFullMetaDataQuery = "SET SESSION \"sql_full_metadata\" = false;" + + selectTypedQuery = "SELECT NAME1, NAME0 FROM SQL_TEST WHERE NAME0=?" + selectNamedQuery2 = "SELECT NAME0, NAME1 FROM SQL_TEST WHERE NAME0=:id AND NAME1=:name;" + selectPosQuery2 = "SELECT NAME0, NAME1 FROM SQL_TEST WHERE NAME0=? AND NAME1=?;" + mixedQuery = "SELECT NAME0, NAME1 FROM SQL_TEST WHERE NAME0=:name0 AND NAME1=?;" +) + +func TestSQL(t *testing.T) { + test_helpers.SkipIfSQLUnsupported(t) + + type testCase struct { + Query string + Args interface{} + sqlInfo SQLInfo + data []interface{} + metaData []ColumnMetaData + } + + selectSpanDifQuery := selectSpanDifQueryNew + if isSeqScanOld, err := test_helpers.IsTarantoolVersionLess(3, 0, 0); err != nil { + t.Fatalf("Could not check the Tarantool version: %s", err) + } else if isSeqScanOld { + selectSpanDifQuery = selectSpanDifQueryOld + } + + testCases := []testCase{ + { + createTableQuery, + []interface{}{}, + SQLInfo{AffectedCount: 1}, + []interface{}{}, + nil, + }, + { + insertQuery, + []interface{}{"1", "test"}, + SQLInfo{AffectedCount: 1}, + []interface{}{}, + nil, + }, + { + selectNamedQuery, + map[string]interface{}{ + "ID": "1", + "NAME": "test", + }, + SQLInfo{AffectedCount: 0}, + []interface{}{[]interface{}{"1", "test"}}, + []ColumnMetaData{ + {FieldType: "string", FieldName: "ID"}, + {FieldType: "string", FieldName: "NAME"}}, + }, + { + selectPosQuery, + []interface{}{"1", "test"}, + SQLInfo{AffectedCount: 0}, + []interface{}{[]interface{}{"1", "test"}}, + []ColumnMetaData{ + {FieldType: "string", FieldName: "ID"}, + {FieldType: "string", FieldName: "NAME"}}, + }, + { + updateQuery, + []interface{}{"test_test", "1"}, + SQLInfo{AffectedCount: 1}, + []interface{}{}, + nil, + }, + { + enableFullMetaDataQuery, + []interface{}{}, + SQLInfo{AffectedCount: 1}, + []interface{}{}, + nil, + }, + { + selectSpanDifQuery, + []interface{}{"test_test"}, + SQLInfo{AffectedCount: 0}, + []interface{}{[]interface{}{"11", "test_test", "1"}}, + []ColumnMetaData{ + { + FieldType: "string", + FieldName: "COLUMN_1", + FieldIsNullable: false, + FieldIsAutoincrement: false, + FieldSpan: "ID||ID", + }, + { + FieldType: "string", + FieldName: "NAME", + FieldIsNullable: true, + FieldIsAutoincrement: false, + FieldSpan: "NAME", + FieldCollation: "unicode", + }, + { + FieldType: "string", + FieldName: "ID", + FieldIsNullable: false, + FieldIsAutoincrement: false, + FieldSpan: "ID", + FieldCollation: "", + }, + }, + }, + { + alterTableQuery, + []interface{}{}, + SQLInfo{AffectedCount: 0}, + []interface{}{}, + nil, + }, + { + insertIncrQuery, + []interface{}{"2", "test_2"}, + SQLInfo{AffectedCount: 1, InfoAutoincrementIds: []uint64{1}}, + []interface{}{}, + nil, + }, + { + deleteQuery, + []interface{}{"test_2"}, + SQLInfo{AffectedCount: 1}, + []interface{}{}, + nil, + }, + { + dropQuery, + []interface{}{}, + SQLInfo{AffectedCount: 1}, + []interface{}{}, + nil, + }, + { + disableFullMetaDataQuery, + []interface{}{}, + SQLInfo{AffectedCount: 1}, + []interface{}{}, + nil, + }, + } + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + for i, test := range testCases { + req := NewExecuteRequest(test.Query).Args(test.Args) + resp, err := conn.Do(req).GetResponse() + assert.NoError(t, err, "Failed to Execute, query: %s", test.Query) + assert.NotNil(t, resp, "Response is nil after Execute\nQuery number: %d", i) + data, err := resp.Decode() + assert.NoError(t, err, "Failed to Decode") + for j := range data { + assert.Equal(t, data[j], test.data[j], "Response data is wrong") + } + exResp, ok := resp.(*ExecuteResponse) + assert.True(t, ok, "Got wrong response type") + sqlInfo, err := exResp.SQLInfo() + assert.NoError(t, err, "Error while getting SQLInfo") + assert.Equal(t, sqlInfo.AffectedCount, test.sqlInfo.AffectedCount, + "Affected count is wrong") + + errorMsg := "Response Metadata is wrong" + metaData, err := exResp.MetaData() + assert.NoError(t, err, "Error while getting MetaData") + for j := range metaData { + assert.Equal(t, metaData[j], test.metaData[j], errorMsg) + } + } +} + +func TestSQLTyped(t *testing.T) { + test_helpers.SkipIfSQLUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + mem := []Member{} + info, meta, err := conn.ExecuteTyped(selectTypedQuery, []interface{}{1}, &mem) + if info.AffectedCount != 0 { + t.Errorf("Rows affected count must be 0") + } + if len(meta) != 2 { + t.Errorf("Meta data is not full") + } + if len(mem) != 1 { + t.Errorf("Wrong length of result") + } + if err != nil { + t.Error(err) + } +} + +func TestSQLBindings(t *testing.T) { + test_helpers.SkipIfSQLUnsupported(t) + + // Data for test table + testData := map[int]string{ + 1: "test", + } + + var resp Response + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + // test all types of supported bindings + // prepare named sql bind + sqlBind := map[string]interface{}{ + "id": 1, + "name": "test", + } + + sqlBind2 := struct { + Id int + Name string + }{1, "test"} + + sqlBind3 := []KeyValueBind{ + {"id", 1}, + {"name", "test"}, + } + + sqlBind4 := []interface{}{ + KeyValueBind{Key: "id", Value: 1}, + KeyValueBind{Key: "name", Value: "test"}, + } + + namedSQLBinds := []interface{}{ + sqlBind, + sqlBind2, + sqlBind3, + sqlBind4, + } + + // positioned sql bind + sqlBind5 := []interface{}{ + 1, "test", + } + + // mixed sql bind + sqlBind6 := []interface{}{ + KeyValueBind{Key: "name0", Value: 1}, + "test", + } + + for _, bind := range namedSQLBinds { + req := NewExecuteRequest(selectNamedQuery2).Args(bind) + resp, err := conn.Do(req).GetResponse() + if err != nil { + t.Fatalf("Failed to Execute: %s", err) + } + if resp == nil { + t.Fatal("Response is nil after Execute") + } + data, err := resp.Decode() + if err != nil { + t.Errorf("Failed to Decode: %s", err) + } + if reflect.DeepEqual(data[0], []interface{}{1, testData[1]}) { + t.Error("Select with named arguments failed") + } + exResp, ok := resp.(*ExecuteResponse) + assert.True(t, ok, "Got wrong response type") + metaData, err := exResp.MetaData() + assert.NoError(t, err, "Error while getting MetaData") + if metaData[0].FieldType != "unsigned" || + metaData[0].FieldName != "NAME0" || + metaData[1].FieldType != "string" || + metaData[1].FieldName != "NAME1" { + t.Error("Wrong metadata") + } + } + + req := NewExecuteRequest(selectPosQuery2).Args(sqlBind5) + resp, err := conn.Do(req).GetResponse() + if err != nil { + t.Fatalf("Failed to Execute: %s", err) + } + if resp == nil { + t.Fatal("Response is nil after Execute") + } + data, err := resp.Decode() + if err != nil { + t.Errorf("Failed to Decode: %s", err) + } + if reflect.DeepEqual(data[0], []interface{}{1, testData[1]}) { + t.Error("Select with positioned arguments failed") + } + exResp, ok := resp.(*ExecuteResponse) + assert.True(t, ok, "Got wrong response type") + metaData, err := exResp.MetaData() + assert.NoError(t, err, "Error while getting MetaData") + if metaData[0].FieldType != "unsigned" || + metaData[0].FieldName != "NAME0" || + metaData[1].FieldType != "string" || + metaData[1].FieldName != "NAME1" { + t.Error("Wrong metadata") + } + + req = NewExecuteRequest(mixedQuery).Args(sqlBind6) + resp, err = conn.Do(req).GetResponse() + if err != nil { + t.Fatalf("Failed to Execute: %s", err) + } + if resp == nil { + t.Fatal("Response is nil after Execute") + } + data, err = resp.Decode() + if err != nil { + t.Errorf("Failed to Decode: %s", err) + } + if reflect.DeepEqual(data[0], []interface{}{1, testData[1]}) { + t.Error("Select with positioned arguments failed") + } + exResp, ok = resp.(*ExecuteResponse) + assert.True(t, ok, "Got wrong response type") + metaData, err = exResp.MetaData() + assert.NoError(t, err, "Error while getting MetaData") + if metaData[0].FieldType != "unsigned" || + metaData[0].FieldName != "NAME0" || + metaData[1].FieldType != "string" || + metaData[1].FieldName != "NAME1" { + t.Error("Wrong metadata") + } +} + +func TestStressSQL(t *testing.T) { + test_helpers.SkipIfSQLUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + req := NewExecuteRequest(createTableQuery) + resp, err := conn.Do(req).GetResponse() + if err != nil { + t.Fatalf("Failed to create an Execute: %s", err) + } + if resp == nil { + t.Fatal("Response is nil after Execute") + } + exResp, ok := resp.(*ExecuteResponse) + assert.True(t, ok, "Got wrong response type") + sqlInfo, err := exResp.SQLInfo() + assert.NoError(t, err, "Error while getting SQLInfo") + if sqlInfo.AffectedCount != 1 { + t.Errorf("Incorrect count of created spaces: %d", sqlInfo.AffectedCount) + } + + // create table with the same name + req = NewExecuteRequest(createTableQuery) + resp, err = conn.Do(req).GetResponse() + if err != nil { + t.Fatalf("Failed to create an Execute: %s", err) + } + if resp == nil { + t.Fatal("Response is nil after Execute") + } + _, err = resp.Decode() + assert.NotNil(t, err, "Expected error while decoding") + + tntErr, ok := err.(Error) + assert.True(t, ok) + assert.Equal(t, iproto.ER_SPACE_EXISTS, tntErr.Code) + if resp.Header().Error != iproto.ER_SPACE_EXISTS { + t.Fatalf("Unexpected response error: %d", resp.Header().Error) + } + prevErr := err + + exResp, ok = resp.(*ExecuteResponse) + assert.True(t, ok, "Got wrong response type") + sqlInfo, err = exResp.SQLInfo() + assert.Equal(t, prevErr, err) + if sqlInfo.AffectedCount != 0 { + t.Errorf("Incorrect count of created spaces: %d", sqlInfo.AffectedCount) + } + + // execute with nil argument + req = NewExecuteRequest(createTableQuery).Args(nil) + resp, err = conn.Do(req).GetResponse() + if err != nil { + t.Fatalf("Failed to create an Execute: %s", err) + } + if resp == nil { + t.Fatal("Response is nil after Execute") + } + if resp.Header().Error == ErrorNo { + t.Fatal("Unexpected successful Execute") + } + exResp, ok = resp.(*ExecuteResponse) + assert.True(t, ok, "Got wrong response type") + sqlInfo, err = exResp.SQLInfo() + assert.NotNil(t, err, "Expected an error") + if sqlInfo.AffectedCount != 0 { + t.Errorf("Incorrect count of created spaces: %d", sqlInfo.AffectedCount) + } + + // execute with zero string + req = NewExecuteRequest("") + resp, err = conn.Do(req).GetResponse() + if err != nil { + t.Fatalf("Failed to create an Execute: %s", err) + } + if resp == nil { + t.Fatal("Response is nil after Execute") + } + if resp.Header().Error == ErrorNo { + t.Fatal("Unexpected successful Execute") + } + exResp, ok = resp.(*ExecuteResponse) + assert.True(t, ok, "Got wrong response type") + sqlInfo, err = exResp.SQLInfo() + assert.NotNil(t, err, "Expected an error") + if sqlInfo.AffectedCount != 0 { + t.Errorf("Incorrect count of created spaces: %d", sqlInfo.AffectedCount) + } + + // drop table query + req = NewExecuteRequest(dropQuery2) + resp, err = conn.Do(req).GetResponse() + if err != nil { + t.Fatalf("Failed to Execute: %s", err) + } + if resp == nil { + t.Fatal("Response is nil after Execute") + } + exResp, ok = resp.(*ExecuteResponse) + assert.True(t, ok, "Got wrong response type") + sqlInfo, err = exResp.SQLInfo() + assert.NoError(t, err, "Error while getting SQLInfo") + if sqlInfo.AffectedCount != 1 { + t.Errorf("Incorrect count of dropped spaces: %d", sqlInfo.AffectedCount) + } + + // drop the same table + req = NewExecuteRequest(dropQuery2) + resp, err = conn.Do(req).GetResponse() + if err != nil { + t.Fatalf("Failed to create an Execute: %s", err) + } + if resp == nil { + t.Fatal("Response is nil after Execute") + } + if resp.Header().Error == ErrorNo { + t.Fatal("Unexpected successful Execute") + } + _, err = resp.Decode() + if err == nil { + t.Fatal("Unexpected lack of error") + } + exResp, ok = resp.(*ExecuteResponse) + assert.True(t, ok, "Got wrong response type") + sqlInfo, err = exResp.SQLInfo() + if err == nil { + t.Fatal("Unexpected lack of error") + } + if sqlInfo.AffectedCount != 0 { + t.Errorf("Incorrect count of created spaces: %d", sqlInfo.AffectedCount) + } +} + +func TestNewPrepared(t *testing.T) { + test_helpers.SkipIfSQLUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + stmt, err := conn.NewPrepared(selectNamedQuery2) + if err != nil { + t.Errorf("failed to prepare: %v", err) + } + + executeReq := NewExecutePreparedRequest(stmt) + unprepareReq := NewUnprepareRequest(stmt) + + resp, err := conn.Do(executeReq.Args([]interface{}{1, "test"})).GetResponse() + if err != nil { + t.Errorf("failed to execute prepared: %v", err) + } + data, err := resp.Decode() + if err != nil { + t.Errorf("Failed to Decode: %s", err) + } + if reflect.DeepEqual(data[0], []interface{}{1, "test"}) { + t.Error("Select with named arguments failed") + } + prepResp, ok := resp.(*ExecuteResponse) + assert.True(t, ok, "Got wrong response type") + metaData, err := prepResp.MetaData() + assert.NoError(t, err, "Error while getting MetaData") + if metaData[0].FieldType != "unsigned" || + metaData[0].FieldName != "NAME0" || + metaData[1].FieldType != "string" || + metaData[1].FieldName != "NAME1" { + t.Error("Wrong metadata") + } + + _, err = conn.Do(unprepareReq).Get() + if err != nil { + t.Errorf("failed to unprepare prepared statement: %v", err) + } + + _, err = conn.Do(unprepareReq).Get() + if err == nil { + t.Errorf("the statement must be already unprepared") + } + require.Contains(t, err.Error(), "Prepared statement with id") + + _, err = conn.Do(executeReq).Get() + if err == nil { + t.Errorf("the statement must be already unprepared") + } + require.Contains(t, err.Error(), "Prepared statement with id") + + prepareReq := NewPrepareRequest(selectNamedQuery2) + data, err = conn.Do(prepareReq).Get() + if err != nil { + t.Errorf("failed to prepare: %v", err) + } + if data == nil { + t.Errorf("failed to prepare: Data is nil") + } + + if len(data) == 0 { + t.Errorf("failed to prepare: response Data has no elements") + } + stmt, ok = data[0].(*Prepared) + if !ok { + t.Errorf("failed to prepare: failed to cast the response Data to Prepared object") + } + if stmt.StatementID == 0 { + t.Errorf("failed to prepare: statement id is 0") + } +} + +func TestConnection_DoWithStrangerConn(t *testing.T) { + expectedErr := fmt.Errorf("the passed connected request doesn't belong to the current" + + " connection or connection pool") + + conn1 := &Connection{} + req := test_helpers.NewMockRequest() + + _, err := conn1.Do(req).Get() + if err == nil { + t.Fatalf("nil error caught") + } + if err.Error() != expectedErr.Error() { + t.Fatalf("Unexpected error caught") + } +} + +func TestConnection_SetResponse_failed(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + req := mockRequest{conn} + fut := conn.Do(&req) + + data, err := fut.Get() + assert.EqualError(t, err, "failed to set response: some error") + assert.Nil(t, data) +} + +func TestGetSchema(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + s, err := GetSchema(conn) + if err != nil { + t.Errorf("unexpected error: %s", err.Error()) + } + if s.Version != 0 || s.Spaces[spaceName].Id != spaceNo { + t.Errorf("GetSchema() returns incorrect schema") + } +} + +func TestConnection_SetSchema_Changes(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + req := NewInsertRequest(spaceName) + req.Tuple([]interface{}{uint(1010), "Tarantool"}) + _, err := conn.Do(req).Get() + if err != nil { + t.Fatalf("Failed to Insert: %s", err) + } + + s, err := GetSchema(conn) + if err != nil { + t.Errorf("unexpected error: %s", err.Error()) + } + conn.SetSchema(s) + + // Check if changes of the SetSchema result will do nothing to the + // connection schema. + s.Spaces[spaceName] = Space{} + + reqS := NewSelectRequest(spaceName) + reqS.Key([]interface{}{uint(1010)}) + data, err := conn.Do(reqS).Get() + if err != nil { + t.Fatalf("failed to Select: %s", err) + } + if data[0].([]interface{})[1] != "Tarantool" { + t.Errorf("wrong Select body: %v", data) + } +} + +func TestSchema(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + // Schema + schema, err := GetSchema(conn) + if err != nil { + t.Errorf("unexpected error: %s", err.Error()) + } + if schema.SpacesById == nil { + t.Errorf("schema.SpacesById is nil") + } + if schema.Spaces == nil { + t.Errorf("schema.Spaces is nil") + } + var space, space2 Space + var ok bool + if space, ok = schema.SpacesById[616]; !ok { + t.Errorf("space with id = 616 was not found in schema.SpacesById") + } + if space2, ok = schema.Spaces["schematest"]; !ok { + t.Errorf("space with name 'schematest' was not found in schema.SpacesById") + } + assert.Equal(t, space, space2, + "space with id = 616 and space with name schematest are different") + if space.Id != 616 { + t.Errorf("space 616 has incorrect Id") + } + if space.Name != "schematest" { + t.Errorf("space 616 has incorrect Name") + } + if !space.Temporary { + t.Errorf("space 616 should be temporary") + } + if space.Engine != "memtx" { + t.Errorf("space 616 engine should be memtx") + } + if space.FieldsCount != 8 { + t.Errorf("space 616 has incorrect fields count") + } + + if space.FieldsById == nil { + t.Errorf("space.FieldsById is nill") + } + if space.Fields == nil { + t.Errorf("space.Fields is nill") + } + if len(space.FieldsById) != 7 { + t.Errorf("space.FieldsById len is incorrect") + } + if len(space.Fields) != 7 { + t.Errorf("space.Fields len is incorrect") + } + + var field1, field2, field5, field1n, field5n Field + if field1, ok = space.FieldsById[1]; !ok { + t.Errorf("field id = 1 was not found") + } + if field2, ok = space.FieldsById[2]; !ok { + t.Errorf("field id = 2 was not found") + } + if field5, ok = space.FieldsById[5]; !ok { + t.Errorf("field id = 5 was not found") + } + + if field1n, ok = space.Fields["name1"]; !ok { + t.Errorf("field name = name1 was not found") + } + if field5n, ok = space.Fields["name5"]; !ok { + t.Errorf("field name = name5 was not found") + } + if field1 != field1n || field5 != field5n { + t.Errorf("field with id = 1 and field with name 'name1' are different") + } + if field1.Name != "name1" { + t.Errorf("field 1 has incorrect Name") + } + if field1.Type != "unsigned" { + t.Errorf("field 1 has incorrect Type") + } + if field2.Name != "name2" { + t.Errorf("field 2 has incorrect Name") + } + if field2.Type != "string" { + t.Errorf("field 2 has incorrect Type") + } + + if space.IndexesById == nil { + t.Errorf("space.IndexesById is nill") + } + if space.Indexes == nil { + t.Errorf("space.Indexes is nill") + } + if len(space.IndexesById) != 2 { + t.Errorf("space.IndexesById len is incorrect") + } + if len(space.Indexes) != 2 { + t.Errorf("space.Indexes len is incorrect") + } + + var index0, index3, index0n, index3n Index + if index0, ok = space.IndexesById[0]; !ok { + t.Errorf("index id = 0 was not found") + } + if index3, ok = space.IndexesById[3]; !ok { + t.Errorf("index id = 3 was not found") + } + if index0n, ok = space.Indexes["primary"]; !ok { + t.Errorf("index name = primary was not found") + } + if index3n, ok = space.Indexes["secondary"]; !ok { + t.Errorf("index name = secondary was not found") + } + assert.Equal(t, index0, index0n, + "index with id = 0 and index with name 'primary' are different") + assert.Equal(t, index3, index3n, + "index with id = 3 and index with name 'secondary' are different") + if index3.Id != 3 { + t.Errorf("index has incorrect Id") + } + if index0.Name != "primary" { + t.Errorf("index has incorrect Name") + } + if index0.Type != "hash" || index3.Type != "tree" { + t.Errorf("index has incorrect Type") + } + if !index0.Unique || index3.Unique { + t.Errorf("index has incorrect Unique") + } + if index3.Fields == nil { + t.Errorf("index.Fields is nil") + } + if len(index3.Fields) != 2 { + t.Errorf("index.Fields len is incorrect") + } + + ifield1 := index3.Fields[0] + ifield2 := index3.Fields[1] + if (ifield1 == IndexField{}) || (ifield2 == IndexField{}) { + t.Fatalf("index field is nil") + } + if ifield1.Id != 1 || ifield2.Id != 2 { + t.Errorf("index field has incorrect Id") + } + if (ifield1.Type != "num" && ifield1.Type != "unsigned") || + (ifield2.Type != "STR" && ifield2.Type != "string") { + t.Errorf("index field has incorrect Type '%s'", ifield2.Type) + } +} + +func TestSchema_IsNullable(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + schema, err := GetSchema(conn) + if err != nil { + t.Errorf("unexpected error: %s", err.Error()) + } + if schema.Spaces == nil { + t.Errorf("schema.Spaces is nil") + } + + var space Space + var ok bool + if space, ok = schema.SpacesById[616]; !ok { + t.Errorf("space with id = 616 was not found in schema.SpacesById") + } + + var field, field_nullable Field + for i := 0; i <= 5; i++ { + name := fmt.Sprintf("name%d", i) + if field, ok = space.Fields[name]; !ok { + t.Errorf("field name = %s was not found", name) + } + if field.IsNullable { + t.Errorf("field %s has incorrect IsNullable", name) + } + } + if field_nullable, ok = space.Fields["nullable"]; !ok { + t.Errorf("field name = nullable was not found") + } + if !field_nullable.IsNullable { + t.Errorf("field nullable has incorrect IsNullable") + } +} + +func TestNewPreparedFromResponse(t *testing.T) { + var ( + ErrNilResponsePassed = fmt.Errorf("passed nil response") + ErrNilResponseData = fmt.Errorf("response Data is nil") + ErrWrongDataFormat = fmt.Errorf("response Data format is wrong") + ) + + testConn := &Connection{} + testCases := []struct { + name string + resp Response + expectedError error + }{ + {"ErrNilResponsePassed", nil, ErrNilResponsePassed}, + {"ErrNilResponseData", test_helpers.NewMockResponse(t, nil), + ErrNilResponseData}, + {"ErrWrongDataFormat", test_helpers.NewMockResponse(t, []interface{}{}), + ErrWrongDataFormat}, + {"ErrWrongDataFormat", test_helpers.NewMockResponse(t, []interface{}{"test"}), + ErrWrongDataFormat}, + } + for _, testCase := range testCases { + t.Run("Expecting error "+testCase.name, func(t *testing.T) { + _, err := NewPreparedFromResponse(testConn, testCase.resp) + assert.Equal(t, testCase.expectedError, err) + }) + } +} + +func TestClientNamed(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + // Insert + data, err := conn.Insert(spaceName, []interface{}{uint(1001), "hello2", "world2"}) + if err != nil { + t.Fatalf("Failed to Insert: %s", err) + } + if data == nil { + t.Errorf("Response is nil after Insert") + } + + // Delete + data, err = conn.Delete(spaceName, indexName, []interface{}{uint(1001)}) + if err != nil { + t.Fatalf("Failed to Delete: %s", err) + } + if data == nil { + t.Errorf("Response is nil after Delete") + } + + // Replace + data, err = conn.Replace(spaceName, []interface{}{uint(1002), "hello", "world"}) + if err != nil { + t.Fatalf("Failed to Replace: %s", err) + } + if data == nil { + t.Errorf("Response is nil after Replace") + } + + // Update + data, err = conn.Update(spaceName, indexName, + []interface{}{ + uint(1002)}, + NewOperations().Assign(1, "buy").Delete(2, 1)) + if err != nil { + t.Fatalf("Failed to Update: %s", err) + } + if data == nil { + t.Errorf("Response is nil after Update") + } + + // Upsert + data, err = conn.Upsert(spaceName, + []interface{}{uint(1003), 1}, NewOperations().Add(1, 1)) + if err != nil { + t.Fatalf("Failed to Upsert (insert): %s", err) + } + if data == nil { + t.Errorf("Response is nil after Upsert (insert)") + } + data, err = conn.Upsert(spaceName, + []interface{}{uint(1003), 1}, NewOperations().Add(1, 1)) + if err != nil { + t.Fatalf("Failed to Upsert (update): %s", err) + } + if data == nil { + t.Errorf("Response is nil after Upsert (update)") + } + + // Select + for i := 1010; i < 1020; i++ { + data, err = conn.Replace(spaceName, + []interface{}{uint(i), fmt.Sprintf("val %d", i), "bla"}) + if err != nil { + t.Fatalf("Failed to Replace: %s", err) + } + if data == nil { + t.Errorf("Response is nil after Replace") + } + } + data, err = conn.Select(spaceName, indexName, 0, 1, IterEq, []interface{}{uint(1010)}) + if err != nil { + t.Fatalf("Failed to Select: %s", err) + } + if data == nil { + t.Errorf("Response is nil after Select") + } + + // Select Typed + var tpl []Tuple + err = conn.SelectTyped(spaceName, indexName, 0, 1, IterEq, []interface{}{uint(1010)}, &tpl) + if err != nil { + t.Fatalf("Failed to SelectTyped: %s", err) + } + if len(tpl) != 1 { + t.Errorf("Result len of SelectTyped != 1") + } +} + +func TestClientRequestObjects(t *testing.T) { + var ( + req Request + err error + ) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + // Ping + req = NewPingRequest() + data, err := conn.Do(req).Get() + if err != nil { + t.Fatalf("Failed to Ping: %s", err) + } + if len(data) != 0 { + t.Errorf("Response Body len != 0") + } + + // The code prepares data. + for i := 1010; i < 1020; i++ { + conn.Delete(spaceName, nil, []interface{}{uint(i)}) + } + + // Insert + for i := 1010; i < 1020; i++ { + req = NewInsertRequest(spaceName). + Tuple([]interface{}{uint(i), fmt.Sprintf("val %d", i), "bla"}) + data, err = conn.Do(req).Get() + if err != nil { + t.Fatalf("Failed to Insert: %s", err) + } + if len(data) != 1 { + t.Fatalf("Response Body len != 1") + } + if tpl, ok := data[0].([]interface{}); !ok { + t.Errorf("Unexpected body of Insert") + } else { + if len(tpl) != 3 { + t.Errorf("Unexpected body of Insert (tuple len)") + } + if id, err := test_helpers.ConvertUint64(tpl[0]); err != nil || id != uint64(i) { + t.Errorf("Unexpected body of Insert (0)") + } + if h, ok := tpl[1].(string); !ok || h != fmt.Sprintf("val %d", i) { + t.Errorf("Unexpected body of Insert (1)") + } + if h, ok := tpl[2].(string); !ok || h != "bla" { + t.Errorf("Unexpected body of Insert (2)") + } + } + } + + // Replace + for i := 1015; i < 1020; i++ { + req = NewReplaceRequest(spaceName). + Tuple([]interface{}{uint(i), fmt.Sprintf("val %d", i), "blar"}) + data, err = conn.Do(req).Get() + if err != nil { + t.Fatalf("Failed to Decode: %s", err) + } + if len(data) != 1 { + t.Fatalf("Response Body len != 1") + } + if tpl, ok := data[0].([]interface{}); !ok { + t.Errorf("Unexpected body of Replace") + } else { + if len(tpl) != 3 { + t.Errorf("Unexpected body of Replace (tuple len)") + } + if id, err := test_helpers.ConvertUint64(tpl[0]); err != nil || id != uint64(i) { + t.Errorf("Unexpected body of Replace (0)") + } + if h, ok := tpl[1].(string); !ok || h != fmt.Sprintf("val %d", i) { + t.Errorf("Unexpected body of Replace (1)") + } + if h, ok := tpl[2].(string); !ok || h != "blar" { + t.Errorf("Unexpected body of Replace (2)") + } + } + } + + // Delete + req = NewDeleteRequest(spaceName). + Key([]interface{}{uint(1016)}) + data, err = conn.Do(req).Get() + if err != nil { + t.Fatalf("Failed to Delete: %s", err) + } + if data == nil { + t.Fatalf("Response data is nil after Delete") + } + if len(data) != 1 { + t.Fatalf("Response Body len != 1") + } + if tpl, ok := data[0].([]interface{}); !ok { + t.Errorf("Unexpected body of Delete") + } else { + if len(tpl) != 3 { + t.Errorf("Unexpected body of Delete (tuple len)") + } + if id, err := test_helpers.ConvertUint64(tpl[0]); err != nil || id != uint64(1016) { + t.Errorf("Unexpected body of Delete (0)") + } + if h, ok := tpl[1].(string); !ok || h != "val 1016" { + t.Errorf("Unexpected body of Delete (1)") + } + if h, ok := tpl[2].(string); !ok || h != "blar" { + t.Errorf("Unexpected body of Delete (2)") + } + } + + // Update without operations. + req = NewUpdateRequest(spaceName). + Index(indexName). + Key([]interface{}{uint(1010)}) + data, err = conn.Do(req).Get() + if err != nil { + t.Errorf("Failed to Update: %s", err) + } + if data == nil { + t.Fatalf("Response data is nil after Update") + } + if len(data) != 1 { + t.Fatalf("Response Data len != 1") + } + if tpl, ok := data[0].([]interface{}); !ok { + t.Errorf("Unexpected body of Update") + } else { + if id, err := test_helpers.ConvertUint64(tpl[0]); err != nil || id != uint64(1010) { + t.Errorf("Unexpected body of Update (0)") + } + if h, ok := tpl[1].(string); !ok || h != "val 1010" { + t.Errorf("Unexpected body of Update (1)") + } + if h, ok := tpl[2].(string); !ok || h != "bla" { + t.Errorf("Unexpected body of Update (2)") + } + } + + // Update. + req = NewUpdateRequest(spaceName). + Index(indexName). + Key([]interface{}{uint(1010)}). + Operations(NewOperations().Assign(1, "bye").Insert(2, 1)) + data, err = conn.Do(req).Get() + if err != nil { + t.Errorf("Failed to Update: %s", err) + } + if len(data) != 1 { + t.Fatalf("Response Data len != 1") + } + if tpl, ok := data[0].([]interface{}); !ok { + t.Errorf("Unexpected body of Select") + } else { + if id, err := test_helpers.ConvertUint64(tpl[0]); err != nil || id != 1010 { + t.Errorf("Unexpected body of Update (0)") + } + if h, ok := tpl[1].(string); !ok || h != "bye" { + t.Errorf("Unexpected body of Update (1)") + } + if h, err := test_helpers.ConvertUint64(tpl[2]); err != nil || h != 1 { + t.Errorf("Unexpected body of Update (2)") + } + } + + // Upsert without operations. + req = NewUpsertRequest(spaceNo). + Tuple([]interface{}{uint(1010), "hi", "hi"}) + data, err = conn.Do(req).Get() + if err != nil { + t.Errorf("Failed to Upsert (update): %s", err) + } + if len(data) != 0 { + t.Fatalf("Response Data len != 0") + } + + // Upsert. + req = NewUpsertRequest(spaceNo). + Tuple([]interface{}{uint(1010), "hi", "hi"}). + Operations(NewOperations().Assign(2, "bye")) + data, err = conn.Do(req).Get() + if err != nil { + t.Errorf("Failed to Upsert (update): %s", err) + } + if len(data) != 0 { + t.Fatalf("Response Data len != 0") + } + + // Call16 vs Call17 + req = NewCall16Request("simple_concat").Args([]interface{}{"1"}) + data, err = conn.Do(req).Get() + if err != nil { + t.Errorf("Failed to use Call") + } + if val, ok := data[0].([]interface{})[0].(string); !ok || val != "11" { + t.Errorf("result is not {{1}} : %v", data) + } + + // Call17 + req = NewCall17Request("simple_concat").Args([]interface{}{"1"}) + data, err = conn.Do(req).Get() + if err != nil { + t.Errorf("Failed to use Call17") + } + if val, ok := data[0].(string); !ok || val != "11" { + t.Errorf("result is not {{1}} : %v", data) + } + + // Eval + req = NewEvalRequest("return 5 + 6") + data, err = conn.Do(req).Get() + if err != nil { + t.Fatalf("Failed to Eval: %s", err) + } + if len(data) < 1 { + t.Errorf("Response.Data is empty after Eval") + } + if val, err := test_helpers.ConvertUint64(data[0]); err != nil || val != 11 { + t.Errorf("5 + 6 == 11, but got %v", val) + } + + // Tarantool supports SQL since version 2.0.0 + isLess, err := test_helpers.IsTarantoolVersionLess(2, 0, 0) + if err != nil { + t.Fatalf("Could not check the Tarantool version: %s", err) + } + if isLess { + return + } + + req = NewExecuteRequest(createTableQuery) + resp, err := conn.Do(req).GetResponse() + if err != nil { + t.Fatalf("Failed to Execute: %s", err) + } + if resp == nil { + t.Fatal("Response is nil after Execute") + } + data, err = resp.Decode() + if err != nil { + t.Fatalf("Failed to Decode: %s", err) + } + if len(data) != 0 { + t.Fatalf("Response Body len != 0") + } + exResp, ok := resp.(*ExecuteResponse) + assert.True(t, ok, "Got wrong response type") + sqlInfo, err := exResp.SQLInfo() + assert.NoError(t, err, "Error while getting SQLInfo") + if sqlInfo.AffectedCount != 1 { + t.Errorf("Incorrect count of created spaces: %d", sqlInfo.AffectedCount) + } + + req = NewExecuteRequest(dropQuery2) + resp, err = conn.Do(req).GetResponse() + if err != nil { + t.Fatalf("Failed to Execute: %s", err) + } + if resp == nil { + t.Fatal("Response is nil after Execute") + } + data, err = resp.Decode() + if err != nil { + t.Fatalf("Failed to Decode: %s", err) + } + if len(data) != 0 { + t.Fatalf("Response Body len != 0") + } + exResp, ok = resp.(*ExecuteResponse) + assert.True(t, ok, "Got wrong response type") + sqlInfo, err = exResp.SQLInfo() + assert.NoError(t, err, "Error while getting SQLInfo") + if sqlInfo.AffectedCount != 1 { + t.Errorf("Incorrect count of dropped spaces: %d", sqlInfo.AffectedCount) + } +} + +func testConnectionDoSelectRequestPrepare(t *testing.T, conn Connector) { + t.Helper() + + for i := 1010; i < 1020; i++ { + req := NewReplaceRequest(spaceName).Tuple( + []interface{}{uint(i), fmt.Sprintf("val %d", i), "bla"}) + if _, err := conn.Do(req).Get(); err != nil { + t.Fatalf("Unable to prepare tuples: %s", err) + } + } +} + +func testConnectionDoSelectRequestCheck(t *testing.T, + resp *SelectResponse, err error, pos bool, dataLen int, firstKey uint64) { + t.Helper() + + if err != nil { + t.Fatalf("Failed to Select: %s", err) + } + if resp == nil { + t.Fatalf("Response is nil after Select") + } + respPos, err := resp.Pos() + if err != nil { + t.Errorf("Error while getting Pos: %s", err) + } + if !pos && respPos != nil { + t.Errorf("Response should not have a position descriptor") + } + if pos && respPos == nil { + t.Fatalf("A response must have a position descriptor") + } + data, err := resp.Decode() + if err != nil { + t.Fatalf("Failed to Decode: %s", err) + } + if len(data) != dataLen { + t.Fatalf("Response Data len %d != %d", len(data), dataLen) + } + for i := 0; i < dataLen; i++ { + key := firstKey + uint64(i) + if tpl, ok := data[i].([]interface{}); !ok { + t.Errorf("Unexpected body of Select") + } else { + if id, err := test_helpers.ConvertUint64(tpl[0]); err != nil || id != key { + t.Errorf("Unexpected body of Select (0) %v, expected %d", + tpl[0], key) + } + expectedSecond := fmt.Sprintf("val %d", key) + if h, ok := tpl[1].(string); !ok || h != expectedSecond { + t.Errorf("Unexpected body of Select (1) %q, expected %q", + tpl[1].(string), expectedSecond) + } + if h, ok := tpl[2].(string); !ok || h != "bla" { + t.Errorf("Unexpected body of Select (2) %q, expected %q", + tpl[2].(string), "bla") + } + } + } +} + +func TestConnectionDoSelectRequest(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + testConnectionDoSelectRequestPrepare(t, conn) + + req := NewSelectRequest(spaceNo). + Index(indexNo). + Limit(10). + Iterator(IterGe). + Key([]interface{}{uint(1010)}) + resp, err := conn.Do(req).GetResponse() + + selResp, ok := resp.(*SelectResponse) + assert.True(t, ok, "Got wrong response type") + + testConnectionDoSelectRequestCheck(t, selResp, err, false, 10, 1010) +} + +func TestConnectionDoWatchOnceRequest(t *testing.T) { + test_helpers.SkipIfWatchOnceUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + _, err := conn.Do(NewBroadcastRequest("hello").Value("world")).Get() + if err != nil { + t.Fatalf("Failed to create a broadcast : %s", err.Error()) + } + + data, err := conn.Do(NewWatchOnceRequest("hello")).Get() + if err != nil { + t.Fatalf("Failed to WatchOnce: %s", err.Error()) + } + if len(data) < 1 || data[0] != "world" { + t.Errorf("Failed to WatchOnce: wrong value returned %v", data) + } +} + +func TestConnectionDoWatchOnceOnEmptyKey(t *testing.T) { + test_helpers.SkipIfWatchOnceUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + data, err := conn.Do(NewWatchOnceRequest("notexists!")).Get() + if err != nil { + t.Fatalf("Failed to WatchOnce: %s", err.Error()) + } + if len(data) > 0 { + t.Errorf("Failed to WatchOnce: wrong value returned %v", data) + } +} + +func TestConnectionDoSelectRequest_fetch_pos(t *testing.T) { + test_helpers.SkipIfPaginationUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + testConnectionDoSelectRequestPrepare(t, conn) + + req := NewSelectRequest(spaceNo). + Index(indexNo). + Limit(2). + Iterator(IterGe). + FetchPos(true). + Key([]interface{}{uint(1010)}) + resp, err := conn.Do(req).GetResponse() + + selResp, ok := resp.(*SelectResponse) + assert.True(t, ok, "Got wrong response type") + + testConnectionDoSelectRequestCheck(t, selResp, err, true, 2, 1010) +} + +func TestConnectDoSelectRequest_after_tuple(t *testing.T) { + test_helpers.SkipIfPaginationUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + testConnectionDoSelectRequestPrepare(t, conn) + + req := NewSelectRequest(spaceNo). + Index(indexNo). + Limit(2). + Iterator(IterGe). + FetchPos(true). + Key([]interface{}{uint(1010)}). + After([]interface{}{uint(1012)}) + resp, err := conn.Do(req).GetResponse() + + selResp, ok := resp.(*SelectResponse) + assert.True(t, ok, "Got wrong response type") + + testConnectionDoSelectRequestCheck(t, selResp, err, true, 2, 1013) +} + +func TestConnectionDoSelectRequest_pagination_pos(t *testing.T) { + test_helpers.SkipIfPaginationUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + testConnectionDoSelectRequestPrepare(t, conn) + + req := NewSelectRequest(spaceNo). + Index(indexNo). + Limit(2). + Iterator(IterGe). + FetchPos(true). + Key([]interface{}{uint(1010)}) + resp, err := conn.Do(req).GetResponse() + + selResp, ok := resp.(*SelectResponse) + assert.True(t, ok, "Got wrong response type") + + testConnectionDoSelectRequestCheck(t, selResp, err, true, 2, 1010) + + selPos, err := selResp.Pos() + assert.NoError(t, err, "Error while getting Pos") + + resp, err = conn.Do(req.After(selPos)).GetResponse() + selResp, ok = resp.(*SelectResponse) + assert.True(t, ok, "Got wrong response type") + + testConnectionDoSelectRequestCheck(t, selResp, err, true, 2, 1012) +} + +func TestConnection_Call(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + data, err := conn.Call("simple_concat", []interface{}{"1"}) + if err != nil { + t.Errorf("Failed to use Call") + } + if val, ok := data[0].(string); !ok || val != "11" { + t.Errorf("result is not {{1}} : %v", data) + } +} + +func TestCallRequest(t *testing.T) { + var err error + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + req := NewCallRequest("simple_concat").Args([]interface{}{"1"}) + data, err := conn.Do(req).Get() + if err != nil { + t.Errorf("Failed to use Call") + } + if val, ok := data[0].(string); !ok || val != "11" { + t.Errorf("result is not {{1}} : %v", data) + } +} + +func TestClientRequestObjectsWithNilContext(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + req := NewPingRequest().Context(nil) // nolint + data, err := conn.Do(req).Get() + if err != nil { + t.Fatalf("Failed to Ping: %s", err) + } + if len(data) != 0 { + t.Errorf("Response Body len != 0") + } +} + +func TestClientRequestObjectsWithPassedCanceledContext(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + ctx, cancel := context.WithCancel(context.Background()) + req := NewPingRequest().Context(ctx) + cancel() + resp, err := conn.Do(req).Get() + if !contextDoneErrRegexp.Match([]byte(err.Error())) { + t.Fatalf("Failed to catch an error from done context") + } + if resp != nil { + t.Fatalf("Response is not nil after the occurred error") + } +} + +// Checking comparable with simple context.WithCancel. +func TestComparableErrorsCanceledContext(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + ctx, cancel := context.WithCancel(context.Background()) + req := NewPingRequest().Context(ctx) + cancel() + _, err := conn.Do(req).Get() + require.True(t, errors.Is(err, context.Canceled), err.Error()) +} + +// Checking comparable with simple context.WithTimeout. +func TestComparableErrorsTimeoutContext(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + timeout := time.Nanosecond + ctx, cancel := context.WithTimeout(context.Background(), timeout) + req := NewPingRequest().Context(ctx) + defer cancel() + _, err := conn.Do(req).Get() + require.True(t, errors.Is(err, context.DeadlineExceeded), err.Error()) +} + +// Checking comparable with context.WithCancelCause. +// Shows ability to compare with custom errors (also with ClientError). +func TestComparableErrorsCancelCauseContext(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + ctxCause, cancelCause := context.WithCancelCause(context.Background()) + req := NewPingRequest().Context(ctxCause) + cancelCause(ClientError{ErrConnectionClosed, "something went wrong"}) + _, err := conn.Do(req).Get() + var tmpErr ClientError + require.True(t, errors.As(err, &tmpErr), tmpErr.Error()) +} + +// waitCtxRequest waits for the WaitGroup in Body() call and returns +// the context from Ctx() call. The request helps us to make sure that +// the context's cancel() call is called before a response received. +type waitCtxRequest struct { + ctx context.Context + wg sync.WaitGroup +} + +func (req *waitCtxRequest) Type() iproto.Type { + return NewPingRequest().Type() +} + +func (req *waitCtxRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { + req.wg.Wait() + return NewPingRequest().Body(res, enc) +} + +func (req *waitCtxRequest) Ctx() context.Context { + return req.ctx +} + +func (req *waitCtxRequest) Async() bool { + return NewPingRequest().Async() +} + +func (req *waitCtxRequest) Response(header Header, body io.Reader) (Response, error) { + resp, err := test_helpers.CreateMockResponse(header, body) + return resp, err +} + +func TestClientRequestObjectsWithContext(t *testing.T) { + var err error + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + ctx, cancel := context.WithCancel(context.Background()) + req := &waitCtxRequest{ctx: ctx} + req.wg.Add(1) + + var futWg sync.WaitGroup + var fut *Future + + futWg.Add(1) + go func() { + defer futWg.Done() + fut = conn.Do(req) + }() + + cancel() + req.wg.Done() + + futWg.Wait() + if fut == nil { + t.Fatalf("fut must be not nil") + } + + resp, err := fut.Get() + if resp != nil { + t.Fatalf("response must be nil") + } + if err == nil { + t.Fatalf("caught nil error") + } + if !contextDoneErrRegexp.Match([]byte(err.Error())) { + t.Fatalf("wrong error caught: %v", err) + } +} + +func TestComplexStructs(t *testing.T) { + var err error + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + tuple := Tuple2{Cid: 777, Orig: "orig", Members: []Member{{"lol", "", 1}, {"wut", "", 3}}} + _, err = conn.Replace(spaceNo, &tuple) + if err != nil { + t.Fatalf("Failed to insert: %s", err) + } + + var tuples [1]Tuple2 + err = conn.SelectTyped(spaceNo, indexNo, 0, 1, IterEq, []interface{}{777}, &tuples) + if err != nil { + t.Fatalf("Failed to selectTyped: %s", err) + } + + if len(tuples) != 1 { + t.Errorf("Failed to selectTyped: unexpected array length %d", len(tuples)) + return + } + + if tuple.Cid != tuples[0].Cid || + len(tuple.Members) != len(tuples[0].Members) || + tuple.Members[1].Name != tuples[0].Members[1].Name { + t.Errorf("Failed to selectTyped: incorrect data") + return + } +} + +func TestStream_IdValues(t *testing.T) { + test_helpers.SkipIfStreamsUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + cases := []uint64{ + 1, + 128, + math.MaxUint8, + math.MaxUint8 + 1, + math.MaxUint16, + math.MaxUint16 + 1, + math.MaxUint32, + math.MaxUint32 + 1, + math.MaxUint64, + } + + stream, _ := conn.NewStream() + req := NewPingRequest() + + for _, id := range cases { + t.Run(fmt.Sprintf("%d", id), func(t *testing.T) { + stream.Id = id + _, err := stream.Do(req).Get() + if err != nil { + t.Fatalf("Failed to Ping: %s", err) + } + }) + } +} + +func TestStream_Commit(t *testing.T) { + var req Request + var err error + var conn *Connection + + test_helpers.SkipIfStreamsUnsupported(t) + + conn = test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + stream, _ := conn.NewStream() + + // Begin transaction + req = NewBeginRequest() + _, err = stream.Do(req).Get() + if err != nil { + t.Fatalf("Failed to Begin: %s", err) + } + + // Insert in stream + req = NewInsertRequest(spaceName). + Tuple([]interface{}{uint(1001), "hello2", "world2"}) + _, err = stream.Do(req).Get() + if err != nil { + t.Fatalf("Failed to Insert: %s", err) + } + defer test_helpers.DeleteRecordByKey(t, conn, spaceNo, indexNo, []interface{}{uint(1001)}) + + // Select not related to the transaction + // while transaction is not committed + // result of select is empty + selectReq := NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(IterEq). + Key([]interface{}{uint(1001)}) + data, err := conn.Do(selectReq).Get() + if err != nil { + t.Fatalf("Failed to Select: %s", err) + } + if len(data) != 0 { + t.Fatalf("Response Data len != 0") + } + + // Select in stream + data, err = stream.Do(selectReq).Get() + if err != nil { + t.Fatalf("Failed to Select: %s", err) + } + if len(data) != 1 { + t.Fatalf("Response Data len != 1") + } + if tpl, ok := data[0].([]interface{}); !ok { + t.Fatalf("Unexpected body of Select") + } else { + if id, err := test_helpers.ConvertUint64(tpl[0]); err != nil || id != 1001 { + t.Fatalf("Unexpected body of Select (0)") + } + if h, ok := tpl[1].(string); !ok || h != "hello2" { + t.Fatalf("Unexpected body of Select (1)") + } + if h, ok := tpl[2].(string); !ok || h != "world2" { + t.Fatalf("Unexpected body of Select (2)") + } + } + + // Commit transaction + req = NewCommitRequest() + _, err = stream.Do(req).Get() + if err != nil { + t.Fatalf("Failed to Commit: %s", err) + } + + // Select outside of transaction + data, err = conn.Do(selectReq).Get() + if err != nil { + t.Fatalf("Failed to Select: %s", err) + } + if len(data) != 1 { + t.Fatalf("Response Data len != 1") + } + if tpl, ok := data[0].([]interface{}); !ok { + t.Fatalf("Unexpected body of Select") + } else { + if id, err := test_helpers.ConvertUint64(tpl[0]); err != nil || id != 1001 { + t.Fatalf("Unexpected body of Select (0)") + } + if h, ok := tpl[1].(string); !ok || h != "hello2" { + t.Fatalf("Unexpected body of Select (1)") + } + if h, ok := tpl[2].(string); !ok || h != "world2" { + t.Fatalf("Unexpected body of Select (2)") + } + } +} + +func TestStream_Rollback(t *testing.T) { + var req Request + var err error + var conn *Connection + + test_helpers.SkipIfStreamsUnsupported(t) + + conn = test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + stream, _ := conn.NewStream() + + // Begin transaction + req = NewBeginRequest() + _, err = stream.Do(req).Get() + if err != nil { + t.Fatalf("Failed to Begin: %s", err) + } + + // Insert in stream + req = NewInsertRequest(spaceName). + Tuple([]interface{}{uint(1001), "hello2", "world2"}) + _, err = stream.Do(req).Get() + if err != nil { + t.Fatalf("Failed to Insert: %s", err) + } + defer test_helpers.DeleteRecordByKey(t, conn, spaceNo, indexNo, []interface{}{uint(1001)}) + + // Select not related to the transaction + // while transaction is not committed + // result of select is empty + selectReq := NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(IterEq). + Key([]interface{}{uint(1001)}) + data, err := conn.Do(selectReq).Get() + if err != nil { + t.Fatalf("Failed to Select: %s", err) + } + if len(data) != 0 { + t.Fatalf("Response Data len != 0") + } + + // Select in stream + data, err = stream.Do(selectReq).Get() + if err != nil { + t.Fatalf("Failed to Select: %s", err) + } + if len(data) != 1 { + t.Fatalf("Response Data len != 1") + } + if tpl, ok := data[0].([]interface{}); !ok { + t.Fatalf("Unexpected body of Select") + } else { + if id, err := test_helpers.ConvertUint64(tpl[0]); err != nil || id != 1001 { + t.Fatalf("Unexpected body of Select (0)") + } + if h, ok := tpl[1].(string); !ok || h != "hello2" { + t.Fatalf("Unexpected body of Select (1)") + } + if h, ok := tpl[2].(string); !ok || h != "world2" { + t.Fatalf("Unexpected body of Select (2)") + } + } + + // Rollback transaction + req = NewRollbackRequest() + _, err = stream.Do(req).Get() + if err != nil { + t.Fatalf("Failed to Rollback: %s", err) + } + + // Select outside of transaction + data, err = conn.Do(selectReq).Get() + if err != nil { + t.Fatalf("Failed to Select: %s", err) + } + if len(data) != 0 { + t.Fatalf("Response Data len != 0") + } + + // Select inside of stream after rollback + _, err = stream.Do(selectReq).Get() + if err != nil { + t.Fatalf("Failed to Select: %s", err) + } + if len(data) != 0 { + t.Fatalf("Response Data len != 0") + } +} + +func TestStream_TxnIsolationLevel(t *testing.T) { + var req Request + var err error + var conn *Connection + + txnIsolationLevels := []TxnIsolationLevel{ + DefaultIsolationLevel, + ReadCommittedLevel, + ReadConfirmedLevel, + BestEffortLevel, + } + + test_helpers.SkipIfStreamsUnsupported(t) + + conn = test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + stream, _ := conn.NewStream() + + for _, level := range txnIsolationLevels { + // Begin transaction + req = NewBeginRequest().TxnIsolation(level).Timeout(500 * time.Millisecond) + _, err = stream.Do(req).Get() + require.Nilf(t, err, "failed to Begin") + + // Insert in stream + req = NewInsertRequest(spaceName). + Tuple([]interface{}{uint(1001), "hello2", "world2"}) + _, err = stream.Do(req).Get() + require.Nilf(t, err, "failed to Insert") + + // Select not related to the transaction + // while transaction is not committed + // result of select is empty + selectReq := NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(IterEq). + Key([]interface{}{uint(1001)}) + data, err := conn.Do(selectReq).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, 0, len(data), "response Data len != 0") + + // Select in stream + data, err = stream.Do(selectReq).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, 1, len(data), "response Body len != 1 after Select") + + tpl, ok := data[0].([]interface{}) + require.Truef(t, ok, "unexpected body of Select") + require.Equalf(t, 3, len(tpl), "unexpected body of Select") + + key, err := test_helpers.ConvertUint64(tpl[0]) + require.Nilf(t, err, "unexpected body of Select (0)") + require.Equalf(t, uint64(1001), key, "unexpected body of Select (0)") + + value1, ok := tpl[1].(string) + require.Truef(t, ok, "unexpected body of Select (1)") + require.Equalf(t, "hello2", value1, "unexpected body of Select (1)") + + value2, ok := tpl[2].(string) + require.Truef(t, ok, "unexpected body of Select (2)") + require.Equalf(t, "world2", value2, "unexpected body of Select (2)") + + // Rollback transaction + req = NewRollbackRequest() + _, err = stream.Do(req).Get() + require.Nilf(t, err, "failed to Rollback") + + // Select outside of transaction + data, err = conn.Do(selectReq).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, 0, len(data), "response Data len != 0") + + // Select inside of stream after rollback + data, err = stream.Do(selectReq).Get() + require.Nilf(t, err, "failed to Select") + require.Equalf(t, 0, len(data), "response Data len != 0") + + test_helpers.DeleteRecordByKey(t, conn, spaceNo, indexNo, []interface{}{uint(1001)}) + } +} + +func TestStream_DoWithStrangerConn(t *testing.T) { + expectedErr := fmt.Errorf("the passed connected request " + + "doesn't belong to the current connection or connection pool") + + conn := &Connection{} + stream, _ := conn.NewStream() + req := test_helpers.NewMockRequest() + + _, err := stream.Do(req).Get() + if err == nil { + t.Fatalf("nil error has been caught") + } + if err.Error() != expectedErr.Error() { + t.Fatalf("Unexpected error has been caught: %s", err.Error()) + } +} + +func TestStream_DoWithClosedConn(t *testing.T) { + expectedErr := fmt.Errorf("using closed connection") + + test_helpers.SkipIfStreamsUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + + stream, _ := conn.NewStream() + conn.Close() + + // Begin transaction + req := NewBeginRequest() + _, err := stream.Do(req).Get() + if err == nil { + t.Fatalf("nil error has been caught") + } + if !strings.Contains(err.Error(), expectedErr.Error()) { + t.Fatalf("Unexpected error has been caught: %s", err.Error()) + } +} + +func TestConnectionBoxSessionPushUnsupported(t *testing.T) { + old := log.Writer() + defer log.SetOutput(old) + + var buf bytes.Buffer + log.SetOutput(&buf) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + _, err := conn.Do(NewCallRequest("push_func").Args([]interface{}{1})).Get() + require.NoError(t, err) + + actualLog := buf.String() + expectedLog := "unsupported box.session.push()" + require.Contains(t, actualLog, expectedLog) +} + +func TestConnectionProtocolInfoSupported(t *testing.T) { + test_helpers.SkipIfIdUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + // First Tarantool protocol version (1, IPROTO_FEATURE_STREAMS and + // IPROTO_FEATURE_TRANSACTIONS) was introduced between 2.10.0-beta1 and + // 2.10.0-beta2. Versions 2 (IPROTO_FEATURE_ERROR_EXTENSION) and + // 3 (IPROTO_FEATURE_WATCHERS) were also introduced between 2.10.0-beta1 and + // 2.10.0-beta2. Version 4 (IPROTO_FEATURE_PAGINATION) was introduced in + // master 948e5cd (possible 2.10.5 or 2.11.0). So each release + // Tarantool >= 2.10 (same as each Tarantool with id support) has protocol + // version >= 3 and first four features. + tarantool210ProtocolInfo := ProtocolInfo{ + Version: ProtocolVersion(3), + Features: []iproto.Feature{ + iproto.IPROTO_FEATURE_STREAMS, + iproto.IPROTO_FEATURE_TRANSACTIONS, + iproto.IPROTO_FEATURE_ERROR_EXTENSION, + iproto.IPROTO_FEATURE_WATCHERS, + }, + } + + serverProtocolInfo := conn.ProtocolInfo() + require.GreaterOrEqual(t, + serverProtocolInfo.Version, + tarantool210ProtocolInfo.Version) + require.Subset(t, + serverProtocolInfo.Features, + tarantool210ProtocolInfo.Features) +} + +func TestClientIdRequestObject(t *testing.T) { + test_helpers.SkipIfIdUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + tarantool210ProtocolInfo := ProtocolInfo{ + Version: ProtocolVersion(3), + Features: []iproto.Feature{ + iproto.IPROTO_FEATURE_STREAMS, + iproto.IPROTO_FEATURE_TRANSACTIONS, + iproto.IPROTO_FEATURE_ERROR_EXTENSION, + iproto.IPROTO_FEATURE_WATCHERS, + }, + } + + req := NewIdRequest(ProtocolInfo{ + Version: ProtocolVersion(1), + Features: []iproto.Feature{iproto.IPROTO_FEATURE_STREAMS}, + }) + data, err := conn.Do(req).Get() + require.Nilf(t, err, "No errors on Id request execution") + require.NotNilf(t, data, "Response data not empty") + require.Equal(t, len(data), 1, "Response data contains exactly one object") + + serverProtocolInfo, ok := data[0].(ProtocolInfo) + require.Truef(t, ok, "Response Data object is an ProtocolInfo object") + require.GreaterOrEqual(t, + serverProtocolInfo.Version, + tarantool210ProtocolInfo.Version) + require.Subset(t, + serverProtocolInfo.Features, + tarantool210ProtocolInfo.Features) +} + +func TestClientIdRequestObjectWithNilContext(t *testing.T) { + test_helpers.SkipIfIdUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + tarantool210ProtocolInfo := ProtocolInfo{ + Version: ProtocolVersion(3), + Features: []iproto.Feature{ + iproto.IPROTO_FEATURE_STREAMS, + iproto.IPROTO_FEATURE_TRANSACTIONS, + iproto.IPROTO_FEATURE_ERROR_EXTENSION, + iproto.IPROTO_FEATURE_WATCHERS, + }, + } + + req := NewIdRequest(ProtocolInfo{ + Version: ProtocolVersion(1), + Features: []iproto.Feature{iproto.IPROTO_FEATURE_STREAMS}, + }).Context(nil) // nolint + data, err := conn.Do(req).Get() + require.Nilf(t, err, "No errors on Id request execution") + require.NotNilf(t, data, "Response data not empty") + require.Equal(t, len(data), 1, "Response data contains exactly one object") + + serverProtocolInfo, ok := data[0].(ProtocolInfo) + require.Truef(t, ok, "Response Data object is an ProtocolInfo object") + require.GreaterOrEqual(t, + serverProtocolInfo.Version, + tarantool210ProtocolInfo.Version) + require.Subset(t, + serverProtocolInfo.Features, + tarantool210ProtocolInfo.Features) +} + +func TestClientIdRequestObjectWithPassedCanceledContext(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + ctx, cancel := context.WithCancel(context.Background()) + req := NewIdRequest(ProtocolInfo{ + Version: ProtocolVersion(1), + Features: []iproto.Feature{iproto.IPROTO_FEATURE_STREAMS}, + }).Context(ctx) // nolint + cancel() + resp, err := conn.Do(req).Get() + require.Nilf(t, resp, "Response is empty") + require.NotNilf(t, err, "Error is not empty") + require.Regexp(t, contextDoneErrRegexp, err.Error()) +} + +func TestConnectionProtocolInfoUnsupported(t *testing.T) { + test_helpers.SkipIfIdSupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + serverProtocolInfo := conn.ProtocolInfo() + expected := ProtocolInfo{} + require.Equal(t, expected, serverProtocolInfo) +} + +func TestConnectionServerFeaturesImmutable(t *testing.T) { + test_helpers.SkipIfIdUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + info := conn.ProtocolInfo() + infoOrig := info.Clone() + info.Features[0] = iproto.Feature(15532) + + require.Equal(t, conn.ProtocolInfo(), infoOrig) + require.NotEqual(t, conn.ProtocolInfo(), info) +} + +func TestConnectionProtocolVersionRequirementSuccess(t *testing.T) { + test_helpers.SkipIfIdUnsupported(t) + + testDialer := dialer + testDialer.RequiredProtocolInfo = ProtocolInfo{ + Version: ProtocolVersion(3), + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := Connect(ctx, testDialer, opts) + + require.Nilf(t, err, "No errors on connect") + require.NotNilf(t, conn, "Connect success") + + conn.Close() +} + +func TestConnectionProtocolVersionRequirementFail(t *testing.T) { + test_helpers.SkipIfIdSupported(t) + + testDialer := dialer + testDialer.RequiredProtocolInfo = ProtocolInfo{ + Version: ProtocolVersion(3), + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := Connect(ctx, testDialer, opts) + + require.Nilf(t, conn, "Connect fail") + require.NotNilf(t, err, "Got error on connect") + require.Contains(t, err.Error(), "invalid server protocol: protocol version 3 is not supported") +} + +func TestConnectionProtocolFeatureRequirementSuccess(t *testing.T) { + test_helpers.SkipIfIdUnsupported(t) + + testDialer := dialer + testDialer.RequiredProtocolInfo = ProtocolInfo{ + Features: []iproto.Feature{iproto.IPROTO_FEATURE_TRANSACTIONS}, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := Connect(ctx, testDialer, opts) + + require.NotNilf(t, conn, "Connect success") + require.Nilf(t, err, "No errors on connect") + + conn.Close() +} + +func TestConnectionProtocolFeatureRequirementFail(t *testing.T) { + test_helpers.SkipIfIdSupported(t) + + testDialer := dialer + testDialer.RequiredProtocolInfo = ProtocolInfo{ + Features: []iproto.Feature{iproto.IPROTO_FEATURE_TRANSACTIONS}, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := Connect(ctx, testDialer, opts) + + require.Nilf(t, conn, "Connect fail") + require.NotNilf(t, err, "Got error on connect") + require.Contains(t, err.Error(), + "invalid server protocol: protocol feature "+ + "IPROTO_FEATURE_TRANSACTIONS is not supported") +} + +func TestConnectionProtocolFeatureRequirementManyFail(t *testing.T) { + test_helpers.SkipIfIdSupported(t) + + testDialer := dialer + testDialer.RequiredProtocolInfo = ProtocolInfo{ + Features: []iproto.Feature{iproto.IPROTO_FEATURE_TRANSACTIONS, + iproto.Feature(15532)}, + } + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := Connect(ctx, testDialer, opts) + + require.Nilf(t, conn, "Connect fail") + require.NotNilf(t, err, "Got error on connect") + require.Contains(t, + err.Error(), + "invalid server protocol: protocol features IPROTO_FEATURE_TRANSACTIONS, "+ + "Feature(15532) are not supported") +} + +func TestErrorExtendedInfoBasic(t *testing.T) { + test_helpers.SkipIfErrorExtendedInfoUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + _, err := conn.Eval("not a Lua code", []interface{}{}) + require.NotNilf(t, err, "expected error on invalid Lua code") + + ttErr, ok := err.(Error) + require.Equalf(t, ok, true, "error is built from a Tarantool error") + + expected := BoxError{ + Type: "LuajitError", + File: "eval", + Line: uint64(1), + Msg: "eval:1: unexpected symbol near 'not'", + Errno: uint64(0), + Code: uint64(32), + } + + // In fact, CheckEqualBoxErrors does not check than File and Line + // of connector BoxError are equal to the Tarantool ones + // since they may differ between different Tarantool versions + // and editions. + test_helpers.CheckEqualBoxErrors(t, expected, *ttErr.ExtendedInfo) +} + +func TestErrorExtendedInfoStack(t *testing.T) { + test_helpers.SkipIfErrorExtendedInfoUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + _, err := conn.Eval("error(chained_error)", []interface{}{}) + require.NotNilf(t, err, "expected error on explicit error raise") + + ttErr, ok := err.(Error) + require.Equalf(t, ok, true, "error is built from a Tarantool error") + + expected := BoxError{ + Type: "ClientError", + File: "config.lua", + Line: uint64(214), + Msg: "Timeout exceeded", + Errno: uint64(0), + Code: uint64(78), + Prev: &BoxError{ + Type: "ClientError", + File: "config.lua", + Line: uint64(213), + Msg: "Unknown error", + Errno: uint64(0), + Code: uint64(0), + }, + } + + // In fact, CheckEqualBoxErrors does not check than File and Line + // of connector BoxError are equal to the Tarantool ones + // since they may differ between different Tarantool versions + // and editions. + test_helpers.CheckEqualBoxErrors(t, expected, *ttErr.ExtendedInfo) +} + +func TestErrorExtendedInfoFields(t *testing.T) { + test_helpers.SkipIfErrorExtendedInfoUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + _, err := conn.Eval("error(access_denied_error)", []interface{}{}) + require.NotNilf(t, err, "expected error on forbidden action") + + ttErr, ok := err.(Error) + require.Equalf(t, ok, true, "error is built from a Tarantool error") + + expected := BoxError{ + Type: "AccessDeniedError", + File: "/__w/sdk/sdk/tarantool-2.10/tarantool/src/box/func.c", + Line: uint64(535), + Msg: "Execute access to function 'forbidden_function' is denied for user 'no_grants'", + Errno: uint64(0), + Code: uint64(42), + Fields: map[string]interface{}{ + "object_type": "function", + "object_name": "forbidden_function", + "access_type": "Execute", + }, + } + + // In fact, CheckEqualBoxErrors does not check than File and Line + // of connector BoxError are equal to the Tarantool ones + // since they may differ between different Tarantool versions + // and editions. + test_helpers.CheckEqualBoxErrors(t, expected, *ttErr.ExtendedInfo) +} + +func TestConnection_NewWatcher(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnection_NewWatcher" + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + events := make(chan WatchEvent) + defer close(events) + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + + select { + case event := <-events: + if event.Conn != conn { + t.Errorf("Unexpected event connection: %v", event.Conn) + } + if event.Key != key { + t.Errorf("Unexpected event key: %s", event.Key) + } + if event.Value != nil { + t.Errorf("Unexpected event value: %v", event.Value) + } + case <-time.After(time.Second): + t.Fatalf("Failed to get watch event.") + } +} + +func newWatcherReconnectionPrepareTestConnection(t *testing.T) (*Connection, context.CancelFunc) { + t.Helper() + + const server = "127.0.0.1:3015" + testDialer := dialer + testDialer.Address = server + + inst, err := test_helpers.StartTarantool(test_helpers.StartOpts{ + Dialer: testDialer, + InitScript: "config.lua", + Listen: server, + WaitStart: 100 * time.Millisecond, + ConnectRetry: 10, + RetryTimeout: 500 * time.Millisecond, + }) + t.Cleanup(func() { test_helpers.StopTarantoolWithCleanup(inst) }) + if err != nil { + t.Fatalf("Unable to start Tarantool: %s", err) + } + + ctx, cancel := test_helpers.GetConnectContext() + + reconnectOpts := opts + reconnectOpts.Reconnect = 100 * time.Millisecond + reconnectOpts.MaxReconnects = 0 + reconnectOpts.Notify = make(chan ConnEvent) + conn, err := Connect(ctx, testDialer, reconnectOpts) + if err != nil { + t.Fatalf("Connection was not established: %v", err) + } + + test_helpers.StopTarantool(inst) + + // Wait for reconnection process to be started. + for conn.ConnectedNow() { + time.Sleep(100 * time.Millisecond) + } + + return conn, cancel +} + +func TestNewWatcherDuringReconnect(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + conn, cancel := newWatcherReconnectionPrepareTestConnection(t) + defer conn.Close() + defer cancel() + + _, err := conn.NewWatcher("one", func(event WatchEvent) {}) + assert.NotNil(t, err) + assert.ErrorContains(t, err, "client connection is not ready") +} + +func TestNewWatcherAfterClose(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + conn, cancel := newWatcherReconnectionPrepareTestConnection(t) + defer cancel() + + _ = conn.Close() + + _, err := conn.NewWatcher("one", func(event WatchEvent) {}) + assert.NotNil(t, err) + assert.ErrorContains(t, err, "using closed connection") +} + +func TestConnection_NewWatcher_noWatchersFeature(t *testing.T) { + test_helpers.SkipIfWatchersSupported(t) + + const key = "TestConnection_NewWatcher_noWatchersFeature" + testDialer := dialer + testDialer.RequiredProtocolInfo = ProtocolInfo{Features: []iproto.Feature{}} + conn := test_helpers.ConnectWithValidation(t, testDialer, opts) + defer conn.Close() + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) {}) + require.Nilf(t, watcher, "watcher must not be created") + require.NotNilf(t, err, "an error is expected") + expected := "the feature IPROTO_FEATURE_WATCHERS must be supported by " + + "connection to create a watcher" + require.Equal(t, expected, err.Error()) +} + +func TestConnection_NewWatcher_reconnect(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnection_NewWatcher_reconnect" + const server = "127.0.0.1:3014" + + testDialer := dialer + testDialer.Address = server + + inst, err := test_helpers.StartTarantool(test_helpers.StartOpts{ + Dialer: testDialer, + InitScript: "config.lua", + Listen: server, + WaitStart: 100 * time.Millisecond, + ConnectRetry: 10, + RetryTimeout: 500 * time.Millisecond, + }) + defer test_helpers.StopTarantoolWithCleanup(inst) + if err != nil { + t.Fatalf("Unable to start Tarantool: %s", err) + } + + reconnectOpts := opts + reconnectOpts.Reconnect = 100 * time.Millisecond + reconnectOpts.MaxReconnects = 10 + + conn := test_helpers.ConnectWithValidation(t, testDialer, reconnectOpts) + defer conn.Close() + + events := make(chan WatchEvent) + defer close(events) + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + + <-events + + test_helpers.StopTarantool(inst) + if err := test_helpers.RestartTarantool(inst); err != nil { + t.Fatalf("Unable to restart Tarantool: %s", err) + } + + maxTime := reconnectOpts.Reconnect * time.Duration(reconnectOpts.MaxReconnects) + select { + case <-events: + case <-time.After(maxTime): + t.Fatalf("Failed to get watch event.") + } +} + +func TestBroadcastRequest(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestBroadcastRequest" + const value = "bar" + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + data, err := conn.Do(NewBroadcastRequest(key).Value(value)).Get() + if err != nil { + t.Fatalf("Got broadcast error: %s", err) + } + if !reflect.DeepEqual(data, []interface{}{}) { + t.Errorf("Got unexpected broadcast response data: %v", data) + } + + events := make(chan WatchEvent) + defer close(events) + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + + select { + case event := <-events: + if event.Conn != conn { + t.Errorf("Unexpected event connection: %v", event.Conn) + } + if event.Key != key { + t.Errorf("Unexpected event key: %s", event.Key) + } + if event.Value != value { + t.Errorf("Unexpected event value: %v", event.Value) + } + case <-time.After(time.Second): + t.Fatalf("Failed to get watch event.") + } +} + +func TestBroadcastRequest_multi(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestBroadcastRequest_multi" + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + events := make(chan WatchEvent) + defer close(events) + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + + <-events // Skip an initial event. + for i := 0; i < 10; i++ { + val := fmt.Sprintf("%d", i) + _, err := conn.Do(NewBroadcastRequest(key).Value(val)).Get() + if err != nil { + t.Fatalf("Failed to send a broadcast request: %s", err) + } + select { + case event := <-events: + if event.Conn != conn { + t.Errorf("Unexpected event connection: %v", event.Conn) + } + if event.Key != key { + t.Errorf("Unexpected event key: %s", event.Key) + } + if event.Value.(string) != val { + t.Errorf("Unexpected event value: %v", event.Value) + } + case <-time.After(time.Second): + t.Fatalf("Failed to get watch event %d", i) + } + } +} + +func TestConnection_NewWatcher_multiOnKey(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnection_NewWatcher_multiOnKey" + const value = "bar" + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + events := []chan WatchEvent{ + make(chan WatchEvent), + make(chan WatchEvent), + } + for _, ch := range events { + defer close(ch) + } + + for _, ch := range events { + channel := ch + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + channel <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + } + + for i, ch := range events { + select { + case <-ch: // Skip an initial event. + case <-time.After(2 * time.Second): + t.Fatalf("Failed to skip watch event for %d callback", i) + } + } + + _, err := conn.Do(NewBroadcastRequest(key).Value(value)).Get() + if err != nil { + t.Fatalf("Failed to send a broadcast request: %s", err) + } + + for i, ch := range events { + select { + case event := <-ch: + if event.Conn != conn { + t.Errorf("Unexpected event connection: %v", event.Conn) + } + if event.Key != key { + t.Errorf("Unexpected event key: %s", event.Key) + } + if event.Value.(string) != value { + t.Errorf("Unexpected event value: %v", event.Value) + } + case <-time.After(2 * time.Second): + t.Fatalf("Failed to get watch event from callback %d", i) + } + } +} + +func TestWatcher_Unregister(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestWatcher_Unregister" + const value = "bar" + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + events := make(chan WatchEvent) + defer close(events) + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + + <-events + watcher.Unregister() + + _, err = conn.Do(NewBroadcastRequest(key).Value(value)).Get() + if err != nil { + t.Fatalf("Got broadcast error: %s", err) + } + + select { + case event := <-events: + t.Fatalf("Get unexpected events: %v", event) + case <-time.After(time.Second): + } +} + +func TestConnection_NewWatcher_concurrent(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const testConcurrency = 1000 + const key = "TestConnection_NewWatcher_concurrent" + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + var wg sync.WaitGroup + wg.Add(testConcurrency) + + errors := make(chan error, testConcurrency) + for i := 0; i < testConcurrency; i++ { + go func(i int) { + defer wg.Done() + + events := make(chan struct{}) + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + close(events) + }) + if err != nil { + errors <- err + } else { + select { + case <-events: + case <-time.After(time.Second): + errors <- fmt.Errorf("Unable to get an event %d", i) + } + watcher.Unregister() + } + }(i) + } + wg.Wait() + close(errors) + + for err := range errors { + t.Errorf("An error found: %s", err) + } +} + +func TestWatcher_Unregister_concurrent(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const testConcurrency = 1000 + const key = "TestWatcher_Unregister_concurrent" + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) {}) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + + var wg sync.WaitGroup + wg.Add(testConcurrency) + + for i := 0; i < testConcurrency; i++ { + go func() { + defer wg.Done() + watcher.Unregister() + }() + } + wg.Wait() +} + +func TestConnection_named_index_after_reconnect(t *testing.T) { + const server = "127.0.0.1:3015" + + testDialer := dialer + testDialer.Address = server + + inst, err := test_helpers.StartTarantool(test_helpers.StartOpts{ + Dialer: testDialer, + InitScript: "config.lua", + Listen: server, + WaitStart: 100 * time.Millisecond, + ConnectRetry: 10, + RetryTimeout: 500 * time.Millisecond, + }) + defer test_helpers.StopTarantoolWithCleanup(inst) + if err != nil { + t.Fatalf("Unable to start Tarantool: %s", err) + } + + reconnectOpts := opts + reconnectOpts.Reconnect = 100 * time.Millisecond + reconnectOpts.MaxReconnects = 10 + + conn := test_helpers.ConnectWithValidation(t, testDialer, reconnectOpts) + defer conn.Close() + + test_helpers.StopTarantool(inst) + + request := NewSelectRequest("test").Index("primary").Limit(1) + _, err = conn.Do(request).Get() + if err == nil { + t.Fatalf("An error expected.") + } + + if err := test_helpers.RestartTarantool(inst); err != nil { + t.Fatalf("Unable to restart Tarantool: %s", err) + } + + maxTime := reconnectOpts.Reconnect * time.Duration(reconnectOpts.MaxReconnects) + timeout := time.After(maxTime) + + for { + select { + case <-timeout: + t.Fatalf("Failed to execute request without an error, last error: %s", err) + default: + } + + _, err = conn.Do(request).Get() + if err == nil { + return + } + } +} + +func TestConnect_schema_update(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + for i := 0; i < 100; i++ { + fut := conn.Do(NewCallRequest("create_spaces")) + + switch conn, err := Connect(ctx, dialer, opts); { + case err != nil: + assert.ErrorIs(t, err, ErrConcurrentSchemaUpdate) + case conn == nil: + assert.Fail(t, "conn is nil") + default: + _ = conn.Close() + } + + if _, err := fut.Get(); err != nil { + t.Errorf("Failed to call create_spaces: %s", err) + } + } +} + +func TestConnect_context_cancel(t *testing.T) { + var connLongReconnectOpts = Opts{ + Timeout: 5 * time.Second, + Reconnect: time.Second, + MaxReconnects: 100, + } + + ctx, cancel := context.WithCancel(context.Background()) + + var conn *Connection + var err error + + cancel() + conn, err = Connect(ctx, dialer, connLongReconnectOpts) + + if conn != nil || err == nil { + t.Fatalf("Connection was created after cancel") + } + if !strings.Contains(err.Error(), "operation was canceled") { + t.Fatalf("Unexpected error, expected to contain %s, got %v", + "operation was canceled", err) + } +} + +// A dialer that rejects the first few connection requests. +type mockSlowDialer struct { + counter *int + original NetDialer +} + +func (m mockSlowDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { + *m.counter++ + if *m.counter < 10 { + return nil, fmt.Errorf("Too early: %v", *m.counter) + } + return m.original.Dial(ctx, opts) +} + +func TestConnectIsBlocked(t *testing.T) { + const server = "127.0.0.1:3015" + testDialer := dialer + testDialer.Address = server + + inst, err := test_helpers.StartTarantool(test_helpers.StartOpts{ + Dialer: testDialer, + InitScript: "config.lua", + Listen: server, + WaitStart: 100 * time.Millisecond, + ConnectRetry: 10, + RetryTimeout: 500 * time.Millisecond, + }) + defer test_helpers.StopTarantoolWithCleanup(inst) + if err != nil { + t.Fatalf("Unable to start Tarantool: %s", err) + } + + var counter int + mockDialer := mockSlowDialer{original: testDialer, counter: &counter} + ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second) + defer cancel() + + reconnectOpts := opts + reconnectOpts.Reconnect = 100 * time.Millisecond + reconnectOpts.MaxReconnects = 100 + conn, err := Connect(ctx, mockDialer, reconnectOpts) + assert.Nil(t, err) + conn.Close() + assert.GreaterOrEqual(t, counter, 10) +} + +func TestConnectIsBlockedUntilContextExpires(t *testing.T) { + const server = "127.0.0.1:3015" + + testDialer := dialer + testDialer.Address = server + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + + reconnectOpts := opts + reconnectOpts.Reconnect = 100 * time.Millisecond + reconnectOpts.MaxReconnects = 100 + _, err := Connect(ctx, testDialer, reconnectOpts) + assert.NotNil(t, err) + assert.ErrorContains(t, err, "failed to dial: dial tcp 127.0.0.1:3015: i/o timeout") +} + +func TestConnectIsUnblockedAfterMaxAttempts(t *testing.T) { + const server = "127.0.0.1:3015" + + testDialer := dialer + testDialer.Address = server + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + + reconnectOpts := opts + reconnectOpts.Reconnect = 100 * time.Millisecond + reconnectOpts.MaxReconnects = 1 + _, err := Connect(ctx, testDialer, reconnectOpts) + assert.NotNil(t, err) + assert.ErrorContains(t, err, "last reconnect failed") +} + +func buildSidecar(dir string) error { + goPath, err := exec.LookPath("go") + if err != nil { + return err + } + cmd := exec.Command(goPath, "build", "main.go") + cmd.Dir = filepath.Join(dir, "testdata", "sidecar") + return cmd.Run() +} + +func TestFdDialer(t *testing.T) { + isLess, err := test_helpers.IsTarantoolVersionLess(3, 0, 0) + if err != nil || isLess { + t.Skip("box.session.new present in Tarantool since version 3.0") + } + + wd, err := os.Getwd() + require.NoError(t, err) + + err = buildSidecar(wd) + require.NoErrorf(t, err, "failed to build sidecar: %v", err) + + instOpts := startOpts + instOpts.Listen = fdDialerTestServer + instOpts.Dialer = NetDialer{ + Address: fdDialerTestServer, + User: "test", + Password: "test", + } + + inst, err := test_helpers.StartTarantool(instOpts) + require.NoError(t, err) + defer test_helpers.StopTarantoolWithCleanup(inst) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + sidecarExe := filepath.Join(wd, "testdata", "sidecar", "main") + + evalBody := fmt.Sprintf(` + local socket = require('socket') + local popen = require('popen') + local os = require('os') + local s1, s2 = socket.socketpair('AF_UNIX', 'SOCK_STREAM', 0) + + --[[ Tell sidecar which fd use to connect. --]] + os.setenv('SOCKET_FD', tostring(s2:fd())) + + box.session.new({ + type = 'binary', + fd = s1:fd(), + user = 'test', + }) + s1:detach() + + local ph, err = popen.new({'%s'}, { + stdout = popen.opts.PIPE, + stderr = popen.opts.PIPE, + inherit_fds = {s2:fd()}, + }) + + if err ~= nil then + return 1, err + end + + ph:wait() + + local status_code = ph:info().status.exit_code + local stderr = ph:read({stderr=true}):rstrip() + local stdout = ph:read({stdout=true}):rstrip() + return status_code, stderr, stdout + `, sidecarExe) + + var resp []interface{} + err = conn.EvalTyped(evalBody, []interface{}{}, &resp) + require.NoError(t, err) + require.Equal(t, "", resp[1], resp[1]) + require.Equal(t, "", resp[2], resp[2]) + require.Equal(t, int8(0), resp[0]) +} + +const ( + errNoSyncTransactionQueue = "The synchronous transaction queue doesn't belong to any instance" +) + +func TestDoBeginRequest_IsSync(t *testing.T) { + test_helpers.SkipIfIsSyncUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + stream, err := conn.NewStream() + require.NoError(t, err) + + _, err = stream.Do(NewBeginRequest().IsSync(true)).Get() + assert.Nil(t, err) + + _, err = stream.Do( + NewReplaceRequest("test").Tuple([]interface{}{1, "foo"}), + ).Get() + require.Nil(t, err) + + _, err = stream.Do(NewCommitRequest()).Get() + require.NotNil(t, err) + assert.Contains(t, err.Error(), errNoSyncTransactionQueue) +} + +func TestDoCommitRequest_IsSync(t *testing.T) { + test_helpers.SkipIfIsSyncUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + stream, err := conn.NewStream() + require.NoError(t, err) + + _, err = stream.Do(NewBeginRequest()).Get() + require.Nil(t, err) + + _, err = stream.Do( + NewReplaceRequest("test").Tuple([]interface{}{1, "foo"}), + ).Get() + require.Nil(t, err) + + _, err = stream.Do(NewCommitRequest().IsSync(true)).Get() + require.NotNil(t, err) + assert.Contains(t, err.Error(), errNoSyncTransactionQueue) +} + +func TestDoCommitRequest_NoSync(t *testing.T) { + test_helpers.SkipIfIsSyncUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + stream, err := conn.NewStream() + require.NoError(t, err) + + _, err = stream.Do(NewBeginRequest()).Get() + require.Nil(t, err) + + _, err = stream.Do( + NewReplaceRequest("test").Tuple([]interface{}{1, "foo"}), + ).Get() + require.Nil(t, err) + + _, err = stream.Do(NewCommitRequest()).Get() + assert.Nil(t, err) +} + +// runTestMain is a body of TestMain function +// (see https://pkg.go.dev/testing#hdr-Main). +// Using defer + os.Exit is not works so TestMain body +// is a separate function, see +// https://stackoverflow.com/questions/27629380/how-to-exit-a-go-program-honoring-deferred-calls +func runTestMain(m *testing.M) int { + // Tarantool supports streams and interactive transactions since version 2.10.0 + isStreamUnsupported, err := test_helpers.IsTarantoolVersionLess(2, 10, 0) + if err != nil { + log.Fatalf("Could not check the Tarantool version: %s", err) + } + + startOpts.MemtxUseMvccEngine = !isStreamUnsupported + + inst, err := test_helpers.StartTarantool(startOpts) + if err != nil { + log.Printf("Failed to prepare test tarantool: %s", err) + return 1 + } + + defer test_helpers.StopTarantoolWithCleanup(inst) + + return m.Run() +} + +func TestMain(m *testing.M) { + code := runTestMain(m) + os.Exit(code) +} diff --git a/test_helpers/doer.go b/test_helpers/doer.go new file mode 100644 index 000000000..b61692c43 --- /dev/null +++ b/test_helpers/doer.go @@ -0,0 +1,69 @@ +package test_helpers + +import ( + "bytes" + "testing" + + "github.com/tarantool/go-tarantool/v3" +) + +type doerResponse struct { + resp *MockResponse + err error +} + +// MockDoer is an implementation of the Doer interface +// used for testing purposes. +type MockDoer struct { + // Requests is a slice of received requests. + // It could be used to compare incoming requests with expected. + Requests []tarantool.Request + responses []doerResponse + t *testing.T +} + +// NewMockDoer creates a MockDoer by given responses. +// Each response could be one of two types: MockResponse or error. +func NewMockDoer(t *testing.T, responses ...interface{}) MockDoer { + t.Helper() + + mockDoer := MockDoer{t: t} + for _, response := range responses { + doerResp := doerResponse{} + + switch resp := response.(type) { + case *MockResponse: + doerResp.resp = resp + case error: + doerResp.err = resp + default: + t.Fatalf("unsupported type: %T", response) + } + + mockDoer.responses = append(mockDoer.responses, doerResp) + } + return mockDoer +} + +// Do returns a future with the current response or an error. +// It saves the current request into MockDoer.Requests. +func (doer *MockDoer) Do(req tarantool.Request) *tarantool.Future { + doer.Requests = append(doer.Requests, req) + + mockReq := NewMockRequest() + fut := tarantool.NewFuture(mockReq) + + if len(doer.responses) == 0 { + doer.t.Fatalf("list of responses is empty") + } + response := doer.responses[0] + + if response.err != nil { + fut.SetError(response.err) + } else { + fut.SetResponse(response.resp.header, bytes.NewBuffer(response.resp.data)) + } + doer.responses = doer.responses[1:] + + return fut +} diff --git a/test_helpers/example_test.go b/test_helpers/example_test.go new file mode 100644 index 000000000..3b1ed5d64 --- /dev/null +++ b/test_helpers/example_test.go @@ -0,0 +1,37 @@ +package test_helpers_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/test_helpers" +) + +func TestExampleMockDoer(t *testing.T) { + mockDoer := test_helpers.NewMockDoer(t, + test_helpers.NewMockResponse(t, []interface{}{"some data"}), + fmt.Errorf("some error"), + test_helpers.NewMockResponse(t, "some typed data"), + fmt.Errorf("some error"), + ) + + data, err := mockDoer.Do(tarantool.NewPingRequest()).Get() + assert.NoError(t, err) + assert.Equal(t, []interface{}{"some data"}, data) + + data, err = mockDoer.Do(tarantool.NewSelectRequest("foo")).Get() + assert.EqualError(t, err, "some error") + assert.Nil(t, data) + + var stringData string + err = mockDoer.Do(tarantool.NewInsertRequest("space")).GetTyped(&stringData) + assert.NoError(t, err) + assert.Equal(t, "some typed data", stringData) + + err = mockDoer.Do(tarantool.NewPrepareRequest("expr")).GetTyped(&stringData) + assert.EqualError(t, err, "some error") + assert.Nil(t, data) +} diff --git a/test_helpers/main.go b/test_helpers/main.go new file mode 100644 index 000000000..c80683d94 --- /dev/null +++ b/test_helpers/main.go @@ -0,0 +1,573 @@ +// Helpers for managing Tarantool process for testing purposes. +// +// Package introduces go helpers for starting a tarantool process and +// validating Tarantool version. Helpers are based on os/exec calls. +// Retries to connect test tarantool instance handled explicitly, +// see tarantool/go-tarantool/#136. +// +// Tarantool's instance Lua scripts use environment variables to configure +// box.cfg. Listen port is set in the end of script so it is possible to +// connect only if every other thing was set up already. +package test_helpers + +import ( + "context" + "errors" + "fmt" + "io" + "io/fs" + "log" + "os" + "os/exec" + "path/filepath" + "regexp" + "strconv" + "time" + + "github.com/tarantool/go-tarantool/v3" +) + +type StartOpts struct { + // Auth is an authentication method for a Tarantool instance. + Auth tarantool.Auth + + // InitScript is a Lua script for tarantool to run on start. + InitScript string + + // ConfigFile is a path to a configuration file for a Tarantool instance. + // Required in pair with InstanceName. + ConfigFile string + + // InstanceName is a name of an instance to run. + // Required in pair with ConfigFile. + InstanceName string + + // Listen is box.cfg listen parameter for tarantool. + // Use this address to connect to tarantool after configuration. + // https://www.tarantool.io/en/doc/latest/reference/configuration/#cfg-basic-listen + Listen string + + // WorkDir is box.cfg work_dir parameter for a Tarantool instance: + // a folder to store data files. If not specified, helpers create a + // new temporary directory. + // Folder must be unique for each Tarantool process used simultaneously. + // https://www.tarantool.io/en/doc/latest/reference/configuration/#confval-work_dir + WorkDir string + + // SslCertsDir is a path to a directory with SSL certificates. It will be + // copied to the working directory. + SslCertsDir string + + // WaitStart is a time to wait before starting to ping tarantool. + WaitStart time.Duration + + // ConnectRetry is a count of retry attempts to ping tarantool. If the + // value < 0 then there will be no ping tarantool at all. + ConnectRetry int + + // RetryTimeout is a time between tarantool ping retries. + RetryTimeout time.Duration + + // MemtxUseMvccEngine is flag to enable transactional + // manager if set to true. + MemtxUseMvccEngine bool + + // Dialer to check that connection established. + Dialer tarantool.Dialer +} + +type state struct { + done chan struct{} + ret error + stopped bool +} + +// TarantoolInstance is a data for instance graceful shutdown and cleanup. +type TarantoolInstance struct { + // Cmd is a Tarantool command. Used to kill Tarantool process. + // + // Deprecated: Cmd field will be removed in the next major version. + // Use `Wait()` and `Stop()` methods, instead of calling `Cmd.Process.Wait()` or + // `Cmd.Process.Kill()` directly. + Cmd *exec.Cmd + + // Options for restarting a tarantool instance. + Opts StartOpts + + // Dialer to check that connection established. + Dialer tarantool.Dialer + + st chan state +} + +// IsExit checks if Tarantool process was terminated. +func (t *TarantoolInstance) IsExit() bool { + st := <-t.st + t.st <- st + + select { + case <-st.done: + default: + return false + } + + return st.ret != nil +} + +func (t *TarantoolInstance) result() error { + st := <-t.st + t.st <- st + + select { + case <-st.done: + default: + return nil + } + + return st.ret +} + +func (t *TarantoolInstance) checkDone() { + ret := t.Cmd.Wait() + + st := <-t.st + + st.ret = ret + close(st.done) + + t.st <- st + + if !st.stopped { + log.Printf("Tarantool %q was unexpectedly terminated: %v", t.Opts.Listen, t.result()) + } +} + +// Wait waits until Tarantool process is terminated. +// Returns error as process result status. +func (t *TarantoolInstance) Wait() error { + st := <-t.st + t.st <- st + + <-st.done + t.Cmd.Process = nil + + st = <-t.st + t.st <- st + + return st.ret +} + +// Stop terminates Tarantool process and waits until it exit. +func (t *TarantoolInstance) Stop() error { + st := <-t.st + st.stopped = true + t.st <- st + + if t.IsExit() { + return nil + } + if t.Cmd != nil && t.Cmd.Process != nil { + if err := t.Cmd.Process.Kill(); err != nil && !t.IsExit() { + return fmt.Errorf("failed to kill tarantool %q (pid %d), got %s", + t.Opts.Listen, t.Cmd.Process.Pid, err) + } + t.Wait() + } + return nil +} + +// Signal sends a signal to the Tarantool instance. +func (t *TarantoolInstance) Signal(sig os.Signal) error { + return t.Cmd.Process.Signal(sig) +} + +func isReady(dialer tarantool.Dialer, opts *tarantool.Opts) error { + var err error + var conn *tarantool.Connection + + ctx, cancel := GetConnectContext() + defer cancel() + conn, err = tarantool.Connect(ctx, dialer, *opts) + if err != nil { + return err + } + if conn == nil { + return errors.New("connection is nil after connect") + } + defer conn.Close() + + _, err = conn.Do(tarantool.NewPingRequest()).Get() + if err != nil { + return err + } + + return nil +} + +var ( + // Used to extract Tarantool version (major.minor.patch). + tarantoolVersionRegexp *regexp.Regexp +) + +func init() { + tarantoolVersionRegexp = regexp.MustCompile(`Tarantool (Enterprise )?(\d+)\.(\d+)\.(\d+).*`) +} + +// atoiUint64 parses string to uint64. +func atoiUint64(str string) (uint64, error) { + res, err := strconv.ParseUint(str, 10, 64) + if err != nil { + return 0, fmt.Errorf("cast to number error (%s)", err) + } + return res, nil +} + +func getTarantoolExec() string { + if tar_bin := os.Getenv("TARANTOOL_BIN"); tar_bin != "" { + return tar_bin + } + return "tarantool" +} + +// IsTarantoolVersionLess checks if tarantool version is less +// than passed . Returns error if failed +// to extract version. +func IsTarantoolVersionLess(majorMin uint64, minorMin uint64, patchMin uint64) (bool, error) { + var major, minor, patch uint64 + + out, err := exec.Command(getTarantoolExec(), "--version").Output() + + if err != nil { + return true, err + } + + parsed := tarantoolVersionRegexp.FindStringSubmatch(string(out)) + + if parsed == nil { + return true, fmt.Errorf("failed to parse output %q", out) + } + + if major, err = atoiUint64(parsed[2]); err != nil { + return true, fmt.Errorf("failed to parse major from output %q: %w", out, err) + } + + if minor, err = atoiUint64(parsed[3]); err != nil { + return true, fmt.Errorf("failed to parse minor from output %q: %w", out, err) + } + + if patch, err = atoiUint64(parsed[4]); err != nil { + return true, fmt.Errorf("failed to parse patch from output %q: %w", out, err) + } + + if major != majorMin { + return major < majorMin, nil + } else if minor != minorMin { + return minor < minorMin, nil + } else { + return patch < patchMin, nil + } +} + +// IsTarantoolEE checks if Tarantool is Enterprise edition. +func IsTarantoolEE() (bool, error) { + out, err := exec.Command(getTarantoolExec(), "--version").Output() + if err != nil { + return true, err + } + + parsed := tarantoolVersionRegexp.FindStringSubmatch(string(out)) + if parsed == nil { + return true, fmt.Errorf("failed to parse output %q", out) + } + + return parsed[1] != "", nil +} + +// RestartTarantool restarts a tarantool instance for tests +// with specifies parameters (refer to StartOpts) +// which were specified in inst parameter. +// inst is a tarantool instance that was started by +// StartTarantool. Rewrites inst.Cmd.Process to stop +// instance with StopTarantool. +// Process must be stopped with StopTarantool. +func RestartTarantool(inst *TarantoolInstance) error { + startedInst, err := StartTarantool(inst.Opts) + + inst.Cmd.Process = startedInst.Cmd.Process + inst.st = startedInst.st + + return err +} + +func removeByMask(dir string, masks ...string) error { + err := filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if d.IsDir() { + return nil + } + + for _, mask := range masks { + if ok, err := filepath.Match(mask, d.Name()); err != nil { + return err + } else if ok { + if err = os.Remove(path); err != nil { + return err + } + } + } + return nil + }) + + if err != nil { + return err + } + return nil +} + +func prepareDir(workDir string) (string, error) { + if workDir == "" { + dir, err := os.MkdirTemp("", "work_dir") + if err != nil { + return "", err + } + return dir, nil + } + // Create work_dir. + err := os.MkdirAll(workDir, 0755) + if err != nil { + return "", err + } + + // Clean up existing work_dir. + err = removeByMask(workDir, "*.snap", "*.xlog") + if err != nil { + return "", err + } + return workDir, nil +} + +// StartTarantool starts a tarantool instance for tests +// with specifies parameters (refer to StartOpts). +// Process must be stopped with StopTarantool. +func StartTarantool(startOpts StartOpts) (*TarantoolInstance, error) { + // Prepare tarantool command. + inst := &TarantoolInstance{ + st: make(chan state, 1), + } + init := state{ + done: make(chan struct{}), + } + inst.st <- init + + var err error + inst.Dialer = startOpts.Dialer + startOpts.WorkDir, err = prepareDir(startOpts.WorkDir) + if err != nil { + return inst, fmt.Errorf("failed to prepare working dir %q: %w", startOpts.WorkDir, err) + } + + args := []string{} + if startOpts.InitScript != "" { + if !filepath.IsAbs(startOpts.InitScript) { + cwd, err := os.Getwd() + if err != nil { + return inst, fmt.Errorf("failed to get current working directory: %w", err) + } + startOpts.InitScript = filepath.Join(cwd, startOpts.InitScript) + } + args = append(args, startOpts.InitScript) + } + if startOpts.ConfigFile != "" && startOpts.InstanceName != "" { + args = append(args, "--config", startOpts.ConfigFile) + args = append(args, "--name", startOpts.InstanceName) + } + inst.Cmd = exec.Command(getTarantoolExec(), args...) + inst.Cmd.Dir = startOpts.WorkDir + + inst.Cmd.Env = append( + os.Environ(), + fmt.Sprintf("TEST_TNT_WORK_DIR=%s", startOpts.WorkDir), + fmt.Sprintf("TEST_TNT_LISTEN=%s", startOpts.Listen), + fmt.Sprintf("TEST_TNT_MEMTX_USE_MVCC_ENGINE=%t", startOpts.MemtxUseMvccEngine), + fmt.Sprintf("TEST_TNT_AUTH_TYPE=%s", startOpts.Auth), + ) + + // Copy SSL certificates. + if startOpts.SslCertsDir != "" { + err = copySslCerts(startOpts.WorkDir, startOpts.SslCertsDir) + if err != nil { + return inst, err + } + } + + // Options for restarting tarantool instance. + inst.Opts = startOpts + + // Start tarantool. + err = inst.Cmd.Start() + if err != nil { + return inst, err + } + + // Try to connect and ping tarantool. + // Using reconnect opts do not help on Connect, + // see https://github.com/tarantool/go-tarantool/issues/136 + time.Sleep(startOpts.WaitStart) + + go inst.checkDone() + + opts := tarantool.Opts{ + Timeout: 500 * time.Millisecond, + SkipSchema: true, + } + + var i int + for i = 0; i <= startOpts.ConnectRetry; i++ { + err = isReady(inst.Dialer, &opts) + + // Both connect and ping is ok. + if err == nil { + break + } + + if i != startOpts.ConnectRetry { + time.Sleep(startOpts.RetryTimeout) + } + } + + if inst.IsExit() && inst.result() != nil { + StopTarantool(inst) + return nil, fmt.Errorf("unexpected terminated Tarantool %q: %w", + inst.Opts.Listen, inst.result()) + } + + if err != nil { + StopTarantool(inst) + return nil, fmt.Errorf("failed to connect Tarantool %q: %w", + inst.Opts.Listen, err) + } + + return inst, nil +} + +// StopTarantool stops a tarantool instance started +// with StartTarantool. Waits until any resources +// associated with the process is released. If something went wrong, fails. +func StopTarantool(inst *TarantoolInstance) { + err := inst.Stop() + if err != nil { + log.Fatal(err) + } +} + +// StopTarantoolWithCleanup stops a tarantool instance started +// with StartTarantool. Waits until any resources +// associated with the process is released. +// Cleans work directory after stop. If something went wrong, fails. +func StopTarantoolWithCleanup(inst *TarantoolInstance) { + StopTarantool(inst) + + if inst.Opts.WorkDir != "" { + if err := os.RemoveAll(inst.Opts.WorkDir); err != nil { + log.Fatalf("Failed to clean work directory, got %s", err) + } + } +} + +func copySslCerts(dst string, sslCertsDir string) (err error) { + dstCertPath := filepath.Join(dst, sslCertsDir) + if err = os.Mkdir(dstCertPath, 0755); err != nil { + return + } + if err = copyDirectoryFiles(sslCertsDir, dstCertPath); err != nil { + return + } + return +} + +func copyDirectoryFiles(scrDir, dest string) error { + entries, err := os.ReadDir(scrDir) + if err != nil { + return err + } + for _, entry := range entries { + if entry.IsDir() { + continue + } + sourcePath := filepath.Join(scrDir, entry.Name()) + destPath := filepath.Join(dest, entry.Name()) + _, err := os.Stat(sourcePath) + if err != nil { + return err + } + + if err := copyFile(sourcePath, destPath); err != nil { + return err + } + + info, err := entry.Info() + if err != nil { + return err + } + + if err := os.Chmod(destPath, info.Mode()); err != nil { + return err + } + } + return nil +} + +func copyFile(srcFile, dstFile string) error { + out, err := os.Create(dstFile) + if err != nil { + return err + } + + defer out.Close() + + in, err := os.Open(srcFile) + if err != nil { + return err + } + defer in.Close() + + _, err = io.Copy(out, in) + if err != nil { + return err + } + + return nil +} + +// msgpack.v5 decodes different uint types depending on value. The +// function helps to unify a result. +func ConvertUint64(v interface{}) (result uint64, err error) { + switch v := v.(type) { + case uint: + result = uint64(v) + case uint8: + result = uint64(v) + case uint16: + result = uint64(v) + case uint32: + result = uint64(v) + case uint64: + result = v + case int: + result = uint64(v) + case int8: + result = uint64(v) + case int16: + result = uint64(v) + case int32: + result = uint64(v) + case int64: + result = uint64(v) + default: + err = fmt.Errorf("non-number value %T", v) + } + return +} + +func GetConnectContext() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), time.Second) +} diff --git a/test_helpers/pool_helper.go b/test_helpers/pool_helper.go new file mode 100644 index 000000000..81eb8e2f3 --- /dev/null +++ b/test_helpers/pool_helper.go @@ -0,0 +1,308 @@ +package test_helpers + +import ( + "context" + "errors" + "fmt" + "reflect" + "sync" + "time" + + "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/pool" +) + +type ListenOnInstanceArgs struct { + ConnPool *pool.ConnectionPool + Mode pool.Mode + ServersNumber int + ExpectedPorts map[string]bool +} + +type CheckStatusesArgs struct { + ConnPool *pool.ConnectionPool + Servers []string + Mode pool.Mode + ExpectedPoolStatus bool + ExpectedStatuses map[string]bool +} + +func compareTuples(expectedTpl []interface{}, actualTpl []interface{}) error { + if len(actualTpl) != len(expectedTpl) { + return fmt.Errorf("unexpected body of Insert (tuple len)") + } + + for i, field := range actualTpl { + if field != expectedTpl[i] { + return fmt.Errorf("unexpected field, expected: %v actual: %v", expectedTpl[i], field) + } + } + + return nil +} + +func CheckPoolStatuses(args interface{}) error { + checkArgs, ok := args.(CheckStatusesArgs) + if !ok { + return fmt.Errorf("incorrect args") + } + + connected, _ := checkArgs.ConnPool.ConnectedNow(checkArgs.Mode) + if connected != checkArgs.ExpectedPoolStatus { + return fmt.Errorf( + "incorrect connection pool status: expected status %t actual status %t", + checkArgs.ExpectedPoolStatus, connected) + } + + poolInfo := checkArgs.ConnPool.GetInfo() + for _, server := range checkArgs.Servers { + status := poolInfo[server].ConnectedNow + if checkArgs.ExpectedStatuses[server] != status { + return fmt.Errorf( + "incorrect conn status: addr %s expected status %t actual status %t", + server, checkArgs.ExpectedStatuses[server], status) + } + } + + return nil +} + +// ProcessListenOnInstance helper calls "return box.cfg.listen" +// as many times as there are servers in the connection pool +// with specified mode. +// For RO mode expected received ports equals to replica ports. +// For RW mode expected received ports equals to master ports. +// For PreferRO mode expected received ports equals to replica +// ports or to all ports. +// For PreferRW mode expected received ports equals to master ports +// or to all ports. +func ProcessListenOnInstance(args interface{}) error { + actualPorts := map[string]bool{} + + listenArgs, ok := args.(ListenOnInstanceArgs) + if !ok { + return fmt.Errorf("incorrect args") + } + + for i := 0; i < listenArgs.ServersNumber; i++ { + req := tarantool.NewEvalRequest("return box.cfg.listen") + data, err := listenArgs.ConnPool.Do(req, listenArgs.Mode).Get() + if err != nil { + return fmt.Errorf("fail to Eval: %s", err.Error()) + } + if len(data) < 1 { + return fmt.Errorf("response.Data is empty after Eval") + } + + port, ok := data[0].(string) + if !ok { + return fmt.Errorf("response.Data is incorrect after Eval") + } + + actualPorts[port] = true + } + + equal := reflect.DeepEqual(actualPorts, listenArgs.ExpectedPorts) + if !equal { + return fmt.Errorf("expected ports: %v, actual ports: %v", + listenArgs.ExpectedPorts, actualPorts) + } + + return nil +} + +func Retry(f func(interface{}) error, args interface{}, count int, timeout time.Duration) error { + var err error + + for i := 0; ; i++ { + err = f(args) + if err == nil { + return err + } + + if i >= (count - 1) { + break + } + + time.Sleep(timeout) + } + + return err +} + +func InsertOnInstance(ctx context.Context, dialer tarantool.Dialer, connOpts tarantool.Opts, + space interface{}, tuple interface{}) error { + conn, err := tarantool.Connect(ctx, dialer, connOpts) + if err != nil { + return fmt.Errorf("fail to connect: %s", err.Error()) + } + if conn == nil { + return fmt.Errorf("conn is nil after Connect") + } + defer conn.Close() + + data, err := conn.Do(tarantool.NewInsertRequest(space).Tuple(tuple)).Get() + if err != nil { + return fmt.Errorf("failed to Insert: %s", err.Error()) + } + if len(data) != 1 { + return fmt.Errorf("response Body len != 1") + } + if tpl, ok := data[0].([]interface{}); !ok { + return fmt.Errorf("unexpected body of Insert") + } else { + expectedTpl, ok := tuple.([]interface{}) + if !ok { + return fmt.Errorf("failed to cast") + } + + err = compareTuples(expectedTpl, tpl) + if err != nil { + return err + } + } + + return nil +} + +func InsertOnInstances( + ctx context.Context, + dialers []tarantool.Dialer, + connOpts tarantool.Opts, + space interface{}, + tuple interface{}) error { + serversNumber := len(dialers) + roles := make([]bool, serversNumber) + for i := 0; i < serversNumber; i++ { + roles[i] = false + } + + err := SetClusterRO(ctx, dialers, connOpts, roles) + if err != nil { + return fmt.Errorf("fail to set roles for cluster: %s", err.Error()) + } + + errs := make([]error, len(dialers)) + var wg sync.WaitGroup + wg.Add(len(dialers)) + for i, dialer := range dialers { + // Pass loop variable(s) to avoid its capturing by reference (not needed since Go 1.22). + go func(i int, dialer tarantool.Dialer) { + defer wg.Done() + errs[i] = InsertOnInstance(ctx, dialer, connOpts, space, tuple) + }(i, dialer) + } + wg.Wait() + + return errors.Join(errs...) +} + +func SetInstanceRO(ctx context.Context, dialer tarantool.Dialer, connOpts tarantool.Opts, + isReplica bool) error { + conn, err := tarantool.Connect(ctx, dialer, connOpts) + if err != nil { + return err + } + + defer conn.Close() + + req := tarantool.NewCallRequest("box.cfg"). + Args([]interface{}{map[string]bool{"read_only": isReplica}}) + if _, err := conn.Do(req).Get(); err != nil { + return err + } + + checkRole := func(conn *tarantool.Connection, isReplica bool) string { + data, err := conn.Do(tarantool.NewCallRequest("box.info")).Get() + switch { + case err != nil: + return fmt.Sprintf("failed to get box.info: %s", err) + case len(data) < 1: + return "box.info is empty" + } + + boxInfo, ok := data[0].(map[interface{}]interface{}) + if !ok { + return "unexpected type in box.info response" + } + + status, statusFound := boxInfo["status"] + readonly, readonlyFound := boxInfo["ro"] + switch { + case !statusFound: + return "box.info.status is missing" + case status != "running": + return fmt.Sprintf("box.info.status='%s' (waiting for 'running')", status) + case !readonlyFound: + return "box.info.ro is missing" + case readonly != isReplica: + return fmt.Sprintf("box.info.ro='%v' (waiting for '%v')", readonly, isReplica) + default: + return "" + } + } + + problem := "not checked yet" + + // Wait for the role to be applied. + for len(problem) != 0 { + select { + case <-time.After(10 * time.Millisecond): + case <-ctx.Done(): + return fmt.Errorf("%w: failed to apply role, the last problem: %s", + ctx.Err(), problem) + } + + problem = checkRole(conn, isReplica) + } + + return nil +} + +func SetClusterRO(ctx context.Context, dialers []tarantool.Dialer, connOpts tarantool.Opts, + roles []bool) error { + if len(dialers) != len(roles) { + return fmt.Errorf("number of servers should be equal to number of roles") + } + + // Apply roles in parallel. + errs := make([]error, len(dialers)) + var wg sync.WaitGroup + wg.Add(len(dialers)) + for i, dialer := range dialers { + // Pass loop variable(s) to avoid its capturing by reference (not needed since Go 1.22). + go func(i int, dialer tarantool.Dialer) { + defer wg.Done() + errs[i] = SetInstanceRO(ctx, dialer, connOpts, roles[i]) + }(i, dialer) + } + wg.Wait() + + return errors.Join(errs...) +} + +func StartTarantoolInstances(instsOpts []StartOpts) ([]*TarantoolInstance, error) { + instances := make([]*TarantoolInstance, 0, len(instsOpts)) + + for _, opts := range instsOpts { + instance, err := StartTarantool(opts) + if err != nil { + StopTarantoolInstances(instances) + return nil, err + } + + instances = append(instances, instance) + } + + return instances, nil +} + +func StopTarantoolInstances(instances []*TarantoolInstance) { + for _, instance := range instances { + StopTarantoolWithCleanup(instance) + } +} + +func GetPoolConnectContext() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), time.Second) +} diff --git a/test_helpers/request.go b/test_helpers/request.go new file mode 100644 index 000000000..003a97ab3 --- /dev/null +++ b/test_helpers/request.go @@ -0,0 +1,52 @@ +package test_helpers + +import ( + "context" + "io" + + "github.com/tarantool/go-iproto" + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// MockRequest is an empty mock request used for testing purposes. +type MockRequest struct { +} + +// NewMockRequest creates an empty MockRequest. +func NewMockRequest() *MockRequest { + return &MockRequest{} +} + +// Type returns an iproto type for MockRequest. +func (req *MockRequest) Type() iproto.Type { + return iproto.Type(0) +} + +// Async returns if MockRequest expects a response. +func (req *MockRequest) Async() bool { + return false +} + +// Body fills an msgpack.Encoder with the watch request body. +func (req *MockRequest) Body(resolver tarantool.SchemaResolver, enc *msgpack.Encoder) error { + return nil +} + +// Conn returns the Connection object the request belongs to. +func (req *MockRequest) Conn() *tarantool.Connection { + return &tarantool.Connection{} +} + +// Ctx returns a context of the MockRequest. +func (req *MockRequest) Ctx() context.Context { + return nil +} + +// Response creates a response for the MockRequest. +func (req *MockRequest) Response(header tarantool.Header, + body io.Reader) (tarantool.Response, error) { + resp, err := CreateMockResponse(header, body) + return resp, err +} diff --git a/test_helpers/response.go b/test_helpers/response.go new file mode 100644 index 000000000..630ac7726 --- /dev/null +++ b/test_helpers/response.go @@ -0,0 +1,73 @@ +package test_helpers + +import ( + "bytes" + "io" + "testing" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/tarantool/go-tarantool/v3" +) + +// MockResponse is a mock response used for testing purposes. +type MockResponse struct { + // header contains response header + header tarantool.Header + // data contains data inside a response. + data []byte +} + +// NewMockResponse creates a new MockResponse with an empty header and the given data. +// body should be passed as a structure to be encoded. +// The encoded body is served as response data and will be decoded once the +// response is decoded. +func NewMockResponse(t *testing.T, body interface{}) *MockResponse { + t.Helper() + + buf := bytes.NewBuffer([]byte{}) + enc := msgpack.NewEncoder(buf) + + err := enc.Encode(body) + if err != nil { + t.Errorf("unexpected error while encoding: %s", err) + } + + return &MockResponse{data: buf.Bytes()} +} + +// CreateMockResponse creates a MockResponse from the header and a data, +// packed inside an io.Reader. +func CreateMockResponse(header tarantool.Header, body io.Reader) (*MockResponse, error) { + if body == nil { + return &MockResponse{header: header, data: nil}, nil + } + data, err := io.ReadAll(body) + if err != nil { + return nil, err + } + return &MockResponse{header: header, data: data}, nil +} + +// Header returns a header for the MockResponse. +func (resp *MockResponse) Header() tarantool.Header { + return resp.header +} + +// Decode returns the result of decoding the response data as slice. +func (resp *MockResponse) Decode() ([]interface{}, error) { + if resp.data == nil { + return nil, nil + } + dec := msgpack.NewDecoder(bytes.NewBuffer(resp.data)) + return dec.DecodeSlice() +} + +// DecodeTyped returns the result of decoding the response data. +func (resp *MockResponse) DecodeTyped(res interface{}) error { + if resp.data == nil { + return nil + } + dec := msgpack.NewDecoder(bytes.NewBuffer(resp.data)) + return dec.Decode(res) +} diff --git a/test_helpers/tcs/prepare.go b/test_helpers/tcs/prepare.go new file mode 100644 index 000000000..92a13afb1 --- /dev/null +++ b/test_helpers/tcs/prepare.go @@ -0,0 +1,66 @@ +package tcs + +import ( + _ "embed" + "fmt" + "os" + "path/filepath" + "text/template" + "time" + + "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/test_helpers" +) + +const ( + waitTimeout = 500 * time.Millisecond + connectRetry = 3 + tcsUser = "client" + tcsPassword = "secret" +) + +//go:embed testdata/config.yaml +var tcsConfig []byte + +func writeConfig(name string, port int) error { + cfg, err := os.Create(name) + if err != nil { + return err + } + defer cfg.Close() + + cfg.Chmod(0644) + + t := template.Must(template.New("config").Parse(string(tcsConfig))) + return t.Execute(cfg, map[string]interface{}{ + "host": "localhost", + "port": port, + }) +} + +func makeOpts(port int) (test_helpers.StartOpts, error) { + opts := test_helpers.StartOpts{} + var err error + opts.WorkDir, err = os.MkdirTemp("", "tcs_dir") + if err != nil { + return opts, err + } + + opts.ConfigFile = filepath.Join(opts.WorkDir, "config.yaml") + err = writeConfig(opts.ConfigFile, port) + if err != nil { + return opts, fmt.Errorf("can't save file %q: %w", opts.ConfigFile, err) + } + + opts.Listen = fmt.Sprintf("localhost:%d", port) + opts.WaitStart = waitTimeout + opts.ConnectRetry = connectRetry + opts.RetryTimeout = waitTimeout + opts.InstanceName = "master" + opts.Dialer = tarantool.NetDialer{ + Address: opts.Listen, + User: tcsUser, + Password: tcsPassword, + } + return opts, nil +} diff --git a/test_helpers/tcs/tcs.go b/test_helpers/tcs/tcs.go new file mode 100644 index 000000000..a54ba5fda --- /dev/null +++ b/test_helpers/tcs/tcs.go @@ -0,0 +1,180 @@ +package tcs + +import ( + "context" + "errors" + "fmt" + "net" + "testing" + + "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/test_helpers" +) + +// ErrNotSupported identifies result of `Start()` why storage was not started. +var ErrNotSupported = errors.New("required Tarantool EE 3.3+") + +// ErrNoValue used to show that `Get()` was successful, but no values were found. +var ErrNoValue = errors.New("required value not found") + +// TCS is a Tarantool centralized configuration storage connection. +type TCS struct { + inst *test_helpers.TarantoolInstance + conn *tarantool.Connection + tb testing.TB + port int +} + +// dataResponse content of TcS response in data array. +type dataResponse struct { + Path string `msgpack:"path"` + Value string `msgpack:"value"` + ModRevision int64 `msgpack:"mod_revision"` +} + +// findEmptyPort returns some random unused port if @port is passed with zero. +func findEmptyPort(port int) (int, error) { + listener, err := net.Listen("tcp4", fmt.Sprintf(":%d", port)) + if err != nil { + return 0, err + } + defer listener.Close() + + addr := listener.Addr().(*net.TCPAddr) + return addr.Port, nil +} + +// Start starts a Tarantool centralized configuration storage. +// Use `port = 0` to use any unused port. +// Returns a Tcs instance and a cleanup function. +func Start(port int) (TCS, error) { + tcs := TCS{} + if ok, err := test_helpers.IsTcsSupported(); !ok || err != nil { + return tcs, errors.Join(ErrNotSupported, err) + } + var err error + tcs.port, err = findEmptyPort(port) + if err != nil { + if port == 0 { + return tcs, fmt.Errorf("failed to detect an empty port: %w", err) + } else { + return tcs, fmt.Errorf("port %d can't be used: %w", port, err) + } + } + + opts, err := makeOpts(tcs.port) + if err != nil { + return tcs, err + } + + tcs.inst, err = test_helpers.StartTarantool(opts) + if err != nil { + return tcs, fmt.Errorf("failed to start Tarantool config storage: %w", err) + } + + tcs.conn, err = tarantool.Connect(context.Background(), tcs.inst.Dialer, tarantool.Opts{}) + if err != nil { + return tcs, fmt.Errorf("failed to connect to Tarantool config storage: %w", err) + } + + return tcs, nil +} + +// Start starts a Tarantool centralized configuration storage. +// Returns a Tcs instance and a cleanup function. +func StartTesting(tb testing.TB, port int) TCS { + tcs, err := Start(port) + if err != nil { + tb.Fatal(err) + } + return tcs +} + +// Doer returns interface for interacting with Tarantool. +func (t *TCS) Doer() tarantool.Doer { + return t.conn +} + +// Dialer returns a dialer to connect to Tarantool. +func (t *TCS) Dialer() tarantool.Dialer { + return t.inst.Dialer +} + +// Endpoints returns a list of addresses to connect. +func (t *TCS) Endpoints() []string { + return []string{fmt.Sprintf("127.0.0.1:%d", t.port)} +} + +// Credentials returns a user name and password to connect. +func (t *TCS) Credentials() (string, string) { + return tcsUser, tcsPassword +} + +// Stop stops the Tarantool centralized configuration storage. +func (t *TCS) Stop() { + if t.tb != nil { + t.tb.Helper() + } + if t.conn != nil { + t.conn.Close() + } + test_helpers.StopTarantoolWithCleanup(t.inst) +} + +// Put implements "config.storage.put" method. +func (t *TCS) Put(ctx context.Context, path string, value string) error { + if t.tb != nil { + t.tb.Helper() + } + req := tarantool.NewCallRequest("config.storage.put"). + Args([]any{path, value}). + Context(ctx) + if _, err := t.conn.Do(req).GetResponse(); err != nil { + return fmt.Errorf("failed to save data to tarantool: %w", err) + } + return nil +} + +// Delete implements "config.storage.delete" method. +func (t *TCS) Delete(ctx context.Context, path string) error { + if t.tb != nil { + t.tb.Helper() + } + req := tarantool.NewCallRequest("config.storage.delete"). + Args([]any{path}). + Context(ctx) + if _, err := t.conn.Do(req).GetResponse(); err != nil { + return fmt.Errorf("failed to delete data from tarantool: %w", err) + } + return nil +} + +// Get implements "config.storage.get" method. +func (t *TCS) Get(ctx context.Context, path string) (string, error) { + if t.tb != nil { + t.tb.Helper() + } + req := tarantool.NewCallRequest("config.storage.get"). + Args([]any{path}). + Context(ctx) + + resp := []struct { + Data []dataResponse `msgpack:"data"` + }{} + + err := t.conn.Do(req).GetTyped(&resp) + if err != nil { + return "", fmt.Errorf("failed to fetch data from tarantool: %w", err) + } + if len(resp) != 1 { + return "", errors.New("unexpected response from tarantool") + } + if len(resp[0].Data) == 0 { + return "", ErrNoValue + } + if len(resp[0].Data) != 1 { + return "", errors.New("too much data in response from tarantool") + } + + return resp[0].Data[0].Value, nil +} diff --git a/test_helpers/tcs/testdata/config.yaml b/test_helpers/tcs/testdata/config.yaml new file mode 100644 index 000000000..5d42e32aa --- /dev/null +++ b/test_helpers/tcs/testdata/config.yaml @@ -0,0 +1,39 @@ +credentials: + users: + replicator: + password: "topsecret" + roles: [replication] + client: + password: "secret" + privileges: + - permissions: [execute] + universe: true + - permissions: [read, write] + spaces: [config_storage, config_storage_meta] + +iproto: + advertise: + peer: + login: replicator + +replication: + failover: election + +database: + use_mvcc_engine: true + +groups: + group-001: + replicasets: + replicaset-001: + roles: [config.storage] + roles_cfg: + config_storage: + status_check_interval: 3 + instances: + master: + iproto: + listen: + - uri: "{{.host}}:{{.port}}" + params: + transport: plain diff --git a/test_helpers/utils.go b/test_helpers/utils.go new file mode 100644 index 000000000..d844a822a --- /dev/null +++ b/test_helpers/utils.go @@ -0,0 +1,302 @@ +package test_helpers + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/tarantool/go-tarantool/v3" +) + +// ConnectWithValidation tries to connect to a Tarantool instance. +// It returns a valid connection if it is successful, otherwise finishes a test +// with an error. +func ConnectWithValidation(t testing.TB, + dialer tarantool.Dialer, + opts tarantool.Opts) *tarantool.Connection { + t.Helper() + + ctx, cancel := GetConnectContext() + defer cancel() + conn, err := tarantool.Connect(ctx, dialer, opts) + if err != nil { + t.Fatalf("Failed to connect: %s", err.Error()) + } + if conn == nil { + t.Fatalf("conn is nil after Connect") + } + return conn +} + +func DeleteRecordByKey(t *testing.T, conn tarantool.Connector, + space interface{}, index interface{}, key []interface{}) { + t.Helper() + + req := tarantool.NewDeleteRequest(space). + Index(index). + Key(key) + resp, err := conn.Do(req).Get() + if err != nil { + t.Fatalf("Failed to Delete: %s", err.Error()) + } + if resp == nil { + t.Fatalf("Response is nil after Select") + } +} + +// WaitUntilReconnected waits until connection is reestablished. +// Returns false in case of connection is not in the connected state +// after specified retries count, true otherwise. +func WaitUntilReconnected(conn *tarantool.Connection, retries uint, timeout time.Duration) bool { + err := Retry(func(arg interface{}) error { + conn := arg.(*tarantool.Connection) + connected := conn.ConnectedNow() + if !connected { + return fmt.Errorf("not connected") + } + return nil + }, conn, int(retries), timeout) + + return err == nil +} + +func SkipIfSQLUnsupported(t testing.TB) { + t.Helper() + + // Tarantool supports SQL since version 2.0.0 + isLess, err := IsTarantoolVersionLess(2, 0, 0) + if err != nil { + t.Fatalf("Could not check the Tarantool version: %s", err) + } + if isLess { + t.Skip() + } +} + +// SkipIfLess skips test run if Tarantool version is less than expected. +func SkipIfLess(t *testing.T, reason string, major, minor, patch uint64) { + t.Helper() + + isLess, err := IsTarantoolVersionLess(major, minor, patch) + if err != nil { + t.Fatalf("Could not check the Tarantool version: %s", err) + } + + if isLess { + t.Skipf("Skipping test for Tarantool %s", reason) + } +} + +// SkipIfGreaterOrEqual skips test run if Tarantool version is greater or equal +// than expected. +func SkipIfGreaterOrEqual(t *testing.T, reason string, major, minor, patch uint64) { + t.Helper() + + isLess, err := IsTarantoolVersionLess(major, minor, patch) + if err != nil { + t.Fatalf("Could not check the Tarantool version: %s", err) + } + + if !isLess { + t.Skipf("Skipping test for Tarantool %s", reason) + } +} + +// SkipIfFeatureUnsupported skips test run if Tarantool does not yet support a feature. +func SkipIfFeatureUnsupported(t *testing.T, feature string, major, minor, patch uint64) { + t.Helper() + + SkipIfLess(t, fmt.Sprintf("without %s support", feature), major, minor, patch) +} + +// SkipIfFeatureSupported skips test run if Tarantool supports a feature. +// Helper if useful when we want to test if everything is alright +// on older versions. +func SkipIfFeatureSupported(t *testing.T, feature string, major, minor, patch uint64) { + t.Helper() + + SkipIfGreaterOrEqual(t, fmt.Sprintf("with %s support", feature), major, minor, patch) +} + +// SkipIfFeatureDropped skips test run if Tarantool had dropped +// support of a feature. +func SkipIfFeatureDropped(t *testing.T, feature string, major, minor, patch uint64) { + t.Helper() + + SkipIfGreaterOrEqual(t, fmt.Sprintf("with %s support dropped", feature), major, minor, patch) +} + +// SkipOfStreamsUnsupported skips test run if Tarantool without streams +// support is used. +func SkipIfStreamsUnsupported(t *testing.T) { + t.Helper() + + SkipIfFeatureUnsupported(t, "streams", 2, 10, 0) +} + +// SkipOfStreamsUnsupported skips test run if Tarantool without watchers +// support is used. +func SkipIfWatchersUnsupported(t *testing.T) { + t.Helper() + + SkipIfFeatureUnsupported(t, "watchers", 2, 10, 0) +} + +// SkipIfWatchersSupported skips test run if Tarantool with watchers +// support is used. +func SkipIfWatchersSupported(t *testing.T) { + t.Helper() + + SkipIfFeatureSupported(t, "watchers", 2, 10, 0) +} + +// SkipIfIdUnsupported skips test run if Tarantool without +// IPROTO_ID support is used. +func SkipIfIdUnsupported(t *testing.T) { + t.Helper() + + SkipIfFeatureUnsupported(t, "id requests", 2, 10, 0) +} + +// SkipIfIdSupported skips test run if Tarantool with +// IPROTO_ID support is used. Skip is useful for tests validating +// that protocol info is processed as expected even for pre-IPROTO_ID instances. +func SkipIfIdSupported(t *testing.T) { + t.Helper() + + SkipIfFeatureSupported(t, "id requests", 2, 10, 0) +} + +// SkipIfErrorExtendedInfoUnsupported skips test run if Tarantool without +// IPROTO_ERROR (0x52) support is used. +func SkipIfErrorExtendedInfoUnsupported(t *testing.T) { + t.Helper() + + SkipIfFeatureUnsupported(t, "error extended info", 2, 4, 1) +} + +// SkipIfErrorMessagePackTypeUnsupported skips test run if Tarantool without +// MP_ERROR type over iproto support is used. +func SkipIfErrorMessagePackTypeUnsupported(t *testing.T) { + t.Helper() + + SkipIfFeatureUnsupported(t, "error type in MessagePack", 2, 10, 0) +} + +// SkipIfPaginationUnsupported skips test run if Tarantool without +// pagination is used. +func SkipIfPaginationUnsupported(t *testing.T) { + t.Helper() + + SkipIfFeatureUnsupported(t, "pagination", 2, 11, 0) +} + +// SkipIfWatchOnceUnsupported skips test run if Tarantool without WatchOnce +// request type is used. +func SkipIfWatchOnceUnsupported(t *testing.T) { + t.Helper() + + SkipIfFeatureUnsupported(t, "watch once", 3, 0, 0) +} + +// SkipIfWatchOnceSupported skips test run if Tarantool with WatchOnce +// request type is used. +func SkipIfWatchOnceSupported(t *testing.T) { + t.Helper() + + SkipIfFeatureSupported(t, "watch once", 3, 0, 0) +} + +// SkipIfCrudSpliceBroken skips test run if splice operation is broken +// on the crud side. +// https://github.com/tarantool/crud/issues/397 +func SkipIfCrudSpliceBroken(t *testing.T) { + t.Helper() + + SkipIfFeatureUnsupported(t, "crud update splice", 2, 0, 0) +} + +// SkipIfIsSyncUnsupported skips test run if Tarantool without +// IS_SYNC support is used. +func SkipIfIsSyncUnsupported(t *testing.T) { + t.Helper() + + SkipIfFeatureUnsupported(t, "is sync", 3, 1, 0) +} + +// IsTcsSupported checks if Tarantool supports centralized storage. +// Tarantool supports centralized storage with Enterprise since 3.3.0 version. +func IsTcsSupported() (bool, error) { + + if isEe, err := IsTarantoolEE(); !isEe || err != nil { + return false, err + } + if isLess, err := IsTarantoolVersionLess(3, 3, 0); isLess || err != nil { + return false, err + } + return true, nil +} + +// SkipIfTCSUnsupported skips test if no centralized storage support. +func SkipIfTcsUnsupported(t testing.TB) { + t.Helper() + + ok, err := IsTcsSupported() + if err != nil { + t.Fatalf("Could not check the Tarantool version: %s", err) + } + if !ok { + t.Skip("not found Tarantool EE 3.3+") + } +} + +// CheckEqualBoxErrors checks equivalence of tarantool.BoxError objects. +// +// Tarantool errors are not comparable by nature: +// +// tarantool> msgpack.decode(mp_error_repr) == msgpack.decode(mp_error_repr) +// --- +// - false +// ... +// +// Tarantool error file and line could differ even between +// different patches. +// +// So we check equivalence of all attributes except for Line and File. +// For Line and File, we check that they are filled with some non-default values +// (lines are counted starting with 1 and empty file path is not expected too). +func CheckEqualBoxErrors(t *testing.T, expected tarantool.BoxError, actual tarantool.BoxError) { + t.Helper() + + require.Equalf(t, expected.Depth(), actual.Depth(), "Error stack depth is the same") + + for { + require.Equal(t, expected.Type, actual.Type) + require.Greater(t, len(expected.File), 0) + require.Greater(t, expected.Line, uint64(0)) + require.Equal(t, expected.Msg, actual.Msg) + require.Equal(t, expected.Errno, actual.Errno) + require.Equal(t, expected.Code, actual.Code) + require.Subset(t, actual.Fields, expected.Fields) + + if expected.Prev != nil { + // Stack depth is the same + expected = *expected.Prev + actual = *actual.Prev + } else { + break + } + } +} + +// Ptr returns a pointer to an existing value. +// +// Example: +// +// func NewInt() int { return 1 } +// var b *int = Ptr(NewInt()) +func Ptr[T any](val T) *T { + return &val +} diff --git a/testdata/requests/begin-with-txn-isolation-is-sync-timeout.msgpack b/testdata/requests/begin-with-txn-isolation-is-sync-timeout.msgpack new file mode 100644 index 000000000..dc04ebcd0 Binary files /dev/null and b/testdata/requests/begin-with-txn-isolation-is-sync-timeout.msgpack differ diff --git a/testdata/requests/begin-with-txn-isolation-is-sync.msgpack b/testdata/requests/begin-with-txn-isolation-is-sync.msgpack new file mode 100644 index 000000000..962d8b233 --- /dev/null +++ b/testdata/requests/begin-with-txn-isolation-is-sync.msgpack @@ -0,0 +1 @@ +�Ya� \ No newline at end of file diff --git a/testdata/requests/begin-with-txn-isolation.msgpack b/testdata/requests/begin-with-txn-isolation.msgpack new file mode 100644 index 000000000..bce57fe8c --- /dev/null +++ b/testdata/requests/begin-with-txn-isolation.msgpack @@ -0,0 +1 @@ +�Y \ No newline at end of file diff --git a/testdata/requests/begin.msgpack b/testdata/requests/begin.msgpack new file mode 100644 index 000000000..5416677bc --- /dev/null +++ b/testdata/requests/begin.msgpack @@ -0,0 +1 @@ +� \ No newline at end of file diff --git a/testdata/requests/call-no-args.msgpack b/testdata/requests/call-no-args.msgpack new file mode 100644 index 000000000..17e7abbb0 --- /dev/null +++ b/testdata/requests/call-no-args.msgpack @@ -0,0 +1 @@ +�"�function.name!� \ No newline at end of file diff --git a/testdata/requests/call-with-args-empty-array.msgpack b/testdata/requests/call-with-args-empty-array.msgpack new file mode 100644 index 000000000..17e7abbb0 --- /dev/null +++ b/testdata/requests/call-with-args-empty-array.msgpack @@ -0,0 +1 @@ +�"�function.name!� \ No newline at end of file diff --git a/testdata/requests/call-with-args-mixed.msgpack b/testdata/requests/call-with-args-mixed.msgpack new file mode 100644 index 000000000..8a10ddcfa Binary files /dev/null and b/testdata/requests/call-with-args-mixed.msgpack differ diff --git a/testdata/requests/call-with-args-nil.msgpack b/testdata/requests/call-with-args-nil.msgpack new file mode 100644 index 000000000..17e7abbb0 --- /dev/null +++ b/testdata/requests/call-with-args-nil.msgpack @@ -0,0 +1 @@ +�"�function.name!� \ No newline at end of file diff --git a/testdata/requests/call-with-args.msgpack b/testdata/requests/call-with-args.msgpack new file mode 100644 index 000000000..15eeb65db --- /dev/null +++ b/testdata/requests/call-with-args.msgpack @@ -0,0 +1 @@ +�"�function.name!� \ No newline at end of file diff --git a/testdata/requests/call16-with-args-nil.msgpack b/testdata/requests/call16-with-args-nil.msgpack new file mode 100644 index 000000000..17e7abbb0 --- /dev/null +++ b/testdata/requests/call16-with-args-nil.msgpack @@ -0,0 +1 @@ +�"�function.name!� \ No newline at end of file diff --git a/testdata/requests/call16-with-args.msgpack b/testdata/requests/call16-with-args.msgpack new file mode 100644 index 000000000..8a10ddcfa Binary files /dev/null and b/testdata/requests/call16-with-args.msgpack differ diff --git a/testdata/requests/call16.msgpack b/testdata/requests/call16.msgpack new file mode 100644 index 000000000..17e7abbb0 --- /dev/null +++ b/testdata/requests/call16.msgpack @@ -0,0 +1 @@ +�"�function.name!� \ No newline at end of file diff --git a/testdata/requests/call17-with-args-nil.msgpack b/testdata/requests/call17-with-args-nil.msgpack new file mode 100644 index 000000000..17e7abbb0 --- /dev/null +++ b/testdata/requests/call17-with-args-nil.msgpack @@ -0,0 +1 @@ +�"�function.name!� \ No newline at end of file diff --git a/testdata/requests/call17-with-args.msgpack b/testdata/requests/call17-with-args.msgpack new file mode 100644 index 000000000..8a10ddcfa Binary files /dev/null and b/testdata/requests/call17-with-args.msgpack differ diff --git a/testdata/requests/call17.msgpack b/testdata/requests/call17.msgpack new file mode 100644 index 000000000..17e7abbb0 --- /dev/null +++ b/testdata/requests/call17.msgpack @@ -0,0 +1 @@ +�"�function.name!� \ No newline at end of file diff --git a/testdata/requests/commit-raw.msgpack b/testdata/requests/commit-raw.msgpack new file mode 100644 index 000000000..5416677bc --- /dev/null +++ b/testdata/requests/commit-raw.msgpack @@ -0,0 +1 @@ +� \ No newline at end of file diff --git a/testdata/requests/commit-with-sync-false.msgpack b/testdata/requests/commit-with-sync-false.msgpack new file mode 100644 index 000000000..1c1f8b153 --- /dev/null +++ b/testdata/requests/commit-with-sync-false.msgpack @@ -0,0 +1 @@ +�a� \ No newline at end of file diff --git a/testdata/requests/commit-with-sync.msgpack b/testdata/requests/commit-with-sync.msgpack new file mode 100644 index 000000000..118311ccd --- /dev/null +++ b/testdata/requests/commit-with-sync.msgpack @@ -0,0 +1 @@ +�a� \ No newline at end of file diff --git a/testdata/requests/delete-raw.msgpack b/testdata/requests/delete-raw.msgpack new file mode 100644 index 000000000..8d001f87c Binary files /dev/null and b/testdata/requests/delete-raw.msgpack differ diff --git a/testdata/requests/delete-sname-iname.msgpack b/testdata/requests/delete-sname-iname.msgpack new file mode 100644 index 000000000..49acbef2f --- /dev/null +++ b/testdata/requests/delete-sname-iname.msgpack @@ -0,0 +1 @@ +�^�table_name_�index_name �{ \ No newline at end of file diff --git a/testdata/requests/delete-sname-inumber.msgpack b/testdata/requests/delete-sname-inumber.msgpack new file mode 100644 index 000000000..9390f6e9b --- /dev/null +++ b/testdata/requests/delete-sname-inumber.msgpack @@ -0,0 +1 @@ +�^�table_name{ �{ \ No newline at end of file diff --git a/testdata/requests/delete-snumber-iname.msgpack b/testdata/requests/delete-snumber-iname.msgpack new file mode 100644 index 000000000..5805ea28e --- /dev/null +++ b/testdata/requests/delete-snumber-iname.msgpack @@ -0,0 +1 @@ +���_�index_name �{ \ No newline at end of file diff --git a/testdata/requests/delete-snumber-inumber.msgpack b/testdata/requests/delete-snumber-inumber.msgpack new file mode 100644 index 000000000..39c85022e --- /dev/null +++ b/testdata/requests/delete-snumber-inumber.msgpack @@ -0,0 +1 @@ +���{ �{ \ No newline at end of file diff --git a/testdata/requests/delete.msgpack b/testdata/requests/delete.msgpack new file mode 100644 index 000000000..8d001f87c Binary files /dev/null and b/testdata/requests/delete.msgpack differ diff --git a/testdata/requests/eval-with-args.msgpack b/testdata/requests/eval-with-args.msgpack new file mode 100644 index 000000000..6ab7b9953 Binary files /dev/null and b/testdata/requests/eval-with-args.msgpack differ diff --git a/testdata/requests/eval-with-empty-array.msgpack b/testdata/requests/eval-with-empty-array.msgpack new file mode 100644 index 000000000..152e20ef1 --- /dev/null +++ b/testdata/requests/eval-with-empty-array.msgpack @@ -0,0 +1 @@ +�'�function_name()!� \ No newline at end of file diff --git a/testdata/requests/eval-with-nil.msgpack b/testdata/requests/eval-with-nil.msgpack new file mode 100644 index 000000000..152e20ef1 --- /dev/null +++ b/testdata/requests/eval-with-nil.msgpack @@ -0,0 +1 @@ +�'�function_name()!� \ No newline at end of file diff --git a/testdata/requests/eval-with-single-number.msgpack b/testdata/requests/eval-with-single-number.msgpack new file mode 100644 index 000000000..0d75cc10e --- /dev/null +++ b/testdata/requests/eval-with-single-number.msgpack @@ -0,0 +1 @@ +�'�function_name()! \ No newline at end of file diff --git a/testdata/requests/eval.msgpack b/testdata/requests/eval.msgpack new file mode 100644 index 000000000..152e20ef1 --- /dev/null +++ b/testdata/requests/eval.msgpack @@ -0,0 +1 @@ +�'�function_name()!� \ No newline at end of file diff --git a/testdata/requests/insert-sname.msgpack b/testdata/requests/insert-sname.msgpack new file mode 100644 index 000000000..6e5f221a7 Binary files /dev/null and b/testdata/requests/insert-sname.msgpack differ diff --git a/testdata/requests/insert-snumber.msgpack b/testdata/requests/insert-snumber.msgpack new file mode 100644 index 000000000..2fd7d3af3 Binary files /dev/null and b/testdata/requests/insert-snumber.msgpack differ diff --git a/testdata/requests/ping.msgpack b/testdata/requests/ping.msgpack new file mode 100644 index 000000000..5416677bc --- /dev/null +++ b/testdata/requests/ping.msgpack @@ -0,0 +1 @@ +� \ No newline at end of file diff --git a/testdata/requests/replace-sname.msgpack b/testdata/requests/replace-sname.msgpack new file mode 100644 index 000000000..6e5f221a7 Binary files /dev/null and b/testdata/requests/replace-sname.msgpack differ diff --git a/testdata/requests/replace-snumber.msgpack b/testdata/requests/replace-snumber.msgpack new file mode 100644 index 000000000..2fd7d3af3 Binary files /dev/null and b/testdata/requests/replace-snumber.msgpack differ diff --git a/testdata/requests/rollback.msgpack b/testdata/requests/rollback.msgpack new file mode 100644 index 000000000..5416677bc --- /dev/null +++ b/testdata/requests/rollback.msgpack @@ -0,0 +1 @@ +� \ No newline at end of file diff --git a/testdata/requests/select b/testdata/requests/select new file mode 100644 index 000000000..ebed86cad Binary files /dev/null and b/testdata/requests/select differ diff --git a/testdata/requests/select-key-sname-iname.msgpack b/testdata/requests/select-key-sname-iname.msgpack new file mode 100644 index 000000000..7b9bbe70f Binary files /dev/null and b/testdata/requests/select-key-sname-iname.msgpack differ diff --git a/testdata/requests/select-key-sname-inumber.msgpack b/testdata/requests/select-key-sname-inumber.msgpack new file mode 100644 index 000000000..e8cdb0bd9 Binary files /dev/null and b/testdata/requests/select-key-sname-inumber.msgpack differ diff --git a/testdata/requests/select-key-snumber-iname.msgpack b/testdata/requests/select-key-snumber-iname.msgpack new file mode 100644 index 000000000..f686f9afc Binary files /dev/null and b/testdata/requests/select-key-snumber-iname.msgpack differ diff --git a/testdata/requests/select-key-snumber-inumber.msgpack b/testdata/requests/select-key-snumber-inumber.msgpack new file mode 100644 index 000000000..29ea666d4 Binary files /dev/null and b/testdata/requests/select-key-snumber-inumber.msgpack differ diff --git a/testdata/requests/select-sname-iname.msgpack b/testdata/requests/select-sname-iname.msgpack new file mode 100644 index 000000000..cfd15c45f Binary files /dev/null and b/testdata/requests/select-sname-iname.msgpack differ diff --git a/testdata/requests/select-sname-inumber.msgpack b/testdata/requests/select-sname-inumber.msgpack new file mode 100644 index 000000000..eee6666fe Binary files /dev/null and b/testdata/requests/select-sname-inumber.msgpack differ diff --git a/testdata/requests/select-snumber-iname.msgpack b/testdata/requests/select-snumber-iname.msgpack new file mode 100644 index 000000000..928e75426 Binary files /dev/null and b/testdata/requests/select-snumber-iname.msgpack differ diff --git a/testdata/requests/select-snumber-inumber.msgpack b/testdata/requests/select-snumber-inumber.msgpack new file mode 100644 index 000000000..8470a2045 Binary files /dev/null and b/testdata/requests/select-snumber-inumber.msgpack differ diff --git a/testdata/requests/select-with-after.msgpack b/testdata/requests/select-with-after.msgpack new file mode 100644 index 000000000..9de267c51 Binary files /dev/null and b/testdata/requests/select-with-after.msgpack differ diff --git a/testdata/requests/select-with-key.msgpack b/testdata/requests/select-with-key.msgpack new file mode 100644 index 000000000..f0976da10 Binary files /dev/null and b/testdata/requests/select-with-key.msgpack differ diff --git a/testdata/requests/select-with-optionals.msgpack b/testdata/requests/select-with-optionals.msgpack new file mode 100644 index 000000000..4915ece97 Binary files /dev/null and b/testdata/requests/select-with-optionals.msgpack differ diff --git a/testdata/requests/update-sname-iname.msgpack b/testdata/requests/update-sname-iname.msgpack new file mode 100644 index 000000000..b085c995b --- /dev/null +++ b/testdata/requests/update-sname-iname.msgpack @@ -0,0 +1,2 @@ +�^�table_name_�index_name �{!���+�test��=�fest��#���!�insert��:�splice��-�subtract��& +��| ��^ \ No newline at end of file diff --git a/testdata/requests/update-sname-inumber.msgpack b/testdata/requests/update-sname-inumber.msgpack new file mode 100644 index 000000000..f603ddf34 --- /dev/null +++ b/testdata/requests/update-sname-inumber.msgpack @@ -0,0 +1,2 @@ +�^�table_name{ �{!���+�test��=�fest��#���!�insert��:�splice��-�subtract��& +��| ��^ \ No newline at end of file diff --git a/testdata/requests/update-snumber-iname.msgpack b/testdata/requests/update-snumber-iname.msgpack new file mode 100644 index 000000000..cd3ec4317 --- /dev/null +++ b/testdata/requests/update-snumber-iname.msgpack @@ -0,0 +1,2 @@ +�{_�index_name �{!���+�test��=�fest��#���!�insert��:�splice��-�subtract��& +��| ��^ \ No newline at end of file diff --git a/testdata/requests/update.msgpack b/testdata/requests/update.msgpack new file mode 100644 index 000000000..3749bbcad Binary files /dev/null and b/testdata/requests/update.msgpack differ diff --git a/testdata/requests/upsert-sname.msgpack b/testdata/requests/upsert-sname.msgpack new file mode 100644 index 000000000..3620cbc1e Binary files /dev/null and b/testdata/requests/upsert-sname.msgpack differ diff --git a/testdata/requests/upsert-snumber.msgpack b/testdata/requests/upsert-snumber.msgpack new file mode 100644 index 000000000..6e34b0c2f Binary files /dev/null and b/testdata/requests/upsert-snumber.msgpack differ diff --git a/testdata/requests/upsert.msgpack b/testdata/requests/upsert.msgpack new file mode 100644 index 000000000..06a3800f8 --- /dev/null +++ b/testdata/requests/upsert.msgpack @@ -0,0 +1 @@ +�^�table_name!�(� \ No newline at end of file diff --git a/testdata/sidecar/main.go b/testdata/sidecar/main.go new file mode 100644 index 000000000..a2c571b87 --- /dev/null +++ b/testdata/sidecar/main.go @@ -0,0 +1,37 @@ +package main + +import ( + "context" + "os" + "strconv" + + "github.com/tarantool/go-tarantool/v3" +) + +func main() { + fd, err := strconv.Atoi(os.Getenv("SOCKET_FD")) + if err != nil { + panic(err) + } + dialer := tarantool.FdDialer{ + Fd: uintptr(fd), + } + conn, err := tarantool.Connect(context.Background(), dialer, tarantool.Opts{}) + if err != nil { + panic(err) + } + if _, err := conn.Do(tarantool.NewPingRequest()).Get(); err != nil { + panic(err) + } + // Insert new tuple. + if _, err := conn.Do(tarantool.NewInsertRequest("test"). + Tuple([]interface{}{239})).Get(); err != nil { + panic(err) + } + // Delete inserted tuple. + if _, err := conn.Do(tarantool.NewDeleteRequest("test"). + Index("primary"). + Key([]interface{}{239})).Get(); err != nil { + panic(err) + } +} diff --git a/uuid/config.lua b/uuid/config.lua new file mode 100644 index 000000000..b8fe1fe08 --- /dev/null +++ b/uuid/config.lua @@ -0,0 +1,36 @@ +local uuid = require('uuid') +local msgpack = require('msgpack') + +-- Do not set listen for now so connector won't be +-- able to send requests until everything is configured. +box.cfg{ + work_dir = os.getenv("TEST_TNT_WORK_DIR"), +} + +box.schema.user.create('test', { password = 'test' , if_not_exists = true }) +box.schema.user.grant('test', 'execute', 'universe', nil, { if_not_exists = true }) + +local uuid_msgpack_supported = pcall(msgpack.encode, uuid.new()) +if not uuid_msgpack_supported then + error('UUID unsupported, use Tarantool 2.4.1 or newer') +end + +local s = box.schema.space.create('testUUID', { + id = 524, + if_not_exists = true, +}) +s:create_index('primary', { + type = 'tree', + parts = {{ field = 1, type = 'uuid' }}, + if_not_exists = true +}) +s:truncate() + +box.schema.user.grant('test', 'read,write', 'space', 'testUUID', { if_not_exists = true }) + +s:insert({ uuid.fromstr("c8f0fa1f-da29-438c-a040-393f1126ad39") }) + +-- Set listen only when every other thing is configured. +box.cfg{ + listen = os.getenv("TEST_TNT_LISTEN"), +} diff --git a/uuid/example_test.go b/uuid/example_test.go new file mode 100644 index 000000000..8673d390c --- /dev/null +++ b/uuid/example_test.go @@ -0,0 +1,57 @@ +// Run Tarantool instance before example execution: +// Terminal 1: +// $ cd uuid +// $ TEST_TNT_LISTEN=3013 TEST_TNT_WORK_DIR=$(mktemp -d -t 'tarantool.XXX') tarantool config.lua +// +// Terminal 2: +// $ cd uuid +// $ go test -v example_test.go +package uuid_test + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/google/uuid" + + "github.com/tarantool/go-tarantool/v3" + _ "github.com/tarantool/go-tarantool/v3/uuid" +) + +var exampleOpts = tarantool.Opts{ + Timeout: 5 * time.Second, +} + +// Example demonstrates how to use tuples with UUID. To enable UUID support +// in msgpack with google/uuid (https://github.com/google/uuid), import +// tarantool/uuid submodule. +func Example() { + dialer := tarantool.NetDialer{ + Address: "127.0.0.1:3013", + User: "test", + Password: "test", + } + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + client, err := tarantool.Connect(ctx, dialer, exampleOpts) + cancel() + if err != nil { + log.Fatalf("Failed to connect: %s", err.Error()) + } + + spaceNo := uint32(524) + + id, uuidErr := uuid.Parse("c8f0fa1f-da29-438c-a040-393f1126ad39") + if uuidErr != nil { + log.Fatalf("Failed to prepare uuid: %s", uuidErr) + } + + data, err := client.Do(tarantool.NewReplaceRequest(spaceNo). + Tuple([]interface{}{id}), + ).Get() + + fmt.Println("UUID tuple replace") + fmt.Println("Error", err) + fmt.Println("Data", data) +} diff --git a/uuid/uuid.go b/uuid/uuid.go new file mode 100644 index 000000000..ca7b0ad05 --- /dev/null +++ b/uuid/uuid.go @@ -0,0 +1,90 @@ +// Package uuid with support of Tarantool's UUID data type. +// +// UUID data type supported in Tarantool since 2.4.1. +// +// Since: 1.6.0. +// +// # See also +// +// - Tarantool commit with UUID support: +// https://github.com/tarantool/tarantool/commit/d68fc29246714eee505bc9bbcd84a02de17972c5 +// +// - Tarantool data model: +// https://www.tarantool.io/en/doc/latest/book/box/data_model/ +// +// - Module UUID: +// https://www.tarantool.io/en/doc/latest/reference/reference_lua/uuid/ +package uuid + +import ( + "fmt" + "reflect" + + "github.com/google/uuid" + "github.com/vmihailenco/msgpack/v5" +) + +// UUID external type. +const uuid_extID = 2 + +//go:generate go tool gentypes -ext-code 2 -marshal-func marshalUUID -unmarshal-func unmarshalUUID -imports "github.com/google/uuid" uuid.UUID + +func marshalUUID(id uuid.UUID) ([]byte, error) { + return id.MarshalBinary() +} + +func unmarshalUUID(uuid *uuid.UUID, data []byte) error { + return uuid.UnmarshalBinary(data) +} + +// encodeUUID encodes a uuid.UUID value into the msgpack format. +func encodeUUID(e *msgpack.Encoder, v reflect.Value) error { + id := v.Interface().(uuid.UUID) + + bytes, err := id.MarshalBinary() + if err != nil { + return fmt.Errorf("msgpack: can't marshal binary uuid: %w", err) + } + + _, err = e.Writer().Write(bytes) + if err != nil { + return fmt.Errorf("msgpack: can't write bytes to msgpack.Encoder writer: %w", err) + } + + return nil +} + +// decodeUUID decodes a uuid.UUID value from the msgpack format. +func decodeUUID(d *msgpack.Decoder, v reflect.Value) error { + var bytesCount = 16 + bytes := make([]byte, bytesCount) + + n, err := d.Buffered().Read(bytes) + if err != nil { + return fmt.Errorf("msgpack: can't read bytes on uuid decode: %w", err) + } + if n < bytesCount { + return fmt.Errorf("msgpack: unexpected end of stream after %d uuid bytes", n) + } + + id, err := uuid.FromBytes(bytes) + if err != nil { + return fmt.Errorf("msgpack: can't create uuid from bytes: %w", err) + } + + v.Set(reflect.ValueOf(id)) + return nil +} + +func init() { + msgpack.Register(reflect.TypeOf((*uuid.UUID)(nil)).Elem(), encodeUUID, decodeUUID) + msgpack.RegisterExtEncoder(uuid_extID, uuid.UUID{}, + func(e *msgpack.Encoder, v reflect.Value) ([]byte, error) { + uuid := v.Interface().(uuid.UUID) + return uuid.MarshalBinary() + }) + msgpack.RegisterExtDecoder(uuid_extID, uuid.UUID{}, + func(d *msgpack.Decoder, v reflect.Value, extLen int) error { + return decodeUUID(d, v) + }) +} diff --git a/uuid/uuid_gen.go b/uuid/uuid_gen.go new file mode 100644 index 000000000..f1b1992ce --- /dev/null +++ b/uuid/uuid_gen.go @@ -0,0 +1,243 @@ +// Code generated by github.com/tarantool/go-option; DO NOT EDIT. + +package uuid + +import ( + "github.com/google/uuid" + + "fmt" + + "github.com/vmihailenco/msgpack/v5" + "github.com/vmihailenco/msgpack/v5/msgpcode" + + "github.com/tarantool/go-option" +) + +// OptionalUUID represents an optional value of type uuid.UUID. +// It can either hold a valid uuid.UUID (IsSome == true) or be empty (IsZero == true). +type OptionalUUID struct { + value uuid.UUID + exists bool +} + +// SomeOptionalUUID creates an optional OptionalUUID with the given uuid.UUID value. +// The returned OptionalUUID will have IsSome() == true and IsZero() == false. +func SomeOptionalUUID(value uuid.UUID) OptionalUUID { + return OptionalUUID{ + value: value, + exists: true, + } +} + +// NoneOptionalUUID creates an empty optional OptionalUUID value. +// The returned OptionalUUID will have IsSome() == false and IsZero() == true. +// +// Example: +// +// o := NoneOptionalUUID() +// if o.IsZero() { +// fmt.Println("value is absent") +// } +func NoneOptionalUUID() OptionalUUID { + return OptionalUUID{} +} + +func (o OptionalUUID) newEncodeError(err error) error { + if err == nil { + return nil + } + return &option.EncodeError{ + Type: "OptionalUUID", + Parent: err, + } +} + +func (o OptionalUUID) newDecodeError(err error) error { + if err == nil { + return nil + } + + return &option.DecodeError{ + Type: "OptionalUUID", + Parent: err, + } +} + +// IsSome returns true if the OptionalUUID contains a value. +// This indicates the value is explicitly set (not None). +func (o OptionalUUID) IsSome() bool { + return o.exists +} + +// IsZero returns true if the OptionalUUID does not contain a value. +// Equivalent to !IsSome(). Useful for consistency with types where +// zero value (e.g. 0, false, zero struct) is valid and needs to be distinguished. +func (o OptionalUUID) IsZero() bool { + return !o.exists +} + +// IsNil is an alias for IsZero. +// +// This method is provided for compatibility with the msgpack Encoder interface. +func (o OptionalUUID) IsNil() bool { + return o.IsZero() +} + +// Get returns the stored value and a boolean flag indicating its presence. +// If the value is present, returns (value, true). +// If the value is absent, returns (zero value of uuid.UUID, false). +// +// Recommended usage: +// +// if value, ok := o.Get(); ok { +// // use value +// } +func (o OptionalUUID) Get() (uuid.UUID, bool) { + return o.value, o.exists +} + +// MustGet returns the stored value if it is present. +// Panics if the value is absent (i.e., IsZero() == true). +// +// Use with caution — only when you are certain the value exists. +// +// Panics with: "optional value is not set" if no value is set. +func (o OptionalUUID) MustGet() uuid.UUID { + if !o.exists { + panic("optional value is not set") + } + + return o.value +} + +// Unwrap returns the stored value regardless of presence. +// If no value is set, returns the zero value for uuid.UUID. +// +// Warning: Does not check presence. Use IsSome() before calling if you need +// to distinguish between absent value and explicit zero value. +func (o OptionalUUID) Unwrap() uuid.UUID { + return o.value +} + +// UnwrapOr returns the stored value if present. +// Otherwise, returns the provided default value. +// +// Example: +// +// o := NoneOptionalUUID() +// v := o.UnwrapOr(someDefaultOptionalUUID) +func (o OptionalUUID) UnwrapOr(defaultValue uuid.UUID) uuid.UUID { + if o.exists { + return o.value + } + + return defaultValue +} + +// UnwrapOrElse returns the stored value if present. +// Otherwise, calls the provided function and returns its result. +// Useful when the default value requires computation or side effects. +// +// Example: +// +// o := NoneOptionalUUID() +// v := o.UnwrapOrElse(func() uuid.UUID { return computeDefault() }) +func (o OptionalUUID) UnwrapOrElse(defaultValue func() uuid.UUID) uuid.UUID { + if o.exists { + return o.value + } + + return defaultValue() +} + +func (o OptionalUUID) encodeValue(encoder *msgpack.Encoder) error { + value, err := marshalUUID(o.value) + if err != nil { + return err + } + + err = encoder.EncodeExtHeader(2, len(value)) + if err != nil { + return err + } + + _, err = encoder.Writer().Write(value) + if err != nil { + return err + } + + return nil +} + +// EncodeMsgpack encodes the OptionalUUID value using MessagePack format. +// - If the value is present, it is encoded as uuid.UUID. +// - If the value is absent (None), it is encoded as nil. +// +// Returns an error if encoding fails. +func (o OptionalUUID) EncodeMsgpack(encoder *msgpack.Encoder) error { + if o.exists { + return o.newEncodeError(o.encodeValue(encoder)) + } + + return o.newEncodeError(encoder.EncodeNil()) +} + +func (o *OptionalUUID) decodeValue(decoder *msgpack.Decoder) error { + tp, length, err := decoder.DecodeExtHeader() + switch { + case err != nil: + return o.newDecodeError(err) + case tp != 2: + return o.newDecodeError(fmt.Errorf("invalid extension code: %d", tp)) + } + + a := make([]byte, length) + if err := decoder.ReadFull(a); err != nil { + return o.newDecodeError(err) + } + + if err := unmarshalUUID(&o.value, a); err != nil { + return o.newDecodeError(err) + } + + o.exists = true + return nil +} + +func (o *OptionalUUID) checkCode(code byte) bool { + return msgpcode.IsExt(code) +} + +// DecodeMsgpack decodes a OptionalUUID value from MessagePack format. +// Supports two input types: +// - nil: interpreted as no value (NoneOptionalUUID) +// - uuid.UUID: interpreted as a present value (SomeOptionalUUID) +// +// Returns an error if the input type is unsupported or decoding fails. +// +// After successful decoding: +// - on nil: exists = false, value = default zero value +// - on uuid.UUID: exists = true, value = decoded value +func (o *OptionalUUID) DecodeMsgpack(decoder *msgpack.Decoder) error { + code, err := decoder.PeekCode() + if err != nil { + return o.newDecodeError(err) + } + + switch { + case code == msgpcode.Nil: + o.exists = false + + return o.newDecodeError(decoder.Skip()) + case o.checkCode(code): + err := o.decodeValue(decoder) + if err != nil { + return o.newDecodeError(err) + } + o.exists = true + + return err + default: + return o.newDecodeError(fmt.Errorf("unexpected code: %d", code)) + } +} diff --git a/uuid/uuid_gen_test.go b/uuid/uuid_gen_test.go new file mode 100644 index 000000000..616bb2314 --- /dev/null +++ b/uuid/uuid_gen_test.go @@ -0,0 +1,117 @@ +package uuid + +import ( + "bytes" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/vmihailenco/msgpack/v5" +) + +func TestSomeOptionalUUID(t *testing.T) { + val := uuid.New() + opt := SomeOptionalUUID(val) + + assert.True(t, opt.IsSome()) + assert.False(t, opt.IsZero()) + + v, ok := opt.Get() + assert.True(t, ok) + assert.Equal(t, val, v) +} + +func TestNoneOptionalUUID(t *testing.T) { + opt := NoneOptionalUUID() + + assert.False(t, opt.IsSome()) + assert.True(t, opt.IsZero()) + + _, ok := opt.Get() + assert.False(t, ok) +} + +func TestOptionalUUID_MustGet(t *testing.T) { + val := uuid.New() + optSome := SomeOptionalUUID(val) + optNone := NoneOptionalUUID() + + assert.Equal(t, val, optSome.MustGet()) + assert.Panics(t, func() { optNone.MustGet() }) +} + +func TestOptionalUUID_Unwrap(t *testing.T) { + val := uuid.New() + optSome := SomeOptionalUUID(val) + optNone := NoneOptionalUUID() + + assert.Equal(t, val, optSome.Unwrap()) + assert.Equal(t, uuid.Nil, optNone.Unwrap()) +} + +func TestOptionalUUID_UnwrapOr(t *testing.T) { + val := uuid.New() + def := uuid.New() + optSome := SomeOptionalUUID(val) + optNone := NoneOptionalUUID() + + assert.Equal(t, val, optSome.UnwrapOr(def)) + assert.Equal(t, def, optNone.UnwrapOr(def)) +} + +func TestOptionalUUID_UnwrapOrElse(t *testing.T) { + val := uuid.New() + def := uuid.New() + optSome := SomeOptionalUUID(val) + optNone := NoneOptionalUUID() + + assert.Equal(t, val, optSome.UnwrapOrElse(func() uuid.UUID { return def })) + assert.Equal(t, def, optNone.UnwrapOrElse(func() uuid.UUID { return def })) +} + +func TestOptionalUUID_EncodeDecodeMsgpack_Some(t *testing.T) { + val := uuid.New() + some := SomeOptionalUUID(val) + + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + dec := msgpack.NewDecoder(&buf) + + err := enc.Encode(some) + assert.NoError(t, err) + + var decodedSome OptionalUUID + err = dec.Decode(&decodedSome) + assert.NoError(t, err) + assert.True(t, decodedSome.IsSome()) + assert.Equal(t, val, decodedSome.Unwrap()) +} + +func TestOptionalUUID_EncodeDecodeMsgpack_None(t *testing.T) { + none := NoneOptionalUUID() + + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + dec := msgpack.NewDecoder(&buf) + + err := enc.Encode(none) + assert.NoError(t, err) + + var decodedNone OptionalUUID + err = dec.Decode(&decodedNone) + assert.NoError(t, err) + assert.True(t, decodedNone.IsZero()) +} + +func TestOptionalUUID_EncodeDecodeMsgpack_InvalidType(t *testing.T) { + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + dec := msgpack.NewDecoder(&buf) + + err := enc.Encode(123) + assert.NoError(t, err) + + var decodedInvalid OptionalUUID + err = dec.Decode(&decodedInvalid) + assert.Error(t, err) +} diff --git a/uuid/uuid_test.go b/uuid/uuid_test.go new file mode 100644 index 000000000..22ffd7eb5 --- /dev/null +++ b/uuid/uuid_test.go @@ -0,0 +1,183 @@ +package uuid_test + +import ( + "fmt" + "log" + "os" + "testing" + "time" + + "github.com/google/uuid" + "github.com/vmihailenco/msgpack/v5" + + . "github.com/tarantool/go-tarantool/v3" + "github.com/tarantool/go-tarantool/v3/test_helpers" + _ "github.com/tarantool/go-tarantool/v3/uuid" +) + +// There is no way to skip tests in testing.M, +// so we use this variable to pass info +// to each testing.T that it should skip. +var isUUIDSupported = false + +var server = "127.0.0.1:3013" +var opts = Opts{ + Timeout: 5 * time.Second, +} +var dialer = NetDialer{ + Address: server, + User: "test", + Password: "test", +} + +var space = "testUUID" +var index = "primary" + +type TupleUUID struct { + id uuid.UUID +} + +func (t *TupleUUID) DecodeMsgpack(d *msgpack.Decoder) error { + var err error + var l int + if l, err = d.DecodeArrayLen(); err != nil { + return err + } + if l != 1 { + return fmt.Errorf("array len doesn't match: %d", l) + } + res, err := d.DecodeInterface() + if err != nil { + return err + } + t.id = res.(uuid.UUID) + return nil +} + +func tupleValueIsId(t *testing.T, tuples []interface{}, id uuid.UUID) { + if len(tuples) != 1 { + t.Fatalf("Response Data len != 1") + } + + if tpl, ok := tuples[0].([]interface{}); !ok { + t.Errorf("Unexpected return value body") + } else { + if len(tpl) != 1 { + t.Errorf("Unexpected return value body (tuple len)") + } + if val, ok := tpl[0].(uuid.UUID); !ok || val != id { + t.Errorf("Unexpected return value body (tuple 0 field)") + } + } +} + +func TestSelect(t *testing.T) { + if isUUIDSupported == false { + t.Skip("Skipping test for Tarantool without UUID support in msgpack") + } + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + id, uuidErr := uuid.Parse("c8f0fa1f-da29-438c-a040-393f1126ad39") + if uuidErr != nil { + t.Fatalf("Failed to prepare test uuid: %s", uuidErr) + } + + sel := NewSelectRequest(space). + Index(index). + Limit(1). + Iterator(IterEq). + Key([]interface{}{id}) + data, errSel := conn.Do(sel).Get() + if errSel != nil { + t.Fatalf("UUID select failed: %s", errSel.Error()) + } + tupleValueIsId(t, data, id) + + var tuples []TupleUUID + errTyp := conn.Do(sel).GetTyped(&tuples) + if errTyp != nil { + t.Fatalf("Failed to SelectTyped: %s", errTyp.Error()) + } + if len(tuples) != 1 { + t.Errorf("Result len of SelectTyped != 1") + } + if tuples[0].id != id { + t.Errorf("Bad value loaded from SelectTyped: %s", tuples[0].id) + } +} + +func TestReplace(t *testing.T) { + if isUUIDSupported == false { + t.Skip("Skipping test for Tarantool without UUID support in msgpack") + } + + conn := test_helpers.ConnectWithValidation(t, dialer, opts) + defer conn.Close() + + id, uuidErr := uuid.Parse("64d22e4d-ac92-4a23-899a-e59f34af5479") + if uuidErr != nil { + t.Errorf("Failed to prepare test uuid: %s", uuidErr) + } + + rep := NewReplaceRequest(space).Tuple([]interface{}{id}) + dataRep, errRep := conn.Do(rep).Get() + if errRep != nil { + t.Errorf("UUID replace failed: %s", errRep) + } + tupleValueIsId(t, dataRep, id) + + sel := NewSelectRequest(space). + Index(index). + Limit(1). + Iterator(IterEq). + Key([]interface{}{id}) + dataSel, errSel := conn.Do(sel).Get() + if errSel != nil { + t.Errorf("UUID select failed: %s", errSel) + } + tupleValueIsId(t, dataSel, id) +} + +// runTestMain is a body of TestMain function +// (see https://pkg.go.dev/testing#hdr-Main). +// Using defer + os.Exit is not works so TestMain body +// is a separate function, see +// https://stackoverflow.com/questions/27629380/how-to-exit-a-go-program-honoring-deferred-calls +func runTestMain(m *testing.M) int { + isLess, err := test_helpers.IsTarantoolVersionLess(2, 4, 1) + if err != nil { + log.Fatalf("Failed to extract tarantool version: %s", err) + } + + if isLess { + log.Println("Skipping UUID tests...") + isUUIDSupported = false + return m.Run() + } else { + isUUIDSupported = true + } + + inst, err := test_helpers.StartTarantool(test_helpers.StartOpts{ + Dialer: dialer, + InitScript: "config.lua", + Listen: server, + WaitStart: 100 * time.Millisecond, + ConnectRetry: 10, + RetryTimeout: 500 * time.Millisecond, + }) + defer test_helpers.StopTarantoolWithCleanup(inst) + + if err != nil { + log.Printf("Failed to prepare test tarantool: %s", err) + return 1 + } + + return m.Run() +} + +func TestMain(m *testing.M) { + code := runTestMain(m) + os.Exit(code) +} diff --git a/watch.go b/watch.go new file mode 100644 index 000000000..9628b96ac --- /dev/null +++ b/watch.go @@ -0,0 +1,151 @@ +package tarantool + +import ( + "context" + "io" + + "github.com/tarantool/go-iproto" + "github.com/vmihailenco/msgpack/v5" +) + +// BroadcastRequest helps to send broadcast messages. See: +// https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_events/broadcast/ +type BroadcastRequest struct { + call *CallRequest + key string +} + +// NewBroadcastRequest returns a new broadcast request for a specified key. +func NewBroadcastRequest(key string) *BroadcastRequest { + req := new(BroadcastRequest) + req.key = key + req.call = NewCallRequest("box.broadcast").Args([]interface{}{key}) + return req +} + +// Value sets the value for the broadcast request. +// Note: default value is nil. +func (req *BroadcastRequest) Value(value interface{}) *BroadcastRequest { + req.call = req.call.Args([]interface{}{req.key, value}) + return req +} + +// Context sets a passed context to the broadcast request. +func (req *BroadcastRequest) Context(ctx context.Context) *BroadcastRequest { + req.call = req.call.Context(ctx) + return req +} + +// Code returns IPROTO code for the broadcast request. +func (req *BroadcastRequest) Type() iproto.Type { + return req.call.Type() +} + +// Body fills an msgpack.Encoder with the broadcast request body. +func (req *BroadcastRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { + return req.call.Body(res, enc) +} + +// Ctx returns a context of the broadcast request. +func (req *BroadcastRequest) Ctx() context.Context { + return req.call.Ctx() +} + +// Async returns is the broadcast request expects a response. +func (req *BroadcastRequest) Async() bool { + return req.call.Async() +} + +// Response creates a response for a BroadcastRequest. +func (req *BroadcastRequest) Response(header Header, body io.Reader) (Response, error) { + return DecodeBaseResponse(header, body) +} + +// watchRequest subscribes to the updates of a specified key defined on the +// server. After receiving the notification, you should send a new +// watchRequest to acknowledge the notification. +type watchRequest struct { + baseRequest + key string + ctx context.Context +} + +// newWatchRequest returns a new watchRequest. +func newWatchRequest(key string) *watchRequest { + req := new(watchRequest) + req.rtype = iproto.IPROTO_WATCH + req.async = true + req.key = key + return req +} + +// Body fills an msgpack.Encoder with the watch request body. +func (req *watchRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { + if err := enc.EncodeMapLen(1); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(iproto.IPROTO_EVENT_KEY)); err != nil { + return err + } + + return enc.EncodeString(req.key) +} + +// Context sets a passed context to the request. +func (req *watchRequest) Context(ctx context.Context) *watchRequest { + req.ctx = ctx + return req +} + +// unwatchRequest unregisters a watcher subscribed to the given notification +// key. +type unwatchRequest struct { + baseRequest + key string + ctx context.Context +} + +// newUnwatchRequest returns a new unwatchRequest. +func newUnwatchRequest(key string) *unwatchRequest { + req := new(unwatchRequest) + req.rtype = iproto.IPROTO_UNWATCH + req.async = true + req.key = key + return req +} + +// Body fills an msgpack.Encoder with the unwatch request body. +func (req *unwatchRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { + if err := enc.EncodeMapLen(1); err != nil { + return err + } + + if err := enc.EncodeUint(uint64(iproto.IPROTO_EVENT_KEY)); err != nil { + return err + } + + return enc.EncodeString(req.key) +} + +// Context sets a passed context to the request. +func (req *unwatchRequest) Context(ctx context.Context) *unwatchRequest { + req.ctx = ctx + return req +} + +// WatchEvent is a watch notification event received from a server. +type WatchEvent struct { + Conn *Connection // A source connection. + Key string // A key. + Value interface{} // A value. +} + +// Watcher is a subscription to broadcast events. +type Watcher interface { + // Unregister unregisters the watcher. + Unregister() +} + +// WatchCallback is a callback to invoke when the key value is updated. +type WatchCallback func(event WatchEvent)