diff --git a/.asf.yaml b/.asf.yaml index dd4975435cf0..36f01b88a724 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -46,6 +46,9 @@ github: strict: true # don't require any jobs to pass contexts: [] + pull_requests: + # enable updating head branches of pull requests + allow_update_branch: true # publishes the content of the `asf-site` branch to # https://arrow.apache.org/rust/ diff --git a/.github/actions/setup-builder/action.yaml b/.github/actions/setup-builder/action.yaml index 20da777ec0e5..f73f7abf9b82 100644 --- a/.github/actions/setup-builder/action.yaml +++ b/.github/actions/setup-builder/action.yaml @@ -17,15 +17,6 @@ name: Prepare Rust Builder description: 'Prepare Rust Build Environment' -inputs: - rust-version: - description: 'version of rust to install (e.g. stable)' - required: false - default: 'stable' - target: - description: 'target architecture(s)' - required: false - default: 'x86_64-unknown-linux-gnu' runs: using: "composite" steps: @@ -43,6 +34,9 @@ runs: /usr/local/cargo/git/db/ key: cargo-cache3-${{ hashFiles('**/Cargo.toml') }} restore-keys: cargo-cache3- + - name: Setup Rust toolchain + shell: bash + run: rustup install - name: Generate lockfile shell: bash run: cargo fetch @@ -51,12 +45,6 @@ runs: run: | apt-get update apt-get install -y protobuf-compiler - - name: Setup Rust toolchain - shell: bash - run: | - echo "Installing ${{ inputs.rust-version }}" - rustup toolchain install ${{ inputs.rust-version }} --target ${{ inputs.target }} - rustup default ${{ inputs.rust-version }} - name: Disable debuginfo generation # Disable full debug symbol generation to speed up CI build and keep memory down # "1" means line tables only, which is useful for panic tracebacks. diff --git a/.github/workflows/arrow.yml b/.github/workflows/arrow.yml index 9d2d7761725b..9b8147326186 100644 --- a/.github/workflows/arrow.yml +++ b/.github/workflows/arrow.yml @@ -56,7 +56,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: submodules: true - name: Setup Rust toolchain @@ -115,7 +115,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: submodules: true - name: Setup Rust toolchain @@ -143,13 +143,15 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - target: wasm32-unknown-unknown,wasm32-wasip1 + - name: Install wasm32 targets + run: | + rustup target add wasm32-unknown-unknown + rustup target add wasm32-wasip1 - name: Build wasm32-unknown-unknown run: cargo build -p arrow --no-default-features --features=json,csv,ipc,ffi --target wasm32-unknown-unknown - name: Build wasm32-wasip1 @@ -161,7 +163,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Setup Clippy diff --git a/.github/workflows/arrow_flight.yml b/.github/workflows/arrow_flight.yml index a76d721b4948..e6aba901aa22 100644 --- a/.github/workflows/arrow_flight.yml +++ b/.github/workflows/arrow_flight.yml @@ -47,7 +47,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: submodules: true - name: Setup Rust toolchain @@ -68,7 +68,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Run gen @@ -82,7 +82,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Setup Clippy diff --git a/.github/workflows/audit.yml b/.github/workflows/audit.yml index e6254ea24a58..a5646ea508aa 100644 --- a/.github/workflows/audit.yml +++ b/.github/workflows/audit.yml @@ -36,7 +36,7 @@ jobs: name: Audit runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Install cargo-audit run: cargo install cargo-audit - name: Run audit check diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index b28e8c20cfe7..32a582af04de 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -38,9 +38,9 @@ jobs: name: Release Audit Tool (RAT) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: 3.8 - name: Audit licenses @@ -50,8 +50,8 @@ jobs: name: Markdown format runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: actions/setup-node@v4 + - uses: actions/checkout@v5 + - uses: actions/setup-node@v5 with: node-version: "14" - name: Prettier check diff --git a/.github/workflows/dev_pr.yml b/.github/workflows/dev_pr.yml index 0d60ae006796..4d81716395b3 100644 --- a/.github/workflows/dev_pr.yml +++ b/.github/workflows/dev_pr.yml @@ -37,14 +37,14 @@ jobs: contents: read pull-requests: write steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Assign GitHub labels if: | github.event_name == 'pull_request_target' && (github.event.action == 'opened' || github.event.action == 'synchronize') - uses: actions/labeler@v5.0.0 + uses: actions/labeler@v6.0.1 with: repo-token: ${{ secrets.GITHUB_TOKEN }} configuration-path: .github/workflows/dev_pr/labeler.yml diff --git a/.github/workflows/dev_pr/labeler.yml b/.github/workflows/dev_pr/labeler.yml index 64299bd507d3..edb6d036174c 100644 --- a/.github/workflows/dev_pr/labeler.yml +++ b/.github/workflows/dev_pr/labeler.yml @@ -37,6 +37,11 @@ arrow: - 'arrow-string/**/*' - 'arrow/**/*' +arrow-avro: + - changed-files: + - any-glob-to-any-file: + - 'arrow-avro/**/*' + arrow-flight: - changed-files: - any-glob-to-any-file: @@ -46,7 +51,13 @@ parquet: - changed-files: - any-glob-to-any-file: - 'parquet/**/*' - - 'parquet-variant/**/*' + +parquet-variant: + - changed-files: + - any-glob-to-any-file: + - 'parquet-variant/**/*' + - 'parquet-variant-compute/**/*' + - 'parquet-variant-json/**/*' parquet-derive: - changed-files: diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 354a77b76634..4eaf62d95de2 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -34,28 +34,20 @@ jobs: docs: name: Rustdocs are clean runs-on: ubuntu-latest - strategy: - matrix: - arch: [ amd64 ] - rust: [ nightly ] container: - image: ${{ matrix.arch }}/rust + image: amd64/rust env: RUSTDOCFLAGS: "-Dwarnings --enable-index-page -Zunstable-options" steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: submodules: true - - name: Install python dev - run: | - apt update - apt install -y libpython3.11-dev - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - rust-version: ${{ matrix.rust }} + - name: Install Nightly Rust + run: rustup install nightly - name: Run cargo doc - run: cargo doc --document-private-items --no-deps --workspace --all-features + run: cargo +nightly doc --document-private-items --no-deps --workspace --all-features - name: Fix file permissions shell: sh run: | @@ -64,7 +56,7 @@ jobs: echo "::warning title=Invalid file permissions automatically fixed::$line" done - name: Upload artifacts - uses: actions/upload-pages-artifact@v3 + uses: actions/upload-pages-artifact@v4 with: name: crate-docs path: target/doc @@ -77,7 +69,7 @@ jobs: contents: write runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Download crate docs uses: actions/download-artifact@v5 with: diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 09711719296c..923da88eb580 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -63,6 +63,7 @@ jobs: ARROW_INTEGRATION_CPP: ON ARROW_INTEGRATION_CSHARP: ON ARCHERY_INTEGRATION_TARGET_IMPLEMENTATIONS: "rust" + ARCHERY_INTEGRATION_WITH_DOTNET: "1" ARCHERY_INTEGRATION_WITH_GO: "1" ARCHERY_INTEGRATION_WITH_JAVA: "1" ARCHERY_INTEGRATION_WITH_JS: "1" @@ -88,33 +89,38 @@ jobs: - name: Check cmake run: which cmake - name: Checkout Arrow - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: repository: apache/arrow submodules: true fetch-depth: 0 - name: Checkout Arrow Rust - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: path: rust fetch-depth: 0 + - name: Checkout Arrow .NET + uses: actions/checkout@v5 + with: + repository: apache/arrow-dotnet + path: dotnet - name: Checkout Arrow Go - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: repository: apache/arrow-go path: go - name: Checkout Arrow Java - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: repository: apache/arrow-java path: java - name: Checkout Arrow JavaScript - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: repository: apache/arrow-js path: js - name: Checkout Arrow nanoarrow - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: repository: apache/arrow-nanoarrow path: nanoarrow @@ -133,7 +139,7 @@ jobs: # PyArrow 15 was the first version to introduce StringView/BinaryView support pyarrow: ["15", "16", "17"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: submodules: true - name: Setup Rust toolchain @@ -152,7 +158,7 @@ jobs: path: /home/runner/target # this key is not equal because maturin uses different compilation flags. key: ${{ runner.os }}-${{ matrix.arch }}-target-maturin-cache-${{ matrix.rust }}- - - uses: actions/setup-python@v5 + - uses: actions/setup-python@v6 with: python-version: '3.8' - name: Upgrade pip and setuptools diff --git a/.github/workflows/miri.yaml b/.github/workflows/miri.yaml index ce67546a104b..92c432dc893b 100644 --- a/.github/workflows/miri.yaml +++ b/.github/workflows/miri.yaml @@ -47,7 +47,7 @@ jobs: name: MIRI runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: submodules: true - name: Setup Rust toolchain diff --git a/.github/workflows/parquet-variant.yml b/.github/workflows/parquet-variant.yml index 9e4003f3645f..26cd73ea24e5 100644 --- a/.github/workflows/parquet-variant.yml +++ b/.github/workflows/parquet-variant.yml @@ -43,7 +43,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: submodules: true - name: Setup Rust toolchain @@ -62,7 +62,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: submodules: true - name: Setup Rust toolchain @@ -80,7 +80,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Setup Clippy diff --git a/.github/workflows/parquet.yml b/.github/workflows/parquet.yml index 946aef75db19..09fc18e351d9 100644 --- a/.github/workflows/parquet.yml +++ b/.github/workflows/parquet.yml @@ -42,6 +42,9 @@ on: - arrow-json/** - arrow-avro/** - parquet/** + - parquet-variant/** + - parquet-variant-compute/** + - parquet-variant-json/** - .github/** jobs: @@ -52,7 +55,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: submodules: true - name: Setup Rust toolchain @@ -75,7 +78,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: submodules: true - name: Setup Rust toolchain @@ -119,7 +122,9 @@ jobs: run: cargo check -p parquet --no-default-features --features flate2 --features flate2-rust_backened - name: Check compilation --no-default-features --features flate2 --features flate2-zlib-rs run: cargo check -p parquet --no-default-features --features flate2 --features flate2-zlib-rs - + - name: Check compilation --no-default-features --features variant_experimental + run: cargo check -p parquet --no-default-features --features variant_experimental + # test the parquet crate builds against wasm32 in stable rust wasm32-build: @@ -128,13 +133,15 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - target: wasm32-unknown-unknown,wasm32-wasip1 + - name: Install wasm32 targets + run: | + rustup target add wasm32-unknown-unknown + rustup target add wasm32-wasip1 - name: Install clang # Needed for zlib compilation run: apt-get update && apt-get install -y clang gcc-multilib - name: Build wasm32-unknown-unknown @@ -149,9 +156,9 @@ jobs: matrix: rust: [ stable ] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.10" cache: "pip" @@ -182,7 +189,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Setup Clippy diff --git a/.github/workflows/parquet_derive.yml b/.github/workflows/parquet_derive.yml index 17aec724a820..98c3168cc1be 100644 --- a/.github/workflows/parquet_derive.yml +++ b/.github/workflows/parquet_derive.yml @@ -43,7 +43,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: submodules: true - name: Setup Rust toolchain @@ -57,7 +57,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Setup Clippy diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8f87c50649d3..c3295d78d48b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -33,7 +33,7 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 5 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Create GitHub Releases run: | version=${GITHUB_REF_NAME} diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 38cccdec3c70..9cd33b296da1 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -36,7 +36,7 @@ jobs: name: Test on Mac runs-on: macos-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: submodules: true - name: Install protoc with brew @@ -59,7 +59,7 @@ jobs: name: Test on Windows runs-on: windows-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: submodules: true - name: Install protobuf compiler in /d/protoc @@ -91,7 +91,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Setup rustfmt @@ -113,11 +113,12 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - - name: Install cargo-msrv - run: cargo install cargo-msrv + - name: Install cargo-msrv (if needed) + # cargo-msrv binary may be cached by the cargo cache step in setup-builder, and cargo install will error if it is already installed + run: if which cargo-msrv ; then echo "using existing cargo-msrv binary" ; else cargo install cargo-msrv ; fi - name: Check all packages run: | # run `cargo msrv verify --manifest-path "path/to/Cargo.toml"` to see problematic dependencies diff --git a/.github/workflows/take.yml b/.github/workflows/take.yml index dd21c794960e..94a95f6e31a2 100644 --- a/.github/workflows/take.yml +++ b/.github/workflows/take.yml @@ -28,7 +28,7 @@ jobs: if: (!github.event.issue.pull_request) && github.event.comment.body == 'take' runs-on: ubuntu-latest steps: - - uses: actions/github-script@v7 + - uses: actions/github-script@v8 with: script: | github.rest.issues.addAssignees({ diff --git a/CHANGELOG-old.md b/CHANGELOG-old.md index 5e9e568115c7..e69e2fd596f0 100644 --- a/CHANGELOG-old.md +++ b/CHANGELOG-old.md @@ -19,6 +19,281 @@ # Historical Changelog +## [56.0.0](https://github.com/apache/arrow-rs/tree/56.0.0) (2025-07-29) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/55.2.0...56.0.0) + +**Breaking changes:** + +- arrow-schema: Remove dict\_id from being required equal for merging [\#7968](https://github.com/apache/arrow-rs/pull/7968) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- \[Parquet\] Use `u64` for `SerializedPageReaderState.offset` & `remaining_bytes`, instead of `usize` [\#7918](https://github.com/apache/arrow-rs/pull/7918) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([JigaoLuo](https://github.com/JigaoLuo)) +- Upgrade tonic dependencies to 0.13.0 version \(try 2\) [\#7839](https://github.com/apache/arrow-rs/pull/7839) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Remove deprecated Arrow functions [\#7830](https://github.com/apache/arrow-rs/pull/7830) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([etseidl](https://github.com/etseidl)) +- Remove deprecated temporal functions [\#7813](https://github.com/apache/arrow-rs/pull/7813) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([etseidl](https://github.com/etseidl)) +- Remove functions from parquet crate deprecated in or before 54.0.0 [\#7811](https://github.com/apache/arrow-rs/pull/7811) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- GH-7686: \[Parquet\] Fix int96 min/max stats [\#7687](https://github.com/apache/arrow-rs/pull/7687) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([rahulketch](https://github.com/rahulketch)) + +**Implemented enhancements:** + +- \[parquet\] Relax type restriction to allow writing dictionary/native batches for same column [\#8004](https://github.com/apache/arrow-rs/issues/8004) +- Support casting int64 to interval [\#7988](https://github.com/apache/arrow-rs/issues/7988) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Add `ListBuilder::with_value` for convenience [\#7951](https://github.com/apache/arrow-rs/issues/7951) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Add `ObjectBuilder::with_field` for convenience [\#7949](https://github.com/apache/arrow-rs/issues/7949) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Impl PartialEq for VariantObject \#7943 [\#7948](https://github.com/apache/arrow-rs/issues/7948) +- \[Variant\] Offer `simdutf8` as an optional dependency when validating metadata [\#7902](https://github.com/apache/arrow-rs/issues/7902) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Avoid collecting offset iterator [\#7901](https://github.com/apache/arrow-rs/issues/7901) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Remove superfluous check when validating monotonic offsets [\#7900](https://github.com/apache/arrow-rs/issues/7900) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Avoid extra allocation in `ObjectBuilder` [\#7899](https://github.com/apache/arrow-rs/issues/7899) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]\[Compute\] `variant_get` kernel [\#7893](https://github.com/apache/arrow-rs/issues/7893) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]\[Compute\] Add batch processing for Variant-JSON String conversion [\#7883](https://github.com/apache/arrow-rs/issues/7883) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support `MapArray` in lexsort [\#7881](https://github.com/apache/arrow-rs/issues/7881) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Add testing for invalid variants \(fuzz testing??\) [\#7842](https://github.com/apache/arrow-rs/issues/7842) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] VariantMetadata, VariantList and VariantObject are too big for Copy [\#7831](https://github.com/apache/arrow-rs/issues/7831) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Allow choosing flate2 backend [\#7826](https://github.com/apache/arrow-rs/issues/7826) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Tests for creating "large" `VariantObjects`s [\#7821](https://github.com/apache/arrow-rs/issues/7821) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Tests for creating "large" `VariantList`s [\#7820](https://github.com/apache/arrow-rs/issues/7820) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Support VariantBuilder to write to buffers owned by the caller [\#7805](https://github.com/apache/arrow-rs/issues/7805) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Move JSON related functionality to different crate. [\#7800](https://github.com/apache/arrow-rs/issues/7800) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Add flag in `ObjectBuilder` to control validation behavior on duplicate field write [\#7777](https://github.com/apache/arrow-rs/issues/7777) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] make `serde_json` an optional dependency of `parquet-variant` [\#7775](https://github.com/apache/arrow-rs/issues/7775) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[coalesce\] Implement specialized `BatchCoalescer::push_batch` for `PrimitiveArray` [\#7763](https://github.com/apache/arrow-rs/issues/7763) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add sort\_kernel benchmark for StringViewArray case [\#7758](https://github.com/apache/arrow-rs/issues/7758) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Improved API for accessing Variant Objects and lists [\#7756](https://github.com/apache/arrow-rs/issues/7756) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Buildable reproducible release builds [\#7751](https://github.com/apache/arrow-rs/issues/7751) +- Allow per-column parquet dictionary page size limit [\#7723](https://github.com/apache/arrow-rs/issues/7723) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Test and implement efficient building for "large" Arrays [\#7699](https://github.com/apache/arrow-rs/issues/7699) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Improve VariantBuilder when creating field name dictionaries / sorted dictionaries [\#7698](https://github.com/apache/arrow-rs/issues/7698) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Add input validation in `VariantBuilder` [\#7697](https://github.com/apache/arrow-rs/issues/7697) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Support Nested Data in `VariantBuilder` [\#7696](https://github.com/apache/arrow-rs/issues/7696) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Parquet: Incorrect min/max stats for int96 columns [\#7686](https://github.com/apache/arrow-rs/issues/7686) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add `DictionaryArray::gc` method [\#7683](https://github.com/apache/arrow-rs/issues/7683) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Add negative tests for reading invalid primitive variant values [\#7645](https://github.com/apache/arrow-rs/issues/7645) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Fixed bugs:** + +- \[Variant\] Panic when appending nested objects to VariantBuilder [\#7907](https://github.com/apache/arrow-rs/issues/7907) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Panic when casting large Decimal256 to f64 due to unchecked `unwrap()` [\#7886](https://github.com/apache/arrow-rs/issues/7886) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Incorrect inlined string view comparison after " Add prefix compare for inlined" [\#7874](https://github.com/apache/arrow-rs/issues/7874) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] `test_json_to_variant_object_very_large` takes over 20s [\#7872](https://github.com/apache/arrow-rs/issues/7872) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] If `ObjectBuilder::finalize` is not called, the resulting Variant object is malformed. [\#7863](https://github.com/apache/arrow-rs/issues/7863) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- CSV error message has values transposed [\#7848](https://github.com/apache/arrow-rs/issues/7848) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Concating struct arrays with no fields unnecessarily errors [\#7828](https://github.com/apache/arrow-rs/issues/7828) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Clippy CI is failing on main after Rust `1.88` upgrade [\#7796](https://github.com/apache/arrow-rs/issues/7796) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- \[Variant\] Field lookup with out of bounds index causes unwanted behavior [\#7784](https://github.com/apache/arrow-rs/issues/7784) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Error verifying `parquet-variant` crate on 55.2.0 with `verify-release-candidate.sh` [\#7746](https://github.com/apache/arrow-rs/issues/7746) +- `test_to_pyarrow` tests fail during release verification [\#7736](https://github.com/apache/arrow-rs/issues/7736) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[parquet\_derive\] Example for ParquetRecordWriter is broken. [\#7732](https://github.com/apache/arrow-rs/issues/7732) +- \[Variant\] `Variant::Object` can contain two fields with the same field name [\#7730](https://github.com/apache/arrow-rs/issues/7730) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Panic when appending Object or List to VariantBuilder [\#7701](https://github.com/apache/arrow-rs/issues/7701) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Slicing a single-field dense union array creates an array with incorrect `logical_nulls` length [\#7647](https://github.com/apache/arrow-rs/issues/7647) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Ensure page encoding statistics are written to Parquet file [\#7643](https://github.com/apache/arrow-rs/pull/7643) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) + +**Documentation updates:** + +- Minor: Upate `cast_with_options` docs about casting integers --\> intervals [\#8002](https://github.com/apache/arrow-rs/pull/8002) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- docs: More docs to `BatchCoalescer` [\#7891](https://github.com/apache/arrow-rs/pull/7891) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([2010YOUY01](https://github.com/2010YOUY01)) +- chore: fix a typo in `ExtensionType::supports_data_type` docs [\#7682](https://github.com/apache/arrow-rs/pull/7682) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- \[Variant\] Add variant docs and examples [\#7661](https://github.com/apache/arrow-rs/pull/7661) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Minor: Add version to deprecation notice for `ParquetMetaDataReader::decode_footer` [\#7639](https://github.com/apache/arrow-rs/pull/7639) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) + +**Performance improvements:** + +- `RowConverter` on list should only encode the sliced list values and not the entire data [\#7993](https://github.com/apache/arrow-rs/issues/7993) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Avoid extra allocation in list builder [\#7977](https://github.com/apache/arrow-rs/issues/7977) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Convert JSON to Variant with fewer copies [\#7964](https://github.com/apache/arrow-rs/issues/7964) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Optimize sort kernels partition\_validity method [\#7936](https://github.com/apache/arrow-rs/issues/7936) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Speedup sorting for inline views [\#7857](https://github.com/apache/arrow-rs/issues/7857) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Perf: Investigate and improve parquet writing performance [\#7822](https://github.com/apache/arrow-rs/issues/7822) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Perf: optimize sort string\_view performance [\#7790](https://github.com/apache/arrow-rs/issues/7790) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Clickbench microbenchmark spends significant time in memcmp for not\_empty predicate [\#7766](https://github.com/apache/arrow-rs/issues/7766) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use prefix first for comparisons, resort to data buffer for remaining data on equal values [\#7744](https://github.com/apache/arrow-rs/issues/7744) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Change use of `inline_value` to inline it to a u128 [\#7743](https://github.com/apache/arrow-rs/issues/7743) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add efficient way to upgrade keys for additional dictionary builders [\#7654](https://github.com/apache/arrow-rs/issues/7654) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Perf: Make sort string view fast\(1.5X ~ 3X faster\) [\#7792](https://github.com/apache/arrow-rs/pull/7792) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- Add specialized coalesce path for PrimitiveArrays [\#7772](https://github.com/apache/arrow-rs/pull/7772) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) + +**Closed issues:** + +- Implement full-range `i256::to_f64` to replace current ±∞ saturation for Decimal256 → Float64 [\#7985](https://github.com/apache/arrow-rs/issues/7985) +- \[Variant\] `impl FromIterator` fpr `VariantPath` [\#7955](https://github.com/apache/arrow-rs/issues/7955) +- `validated` and `is_fully_validated` flags doesn't need to be part of PartialEq [\#7952](https://github.com/apache/arrow-rs/issues/7952) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] remove VariantMetadata::dictionary\_size [\#7947](https://github.com/apache/arrow-rs/issues/7947) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Improve `VariantArray` performance by storing the index of the metadata and value arrays [\#7920](https://github.com/apache/arrow-rs/issues/7920) +- \[Variant\] Converting variant to JSON string seems slow [\#7869](https://github.com/apache/arrow-rs/issues/7869) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Present Variant at Iceberg Summit NYC July 10, 2025 [\#7858](https://github.com/apache/arrow-rs/issues/7858) +- \[Variant\] Avoid second copy of field name in MetadataBuilder [\#7814](https://github.com/apache/arrow-rs/issues/7814) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Remove APIs deprecated in or before 54.0.0 [\#7810](https://github.com/apache/arrow-rs/issues/7810) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- \[Variant\] Make it harder to forget to finish a pending parent i n ObjectBuilder [\#7798](https://github.com/apache/arrow-rs/issues/7798) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Remove explicit ObjectBuilder::finish\(\) and ListBuilder::finish and move to `Drop` impl [\#7780](https://github.com/apache/arrow-rs/issues/7780) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Reduce repetition in tests for arrow-row/src/run.rs [\#7692](https://github.com/apache/arrow-rs/issues/7692) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Add tests for invalid variant values \(aka verify invalid inputs\) [\#7681](https://github.com/apache/arrow-rs/issues/7681) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Introduce structs for Variant::Decimal types [\#7660](https://github.com/apache/arrow-rs/issues/7660) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Merged pull requests:** + +- Add benchmark for converting StringViewArray with mixed short and long strings [\#8015](https://github.com/apache/arrow-rs/pull/8015) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ding-young](https://github.com/ding-young)) +- \[Variant\] impl FromIterator for VariantPath [\#8011](https://github.com/apache/arrow-rs/pull/8011) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([sdf-jkl](https://github.com/sdf-jkl)) +- Create empty buffer for a buffer specified in the C Data Interface with length zero [\#8009](https://github.com/apache/arrow-rs/pull/8009) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- bench: add benchmark for converting list and sliced list to row format [\#8008](https://github.com/apache/arrow-rs/pull/8008) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- bench: benchmark interleave structs [\#8007](https://github.com/apache/arrow-rs/pull/8007) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- \[Parquet\] Allow writing compatible DictionaryArrays to parquet writer [\#8005](https://github.com/apache/arrow-rs/pull/8005) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([albertlockett](https://github.com/albertlockett)) +- doc: remove outdated info from CONTRIBUTING doc in project root dir. [\#7998](https://github.com/apache/arrow-rs/pull/7998) ([sonhmai](https://github.com/sonhmai)) +- perf: only encode actual list values in `RowConverter` \(16-26 times faster for small sliced list\) [\#7996](https://github.com/apache/arrow-rs/pull/7996) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- test: add tests for converting sliced list to row based [\#7994](https://github.com/apache/arrow-rs/pull/7994) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- perf: Improve `interleave` performance for struct \(3-6 times faster\) [\#7991](https://github.com/apache/arrow-rs/pull/7991) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- \[Variant\] Avoid extra buffer allocation in ListBuilder [\#7987](https://github.com/apache/arrow-rs/pull/7987) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([klion26](https://github.com/klion26)) +- Implement full-range `i256::to_f64` to eliminate ±∞ saturation for Decimal256 → Float64 casts [\#7986](https://github.com/apache/arrow-rs/pull/7986) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kosiew](https://github.com/kosiew)) +- Minor: Restore warning comment on Int96 statistics read [\#7975](https://github.com/apache/arrow-rs/pull/7975) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Add additional integration tests to arrow-avro [\#7974](https://github.com/apache/arrow-rs/pull/7974) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nathaniel-d-ef](https://github.com/nathaniel-d-ef)) +- Perf: optimize actual\_buffer\_size to use only data buffer capacity for coalesce [\#7967](https://github.com/apache/arrow-rs/pull/7967) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- Implement Improved arrow-avro Reader Zero-Byte Record Handling [\#7966](https://github.com/apache/arrow-rs/pull/7966) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Perf: improve sort via `partition_validity` to use fast path for bit map scan \(up to 30% faster\) [\#7962](https://github.com/apache/arrow-rs/pull/7962) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- \[Variant\] Revisit VariantMetadata and Object equality [\#7961](https://github.com/apache/arrow-rs/pull/7961) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- \[Variant\] Add ListBuilder::with\_value for convenience [\#7959](https://github.com/apache/arrow-rs/pull/7959) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([codephage2020](https://github.com/codephage2020)) +- \[Variant\] remove VariantMetadata::dictionary\_size [\#7958](https://github.com/apache/arrow-rs/pull/7958) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([codephage2020](https://github.com/codephage2020)) +- \[Variant\] VariantMetadata is allowed to contain the empty string [\#7956](https://github.com/apache/arrow-rs/pull/7956) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Add arrow-avro support for Impala Nullability [\#7954](https://github.com/apache/arrow-rs/pull/7954) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([veronica-m-ef](https://github.com/veronica-m-ef)) +- \[Test\] Add tests for VariantList equality [\#7953](https://github.com/apache/arrow-rs/pull/7953) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- \[Variant\] Add ObjectBuilder::with\_field for convenience [\#7950](https://github.com/apache/arrow-rs/pull/7950) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- \[Variant\] Adding code to store metadata and value references in VariantArray [\#7945](https://github.com/apache/arrow-rs/pull/7945) ([abacef](https://github.com/abacef)) +- \[Variant\] Add `variant_kernels` benchmark [\#7944](https://github.com/apache/arrow-rs/pull/7944) ([alamb](https://github.com/alamb)) +- \[Variant\] Impl `PartialEq` for VariantObject [\#7943](https://github.com/apache/arrow-rs/pull/7943) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- \[Variant\] Add documentation, tests and cleaner api for Variant::get\_path [\#7942](https://github.com/apache/arrow-rs/pull/7942) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- arrow-ipc: Remove all abilities to preserve dict IDs [\#7940](https://github.com/apache/arrow-rs/pull/7940) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([brancz](https://github.com/brancz)) +- Optimize partition\_validity function used in sort kernels [\#7937](https://github.com/apache/arrow-rs/pull/7937) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- \[Variant\] Avoid extra allocation in object builder [\#7935](https://github.com/apache/arrow-rs/pull/7935) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([klion26](https://github.com/klion26)) +- \[Variant\] Avoid collecting offset iterator [\#7934](https://github.com/apache/arrow-rs/pull/7934) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([codephage2020](https://github.com/codephage2020)) +- Minor: Support BinaryView and StringView builders in `make_builder` [\#7931](https://github.com/apache/arrow-rs/pull/7931) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kylebarron](https://github.com/kylebarron)) +- chore: bump MSRV to 1.84 [\#7926](https://github.com/apache/arrow-rs/pull/7926) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([mbrobbel](https://github.com/mbrobbel)) +- Update bzip2 requirement from 0.4.4 to 0.6.0 [\#7924](https://github.com/apache/arrow-rs/pull/7924) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- \[Variant\] Reserve capacity beforehand during large object building [\#7922](https://github.com/apache/arrow-rs/pull/7922) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- \[Variant\] Add `variant_get` compute kernel [\#7919](https://github.com/apache/arrow-rs/pull/7919) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Samyak2](https://github.com/Samyak2)) +- Improve memory usage for `arrow-row -> String/BinaryView` when utf8 validation disabled [\#7917](https://github.com/apache/arrow-rs/pull/7917) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ding-young](https://github.com/ding-young)) +- Restructure compare\_greater function used in parquet statistics for better performance [\#7916](https://github.com/apache/arrow-rs/pull/7916) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jhorstmann](https://github.com/jhorstmann)) +- \[Variant\] Support appending complex variants in `VariantBuilder` [\#7914](https://github.com/apache/arrow-rs/pull/7914) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- \[Variant\] Add `VariantBuilder::new_with_buffers` to write to existing buffers [\#7912](https://github.com/apache/arrow-rs/pull/7912) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Convert JSON to VariantArray without copying \(8 - 32% faster\) [\#7911](https://github.com/apache/arrow-rs/pull/7911) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- \[Variant\] Use simdutf8 for UTF-8 validation [\#7908](https://github.com/apache/arrow-rs/pull/7908) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([codephage2020](https://github.com/codephage2020)) +- \[Variant\] Avoid superflous validation checks [\#7906](https://github.com/apache/arrow-rs/pull/7906) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- Add `VariantArray` and `VariantArrayBuilder` for constructing Arrow Arrays of Variants [\#7905](https://github.com/apache/arrow-rs/pull/7905) ([alamb](https://github.com/alamb)) +- Update sysinfo requirement from 0.35.0 to 0.36.0 [\#7904](https://github.com/apache/arrow-rs/pull/7904) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Fix current CI failure [\#7898](https://github.com/apache/arrow-rs/pull/7898) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Remove redundant is\_err checks in Variant tests [\#7897](https://github.com/apache/arrow-rs/pull/7897) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- \[Variant\] test: add variant object tests with different sizes [\#7896](https://github.com/apache/arrow-rs/pull/7896) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([odysa](https://github.com/odysa)) +- \[Variant\] Define basic convenience methods for variant pathing [\#7894](https://github.com/apache/arrow-rs/pull/7894) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- fix: `view_types` benchmark slice should follow by correct len array [\#7892](https://github.com/apache/arrow-rs/pull/7892) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- Add arrow-avro support for bzip2 and xz compression [\#7890](https://github.com/apache/arrow-rs/pull/7890) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Add arrow-avro support for Duration type and minor fixes for UUID decoding [\#7889](https://github.com/apache/arrow-rs/pull/7889) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- \[Variant\] Reduce variant-related struct sizes [\#7888](https://github.com/apache/arrow-rs/pull/7888) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Fix panic on lossy decimal to float casting: round to saturation for overflows [\#7887](https://github.com/apache/arrow-rs/pull/7887) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kosiew](https://github.com/kosiew)) +- Add tests for invalid variant metadata and value [\#7885](https://github.com/apache/arrow-rs/pull/7885) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- \[Variant\] Introduce parquet-variant-compute crate to transform batches of JSON strings to and from Variants [\#7884](https://github.com/apache/arrow-rs/pull/7884) ([harshmotw-db](https://github.com/harshmotw-db)) +- feat: support `MapArray` in lexsort [\#7882](https://github.com/apache/arrow-rs/pull/7882) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- fix: mark `DataType::Map` as unsupported in `RowConverter` [\#7880](https://github.com/apache/arrow-rs/pull/7880) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- \[Variant\] Speedup validation [\#7878](https://github.com/apache/arrow-rs/pull/7878) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- benchmark: Add StringViewArray gc benchmark with not null cases [\#7877](https://github.com/apache/arrow-rs/pull/7877) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- \[ARROW-RS-7820\]\[Variant\] Add tests for large variant lists [\#7876](https://github.com/apache/arrow-rs/pull/7876) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([klion26](https://github.com/klion26)) +- fix: Incorrect inlined string view comparison after Add prefix compar… [\#7875](https://github.com/apache/arrow-rs/pull/7875) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- perf: speed up StringViewArray gc 1.4 ~5.x faster [\#7873](https://github.com/apache/arrow-rs/pull/7873) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- \[Variant\] Remove superflous validate call and rename methods [\#7871](https://github.com/apache/arrow-rs/pull/7871) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- Benchmark: Add rich testing cases for sort string\(utf8\) [\#7867](https://github.com/apache/arrow-rs/pull/7867) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- chore: update link for `row_filter.rs` [\#7866](https://github.com/apache/arrow-rs/pull/7866) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([haohuaijin](https://github.com/haohuaijin)) +- \[Variant\] List and object builders have no effect until finalized [\#7865](https://github.com/apache/arrow-rs/pull/7865) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Added number to string benches for json\_writer [\#7864](https://github.com/apache/arrow-rs/pull/7864) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([abacef](https://github.com/abacef)) +- \[Variant\] Introduce `parquet-variant-json` crate [\#7862](https://github.com/apache/arrow-rs/pull/7862) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- \[Variant\] Remove dead code, add comments [\#7861](https://github.com/apache/arrow-rs/pull/7861) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Speedup sorting for inline views: 1.4x - 1.7x improvement [\#7856](https://github.com/apache/arrow-rs/pull/7856) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Fix union slice logical\_nulls length [\#7855](https://github.com/apache/arrow-rs/pull/7855) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([codephage2020](https://github.com/codephage2020)) +- Add `get_ref/get_mut` to JSON Writer [\#7854](https://github.com/apache/arrow-rs/pull/7854) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([cetra3](https://github.com/cetra3)) +- \[Minor\] Add Benchmark for RowConverter::append [\#7853](https://github.com/apache/arrow-rs/pull/7853) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Add Enum type support to arrow-avro and Minor Decimal type fix [\#7852](https://github.com/apache/arrow-rs/pull/7852) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- CSV error message has values transposed [\#7851](https://github.com/apache/arrow-rs/pull/7851) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Omega359](https://github.com/Omega359)) +- \[Variant\] Fuzz testing and benchmarks for vaildation [\#7849](https://github.com/apache/arrow-rs/pull/7849) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([carpecodeum](https://github.com/carpecodeum)) +- \[Variant\] Follow up nits and uncomment test cases [\#7846](https://github.com/apache/arrow-rs/pull/7846) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- \[Variant\] Make sure ObjectBuilder and ListBuilder to be finalized before its parent builder [\#7843](https://github.com/apache/arrow-rs/pull/7843) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- Add decimal32 and decimal64 support to Parquet, JSON and CSV readers and writers [\#7841](https://github.com/apache/arrow-rs/pull/7841) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([CurtHagenlocher](https://github.com/CurtHagenlocher)) +- Implement arrow-avro Reader and ReaderBuilder [\#7834](https://github.com/apache/arrow-rs/pull/7834) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- \[Variant\] Support creating sorted dictionaries [\#7833](https://github.com/apache/arrow-rs/pull/7833) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- Add Decimal type support to arrow-avro [\#7832](https://github.com/apache/arrow-rs/pull/7832) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Allow concating struct arrays with no fields [\#7829](https://github.com/apache/arrow-rs/pull/7829) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([AdamGS](https://github.com/AdamGS)) +- Add features to configure flate2 [\#7827](https://github.com/apache/arrow-rs/pull/7827) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zeevm](https://github.com/zeevm)) +- make builder public under experimental [\#7825](https://github.com/apache/arrow-rs/pull/7825) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Improvements for parquet writing performance \(25%-44%\) [\#7824](https://github.com/apache/arrow-rs/pull/7824) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- Use in-memory buffer for arrow\_writer benchmark [\#7823](https://github.com/apache/arrow-rs/pull/7823) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jhorstmann](https://github.com/jhorstmann)) +- \[Variant\] impl \[Try\]From for VariantDecimalXX types [\#7809](https://github.com/apache/arrow-rs/pull/7809) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- \[Variant\] Speedup `ObjectBuilder` \(62x faster\) [\#7808](https://github.com/apache/arrow-rs/pull/7808) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- \[VARIANT\] Support both fallible and infallible access to variants [\#7807](https://github.com/apache/arrow-rs/pull/7807) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Minor: fix clippy in parquet-variant after logical conflict [\#7803](https://github.com/apache/arrow-rs/pull/7803) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- \[Variant\] Add flag in `ObjectBuilder` to control validation behavior on duplicate field write [\#7801](https://github.com/apache/arrow-rs/pull/7801) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([micoo227](https://github.com/micoo227)) +- Fix clippy for Rust 1.88 release [\#7797](https://github.com/apache/arrow-rs/pull/7797) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- \[Variant\] Simplify `Builder` buffer operations [\#7795](https://github.com/apache/arrow-rs/pull/7795) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- fix: Change panic to error in`take` kernel for StringArrary/BinaryArray on overflow [\#7793](https://github.com/apache/arrow-rs/pull/7793) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([chenkovsky](https://github.com/chenkovsky)) +- Update base64 requirement from 0.21 to 0.22 [\#7791](https://github.com/apache/arrow-rs/pull/7791) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Fix RowConverter when FixedSizeList is not the last [\#7789](https://github.com/apache/arrow-rs/pull/7789) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Add schema with only primitive arrays to `coalesce_kernel` benchmark [\#7788](https://github.com/apache/arrow-rs/pull/7788) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add sort\_kernel benchmark for StringViewArray case [\#7787](https://github.com/apache/arrow-rs/pull/7787) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- \[Variant\] Check pending before `VariantObject::insert` [\#7786](https://github.com/apache/arrow-rs/pull/7786) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- \[VARIANT\] impl Display for VariantDecimalXX [\#7785](https://github.com/apache/arrow-rs/pull/7785) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([scovich](https://github.com/scovich)) +- \[VARIANT\] Add support for the json\_to\_variant API [\#7783](https://github.com/apache/arrow-rs/pull/7783) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([harshmotw-db](https://github.com/harshmotw-db)) +- \[Variant\] Consolidate examples for json writing [\#7782](https://github.com/apache/arrow-rs/pull/7782) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Add benchmark for about view array slice [\#7781](https://github.com/apache/arrow-rs/pull/7781) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ctsk](https://github.com/ctsk)) +- \[Variant\] Add negative tests for reading invalid primitive variant values [\#7779](https://github.com/apache/arrow-rs/pull/7779) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([superserious-dev](https://github.com/superserious-dev)) +- \[Variant\] Support creating nested objects and object with lists [\#7778](https://github.com/apache/arrow-rs/pull/7778) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- \[VARIANT\] Validate precision in VariantDecimalXX structs and add missing tests [\#7776](https://github.com/apache/arrow-rs/pull/7776) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Add tests for `BatchCoalescer::push_batch_with_filter`, fix bug [\#7774](https://github.com/apache/arrow-rs/pull/7774) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- \[Variant\] Minor: make fields in `VariantDecimal*` private, add examples [\#7770](https://github.com/apache/arrow-rs/pull/7770) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Extend the fast path in GenericByteViewArray::is\_eq for comparing against empty strings [\#7767](https://github.com/apache/arrow-rs/pull/7767) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- \[Variant\] Improve getter API for `VariantList` and `VariantObject` [\#7757](https://github.com/apache/arrow-rs/pull/7757) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- \[Variant\] Add Variant::as\_object and Variant::as\_list [\#7755](https://github.com/apache/arrow-rs/pull/7755) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- \[Variant\] Fix several overflow panic risks for 32-bit arch [\#7752](https://github.com/apache/arrow-rs/pull/7752) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Add testing section to pull request template [\#7749](https://github.com/apache/arrow-rs/pull/7749) ([alamb](https://github.com/alamb)) +- Perf: Add prefix compare for inlined compare and change use of inline\_value to inline it to a u128 [\#7748](https://github.com/apache/arrow-rs/pull/7748) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- Move arrow-pyarrow tests that require `pyarrow` to be installed into `arrow-pyarrow-testing` crate [\#7742](https://github.com/apache/arrow-rs/pull/7742) ([alamb](https://github.com/alamb)) +- \[Variant\] Improve write API in `Variant::Object` [\#7741](https://github.com/apache/arrow-rs/pull/7741) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- \[Variant\] Support nested lists and object lists [\#7740](https://github.com/apache/arrow-rs/pull/7740) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- feat: \[Variant\] Add Validation for Variant Deciaml [\#7738](https://github.com/apache/arrow-rs/pull/7738) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Weijun-H](https://github.com/Weijun-H)) +- Add fallible versions of temporal functions that may panic [\#7737](https://github.com/apache/arrow-rs/pull/7737) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([adriangb](https://github.com/adriangb)) +- fix: Implement support for appending Object and List variants in VariantBuilder [\#7735](https://github.com/apache/arrow-rs/pull/7735) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Weijun-H](https://github.com/Weijun-H)) +- parquet\_derive: update in working example for ParquetRecordWriter [\#7733](https://github.com/apache/arrow-rs/pull/7733) ([LanHikari22](https://github.com/LanHikari22)) +- Perf: Optimize comparison kernels for inlined views [\#7731](https://github.com/apache/arrow-rs/pull/7731) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- arrow-row: Refactor arrow-row REE roundtrip tests [\#7729](https://github.com/apache/arrow-rs/pull/7729) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- arrow-array: Implement PartialEq for RunArray [\#7727](https://github.com/apache/arrow-rs/pull/7727) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- fix: Do not add null buffer for `NullArray` in MutableArrayData [\#7726](https://github.com/apache/arrow-rs/pull/7726) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([comphead](https://github.com/comphead)) +- Allow per-column parquet dictionary page size limit [\#7724](https://github.com/apache/arrow-rs/pull/7724) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XiangpengHao](https://github.com/XiangpengHao)) +- fix JSON decoder error checking for UTF16 / surrogate parsing panic [\#7721](https://github.com/apache/arrow-rs/pull/7721) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nicklan](https://github.com/nicklan)) +- \[Variant\] Use `BTreeMap` for `VariantBuilder.dict` and `ObjectBuilder.fields` to maintain invariants upon entry writes [\#7720](https://github.com/apache/arrow-rs/pull/7720) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- Introduce `MAX_INLINE_VIEW_LEN` constant for string/byte views [\#7719](https://github.com/apache/arrow-rs/pull/7719) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- \[Variant\] Introduce new type over &str for ShortString [\#7718](https://github.com/apache/arrow-rs/pull/7718) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- Split out variant code into several new sub-modules [\#7717](https://github.com/apache/arrow-rs/pull/7717) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- add `garbage_collect_dictionary` to `arrow-select` [\#7716](https://github.com/apache/arrow-rs/pull/7716) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([davidhewitt](https://github.com/davidhewitt)) +- Support write to buffer api for SerializedFileWriter [\#7714](https://github.com/apache/arrow-rs/pull/7714) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- Support `FixedSizeList` RowConverter [\#7705](https://github.com/apache/arrow-rs/pull/7705) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Make variant iterators safely infallible [\#7704](https://github.com/apache/arrow-rs/pull/7704) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Speedup `interleave_views` \(4-7x faster\) [\#7695](https://github.com/apache/arrow-rs/pull/7695) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Define a "arrow-pyrarrow" crate to implement the "pyarrow" feature. [\#7694](https://github.com/apache/arrow-rs/pull/7694) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brunal](https://github.com/brunal)) +- feat: add constructor to efficiently upgrade dict key type to remaining builders [\#7689](https://github.com/apache/arrow-rs/pull/7689) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([albertlockett](https://github.com/albertlockett)) +- Document REE row format and add some more tests [\#7680](https://github.com/apache/arrow-rs/pull/7680) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- feat: add min max aggregate support for FixedSizeBinary [\#7675](https://github.com/apache/arrow-rs/pull/7675) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alexwilcoxson-rel](https://github.com/alexwilcoxson-rel)) +- arrow-data: Add REE support for `build_extend` and `build_extend_nulls` [\#7671](https://github.com/apache/arrow-rs/pull/7671) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- Variant: Write Variant Values as JSON [\#7670](https://github.com/apache/arrow-rs/pull/7670) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([carpecodeum](https://github.com/carpecodeum)) +- Remove `lazy_static` dependency [\#7669](https://github.com/apache/arrow-rs/pull/7669) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Expyron](https://github.com/Expyron)) +- Finish implementing Variant::Object and Variant::List [\#7666](https://github.com/apache/arrow-rs/pull/7666) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Add `RecordBatch::schema_metadata_mut` and `Field::metadata_mut` [\#7664](https://github.com/apache/arrow-rs/pull/7664) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([emilk](https://github.com/emilk)) +- \[Variant\] Simplify creation of Variants from metadata and value [\#7663](https://github.com/apache/arrow-rs/pull/7663) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- chore: group prost dependabot updates [\#7659](https://github.com/apache/arrow-rs/pull/7659) ([mbrobbel](https://github.com/mbrobbel)) +- Initial Builder API for Creating Variant Values [\#7653](https://github.com/apache/arrow-rs/pull/7653) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([PinkCrow007](https://github.com/PinkCrow007)) +- Add `BatchCoalescer::push_filtered_batch` and docs [\#7652](https://github.com/apache/arrow-rs/pull/7652) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Optimize coalesce kernel for StringView \(10-50% faster\) [\#7650](https://github.com/apache/arrow-rs/pull/7650) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- arrow-row: Add support for REE [\#7649](https://github.com/apache/arrow-rs/pull/7649) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- Use approximate comparisons for pow tests [\#7646](https://github.com/apache/arrow-rs/pull/7646) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([adamreeve](https://github.com/adamreeve)) +- \[Variant\] Implement read support for remaining primitive types [\#7644](https://github.com/apache/arrow-rs/pull/7644) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([superserious-dev](https://github.com/superserious-dev)) +- Add `pretty_format_batches_with_schema` function [\#7642](https://github.com/apache/arrow-rs/pull/7642) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([lewiszlw](https://github.com/lewiszlw)) +- Deprecate old Parquet page index parsing functions [\#7640](https://github.com/apache/arrow-rs/pull/7640) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Update FlightSQL `GetDbSchemas` and `GetTables` schemas to fully match the protocol [\#7638](https://github.com/apache/arrow-rs/pull/7638) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([sgrebnov](https://github.com/sgrebnov)) +- Minor: Remove outdated FIXME from `ParquetMetaDataReader` [\#7635](https://github.com/apache/arrow-rs/pull/7635) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Fix the error info of `StructArray::try_new` [\#7634](https://github.com/apache/arrow-rs/pull/7634) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xudong963](https://github.com/xudong963)) +- Fix reading encrypted Parquet pages when using the page index [\#7633](https://github.com/apache/arrow-rs/pull/7633) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([adamreeve](https://github.com/adamreeve)) +- \[Variant\] Add commented out primitive test casees [\#7631](https://github.com/apache/arrow-rs/pull/7631) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) ## [55.2.0](https://github.com/apache/arrow-rs/tree/55.2.0) (2025-06-22) - Add a `strong_count` method to `Buffer` [\#7568](https://github.com/apache/arrow-rs/issues/7568) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b707d30a3db..b35d9b28a747 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,281 +19,138 @@ # Changelog -## [56.0.0](https://github.com/apache/arrow-rs/tree/56.0.0) (2025-07-29) +## [56.1.0](https://github.com/apache/arrow-rs/tree/56.1.0) (2025-08-21) -[Full Changelog](https://github.com/apache/arrow-rs/compare/55.2.0...56.0.0) - -**Breaking changes:** - -- arrow-schema: Remove dict\_id from being required equal for merging [\#7968](https://github.com/apache/arrow-rs/pull/7968) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) -- \[Parquet\] Use `u64` for `SerializedPageReaderState.offset` & `remaining_bytes`, instead of `usize` [\#7918](https://github.com/apache/arrow-rs/pull/7918) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([JigaoLuo](https://github.com/JigaoLuo)) -- Upgrade tonic dependencies to 0.13.0 version \(try 2\) [\#7839](https://github.com/apache/arrow-rs/pull/7839) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) -- Remove deprecated Arrow functions [\#7830](https://github.com/apache/arrow-rs/pull/7830) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([etseidl](https://github.com/etseidl)) -- Remove deprecated temporal functions [\#7813](https://github.com/apache/arrow-rs/pull/7813) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([etseidl](https://github.com/etseidl)) -- Remove functions from parquet crate deprecated in or before 54.0.0 [\#7811](https://github.com/apache/arrow-rs/pull/7811) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) -- GH-7686: \[Parquet\] Fix int96 min/max stats [\#7687](https://github.com/apache/arrow-rs/pull/7687) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([rahulketch](https://github.com/rahulketch)) +[Full Changelog](https://github.com/apache/arrow-rs/compare/56.0.0...56.1.0) **Implemented enhancements:** -- \[parquet\] Relax type restriction to allow writing dictionary/native batches for same column [\#8004](https://github.com/apache/arrow-rs/issues/8004) -- Support casting int64 to interval [\#7988](https://github.com/apache/arrow-rs/issues/7988) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[Variant\] Add `ListBuilder::with_value` for convenience [\#7951](https://github.com/apache/arrow-rs/issues/7951) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] Add `ObjectBuilder::with_field` for convenience [\#7949](https://github.com/apache/arrow-rs/issues/7949) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] Impl PartialEq for VariantObject \#7943 [\#7948](https://github.com/apache/arrow-rs/issues/7948) -- \[Variant\] Offer `simdutf8` as an optional dependency when validating metadata [\#7902](https://github.com/apache/arrow-rs/issues/7902) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[Variant\] Avoid collecting offset iterator [\#7901](https://github.com/apache/arrow-rs/issues/7901) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] Remove superfluous check when validating monotonic offsets [\#7900](https://github.com/apache/arrow-rs/issues/7900) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] Avoid extra allocation in `ObjectBuilder` [\#7899](https://github.com/apache/arrow-rs/issues/7899) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\]\[Compute\] `variant_get` kernel [\#7893](https://github.com/apache/arrow-rs/issues/7893) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\]\[Compute\] Add batch processing for Variant-JSON String conversion [\#7883](https://github.com/apache/arrow-rs/issues/7883) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Support `MapArray` in lexsort [\#7881](https://github.com/apache/arrow-rs/issues/7881) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[Variant\] Add testing for invalid variants \(fuzz testing??\) [\#7842](https://github.com/apache/arrow-rs/issues/7842) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] VariantMetadata, VariantList and VariantObject are too big for Copy [\#7831](https://github.com/apache/arrow-rs/issues/7831) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Allow choosing flate2 backend [\#7826](https://github.com/apache/arrow-rs/issues/7826) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] Tests for creating "large" `VariantObjects`s [\#7821](https://github.com/apache/arrow-rs/issues/7821) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] Tests for creating "large" `VariantList`s [\#7820](https://github.com/apache/arrow-rs/issues/7820) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] Support VariantBuilder to write to buffers owned by the caller [\#7805](https://github.com/apache/arrow-rs/issues/7805) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] Move JSON related functionality to different crate. [\#7800](https://github.com/apache/arrow-rs/issues/7800) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] Add flag in `ObjectBuilder` to control validation behavior on duplicate field write [\#7777](https://github.com/apache/arrow-rs/issues/7777) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] make `serde_json` an optional dependency of `parquet-variant` [\#7775](https://github.com/apache/arrow-rs/issues/7775) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[coalesce\] Implement specialized `BatchCoalescer::push_batch` for `PrimitiveArray` [\#7763](https://github.com/apache/arrow-rs/issues/7763) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add sort\_kernel benchmark for StringViewArray case [\#7758](https://github.com/apache/arrow-rs/issues/7758) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[Variant\] Improved API for accessing Variant Objects and lists [\#7756](https://github.com/apache/arrow-rs/issues/7756) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Buildable reproducible release builds [\#7751](https://github.com/apache/arrow-rs/issues/7751) -- Allow per-column parquet dictionary page size limit [\#7723](https://github.com/apache/arrow-rs/issues/7723) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] Test and implement efficient building for "large" Arrays [\#7699](https://github.com/apache/arrow-rs/issues/7699) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] Improve VariantBuilder when creating field name dictionaries / sorted dictionaries [\#7698](https://github.com/apache/arrow-rs/issues/7698) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] Add input validation in `VariantBuilder` [\#7697](https://github.com/apache/arrow-rs/issues/7697) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] Support Nested Data in `VariantBuilder` [\#7696](https://github.com/apache/arrow-rs/issues/7696) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Parquet: Incorrect min/max stats for int96 columns [\#7686](https://github.com/apache/arrow-rs/issues/7686) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Add `DictionaryArray::gc` method [\#7683](https://github.com/apache/arrow-rs/issues/7683) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[Variant\] Add negative tests for reading invalid primitive variant values [\#7645](https://github.com/apache/arrow-rs/issues/7645) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Implement cast and other operations on decimal32 and decimal64 \#7815 [\#8204](https://github.com/apache/arrow-rs/issues/8204) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Speed up Parquet filter pushdown with predicate cache [\#8203](https://github.com/apache/arrow-rs/issues/8203) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Optionally read parquet page indexes [\#8070](https://github.com/apache/arrow-rs/issues/8070) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Parquet reader: add method for sync reader read bloom filter [\#8023](https://github.com/apache/arrow-rs/issues/8023) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[parquet\] Support writing logically equivalent types to `ArrowWriter` [\#8012](https://github.com/apache/arrow-rs/issues/8012) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Improve StringArray\(Utf8\) sort performance [\#7847](https://github.com/apache/arrow-rs/issues/7847) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- feat: arrow-ipc delta dictionary support [\#8001](https://github.com/apache/arrow-rs/pull/8001) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([JakeDern](https://github.com/JakeDern)) **Fixed bugs:** -- \[Variant\] Panic when appending nested objects to VariantBuilder [\#7907](https://github.com/apache/arrow-rs/issues/7907) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Panic when casting large Decimal256 to f64 due to unchecked `unwrap()` [\#7886](https://github.com/apache/arrow-rs/issues/7886) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Incorrect inlined string view comparison after " Add prefix compare for inlined" [\#7874](https://github.com/apache/arrow-rs/issues/7874) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[Variant\] `test_json_to_variant_object_very_large` takes over 20s [\#7872](https://github.com/apache/arrow-rs/issues/7872) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] If `ObjectBuilder::finalize` is not called, the resulting Variant object is malformed. [\#7863](https://github.com/apache/arrow-rs/issues/7863) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- CSV error message has values transposed [\#7848](https://github.com/apache/arrow-rs/issues/7848) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Concating struct arrays with no fields unnecessarily errors [\#7828](https://github.com/apache/arrow-rs/issues/7828) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Clippy CI is failing on main after Rust `1.88` upgrade [\#7796](https://github.com/apache/arrow-rs/issues/7796) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] -- \[Variant\] Field lookup with out of bounds index causes unwanted behavior [\#7784](https://github.com/apache/arrow-rs/issues/7784) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Error verifying `parquet-variant` crate on 55.2.0 with `verify-release-candidate.sh` [\#7746](https://github.com/apache/arrow-rs/issues/7746) -- `test_to_pyarrow` tests fail during release verification [\#7736](https://github.com/apache/arrow-rs/issues/7736) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[parquet\_derive\] Example for ParquetRecordWriter is broken. [\#7732](https://github.com/apache/arrow-rs/issues/7732) -- \[Variant\] `Variant::Object` can contain two fields with the same field name [\#7730](https://github.com/apache/arrow-rs/issues/7730) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] Panic when appending Object or List to VariantBuilder [\#7701](https://github.com/apache/arrow-rs/issues/7701) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Slicing a single-field dense union array creates an array with incorrect `logical_nulls` length [\#7647](https://github.com/apache/arrow-rs/issues/7647) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Ensure page encoding statistics are written to Parquet file [\#7643](https://github.com/apache/arrow-rs/pull/7643) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- The Rustdocs are clean CI job is failing [\#8175](https://github.com/apache/arrow-rs/issues/8175) +- \[avro\] Bug in resolving avro schema with named type [\#8045](https://github.com/apache/arrow-rs/issues/8045) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Doc test failure \(test arrow-avro/src/lib.rs - reader\) when verifying avro 56.0.0 RC1 release [\#8018](https://github.com/apache/arrow-rs/issues/8018) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] **Documentation updates:** -- Minor: Upate `cast_with_options` docs about casting integers --\> intervals [\#8002](https://github.com/apache/arrow-rs/pull/8002) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- docs: More docs to `BatchCoalescer` [\#7891](https://github.com/apache/arrow-rs/pull/7891) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([2010YOUY01](https://github.com/2010YOUY01)) -- chore: fix a typo in `ExtensionType::supports_data_type` docs [\#7682](https://github.com/apache/arrow-rs/pull/7682) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) -- \[Variant\] Add variant docs and examples [\#7661](https://github.com/apache/arrow-rs/pull/7661) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- Minor: Add version to deprecation notice for `ParquetMetaDataReader::decode_footer` [\#7639](https://github.com/apache/arrow-rs/pull/7639) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- arrow-row: Document dictionary handling [\#8168](https://github.com/apache/arrow-rs/pull/8168) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Docs: Clarify that Array::value does not check for nulls [\#8065](https://github.com/apache/arrow-rs/pull/8065) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- docs: Fix a typo in README [\#8036](https://github.com/apache/arrow-rs/pull/8036) ([EricccTaiwan](https://github.com/EricccTaiwan)) +- Add more comments to the internal parquet reader [\#7932](https://github.com/apache/arrow-rs/pull/7932) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) **Performance improvements:** -- `RowConverter` on list should only encode the sliced list values and not the entire data [\#7993](https://github.com/apache/arrow-rs/issues/7993) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[Variant\] Avoid extra allocation in list builder [\#7977](https://github.com/apache/arrow-rs/issues/7977) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] Convert JSON to Variant with fewer copies [\#7964](https://github.com/apache/arrow-rs/issues/7964) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Optimize sort kernels partition\_validity method [\#7936](https://github.com/apache/arrow-rs/issues/7936) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Speedup sorting for inline views [\#7857](https://github.com/apache/arrow-rs/issues/7857) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Perf: Investigate and improve parquet writing performance [\#7822](https://github.com/apache/arrow-rs/issues/7822) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Perf: optimize sort string\_view performance [\#7790](https://github.com/apache/arrow-rs/issues/7790) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Clickbench microbenchmark spends significant time in memcmp for not\_empty predicate [\#7766](https://github.com/apache/arrow-rs/issues/7766) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Use prefix first for comparisons, resort to data buffer for remaining data on equal values [\#7744](https://github.com/apache/arrow-rs/issues/7744) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Change use of `inline_value` to inline it to a u128 [\#7743](https://github.com/apache/arrow-rs/issues/7743) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add efficient way to upgrade keys for additional dictionary builders [\#7654](https://github.com/apache/arrow-rs/issues/7654) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Perf: Make sort string view fast\(1.5X ~ 3X faster\) [\#7792](https://github.com/apache/arrow-rs/pull/7792) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) -- Add specialized coalesce path for PrimitiveArrays [\#7772](https://github.com/apache/arrow-rs/pull/7772) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- perf\(arrow-ipc\): avoid counting nulls in `RecordBatchDecoder` [\#8127](https://github.com/apache/arrow-rs/pull/8127) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- Use `Vec` directly in builders [\#7984](https://github.com/apache/arrow-rs/pull/7984) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liamzwbao](https://github.com/liamzwbao)) +- Improve StringArray\(Utf8\) sort performance \(~2-4x faster\) [\#7860](https://github.com/apache/arrow-rs/pull/7860) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) **Closed issues:** -- Implement full-range `i256::to_f64` to replace current ±∞ saturation for Decimal256 → Float64 [\#7985](https://github.com/apache/arrow-rs/issues/7985) -- \[Variant\] `impl FromIterator` fpr `VariantPath` [\#7955](https://github.com/apache/arrow-rs/issues/7955) -- `validated` and `is_fully_validated` flags doesn't need to be part of PartialEq [\#7952](https://github.com/apache/arrow-rs/issues/7952) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] remove VariantMetadata::dictionary\_size [\#7947](https://github.com/apache/arrow-rs/issues/7947) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] Improve `VariantArray` performance by storing the index of the metadata and value arrays [\#7920](https://github.com/apache/arrow-rs/issues/7920) -- \[Variant\] Converting variant to JSON string seems slow [\#7869](https://github.com/apache/arrow-rs/issues/7869) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] Present Variant at Iceberg Summit NYC July 10, 2025 [\#7858](https://github.com/apache/arrow-rs/issues/7858) -- \[Variant\] Avoid second copy of field name in MetadataBuilder [\#7814](https://github.com/apache/arrow-rs/issues/7814) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Remove APIs deprecated in or before 54.0.0 [\#7810](https://github.com/apache/arrow-rs/issues/7810) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] -- \[Variant\] Make it harder to forget to finish a pending parent i n ObjectBuilder [\#7798](https://github.com/apache/arrow-rs/issues/7798) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] Remove explicit ObjectBuilder::finish\(\) and ListBuilder::finish and move to `Drop` impl [\#7780](https://github.com/apache/arrow-rs/issues/7780) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Reduce repetition in tests for arrow-row/src/run.rs [\#7692](https://github.com/apache/arrow-rs/issues/7692) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[Variant\] Add tests for invalid variant values \(aka verify invalid inputs\) [\#7681](https://github.com/apache/arrow-rs/issues/7681) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] Introduce structs for Variant::Decimal types [\#7660](https://github.com/apache/arrow-rs/issues/7660) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Improve fuzz test for Variant [\#8199](https://github.com/apache/arrow-rs/issues/8199) +- \[Variant\] Improve fuzz test for Variant [\#8198](https://github.com/apache/arrow-rs/issues/8198) +- `VariantArrayBuilder` tracks starting offsets instead of \(offset, len\) pairs [\#8192](https://github.com/apache/arrow-rs/issues/8192) +- Rework `ValueBuilder` API to work with `ParentState` for reliable nested rollbacks [\#8188](https://github.com/apache/arrow-rs/issues/8188) +- \[Variant\] Rename `ValueBuffer` as `ValueBuilder` [\#8186](https://github.com/apache/arrow-rs/issues/8186) +- \[Variant\] Refactor `ParentState` to track and rollback state on behalf of its owning builder [\#8182](https://github.com/apache/arrow-rs/issues/8182) +- \[Variant\] `ObjectBuilder` should detect duplicates at insertion time, not at finish [\#8180](https://github.com/apache/arrow-rs/issues/8180) +- \[Variant\] ObjectBuilder does not reliably check for duplicates [\#8170](https://github.com/apache/arrow-rs/issues/8170) +- [Variant] Support `StringView` and `LargeString` in ´batch_json_string_to_variant` [\#8145](https://github.com/apache/arrow-rs/issues/8145) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Rename `batch_json_string_to_variant` and `batch_variant_to_json_string` json\_to\_variant [\#8144](https://github.com/apache/arrow-rs/issues/8144) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[avro\] Use `tempfile` crate rather than custom temporary file generator in tests [\#8143](https://github.com/apache/arrow-rs/issues/8143) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Avro\] Use `Write` rather `dyn Write` in Decoder [\#8142](https://github.com/apache/arrow-rs/issues/8142) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Nested builder rollback is broken [\#8136](https://github.com/apache/arrow-rs/issues/8136) +- \[Variant\] Add support the remaing primitive type\(timestamp\_nanos/timestampntz\_nanos/uuid\) for parquet variant [\#8126](https://github.com/apache/arrow-rs/issues/8126) +- Meta: Implement missing Arrow 56.0 lint rules - Sequential workflow [\#8121](https://github.com/apache/arrow-rs/issues/8121) +- ARROW-012-015: Add linter rules for remaining Arrow 56.0 breaking changes [\#8120](https://github.com/apache/arrow-rs/issues/8120) +- ARROW-010 & ARROW-011: Add linter rules for Parquet Statistics and Metadata API removals [\#8119](https://github.com/apache/arrow-rs/issues/8119) +- ARROW-009: Add linter rules for IPC Dictionary API removals in Arrow 56.0 [\#8118](https://github.com/apache/arrow-rs/issues/8118) +- ARROW-008: Add linter rule for SerializedPageReaderState usize→u64 breaking change [\#8117](https://github.com/apache/arrow-rs/issues/8117) +- ARROW-007: Add linter rule for Schema.all\_fields\(\) removal in Arrow 56.0 [\#8116](https://github.com/apache/arrow-rs/issues/8116) +- \[Variant\] Implement `ShreddingState::AllNull` variant [\#8088](https://github.com/apache/arrow-rs/issues/8088) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Support Shredded Objects in `variant_get` [\#8083](https://github.com/apache/arrow-rs/issues/8083) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::RunEndEncoded` support for `cast_to_variant` kernel [\#8064](https://github.com/apache/arrow-rs/issues/8064) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Dictionary` support for `cast_to_variant` kernel [\#8062](https://github.com/apache/arrow-rs/issues/8062) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Struct` support for `cast_to_variant` kernel [\#8061](https://github.com/apache/arrow-rs/issues/8061) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Decimal32/Decimal64/Decimal128/Decimal256` support for `cast_to_variant` kernel [\#8059](https://github.com/apache/arrow-rs/issues/8059) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Timestamp(..)` support for `cast_to_variant` kernel [\#8058](https://github.com/apache/arrow-rs/issues/8058) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Float16` support for `cast_to_variant` kernel [\#8057](https://github.com/apache/arrow-rs/issues/8057) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Interval` support for `cast_to_variant` kernel [\#8056](https://github.com/apache/arrow-rs/issues/8056) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Time32/Time64` support for `cast_to_variant` kernel [\#8055](https://github.com/apache/arrow-rs/issues/8055) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Date32 / DataType::Date64` support for `cast_to_variant` kernel [\#8054](https://github.com/apache/arrow-rs/issues/8054) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Null` support for `cast_to_variant` kernel [\#8053](https://github.com/apache/arrow-rs/issues/8053) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Boolean` support for `cast_to_variant` kernel [\#8052](https://github.com/apache/arrow-rs/issues/8052) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::FixedSizeBinary` support for `cast_to_variant` kernel [\#8051](https://github.com/apache/arrow-rs/issues/8051) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Binary/LargeBinary/BinaryView` support for `cast_to_variant` kernel [\#8050](https://github.com/apache/arrow-rs/issues/8050) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Utf8/LargeUtf8/Utf8View` support for `cast_to_variant` kernel [\#8049](https://github.com/apache/arrow-rs/issues/8049) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Implement `cast_to_variant` kernel [\#8043](https://github.com/apache/arrow-rs/issues/8043) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Support `variant_get` kernel for shredded variants [\#7941](https://github.com/apache/arrow-rs/issues/7941) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add test for casting `Decimal128` \(`i128::MIN` and `i128::MAX`\) to `f64` with overflow handling [\#7939](https://github.com/apache/arrow-rs/issues/7939) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] **Merged pull requests:** -- Add benchmark for converting StringViewArray with mixed short and long strings [\#8015](https://github.com/apache/arrow-rs/pull/8015) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ding-young](https://github.com/ding-young)) -- \[Variant\] impl FromIterator for VariantPath [\#8011](https://github.com/apache/arrow-rs/pull/8011) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([sdf-jkl](https://github.com/sdf-jkl)) -- Create empty buffer for a buffer specified in the C Data Interface with length zero [\#8009](https://github.com/apache/arrow-rs/pull/8009) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- bench: add benchmark for converting list and sliced list to row format [\#8008](https://github.com/apache/arrow-rs/pull/8008) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) -- bench: benchmark interleave structs [\#8007](https://github.com/apache/arrow-rs/pull/8007) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) -- \[Parquet\] Allow writing compatible DictionaryArrays to parquet writer [\#8005](https://github.com/apache/arrow-rs/pull/8005) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([albertlockett](https://github.com/albertlockett)) -- doc: remove outdated info from CONTRIBUTING doc in project root dir. [\#7998](https://github.com/apache/arrow-rs/pull/7998) ([sonhmai](https://github.com/sonhmai)) -- perf: only encode actual list values in `RowConverter` \(16-26 times faster for small sliced list\) [\#7996](https://github.com/apache/arrow-rs/pull/7996) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) -- test: add tests for converting sliced list to row based [\#7994](https://github.com/apache/arrow-rs/pull/7994) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) -- perf: Improve `interleave` performance for struct \(3-6 times faster\) [\#7991](https://github.com/apache/arrow-rs/pull/7991) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) -- \[Variant\] Avoid extra buffer allocation in ListBuilder [\#7987](https://github.com/apache/arrow-rs/pull/7987) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([klion26](https://github.com/klion26)) -- Implement full-range `i256::to_f64` to eliminate ±∞ saturation for Decimal256 → Float64 casts [\#7986](https://github.com/apache/arrow-rs/pull/7986) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kosiew](https://github.com/kosiew)) -- Minor: Restore warning comment on Int96 statistics read [\#7975](https://github.com/apache/arrow-rs/pull/7975) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- Add additional integration tests to arrow-avro [\#7974](https://github.com/apache/arrow-rs/pull/7974) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nathaniel-d-ef](https://github.com/nathaniel-d-ef)) -- Perf: optimize actual\_buffer\_size to use only data buffer capacity for coalesce [\#7967](https://github.com/apache/arrow-rs/pull/7967) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) -- Implement Improved arrow-avro Reader Zero-Byte Record Handling [\#7966](https://github.com/apache/arrow-rs/pull/7966) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) -- Perf: improve sort via `partition_validity` to use fast path for bit map scan \(up to 30% faster\) [\#7962](https://github.com/apache/arrow-rs/pull/7962) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) -- \[Variant\] Revisit VariantMetadata and Object equality [\#7961](https://github.com/apache/arrow-rs/pull/7961) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) -- \[Variant\] Add ListBuilder::with\_value for convenience [\#7959](https://github.com/apache/arrow-rs/pull/7959) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([codephage2020](https://github.com/codephage2020)) -- \[Variant\] remove VariantMetadata::dictionary\_size [\#7958](https://github.com/apache/arrow-rs/pull/7958) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([codephage2020](https://github.com/codephage2020)) -- \[Variant\] VariantMetadata is allowed to contain the empty string [\#7956](https://github.com/apache/arrow-rs/pull/7956) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) -- Add arrow-avro support for Impala Nullability [\#7954](https://github.com/apache/arrow-rs/pull/7954) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([veronica-m-ef](https://github.com/veronica-m-ef)) -- \[Test\] Add tests for VariantList equality [\#7953](https://github.com/apache/arrow-rs/pull/7953) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- \[Variant\] Add ObjectBuilder::with\_field for convenience [\#7950](https://github.com/apache/arrow-rs/pull/7950) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- \[Variant\] Adding code to store metadata and value references in VariantArray [\#7945](https://github.com/apache/arrow-rs/pull/7945) ([abacef](https://github.com/abacef)) -- \[Variant\] Add `variant_kernels` benchmark [\#7944](https://github.com/apache/arrow-rs/pull/7944) ([alamb](https://github.com/alamb)) -- \[Variant\] Impl `PartialEq` for VariantObject [\#7943](https://github.com/apache/arrow-rs/pull/7943) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) -- \[Variant\] Add documentation, tests and cleaner api for Variant::get\_path [\#7942](https://github.com/apache/arrow-rs/pull/7942) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- arrow-ipc: Remove all abilities to preserve dict IDs [\#7940](https://github.com/apache/arrow-rs/pull/7940) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([brancz](https://github.com/brancz)) -- Optimize partition\_validity function used in sort kernels [\#7937](https://github.com/apache/arrow-rs/pull/7937) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) -- \[Variant\] Avoid extra allocation in object builder [\#7935](https://github.com/apache/arrow-rs/pull/7935) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([klion26](https://github.com/klion26)) -- \[Variant\] Avoid collecting offset iterator [\#7934](https://github.com/apache/arrow-rs/pull/7934) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([codephage2020](https://github.com/codephage2020)) -- Minor: Support BinaryView and StringView builders in `make_builder` [\#7931](https://github.com/apache/arrow-rs/pull/7931) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kylebarron](https://github.com/kylebarron)) -- chore: bump MSRV to 1.84 [\#7926](https://github.com/apache/arrow-rs/pull/7926) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([mbrobbel](https://github.com/mbrobbel)) -- Update bzip2 requirement from 0.4.4 to 0.6.0 [\#7924](https://github.com/apache/arrow-rs/pull/7924) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) -- \[Variant\] Reserve capacity beforehand during large object building [\#7922](https://github.com/apache/arrow-rs/pull/7922) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) -- \[Variant\] Add `variant_get` compute kernel [\#7919](https://github.com/apache/arrow-rs/pull/7919) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Samyak2](https://github.com/Samyak2)) -- Improve memory usage for `arrow-row -> String/BinaryView` when utf8 validation disabled [\#7917](https://github.com/apache/arrow-rs/pull/7917) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ding-young](https://github.com/ding-young)) -- Restructure compare\_greater function used in parquet statistics for better performance [\#7916](https://github.com/apache/arrow-rs/pull/7916) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jhorstmann](https://github.com/jhorstmann)) -- \[Variant\] Support appending complex variants in `VariantBuilder` [\#7914](https://github.com/apache/arrow-rs/pull/7914) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) -- \[Variant\] Add `VariantBuilder::new_with_buffers` to write to existing buffers [\#7912](https://github.com/apache/arrow-rs/pull/7912) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- Convert JSON to VariantArray without copying \(8 - 32% faster\) [\#7911](https://github.com/apache/arrow-rs/pull/7911) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- \[Variant\] Use simdutf8 for UTF-8 validation [\#7908](https://github.com/apache/arrow-rs/pull/7908) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([codephage2020](https://github.com/codephage2020)) -- \[Variant\] Avoid superflous validation checks [\#7906](https://github.com/apache/arrow-rs/pull/7906) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) -- Add `VariantArray` and `VariantArrayBuilder` for constructing Arrow Arrays of Variants [\#7905](https://github.com/apache/arrow-rs/pull/7905) ([alamb](https://github.com/alamb)) -- Update sysinfo requirement from 0.35.0 to 0.36.0 [\#7904](https://github.com/apache/arrow-rs/pull/7904) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) -- Fix current CI failure [\#7898](https://github.com/apache/arrow-rs/pull/7898) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Remove redundant is\_err checks in Variant tests [\#7897](https://github.com/apache/arrow-rs/pull/7897) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) -- \[Variant\] test: add variant object tests with different sizes [\#7896](https://github.com/apache/arrow-rs/pull/7896) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([odysa](https://github.com/odysa)) -- \[Variant\] Define basic convenience methods for variant pathing [\#7894](https://github.com/apache/arrow-rs/pull/7894) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) -- fix: `view_types` benchmark slice should follow by correct len array [\#7892](https://github.com/apache/arrow-rs/pull/7892) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) -- Add arrow-avro support for bzip2 and xz compression [\#7890](https://github.com/apache/arrow-rs/pull/7890) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) -- Add arrow-avro support for Duration type and minor fixes for UUID decoding [\#7889](https://github.com/apache/arrow-rs/pull/7889) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) -- \[Variant\] Reduce variant-related struct sizes [\#7888](https://github.com/apache/arrow-rs/pull/7888) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) -- Fix panic on lossy decimal to float casting: round to saturation for overflows [\#7887](https://github.com/apache/arrow-rs/pull/7887) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kosiew](https://github.com/kosiew)) -- Add tests for invalid variant metadata and value [\#7885](https://github.com/apache/arrow-rs/pull/7885) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) -- \[Variant\] Introduce parquet-variant-compute crate to transform batches of JSON strings to and from Variants [\#7884](https://github.com/apache/arrow-rs/pull/7884) ([harshmotw-db](https://github.com/harshmotw-db)) -- feat: support `MapArray` in lexsort [\#7882](https://github.com/apache/arrow-rs/pull/7882) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) -- fix: mark `DataType::Map` as unsupported in `RowConverter` [\#7880](https://github.com/apache/arrow-rs/pull/7880) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) -- \[Variant\] Speedup validation [\#7878](https://github.com/apache/arrow-rs/pull/7878) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) -- benchmark: Add StringViewArray gc benchmark with not null cases [\#7877](https://github.com/apache/arrow-rs/pull/7877) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) -- \[ARROW-RS-7820\]\[Variant\] Add tests for large variant lists [\#7876](https://github.com/apache/arrow-rs/pull/7876) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([klion26](https://github.com/klion26)) -- fix: Incorrect inlined string view comparison after Add prefix compar… [\#7875](https://github.com/apache/arrow-rs/pull/7875) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) -- perf: speed up StringViewArray gc 1.4 ~5.x faster [\#7873](https://github.com/apache/arrow-rs/pull/7873) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) -- \[Variant\] Remove superflous validate call and rename methods [\#7871](https://github.com/apache/arrow-rs/pull/7871) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) -- Benchmark: Add rich testing cases for sort string\(utf8\) [\#7867](https://github.com/apache/arrow-rs/pull/7867) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) -- chore: update link for `row_filter.rs` [\#7866](https://github.com/apache/arrow-rs/pull/7866) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([haohuaijin](https://github.com/haohuaijin)) -- \[Variant\] List and object builders have no effect until finalized [\#7865](https://github.com/apache/arrow-rs/pull/7865) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) -- Added number to string benches for json\_writer [\#7864](https://github.com/apache/arrow-rs/pull/7864) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([abacef](https://github.com/abacef)) -- \[Variant\] Introduce `parquet-variant-json` crate [\#7862](https://github.com/apache/arrow-rs/pull/7862) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- \[Variant\] Remove dead code, add comments [\#7861](https://github.com/apache/arrow-rs/pull/7861) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- Speedup sorting for inline views: 1.4x - 1.7x improvement [\#7856](https://github.com/apache/arrow-rs/pull/7856) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) -- Fix union slice logical\_nulls length [\#7855](https://github.com/apache/arrow-rs/pull/7855) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([codephage2020](https://github.com/codephage2020)) -- Add `get_ref/get_mut` to JSON Writer [\#7854](https://github.com/apache/arrow-rs/pull/7854) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([cetra3](https://github.com/cetra3)) -- \[Minor\] Add Benchmark for RowConverter::append [\#7853](https://github.com/apache/arrow-rs/pull/7853) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) -- Add Enum type support to arrow-avro and Minor Decimal type fix [\#7852](https://github.com/apache/arrow-rs/pull/7852) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) -- CSV error message has values transposed [\#7851](https://github.com/apache/arrow-rs/pull/7851) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Omega359](https://github.com/Omega359)) -- \[Variant\] Fuzz testing and benchmarks for vaildation [\#7849](https://github.com/apache/arrow-rs/pull/7849) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([carpecodeum](https://github.com/carpecodeum)) -- \[Variant\] Follow up nits and uncomment test cases [\#7846](https://github.com/apache/arrow-rs/pull/7846) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) -- \[Variant\] Make sure ObjectBuilder and ListBuilder to be finalized before its parent builder [\#7843](https://github.com/apache/arrow-rs/pull/7843) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) -- Add decimal32 and decimal64 support to Parquet, JSON and CSV readers and writers [\#7841](https://github.com/apache/arrow-rs/pull/7841) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([CurtHagenlocher](https://github.com/CurtHagenlocher)) -- Implement arrow-avro Reader and ReaderBuilder [\#7834](https://github.com/apache/arrow-rs/pull/7834) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) -- \[Variant\] Support creating sorted dictionaries [\#7833](https://github.com/apache/arrow-rs/pull/7833) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) -- Add Decimal type support to arrow-avro [\#7832](https://github.com/apache/arrow-rs/pull/7832) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) -- Allow concating struct arrays with no fields [\#7829](https://github.com/apache/arrow-rs/pull/7829) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([AdamGS](https://github.com/AdamGS)) -- Add features to configure flate2 [\#7827](https://github.com/apache/arrow-rs/pull/7827) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zeevm](https://github.com/zeevm)) -- make builder public under experimental [\#7825](https://github.com/apache/arrow-rs/pull/7825) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XiangpengHao](https://github.com/XiangpengHao)) -- Improvements for parquet writing performance \(25%-44%\) [\#7824](https://github.com/apache/arrow-rs/pull/7824) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) -- Use in-memory buffer for arrow\_writer benchmark [\#7823](https://github.com/apache/arrow-rs/pull/7823) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jhorstmann](https://github.com/jhorstmann)) -- \[Variant\] impl \[Try\]From for VariantDecimalXX types [\#7809](https://github.com/apache/arrow-rs/pull/7809) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) -- \[Variant\] Speedup `ObjectBuilder` \(62x faster\) [\#7808](https://github.com/apache/arrow-rs/pull/7808) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) -- \[VARIANT\] Support both fallible and infallible access to variants [\#7807](https://github.com/apache/arrow-rs/pull/7807) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) -- Minor: fix clippy in parquet-variant after logical conflict [\#7803](https://github.com/apache/arrow-rs/pull/7803) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- \[Variant\] Add flag in `ObjectBuilder` to control validation behavior on duplicate field write [\#7801](https://github.com/apache/arrow-rs/pull/7801) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([micoo227](https://github.com/micoo227)) -- Fix clippy for Rust 1.88 release [\#7797](https://github.com/apache/arrow-rs/pull/7797) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) -- \[Variant\] Simplify `Builder` buffer operations [\#7795](https://github.com/apache/arrow-rs/pull/7795) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) -- fix: Change panic to error in`take` kernel for StringArrary/BinaryArray on overflow [\#7793](https://github.com/apache/arrow-rs/pull/7793) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([chenkovsky](https://github.com/chenkovsky)) -- Update base64 requirement from 0.21 to 0.22 [\#7791](https://github.com/apache/arrow-rs/pull/7791) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) -- Fix RowConverter when FixedSizeList is not the last [\#7789](https://github.com/apache/arrow-rs/pull/7789) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) -- Add schema with only primitive arrays to `coalesce_kernel` benchmark [\#7788](https://github.com/apache/arrow-rs/pull/7788) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- Add sort\_kernel benchmark for StringViewArray case [\#7787](https://github.com/apache/arrow-rs/pull/7787) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) -- \[Variant\] Check pending before `VariantObject::insert` [\#7786](https://github.com/apache/arrow-rs/pull/7786) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) -- \[VARIANT\] impl Display for VariantDecimalXX [\#7785](https://github.com/apache/arrow-rs/pull/7785) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([scovich](https://github.com/scovich)) -- \[VARIANT\] Add support for the json\_to\_variant API [\#7783](https://github.com/apache/arrow-rs/pull/7783) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([harshmotw-db](https://github.com/harshmotw-db)) -- \[Variant\] Consolidate examples for json writing [\#7782](https://github.com/apache/arrow-rs/pull/7782) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- Add benchmark for about view array slice [\#7781](https://github.com/apache/arrow-rs/pull/7781) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ctsk](https://github.com/ctsk)) -- \[Variant\] Add negative tests for reading invalid primitive variant values [\#7779](https://github.com/apache/arrow-rs/pull/7779) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([superserious-dev](https://github.com/superserious-dev)) -- \[Variant\] Support creating nested objects and object with lists [\#7778](https://github.com/apache/arrow-rs/pull/7778) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) -- \[VARIANT\] Validate precision in VariantDecimalXX structs and add missing tests [\#7776](https://github.com/apache/arrow-rs/pull/7776) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) -- Add tests for `BatchCoalescer::push_batch_with_filter`, fix bug [\#7774](https://github.com/apache/arrow-rs/pull/7774) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- \[Variant\] Minor: make fields in `VariantDecimal*` private, add examples [\#7770](https://github.com/apache/arrow-rs/pull/7770) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- Extend the fast path in GenericByteViewArray::is\_eq for comparing against empty strings [\#7767](https://github.com/apache/arrow-rs/pull/7767) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) -- \[Variant\] Improve getter API for `VariantList` and `VariantObject` [\#7757](https://github.com/apache/arrow-rs/pull/7757) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) -- \[Variant\] Add Variant::as\_object and Variant::as\_list [\#7755](https://github.com/apache/arrow-rs/pull/7755) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- \[Variant\] Fix several overflow panic risks for 32-bit arch [\#7752](https://github.com/apache/arrow-rs/pull/7752) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) -- Add testing section to pull request template [\#7749](https://github.com/apache/arrow-rs/pull/7749) ([alamb](https://github.com/alamb)) -- Perf: Add prefix compare for inlined compare and change use of inline\_value to inline it to a u128 [\#7748](https://github.com/apache/arrow-rs/pull/7748) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) -- Move arrow-pyarrow tests that require `pyarrow` to be installed into `arrow-pyarrow-testing` crate [\#7742](https://github.com/apache/arrow-rs/pull/7742) ([alamb](https://github.com/alamb)) -- \[Variant\] Improve write API in `Variant::Object` [\#7741](https://github.com/apache/arrow-rs/pull/7741) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) -- \[Variant\] Support nested lists and object lists [\#7740](https://github.com/apache/arrow-rs/pull/7740) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) -- feat: \[Variant\] Add Validation for Variant Deciaml [\#7738](https://github.com/apache/arrow-rs/pull/7738) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Weijun-H](https://github.com/Weijun-H)) -- Add fallible versions of temporal functions that may panic [\#7737](https://github.com/apache/arrow-rs/pull/7737) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([adriangb](https://github.com/adriangb)) -- fix: Implement support for appending Object and List variants in VariantBuilder [\#7735](https://github.com/apache/arrow-rs/pull/7735) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Weijun-H](https://github.com/Weijun-H)) -- parquet\_derive: update in working example for ParquetRecordWriter [\#7733](https://github.com/apache/arrow-rs/pull/7733) ([LanHikari22](https://github.com/LanHikari22)) -- Perf: Optimize comparison kernels for inlined views [\#7731](https://github.com/apache/arrow-rs/pull/7731) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) -- arrow-row: Refactor arrow-row REE roundtrip tests [\#7729](https://github.com/apache/arrow-rs/pull/7729) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) -- arrow-array: Implement PartialEq for RunArray [\#7727](https://github.com/apache/arrow-rs/pull/7727) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) -- fix: Do not add null buffer for `NullArray` in MutableArrayData [\#7726](https://github.com/apache/arrow-rs/pull/7726) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([comphead](https://github.com/comphead)) -- Allow per-column parquet dictionary page size limit [\#7724](https://github.com/apache/arrow-rs/pull/7724) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XiangpengHao](https://github.com/XiangpengHao)) -- fix JSON decoder error checking for UTF16 / surrogate parsing panic [\#7721](https://github.com/apache/arrow-rs/pull/7721) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nicklan](https://github.com/nicklan)) -- \[Variant\] Use `BTreeMap` for `VariantBuilder.dict` and `ObjectBuilder.fields` to maintain invariants upon entry writes [\#7720](https://github.com/apache/arrow-rs/pull/7720) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) -- Introduce `MAX_INLINE_VIEW_LEN` constant for string/byte views [\#7719](https://github.com/apache/arrow-rs/pull/7719) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- \[Variant\] Introduce new type over &str for ShortString [\#7718](https://github.com/apache/arrow-rs/pull/7718) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) -- Split out variant code into several new sub-modules [\#7717](https://github.com/apache/arrow-rs/pull/7717) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) -- add `garbage_collect_dictionary` to `arrow-select` [\#7716](https://github.com/apache/arrow-rs/pull/7716) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([davidhewitt](https://github.com/davidhewitt)) -- Support write to buffer api for SerializedFileWriter [\#7714](https://github.com/apache/arrow-rs/pull/7714) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) -- Support `FixedSizeList` RowConverter [\#7705](https://github.com/apache/arrow-rs/pull/7705) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) -- Make variant iterators safely infallible [\#7704](https://github.com/apache/arrow-rs/pull/7704) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) -- Speedup `interleave_views` \(4-7x faster\) [\#7695](https://github.com/apache/arrow-rs/pull/7695) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) -- Define a "arrow-pyrarrow" crate to implement the "pyarrow" feature. [\#7694](https://github.com/apache/arrow-rs/pull/7694) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brunal](https://github.com/brunal)) -- feat: add constructor to efficiently upgrade dict key type to remaining builders [\#7689](https://github.com/apache/arrow-rs/pull/7689) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([albertlockett](https://github.com/albertlockett)) -- Document REE row format and add some more tests [\#7680](https://github.com/apache/arrow-rs/pull/7680) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- feat: add min max aggregate support for FixedSizeBinary [\#7675](https://github.com/apache/arrow-rs/pull/7675) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alexwilcoxson-rel](https://github.com/alexwilcoxson-rel)) -- arrow-data: Add REE support for `build_extend` and `build_extend_nulls` [\#7671](https://github.com/apache/arrow-rs/pull/7671) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) -- Variant: Write Variant Values as JSON [\#7670](https://github.com/apache/arrow-rs/pull/7670) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([carpecodeum](https://github.com/carpecodeum)) -- Remove `lazy_static` dependency [\#7669](https://github.com/apache/arrow-rs/pull/7669) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Expyron](https://github.com/Expyron)) -- Finish implementing Variant::Object and Variant::List [\#7666](https://github.com/apache/arrow-rs/pull/7666) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) -- Add `RecordBatch::schema_metadata_mut` and `Field::metadata_mut` [\#7664](https://github.com/apache/arrow-rs/pull/7664) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([emilk](https://github.com/emilk)) -- \[Variant\] Simplify creation of Variants from metadata and value [\#7663](https://github.com/apache/arrow-rs/pull/7663) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- chore: group prost dependabot updates [\#7659](https://github.com/apache/arrow-rs/pull/7659) ([mbrobbel](https://github.com/mbrobbel)) -- Initial Builder API for Creating Variant Values [\#7653](https://github.com/apache/arrow-rs/pull/7653) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([PinkCrow007](https://github.com/PinkCrow007)) -- Add `BatchCoalescer::push_filtered_batch` and docs [\#7652](https://github.com/apache/arrow-rs/pull/7652) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- Optimize coalesce kernel for StringView \(10-50% faster\) [\#7650](https://github.com/apache/arrow-rs/pull/7650) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- arrow-row: Add support for REE [\#7649](https://github.com/apache/arrow-rs/pull/7649) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) -- Use approximate comparisons for pow tests [\#7646](https://github.com/apache/arrow-rs/pull/7646) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([adamreeve](https://github.com/adamreeve)) -- \[Variant\] Implement read support for remaining primitive types [\#7644](https://github.com/apache/arrow-rs/pull/7644) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([superserious-dev](https://github.com/superserious-dev)) -- Add `pretty_format_batches_with_schema` function [\#7642](https://github.com/apache/arrow-rs/pull/7642) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([lewiszlw](https://github.com/lewiszlw)) -- Deprecate old Parquet page index parsing functions [\#7640](https://github.com/apache/arrow-rs/pull/7640) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) -- Update FlightSQL `GetDbSchemas` and `GetTables` schemas to fully match the protocol [\#7638](https://github.com/apache/arrow-rs/pull/7638) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([sgrebnov](https://github.com/sgrebnov)) -- Minor: Remove outdated FIXME from `ParquetMetaDataReader` [\#7635](https://github.com/apache/arrow-rs/pull/7635) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) -- Fix the error info of `StructArray::try_new` [\#7634](https://github.com/apache/arrow-rs/pull/7634) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xudong963](https://github.com/xudong963)) -- Fix reading encrypted Parquet pages when using the page index [\#7633](https://github.com/apache/arrow-rs/pull/7633) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([adamreeve](https://github.com/adamreeve)) -- \[Variant\] Add commented out primitive test casees [\#7631](https://github.com/apache/arrow-rs/pull/7631) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- \[Variant\] Enhance the variant fuz test to cover time/timestamp/uuid primitive type [\#8200](https://github.com/apache/arrow-rs/pull/8200) ([klion26](https://github.com/klion26)) +- \[Variant\] VariantArrayBuilder tracks only offsets [\#8193](https://github.com/apache/arrow-rs/pull/8193) ([scovich](https://github.com/scovich)) +- \[Variant\] Caller provides ParentState to ValueBuilder methods [\#8189](https://github.com/apache/arrow-rs/pull/8189) ([scovich](https://github.com/scovich)) +- \[Variant\] Rename ValueBuffer as ValueBuilder [\#8187](https://github.com/apache/arrow-rs/pull/8187) ([scovich](https://github.com/scovich)) +- \[Variant\] ParentState handles finish/rollback for builders [\#8185](https://github.com/apache/arrow-rs/pull/8185) ([scovich](https://github.com/scovich)) +- \[Variant\]: Implement `DataType::RunEndEncoded` support for `cast_to_variant` kernel [\#8174](https://github.com/apache/arrow-rs/pull/8174) ([liamzwbao](https://github.com/liamzwbao)) +- \[Variant\]: Implement `DataType::Dictionary` support for `cast_to_variant` kernel [\#8173](https://github.com/apache/arrow-rs/pull/8173) ([liamzwbao](https://github.com/liamzwbao)) +- Implement `ArrayBuilder` for `UnionBuilder` [\#8169](https://github.com/apache/arrow-rs/pull/8169) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([grtlr](https://github.com/grtlr)) +- \[Variant\] Support `LargeString` and `StringView` in `batch_json_string_to_variant` [\#8163](https://github.com/apache/arrow-rs/pull/8163) ([liamzwbao](https://github.com/liamzwbao)) +- \[Variant\] Rename `batch_json_string_to_variant` and `batch_variant_to_json_string` [\#8161](https://github.com/apache/arrow-rs/pull/8161) ([liamzwbao](https://github.com/liamzwbao)) +- \[Variant\] Add primitive type timestamp\_nanos\(with&without timezone\) and uuid [\#8149](https://github.com/apache/arrow-rs/pull/8149) ([klion26](https://github.com/klion26)) +- refactor\(avro\): Use impl Write instead of dyn Write in encoder [\#8148](https://github.com/apache/arrow-rs/pull/8148) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Xuanwo](https://github.com/Xuanwo)) +- chore: Use tempfile to replace hand-written utils functions [\#8147](https://github.com/apache/arrow-rs/pull/8147) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Xuanwo](https://github.com/Xuanwo)) +- feat: support push batch direct to completed and add biggest coalesce batch support [\#8146](https://github.com/apache/arrow-rs/pull/8146) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- \[Variant\] Add human-readable impl Debug for Variant [\#8140](https://github.com/apache/arrow-rs/pull/8140) ([scovich](https://github.com/scovich)) +- \[Variant\] Fix broken metadata builder rollback [\#8135](https://github.com/apache/arrow-rs/pull/8135) ([scovich](https://github.com/scovich)) +- \[Variant\]: Implement DataType::Interval support for cast\_to\_variant kernel [\#8125](https://github.com/apache/arrow-rs/pull/8125) ([codephage2020](https://github.com/codephage2020)) +- Add schema resolution and type promotion support to arrow-avro Decoder [\#8124](https://github.com/apache/arrow-rs/pull/8124) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Add Initial `arrow-avro` writer implementation with basic type support [\#8123](https://github.com/apache/arrow-rs/pull/8123) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- \[Variant\] Add Variant::Time primitive and cast logic [\#8114](https://github.com/apache/arrow-rs/pull/8114) ([klion26](https://github.com/klion26)) +- \[Variant\] Support Timestamp to variant for `cast_to_variant` kernel [\#8113](https://github.com/apache/arrow-rs/pull/8113) ([abacef](https://github.com/abacef)) +- Bump actions/checkout from 4 to 5 [\#8110](https://github.com/apache/arrow-rs/pull/8110) ([dependabot[bot]](https://github.com/apps/dependabot)) +- \[Varaint\]: add `DataType::Null` support to cast\_to\_variant [\#8107](https://github.com/apache/arrow-rs/pull/8107) ([feniljain](https://github.com/feniljain)) +- \[Variant\] Adding fixed size byte array to variant and test [\#8106](https://github.com/apache/arrow-rs/pull/8106) ([abacef](https://github.com/abacef)) +- \[VARIANT\] Initial integration tests for variant reads [\#8104](https://github.com/apache/arrow-rs/pull/8104) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([carpecodeum](https://github.com/carpecodeum)) +- \[Variant\]: Implement `DataType::Decimal32/Decimal64/Decimal128/Decimal256` support for `cast_to_variant` kernel [\#8101](https://github.com/apache/arrow-rs/pull/8101) ([liamzwbao](https://github.com/liamzwbao)) +- Refactor arrow-avro `Decoder` to support partial decoding [\#8100](https://github.com/apache/arrow-rs/pull/8100) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- fix: Validate metadata len in IPC reader [\#8097](https://github.com/apache/arrow-rs/pull/8097) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([JakeDern](https://github.com/JakeDern)) +- \[parquet\] further improve logical type compatibility in ArrowWriter [\#8095](https://github.com/apache/arrow-rs/pull/8095) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([albertlockett](https://github.com/albertlockett)) +- \[Varint\] Implement ShreddingState::AllNull variant [\#8093](https://github.com/apache/arrow-rs/pull/8093) ([codephage2020](https://github.com/codephage2020)) +- \[Variant\] Minor: Add comments to tickets for follow on items [\#8092](https://github.com/apache/arrow-rs/pull/8092) ([alamb](https://github.com/alamb)) +- \[VARIANT\] Add support for DataType::Struct for cast\_to\_variant [\#8090](https://github.com/apache/arrow-rs/pull/8090) ([carpecodeum](https://github.com/carpecodeum)) +- \[VARIANT\] Add support for DataType::Utf8/LargeUtf8/Utf8View for cast\_to\_variant [\#8089](https://github.com/apache/arrow-rs/pull/8089) ([carpecodeum](https://github.com/carpecodeum)) +- \[Variant\] Implement `DataType::Boolean` support for `cast_to_variant` kernel [\#8085](https://github.com/apache/arrow-rs/pull/8085) ([sdf-jkl](https://github.com/sdf-jkl)) +- \[Variant\] Implement `DataType::{Date32,Date64}` =\> `Variant::Date` [\#8081](https://github.com/apache/arrow-rs/pull/8081) ([superserious-dev](https://github.com/superserious-dev)) +- Fix new clippy lints from Rust 1.89 [\#8078](https://github.com/apache/arrow-rs/pull/8078) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Implement ArrowSchema to AvroSchema conversion logic in arrow-avro [\#8075](https://github.com/apache/arrow-rs/pull/8075) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Implement `DataType::{Binary, LargeBinary, BinaryView}` =\> `Variant::Binary` [\#8074](https://github.com/apache/arrow-rs/pull/8074) ([superserious-dev](https://github.com/superserious-dev)) +- \[Variant\] Implement `DataType::Float16` =\> `Variant::Float` [\#8073](https://github.com/apache/arrow-rs/pull/8073) ([superserious-dev](https://github.com/superserious-dev)) +- create PageIndexPolicy to allow optional indexes [\#8071](https://github.com/apache/arrow-rs/pull/8071) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([kczimm](https://github.com/kczimm)) +- \[Variant\] Minor: use From impl to make conversion infallable [\#8068](https://github.com/apache/arrow-rs/pull/8068) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Bump actions/download-artifact from 4 to 5 [\#8066](https://github.com/apache/arrow-rs/pull/8066) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Added arrow-avro schema resolution foundations and type promotion [\#8047](https://github.com/apache/arrow-rs/pull/8047) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Fix arrow-avro type resolver register bug [\#8046](https://github.com/apache/arrow-rs/pull/8046) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([yongkyunlee](https://github.com/yongkyunlee)) +- implement `cast_to_variant` kernel to cast native types to `VariantArray` [\#8044](https://github.com/apache/arrow-rs/pull/8044) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Add arrow-avro `SchemaStore` and fingerprinting [\#8039](https://github.com/apache/arrow-rs/pull/8039) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Add more benchmarks for Parquet thrift decoding [\#8037](https://github.com/apache/arrow-rs/pull/8037) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Support multi-threaded writing of Parquet files with modular encryption [\#8029](https://github.com/apache/arrow-rs/pull/8029) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([rok](https://github.com/rok)) +- Add arrow-avro Decoder Benchmarks [\#8025](https://github.com/apache/arrow-rs/pull/8025) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- feat: add method for sync Parquet reader read bloom filter [\#8024](https://github.com/apache/arrow-rs/pull/8024) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([mapleFU](https://github.com/mapleFU)) +- \[Variant\] Add `variant_get` and Shredded `VariantArray` [\#8021](https://github.com/apache/arrow-rs/pull/8021) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Implement arrow-avro SchemaStore and Fingerprinting To Enable Schema Resolution [\#8006](https://github.com/apache/arrow-rs/pull/8006) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- \[Parquet\] Add tests for IO/CPU access in parquet reader [\#7971](https://github.com/apache/arrow-rs/pull/7971) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Speed up Parquet filter pushdown v4 \(Predicate evaluation cache for async\_reader\) [\#7850](https://github.com/apache/arrow-rs/pull/7850) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Implement cast and other operations on decimal32 and decimal64 [\#7815](https://github.com/apache/arrow-rs/pull/7815) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([CurtHagenlocher](https://github.com/CurtHagenlocher)) diff --git a/Cargo.toml b/Cargo.toml index 9d1ad6d03b5e..bf0efc37d30a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -67,7 +67,7 @@ exclude = [ ] [workspace.package] -version = "56.0.0" +version = "56.1.0" homepage = "https://github.com/apache/arrow-rs" repository = "https://github.com/apache/arrow-rs" authors = ["Apache Arrow "] @@ -84,27 +84,27 @@ edition = "2021" rust-version = "1.84" [workspace.dependencies] -arrow = { version = "56.0.0", path = "./arrow", default-features = false } -arrow-arith = { version = "56.0.0", path = "./arrow-arith" } -arrow-array = { version = "56.0.0", path = "./arrow-array" } -arrow-buffer = { version = "56.0.0", path = "./arrow-buffer" } -arrow-cast = { version = "56.0.0", path = "./arrow-cast" } -arrow-csv = { version = "56.0.0", path = "./arrow-csv" } -arrow-data = { version = "56.0.0", path = "./arrow-data" } -arrow-ipc = { version = "56.0.0", path = "./arrow-ipc" } -arrow-json = { version = "56.0.0", path = "./arrow-json" } -arrow-ord = { version = "56.0.0", path = "./arrow-ord" } -arrow-pyarrow = { version = "56.0.0", path = "./arrow-pyarrow" } -arrow-row = { version = "56.0.0", path = "./arrow-row" } -arrow-schema = { version = "56.0.0", path = "./arrow-schema" } -arrow-select = { version = "56.0.0", path = "./arrow-select" } -arrow-string = { version = "56.0.0", path = "./arrow-string" } -parquet = { version = "56.0.0", path = "./parquet", default-features = false } +arrow = { version = "56.1.0", path = "./arrow", default-features = false } +arrow-arith = { version = "56.1.0", path = "./arrow-arith" } +arrow-array = { version = "56.1.0", path = "./arrow-array" } +arrow-buffer = { version = "56.1.0", path = "./arrow-buffer" } +arrow-cast = { version = "56.1.0", path = "./arrow-cast" } +arrow-csv = { version = "56.1.0", path = "./arrow-csv" } +arrow-data = { version = "56.1.0", path = "./arrow-data" } +arrow-ipc = { version = "56.1.0", path = "./arrow-ipc" } +arrow-json = { version = "56.1.0", path = "./arrow-json" } +arrow-ord = { version = "56.1.0", path = "./arrow-ord" } +arrow-pyarrow = { version = "56.1.0", path = "./arrow-pyarrow" } +arrow-row = { version = "56.1.0", path = "./arrow-row" } +arrow-schema = { version = "56.1.0", path = "./arrow-schema" } +arrow-select = { version = "56.1.0", path = "./arrow-select" } +arrow-string = { version = "56.1.0", path = "./arrow-string" } +parquet = { version = "56.1.0", path = "./parquet", default-features = false } # These crates have not yet been released and thus do not use the workspace version parquet-variant = { version = "0.1.0", path = "./parquet-variant" } parquet-variant-json = { version = "0.1.0", path = "./parquet-variant-json" } -parquet-variant-compute = { version = "0.1.0", path = "./parquet-variant-json" } +parquet-variant-compute = { version = "0.1.0", path = "./parquet-variant-compute" } chrono = { version = "0.4.40", default-features = false, features = ["clock"] } diff --git a/arrow-array/Cargo.toml b/arrow-array/Cargo.toml index 8ebe21c70772..9fffe3b6bbe2 100644 --- a/arrow-array/Cargo.toml +++ b/arrow-array/Cargo.toml @@ -46,7 +46,7 @@ chrono = { workspace = true } chrono-tz = { version = "0.10", optional = true } num = { version = "0.4.1", default-features = false, features = ["std"] } half = { version = "2.1", default-features = false, features = ["num-traits"] } -hashbrown = { version = "0.15.1", default-features = false } +hashbrown = { version = "0.16.0", default-features = false } [package.metadata.docs.rs] all-features = true diff --git a/arrow-array/src/array/boolean_array.rs b/arrow-array/src/array/boolean_array.rs index fcebf5a0f718..fe7ad85b7a05 100644 --- a/arrow-array/src/array/boolean_array.rs +++ b/arrow-array/src/array/boolean_array.rs @@ -178,6 +178,9 @@ impl BooleanArray { /// Returns the boolean value at index `i`. /// + /// Note: This method does not check for nulls and the value is arbitrary + /// if [`is_null`](Self::is_null) returns true for the index. + /// /// # Safety /// This doesn't check bounds, the caller must ensure that index < self.len() pub unsafe fn value_unchecked(&self, i: usize) -> bool { @@ -185,6 +188,10 @@ impl BooleanArray { } /// Returns the boolean value at index `i`. + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// if [`is_null`](Self::is_null) returns true for the index. + /// /// # Panics /// Panics if index `i` is out of bounds pub fn value(&self, i: usize) -> bool { diff --git a/arrow-array/src/array/byte_array.rs b/arrow-array/src/array/byte_array.rs index 192c9654b055..2ff9e9f4f658 100644 --- a/arrow-array/src/array/byte_array.rs +++ b/arrow-array/src/array/byte_array.rs @@ -276,6 +276,10 @@ impl GenericByteArray { } /// Returns the element at index `i` + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// if [`is_null`](Self::is_null) returns true for the index. + /// /// # Safety /// Caller is responsible for ensuring that the index is within the bounds of the array pub unsafe fn value_unchecked(&self, i: usize) -> &T::Native { @@ -304,6 +308,10 @@ impl GenericByteArray { } /// Returns the element at index `i` + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// (but still well-defined) if [`is_null`](Self::is_null) returns true for the index. + /// /// # Panics /// Panics if index `i` is out of bounds. pub fn value(&self, i: usize) -> &T::Native { diff --git a/arrow-array/src/array/byte_view_array.rs b/arrow-array/src/array/byte_view_array.rs index 43ff3f76369f..7c8993d6028e 100644 --- a/arrow-array/src/array/byte_view_array.rs +++ b/arrow-array/src/array/byte_view_array.rs @@ -296,6 +296,10 @@ impl GenericByteViewArray { } /// Returns the element at index `i` + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// (but still well-defined) if [`is_null`](Self::is_null) returns true for the index. + /// /// # Panics /// Panics if index `i` is out of bounds. pub fn value(&self, i: usize) -> &T::Native { @@ -312,6 +316,9 @@ impl GenericByteViewArray { /// Returns the element at index `i` without bounds checking /// + /// Note: This method does not check for nulls and the value is arbitrary + /// if [`is_null`](Self::is_null) returns true for the index. + /// /// # Safety /// /// Caller is responsible for ensuring that the index is within the bounds diff --git a/arrow-array/src/array/fixed_size_binary_array.rs b/arrow-array/src/array/fixed_size_binary_array.rs index 55973a58f2cb..76d9db04704e 100644 --- a/arrow-array/src/array/fixed_size_binary_array.rs +++ b/arrow-array/src/array/fixed_size_binary_array.rs @@ -135,6 +135,10 @@ impl FixedSizeBinaryArray { } /// Returns the element at index `i` as a byte slice. + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// (but still well-defined) if [`is_null`](Self::is_null) returns true for the index. + /// /// # Panics /// Panics if index `i` is out of bounds. pub fn value(&self, i: usize) -> &[u8] { @@ -155,8 +159,14 @@ impl FixedSizeBinaryArray { } /// Returns the element at index `i` as a byte slice. + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// if [`is_null`](Self::is_null) returns true for the index. + /// /// # Safety - /// Caller is responsible for ensuring that the index is within the bounds of the array + /// + /// Caller is responsible for ensuring that the index is within the bounds + /// of the array pub unsafe fn value_unchecked(&self, i: usize) -> &[u8] { let offset = i + self.offset(); let pos = self.value_offset_at(offset); diff --git a/arrow-array/src/array/fixed_size_list_array.rs b/arrow-array/src/array/fixed_size_list_array.rs index f807cc88fbca..4a338591e5aa 100644 --- a/arrow-array/src/array/fixed_size_list_array.rs +++ b/arrow-array/src/array/fixed_size_list_array.rs @@ -243,6 +243,12 @@ impl FixedSizeListArray { } /// Returns ith value of this list array. + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// (but still well-defined) if [`is_null`](Self::is_null) returns true for the index. + /// + /// # Panics + /// Panics if index `i` is out of bounds pub fn value(&self, i: usize) -> ArrayRef { self.values .slice(self.value_offset_at(i), self.value_length() as usize) diff --git a/arrow-array/src/array/list_array.rs b/arrow-array/src/array/list_array.rs index 832a1c0a9ad8..8836b5b0f73d 100644 --- a/arrow-array/src/array/list_array.rs +++ b/arrow-array/src/array/list_array.rs @@ -327,6 +327,10 @@ impl GenericListArray { } /// Returns ith value of this list array. + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// if [`is_null`](Self::is_null) returns true for the index. + /// /// # Safety /// Caller must ensure that the index is within the array bounds pub unsafe fn value_unchecked(&self, i: usize) -> ArrayRef { @@ -336,6 +340,12 @@ impl GenericListArray { } /// Returns ith value of this list array. + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// (but still well-defined) if [`is_null`](Self::is_null) returns true for the index. + /// + /// # Panics + /// Panics if index `i` is out of bounds pub fn value(&self, i: usize) -> ArrayRef { let end = self.value_offsets()[i + 1].as_usize(); let start = self.value_offsets()[i].as_usize(); diff --git a/arrow-array/src/array/list_view_array.rs b/arrow-array/src/array/list_view_array.rs index a239ea1e5e73..7d66d10d263c 100644 --- a/arrow-array/src/array/list_view_array.rs +++ b/arrow-array/src/array/list_view_array.rs @@ -283,6 +283,10 @@ impl GenericListViewArray { } /// Returns ith value of this list view array. + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// if [`is_null`](Self::is_null) returns true for the index. + /// /// # Safety /// Caller must ensure that the index is within the array bounds pub unsafe fn value_unchecked(&self, i: usize) -> ArrayRef { @@ -292,6 +296,10 @@ impl GenericListViewArray { } /// Returns ith value of this list view array. + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// (but still well-defined) if [`is_null`](Self::is_null) returns true for the index. + /// /// # Panics /// Panics if the index is out of bounds pub fn value(&self, i: usize) -> ArrayRef { diff --git a/arrow-array/src/array/map_array.rs b/arrow-array/src/array/map_array.rs index 18a7c491aa16..9a1e04c7f1c0 100644 --- a/arrow-array/src/array/map_array.rs +++ b/arrow-array/src/array/map_array.rs @@ -185,6 +185,9 @@ impl MapArray { /// Returns ith value of this map array. /// + /// Note: This method does not check for nulls and the value is arbitrary + /// if [`is_null`](Self::is_null) returns true for the index. + /// /// # Safety /// Caller must ensure that the index is within the array bounds pub unsafe fn value_unchecked(&self, i: usize) -> StructArray { @@ -197,6 +200,12 @@ impl MapArray { /// Returns ith value of this map array. /// /// This is a [`StructArray`] containing two fields + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// (but still well-defined) if [`is_null`](Self::is_null) returns true for the index. + /// + /// # Panics + /// Panics if index `i` is out of bounds pub fn value(&self, i: usize) -> StructArray { let end = self.value_offsets()[i + 1] as usize; let start = self.value_offsets()[i] as usize; diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs index 9327668824f8..42594e7a129d 100644 --- a/arrow-array/src/array/primitive_array.rs +++ b/arrow-array/src/array/primitive_array.rs @@ -720,6 +720,9 @@ impl PrimitiveArray { /// Returns the primitive value at index `i`. /// + /// Note: This method does not check for nulls and the value is arbitrary + /// if [`is_null`](Self::is_null) returns true for the index. + /// /// # Safety /// /// caller must ensure that the passed in offset is less than the array len() @@ -729,6 +732,10 @@ impl PrimitiveArray { } /// Returns the primitive value at index `i`. + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// if [`is_null`](Self::is_null) returns true for the index. + /// /// # Panics /// Panics if index `i` is out of bounds #[inline] @@ -1235,6 +1242,8 @@ where /// /// If a data type cannot be converted to `NaiveDateTime`, a `None` is returned. /// A valid value is expected, thus the user should first check for validity. + /// + /// See notes on [`PrimitiveArray::value`] regarding nulls and panics pub fn value_as_datetime(&self, i: usize) -> Option { as_datetime::(i64::from(self.value(i))) } @@ -1243,6 +1252,8 @@ where /// /// functionally it is same as `value_as_datetime`, however it adds /// the passed tz to the to-be-returned NaiveDateTime + /// + /// See notes on [`PrimitiveArray::value`] regarding nulls and panics pub fn value_as_datetime_with_tz(&self, i: usize, tz: Tz) -> Option> { as_datetime_with_timezone::(i64::from(self.value(i)), tz) } @@ -1250,6 +1261,8 @@ where /// Returns value as a chrono `NaiveDate` by using `Self::datetime()` /// /// If a data type cannot be converted to `NaiveDate`, a `None` is returned + /// + /// See notes on [`PrimitiveArray::value`] regarding nulls and panics pub fn value_as_date(&self, i: usize) -> Option { self.value_as_datetime(i).map(|datetime| datetime.date()) } @@ -1257,6 +1270,8 @@ where /// Returns a value as a chrono `NaiveTime` /// /// `Date32` and `Date64` return UTC midnight as they do not have time resolution + /// + /// See notes on [`PrimitiveArray::value`] regarding nulls and panics pub fn value_as_time(&self, i: usize) -> Option { as_time::(i64::from(self.value(i))) } @@ -1264,6 +1279,8 @@ where /// Returns a value as a chrono `Duration` /// /// If a data type cannot be converted to `Duration`, a `None` is returned + /// + /// See notes on [`PrimitiveArray::value`] regarding nulls and panics pub fn value_as_duration(&self, i: usize) -> Option { as_duration::(i64::from(self.value(i))) } diff --git a/arrow-array/src/array/union_array.rs b/arrow-array/src/array/union_array.rs index 1350cae3a38b..d105876723da 100644 --- a/arrow-array/src/array/union_array.rs +++ b/arrow-array/src/array/union_array.rs @@ -287,6 +287,10 @@ impl UnionArray { } /// Returns the array's value at index `i`. + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// (but still well-defined) if [`is_null`](Self::is_null) returns true for the index. + /// /// # Panics /// Panics if index `i` is out of bounds pub fn value(&self, i: usize) -> ArrayRef { diff --git a/arrow-array/src/builder/fixed_size_binary_builder.rs b/arrow-array/src/builder/fixed_size_binary_builder.rs index b5f268917c92..8fd6b72c053b 100644 --- a/arrow-array/src/builder/fixed_size_binary_builder.rs +++ b/arrow-array/src/builder/fixed_size_binary_builder.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::builder::{ArrayBuilder, UInt8BufferBuilder}; +use crate::builder::ArrayBuilder; use crate::{ArrayRef, FixedSizeBinaryArray}; use arrow_buffer::Buffer; use arrow_buffer::NullBufferBuilder; @@ -42,7 +42,7 @@ use std::sync::Arc; /// ``` #[derive(Debug)] pub struct FixedSizeBinaryBuilder { - values_builder: UInt8BufferBuilder, + values_builder: Vec, null_buffer_builder: NullBufferBuilder, value_length: i32, } @@ -61,7 +61,7 @@ impl FixedSizeBinaryBuilder { "value length ({byte_width}) of the array must >= 0" ); Self { - values_builder: UInt8BufferBuilder::new(capacity * byte_width as usize), + values_builder: Vec::with_capacity(capacity * byte_width as usize), null_buffer_builder: NullBufferBuilder::new(capacity), value_length: byte_width, } @@ -79,7 +79,7 @@ impl FixedSizeBinaryBuilder { .to_string(), )) } else { - self.values_builder.append_slice(value.as_ref()); + self.values_builder.extend_from_slice(value.as_ref()); self.null_buffer_builder.append_non_null(); Ok(()) } @@ -89,7 +89,7 @@ impl FixedSizeBinaryBuilder { #[inline] pub fn append_null(&mut self) { self.values_builder - .append_slice(&vec![0u8; self.value_length as usize][..]); + .extend(std::iter::repeat_n(0u8, self.value_length as usize)); self.null_buffer_builder.append_null(); } @@ -97,7 +97,7 @@ impl FixedSizeBinaryBuilder { #[inline] pub fn append_nulls(&mut self, n: usize) { self.values_builder - .append_slice(&vec![0u8; self.value_length as usize * n][..]); + .extend(std::iter::repeat_n(0u8, self.value_length as usize * n)); self.null_buffer_builder.append_n_nulls(n); } @@ -110,7 +110,7 @@ impl FixedSizeBinaryBuilder { pub fn finish(&mut self) -> FixedSizeBinaryArray { let array_length = self.len(); let array_data_builder = ArrayData::builder(DataType::FixedSizeBinary(self.value_length)) - .add_buffer(self.values_builder.finish()) + .add_buffer(std::mem::take(&mut self.values_builder).into()) .nulls(self.null_buffer_builder.finish()) .len(array_length); let array_data = unsafe { array_data_builder.build_unchecked() }; diff --git a/arrow-array/src/builder/fixed_size_binary_dictionary_builder.rs b/arrow-array/src/builder/fixed_size_binary_dictionary_builder.rs index 21e842723b4a..852ba680227f 100644 --- a/arrow-array/src/builder/fixed_size_binary_dictionary_builder.rs +++ b/arrow-array/src/builder/fixed_size_binary_dictionary_builder.rs @@ -311,6 +311,41 @@ where DictionaryArray::from(unsafe { builder.build_unchecked() }) } + + /// Builds the `DictionaryArray` without resetting the values builder or + /// the internal de-duplication map. + /// + /// The advantage of doing this is that the values will represent the entire + /// set of what has been built so-far by this builder and ensures + /// consistency in the assignment of keys to values across multiple calls + /// to `finish_preserve_values`. This enables ipc writers to efficiently + /// emit delta dictionaries. + /// + /// The downside to this is that building the record requires creating a + /// copy of the values, which can become slowly more expensive if the + /// dictionary grows. + /// + /// Additionally, if record batches from multiple different dictionary + /// builders for the same column are fed into a single ipc writer, beware + /// that entire dictionaries are likely to be re-sent frequently even when + /// the majority of the values are not used by the current record batch. + pub fn finish_preserve_values(&mut self) -> DictionaryArray { + let values = self.values_builder.finish_cloned(); + let keys = self.keys_builder.finish(); + + let data_type = DataType::Dictionary( + Box::new(K::DATA_TYPE), + Box::new(FixedSizeBinary(self.byte_width)), + ); + + let builder = keys + .into_data() + .into_builder() + .data_type(data_type) + .child_data(vec![values.into_data()]); + + DictionaryArray::from(unsafe { builder.build_unchecked() }) + } } fn get_bytes(values: &FixedSizeBinaryBuilder, byte_width: i32, idx: usize) -> &[u8] { @@ -508,4 +543,62 @@ mod tests { ); } } + + #[test] + fn test_finish_preserve_values() { + // Create the first dictionary + let mut builder = FixedSizeBinaryDictionaryBuilder::::new(3); + builder.append_value("aaa"); + builder.append_value("bbb"); + builder.append_value("ccc"); + let dict = builder.finish_preserve_values(); + assert_eq!(dict.keys().values(), &[0, 1, 2]); + let values = dict + .downcast_dict::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!( + values, + vec![ + Some("aaa".as_bytes()), + Some("bbb".as_bytes()), + Some("ccc".as_bytes()) + ] + ); + + // Create a new dictionary + builder.append_value("ddd"); + builder.append_value("eee"); + let dict2 = builder.finish_preserve_values(); + + // Make sure the keys are assigned after the old ones and we have the + // right values + assert_eq!(dict2.keys().values(), &[3, 4]); + let values = dict2 + .downcast_dict::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!(values, [Some("ddd".as_bytes()), Some("eee".as_bytes())]); + + // Check that we have all of the expected values + let all_values = dict2 + .values() + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!( + all_values, + [ + Some("aaa".as_bytes()), + Some("bbb".as_bytes()), + Some("ccc".as_bytes()), + Some("ddd".as_bytes()), + Some("eee".as_bytes()) + ] + ); + } } diff --git a/arrow-array/src/builder/generic_bytes_builder.rs b/arrow-array/src/builder/generic_bytes_builder.rs index 91ac2a483ef4..c2c743e3ab27 100644 --- a/arrow-array/src/builder/generic_bytes_builder.rs +++ b/arrow-array/src/builder/generic_bytes_builder.rs @@ -15,11 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::builder::{ArrayBuilder, BufferBuilder, UInt8BufferBuilder}; +use crate::builder::ArrayBuilder; use crate::types::{ByteArrayType, GenericBinaryType, GenericStringType}; use crate::{Array, ArrayRef, GenericByteArray, OffsetSizeTrait}; -use arrow_buffer::NullBufferBuilder; -use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer}; +use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer, NullBufferBuilder, ScalarBuffer}; use arrow_data::ArrayDataBuilder; use std::any::Any; use std::sync::Arc; @@ -29,8 +28,8 @@ use std::sync::Arc; /// For building strings, see docs on [`GenericStringBuilder`]. /// For building binary, see docs on [`GenericBinaryBuilder`]. pub struct GenericByteBuilder { - value_builder: UInt8BufferBuilder, - offsets_builder: BufferBuilder, + value_builder: Vec, + offsets_builder: Vec, null_buffer_builder: NullBufferBuilder, } @@ -47,10 +46,10 @@ impl GenericByteBuilder { /// - `data_capacity` is the total number of bytes of data to pre-allocate /// (for all items, not per item). pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { - let mut offsets_builder = BufferBuilder::::new(item_capacity + 1); - offsets_builder.append(T::Offset::from_usize(0).unwrap()); + let mut offsets_builder = Vec::with_capacity(item_capacity + 1); + offsets_builder.push(T::Offset::from_usize(0).unwrap()); Self { - value_builder: UInt8BufferBuilder::new(data_capacity), + value_builder: Vec::with_capacity(data_capacity), offsets_builder, null_buffer_builder: NullBufferBuilder::new(item_capacity), } @@ -67,8 +66,9 @@ impl GenericByteBuilder { value_buffer: MutableBuffer, null_buffer: Option, ) -> Self { - let offsets_builder = BufferBuilder::::new_from_buffer(offsets_buffer); - let value_builder = BufferBuilder::::new_from_buffer(value_buffer); + let offsets_builder: Vec = + ScalarBuffer::::from(offsets_buffer).into(); + let value_builder: Vec = ScalarBuffer::::from(value_buffer).into(); let null_buffer_builder = null_buffer .map(|buffer| NullBufferBuilder::new_from_buffer(buffer, offsets_builder.len() - 1)) @@ -103,9 +103,10 @@ impl GenericByteBuilder { /// [`BinaryArray`]: crate::BinaryArray #[inline] pub fn append_value(&mut self, value: impl AsRef) { - self.value_builder.append_slice(value.as_ref().as_ref()); + self.value_builder + .extend_from_slice(value.as_ref().as_ref()); self.null_buffer_builder.append(true); - self.offsets_builder.append(self.next_offset()); + self.offsets_builder.push(self.next_offset()); } /// Append an `Option` value into the builder. @@ -126,7 +127,7 @@ impl GenericByteBuilder { #[inline] pub fn append_null(&mut self) { self.null_buffer_builder.append(false); - self.offsets_builder.append(self.next_offset()); + self.offsets_builder.push(self.next_offset()); } /// Appends `n` `null`s into the builder. @@ -134,7 +135,8 @@ impl GenericByteBuilder { pub fn append_nulls(&mut self, n: usize) { self.null_buffer_builder.append_n_nulls(n); let next_offset = self.next_offset(); - self.offsets_builder.append_n(n, next_offset); + self.offsets_builder + .extend(std::iter::repeat_n(next_offset, n)); } /// Appends array values and null to this builder as is @@ -150,7 +152,7 @@ impl GenericByteBuilder { // If the offsets are contiguous, we can append them directly avoiding the need to align // for example, when the first appended array is not sliced (starts at offset 0) if self.next_offset() == offsets[0] { - self.offsets_builder.append_slice(&offsets[1..]); + self.offsets_builder.extend_from_slice(&offsets[1..]); } else { // Shifting all the offsets let shift: T::Offset = self.next_offset() - offsets[0]; @@ -164,11 +166,11 @@ impl GenericByteBuilder { intermediate.push(offset + shift) } - self.offsets_builder.append_slice(&intermediate); + self.offsets_builder.extend_from_slice(&intermediate); } // Append underlying values, starting from the first offset and ending at the last offset - self.value_builder.append_slice( + self.value_builder.extend_from_slice( &array.values().as_slice()[offsets[0].as_usize()..offsets[array.len()].as_usize()], ); @@ -184,11 +186,11 @@ impl GenericByteBuilder { let array_type = T::DATA_TYPE; let array_builder = ArrayDataBuilder::new(array_type) .len(self.len()) - .add_buffer(self.offsets_builder.finish()) - .add_buffer(self.value_builder.finish()) + .add_buffer(std::mem::take(&mut self.offsets_builder).into()) + .add_buffer(std::mem::take(&mut self.value_builder).into()) .nulls(self.null_buffer_builder.finish()); - self.offsets_builder.append(self.next_offset()); + self.offsets_builder.push(self.next_offset()); let array_data = unsafe { array_builder.build_unchecked() }; GenericByteArray::from(array_data) } @@ -340,7 +342,7 @@ pub type GenericStringBuilder = GenericByteBuilder>; impl std::fmt::Write for GenericStringBuilder { fn write_str(&mut self, s: &str) -> std::fmt::Result { - self.value_builder.append_slice(s.as_bytes()); + self.value_builder.extend_from_slice(s.as_bytes()); Ok(()) } } @@ -394,7 +396,7 @@ pub type GenericBinaryBuilder = GenericByteBuilder>; impl std::io::Write for GenericBinaryBuilder { fn write(&mut self, bs: &[u8]) -> std::io::Result { - self.value_builder.append_slice(bs); + self.value_builder.extend_from_slice(bs); Ok(bs.len()) } diff --git a/arrow-array/src/builder/generic_bytes_dictionary_builder.rs b/arrow-array/src/builder/generic_bytes_dictionary_builder.rs index a2ed91ac905d..1c7d8bedbcf1 100644 --- a/arrow-array/src/builder/generic_bytes_dictionary_builder.rs +++ b/arrow-array/src/builder/generic_bytes_dictionary_builder.rs @@ -463,6 +463,38 @@ where DictionaryArray::from(unsafe { builder.build_unchecked() }) } + /// Builds the `DictionaryArray` without resetting the values builder or + /// the internal de-duplication map. + /// + /// The advantage of doing this is that the values will represent the entire + /// set of what has been built so-far by this builder and ensures + /// consistency in the assignment of keys to values across multiple calls + /// to `finish_preserve_values`. This enables ipc writers to efficiently + /// emit delta dictionaries. + /// + /// The downside to this is that building the record requires creating a + /// copy of the values, which can become slowly more expensive if the + /// dictionary grows. + /// + /// Additionally, if record batches from multiple different dictionary + /// builders for the same column are fed into a single ipc writer, beware + /// that entire dictionaries are likely to be re-sent frequently even when + /// the majority of the values are not used by the current record batch. + pub fn finish_preserve_values(&mut self) -> DictionaryArray { + let values = self.values_builder.finish_cloned(); + let keys = self.keys_builder.finish(); + + let data_type = DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(T::DATA_TYPE)); + + let builder = keys + .into_data() + .into_builder() + .data_type(data_type) + .child_data(vec![values.into_data()]); + + DictionaryArray::from(unsafe { builder.build_unchecked() }) + } + /// Returns the current null buffer as a slice pub fn validity_slice(&self) -> Option<&[u8]> { self.keys_builder.validity_slice() @@ -1006,4 +1038,51 @@ mod tests { assert_eq!(values, [None, None]); } + + #[test] + fn test_finish_preserve_values() { + // Create the first dictionary + let mut builder = GenericByteDictionaryBuilder::::new(); + builder.append("a").unwrap(); + builder.append("b").unwrap(); + builder.append("c").unwrap(); + let dict = builder.finish_preserve_values(); + assert_eq!(dict.keys().values(), &[0, 1, 2]); + assert_eq!(dict.values().len(), 3); + let values = dict + .downcast_dict::>() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!(values, [Some("a"), Some("b"), Some("c")]); + + // Create a new dictionary + builder.append("d").unwrap(); + builder.append("e").unwrap(); + let dict2 = builder.finish_preserve_values(); + + // Make sure the keys are assigned after the old ones and we have the + // right values + assert_eq!(dict2.keys().values(), &[3, 4]); + let values = dict2 + .downcast_dict::>() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!(values, [Some("d"), Some("e")]); + + // Check that we have all of the expected values + assert_eq!(dict2.values().len(), 5); + let all_values = dict2 + .values() + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!( + all_values, + [Some("a"), Some("b"), Some("c"), Some("d"), Some("e"),] + ); + } } diff --git a/arrow-array/src/builder/generic_list_builder.rs b/arrow-array/src/builder/generic_list_builder.rs index 463b498c55ba..4d044ca35e2a 100644 --- a/arrow-array/src/builder/generic_list_builder.rs +++ b/arrow-array/src/builder/generic_list_builder.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::builder::{ArrayBuilder, BufferBuilder}; +use crate::builder::ArrayBuilder; use crate::{Array, ArrayRef, GenericListArray, OffsetSizeTrait}; use arrow_buffer::NullBufferBuilder; use arrow_buffer::{Buffer, OffsetBuffer}; @@ -86,7 +86,7 @@ use std::sync::Arc; /// [`LargeListArray`]: crate::array::LargeListArray #[derive(Debug)] pub struct GenericListBuilder { - offsets_builder: BufferBuilder, + offsets_builder: Vec, null_buffer_builder: NullBufferBuilder, values_builder: T, field: Option, @@ -108,8 +108,8 @@ impl GenericListBuilder Self { - let mut offsets_builder = BufferBuilder::::new(capacity + 1); - offsets_builder.append(OffsetSize::zero()); + let mut offsets_builder = Vec::with_capacity(capacity + 1); + offsets_builder.push(OffsetSize::zero()); Self { offsets_builder, null_buffer_builder: NullBufferBuilder::new(capacity), @@ -192,7 +192,7 @@ where /// Panics if the length of [`Self::values`] exceeds `OffsetSize::MAX` #[inline] pub fn append(&mut self, is_valid: bool) { - self.offsets_builder.append(self.next_offset()); + self.offsets_builder.push(self.next_offset()); self.null_buffer_builder.append(is_valid); } @@ -266,7 +266,7 @@ where /// See [`Self::append_value`] for an example use. #[inline] pub fn append_null(&mut self) { - self.offsets_builder.append(self.next_offset()); + self.offsets_builder.push(self.next_offset()); self.null_buffer_builder.append_null(); } @@ -274,7 +274,8 @@ where #[inline] pub fn append_nulls(&mut self, n: usize) { let next_offset = self.next_offset(); - self.offsets_builder.append_n(n, next_offset); + self.offsets_builder + .extend(std::iter::repeat_n(next_offset, n)); self.null_buffer_builder.append_n_nulls(n); } @@ -298,10 +299,10 @@ where let values = self.values_builder.finish(); let nulls = self.null_buffer_builder.finish(); - let offsets = self.offsets_builder.finish(); + let offsets = Buffer::from_vec(std::mem::take(&mut self.offsets_builder)); // Safety: Safe by construction let offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) }; - self.offsets_builder.append(OffsetSize::zero()); + self.offsets_builder.push(OffsetSize::zero()); let field = match &self.field { Some(f) => f.clone(), diff --git a/arrow-array/src/builder/generic_list_view_builder.rs b/arrow-array/src/builder/generic_list_view_builder.rs index 5aaf9efefe24..23204fca31b8 100644 --- a/arrow-array/src/builder/generic_list_view_builder.rs +++ b/arrow-array/src/builder/generic_list_view_builder.rs @@ -17,7 +17,7 @@ use crate::builder::ArrayBuilder; use crate::{ArrayRef, GenericListViewArray, OffsetSizeTrait}; -use arrow_buffer::{Buffer, BufferBuilder, NullBufferBuilder, ScalarBuffer}; +use arrow_buffer::{Buffer, NullBufferBuilder, ScalarBuffer}; use arrow_schema::{Field, FieldRef}; use std::any::Any; use std::sync::Arc; @@ -25,8 +25,8 @@ use std::sync::Arc; /// Builder for [`GenericListViewArray`] #[derive(Debug)] pub struct GenericListViewBuilder { - offsets_builder: BufferBuilder, - sizes_builder: BufferBuilder, + offsets_builder: Vec, + sizes_builder: Vec, null_buffer_builder: NullBufferBuilder, values_builder: T, field: Option, @@ -83,8 +83,8 @@ impl GenericListViewBuilder Self { - let offsets_builder = BufferBuilder::::new(capacity); - let sizes_builder = BufferBuilder::::new(capacity); + let offsets_builder = Vec::with_capacity(capacity); + let sizes_builder = Vec::with_capacity(capacity); Self { offsets_builder, null_buffer_builder: NullBufferBuilder::new(capacity), @@ -132,8 +132,8 @@ where /// Panics if the length of [`Self::values`] exceeds `OffsetSize::MAX` #[inline] pub fn append(&mut self, is_valid: bool) { - self.offsets_builder.append(self.current_offset); - self.sizes_builder.append( + self.offsets_builder.push(self.current_offset); + self.sizes_builder.push( OffsetSize::from_usize( self.values_builder.len() - self.current_offset.to_usize().unwrap(), ) @@ -158,9 +158,8 @@ where /// See [`Self::append_value`] for an example use. #[inline] pub fn append_null(&mut self) { - self.offsets_builder.append(self.current_offset); - self.sizes_builder - .append(OffsetSize::from_usize(0).unwrap()); + self.offsets_builder.push(self.current_offset); + self.sizes_builder.push(OffsetSize::from_usize(0).unwrap()); self.null_buffer_builder.append_null(); } @@ -183,12 +182,12 @@ where pub fn finish(&mut self) -> GenericListViewArray { let values = self.values_builder.finish(); let nulls = self.null_buffer_builder.finish(); - let offsets = self.offsets_builder.finish(); + let offsets = Buffer::from_vec(std::mem::take(&mut self.offsets_builder)); self.current_offset = OffsetSize::zero(); // Safety: Safe by construction let offsets = ScalarBuffer::from(offsets); - let sizes = self.sizes_builder.finish(); + let sizes = Buffer::from_vec(std::mem::take(&mut self.sizes_builder)); let sizes = ScalarBuffer::from(sizes); let field = match &self.field { Some(f) => f.clone(), diff --git a/arrow-array/src/builder/map_builder.rs b/arrow-array/src/builder/map_builder.rs index 012a454e76c9..a9895eabed32 100644 --- a/arrow-array/src/builder/map_builder.rs +++ b/arrow-array/src/builder/map_builder.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::builder::{ArrayBuilder, BufferBuilder}; +use crate::builder::ArrayBuilder; use crate::{Array, ArrayRef, MapArray, StructArray}; use arrow_buffer::Buffer; use arrow_buffer::{NullBuffer, NullBufferBuilder}; @@ -56,7 +56,7 @@ use std::sync::Arc; /// ``` #[derive(Debug)] pub struct MapBuilder { - offsets_builder: BufferBuilder, + offsets_builder: Vec, null_buffer_builder: NullBufferBuilder, field_names: MapFieldNames, key_builder: K, @@ -100,8 +100,8 @@ impl MapBuilder { value_builder: V, capacity: usize, ) -> Self { - let mut offsets_builder = BufferBuilder::::new(capacity + 1); - offsets_builder.append(0); + let mut offsets_builder = Vec::with_capacity(capacity + 1); + offsets_builder.push(0); Self { offsets_builder, null_buffer_builder: NullBufferBuilder::new(capacity), @@ -166,7 +166,7 @@ impl MapBuilder { self.value_builder.len() ))); } - self.offsets_builder.append(self.key_builder.len() as i32); + self.offsets_builder.push(self.key_builder.len() as i32); self.null_buffer_builder.append(is_valid); Ok(()) } @@ -177,8 +177,8 @@ impl MapBuilder { // Build the keys let keys_arr = self.key_builder.finish(); let values_arr = self.value_builder.finish(); - let offset_buffer = self.offsets_builder.finish(); - self.offsets_builder.append(0); + let offset_buffer = Buffer::from_vec(std::mem::take(&mut self.offsets_builder)); + self.offsets_builder.push(0); let null_bit_buffer = self.null_buffer_builder.finish(); self.finish_helper(keys_arr, values_arr, offset_buffer, null_bit_buffer, len) diff --git a/arrow-array/src/builder/primitive_builder.rs b/arrow-array/src/builder/primitive_builder.rs index 7aca730ce192..049cef241c83 100644 --- a/arrow-array/src/builder/primitive_builder.rs +++ b/arrow-array/src/builder/primitive_builder.rs @@ -15,11 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::builder::{ArrayBuilder, BufferBuilder}; +use crate::builder::ArrayBuilder; use crate::types::*; use crate::{Array, ArrayRef, PrimitiveArray}; -use arrow_buffer::NullBufferBuilder; -use arrow_buffer::{Buffer, MutableBuffer}; +use arrow_buffer::{Buffer, MutableBuffer, NullBufferBuilder, ScalarBuffer}; use arrow_data::ArrayData; use arrow_schema::{ArrowError, DataType}; use std::any::Any; @@ -99,7 +98,7 @@ pub type Decimal256Builder = PrimitiveBuilder; /// Builder for [`PrimitiveArray`] #[derive(Debug)] pub struct PrimitiveBuilder { - values_builder: BufferBuilder, + values_builder: Vec, null_buffer_builder: NullBufferBuilder, data_type: DataType, } @@ -151,7 +150,7 @@ impl PrimitiveBuilder { /// Creates a new primitive array builder with capacity no of items pub fn with_capacity(capacity: usize) -> Self { Self { - values_builder: BufferBuilder::::new(capacity), + values_builder: Vec::with_capacity(capacity), null_buffer_builder: NullBufferBuilder::new(capacity), data_type: T::DATA_TYPE, } @@ -162,7 +161,7 @@ impl PrimitiveBuilder { values_buffer: MutableBuffer, null_buffer: Option, ) -> Self { - let values_builder = BufferBuilder::::new_from_buffer(values_buffer); + let values_builder: Vec = ScalarBuffer::::from(values_buffer).into(); let null_buffer_builder = null_buffer .map(|buffer| NullBufferBuilder::new_from_buffer(buffer, values_builder.len())) @@ -204,28 +203,29 @@ impl PrimitiveBuilder { #[inline] pub fn append_value(&mut self, v: T::Native) { self.null_buffer_builder.append_non_null(); - self.values_builder.append(v); + self.values_builder.push(v); } /// Appends a value of type `T` into the builder `n` times #[inline] pub fn append_value_n(&mut self, v: T::Native, n: usize) { self.null_buffer_builder.append_n_non_nulls(n); - self.values_builder.append_n(n, v); + self.values_builder.extend(std::iter::repeat_n(v, n)); } /// Appends a null slot into the builder #[inline] pub fn append_null(&mut self) { self.null_buffer_builder.append_null(); - self.values_builder.advance(1); + self.values_builder.push(T::Native::default()); } /// Appends `n` no. of null's into the builder #[inline] pub fn append_nulls(&mut self, n: usize) { self.null_buffer_builder.append_n_nulls(n); - self.values_builder.advance(n); + self.values_builder + .extend(std::iter::repeat_n(T::Native::default(), n)); } /// Appends an `Option` into the builder @@ -241,7 +241,7 @@ impl PrimitiveBuilder { #[inline] pub fn append_slice(&mut self, v: &[T::Native]) { self.null_buffer_builder.append_n_non_nulls(v.len()); - self.values_builder.append_slice(v); + self.values_builder.extend_from_slice(v); } /// Appends values from a slice of type `T` and a validity boolean slice @@ -257,7 +257,7 @@ impl PrimitiveBuilder { "Value and validity lengths must be equal" ); self.null_buffer_builder.append_slice(is_valid); - self.values_builder.append_slice(values); + self.values_builder.extend_from_slice(values); } /// Appends array values and null to this builder as is @@ -274,7 +274,7 @@ impl PrimitiveBuilder { "array data type mismatch" ); - self.values_builder.append_slice(array.values()); + self.values_builder.extend_from_slice(array.values()); if let Some(null_buffer) = array.nulls() { self.null_buffer_builder.append_buffer(null_buffer); } else { @@ -296,7 +296,7 @@ impl PrimitiveBuilder { .expect("append_trusted_len_iter requires an upper bound"); self.null_buffer_builder.append_n_non_nulls(len); - self.values_builder.append_trusted_len_iter(iter); + self.values_builder.extend(iter); } /// Builds the [`PrimitiveArray`] and reset this builder. @@ -305,7 +305,7 @@ impl PrimitiveBuilder { let nulls = self.null_buffer_builder.finish(); let builder = ArrayData::builder(self.data_type.clone()) .len(len) - .add_buffer(self.values_builder.finish()) + .add_buffer(std::mem::take(&mut self.values_builder).into()) .nulls(nulls); let array_data = unsafe { builder.build_unchecked() }; @@ -333,7 +333,7 @@ impl PrimitiveBuilder { /// Returns the current values buffer as a mutable slice pub fn values_slice_mut(&mut self) -> &mut [T::Native] { - self.values_builder.as_slice_mut() + self.values_builder.as_mut_slice() } /// Returns the current null buffer as a slice @@ -349,7 +349,7 @@ impl PrimitiveBuilder { /// Returns the current values buffer and null buffer as a slice pub fn slices_mut(&mut self) -> (&mut [T::Native], Option<&mut [u8]>) { ( - self.values_builder.as_slice_mut(), + self.values_builder.as_mut_slice(), self.null_buffer_builder.as_slice_mut(), ) } diff --git a/arrow-array/src/builder/primitive_dictionary_builder.rs b/arrow-array/src/builder/primitive_dictionary_builder.rs index 1d921c6df097..acef8446ad4b 100644 --- a/arrow-array/src/builder/primitive_dictionary_builder.rs +++ b/arrow-array/src/builder/primitive_dictionary_builder.rs @@ -460,6 +460,38 @@ where DictionaryArray::from(unsafe { builder.build_unchecked() }) } + /// Builds the `DictionaryArray` without resetting the values builder or + /// the internal de-duplication map. + /// + /// The advantage of doing this is that the values will represent the entire + /// set of what has been built so-far by this builder and ensures + /// consistency in the assignment of keys to values across multiple calls + /// to `finish_preserve_values`. This enables ipc writers to efficiently + /// emit delta dictionaries. + /// + /// The downside to this is that building the record requires creating a + /// copy of the values, which can become slowly more expensive if the + /// dictionary grows. + /// + /// Additionally, if record batches from multiple different dictionary + /// builders for the same column are fed into a single ipc writer, beware + /// that entire dictionaries are likely to be re-sent frequently even when + /// the majority of the values are not used by the current record batch. + pub fn finish_preserve_values(&mut self) -> DictionaryArray { + let values = self.values_builder.finish_cloned(); + let keys = self.keys_builder.finish(); + + let data_type = DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(V::DATA_TYPE)); + + let builder = keys + .into_data() + .into_builder() + .data_type(data_type) + .child_data(vec![values.into_data()]); + + DictionaryArray::from(unsafe { builder.build_unchecked() }) + } + /// Returns the current dictionary values buffer as a slice pub fn values_slice(&self) -> &[V::Native] { self.values_builder.values_slice() @@ -817,4 +849,45 @@ mod tests { ); } } + + #[test] + fn test_finish_preserve_values() { + // Create the first dictionary + let mut builder = PrimitiveDictionaryBuilder::::new(); + builder.append(10).unwrap(); + builder.append(20).unwrap(); + let array = builder.finish_preserve_values(); + assert_eq!(array.keys(), &UInt8Array::from(vec![Some(0), Some(1)])); + let values: &[u32] = array + .values() + .as_any() + .downcast_ref::() + .unwrap() + .values(); + assert_eq!(values, &[10, 20]); + + // Create a new dictionary + builder.append(30).unwrap(); + builder.append(40).unwrap(); + let array2 = builder.finish_preserve_values(); + + // Make sure the keys are assigned after the old ones + // and that we have the right values + assert_eq!(array2.keys(), &UInt8Array::from(vec![Some(2), Some(3)])); + let values = array2 + .downcast_dict::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!(values, vec![Some(30), Some(40)]); + + // Check that we have all of the expected values + let all_values: &[u32] = array2 + .values() + .as_any() + .downcast_ref::() + .unwrap() + .values(); + assert_eq!(all_values, &[10, 20, 30, 40]); + } } diff --git a/arrow-array/src/builder/union_builder.rs b/arrow-array/src/builder/union_builder.rs index e6184f4ac6d2..1e7ddedf523f 100644 --- a/arrow-array/src/builder/union_builder.rs +++ b/arrow-array/src/builder/union_builder.rs @@ -16,10 +16,10 @@ // under the License. use crate::builder::buffer_builder::{Int32BufferBuilder, Int8BufferBuilder}; -use crate::builder::BufferBuilder; -use crate::{make_array, ArrowPrimitiveType, UnionArray}; +use crate::builder::{ArrayBuilder, BufferBuilder}; +use crate::{make_array, ArrayRef, ArrowPrimitiveType, UnionArray}; use arrow_buffer::NullBufferBuilder; -use arrow_buffer::{ArrowNativeType, Buffer}; +use arrow_buffer::{ArrowNativeType, Buffer, ScalarBuffer}; use arrow_data::ArrayDataBuilder; use arrow_schema::{ArrowError, DataType, Field}; use std::any::Any; @@ -42,12 +42,14 @@ struct FieldData { } /// A type-erased [`BufferBuilder`] used by [`FieldData`] -trait FieldDataValues: std::fmt::Debug { +trait FieldDataValues: std::fmt::Debug + Send + Sync { fn as_mut_any(&mut self) -> &mut dyn Any; fn append_null(&mut self); fn finish(&mut self) -> Buffer; + + fn finish_cloned(&self) -> Buffer; } impl FieldDataValues for BufferBuilder { @@ -62,6 +64,10 @@ impl FieldDataValues for BufferBuilder { fn finish(&mut self) -> Buffer { self.finish() } + + fn finish_cloned(&self) -> Buffer { + Buffer::from_slice_ref(self.as_slice()) + } } impl FieldData { @@ -138,7 +144,7 @@ impl FieldData { /// assert_eq!(union.value_offset(1), 1); /// assert_eq!(union.value_offset(2), 2); /// ``` -#[derive(Debug)] +#[derive(Debug, Default)] pub struct UnionBuilder { /// The current number of slots in the array len: usize, @@ -310,4 +316,172 @@ impl UnionBuilder { children, ) } + + /// Builds this builder creating a new `UnionArray` without consuming the builder. + /// + /// This is used for the `finish_cloned` implementation in `ArrayBuilder`. + fn build_cloned(&self) -> Result { + let mut children = Vec::with_capacity(self.fields.len()); + let union_fields: Vec<_> = self + .fields + .iter() + .map(|(name, field_data)| { + let FieldData { + type_id, + data_type, + values_buffer, + slots, + null_buffer_builder, + } = field_data; + + let array_ref = make_array(unsafe { + ArrayDataBuilder::new(data_type.clone()) + .add_buffer(values_buffer.finish_cloned()) + .len(*slots) + .nulls(null_buffer_builder.finish_cloned()) + .build_unchecked() + }); + children.push(array_ref); + ( + *type_id, + Arc::new(Field::new(name.clone(), data_type.clone(), false)), + ) + }) + .collect(); + UnionArray::try_new( + union_fields.into_iter().collect(), + ScalarBuffer::from(self.type_id_builder.as_slice().to_vec()), + self.value_offset_builder + .as_ref() + .map(|builder| ScalarBuffer::from(builder.as_slice().to_vec())), + children, + ) + } +} + +impl ArrayBuilder for UnionBuilder { + /// Returns the number of array slots in the builder + fn len(&self) -> usize { + self.len + } + + /// Builds the array + fn finish(&mut self) -> ArrayRef { + // Even simpler - just move the builder using mem::take and replace with default + let builder = std::mem::take(self); + + // Since UnionBuilder controls all invariants, this should never fail + Arc::new(builder.build().unwrap()) + } + + /// Builds the array without resetting the underlying builder + fn finish_cloned(&self) -> ArrayRef { + // We construct the UnionArray carefully to ensure try_new cannot fail. + // Since UnionBuilder controls all the invariants, this should never panic. + Arc::new(self.build_cloned().unwrap_or_else(|err| { + panic!("UnionBuilder::build_cloned failed unexpectedly: {}", err) + })) + } + + /// Returns the builder as a non-mutable `Any` reference + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns the builder as a mutable `Any` reference + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + /// Returns the boxed builder as a box of `Any` + fn into_box_any(self: Box) -> Box { + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::array::Array; + use crate::cast::AsArray; + use crate::types::{Float64Type, Int32Type}; + + #[test] + fn test_union_builder_array_builder_trait() { + // Test that UnionBuilder implements ArrayBuilder trait + let mut builder = UnionBuilder::new_dense(); + + // Add some data + builder.append::("a", 1).unwrap(); + builder.append::("b", 3.0).unwrap(); + builder.append::("a", 4).unwrap(); + + assert_eq!(builder.len(), 3); + + // Test finish_cloned (non-destructive) + let array1 = builder.finish_cloned(); + assert_eq!(array1.len(), 3); + + // Verify values in cloned array + let union1 = array1.as_any().downcast_ref::().unwrap(); + assert_eq!(union1.type_ids(), &[0, 1, 0]); + assert_eq!(union1.offsets().unwrap().as_ref(), &[0, 0, 1]); + let int_array1 = union1.child(0).as_primitive::(); + let float_array1 = union1.child(1).as_primitive::(); + assert_eq!(int_array1.value(0), 1); + assert_eq!(int_array1.value(1), 4); + assert_eq!(float_array1.value(0), 3.0); + + // Builder should still be usable after finish_cloned + builder.append::("b", 5.0).unwrap(); + assert_eq!(builder.len(), 4); + + // Test finish (destructive) + let array2 = builder.finish(); + assert_eq!(array2.len(), 4); + + // Verify values in final array + let union2 = array2.as_any().downcast_ref::().unwrap(); + assert_eq!(union2.type_ids(), &[0, 1, 0, 1]); + assert_eq!(union2.offsets().unwrap().as_ref(), &[0, 0, 1, 1]); + let int_array2 = union2.child(0).as_primitive::(); + let float_array2 = union2.child(1).as_primitive::(); + assert_eq!(int_array2.value(0), 1); + assert_eq!(int_array2.value(1), 4); + assert_eq!(float_array2.value(0), 3.0); + assert_eq!(float_array2.value(1), 5.0); + } + + #[test] + fn test_union_builder_type_erased() { + // Test type-erased usage with Box + let mut builders: Vec> = vec![Box::new(UnionBuilder::new_sparse())]; + + // Downcast and use + let union_builder = builders[0] + .as_any_mut() + .downcast_mut::() + .unwrap(); + union_builder.append::("x", 10).unwrap(); + union_builder.append::("y", 20.0).unwrap(); + + assert_eq!(builders[0].len(), 2); + + let result = builders + .into_iter() + .map(|mut b| b.finish()) + .collect::>(); + assert_eq!(result[0].len(), 2); + + // Verify sparse union values + let union = result[0].as_any().downcast_ref::().unwrap(); + assert_eq!(union.type_ids(), &[0, 1]); + assert!(union.offsets().is_none()); // Sparse union has no offsets + let int_array = union.child(0).as_primitive::(); + let float_array = union.child(1).as_primitive::(); + assert_eq!(int_array.value(0), 10); + assert!(int_array.is_null(1)); // Null in sparse layout + assert!(float_array.is_null(0)); // Null in sparse layout + assert_eq!(float_array.value(1), 20.0); + } } diff --git a/arrow-array/src/cast.rs b/arrow-array/src/cast.rs index 41fffc4bc80c..de590ff87c77 100644 --- a/arrow-array/src/cast.rs +++ b/arrow-array/src/cast.rs @@ -1132,6 +1132,18 @@ mod tests { assert!(!as_string_array(&array).is_empty()) } + #[test] + fn test_decimal32array() { + let a = Decimal32Array::from_iter_values([1, 2, 4, 5]); + assert!(!as_primitive_array::(&a).is_empty()); + } + + #[test] + fn test_decimal64array() { + let a = Decimal64Array::from_iter_values([1, 2, 4, 5]); + assert!(!as_primitive_array::(&a).is_empty()); + } + #[test] fn test_decimal128array() { let a = Decimal128Array::from_iter_values([1, 2, 4, 5]); diff --git a/arrow-array/src/timezone.rs b/arrow-array/src/timezone.rs index b4df77deb4f5..bcf582152146 100644 --- a/arrow-array/src/timezone.rs +++ b/arrow-array/src/timezone.rs @@ -53,6 +53,7 @@ mod private { use super::*; use chrono::offset::TimeZone; use chrono::{LocalResult, NaiveDate, NaiveDateTime, Offset}; + use std::fmt::Display; use std::str::FromStr; /// An [`Offset`] for [`Tz`] @@ -97,6 +98,15 @@ mod private { } } + impl Display for Tz { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0 { + TzInner::Timezone(tz) => tz.fmt(f), + TzInner::Offset(offset) => offset.fmt(f), + } + } + } + macro_rules! tz { ($s:ident, $tz:ident, $b:block) => { match $s.0 { @@ -228,6 +238,15 @@ mod private { sydney_offset_with_dst ); } + + #[test] + fn test_timezone_display() { + let test_cases = ["UTC", "America/Los_Angeles", "-08:00", "+05:30"]; + for &case in &test_cases { + let tz: Tz = case.parse().unwrap(); + assert_eq!(tz.to_string(), case); + } + } } } diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs index 96c496a536bb..144de8dbecbd 100644 --- a/arrow-array/src/types.rs +++ b/arrow-array/src/types.rs @@ -1820,6 +1820,8 @@ mod tests { test_layout::(); test_layout::(); test_layout::(); + test_layout::(); + test_layout::(); test_layout::(); test_layout::(); test_layout::(); diff --git a/arrow-avro/Cargo.toml b/arrow-avro/Cargo.toml index 1a1fc2f066ea..30c23e1932ae 100644 --- a/arrow-avro/Cargo.toml +++ b/arrow-avro/Cargo.toml @@ -40,6 +40,9 @@ default = ["deflate", "snappy", "zstd", "bzip2", "xz"] deflate = ["flate2"] snappy = ["snap", "crc"] canonical_extension_types = ["arrow-schema/canonical_extension_types"] +md5 = ["dep:md5"] +sha256 = ["dep:sha2"] +small_decimals = [] [dependencies] arrow-schema = { workspace = true } @@ -58,7 +61,9 @@ crc = { version = "3.0", optional = true } strum_macros = "0.27" uuid = "1.17" indexmap = "2.10" - +rand = "0.9" +md5 = { version = "0.8", optional = true } +sha2 = { version = "0.10", optional = true } [dev-dependencies] arrow-data = { workspace = true } @@ -73,7 +78,7 @@ arrow = { workspace = true } futures = "0.3.31" bytes = "1.10.1" async-stream = "0.3.6" -apache-avro = "0.14.0" +apache-avro = "0.20.0" num-bigint = "0.4" once_cell = "1.21.3" @@ -83,4 +88,8 @@ harness = false [[bench]] name = "decoder" +harness = false + +[[bench]] +name = "avro_writer" harness = false \ No newline at end of file diff --git a/arrow-avro/benches/avro_writer.rs b/arrow-avro/benches/avro_writer.rs new file mode 100644 index 000000000000..aeb9edbac82a --- /dev/null +++ b/arrow-avro/benches/avro_writer.rs @@ -0,0 +1,766 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks for `arrow-avro` Writer (Avro Object Container File) + +extern crate arrow_avro; +extern crate criterion; +extern crate once_cell; + +use arrow_array::{ + builder::{ListBuilder, StringBuilder}, + types::{Int32Type, Int64Type, IntervalMonthDayNanoType, TimestampMicrosecondType}, + ArrayRef, BinaryArray, BooleanArray, Decimal128Array, Decimal256Array, Decimal32Array, + Decimal64Array, FixedSizeBinaryArray, Float32Array, Float64Array, ListArray, PrimitiveArray, + RecordBatch, StringArray, StructArray, +}; +use arrow_avro::writer::AvroWriter; +use arrow_buffer::i256; +use arrow_schema::{DataType, Field, IntervalUnit, Schema, TimeUnit}; +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput}; +use once_cell::sync::Lazy; +use rand::{ + distr::uniform::{SampleRange, SampleUniform}, + rngs::StdRng, + Rng, SeedableRng, +}; +use std::collections::HashMap; +use std::io::Cursor; +use std::sync::Arc; +use std::time::Duration; +use tempfile::tempfile; + +const SIZES: [usize; 4] = [4_096, 8_192, 100_000, 1_000_000]; +const BASE_SEED: u64 = 0x5EED_1234_ABCD_EF01; +const MIX_CONST_1: u64 = 0x9E37_79B1_85EB_CA87; +const MIX_CONST_2: u64 = 0xC2B2_AE3D_27D4_EB4F; + +#[inline] +fn rng_for(tag: u64, n: usize) -> StdRng { + let seed = BASE_SEED ^ tag.wrapping_mul(MIX_CONST_1) ^ (n as u64).wrapping_mul(MIX_CONST_2); + StdRng::seed_from_u64(seed) +} + +#[inline] +fn sample_in(rng: &mut StdRng, range: Rg) -> T +where + T: SampleUniform, + Rg: SampleRange, +{ + rng.random_range(range) +} + +#[inline] +fn make_bool_array_with_tag(n: usize, tag: u64) -> BooleanArray { + let mut rng = rng_for(tag, n); + // Can't use SampleUniform for bool; use the RNG's boolean helper + let values = (0..n).map(|_| rng.random_bool(0.5)); + // This repo exposes `from_iter`, not `from_iter_values` for BooleanArray + BooleanArray::from_iter(values.map(Some)) +} + +#[inline] +fn make_i32_array_with_tag(n: usize, tag: u64) -> PrimitiveArray { + let mut rng = rng_for(tag, n); + let values = (0..n).map(|_| rng.random::()); + PrimitiveArray::::from_iter_values(values) +} + +#[inline] +fn make_i64_array_with_tag(n: usize, tag: u64) -> PrimitiveArray { + let mut rng = rng_for(tag, n); + let values = (0..n).map(|_| rng.random::()); + PrimitiveArray::::from_iter_values(values) +} + +#[inline] +fn rand_ascii_string(rng: &mut StdRng, min_len: usize, max_len: usize) -> String { + let len = rng.random_range(min_len..=max_len); + (0..len) + .map(|_| (rng.random_range(b'a'..=b'z') as char)) + .collect() +} + +#[inline] +fn make_utf8_array_with_tag(n: usize, tag: u64) -> StringArray { + let mut rng = rng_for(tag, n); + let data: Vec = (0..n).map(|_| rand_ascii_string(&mut rng, 3, 16)).collect(); + StringArray::from_iter_values(data) +} + +#[inline] +fn make_f32_array_with_tag(n: usize, tag: u64) -> Float32Array { + let mut rng = rng_for(tag, n); + let values = (0..n).map(|_| rng.random::()); + Float32Array::from_iter_values(values) +} + +#[inline] +fn make_f64_array_with_tag(n: usize, tag: u64) -> Float64Array { + let mut rng = rng_for(tag, n); + let values = (0..n).map(|_| rng.random::()); + Float64Array::from_iter_values(values) +} + +#[inline] +fn make_binary_array_with_tag(n: usize, tag: u64) -> BinaryArray { + let mut rng = rng_for(tag, n); + let mut payloads: Vec> = Vec::with_capacity(n); + for _ in 0..n { + let len = rng.random_range(1..=16); + let mut p = vec![0u8; len]; + rng.fill(&mut p[..]); + payloads.push(p); + } + let views: Vec<&[u8]> = payloads.iter().map(|p| &p[..]).collect(); + // This repo exposes a simple `from_vec` for BinaryArray + BinaryArray::from_vec(views) +} + +#[inline] +fn make_fixed16_array_with_tag(n: usize, tag: u64) -> FixedSizeBinaryArray { + let mut rng = rng_for(tag, n); + let payloads = (0..n) + .map(|_| { + let mut b = [0u8; 16]; + rng.fill(&mut b); + b + }) + .collect::>(); + // Fixed-size constructor available in this repo + FixedSizeBinaryArray::try_from_iter(payloads.into_iter()).expect("build FixedSizeBinaryArray") +} + +/// Make an Arrow `Interval(IntervalUnit::MonthDayNano)` array with **non-negative** +/// (months, days, nanos) values, and nanos as **multiples of 1_000_000** (whole ms), +/// per Avro `duration` constraints used by the writer. +#[inline] +fn make_interval_mdn_array_with_tag( + n: usize, + tag: u64, +) -> PrimitiveArray { + let mut rng = rng_for(tag, n); + let values = (0..n).map(|_| { + let months: i32 = rng.random_range(0..=120); + let days: i32 = rng.random_range(0..=31); + // pick millis within a day (safe within u32::MAX and realistic) + let millis: u32 = rng.random_range(0..=86_400_000); + let nanos: i64 = (millis as i64) * 1_000_000; + IntervalMonthDayNanoType::make_value(months, days, nanos) + }); + PrimitiveArray::::from_iter_values(values) +} + +#[inline] +fn make_ts_micros_array_with_tag(n: usize, tag: u64) -> PrimitiveArray { + let mut rng = rng_for(tag, n); + let base: i64 = 1_600_000_000_000_000; + let year_us: i64 = 31_536_000_000_000; + let values = (0..n).map(|_| base + sample_in::(&mut rng, 0..year_us)); + PrimitiveArray::::from_iter_values(values) +} + +// === Decimal helpers & generators === + +#[inline] +fn pow10_i32(p: u8) -> i32 { + (0..p).fold(1i32, |acc, _| acc.saturating_mul(10)) +} + +#[inline] +fn pow10_i64(p: u8) -> i64 { + (0..p).fold(1i64, |acc, _| acc.saturating_mul(10)) +} + +#[inline] +fn pow10_i128(p: u8) -> i128 { + (0..p).fold(1i128, |acc, _| acc.saturating_mul(10)) +} + +#[inline] +fn make_decimal32_array_with_tag(n: usize, tag: u64, precision: u8, scale: i8) -> Decimal32Array { + let mut rng = rng_for(tag, n); + let max = pow10_i32(precision).saturating_sub(1); + let values = (0..n).map(|_| rng.random_range(-max..=max)); + Decimal32Array::from_iter_values(values) + .with_precision_and_scale(precision, scale) + .expect("set precision/scale on Decimal32Array") +} + +#[inline] +fn make_decimal64_array_with_tag(n: usize, tag: u64, precision: u8, scale: i8) -> Decimal64Array { + let mut rng = rng_for(tag, n); + let max = pow10_i64(precision).saturating_sub(1); + let values = (0..n).map(|_| rng.random_range(-max..=max)); + Decimal64Array::from_iter_values(values) + .with_precision_and_scale(precision, scale) + .expect("set precision/scale on Decimal64Array") +} + +#[inline] +fn make_decimal128_array_with_tag(n: usize, tag: u64, precision: u8, scale: i8) -> Decimal128Array { + let mut rng = rng_for(tag, n); + let max = pow10_i128(precision).saturating_sub(1); + let values = (0..n).map(|_| rng.random_range(-max..=max)); + Decimal128Array::from_iter_values(values) + .with_precision_and_scale(precision, scale) + .expect("set precision/scale on Decimal128Array") +} + +#[inline] +fn make_decimal256_array_with_tag(n: usize, tag: u64, precision: u8, scale: i8) -> Decimal256Array { + // Generate within i128 range and widen to i256 to keep generation cheap and portable + let mut rng = rng_for(tag, n); + let max128 = pow10_i128(30).saturating_sub(1); + let values = (0..n).map(|_| { + let v: i128 = rng.random_range(-max128..=max128); + i256::from_i128(v) + }); + Decimal256Array::from_iter_values(values) + .with_precision_and_scale(precision, scale) + .expect("set precision/scale on Decimal256Array") +} + +#[inline] +fn make_fixed16_array(n: usize) -> FixedSizeBinaryArray { + make_fixed16_array_with_tag(n, 0xF15E_D016) +} + +#[inline] +fn make_interval_mdn_array(n: usize) -> PrimitiveArray { + make_interval_mdn_array_with_tag(n, 0xD0_1E_AD) +} + +#[inline] +fn make_bool_array(n: usize) -> BooleanArray { + make_bool_array_with_tag(n, 0xB001) +} +#[inline] +fn make_i32_array(n: usize) -> PrimitiveArray { + make_i32_array_with_tag(n, 0x1337_0032) +} +#[inline] +fn make_i64_array(n: usize) -> PrimitiveArray { + make_i64_array_with_tag(n, 0x1337_0064) +} +#[inline] +fn make_f32_array(n: usize) -> Float32Array { + make_f32_array_with_tag(n, 0xF0_0032) +} +#[inline] +fn make_f64_array(n: usize) -> Float64Array { + make_f64_array_with_tag(n, 0xF0_0064) +} +#[inline] +fn make_binary_array(n: usize) -> BinaryArray { + make_binary_array_with_tag(n, 0xB1_0001) +} +#[inline] +fn make_ts_micros_array(n: usize) -> PrimitiveArray { + make_ts_micros_array_with_tag(n, 0x7157_0001) +} +#[inline] +fn make_utf8_array(n: usize) -> StringArray { + make_utf8_array_with_tag(n, 0x5712_07F8) +} +#[inline] +fn make_list_utf8_array(n: usize) -> ListArray { + make_list_utf8_array_with_tag(n, 0x0A11_57ED) +} +#[inline] +fn make_struct_array(n: usize) -> StructArray { + make_struct_array_with_tag(n, 0x57_AB_C7) +} + +#[inline] +fn make_list_utf8_array_with_tag(n: usize, tag: u64) -> ListArray { + let mut rng = rng_for(tag, n); + let mut builder = ListBuilder::new(StringBuilder::new()); + for _ in 0..n { + let items = rng.random_range(0..=5); + for _ in 0..items { + let s = rand_ascii_string(&mut rng, 1, 12); + builder.values().append_value(s.as_str()); + } + builder.append(true); + } + builder.finish() +} + +#[inline] +fn make_struct_array_with_tag(n: usize, tag: u64) -> StructArray { + let s_tag = tag ^ 0x5u64; + let i_tag = tag ^ 0x6u64; + let f_tag = tag ^ 0x7u64; + let s_col: ArrayRef = Arc::new(make_utf8_array_with_tag(n, s_tag)); + let i_col: ArrayRef = Arc::new(make_i32_array_with_tag(n, i_tag)); + let f_col: ArrayRef = Arc::new(make_f64_array_with_tag(n, f_tag)); + StructArray::from(vec![ + ( + Arc::new(Field::new("s1", DataType::Utf8, false)), + s_col.clone(), + ), + ( + Arc::new(Field::new("s2", DataType::Int32, false)), + i_col.clone(), + ), + ( + Arc::new(Field::new("s3", DataType::Float64, false)), + f_col.clone(), + ), + ]) +} + +#[inline] +fn schema_single(name: &str, dt: DataType) -> Arc { + Arc::new(Schema::new(vec![Field::new(name, dt, false)])) +} + +#[inline] +fn schema_mixed() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("f1", DataType::Int32, false), + Field::new("f2", DataType::Int64, false), + Field::new("f3", DataType::Binary, false), + Field::new("f4", DataType::Float64, false), + ])) +} + +#[inline] +fn schema_fixed16() -> Arc { + schema_single("field1", DataType::FixedSizeBinary(16)) +} + +#[inline] +fn schema_uuid16() -> Arc { + let mut md = HashMap::new(); + md.insert("logicalType".to_string(), "uuid".to_string()); + let field = Field::new("uuid", DataType::FixedSizeBinary(16), false).with_metadata(md); + Arc::new(Schema::new(vec![field])) +} + +#[inline] +fn schema_interval_mdn() -> Arc { + schema_single("duration", DataType::Interval(IntervalUnit::MonthDayNano)) +} + +#[inline] +fn schema_decimal_with_size(name: &str, dt: DataType, size_meta: Option) -> Arc { + let field = if let Some(size) = size_meta { + let mut md = HashMap::new(); + md.insert("size".to_string(), size.to_string()); + Field::new(name, dt, false).with_metadata(md) + } else { + Field::new(name, dt, false) + }; + Arc::new(Schema::new(vec![field])) +} + +static BOOLEAN_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Boolean); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_bool_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static INT32_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Int32); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_i32_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static INT64_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Int64); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_i64_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static FLOAT32_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Float32); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_f32_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static FLOAT64_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Float64); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_f64_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static BINARY_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Binary); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_binary_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static FIXED16_DATA: Lazy> = Lazy::new(|| { + let schema = schema_fixed16(); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_fixed16_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static UUID16_DATA: Lazy> = Lazy::new(|| { + let schema = schema_uuid16(); + SIZES + .iter() + .map(|&n| { + // Same values as Fixed16; writer path differs because of field metadata + let col: ArrayRef = Arc::new(make_fixed16_array_with_tag(n, 0x7575_6964_7575_6964)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static INTERVAL_MDN_DATA: Lazy> = Lazy::new(|| { + let schema = schema_interval_mdn(); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_interval_mdn_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static TIMESTAMP_US_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Timestamp(TimeUnit::Microsecond, None)); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_ts_micros_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static MIXED_DATA: Lazy> = Lazy::new(|| { + let schema = schema_mixed(); + SIZES + .iter() + .map(|&n| { + let f1: ArrayRef = Arc::new(make_i32_array_with_tag(n, 0xA1)); + let f2: ArrayRef = Arc::new(make_i64_array_with_tag(n, 0xA2)); + let f3: ArrayRef = Arc::new(make_binary_array_with_tag(n, 0xA3)); + let f4: ArrayRef = Arc::new(make_f64_array_with_tag(n, 0xA4)); + RecordBatch::try_new(schema.clone(), vec![f1, f2, f3, f4]).unwrap() + }) + .collect() +}); + +static UTF8_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Utf8); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_utf8_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static LIST_UTF8_DATA: Lazy> = Lazy::new(|| { + // IMPORTANT: ListBuilder creates a child field named "item" that is nullable by default. + // Make the schema's list item nullable to match the array we construct. + let item_field = Arc::new(Field::new("item", DataType::Utf8, true)); + let schema = schema_single("field1", DataType::List(item_field)); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_list_utf8_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static STRUCT_DATA: Lazy> = Lazy::new(|| { + let struct_dt = DataType::Struct( + vec![ + Field::new("s1", DataType::Utf8, false), + Field::new("s2", DataType::Int32, false), + Field::new("s3", DataType::Float64, false), + ] + .into(), + ); + let schema = schema_single("field1", struct_dt); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_struct_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static DECIMAL32_DATA: Lazy> = Lazy::new(|| { + // Choose a representative precision/scale within Decimal32 limits + let precision: u8 = 7; + let scale: i8 = 2; + let schema = schema_single("amount", DataType::Decimal32(precision, scale)); + SIZES + .iter() + .map(|&n| { + let arr = make_decimal32_array_with_tag(n, 0xDEC_0032, precision, scale); + let col: ArrayRef = Arc::new(arr); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static DECIMAL64_DATA: Lazy> = Lazy::new(|| { + let precision: u8 = 13; + let scale: i8 = 3; + let schema = schema_single("amount", DataType::Decimal64(precision, scale)); + SIZES + .iter() + .map(|&n| { + let arr = make_decimal64_array_with_tag(n, 0xDEC_0064, precision, scale); + let col: ArrayRef = Arc::new(arr); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static DECIMAL128_BYTES_DATA: Lazy> = Lazy::new(|| { + let precision: u8 = 25; + let scale: i8 = 6; + let schema = schema_single("amount", DataType::Decimal128(precision, scale)); + SIZES + .iter() + .map(|&n| { + let arr = make_decimal128_array_with_tag(n, 0xDEC_0128, precision, scale); + let col: ArrayRef = Arc::new(arr); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static DECIMAL128_FIXED16_DATA: Lazy> = Lazy::new(|| { + // Same logical type as above but force Avro fixed(16) via metadata "size": "16" + let precision: u8 = 25; + let scale: i8 = 6; + let schema = + schema_decimal_with_size("amount", DataType::Decimal128(precision, scale), Some(16)); + SIZES + .iter() + .map(|&n| { + let arr = make_decimal128_array_with_tag(n, 0xDEC_F128, precision, scale); + let col: ArrayRef = Arc::new(arr); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static DECIMAL256_DATA: Lazy> = Lazy::new(|| { + // Use a higher precision typical of 256-bit decimals + let precision: u8 = 50; + let scale: i8 = 10; + let schema = schema_single("amount", DataType::Decimal256(precision, scale)); + SIZES + .iter() + .map(|&n| { + let arr = make_decimal256_array_with_tag(n, 0xDEC_0256, precision, scale); + let col: ArrayRef = Arc::new(arr); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static MAP_DATA: Lazy> = Lazy::new(|| { + use arrow_array::builder::{MapBuilder, StringBuilder}; + + let key_field = Arc::new(Field::new("keys", DataType::Utf8, false)); + let value_field = Arc::new(Field::new("values", DataType::Utf8, true)); + let entry_struct = Field::new( + "entries", + DataType::Struct(vec![key_field.as_ref().clone(), value_field.as_ref().clone()].into()), + false, + ); + let map_dt = DataType::Map(Arc::new(entry_struct), false); + let schema = schema_single("field1", map_dt); + + SIZES + .iter() + .map(|&n| { + // Build a MapArray with n rows + let mut builder = MapBuilder::new(None, StringBuilder::new(), StringBuilder::new()); + let mut rng = rng_for(0x00D0_0D1A, n); + for _ in 0..n { + let entries = rng.random_range(0..=5); + for _ in 0..entries { + let k = rand_ascii_string(&mut rng, 3, 10); + let v = rand_ascii_string(&mut rng, 0, 12); + // keys non-nullable, values nullable allowed but we provide non-null here + builder.keys().append_value(k); + builder.values().append_value(v); + } + builder.append(true).expect("Error building MapArray"); + } + let col: ArrayRef = Arc::new(builder.finish()); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static ENUM_DATA: Lazy> = Lazy::new(|| { + // To represent an Avro enum, the Arrow writer expects a Dictionary + // field with metadata specifying the enum symbols. + let enum_symbols = r#"["RED", "GREEN", "BLUE"]"#; + let mut metadata = HashMap::new(); + metadata.insert("avro.enum.symbols".to_string(), enum_symbols.to_string()); + + let dict_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let field = Field::new("color_enum", dict_type, false).with_metadata(metadata); + let schema = Arc::new(Schema::new(vec![field])); + + let dict_values: ArrayRef = Arc::new(StringArray::from(vec!["RED", "GREEN", "BLUE"])); + + SIZES + .iter() + .map(|&n| { + use arrow_array::DictionaryArray; + let mut rng = rng_for(0x3A7A, n); + let keys_vec: Vec = (0..n).map(|_| rng.random_range(0..=2)).collect(); + let keys = PrimitiveArray::::from(keys_vec); + + let dict_array = + DictionaryArray::::try_new(keys, dict_values.clone()).unwrap(); + let col: ArrayRef = Arc::new(dict_array); + + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +fn ocf_size_for_batch(batch: &RecordBatch) -> usize { + let schema_owned: Schema = (*batch.schema()).clone(); + let cursor = Cursor::new(Vec::::with_capacity(1024)); + let mut writer = AvroWriter::new(cursor, schema_owned).expect("create writer"); + writer.write(batch).expect("write batch"); + writer.finish().expect("finish writer"); + let inner = writer.into_inner(); + inner.into_inner().len() +} + +fn bench_writer_scenario(c: &mut Criterion, name: &str, data_sets: &[RecordBatch]) { + let mut group = c.benchmark_group(name); + let schema_owned: Schema = (*data_sets[0].schema()).clone(); + for (idx, &rows) in SIZES.iter().enumerate() { + let batch = &data_sets[idx]; + let bytes = ocf_size_for_batch(batch); + group.throughput(Throughput::Bytes(bytes as u64)); + match rows { + 4_096 | 8_192 => { + group + .sample_size(40) + .measurement_time(Duration::from_secs(10)) + .warm_up_time(Duration::from_secs(3)); + } + 100_000 => { + group + .sample_size(20) + .measurement_time(Duration::from_secs(10)) + .warm_up_time(Duration::from_secs(3)); + } + 1_000_000 => { + group + .sample_size(10) + .measurement_time(Duration::from_secs(10)) + .warm_up_time(Duration::from_secs(3)); + } + _ => {} + } + group.bench_function(BenchmarkId::from_parameter(rows), |b| { + b.iter_batched_ref( + || { + let file = tempfile().expect("create temp file"); + AvroWriter::new(file, schema_owned.clone()).expect("create writer") + }, + |writer| { + writer.write(batch).unwrap(); + writer.finish().unwrap(); + }, + BatchSize::SmallInput, + ) + }); + } + group.finish(); +} + +fn criterion_benches(c: &mut Criterion) { + bench_writer_scenario(c, "write-Boolean", &BOOLEAN_DATA); + bench_writer_scenario(c, "write-Int32", &INT32_DATA); + bench_writer_scenario(c, "write-Int64", &INT64_DATA); + bench_writer_scenario(c, "write-Float32", &FLOAT32_DATA); + bench_writer_scenario(c, "write-Float64", &FLOAT64_DATA); + bench_writer_scenario(c, "write-Binary(Bytes)", &BINARY_DATA); + bench_writer_scenario(c, "write-TimestampMicros", &TIMESTAMP_US_DATA); + bench_writer_scenario(c, "write-Mixed", &MIXED_DATA); + bench_writer_scenario(c, "write-Utf8", &UTF8_DATA); + bench_writer_scenario(c, "write-List", &LIST_UTF8_DATA); + bench_writer_scenario(c, "write-Struct", &STRUCT_DATA); + bench_writer_scenario(c, "write-FixedSizeBinary16", &FIXED16_DATA); + bench_writer_scenario(c, "write-UUID(logicalType)", &UUID16_DATA); + bench_writer_scenario(c, "write-IntervalMonthDayNanoDuration", &INTERVAL_MDN_DATA); + bench_writer_scenario(c, "write-Decimal32(bytes)", &DECIMAL32_DATA); + bench_writer_scenario(c, "write-Decimal64(bytes)", &DECIMAL64_DATA); + bench_writer_scenario(c, "write-Decimal128(bytes)", &DECIMAL128_BYTES_DATA); + bench_writer_scenario(c, "write-Decimal128(fixed16)", &DECIMAL128_FIXED16_DATA); + bench_writer_scenario(c, "write-Decimal256(bytes)", &DECIMAL256_DATA); + bench_writer_scenario(c, "write-Map", &MAP_DATA); + bench_writer_scenario(c, "write-Enum", &ENUM_DATA); +} + +criterion_group! { + name = avro_writer; + config = Criterion::default().configure_from_args(); + targets = criterion_benches +} +criterion_main!(avro_writer); diff --git a/arrow-avro/benches/decoder.rs b/arrow-avro/benches/decoder.rs index df802daea154..0ca240d12fc9 100644 --- a/arrow-avro/benches/decoder.rs +++ b/arrow-avro/benches/decoder.rs @@ -27,19 +27,42 @@ extern crate uuid; use apache_avro::types::Value; use apache_avro::{to_avro_datum, Decimal, Schema as ApacheSchema}; -use arrow_avro::schema::{Fingerprint, SINGLE_OBJECT_MAGIC}; +use arrow_avro::schema::{Fingerprint, FingerprintAlgorithm, CONFLUENT_MAGIC, SINGLE_OBJECT_MAGIC}; use arrow_avro::{reader::ReaderBuilder, schema::AvroSchema}; use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput}; use once_cell::sync::Lazy; use std::{hint::black_box, time::Duration}; use uuid::Uuid; -fn make_prefix(fp: Fingerprint) -> [u8; 10] { - let Fingerprint::Rabin(val) = fp; - let mut buf = [0u8; 10]; - buf[..2].copy_from_slice(&SINGLE_OBJECT_MAGIC); // C3 01 - buf[2..].copy_from_slice(&val.to_le_bytes()); // little‑endian 64‑bit - buf +fn make_prefix(fp: Fingerprint) -> Vec { + match fp { + Fingerprint::Rabin(val) => { + let mut buf = Vec::with_capacity(SINGLE_OBJECT_MAGIC.len() + size_of::()); + buf.extend_from_slice(&SINGLE_OBJECT_MAGIC); // C3 01 + buf.extend_from_slice(&val.to_le_bytes()); // little-endian + buf + } + Fingerprint::Id(id) => { + let mut buf = Vec::with_capacity(CONFLUENT_MAGIC.len() + size_of::()); + buf.extend_from_slice(&CONFLUENT_MAGIC); // 00 + buf.extend_from_slice(&id.to_be_bytes()); // big-endian + buf + } + #[cfg(feature = "md5")] + Fingerprint::MD5(val) => { + let mut buf = Vec::with_capacity(SINGLE_OBJECT_MAGIC.len() + size_of_val(&val)); + buf.extend_from_slice(&SINGLE_OBJECT_MAGIC); // C3 01 + buf.extend_from_slice(&val); + buf + } + #[cfg(feature = "sha256")] + Fingerprint::SHA256(val) => { + let mut buf = Vec::with_capacity(SINGLE_OBJECT_MAGIC.len() + size_of_val(&val)); + buf.extend_from_slice(&SINGLE_OBJECT_MAGIC); // C3 01 + buf.extend_from_slice(&val); + buf + } + } } fn encode_records_with_prefix( @@ -336,6 +359,27 @@ fn new_decoder( .expect("failed to build decoder") } +fn new_decoder_id( + schema_json: &'static str, + batch_size: usize, + utf8view: bool, + id: u32, +) -> arrow_avro::reader::Decoder { + let schema = AvroSchema::new(schema_json.parse().unwrap()); + let mut store = arrow_avro::schema::SchemaStore::new_with_type(FingerprintAlgorithm::None); + // Register the schema with a provided Confluent-style ID + store + .set(Fingerprint::Id(id), schema.clone()) + .expect("failed to set schema with id"); + ReaderBuilder::new() + .with_writer_schema_store(store) + .with_active_fingerprint(Fingerprint::Id(id)) + .with_batch_size(batch_size) + .with_utf8_view(utf8view) + .build_decoder() + .expect("failed to build decoder for id") +} + const SIZES: [usize; 3] = [100, 10_000, 1_000_000]; const INT_SCHEMA: &str = @@ -373,7 +417,7 @@ macro_rules! dataset { static $name: Lazy>> = Lazy::new(|| { let schema = ApacheSchema::parse_str($schema_json).expect("invalid schema for generator"); - let arrow_schema = AvroSchema::new($schema_json.to_string()); + let arrow_schema = AvroSchema::new($schema_json.parse().unwrap()); let fingerprint = arrow_schema.fingerprint().expect("fingerprint failed"); let prefix = make_prefix(fingerprint); SIZES @@ -384,6 +428,24 @@ macro_rules! dataset { }; } +/// Additional helper for Confluent's ID-based wire format (00 + BE u32). +macro_rules! dataset_id { + ($name:ident, $schema_json:expr, $gen_fn:ident, $id:expr) => { + static $name: Lazy>> = Lazy::new(|| { + let schema = + ApacheSchema::parse_str($schema_json).expect("invalid schema for generator"); + let prefix = make_prefix(Fingerprint::Id($id)); + SIZES + .iter() + .map(|&n| $gen_fn(&schema, n, &prefix)) + .collect() + }); + }; +} + +const ID_BENCH_ID: u32 = 7; + +dataset_id!(INT_DATA_ID, INT_SCHEMA, gen_int, ID_BENCH_ID); dataset!(INT_DATA, INT_SCHEMA, gen_int); dataset!(LONG_DATA, LONG_SCHEMA, gen_long); dataset!(FLOAT_DATA, FLOAT_SCHEMA, gen_float); @@ -406,19 +468,20 @@ dataset!(ENUM_DATA, ENUM_SCHEMA, gen_enum); dataset!(MIX_DATA, MIX_SCHEMA, gen_mixed); dataset!(NEST_DATA, NEST_SCHEMA, gen_nested); -fn bench_scenario( +fn bench_with_decoder( c: &mut Criterion, name: &str, - schema_json: &'static str, data_sets: &[Vec], - utf8view: bool, - batch_size: usize, -) { + rows: &[usize], + mut new_decoder: F, +) where + F: FnMut() -> arrow_avro::reader::Decoder, +{ let mut group = c.benchmark_group(name); - for (idx, &rows) in SIZES.iter().enumerate() { + for (idx, &row_count) in rows.iter().enumerate() { let datum = &data_sets[idx]; group.throughput(Throughput::Bytes(datum.len() as u64)); - match rows { + match row_count { 10_000 => { group .sample_size(25) @@ -433,9 +496,9 @@ fn bench_scenario( } _ => {} } - group.bench_function(BenchmarkId::from_parameter(rows), |b| { + group.bench_function(BenchmarkId::from_parameter(row_count), |b| { b.iter_batched_ref( - || new_decoder(schema_json, batch_size, utf8view), + &mut new_decoder, |decoder| { black_box(decoder.decode(datum).unwrap()); black_box(decoder.flush().unwrap().unwrap()); @@ -449,105 +512,75 @@ fn bench_scenario( fn criterion_benches(c: &mut Criterion) { for &batch_size in &[SMALL_BATCH, LARGE_BATCH] { - bench_scenario( - c, - "Interval", - INTERVAL_SCHEMA, - &INTERVAL_DATA, - false, - batch_size, - ); - bench_scenario(c, "Int32", INT_SCHEMA, &INT_DATA, false, batch_size); - bench_scenario(c, "Int64", LONG_SCHEMA, &LONG_DATA, false, batch_size); - bench_scenario(c, "Float32", FLOAT_SCHEMA, &FLOAT_DATA, false, batch_size); - bench_scenario(c, "Boolean", BOOL_SCHEMA, &BOOL_DATA, false, batch_size); - bench_scenario(c, "Float64", DOUBLE_SCHEMA, &DOUBLE_DATA, false, batch_size); - bench_scenario( - c, - "Binary(Bytes)", - BYTES_SCHEMA, - &BYTES_DATA, - false, - batch_size, - ); - bench_scenario(c, "String", STRING_SCHEMA, &STRING_DATA, false, batch_size); - bench_scenario( - c, - "StringView", - STRING_SCHEMA, - &STRING_DATA, - true, - batch_size, - ); - bench_scenario(c, "Date32", DATE_SCHEMA, &DATE_DATA, false, batch_size); - bench_scenario( - c, - "TimeMillis", - TMILLIS_SCHEMA, - &TMILLIS_DATA, - false, - batch_size, - ); - bench_scenario( - c, - "TimeMicros", - TMICROS_SCHEMA, - &TMICROS_DATA, - false, - batch_size, - ); - bench_scenario( - c, - "TimestampMillis", - TSMILLIS_SCHEMA, - &TSMILLIS_DATA, - false, - batch_size, - ); - bench_scenario( - c, - "TimestampMicros", - TSMICROS_SCHEMA, - &TSMICROS_DATA, - false, - batch_size, - ); - bench_scenario(c, "Map", MAP_SCHEMA, &MAP_DATA, false, batch_size); - bench_scenario(c, "Array", ARRAY_SCHEMA, &ARRAY_DATA, false, batch_size); - bench_scenario( - c, - "Decimal128", - DECIMAL_SCHEMA, - &DECIMAL_DATA, - false, - batch_size, - ); - bench_scenario(c, "UUID", UUID_SCHEMA, &UUID_DATA, false, batch_size); - bench_scenario( - c, - "FixedSizeBinary", - FIXED_SCHEMA, - &FIXED_DATA, - false, - batch_size, - ); - bench_scenario( - c, - "Enum(Dictionary)", - ENUM_SCHEMA, - &ENUM_DATA, - false, - batch_size, - ); - bench_scenario(c, "Mixed", MIX_SCHEMA, &MIX_DATA, false, batch_size); - bench_scenario( - c, - "Nested(Struct)", - NEST_SCHEMA, - &NEST_DATA, - false, - batch_size, - ); + bench_with_decoder(c, "Interval", &INTERVAL_DATA, &SIZES, || { + new_decoder(INTERVAL_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Int32", &INT_DATA, &SIZES, || { + new_decoder(INT_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Int32_Id", &INT_DATA_ID, &SIZES, || { + new_decoder_id(INT_SCHEMA, batch_size, false, ID_BENCH_ID) + }); + bench_with_decoder(c, "Int64", &LONG_DATA, &SIZES, || { + new_decoder(LONG_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Float32", &FLOAT_DATA, &SIZES, || { + new_decoder(FLOAT_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Boolean", &BOOL_DATA, &SIZES, || { + new_decoder(BOOL_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Float64", &DOUBLE_DATA, &SIZES, || { + new_decoder(DOUBLE_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Binary(Bytes)", &BYTES_DATA, &SIZES, || { + new_decoder(BYTES_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "String", &STRING_DATA, &SIZES, || { + new_decoder(STRING_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "StringView", &STRING_DATA, &SIZES, || { + new_decoder(STRING_SCHEMA, batch_size, true) + }); + bench_with_decoder(c, "Date32", &DATE_DATA, &SIZES, || { + new_decoder(DATE_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "TimeMillis", &TMILLIS_DATA, &SIZES, || { + new_decoder(TMILLIS_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "TimeMicros", &TMICROS_DATA, &SIZES, || { + new_decoder(TMICROS_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "TimestampMillis", &TSMILLIS_DATA, &SIZES, || { + new_decoder(TSMILLIS_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "TimestampMicros", &TSMICROS_DATA, &SIZES, || { + new_decoder(TSMICROS_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Map", &MAP_DATA, &SIZES, || { + new_decoder(MAP_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Array", &ARRAY_DATA, &SIZES, || { + new_decoder(ARRAY_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Decimal128", &DECIMAL_DATA, &SIZES, || { + new_decoder(DECIMAL_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "UUID", &UUID_DATA, &SIZES, || { + new_decoder(UUID_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "FixedSizeBinary", &FIXED_DATA, &SIZES, || { + new_decoder(FIXED_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Enum(Dictionary)", &ENUM_DATA, &SIZES, || { + new_decoder(ENUM_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Mixed", &MIX_DATA, &SIZES, || { + new_decoder(MIX_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Nested(Struct)", &NEST_DATA, &SIZES, || { + new_decoder(NEST_SCHEMA, batch_size, false) + }); } } diff --git a/arrow-avro/examples/decode_kafka_stream.rs b/arrow-avro/examples/decode_kafka_stream.rs new file mode 100644 index 000000000000..f5b0f2e6575b --- /dev/null +++ b/arrow-avro/examples/decode_kafka_stream.rs @@ -0,0 +1,233 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Decode **Confluent Schema Registry - framed** Avro messages into Arrow [`RecordBatch`]es, +//! resolving **older writer schemas** against a **current reader schema** without adding +//! any new reader‑only fields. +//! +//! What this example shows: +//! * A **reader schema** for the current topic version with fields: `{ id: long, name: string }`. +//! * Two older **writer schemas** (Confluent IDs **0** and **1**): +//! - v0: `{ id: int, name: string }` (older type for `id`) +//! - v1: `{ id: long, name: string, email: ["null","string"] }` (extra writer field `email`) +//! * Streaming decode with `ReaderBuilder::with_reader_schema(...)` so that: +//! - v0's `id:int` is **promoted** to `long` for the reader +//! - v1's extra `email` field is **ignored** by the reader (projection) +//! +//! Wire format reminder (message value bytes): +//! `0x00` magic byte + 4‑byte **big‑endian** schema ID + Avro **binary** body. +//! + +use arrow_array::{Int64Array, RecordBatch, StringArray}; +use arrow_avro::reader::ReaderBuilder; +use arrow_avro::schema::{ + AvroSchema, Fingerprint, FingerprintAlgorithm, SchemaStore, CONFLUENT_MAGIC, +}; +use arrow_schema::ArrowError; + +fn encode_long(value: i64, out: &mut Vec) { + let mut n = ((value << 1) ^ (value >> 63)) as u64; + while (n & !0x7F) != 0 { + out.push(((n as u8) & 0x7F) | 0x80); + n >>= 7; + } + out.push(n as u8); +} + +fn encode_len(len: usize, out: &mut Vec) { + encode_long(len as i64, out) +} + +fn encode_string(s: &str, out: &mut Vec) { + encode_len(s.len(), out); + out.extend_from_slice(s.as_bytes()); +} + +fn encode_union_index(index: i64, out: &mut Vec) { + encode_long(index, out); +} + +// Writer v0 (ID=0): +// {"type":"record","name":"User","fields":[ +// {"name":"id","type":"int"}, +// {"name":"name","type":"string"}]} +fn encode_user_v0_body(id: i32, name: &str) -> Vec { + let mut v = Vec::with_capacity(16 + name.len()); + encode_long(id as i64, &mut v); + encode_string(name, &mut v); + v +} + +// Writer v1 (ID=1): +// {"type":"record","name":"User","fields":[ +// {"name":"id","type":"long"}, +// {"name":"name","type":"string"}, +// {"name":"email","type":["null","string"],"default":null}]} +fn encode_user_v1_body(id: i64, name: &str, email: Option<&str>) -> Vec { + let mut v = Vec::with_capacity(24 + name.len() + email.map(|s| s.len()).unwrap_or(0)); + encode_long(id, &mut v); // id: long + encode_string(name, &mut v); // name: string + match email { + None => { + // union index 0 => null + encode_union_index(0, &mut v); + // no value bytes follow for null + } + Some(s) => { + // union index 1 => string, followed by the string payload + encode_union_index(1, &mut v); + encode_string(s, &mut v); + } + } + v +} + +fn frame_confluent(id_be: u32, body: &[u8]) -> Vec { + let mut out = Vec::with_capacity(5 + body.len()); + out.extend_from_slice(&CONFLUENT_MAGIC); // 0x00 + out.extend_from_slice(&id_be.to_be_bytes()); + out.extend_from_slice(body); + out +} + +fn print_arrow_schema(schema: &arrow_schema::Schema) { + println!("Resolved Arrow schema (via reader schema):"); + for (i, f) in schema.fields().iter().enumerate() { + println!( + " {i:>2}: {}: {:?} (nullable: {})", + f.name(), + f.data_type(), + f.is_nullable() + ); + } + if !schema.metadata.is_empty() { + println!(" metadata: {:?}", schema.metadata()); + } +} + +fn print_rows(batch: &RecordBatch) -> Result<(), ArrowError> { + let ids = batch + .column(0) + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::ComputeError("col 0 not Int64".into()))?; + let names = batch + .column(1) + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::ComputeError("col 1 not Utf8".into()))?; + for row in 0..batch.num_rows() { + let id = ids.value(row); + let name = names.value(row); + println!(" row {row}: id={id}, name={name}"); + } + Ok(()) +} + +fn main() -> Result<(), Box> { + // The current topic schema as a READER schema + let reader_schema = AvroSchema::new( + r#"{ + "type":"record","name":"User","fields":[ + {"name":"id","type":"long"}, + {"name":"name","type":"string"} + ]}"# + .to_string(), + ); + + // Two prior WRITER schemas versions under Confluent IDs 0 and 1 + let writer_v0 = AvroSchema::new( + r#"{ + "type":"record","name":"User","fields":[ + {"name":"id","type":"int"}, + {"name":"name","type":"string"} + ]}"# + .to_string(), + ); + let writer_v1 = AvroSchema::new( + r#"{ + "type":"record","name":"User","fields":[ + {"name":"id","type":"long"}, + {"name":"name","type":"string"}, + {"name":"email","type":["null","string"],"default":null} + ]}"# + .to_string(), + ); + + let id_v0: u32 = 0; + let id_v1: u32 = 1; + + // Confluent SchemaStore keyed by integer IDs (FingerprintAlgorithm::None) + let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::None); + store.set(Fingerprint::Id(id_v0), writer_v0.clone())?; + store.set(Fingerprint::Id(id_v1), writer_v1.clone())?; + + // Build a streaming Decoder with the READER schema + let mut decoder = ReaderBuilder::new() + .with_reader_schema(reader_schema) + .with_writer_schema_store(store) + .with_batch_size(8) // small batches for demo output + .build_decoder()?; + + // Print the resolved Arrow schema (derived from reader and writer) + let resolved = decoder.schema(); + print_arrow_schema(resolved.as_ref()); + println!(); + + // Simulate an interleaved Kafka stream (IDs 0 and 1) + // - v0: {id:int, name:string} --> reader: id promoted to long + // - v1: {id:long, name:string, email: ...} --> reader ignores extra field + let mut frames: Vec<(u32, Vec)> = Vec::new(); + + // Some v0 messages + for (i, name) in ["v0-alice", "v0-bob", "v0-carol"].iter().enumerate() { + let body = encode_user_v0_body(1000 + i as i32, name); + frames.push((id_v0, frame_confluent(id_v0, &body))); + } + + // Some v1 messages (may include optional email on the writer side) + let v1_rows = [ + (2001_i64, "v1-dave", Some("dave@example.com")), + (2002_i64, "v1-erin", None), + (2003_i64, "v1-frank", Some("frank@example.com")), + ]; + for (id, name, email) in v1_rows { + let body = encode_user_v1_body(id, name, email); + frames.push((id_v1, frame_confluent(id_v1, &body))); + } + + // Interleave to show mid-stream schema ID changes (0,1,0,1, ...) + frames.swap(1, 3); // crude interleave for demo + + // Decode frames as if they were Kafka record values + for (schema_id, frame) in frames { + println!("Decoding record framed with Confluent schema id = {schema_id}"); + let _consumed = decoder.decode(&frame)?; + while let Some(batch) = decoder.flush()? { + println!( + " -> Emitted batch: rows = {}, cols = {}", + batch.num_rows(), + batch.num_columns() + ); + print_rows(&batch)?; + } + println!(); + } + + println!("Done decoding Kafka-style stream with schema resolution (no reader-added fields)."); + Ok(()) +} diff --git a/arrow-avro/examples/decode_stream.rs b/arrow-avro/examples/decode_stream.rs new file mode 100644 index 000000000000..fe13382d2991 --- /dev/null +++ b/arrow-avro/examples/decode_stream.rs @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Decode Avro **stream-framed** bytes into Arrow [`RecordBatch`]es. +//! +//! This example demonstrates how to: +//! * Build a streaming `Decoder` via `ReaderBuilder::build_decoder` +//! * Register a writer schema keyed by a **Single‑Object** Rabin fingerprint +//! * Generate a few **Single‑Object** frames in‑memory and decode them + +use arrow_avro::reader::ReaderBuilder; +use arrow_avro::schema::{AvroSchema, Fingerprint, SchemaStore, SINGLE_OBJECT_MAGIC}; + +fn encode_long(value: i64, out: &mut Vec) { + let mut n = ((value << 1) ^ (value >> 63)) as u64; + while (n & !0x7F) != 0 { + out.push(((n as u8) & 0x7F) | 0x80); + n >>= 7; + } + out.push(n as u8); +} + +fn encode_len(len: usize, out: &mut Vec) { + encode_long(len as i64, out) +} + +fn encode_string(s: &str, out: &mut Vec) { + encode_len(s.len(), out); + out.extend_from_slice(s.as_bytes()); +} + +fn encode_user_body(id: i64, name: &str) -> Vec { + let mut v = Vec::with_capacity(16 + name.len()); + encode_long(id, &mut v); + encode_string(name, &mut v); + v +} + +// Frame a body as Avro Single‑Object: magic + 8-byte little‑endian fingerprint + body +fn frame_single_object(fp_rabin: u64, body: &[u8]) -> Vec { + let mut out = Vec::with_capacity(2 + 8 + body.len()); + out.extend_from_slice(&SINGLE_OBJECT_MAGIC); + out.extend_from_slice(&fp_rabin.to_le_bytes()); + out.extend_from_slice(body); + out +} + +fn main() -> Result<(), Box> { + // A tiny Avro writer schema used to generate a few messages + let avro = AvroSchema::new( + r#"{"type":"record","name":"User","fields":[ + {"name":"id","type":"long"}, + {"name":"name","type":"string"}]}"# + .to_string(), + ); + + // Register the writer schema in a store (keyed by Rabin fingerprint). + // Keep the fingerprint to seed the decoder and to frame generated messages. + let mut store = SchemaStore::new(); + let fp = store.register(avro.clone())?; + let rabin = match fp { + Fingerprint::Rabin(v) => v, + _ => unreachable!("Single‑Object framing uses Rabin fingerprints"), + }; + + // Build a streaming decoder configured for Single‑Object framing. + let mut decoder = ReaderBuilder::new() + .with_writer_schema_store(store) + .with_active_fingerprint(fp) + .build_decoder()?; + + // Generate 5 Single‑Object frames for the "User" schema. + let mut bytes = Vec::new(); + for i in 0..5 { + let body = encode_user_body(i as i64, &format!("user-{i}")); + bytes.extend_from_slice(&frame_single_object(rabin, &body)); + } + + // Feed all bytes at once, then flush completed batches. + let _consumed = decoder.decode(&bytes)?; + while let Some(batch) = decoder.flush()? { + println!( + "Batch: rows = {:>3}, cols = {}", + batch.num_rows(), + batch.num_columns() + ); + } + + Ok(()) +} diff --git a/arrow-avro/examples/read_avro_ocf.rs b/arrow-avro/examples/read_avro_ocf.rs new file mode 100644 index 000000000000..bf17ed572bfe --- /dev/null +++ b/arrow-avro/examples/read_avro_ocf.rs @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Read an Avro **Object Container File (OCF)** into Arrow [`RecordBatch`] values. +//! +//! This example demonstrates how to: +//! * Construct a [`Reader`] using [`ReaderBuilder::build`] +//! * Iterate `RecordBatch`es and print a brief summary + +use std::fs::File; +use std::io::BufReader; +use std::path::PathBuf; + +use arrow_array::RecordBatch; +use arrow_avro::reader::ReaderBuilder; + +fn main() -> Result<(), Box> { + let ocf_path: PathBuf = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("test") + .join("data") + .join("skippable_types.avro"); + + let reader = BufReader::new(File::open(&ocf_path)?); + // Build a high-level OCF Reader with default settings + let avro_reader = ReaderBuilder::new().build(reader)?; + let schema = avro_reader.schema(); + println!( + "Discovered Arrow schema with {} fields", + schema.fields().len() + ); + + let mut total_batches = 0usize; + let mut total_rows = 0usize; + let mut total_columns = schema.fields().len(); + + for result in avro_reader { + let batch: RecordBatch = result?; + total_batches += 1; + total_rows += batch.num_rows(); + total_columns = batch.num_columns(); + + println!( + "Batch {:>3}: rows = {:>6}, cols = {:>3}", + total_batches, + batch.num_rows(), + batch.num_columns() + ); + } + + println!(); + println!("Done."); + println!(" Batches : {total_batches}"); + println!(" Rows : {total_rows}"); + println!(" Columns : {total_columns}"); + + Ok(()) +} diff --git a/arrow-avro/examples/read_ocf_with_resolution.rs b/arrow-avro/examples/read_ocf_with_resolution.rs new file mode 100644 index 000000000000..7367ba3cd5b0 --- /dev/null +++ b/arrow-avro/examples/read_ocf_with_resolution.rs @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Read an Avro **Object Container File (OCF)** using an inline **reader schema** +//! that differs from the writer schema, demonstrating Avro **schema resolution** +//! (field projection and legal type promotion) without ever fetching the writer +//! schema from the file. +//! +//! What this example does: +//! 1. Locates `/test/data/skippable_types.avro` (portable path). +//! 2. Defines an inline **reader schema** JSON: +//! * Projects a subset of fields from the writer schema, and +//! * Promotes `"int"` to `"long"` where applicable. +//! 3. Builds a `Reader` with `ReaderBuilder::with_reader_schema(...)` and prints batches. + +use std::fs::File; +use std::io::BufReader; +use std::path::PathBuf; + +use arrow_array::RecordBatch; +use arrow_avro::reader::ReaderBuilder; +use arrow_avro::schema::AvroSchema; + +fn default_ocf_path() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("test") + .join("data") + .join("skippable_types.avro") +} + +// A minimal reader schema compatible with the provided writer schema +const READER_SCHEMA_JSON: &str = r#" +{ + "type": "record", + "name": "SkippableTypesRecord", + "fields": [ + { "name": "boolean_field", "type": "boolean" }, + { "name": "int_field", "type": "long" }, + { "name": "long_field", "type": "long" }, + { "name": "string_field", "type": "string" }, + { "name": "nullable_nullfirst_field", "type": ["null", "long"] } + ] +} +"#; + +fn main() -> Result<(), Box> { + let ocf_path = default_ocf_path(); + let file = File::open(&ocf_path)?; + let reader_schema = AvroSchema::new(READER_SCHEMA_JSON.to_string()); + + let reader = ReaderBuilder::new() + .with_reader_schema(reader_schema) + .build(BufReader::new(file))?; + + let resolved_schema = reader.schema(); + println!( + "Reader-based decode: resolved Arrow schema with {} fields", + resolved_schema.fields().len() + ); + + // Iterate batches and print a brief summary + let mut total_batches = 0usize; + let mut total_rows = 0usize; + for next in reader { + let batch: RecordBatch = next?; + total_batches += 1; + total_rows += batch.num_rows(); + println!( + " Batch {:>3}: rows = {:>6}, cols = {:>2}", + total_batches, + batch.num_rows(), + batch.num_columns() + ); + } + + println!(); + println!("Done (with reader/writer schema resolution)."); + println!(" Batches : {total_batches}"); + println!(" Rows : {total_rows}"); + + Ok(()) +} diff --git a/arrow-avro/examples/write_avro_ocf.rs b/arrow-avro/examples/write_avro_ocf.rs new file mode 100644 index 000000000000..5bdca0de7a3d --- /dev/null +++ b/arrow-avro/examples/write_avro_ocf.rs @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # Write an Avro Object Container File (OCF) from an Arrow `RecordBatch` +//! +//! This example builds a small Arrow `RecordBatch` and persists it to an +//! **Avro Object Container File (OCF)** using +//! `arrow_avro::writer::{Writer, WriterBuilder}`. +//! +//! ## What this example does +//! - Define an Arrow schema with supported types (`Int64`, `Utf8`, `Boolean`, +//! `Float64`, `Binary`, and `Timestamp (Microsecond, "UTC")`). +//! - Constructs arrays and a `RecordBatch`, ensuring each column’s data type +//! **exactly** matches the schema (timestamps include the `"UTC"` timezone). +//! - Writes a single batch to `target/write_avro_ocf_example.avro` as an OCF, +//! using Snappy block compression (you can disable or change the codec). +//! - Prints the file’s 16‑byte sync marker (used by OCF to delimit blocks). + +use std::fs::File; +use std::io::BufWriter; +use std::sync::Arc; + +use arrow_array::{ + ArrayRef, BinaryArray, BooleanArray, Float64Array, Int64Array, RecordBatch, StringArray, + TimestampMicrosecondArray, +}; +use arrow_avro::compression::CompressionCodec; +use arrow_avro::writer::format::AvroOcfFormat; +use arrow_avro::writer::{Writer, WriterBuilder}; +use arrow_schema::{DataType, Field, Schema, TimeUnit}; + +fn main() -> Result<(), Box> { + // Arrow schema + // id: Int64 (non-null) + // name: Utf8 (nullable) + // active: Boolean (non-null) + // score: Float64 (nullable) + // payload: Binary (nullable) + // created_at: Timestamp(Microsecond, Some("UTC")) (non-null) + let schema = Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, true), + Field::new("active", DataType::Boolean, false), + Field::new("score", DataType::Float64, true), + Field::new("payload", DataType::Binary, true), + Field::new( + "created_at", + DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC".to_string()))), + false, + ), + ]); + + let schema_ref = Arc::new(schema.clone()); + let ids = Int64Array::from(vec![1_i64, 2, 3]); + let names = StringArray::from(vec![Some("alpha"), None, Some("gamma")]); + let active = BooleanArray::from(vec![true, false, true]); + let scores = Float64Array::from(vec![Some(1.5_f64), None, Some(3.0)]); + + // BinaryArray: include a null + let payload = BinaryArray::from_opt_vec(vec![Some(&b"abc"[..]), None, Some(&[0u8, 1, 2][..])]); + + // Timestamp in microseconds since UNIX epoch + let created_at = TimestampMicrosecondArray::from(vec![ + Some(1_722_000_000_000_000_i64), + Some(1_722_000_123_456_000_i64), + Some(1_722_000_999_999_000_i64), + ]) + .with_timezone("UTC".to_string()); + + let columns: Vec = vec![ + Arc::new(ids), + Arc::new(names), + Arc::new(active), + Arc::new(scores), + Arc::new(payload), + Arc::new(created_at), + ]; + + let batch = RecordBatch::try_new(schema_ref, columns)?; + + // Build an OCF writer with optional compression + let out_path = "target/write_avro_ocf_example.avro"; + let file = File::create(out_path)?; + let mut writer: Writer<_, AvroOcfFormat> = WriterBuilder::new(schema) + .with_compression(Some(CompressionCodec::Snappy)) + .build(BufWriter::new(file))?; + + // Write a single batch (use `write_batches` for multiple) + writer.write(&batch)?; + writer.finish()?; // flush and finalize + + if let Some(sync) = writer.sync_marker() { + println!("Wrote OCF to {out_path} (sync marker: {:02x?})", &sync[..]); + } else { + println!("Wrote OCF to {out_path}"); + } + + Ok(()) +} diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index dcd39845014f..cf0276f0a25d 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -15,26 +15,113 @@ // specific language governing permissions and limitations // under the License. -use crate::schema::{Attributes, AvroSchema, ComplexType, PrimitiveType, Record, Schema, TypeName}; +use crate::schema::{ + Array, Attributes, AvroSchema, ComplexType, Enum, Fixed, Map, Nullability, PrimitiveType, + Record, Schema, Type, TypeName, AVRO_ENUM_SYMBOLS_METADATA_KEY, + AVRO_FIELD_DEFAULT_METADATA_KEY, AVRO_ROOT_RECORD_DEFAULT_NAME, +}; use arrow_schema::{ ArrowError, DataType, Field, Fields, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, - DECIMAL128_MAX_SCALE, + DECIMAL256_MAX_PRECISION, }; -use std::borrow::Cow; +#[cfg(feature = "small_decimals")] +use arrow_schema::{DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION}; +use indexmap::IndexMap; +use serde_json::Value; use std::collections::HashMap; use std::sync::Arc; -/// Avro types are not nullable, with nullability instead encoded as a union -/// where one of the variants is the null type. +/// Contains information about how to resolve differences between a writer's and a reader's schema. +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum ResolutionInfo { + /// Indicates that the writer's type should be promoted to the reader's type. + Promotion(Promotion), + /// Indicates that a default value should be used for a field. + DefaultValue(AvroLiteral), + /// Provides mapping information for resolving enums. + EnumMapping(EnumMapping), + /// Provides resolution information for record fields. + Record(ResolvedRecord), +} + +/// Represents a literal Avro value. +/// +/// This is used to represent default values in an Avro schema. +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum AvroLiteral { + /// Represents a null value. + Null, + /// Represents a boolean value. + Boolean(bool), + /// Represents an integer value. + Int(i32), + /// Represents a long value. + Long(i64), + /// Represents a float value. + Float(f32), + /// Represents a double value. + Double(f64), + /// Represents a bytes value. + Bytes(Vec), + /// Represents a string value. + String(String), + /// Represents an enum symbol. + Enum(String), + /// Represents a JSON array default for an Avro array, containing element literals. + Array(Vec), + /// Represents a JSON object default for an Avro map/struct, mapping string keys to value literals. + Map(IndexMap), + /// Represents an unsupported literal type. + Unsupported, +} + +/// Contains the necessary information to resolve a writer's record against a reader's record schema. +#[derive(Debug, Clone, PartialEq)] +pub struct ResolvedRecord { + /// Maps a writer's field index to the corresponding reader's field index. + /// `None` if the writer's field is not present in the reader's schema. + pub(crate) writer_to_reader: Arc<[Option]>, + /// A list of indices in the reader's schema for fields that have a default value. + pub(crate) default_fields: Arc<[usize]>, + /// For fields present in the writer's schema but not the reader's, this stores their data type. + /// This is needed to correctly skip over these fields during deserialization. + pub(crate) skip_fields: Arc<[Option]>, +} + +/// Defines the type of promotion to be applied during schema resolution. +/// +/// Schema resolution may require promoting a writer's data type to a reader's data type. +/// For example, an `int` can be promoted to a `long`, `float`, or `double`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum Promotion { + /// Promotes an `int` to a `long`. + IntToLong, + /// Promotes an `int` to a `float`. + IntToFloat, + /// Promotes an `int` to a `double`. + IntToDouble, + /// Promotes a `long` to a `float`. + LongToFloat, + /// Promotes a `long` to a `double`. + LongToDouble, + /// Promotes a `float` to a `double`. + FloatToDouble, + /// Promotes a `string` to `bytes`. + StringToBytes, + /// Promotes `bytes` to a `string`. + BytesToString, +} + +/// Holds the mapping information for resolving Avro enums. /// -/// To accommodate this we special case two-variant unions where one of the -/// variants is the null type, and use this to derive arrow's notion of nullability -#[derive(Debug, Copy, Clone)] -pub enum Nullability { - /// The nulls are encoded as the first union variant - NullFirst, - /// The nulls are encoded as the second union variant - NullSecond, +/// When resolving schemas, the writer's enum symbols must be mapped to the reader's symbols. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EnumMapping { + /// A mapping from the writer's symbol index to the reader's symbol index. + pub(crate) mapping: Arc<[i32]>, + /// The index to use for a writer's symbol that is not present in the reader's enum + /// and a default value is specified in the reader's schema. + pub(crate) default_index: i32, } #[cfg(feature = "canonical_extension_types")] @@ -46,11 +133,12 @@ fn with_extension_type(codec: &Codec, field: Field) -> Field { } /// An Avro datatype mapped to the arrow data model -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct AvroDataType { nullability: Option, metadata: HashMap, codec: Codec, + pub(crate) resolution: Option, } impl AvroDataType { @@ -64,6 +152,22 @@ impl AvroDataType { codec, metadata, nullability, + resolution: None, + } + } + + #[inline] + fn new_with_resolution( + codec: Codec, + metadata: HashMap, + nullability: Option, + resolution: Option, + ) -> Self { + Self { + codec, + metadata, + nullability, + resolution, } } @@ -96,10 +200,229 @@ impl AvroDataType { pub fn nullability(&self) -> Option { self.nullability } + + #[inline] + fn parse_default_literal(&self, default_json: &Value) -> Result { + fn expect_string<'v>( + default_json: &'v Value, + data_type: &str, + ) -> Result<&'v str, ArrowError> { + match default_json { + Value::String(s) => Ok(s.as_str()), + _ => Err(ArrowError::SchemaError(format!( + "Default value must be a JSON string for {data_type}" + ))), + } + } + + fn parse_bytes_default( + default_json: &Value, + expected_len: Option, + ) -> Result, ArrowError> { + let s = expect_string(default_json, "bytes/fixed logical types")?; + let mut out = Vec::with_capacity(s.len()); + for ch in s.chars() { + let cp = ch as u32; + if cp > 0xFF { + return Err(ArrowError::SchemaError(format!( + "Invalid codepoint U+{cp:04X} in bytes/fixed default; must be ≤ 0xFF" + ))); + } + out.push(cp as u8); + } + if let Some(len) = expected_len { + if out.len() != len { + return Err(ArrowError::SchemaError(format!( + "Default length {} does not match expected fixed size {len}", + out.len(), + ))); + } + } + Ok(out) + } + + fn parse_json_i64(default_json: &Value, data_type: &str) -> Result { + match default_json { + Value::Number(n) => n.as_i64().ok_or_else(|| { + ArrowError::SchemaError(format!("Default {data_type} must be an integer")) + }), + _ => Err(ArrowError::SchemaError(format!( + "Default {data_type} must be a JSON integer" + ))), + } + } + + fn parse_json_f64(default_json: &Value, data_type: &str) -> Result { + match default_json { + Value::Number(n) => n.as_f64().ok_or_else(|| { + ArrowError::SchemaError(format!("Default {data_type} must be a number")) + }), + _ => Err(ArrowError::SchemaError(format!( + "Default {data_type} must be a JSON number" + ))), + } + } + + // Handle JSON nulls per-spec: allowed only for `null` type or unions with null FIRST + if default_json.is_null() { + return match self.codec() { + Codec::Null => Ok(AvroLiteral::Null), + _ if self.nullability() == Some(Nullability::NullFirst) => Ok(AvroLiteral::Null), + _ => Err(ArrowError::SchemaError( + "JSON null default is only valid for `null` type or for a union whose first branch is `null`" + .to_string(), + )), + }; + } + let lit = match self.codec() { + Codec::Null => { + return Err(ArrowError::SchemaError( + "Default for `null` type must be JSON null".to_string(), + )) + } + Codec::Boolean => match default_json { + Value::Bool(b) => AvroLiteral::Boolean(*b), + _ => { + return Err(ArrowError::SchemaError( + "Boolean default must be a JSON boolean".to_string(), + )) + } + }, + Codec::Int32 | Codec::Date32 | Codec::TimeMillis => { + let i = parse_json_i64(default_json, "int")?; + if i < i32::MIN as i64 || i > i32::MAX as i64 { + return Err(ArrowError::SchemaError(format!( + "Default int {i} out of i32 range" + ))); + } + AvroLiteral::Int(i as i32) + } + Codec::Int64 + | Codec::TimeMicros + | Codec::TimestampMillis(_) + | Codec::TimestampMicros(_) => AvroLiteral::Long(parse_json_i64(default_json, "long")?), + Codec::Float32 => { + let f = parse_json_f64(default_json, "float")?; + if !f.is_finite() || f < f32::MIN as f64 || f > f32::MAX as f64 { + return Err(ArrowError::SchemaError(format!( + "Default float {f} out of f32 range or not finite" + ))); + } + AvroLiteral::Float(f as f32) + } + Codec::Float64 => AvroLiteral::Double(parse_json_f64(default_json, "double")?), + Codec::Utf8 | Codec::Utf8View | Codec::Uuid => { + AvroLiteral::String(expect_string(default_json, "string/uuid")?.to_string()) + } + Codec::Binary => AvroLiteral::Bytes(parse_bytes_default(default_json, None)?), + Codec::Fixed(sz) => { + AvroLiteral::Bytes(parse_bytes_default(default_json, Some(*sz as usize))?) + } + Codec::Decimal(_, _, fixed_size) => { + AvroLiteral::Bytes(parse_bytes_default(default_json, *fixed_size)?) + } + Codec::Enum(symbols) => { + let s = expect_string(default_json, "enum")?; + if symbols.iter().any(|sym| sym == s) { + AvroLiteral::Enum(s.to_string()) + } else { + return Err(ArrowError::SchemaError(format!( + "Default enum symbol {s:?} not found in reader enum symbols" + ))); + } + } + Codec::Interval => AvroLiteral::Bytes(parse_bytes_default(default_json, Some(12))?), + Codec::List(item_dt) => match default_json { + Value::Array(items) => AvroLiteral::Array( + items + .iter() + .map(|v| item_dt.parse_default_literal(v)) + .collect::>()?, + ), + _ => { + return Err(ArrowError::SchemaError( + "Default value must be a JSON array for Avro array type".to_string(), + )) + } + }, + Codec::Map(val_dt) => match default_json { + Value::Object(map) => { + let mut out = IndexMap::with_capacity(map.len()); + for (k, v) in map { + out.insert(k.clone(), val_dt.parse_default_literal(v)?); + } + AvroLiteral::Map(out) + } + _ => { + return Err(ArrowError::SchemaError( + "Default value must be a JSON object for Avro map type".to_string(), + )) + } + }, + Codec::Struct(fields) => match default_json { + Value::Object(obj) => { + let mut out: IndexMap = + IndexMap::with_capacity(fields.len()); + for f in fields.as_ref() { + let name = f.name().to_string(); + if let Some(sub) = obj.get(&name) { + out.insert(name, f.data_type().parse_default_literal(sub)?); + } else { + // Cache metadata lookup once + let stored_default = + f.data_type().metadata.get(AVRO_FIELD_DEFAULT_METADATA_KEY); + if stored_default.is_none() + && f.data_type().nullability() == Some(Nullability::default()) + { + out.insert(name, AvroLiteral::Null); + } else if let Some(default_json) = stored_default { + let v: Value = + serde_json::from_str(default_json).map_err(|e| { + ArrowError::SchemaError(format!( + "Failed to parse stored subfield default JSON for '{}': {e}", + f.name(), + )) + })?; + out.insert(name, f.data_type().parse_default_literal(&v)?); + } else { + return Err(ArrowError::SchemaError(format!( + "Record default missing required subfield '{}' with non-nullable type {:?}", + f.name(), + f.data_type().codec() + ))); + } + } + } + AvroLiteral::Map(out) + } + _ => { + return Err(ArrowError::SchemaError( + "Default value for record/struct must be a JSON object".to_string(), + )) + } + }, + }; + Ok(lit) + } + + fn store_default(&mut self, default_json: &Value) -> Result<(), ArrowError> { + let json_text = serde_json::to_string(default_json).map_err(|e| { + ArrowError::ParseError(format!("Failed to serialize default to JSON: {e}")) + })?; + self.metadata + .insert(AVRO_FIELD_DEFAULT_METADATA_KEY.to_string(), json_text); + Ok(()) + } + + fn parse_and_store_default(&mut self, default_json: &Value) -> Result { + let lit = self.parse_default_literal(default_json)?; + self.store_default(default_json)?; + Ok(lit) + } } /// A named [`AvroDataType`] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct AvroField { name: String, data_type: AvroDataType, @@ -151,9 +474,16 @@ impl AvroField { use_utf8view: bool, strict_mode: bool, ) -> Result { - Err(ArrowError::NotYetImplemented( - "Resolving schema from a writer and reader schema is not yet implemented".to_string(), - )) + let top_name = match reader_schema { + Schema::Complex(ComplexType::Record(r)) => r.name.to_string(), + _ => AVRO_ROOT_RECORD_DEFAULT_NAME.to_string(), + }; + let mut resolver = Maker::new(use_utf8view, strict_mode); + let data_type = resolver.make_data_type(writer_schema, Some(reader_schema), None)?; + Ok(Self { + name: top_name, + data_type, + }) } } @@ -163,8 +493,8 @@ impl<'a> TryFrom<&Schema<'a>> for AvroField { fn try_from(schema: &Schema<'a>) -> Result { match schema { Schema::Complex(ComplexType::Record(r)) => { - let mut resolver = Resolver::default(); - let data_type = make_data_type(schema, None, &mut resolver, false, false)?; + let mut resolver = Maker::new(false, false); + let data_type = resolver.make_data_type(schema, None, None)?; Ok(AvroField { data_type, name: r.name.to_string(), @@ -181,7 +511,7 @@ impl<'a> TryFrom<&Schema<'a>> for AvroField { #[derive(Debug)] pub struct AvroFieldBuilder<'a> { writer_schema: &'a Schema<'a>, - reader_schema: Option, + reader_schema: Option<&'a Schema<'a>>, use_utf8view: bool, strict_mode: bool, } @@ -202,7 +532,7 @@ impl<'a> AvroFieldBuilder<'a> { /// If a reader schema is provided, the builder will produce a resolved `AvroField` /// that can handle differences between the writer's and reader's schemas. #[inline] - pub fn with_reader_schema(mut self, reader_schema: AvroSchema) -> Self { + pub fn with_reader_schema(mut self, reader_schema: &'a Schema<'a>) -> Self { self.reader_schema = Some(reader_schema); self } @@ -223,14 +553,9 @@ impl<'a> AvroFieldBuilder<'a> { pub fn build(self) -> Result { match self.writer_schema { Schema::Complex(ComplexType::Record(r)) => { - let mut resolver = Resolver::default(); - let data_type = make_data_type( - self.writer_schema, - None, - &mut resolver, - self.use_utf8view, - self.strict_mode, - )?; + let mut resolver = Maker::new(self.use_utf8view, self.strict_mode); + let data_type = + resolver.make_data_type(self.writer_schema, self.reader_schema, None)?; Ok(AvroField { name: r.name.to_string(), data_type, @@ -247,7 +572,7 @@ impl<'a> AvroFieldBuilder<'a> { /// An Avro encoding /// /// -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum Codec { /// Represents Avro null type, maps to Arrow's Null data type Null, @@ -289,7 +614,7 @@ pub enum Codec { /// Represents Avro fixed type, maps to Arrow's FixedSizeBinary data type /// The i32 parameter indicates the fixed binary size Fixed(i32), - /// Represents Avro decimal type, maps to Arrow's Decimal128 or Decimal256 data types + /// Represents Avro decimal type, maps to Arrow's Decimal32, Decimal64, Decimal128, or Decimal256 data types /// /// The fields are `(precision, scale, fixed_size)`. /// - `precision` (`usize`): Total number of digits. @@ -335,20 +660,28 @@ impl Codec { } Self::Interval => DataType::Interval(IntervalUnit::MonthDayNano), Self::Fixed(size) => DataType::FixedSizeBinary(*size), - Self::Decimal(precision, scale, size) => { + Self::Decimal(precision, scale, _size) => { let p = *precision as u8; let s = scale.unwrap_or(0) as i8; - let too_large_for_128 = match *size { - Some(sz) => sz > 16, - None => { - (p as usize) > DECIMAL128_MAX_PRECISION as usize - || (s as usize) > DECIMAL128_MAX_SCALE as usize + #[cfg(feature = "small_decimals")] + { + if *precision <= DECIMAL32_MAX_PRECISION as usize { + DataType::Decimal32(p, s) + } else if *precision <= DECIMAL64_MAX_PRECISION as usize { + DataType::Decimal64(p, s) + } else if *precision <= DECIMAL128_MAX_PRECISION as usize { + DataType::Decimal128(p, s) + } else { + DataType::Decimal256(p, s) + } + } + #[cfg(not(feature = "small_decimals"))] + { + if *precision <= DECIMAL128_MAX_PRECISION as usize { + DataType::Decimal128(p, s) + } else { + DataType::Decimal256(p, s) } - }; - if too_large_for_128 { - DataType::Decimal256(p, s) - } else { - DataType::Decimal128(p, s) } } Self::Uuid => DataType::FixedSizeBinary(16), @@ -394,6 +727,29 @@ impl From for Codec { } } +/// Compute the exact maximum base‑10 precision that fits in `n` bytes for Avro +/// `fixed` decimals stored as two's‑complement unscaled integers (big‑endian). +/// +/// Per Avro spec (Decimal logical type), for a fixed length `n`: +/// max precision = ⌊log₁₀(2^(8n − 1) − 1)⌋. +/// +/// This function returns `None` if `n` is 0 or greater than 32 (Arrow supports +/// Decimal256, which is 32 bytes and has max precision 76). +const fn max_precision_for_fixed_bytes(n: usize) -> Option { + // Precomputed exact table for n = 1..=32 + // 1:2, 2:4, 3:6, 4:9, 5:11, 6:14, 7:16, 8:18, 9:21, 10:23, 11:26, 12:28, + // 13:31, 14:33, 15:35, 16:38, 17:40, 18:43, 19:45, 20:47, 21:50, 22:52, + // 23:55, 24:57, 25:59, 26:62, 27:64, 28:67, 29:69, 30:71, 31:74, 32:76 + const MAX_P: [usize; 32] = [ + 2, 4, 6, 9, 11, 14, 16, 18, 21, 23, 26, 28, 31, 33, 35, 38, 40, 43, 45, 47, 50, 52, 55, 57, + 59, 62, 64, 67, 69, 71, 74, 76, + ]; + match n { + 1..=32 => Some(MAX_P[n - 1]), + _ => None, + } +} + fn parse_decimal_attributes( attributes: &Attributes, fallback_size: Option, @@ -417,6 +773,34 @@ fn parse_decimal_attributes( .and_then(|v| v.as_u64()) .map(|s| s as usize) .or(fallback_size); + if precision == 0 { + return Err(ArrowError::ParseError( + "Decimal requires precision > 0".to_string(), + )); + } + if scale > precision { + return Err(ArrowError::ParseError(format!( + "Decimal has invalid scale > precision: scale={scale}, precision={precision}" + ))); + } + if precision > DECIMAL256_MAX_PRECISION as usize { + return Err(ArrowError::ParseError(format!( + "Decimal precision {precision} exceeds maximum supported by Arrow ({})", + DECIMAL256_MAX_PRECISION + ))); + } + if let Some(sz) = size { + let max_p = max_precision_for_fixed_bytes(sz).ok_or_else(|| { + ArrowError::ParseError(format!( + "Invalid fixed size for decimal: {sz}, must be between 1 and 32 bytes" + )) + })?; + if precision > max_p { + return Err(ArrowError::ParseError(format!( + "Decimal precision {precision} exceeds capacity of fixed size {sz} bytes (max {max_p})" + ))); + } + } Ok((precision, scale, size)) } @@ -467,7 +851,6 @@ impl<'a> Resolver<'a> { let (namespace, name) = name .rsplit_once('.') .unwrap_or_else(|| (namespace.unwrap_or(""), name)); - self.map .get(&(namespace, name)) .ok_or_else(|| ArrowError::ParseError(format!("Failed to resolve {namespace}.{name}"))) @@ -475,221 +858,664 @@ impl<'a> Resolver<'a> { } } -/// Parses a [`AvroDataType`] from the provided `schema` and the given `name` and `namespace` -/// -/// `name`: is name used to refer to `schema` in its parent -/// `namespace`: an optional qualifier used as part of a type hierarchy -/// If the data type is a string, convert to use Utf8View if requested -/// -/// This function is used during the schema conversion process to determine whether -/// string data should be represented as StringArray (default) or StringViewArray. -/// -/// `use_utf8view`: if true, use Utf8View instead of Utf8 for string types +fn names_match( + writer_name: &str, + writer_aliases: &[&str], + reader_name: &str, + reader_aliases: &[&str], +) -> bool { + writer_name == reader_name + || reader_aliases.contains(&writer_name) + || writer_aliases.contains(&reader_name) +} + +fn ensure_names_match( + data_type: &str, + writer_name: &str, + writer_aliases: &[&str], + reader_name: &str, + reader_aliases: &[&str], +) -> Result<(), ArrowError> { + if names_match(writer_name, writer_aliases, reader_name, reader_aliases) { + Ok(()) + } else { + Err(ArrowError::ParseError(format!( + "{data_type} name mismatch writer={writer_name}, reader={reader_name}" + ))) + } +} + +fn primitive_of(schema: &Schema) -> Option { + match schema { + Schema::TypeName(TypeName::Primitive(primitive)) => Some(*primitive), + Schema::Type(Type { + r#type: TypeName::Primitive(primitive), + .. + }) => Some(*primitive), + _ => None, + } +} + +fn nullable_union_variants<'x, 'y>( + variant: &'y [Schema<'x>], +) -> Option<(Nullability, &'y Schema<'x>)> { + if variant.len() != 2 { + return None; + } + let is_null = |schema: &Schema<'x>| { + matches!( + schema, + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)) + ) + }; + match (is_null(&variant[0]), is_null(&variant[1])) { + (true, false) => Some((Nullability::NullFirst, &variant[1])), + (false, true) => Some((Nullability::NullSecond, &variant[0])), + _ => None, + } +} + +/// Resolves Avro type names to [`AvroDataType`] /// -/// See [`Resolver`] for more information -fn make_data_type<'a>( - schema: &Schema<'a>, - namespace: Option<&'a str>, - resolver: &mut Resolver<'a>, +/// See +struct Maker<'a> { + resolver: Resolver<'a>, use_utf8view: bool, strict_mode: bool, -) -> Result { - match schema { - Schema::TypeName(TypeName::Primitive(p)) => { - let codec: Codec = (*p).into(); - let codec = codec.with_utf8view(use_utf8view); - Ok(AvroDataType { - nullability: None, - metadata: Default::default(), - codec, - }) +} + +impl<'a> Maker<'a> { + fn new(use_utf8view: bool, strict_mode: bool) -> Self { + Self { + resolver: Default::default(), + use_utf8view, + strict_mode, } - Schema::TypeName(TypeName::Ref(name)) => resolver.resolve(name, namespace), - Schema::Union(f) => { - // Special case the common case of nullable primitives - let null = f - .iter() - .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))); - match (f.len() == 2, null) { - (true, Some(0)) => { - let mut field = - make_data_type(&f[1], namespace, resolver, use_utf8view, strict_mode)?; - field.nullability = Some(Nullability::NullFirst); - Ok(field) - } - (true, Some(1)) => { - if strict_mode { - return Err(ArrowError::SchemaError( - "Found Avro union of the form ['T','null'], which is disallowed in strict_mode" - .to_string(), - )); + } + fn make_data_type<'s>( + &mut self, + writer_schema: &'s Schema<'a>, + reader_schema: Option<&'s Schema<'a>>, + namespace: Option<&'a str>, + ) -> Result { + match reader_schema { + Some(reader_schema) => self.resolve_type(writer_schema, reader_schema, namespace), + None => self.parse_type(writer_schema, namespace), + } + } + + /// Parses a [`AvroDataType`] from the provided [`Schema`] and the given `name` and `namespace` + /// + /// `name`: is the name used to refer to `schema` in its parent + /// `namespace`: an optional qualifier used as part of a type hierarchy + /// If the data type is a string, convert to use Utf8View if requested + /// + /// This function is used during the schema conversion process to determine whether + /// string data should be represented as StringArray (default) or StringViewArray. + /// + /// `use_utf8view`: if true, use Utf8View instead of Utf8 for string types + /// + /// See [`Resolver`] for more information + fn parse_type<'s>( + &mut self, + schema: &'s Schema<'a>, + namespace: Option<&'a str>, + ) -> Result { + match schema { + Schema::TypeName(TypeName::Primitive(p)) => Ok(AvroDataType::new( + Codec::from(*p).with_utf8view(self.use_utf8view), + Default::default(), + None, + )), + Schema::TypeName(TypeName::Ref(name)) => self.resolver.resolve(name, namespace), + Schema::Union(f) => { + // Special case the common case of nullable primitives + let null = f + .iter() + .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))); + match (f.len() == 2, null) { + (true, Some(0)) => { + let mut field = self.parse_type(&f[1], namespace)?; + field.nullability = Some(Nullability::NullFirst); + Ok(field) } - let mut field = - make_data_type(&f[0], namespace, resolver, use_utf8view, strict_mode)?; - field.nullability = Some(Nullability::NullSecond); - Ok(field) + (true, Some(1)) => { + if self.strict_mode { + return Err(ArrowError::SchemaError( + "Found Avro union of the form ['T','null'], which is disallowed in strict_mode" + .to_string(), + )); + } + let mut field = self.parse_type(&f[0], namespace)?; + field.nullability = Some(Nullability::NullSecond); + Ok(field) + } + _ => Err(ArrowError::NotYetImplemented(format!( + "Union of {f:?} not currently supported" + ))), } - _ => Err(ArrowError::NotYetImplemented(format!( - "Union of {f:?} not currently supported" - ))), } - } - Schema::Complex(c) => match c { - ComplexType::Record(r) => { - let namespace = r.namespace.or(namespace); - let fields = r - .fields - .iter() - .map(|field| { - Ok(AvroField { - name: field.name.to_string(), - data_type: make_data_type( - &field.r#type, - namespace, - resolver, - use_utf8view, - strict_mode, - )?, + Schema::Complex(c) => match c { + ComplexType::Record(r) => { + let namespace = r.namespace.or(namespace); + let fields = r + .fields + .iter() + .map(|field| { + Ok(AvroField { + name: field.name.to_string(), + data_type: self.parse_type(&field.r#type, namespace)?, + }) }) + .collect::>()?; + let field = AvroDataType { + nullability: None, + codec: Codec::Struct(fields), + metadata: r.attributes.field_metadata(), + resolution: None, + }; + self.resolver.register(r.name, namespace, field.clone()); + Ok(field) + } + ComplexType::Array(a) => { + let field = self.parse_type(a.items.as_ref(), namespace)?; + Ok(AvroDataType { + nullability: None, + metadata: a.attributes.field_metadata(), + codec: Codec::List(Arc::new(field)), + resolution: None, }) - .collect::>()?; - let field = AvroDataType { - nullability: None, - codec: Codec::Struct(fields), - metadata: r.attributes.field_metadata(), - }; - resolver.register(r.name, namespace, field.clone()); - Ok(field) - } - ComplexType::Array(a) => { - let mut field = make_data_type( - a.items.as_ref(), - namespace, - resolver, - use_utf8view, - strict_mode, - )?; - Ok(AvroDataType { - nullability: None, - metadata: a.attributes.field_metadata(), - codec: Codec::List(Arc::new(field)), - }) - } - ComplexType::Fixed(f) => { - let size = f.size.try_into().map_err(|e| { - ArrowError::ParseError(format!("Overflow converting size to i32: {e}")) - })?; - let md = f.attributes.field_metadata(); - let field = match f.attributes.logical_type { - Some("decimal") => { - let (precision, scale, _) = - parse_decimal_attributes(&f.attributes, Some(size as usize), true)?; - AvroDataType { - nullability: None, - metadata: md, - codec: Codec::Decimal(precision, Some(scale), Some(size as usize)), + } + ComplexType::Fixed(f) => { + let size = f.size.try_into().map_err(|e| { + ArrowError::ParseError(format!("Overflow converting size to i32: {e}")) + })?; + let md = f.attributes.field_metadata(); + let field = match f.attributes.logical_type { + Some("decimal") => { + let (precision, scale, _) = + parse_decimal_attributes(&f.attributes, Some(size as usize), true)?; + AvroDataType { + nullability: None, + metadata: md, + codec: Codec::Decimal(precision, Some(scale), Some(size as usize)), + resolution: None, + } } - } - Some("duration") => { - if size != 12 { - return Err(ArrowError::ParseError(format!( - "Invalid fixed size for Duration: {size}, must be 12" - ))); - }; - AvroDataType { + Some("duration") => { + if size != 12 { + return Err(ArrowError::ParseError(format!( + "Invalid fixed size for Duration: {size}, must be 12" + ))); + }; + AvroDataType { + nullability: None, + metadata: md, + codec: Codec::Interval, + resolution: None, + } + } + _ => AvroDataType { nullability: None, metadata: md, - codec: Codec::Interval, - } - } - _ => AvroDataType { + codec: Codec::Fixed(size), + resolution: None, + }, + }; + self.resolver.register(f.name, namespace, field.clone()); + Ok(field) + } + ComplexType::Enum(e) => { + let namespace = e.namespace.or(namespace); + let symbols = e + .symbols + .iter() + .map(|s| s.to_string()) + .collect::>(); + + let mut metadata = e.attributes.field_metadata(); + let symbols_json = serde_json::to_string(&e.symbols).map_err(|e| { + ArrowError::ParseError(format!("Failed to serialize enum symbols: {e}")) + })?; + metadata.insert(AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), symbols_json); + let field = AvroDataType { nullability: None, - metadata: md, - codec: Codec::Fixed(size), - }, - }; - resolver.register(f.name, namespace, field.clone()); - Ok(field) - } - ComplexType::Enum(e) => { - let namespace = e.namespace.or(namespace); - let symbols = e - .symbols - .iter() - .map(|s| s.to_string()) - .collect::>(); - - let mut metadata = e.attributes.field_metadata(); - let symbols_json = serde_json::to_string(&e.symbols).map_err(|e| { - ArrowError::ParseError(format!("Failed to serialize enum symbols: {e}")) - })?; - metadata.insert("avro.enum.symbols".to_string(), symbols_json); - let field = AvroDataType { - nullability: None, - metadata, - codec: Codec::Enum(symbols), - }; - resolver.register(e.name, namespace, field.clone()); - Ok(field) - } - ComplexType::Map(m) => { - let val = - make_data_type(&m.values, namespace, resolver, use_utf8view, strict_mode)?; - Ok(AvroDataType { - nullability: None, - metadata: m.attributes.field_metadata(), - codec: Codec::Map(Arc::new(val)), - }) - } - }, - Schema::Type(t) => { - let mut field = make_data_type( - &Schema::TypeName(t.r#type.clone()), - namespace, - resolver, - use_utf8view, - strict_mode, - )?; - - // https://avro.apache.org/docs/1.11.1/specification/#logical-types - match (t.attributes.logical_type, &mut field.codec) { - (Some("decimal"), c @ Codec::Binary) => { - let (prec, sc, _) = parse_decimal_attributes(&t.attributes, None, false)?; - *c = Codec::Decimal(prec, Some(sc), None); + metadata, + codec: Codec::Enum(symbols), + resolution: None, + }; + self.resolver.register(e.name, namespace, field.clone()); + Ok(field) } - (Some("date"), c @ Codec::Int32) => *c = Codec::Date32, - (Some("time-millis"), c @ Codec::Int32) => *c = Codec::TimeMillis, - (Some("time-micros"), c @ Codec::Int64) => *c = Codec::TimeMicros, - (Some("timestamp-millis"), c @ Codec::Int64) => *c = Codec::TimestampMillis(true), - (Some("timestamp-micros"), c @ Codec::Int64) => *c = Codec::TimestampMicros(true), - (Some("local-timestamp-millis"), c @ Codec::Int64) => { - *c = Codec::TimestampMillis(false) + ComplexType::Map(m) => { + let val = self.parse_type(&m.values, namespace)?; + Ok(AvroDataType { + nullability: None, + metadata: m.attributes.field_metadata(), + codec: Codec::Map(Arc::new(val)), + resolution: None, + }) } - (Some("local-timestamp-micros"), c @ Codec::Int64) => { - *c = Codec::TimestampMicros(false) + }, + Schema::Type(t) => { + let mut field = self.parse_type(&Schema::TypeName(t.r#type.clone()), namespace)?; + // https://avro.apache.org/docs/1.11.1/specification/#logical-types + match (t.attributes.logical_type, &mut field.codec) { + (Some("decimal"), c @ Codec::Binary) => { + let (prec, sc, _) = parse_decimal_attributes(&t.attributes, None, false)?; + *c = Codec::Decimal(prec, Some(sc), None); + } + (Some("date"), c @ Codec::Int32) => *c = Codec::Date32, + (Some("time-millis"), c @ Codec::Int32) => *c = Codec::TimeMillis, + (Some("time-micros"), c @ Codec::Int64) => *c = Codec::TimeMicros, + (Some("timestamp-millis"), c @ Codec::Int64) => { + *c = Codec::TimestampMillis(true) + } + (Some("timestamp-micros"), c @ Codec::Int64) => { + *c = Codec::TimestampMicros(true) + } + (Some("local-timestamp-millis"), c @ Codec::Int64) => { + *c = Codec::TimestampMillis(false) + } + (Some("local-timestamp-micros"), c @ Codec::Int64) => { + *c = Codec::TimestampMicros(false) + } + (Some("uuid"), c @ Codec::Utf8) => *c = Codec::Uuid, + (Some(logical), _) => { + // Insert unrecognized logical type into metadata map + field.metadata.insert("logicalType".into(), logical.into()); + } + (None, _) => {} } - (Some("uuid"), c @ Codec::Utf8) => *c = Codec::Uuid, - (Some(logical), _) => { - // Insert unrecognized logical type into metadata map - field.metadata.insert("logicalType".into(), logical.into()); + if !t.attributes.additional.is_empty() { + for (k, v) in &t.attributes.additional { + field.metadata.insert(k.to_string(), v.to_string()); + } } - (None, _) => {} + Ok(field) } + } + } - if !t.attributes.additional.is_empty() { - for (k, v) in &t.attributes.additional { - field.metadata.insert(k.to_string(), v.to_string()); - } + fn resolve_type<'s>( + &mut self, + writer_schema: &'s Schema<'a>, + reader_schema: &'s Schema<'a>, + namespace: Option<&'a str>, + ) -> Result { + if let (Some(write_primitive), Some(read_primitive)) = + (primitive_of(writer_schema), primitive_of(reader_schema)) + { + return self.resolve_primitives(write_primitive, read_primitive, reader_schema); + } + match (writer_schema, reader_schema) { + ( + Schema::Complex(ComplexType::Array(writer_array)), + Schema::Complex(ComplexType::Array(reader_array)), + ) => self.resolve_array(writer_array, reader_array, namespace), + ( + Schema::Complex(ComplexType::Map(writer_map)), + Schema::Complex(ComplexType::Map(reader_map)), + ) => self.resolve_map(writer_map, reader_map, namespace), + ( + Schema::Complex(ComplexType::Fixed(writer_fixed)), + Schema::Complex(ComplexType::Fixed(reader_fixed)), + ) => self.resolve_fixed(writer_fixed, reader_fixed, reader_schema, namespace), + ( + Schema::Complex(ComplexType::Record(writer_record)), + Schema::Complex(ComplexType::Record(reader_record)), + ) => self.resolve_records(writer_record, reader_record, namespace), + ( + Schema::Complex(ComplexType::Enum(writer_enum)), + Schema::Complex(ComplexType::Enum(reader_enum)), + ) => self.resolve_enums(writer_enum, reader_enum, reader_schema, namespace), + (Schema::Union(writer_variants), Schema::Union(reader_variants)) => self + .resolve_nullable_union( + writer_variants.as_slice(), + reader_variants.as_slice(), + namespace, + ), + (Schema::TypeName(TypeName::Ref(_)), _) => self.parse_type(reader_schema, namespace), + (_, Schema::TypeName(TypeName::Ref(_))) => self.parse_type(reader_schema, namespace), + _ => Err(ArrowError::NotYetImplemented( + "Other resolutions not yet implemented".to_string(), + )), + } + } + + fn resolve_array( + &mut self, + writer_array: &Array<'a>, + reader_array: &Array<'a>, + namespace: Option<&'a str>, + ) -> Result { + Ok(AvroDataType { + nullability: None, + metadata: reader_array.attributes.field_metadata(), + codec: Codec::List(Arc::new(self.make_data_type( + writer_array.items.as_ref(), + Some(reader_array.items.as_ref()), + namespace, + )?)), + resolution: None, + }) + } + + fn resolve_map( + &mut self, + writer_map: &Map<'a>, + reader_map: &Map<'a>, + namespace: Option<&'a str>, + ) -> Result { + Ok(AvroDataType { + nullability: None, + metadata: reader_map.attributes.field_metadata(), + codec: Codec::Map(Arc::new(self.make_data_type( + &writer_map.values, + Some(&reader_map.values), + namespace, + )?)), + resolution: None, + }) + } + + fn resolve_fixed<'s>( + &mut self, + writer_fixed: &Fixed<'a>, + reader_fixed: &Fixed<'a>, + reader_schema: &'s Schema<'a>, + namespace: Option<&'a str>, + ) -> Result { + ensure_names_match( + "Fixed", + writer_fixed.name, + &writer_fixed.aliases, + reader_fixed.name, + &reader_fixed.aliases, + )?; + if writer_fixed.size != reader_fixed.size { + return Err(ArrowError::SchemaError(format!( + "Fixed size mismatch for {}: writer={}, reader={}", + reader_fixed.name, writer_fixed.size, reader_fixed.size + ))); + } + self.parse_type(reader_schema, namespace) + } + + fn resolve_primitives( + &mut self, + write_primitive: PrimitiveType, + read_primitive: PrimitiveType, + reader_schema: &Schema<'a>, + ) -> Result { + if write_primitive == read_primitive { + return self.parse_type(reader_schema, None); + } + let promotion = match (write_primitive, read_primitive) { + (PrimitiveType::Int, PrimitiveType::Long) => Promotion::IntToLong, + (PrimitiveType::Int, PrimitiveType::Float) => Promotion::IntToFloat, + (PrimitiveType::Int, PrimitiveType::Double) => Promotion::IntToDouble, + (PrimitiveType::Long, PrimitiveType::Float) => Promotion::LongToFloat, + (PrimitiveType::Long, PrimitiveType::Double) => Promotion::LongToDouble, + (PrimitiveType::Float, PrimitiveType::Double) => Promotion::FloatToDouble, + (PrimitiveType::String, PrimitiveType::Bytes) => Promotion::StringToBytes, + (PrimitiveType::Bytes, PrimitiveType::String) => Promotion::BytesToString, + _ => { + return Err(ArrowError::ParseError(format!( + "Illegal promotion {write_primitive:?} to {read_primitive:?}" + ))) } - Ok(field) + }; + let mut datatype = self.parse_type(reader_schema, None)?; + datatype.resolution = Some(ResolutionInfo::Promotion(promotion)); + Ok(datatype) + } + + fn resolve_nullable_union<'s>( + &mut self, + writer_variants: &'s [Schema<'a>], + reader_variants: &'s [Schema<'a>], + namespace: Option<&'a str>, + ) -> Result { + match ( + nullable_union_variants(writer_variants), + nullable_union_variants(reader_variants), + ) { + (Some((_, write_nonnull)), Some((read_nb, read_nonnull))) => { + let mut dt = self.make_data_type(write_nonnull, Some(read_nonnull), namespace)?; + // Adopt reader union null ordering + dt.nullability = Some(read_nb); + Ok(dt) + } + _ => Err(ArrowError::NotYetImplemented( + "Union resolution requires both writer and reader to be 2-branch nullable unions" + .to_string(), + )), } } + + // Resolve writer vs. reader enum schemas according to Avro 1.11.1. + // + // # How enums resolve (writer to reader) + // Per “Schema Resolution”: + // * The two schemas must refer to the same (unqualified) enum name (or match + // via alias rewriting). + // * If the writer’s symbol is not present in the reader’s enum and the reader + // enum has a `default`, that `default` symbol must be used; otherwise, + // error. + // https://avro.apache.org/docs/1.11.1/specification/#schema-resolution + // * Avro “Aliases” are applied from the reader side to rewrite the writer’s + // names during resolution. For robustness across ecosystems, we also accept + // symmetry here (see note below). + // https://avro.apache.org/docs/1.11.1/specification/#aliases + // + // # Rationale for this code path + // 1. Do the work once at schema‑resolution time. Avro serializes an enum as a + // writer‑side position. Mapping positions on the hot decoder path is expensive + // if done with string lookups. This method builds a `writer_index to reader_index` + // vector once, so decoding just does an O(1) table lookup. + // 2. Adopt the reader’s symbol set and order. We return an Arrow + // `Dictionary(Int32, Utf8)` whose dictionary values are the reader enum + // symbols. This makes downstream semantics match the reader schema, including + // Avro’s sort order rule that orders enums by symbol position in the schema. + // https://avro.apache.org/docs/1.11.1/specification/#sort-order + // 3. Honor Avro’s `default` for enums. Avro 1.9+ allows a type‑level default + // on the enum. When the writer emits a symbol unknown to the reader, we map it + // to the reader’s validated `default` symbol if present; otherwise we signal an + // error at decoding time. + // https://avro.apache.org/docs/1.11.1/specification/#enums + // + // # Implementation notes + // * We first check that enum names match or are*alias‑equivalent. The Avro + // spec describes alias rewriting using reader aliases; this implementation + // additionally treats writer aliases as acceptable for name matching to be + // resilient with schemas produced by different tooling. + // * We build `EnumMapping`: + // - `mapping[i]` = reader index of the writer symbol at writer index `i`. + // - If the writer symbol is absent and the reader has a default, we store the + // reader index of that default. + // - Otherwise we store `-1` as a sentinel meaning unresolvable; the decoder + // must treat encountering such a value as an error, per the spec. + // * We persist the reader symbol list in field metadata under + // `AVRO_ENUM_SYMBOLS_METADATA_KEY`, so consumers can inspect the dictionary + // without needing the original Avro schema. + // * The Arrow representation is `Dictionary(Int32, Utf8)`, which aligns with + // Avro’s integer index encoding for enums. + // + // # Examples + // * Writer `["A","B","C"]`, Reader `["A","B"]`, Reader default `"A"` + // `mapping = [0, 1, 0]`, `default_index = 0`. + // * Writer `["A","B"]`, Reader `["B","A"]` (no default) + // `mapping = [1, 0]`, `default_index = -1`. + // * Writer `["A","B","C"]`, Reader `["A","B"]` (no default) + // `mapping = [0, 1, -1]` (decode must error on `"C"`). + fn resolve_enums( + &mut self, + writer_enum: &Enum<'a>, + reader_enum: &Enum<'a>, + reader_schema: &Schema<'a>, + namespace: Option<&'a str>, + ) -> Result { + ensure_names_match( + "Enum", + writer_enum.name, + &writer_enum.aliases, + reader_enum.name, + &reader_enum.aliases, + )?; + if writer_enum.symbols == reader_enum.symbols { + return self.parse_type(reader_schema, namespace); + } + let reader_index: HashMap<&str, i32> = reader_enum + .symbols + .iter() + .enumerate() + .map(|(index, &symbol)| (symbol, index as i32)) + .collect(); + let default_index: i32 = match reader_enum.default { + Some(symbol) => *reader_index.get(symbol).ok_or_else(|| { + ArrowError::SchemaError(format!( + "Reader enum '{}' default symbol '{symbol}' not found in symbols list", + reader_enum.name, + )) + })?, + None => -1, + }; + let mapping: Vec = writer_enum + .symbols + .iter() + .map(|&write_symbol| { + reader_index + .get(write_symbol) + .copied() + .unwrap_or(default_index) + }) + .collect(); + if self.strict_mode && mapping.iter().any(|&m| m < 0) { + return Err(ArrowError::SchemaError(format!( + "Reader enum '{}' does not cover all writer symbols and no default is provided", + reader_enum.name + ))); + } + let mut dt = self.parse_type(reader_schema, namespace)?; + dt.resolution = Some(ResolutionInfo::EnumMapping(EnumMapping { + mapping: Arc::from(mapping), + default_index, + })); + let reader_ns = reader_enum.namespace.or(namespace); + self.resolver + .register(reader_enum.name, reader_ns, dt.clone()); + Ok(dt) + } + + fn resolve_records( + &mut self, + writer_record: &Record<'a>, + reader_record: &Record<'a>, + namespace: Option<&'a str>, + ) -> Result { + ensure_names_match( + "Record", + writer_record.name, + &writer_record.aliases, + reader_record.name, + &reader_record.aliases, + )?; + let writer_ns = writer_record.namespace.or(namespace); + let reader_ns = reader_record.namespace.or(namespace); + let reader_md = reader_record.attributes.field_metadata(); + let writer_index_map: HashMap<&str, usize> = writer_record + .fields + .iter() + .enumerate() + .map(|(idx, wf)| (wf.name, idx)) + .collect(); + let mut writer_to_reader: Vec> = vec![None; writer_record.fields.len()]; + let reader_fields: Vec = reader_record + .fields + .iter() + .enumerate() + .map(|(reader_idx, r_field)| -> Result { + if let Some(&writer_idx) = writer_index_map.get(r_field.name) { + let w_schema = &writer_record.fields[writer_idx].r#type; + let dt = self.make_data_type(w_schema, Some(&r_field.r#type), reader_ns)?; + writer_to_reader[writer_idx] = Some(reader_idx); + Ok(AvroField { + name: r_field.name.to_string(), + data_type: dt, + }) + } else { + let mut dt = self.parse_type(&r_field.r#type, reader_ns)?; + match r_field.default.as_ref() { + Some(default_json) => { + dt.resolution = Some(ResolutionInfo::DefaultValue( + dt.parse_and_store_default(default_json)?, + )); + } + None => { + if dt.nullability() == Some(Nullability::NullFirst) { + dt.resolution = Some(ResolutionInfo::DefaultValue( + dt.parse_and_store_default(&Value::Null)?, + )); + } else { + return Err(ArrowError::SchemaError(format!( + "Reader field '{}' not present in writer schema must have a default value", + r_field.name + ))); + } + } + } + Ok(AvroField { + name: r_field.name.to_string(), + data_type: dt, + }) + } + }) + .collect::>()?; + let default_fields: Vec = reader_fields + .iter() + .enumerate() + .filter_map(|(index, field)| { + matches!( + field.data_type().resolution, + Some(ResolutionInfo::DefaultValue(_)) + ) + .then_some(index) + }) + .collect(); + let skip_fields: Vec> = writer_record + .fields + .iter() + .enumerate() + .map(|(writer_index, writer_field)| { + if writer_to_reader[writer_index].is_some() { + Ok(None) + } else { + self.parse_type(&writer_field.r#type, writer_ns).map(Some) + } + }) + .collect::>()?; + let resolved = AvroDataType::new_with_resolution( + Codec::Struct(Arc::from(reader_fields)), + reader_md, + None, + Some(ResolutionInfo::Record(ResolvedRecord { + writer_to_reader: Arc::from(writer_to_reader), + default_fields: Arc::from(default_fields), + skip_fields: Arc::from(skip_fields), + })), + ); + // Register a resolved record by reader name+namespace for potential named type refs + self.resolver + .register(reader_record.name, reader_ns, resolved.clone()); + Ok(resolved) + } } #[cfg(test)] mod tests { use super::*; - use crate::schema::{Attributes, PrimitiveType, Schema, Type, TypeName}; + use crate::schema::{Attributes, Fixed, PrimitiveType, Schema, Type, TypeName}; use serde_json; fn create_schema_with_logical_type( @@ -707,12 +1533,36 @@ mod tests { }) } + fn create_fixed_schema(size: usize, logical_type: &'static str) -> Schema<'static> { + let attributes = Attributes { + logical_type: Some(logical_type), + additional: Default::default(), + }; + + Schema::Complex(ComplexType::Fixed(Fixed { + name: "fixed_type", + namespace: None, + aliases: Vec::new(), + size, + attributes, + })) + } + + fn resolve_promotion(writer: PrimitiveType, reader: PrimitiveType) -> AvroDataType { + let writer_schema = Schema::TypeName(TypeName::Primitive(writer)); + let reader_schema = Schema::TypeName(TypeName::Primitive(reader)); + let mut maker = Maker::new(false, false); + maker + .make_data_type(&writer_schema, Some(&reader_schema), None) + .expect("promotion should resolve") + } + #[test] fn test_date_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Int, "date"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::Date32)); } @@ -721,8 +1571,8 @@ mod tests { fn test_time_millis_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Int, "time-millis"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::TimeMillis)); } @@ -731,8 +1581,8 @@ mod tests { fn test_time_micros_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Long, "time-micros"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::TimeMicros)); } @@ -741,8 +1591,8 @@ mod tests { fn test_timestamp_millis_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Long, "timestamp-millis"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::TimestampMillis(true))); } @@ -751,8 +1601,8 @@ mod tests { fn test_timestamp_micros_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Long, "timestamp-micros"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::TimestampMicros(true))); } @@ -761,8 +1611,8 @@ mod tests { fn test_local_timestamp_millis_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Long, "local-timestamp-millis"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::TimestampMillis(false))); } @@ -771,8 +1621,8 @@ mod tests { fn test_local_timestamp_micros_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Long, "local-timestamp-micros"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::TimestampMicros(false))); } @@ -780,11 +1630,9 @@ mod tests { #[test] fn test_uuid_type() { let mut codec = Codec::Fixed(16); - if let c @ Codec::Fixed(16) = &mut codec { *c = Codec::Uuid; } - assert!(matches!(codec, Codec::Uuid)); } @@ -821,13 +1669,12 @@ mod tests { panic!("Expected NotYetImplemented error"); } } - #[test] fn test_unknown_logical_type_added_to_metadata() { let schema = create_schema_with_logical_type(PrimitiveType::Int, "custom-type"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert_eq!( result.metadata.get("logicalType"), @@ -839,8 +1686,8 @@ mod tests { fn test_string_with_utf8view_enabled() { let schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::String)); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, true, false).unwrap(); + let mut maker = Maker::new(true, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::Utf8View)); } @@ -849,8 +1696,8 @@ mod tests { fn test_string_without_utf8view_enabled() { let schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::String)); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::Utf8)); } @@ -877,8 +1724,8 @@ mod tests { let schema = Schema::Complex(ComplexType::Record(record)); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, true, false).unwrap(); + let mut maker = Maker::new(true, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); if let Codec::Struct(fields) = &result.codec { let first_field_codec = &fields[0].data_type().codec; @@ -895,8 +1742,8 @@ mod tests { Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), ]); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false, true); + let mut maker = Maker::new(false, true); + let result = maker.make_data_type(&schema, None, None); assert!(result.is_err()); match result { @@ -909,6 +1756,126 @@ mod tests { } } + #[test] + fn test_resolve_int_to_float_promotion() { + let result = resolve_promotion(PrimitiveType::Int, PrimitiveType::Float); + assert!(matches!(result.codec, Codec::Float32)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::IntToFloat)) + ); + } + + #[test] + fn test_resolve_int_to_double_promotion() { + let result = resolve_promotion(PrimitiveType::Int, PrimitiveType::Double); + assert!(matches!(result.codec, Codec::Float64)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::IntToDouble)) + ); + } + + #[test] + fn test_resolve_long_to_float_promotion() { + let result = resolve_promotion(PrimitiveType::Long, PrimitiveType::Float); + assert!(matches!(result.codec, Codec::Float32)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::LongToFloat)) + ); + } + + #[test] + fn test_resolve_long_to_double_promotion() { + let result = resolve_promotion(PrimitiveType::Long, PrimitiveType::Double); + assert!(matches!(result.codec, Codec::Float64)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::LongToDouble)) + ); + } + + #[test] + fn test_resolve_float_to_double_promotion() { + let result = resolve_promotion(PrimitiveType::Float, PrimitiveType::Double); + assert!(matches!(result.codec, Codec::Float64)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::FloatToDouble)) + ); + } + + #[test] + fn test_resolve_string_to_bytes_promotion() { + let result = resolve_promotion(PrimitiveType::String, PrimitiveType::Bytes); + assert!(matches!(result.codec, Codec::Binary)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::StringToBytes)) + ); + } + + #[test] + fn test_resolve_bytes_to_string_promotion() { + let result = resolve_promotion(PrimitiveType::Bytes, PrimitiveType::String); + assert!(matches!(result.codec, Codec::Utf8)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::BytesToString)) + ); + } + + #[test] + fn test_resolve_illegal_promotion_double_to_float_errors() { + let writer_schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::Double)); + let reader_schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::Float)); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&writer_schema, Some(&reader_schema), None); + assert!(result.is_err()); + match result { + Err(ArrowError::ParseError(msg)) => { + assert!(msg.contains("Illegal promotion")); + } + _ => panic!("Expected ParseError for illegal promotion Double -> Float"), + } + } + + #[test] + fn test_promotion_within_nullable_union_keeps_reader_null_ordering() { + let writer = Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + ]); + let reader = Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Double)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + ]); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&writer, Some(&reader), None).unwrap(); + assert!(matches!(result.codec, Codec::Float64)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::IntToDouble)) + ); + assert_eq!(result.nullability, Some(Nullability::NullSecond)); + } + + #[test] + fn test_resolve_type_promotion() { + let writer_schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)); + let reader_schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)); + let mut maker = Maker::new(false, false); + let result = maker + .make_data_type(&writer_schema, Some(&reader_schema), None) + .unwrap(); + assert!(matches!(result.codec, Codec::Int64)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::IntToLong)) + ); + } + #[test] fn test_nested_record_type_reuse_without_namespace() { let schema_str = r#" @@ -935,8 +1902,8 @@ mod tests { let schema: Schema = serde_json::from_str(schema_str).unwrap(); - let mut resolver = Resolver::default(); - let avro_data_type = make_data_type(&schema, None, &mut resolver, false, false).unwrap(); + let mut maker = Maker::new(false, false); + let avro_data_type = maker.make_data_type(&schema, None, None).unwrap(); if let Codec::Struct(fields) = avro_data_type.codec() { assert_eq!(fields.len(), 4); @@ -1015,8 +1982,8 @@ mod tests { let schema: Schema = serde_json::from_str(schema_str).unwrap(); - let mut resolver = Resolver::default(); - let avro_data_type = make_data_type(&schema, None, &mut resolver, false, false).unwrap(); + let mut maker = Maker::new(false, false); + let avro_data_type = maker.make_data_type(&schema, None, None).unwrap(); if let Codec::Struct(fields) = avro_data_type.codec() { assert_eq!(fields.len(), 4); @@ -1066,4 +2033,451 @@ mod tests { panic!("Top-level schema is not a struct"); } } + + #[test] + fn test_resolve_from_writer_and_reader_defaults_root_name_for_non_record_reader() { + let writer_schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::String)); + let reader_schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::String)); + let field = + AvroField::resolve_from_writer_and_reader(&writer_schema, &reader_schema, false, false) + .expect("resolution should succeed"); + assert_eq!(field.name(), AVRO_ROOT_RECORD_DEFAULT_NAME); + assert!(matches!(field.data_type().codec(), Codec::Utf8)); + } + + fn json_string(s: &str) -> Value { + Value::String(s.to_string()) + } + + fn assert_default_stored(dt: &AvroDataType, default_json: &Value) { + let stored = dt + .metadata + .get(AVRO_FIELD_DEFAULT_METADATA_KEY) + .cloned() + .unwrap_or_default(); + let expected = serde_json::to_string(default_json).unwrap(); + assert_eq!(stored, expected, "stored default metadata should match"); + } + + #[test] + fn test_validate_and_store_default_null_and_nullability_rules() { + let mut dt_null = AvroDataType::new(Codec::Null, HashMap::new(), None); + let lit = dt_null.parse_and_store_default(&Value::Null).unwrap(); + assert_eq!(lit, AvroLiteral::Null); + assert_default_stored(&dt_null, &Value::Null); + let mut dt_int = AvroDataType::new(Codec::Int32, HashMap::new(), None); + let err = dt_int.parse_and_store_default(&Value::Null).unwrap_err(); + assert!( + err.to_string() + .contains("JSON null default is only valid for `null` type"), + "unexpected error: {err}" + ); + let mut dt_int_nf = + AvroDataType::new(Codec::Int32, HashMap::new(), Some(Nullability::NullFirst)); + let lit2 = dt_int_nf.parse_and_store_default(&Value::Null).unwrap(); + assert_eq!(lit2, AvroLiteral::Null); + assert_default_stored(&dt_int_nf, &Value::Null); + let mut dt_int_ns = + AvroDataType::new(Codec::Int32, HashMap::new(), Some(Nullability::NullSecond)); + let err2 = dt_int_ns.parse_and_store_default(&Value::Null).unwrap_err(); + assert!( + err2.to_string() + .contains("JSON null default is only valid for `null` type"), + "unexpected error: {err2}" + ); + } + + #[test] + fn test_validate_and_store_default_primitives_and_temporal() { + let mut dt_bool = AvroDataType::new(Codec::Boolean, HashMap::new(), None); + let lit = dt_bool.parse_and_store_default(&Value::Bool(true)).unwrap(); + assert_eq!(lit, AvroLiteral::Boolean(true)); + assert_default_stored(&dt_bool, &Value::Bool(true)); + let mut dt_i32 = AvroDataType::new(Codec::Int32, HashMap::new(), None); + let lit = dt_i32 + .parse_and_store_default(&serde_json::json!(123)) + .unwrap(); + assert_eq!(lit, AvroLiteral::Int(123)); + assert_default_stored(&dt_i32, &serde_json::json!(123)); + let err = dt_i32 + .parse_and_store_default(&serde_json::json!(i64::from(i32::MAX) + 1)) + .unwrap_err(); + assert!(format!("{err}").contains("out of i32 range")); + let mut dt_i64 = AvroDataType::new(Codec::Int64, HashMap::new(), None); + let lit = dt_i64 + .parse_and_store_default(&serde_json::json!(1234567890)) + .unwrap(); + assert_eq!(lit, AvroLiteral::Long(1234567890)); + assert_default_stored(&dt_i64, &serde_json::json!(1234567890)); + let mut dt_f32 = AvroDataType::new(Codec::Float32, HashMap::new(), None); + let lit = dt_f32 + .parse_and_store_default(&serde_json::json!(1.25)) + .unwrap(); + assert_eq!(lit, AvroLiteral::Float(1.25)); + assert_default_stored(&dt_f32, &serde_json::json!(1.25)); + let err = dt_f32 + .parse_and_store_default(&serde_json::json!(1e39)) + .unwrap_err(); + assert!(format!("{err}").contains("out of f32 range")); + let mut dt_f64 = AvroDataType::new(Codec::Float64, HashMap::new(), None); + let lit = dt_f64 + .parse_and_store_default(&serde_json::json!(std::f64::consts::PI)) + .unwrap(); + assert_eq!(lit, AvroLiteral::Double(std::f64::consts::PI)); + assert_default_stored(&dt_f64, &serde_json::json!(std::f64::consts::PI)); + let mut dt_str = AvroDataType::new(Codec::Utf8, HashMap::new(), None); + let l = dt_str + .parse_and_store_default(&json_string("hello")) + .unwrap(); + assert_eq!(l, AvroLiteral::String("hello".into())); + assert_default_stored(&dt_str, &json_string("hello")); + let mut dt_strv = AvroDataType::new(Codec::Utf8View, HashMap::new(), None); + let l = dt_strv + .parse_and_store_default(&json_string("view")) + .unwrap(); + assert_eq!(l, AvroLiteral::String("view".into())); + assert_default_stored(&dt_strv, &json_string("view")); + let mut dt_uuid = AvroDataType::new(Codec::Uuid, HashMap::new(), None); + let l = dt_uuid + .parse_and_store_default(&json_string("00000000-0000-0000-0000-000000000000")) + .unwrap(); + assert_eq!( + l, + AvroLiteral::String("00000000-0000-0000-0000-000000000000".into()) + ); + let mut dt_bin = AvroDataType::new(Codec::Binary, HashMap::new(), None); + let l = dt_bin.parse_and_store_default(&json_string("ABC")).unwrap(); + assert_eq!(l, AvroLiteral::Bytes(vec![65, 66, 67])); + let err = dt_bin + .parse_and_store_default(&json_string("€")) // U+20AC + .unwrap_err(); + assert!(format!("{err}").contains("Invalid codepoint")); + let mut dt_date = AvroDataType::new(Codec::Date32, HashMap::new(), None); + let ld = dt_date + .parse_and_store_default(&serde_json::json!(1)) + .unwrap(); + assert_eq!(ld, AvroLiteral::Int(1)); + let mut dt_tmill = AvroDataType::new(Codec::TimeMillis, HashMap::new(), None); + let lt = dt_tmill + .parse_and_store_default(&serde_json::json!(86_400_000)) + .unwrap(); + assert_eq!(lt, AvroLiteral::Int(86_400_000)); + let mut dt_tmicros = AvroDataType::new(Codec::TimeMicros, HashMap::new(), None); + let ltm = dt_tmicros + .parse_and_store_default(&serde_json::json!(1_000_000)) + .unwrap(); + assert_eq!(ltm, AvroLiteral::Long(1_000_000)); + let mut dt_ts_milli = AvroDataType::new(Codec::TimestampMillis(true), HashMap::new(), None); + let l1 = dt_ts_milli + .parse_and_store_default(&serde_json::json!(123)) + .unwrap(); + assert_eq!(l1, AvroLiteral::Long(123)); + let mut dt_ts_micro = + AvroDataType::new(Codec::TimestampMicros(false), HashMap::new(), None); + let l2 = dt_ts_micro + .parse_and_store_default(&serde_json::json!(456)) + .unwrap(); + assert_eq!(l2, AvroLiteral::Long(456)); + } + + #[test] + fn test_validate_and_store_default_fixed_decimal_interval() { + let mut dt_fixed = AvroDataType::new(Codec::Fixed(4), HashMap::new(), None); + let l = dt_fixed + .parse_and_store_default(&json_string("WXYZ")) + .unwrap(); + assert_eq!(l, AvroLiteral::Bytes(vec![87, 88, 89, 90])); + let err = dt_fixed + .parse_and_store_default(&json_string("TOO LONG")) + .unwrap_err(); + assert!(err.to_string().contains("Default length")); + let mut dt_dec_fixed = + AvroDataType::new(Codec::Decimal(10, Some(2), Some(3)), HashMap::new(), None); + let l = dt_dec_fixed + .parse_and_store_default(&json_string("abc")) + .unwrap(); + assert_eq!(l, AvroLiteral::Bytes(vec![97, 98, 99])); + let err = dt_dec_fixed + .parse_and_store_default(&json_string("toolong")) + .unwrap_err(); + assert!(err.to_string().contains("Default length")); + let mut dt_dec_bytes = + AvroDataType::new(Codec::Decimal(10, Some(2), None), HashMap::new(), None); + let l = dt_dec_bytes + .parse_and_store_default(&json_string("freeform")) + .unwrap(); + assert_eq!( + l, + AvroLiteral::Bytes("freeform".bytes().collect::>()) + ); + let mut dt_interval = AvroDataType::new(Codec::Interval, HashMap::new(), None); + let l = dt_interval + .parse_and_store_default(&json_string("ABCDEFGHIJKL")) + .unwrap(); + assert_eq!( + l, + AvroLiteral::Bytes("ABCDEFGHIJKL".bytes().collect::>()) + ); + let err = dt_interval + .parse_and_store_default(&json_string("short")) + .unwrap_err(); + assert!(err.to_string().contains("Default length")); + } + + #[test] + fn test_validate_and_store_default_enum_list_map_struct() { + let symbols: Arc<[String]> = ["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()] + .into_iter() + .collect(); + let mut dt_enum = AvroDataType::new(Codec::Enum(symbols), HashMap::new(), None); + let l = dt_enum + .parse_and_store_default(&json_string("GREEN")) + .unwrap(); + assert_eq!(l, AvroLiteral::Enum("GREEN".into())); + let err = dt_enum + .parse_and_store_default(&json_string("YELLOW")) + .unwrap_err(); + assert!(err.to_string().contains("Default enum symbol")); + let item = AvroDataType::new(Codec::Int64, HashMap::new(), None); + let mut dt_list = AvroDataType::new(Codec::List(Arc::new(item)), HashMap::new(), None); + let val = serde_json::json!([1, 2, 3]); + let l = dt_list.parse_and_store_default(&val).unwrap(); + assert_eq!( + l, + AvroLiteral::Array(vec![ + AvroLiteral::Long(1), + AvroLiteral::Long(2), + AvroLiteral::Long(3) + ]) + ); + let err = dt_list + .parse_and_store_default(&serde_json::json!({"not":"array"})) + .unwrap_err(); + assert!(err.to_string().contains("JSON array")); + let val_dt = AvroDataType::new(Codec::Float64, HashMap::new(), None); + let mut dt_map = AvroDataType::new(Codec::Map(Arc::new(val_dt)), HashMap::new(), None); + let mv = serde_json::json!({"x": 1.5, "y": 2.5}); + let l = dt_map.parse_and_store_default(&mv).unwrap(); + let mut expected = IndexMap::new(); + expected.insert("x".into(), AvroLiteral::Double(1.5)); + expected.insert("y".into(), AvroLiteral::Double(2.5)); + assert_eq!(l, AvroLiteral::Map(expected)); + // Not object -> error + let err = dt_map + .parse_and_store_default(&serde_json::json!(123)) + .unwrap_err(); + assert!(err.to_string().contains("JSON object")); + let mut field_a = AvroField { + name: "a".into(), + data_type: AvroDataType::new(Codec::Int32, HashMap::new(), None), + }; + let field_b = AvroField { + name: "b".into(), + data_type: AvroDataType::new( + Codec::Int64, + HashMap::new(), + Some(Nullability::NullFirst), + ), + }; + let mut c_md = HashMap::new(); + c_md.insert(AVRO_FIELD_DEFAULT_METADATA_KEY.into(), "\"xyz\"".into()); + let field_c = AvroField { + name: "c".into(), + data_type: AvroDataType::new(Codec::Utf8, c_md, None), + }; + field_a.data_type.metadata.insert("doc".into(), "na".into()); + let struct_fields: Arc<[AvroField]> = Arc::from(vec![field_a, field_b, field_c]); + let mut dt_struct = AvroDataType::new(Codec::Struct(struct_fields), HashMap::new(), None); + let default_obj = serde_json::json!({"a": 7}); + let l = dt_struct.parse_and_store_default(&default_obj).unwrap(); + let mut expected = IndexMap::new(); + expected.insert("a".into(), AvroLiteral::Int(7)); + expected.insert("b".into(), AvroLiteral::Null); + expected.insert("c".into(), AvroLiteral::String("xyz".into())); + assert_eq!(l, AvroLiteral::Map(expected)); + assert_default_stored(&dt_struct, &default_obj); + let req_field = AvroField { + name: "req".into(), + data_type: AvroDataType::new(Codec::Boolean, HashMap::new(), None), + }; + let mut dt_bad = AvroDataType::new( + Codec::Struct(Arc::from(vec![req_field])), + HashMap::new(), + None, + ); + let err = dt_bad + .parse_and_store_default(&serde_json::json!({})) + .unwrap_err(); + assert!( + err.to_string().contains("missing required subfield 'req'"), + "unexpected error: {err}" + ); + let err = dt_struct + .parse_and_store_default(&serde_json::json!(10)) + .unwrap_err(); + err.to_string().contains("must be a JSON object"); + } + + #[test] + fn test_resolve_array_promotion_and_reader_metadata() { + let mut w_add: HashMap<&str, Value> = HashMap::new(); + w_add.insert("who", json_string("writer")); + let mut r_add: HashMap<&str, Value> = HashMap::new(); + r_add.insert("who", json_string("reader")); + let writer_schema = Schema::Complex(ComplexType::Array(Array { + items: Box::new(Schema::TypeName(TypeName::Primitive(PrimitiveType::Int))), + attributes: Attributes { + logical_type: None, + additional: w_add, + }, + })); + let reader_schema = Schema::Complex(ComplexType::Array(Array { + items: Box::new(Schema::TypeName(TypeName::Primitive(PrimitiveType::Long))), + attributes: Attributes { + logical_type: None, + additional: r_add, + }, + })); + let mut maker = Maker::new(false, false); + let dt = maker + .make_data_type(&writer_schema, Some(&reader_schema), None) + .unwrap(); + assert_eq!(dt.metadata.get("who"), Some(&"\"reader\"".to_string())); + if let Codec::List(inner) = dt.codec() { + assert!(matches!(inner.codec(), Codec::Int64)); + assert_eq!( + inner.resolution, + Some(ResolutionInfo::Promotion(Promotion::IntToLong)) + ); + } else { + panic!("expected list codec"); + } + } + + #[test] + fn test_resolve_fixed_success_name_and_size_match_and_alias() { + let writer_schema = Schema::Complex(ComplexType::Fixed(Fixed { + name: "MD5", + namespace: None, + aliases: vec!["Hash16"], + size: 16, + attributes: Attributes::default(), + })); + let reader_schema = Schema::Complex(ComplexType::Fixed(Fixed { + name: "Hash16", + namespace: None, + aliases: vec![], + size: 16, + attributes: Attributes::default(), + })); + let mut maker = Maker::new(false, false); + let dt = maker + .make_data_type(&writer_schema, Some(&reader_schema), None) + .unwrap(); + assert!(matches!(dt.codec(), Codec::Fixed(16))); + } + + #[test] + fn test_resolve_records_mapping_default_fields_and_skip_fields() { + let writer = Schema::Complex(ComplexType::Record(Record { + name: "R", + namespace: None, + doc: None, + aliases: vec![], + fields: vec![ + crate::schema::Field { + name: "a", + doc: None, + r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + default: None, + }, + crate::schema::Field { + name: "skipme", + doc: None, + r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + default: None, + }, + crate::schema::Field { + name: "b", + doc: None, + r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), + default: None, + }, + ], + attributes: Attributes::default(), + })); + let reader = Schema::Complex(ComplexType::Record(Record { + name: "R", + namespace: None, + doc: None, + aliases: vec![], + fields: vec![ + crate::schema::Field { + name: "b", + doc: None, + r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), + default: None, + }, + crate::schema::Field { + name: "a", + doc: None, + r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), + default: None, + }, + crate::schema::Field { + name: "name", + doc: None, + r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + default: Some(json_string("anon")), + }, + crate::schema::Field { + name: "opt", + doc: None, + r#type: Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + ]), + default: None, // should default to null because NullFirst + }, + ], + attributes: Attributes::default(), + })); + let mut maker = Maker::new(false, false); + let dt = maker + .make_data_type(&writer, Some(&reader), None) + .expect("record resolution"); + let fields = match dt.codec() { + Codec::Struct(f) => f, + other => panic!("expected struct, got {other:?}"), + }; + assert_eq!(fields.len(), 4); + assert_eq!(fields[0].name(), "b"); + assert_eq!(fields[1].name(), "a"); + assert_eq!(fields[2].name(), "name"); + assert_eq!(fields[3].name(), "opt"); + assert!(matches!( + fields[1].data_type().resolution, + Some(ResolutionInfo::Promotion(Promotion::IntToLong)) + )); + let rec = match dt.resolution { + Some(ResolutionInfo::Record(ref r)) => r.clone(), + other => panic!("expected record resolution, got {other:?}"), + }; + assert_eq!(rec.writer_to_reader.as_ref(), &[Some(1), None, Some(0)]); + assert_eq!(rec.default_fields.as_ref(), &[2usize, 3usize]); + assert!(rec.skip_fields[0].is_none()); + assert!(rec.skip_fields[2].is_none()); + let skip1 = rec.skip_fields[1].as_ref().expect("skip field present"); + assert!(matches!(skip1.codec(), Codec::Utf8)); + let name_md = &fields[2].data_type().metadata; + assert_eq!( + name_md.get(AVRO_FIELD_DEFAULT_METADATA_KEY), + Some(&"\"anon\"".to_string()) + ); + let opt_md = &fields[3].data_type().metadata; + assert_eq!( + opt_md.get(AVRO_FIELD_DEFAULT_METADATA_KEY), + Some(&"null".to_string()) + ); + } } diff --git a/arrow-avro/src/compression.rs b/arrow-avro/src/compression.rs index 1e1960dc841f..64bacc8fd9b8 100644 --- a/arrow-avro/src/compression.rs +++ b/arrow-avro/src/compression.rs @@ -17,7 +17,7 @@ use arrow_schema::ArrowError; use std::io; -use std::io::Read; +use std::io::{Read, Write}; /// The metadata key used for storing the JSON encoded [`CompressionCodec`] pub const CODEC_METADATA_KEY: &str = "avro.codec"; @@ -112,4 +112,77 @@ impl CompressionCodec { )), } } + + pub(crate) fn compress(&self, data: &[u8]) -> Result, ArrowError> { + match self { + #[cfg(feature = "deflate")] + CompressionCodec::Deflate => { + let mut encoder = + flate2::write::DeflateEncoder::new(Vec::new(), flate2::Compression::default()); + encoder.write_all(data)?; + let compressed = encoder.finish()?; + Ok(compressed) + } + #[cfg(not(feature = "deflate"))] + CompressionCodec::Deflate => Err(ArrowError::ParseError( + "Deflate codec requires deflate feature".to_string(), + )), + + #[cfg(feature = "snappy")] + CompressionCodec::Snappy => { + let mut encoder = snap::raw::Encoder::new(); + // Allocate and compress in one step for efficiency + let mut compressed = encoder + .compress_vec(data) + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + // Compute CRC32 (ISO‑HDLC poly) of **uncompressed** data + let crc_val = crc::Crc::::new(&crc::CRC_32_ISO_HDLC).checksum(data); + compressed.extend_from_slice(&crc_val.to_be_bytes()); + Ok(compressed) + } + #[cfg(not(feature = "snappy"))] + CompressionCodec::Snappy => Err(ArrowError::ParseError( + "Snappy codec requires snappy feature".to_string(), + )), + + #[cfg(feature = "zstd")] + CompressionCodec::ZStandard => { + let mut encoder = zstd::Encoder::new(Vec::new(), 0) + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + encoder.write_all(data)?; + let compressed = encoder + .finish() + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + Ok(compressed) + } + #[cfg(not(feature = "zstd"))] + CompressionCodec::ZStandard => Err(ArrowError::ParseError( + "ZStandard codec requires zstd feature".to_string(), + )), + + #[cfg(feature = "bzip2")] + CompressionCodec::Bzip2 => { + let mut encoder = + bzip2::write::BzEncoder::new(Vec::new(), bzip2::Compression::default()); + encoder.write_all(data)?; + let compressed = encoder.finish()?; + Ok(compressed) + } + #[cfg(not(feature = "bzip2"))] + CompressionCodec::Bzip2 => Err(ArrowError::ParseError( + "Bzip2 codec requires bzip2 feature".to_string(), + )), + #[cfg(feature = "xz")] + CompressionCodec::Xz => { + let mut encoder = xz::write::XzEncoder::new(Vec::new(), 6); + encoder.write_all(data)?; + let compressed = encoder.finish()?; + Ok(compressed) + } + #[cfg(not(feature = "xz"))] + CompressionCodec::Xz => Err(ArrowError::ParseError( + "XZ codec requires xz feature".to_string(), + )), + } + } } diff --git a/arrow-avro/src/lib.rs b/arrow-avro/src/lib.rs index 8087a908d673..9367bc8efcb7 100644 --- a/arrow-avro/src/lib.rs +++ b/arrow-avro/src/lib.rs @@ -33,6 +33,11 @@ /// Implements the primary reader interface and record decoding logic. pub mod reader; +/// Core functionality for writing Arrow arrays as Avro data +/// +/// Implements the primary writer interface and record encoding logic. +pub mod writer; + /// Avro schema parsing and representation /// /// Provides types for parsing and representing Avro schema definitions. diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index e9bf7af61e1c..bf72fc92c642 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -17,49 +17,86 @@ //! Avro reader //! -//! This module provides facilities to read Apache Avro-encoded files or streams -//! into Arrow's `RecordBatch` format. In particular, it introduces: +//! Facilities to read Apache Avro–encoded data into Arrow's `RecordBatch` format. //! -//! * `ReaderBuilder`: Configures Avro reading, e.g., batch size -//! * `Reader`: Yields `RecordBatch` values, implementing `Iterator` -//! * `Decoder`: A low-level push-based decoder for Avro records +//! This module exposes three layers of the API surface, from highest to lowest-level: //! -//! # Basic Usage +//! * `ReaderBuilder`: configures how Avro is read (batch size, strict union handling, +//! string representation, reader schema, etc.) and produces either: +//! * a `Reader` for **Avro Object Container Files (OCF)** read from any `BufRead`, or +//! * a low-level `Decoder` for **single‑object encoded** Avro bytes and Confluent +//! **Schema Registry** framed messages. +//! * `Reader`: a convenient, synchronous iterator over `RecordBatch` decoded from an OCF +//! input. Implements [`Iterator>`] and +//! `RecordBatchReader`. +//! * `Decoder`: a push‑based row decoder that consumes raw Avro bytes and yields ready +//! `RecordBatch` values when batches fill. This is suitable for integrating with async +//! byte streams, network protocols, or other custom data sources. //! -//! `Reader` can be used directly with synchronous data sources, such as [`std::fs::File`]. +//! ## Encodings and when to use which type //! -//! ## Reading a Single Batch +//! * **Object Container File (OCF)**: A self‑describing file format with a header containing +//! the writer schema, optional compression codec, and a sync marker, followed by one or +//! more data blocks. Use `Reader` for this format. See the Avro specification for the +//! structure of OCF headers and blocks. +//! * **Single‑Object Encoding**: A stream‑friendly framing that prefixes each record body with +//! the 2‑byte magic `0xC3 0x01` followed by a schema fingerprint. Use `Decoder` with a +//! populated `SchemaStore` to resolve fingerprints to full +//! schemas. +//! * **Confluent Schema Registry wire format**: A 1‑byte magic `0x00`, a 4‑byte big‑endian +//! schema ID, then the Avro‑encoded body. Use `Decoder` with a +//! `SchemaStore` configured for `FingerprintAlgorithm::None` +//! and entries keyed by `Fingerprint::Id`. Confluent docs +//! describe this framing. +//! +//! ## Basic file usage (OCF) +//! +//! Use `ReaderBuilder::build` to construct a `Reader` from any `BufRead`, such as a +//! `BufReader`. The reader yields `RecordBatch` values you can iterate over or collect. +//! +//! ```no_run +//! use std::fs::File; +//! use std::io::BufReader; +//! use arrow_array::RecordBatch; +//! use arrow_avro::reader::ReaderBuilder; +//! +//! // Locate a test file (mirrors Arrow's test data layout) +//! let path = "avro/alltypes_plain.avro"; +//! let path = std::env::var("ARROW_TEST_DATA") +//! .map(|dir| format!("{dir}/{path}")) +//! .unwrap_or_else(|_| format!("../testing/data/{path}")); //! -//! ``` -//! # use std::fs::File; -//! # use std::io::BufReader; -//! # use arrow_avro::reader::ReaderBuilder; -//! # let path = "avro/alltypes_plain.avro"; -//! # let path = match std::env::var("ARROW_TEST_DATA") { -//! # Ok(dir) => format!("{dir}/{path}"), -//! # Err(_) => format!("../testing/data/{path}") -//! # }; //! let file = File::open(path).unwrap(); -//! let mut avro = ReaderBuilder::new().build(BufReader::new(file)).unwrap(); -//! let batch = avro.next().unwrap(); +//! let mut reader = ReaderBuilder::new().build(BufReader::new(file)).unwrap(); +//! +//! // Iterate batches +//! let mut num_rows = 0usize; +//! while let Some(batch) = reader.next() { +//! let batch: RecordBatch = batch.unwrap(); +//! num_rows += batch.num_rows(); +//! } +//! println!("decoded {num_rows} rows"); //! ``` //! -//! # Async Usage +//! ## Streaming usage (single‑object / Confluent) //! -//! The lower-level `Decoder` can be integrated with various forms of async data streams, -//! and is designed to be agnostic to different async IO primitives within -//! the Rust ecosystem. It works by incrementally decoding Avro data from byte slices. +//! The `Decoder` lets you integrate Avro decoding with **any** source of bytes by +//! periodically calling `Decoder::decode` with new data and calling `Decoder::flush` +//! to get a `RecordBatch` once at least one row is complete. //! -//! For example, see below for how it could be used with an arbitrary `Stream` of `Bytes`: +//! The example below shows how to decode from an arbitrary stream of `bytes::Bytes` using +//! `futures` utilities. Note: this is illustrative and keeps a single in‑memory `Bytes` +//! buffer for simplicity—real applications typically maintain a rolling buffer. //! -//! ``` -//! # use std::task::{Poll, ready}; -//! # use bytes::{Buf, Bytes}; -//! # use arrow_schema::ArrowError; -//! # use futures::stream::{Stream, StreamExt}; -//! # use arrow_array::RecordBatch; -//! # use arrow_avro::reader::Decoder; +//! ```no_run +//! use bytes::{Buf, Bytes}; +//! use futures::{Stream, StreamExt}; +//! use std::task::{Poll, ready}; +//! use arrow_array::RecordBatch; +//! use arrow_schema::ArrowError; +//! use arrow_avro::reader::Decoder; //! +//! /// Decode a stream of Avro-framed bytes into RecordBatch values. //! fn decode_stream + Unpin>( //! mut decoder: Decoder, //! mut input: S, @@ -70,30 +107,105 @@ //! if buffered.is_empty() { //! buffered = match ready!(input.poll_next_unpin(cx)) { //! Some(b) => b, -//! None => break, +//! None => break, // EOF //! }; //! } +//! // Feed as much as possible //! let decoded = match decoder.decode(buffered.as_ref()) { -//! Ok(decoded) => decoded, +//! Ok(n) => n, //! Err(e) => return Poll::Ready(Some(Err(e))), //! }; //! let read = buffered.len(); //! buffered.advance(decoded); //! if decoded != read { +//! // decoder made partial progress; request more bytes //! break //! } //! } -//! // Convert any fully-decoded rows to a RecordBatch, if available +//! // Return a batch if one or more rows are complete //! Poll::Ready(decoder.flush().transpose()) //! }) //! } //! ``` //! - +//! ### Building a `Decoder` for **single‑object encoding** (Rabin fingerprints) +//! +//! ```no_run +//! use arrow_avro::schema::{AvroSchema, SchemaStore}; +//! use arrow_avro::reader::ReaderBuilder; +//! +//! // Build a SchemaStore and register known writer schemas +//! let mut store = SchemaStore::new(); // Rabin by default +//! let user_schema = AvroSchema::new(r#"{"type":"record","name":"User","fields":[ +//! {"name":"id","type":"long"},{"name":"name","type":"string"}]}"#.to_string()); +//! let _fp = store.register(user_schema).unwrap(); // computes Rabin CRC-64-AVRO +//! +//! // Build a Decoder that expects single-object encoding (0xC3 0x01 + fingerprint and body) +//! let decoder = ReaderBuilder::new() +//! .with_writer_schema_store(store) +//! .with_batch_size(1024) +//! .build_decoder() +//! .unwrap(); +//! // Feed decoder with framed bytes (not shown; see `decode_stream` above). +//! ``` +//! +//! ### Building a `Decoder` for **Confluent Schema Registry** framed messages +//! +//! ```no_run +//! use arrow_avro::schema::{AvroSchema, SchemaStore, Fingerprint, FingerprintAlgorithm}; +//! use arrow_avro::reader::ReaderBuilder; +//! +//! // Confluent wire format uses a magic 0x00 byte + 4-byte schema id (big-endian). +//! // Create a store keyed by `Fingerprint::Id` and pre-populate with known schemas. +//! let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::None); +//! +//! // Suppose registry ID 42 corresponds to this Avro schema: +//! let avro = AvroSchema::new(r#"{"type":"string"}"#.to_string()); +//! store.set(Fingerprint::Id(42), avro).unwrap(); +//! +//! // Build a Decoder that understands Confluent framing +//! let decoder = ReaderBuilder::new() +//! .with_writer_schema_store(store) +//! .build_decoder() +//! .unwrap(); +//! // Feed decoder with 0x00 + [id:4] + Avro body frames. +//! ``` +//! +//! ## Schema evolution and batch boundaries +//! +//! `Decoder` supports mid‑stream schema changes when the input framing carries a schema +//! fingerprint (single‑object or Confluent). When a new fingerprint is observed: +//! +//! * If the current `RecordBatch` is **empty**, the decoder switches to the new schema +//! immediately. +//! * If not, the decoder finishes the current batch first and only then switches. +//! +//! Consequently, the schema of batches produced by `Decoder::flush` may change over time, +//! and `Decoder` intentionally does **not** implement `RecordBatchReader`. In contrast, +//! `Reader` (OCF) has a single writer schema for the entire file and therefore implements +//! `RecordBatchReader`. +//! +//! ## Performance & memory +//! +//! * `batch_size` controls the maximum number of rows per `RecordBatch`. Larger batches +//! amortize per‑batch overhead; smaller batches reduce peak memory usage and latency. +//! * When `utf8_view` is enabled, string columns use Arrow’s `StringViewArray`, which can +//! reduce allocations for short strings. +//! * For OCF, blocks may be compressed `Reader` will decompress using the codec specified +//! in the file header and feed uncompressed bytes to the row `Decoder`. +//! +//! ## Error handling +//! +//! * Incomplete inputs return parse errors with "Unexpected EOF"; callers typically provide +//! more bytes and try again. +//! * If a fingerprint is unknown to the provided `SchemaStore`, decoding fails with a +//! descriptive error. Populate the store up front to avoid this. +//! +//! --- use crate::codec::{AvroField, AvroFieldBuilder}; use crate::schema::{ - compare_schemas, generate_fingerprint, AvroSchema, Fingerprint, FingerprintAlgorithm, Schema, - SchemaStore, SINGLE_OBJECT_MAGIC, + compare_schemas, AvroSchema, Fingerprint, FingerprintAlgorithm, Schema, SchemaStore, + CONFLUENT_MAGIC, SINGLE_OBJECT_MAGIC, }; use arrow_array::{Array, RecordBatch, RecordBatchReader}; use arrow_schema::{ArrowError, SchemaRef}; @@ -130,7 +242,86 @@ fn read_header(mut reader: R) -> Result { }) } -/// A low-level interface for decoding Avro-encoded bytes into Arrow `RecordBatch`. +// NOTE: The Current ` is_incomplete_data ` below is temporary and will be improved prior to public release +fn is_incomplete_data(err: &ArrowError) -> bool { + matches!( + err, + ArrowError::ParseError(msg) + if msg.contains("Unexpected EOF") + ) +} + +/// A low‑level, push‑based decoder from Avro bytes to Arrow `RecordBatch`. +/// +/// `Decoder` is designed for **streaming** scenarios: +/// +/// * You *feed* freshly received bytes using `Self::decode`, potentially multiple times, +/// until at least one row is complete. +/// * You then *drain* completed rows with `Self::flush`, which yields a `RecordBatch` +/// if any rows were finished since the last flush. +/// +/// Unlike `Reader`, which is specialized for Avro **Object Container Files**, `Decoder` +/// understands **framed single‑object** inputs and **Confluent Schema Registry** messages, +/// switching schemas mid‑stream when the framing indicates a new fingerprint. +/// +/// ### Supported prefixes +/// +/// On each new row boundary, `Decoder` tries to match one of the following "prefixes": +/// +/// * **Single‑Object encoding**: magic `0xC3 0x01` + schema fingerprint (length depends on +/// the configured `FingerprintAlgorithm`); see `SINGLE_OBJECT_MAGIC`. +/// * **Confluent wire format**: magic `0x00` + 4‑byte big‑endian schema id; see +/// `CONFLUENT_MAGIC`. +/// +/// The active fingerprint determines which cached row decoder is used to decode the following +/// record body bytes. +/// +/// ### Schema switching semantics +/// +/// When a new fingerprint is observed: +/// +/// * If the current batch is empty, the decoder switches immediately; +/// * Otherwise, the current batch is finalized on the next `flush` and only then +/// does the decoder switch to the new schema. This guarantees that a single `RecordBatch` +/// never mixes rows with different schemas. +/// +/// ### Examples +/// +/// Build a `Decoder` for single‑object encoding using a `SchemaStore` with Rabin fingerprints: +/// +/// ```no_run +/// use arrow_avro::schema::{AvroSchema, SchemaStore}; +/// use arrow_avro::reader::ReaderBuilder; +/// +/// let mut store = SchemaStore::new(); // Rabin by default +/// let avro = AvroSchema::new(r#""string""#.to_string()); +/// let _fp = store.register(avro).unwrap(); +/// +/// let mut decoder = ReaderBuilder::new() +/// .with_writer_schema_store(store) +/// .with_batch_size(512) +/// .build_decoder() +/// .unwrap(); +/// +/// // Feed bytes (framed as 0xC3 0x01 + fingerprint and body) +/// // decoder.decode(&bytes)?; +/// // if let Some(batch) = decoder.flush()? { /* process */ } +/// ``` +/// +/// Build a `Decoder` for Confluent Registry messages (magic 0x00 + 4‑byte id): +/// +/// ```no_run +/// use arrow_avro::schema::{AvroSchema, SchemaStore, Fingerprint, FingerprintAlgorithm}; +/// use arrow_avro::reader::ReaderBuilder; +/// +/// let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::None); +/// store.set(Fingerprint::Id(7), AvroSchema::new(r#""long""#.to_string())).unwrap(); +/// +/// let mut decoder = ReaderBuilder::new() +/// .with_writer_schema_store(store) +/// .build_decoder() +/// .unwrap(); +/// ``` #[derive(Debug)] pub struct Decoder { active_decoder: RecordDecoder, @@ -139,85 +330,132 @@ pub struct Decoder { remaining_capacity: usize, cache: IndexMap, fingerprint_algorithm: FingerprintAlgorithm, - expect_prefix: bool, utf8_view: bool, strict_mode: bool, pending_schema: Option<(Fingerprint, RecordDecoder)>, + awaiting_body: bool, } impl Decoder { - /// Return the Arrow schema for the rows decoded by this decoder + /// Returns the Arrow schema for the rows decoded by this decoder. + /// + /// **Note:** With single‑object or Confluent framing, the schema may change + /// at a row boundary when the input indicates a new fingerprint. pub fn schema(&self) -> SchemaRef { self.active_decoder.schema().clone() } - /// Return the configured maximum number of rows per batch + /// Returns the configured maximum number of rows per batch. pub fn batch_size(&self) -> usize { self.batch_size } - /// Feed `data` into the decoder row by row until we either: - /// - consume all bytes in `data`, or - /// - reach `batch_size` decoded rows. + /// Feed a chunk of bytes into the decoder. + /// + /// This will: /// - /// Returns the number of bytes consumed. + /// * Decode at most `Self::batch_size` rows; + /// * Return the number of input bytes **consumed** from `data` (which may be 0 if more + /// bytes are required, or less than `data.len()` if a prefix/body straddles the + /// chunk boundary); + /// * Defer producing a `RecordBatch` until you call `Self::flush`. + /// + /// # Returns + /// The number of bytes consumed from `data`. + /// + /// # Errors + /// Returns an error if: + /// + /// * The input indicates an unknown fingerprint (not present in the provided + /// `SchemaStore`; + /// * The Avro body is malformed; + /// * A strict‑mode union rule is violated (see `ReaderBuilder::with_strict_mode`). pub fn decode(&mut self, data: &[u8]) -> Result { - if self.expect_prefix - && data.len() >= SINGLE_OBJECT_MAGIC.len() - && !data.starts_with(&SINGLE_OBJECT_MAGIC) - { - return Err(ArrowError::ParseError( - "Expected single‑object encoding fingerprint prefix for first message \ - (writer_schema_store is set but active_fingerprint is None)" - .into(), - )); - } let mut total_consumed = 0usize; - // The loop stops when the batch is full, a schema change is staged, - // or handle_prefix indicates we need more bytes (Some(0)). while total_consumed < data.len() && self.remaining_capacity > 0 { - if let Some(n) = self.handle_prefix(&data[total_consumed..])? { - // We either consumed a prefix (n > 0) and need a schema switch, or we need - // more bytes to make a decision. Either way, this decoding attempt is finished. - total_consumed += n; + if self.awaiting_body { + match self.active_decoder.decode(&data[total_consumed..], 1) { + Ok(n) => { + self.remaining_capacity -= 1; + total_consumed += n; + self.awaiting_body = false; + continue; + } + Err(ref e) if is_incomplete_data(e) => break, + err => return err, + }; + } + match self.handle_prefix(&data[total_consumed..])? { + Some(0) => break, // Insufficient bytes + Some(n) => { + total_consumed += n; + self.apply_pending_schema_if_batch_empty(); + self.awaiting_body = true; + } + None => { + return Err(ArrowError::ParseError( + "Missing magic bytes and fingerprint".to_string(), + )) + } } - // No prefix: decode one row and keep going. - let n = self.active_decoder.decode(&data[total_consumed..], 1)?; - self.remaining_capacity -= 1; - total_consumed += n; } Ok(total_consumed) } - // Attempt to handle a single‑object‑encoding prefix at the current position. - // + // Attempt to handle a prefix at the current position. // * Ok(None) – buffer does not start with the prefix. // * Ok(Some(0)) – prefix detected, but the buffer is too short; caller should await more bytes. // * Ok(Some(n)) – consumed `n > 0` bytes of a complete prefix (magic and fingerprint). fn handle_prefix(&mut self, buf: &[u8]) -> Result, ArrowError> { - // If there is no schema store, prefixes are unrecognized. - if !self.expect_prefix { - return Ok(None); + match self.fingerprint_algorithm { + FingerprintAlgorithm::Rabin => { + self.handle_prefix_common(buf, &SINGLE_OBJECT_MAGIC, |bytes| { + Fingerprint::Rabin(u64::from_le_bytes(bytes)) + }) + } + FingerprintAlgorithm::None => { + self.handle_prefix_common(buf, &CONFLUENT_MAGIC, |bytes| { + Fingerprint::Id(u32::from_be_bytes(bytes)) + }) + } + #[cfg(feature = "md5")] + FingerprintAlgorithm::MD5 => { + self.handle_prefix_common(buf, &SINGLE_OBJECT_MAGIC, |bytes| { + Fingerprint::MD5(bytes) + }) + } + #[cfg(feature = "sha256")] + FingerprintAlgorithm::SHA256 => { + self.handle_prefix_common(buf, &SINGLE_OBJECT_MAGIC, |bytes| { + Fingerprint::SHA256(bytes) + }) + } + } + } + + /// This method checks for the provided `magic` bytes at the start of `buf` and, if present, + /// attempts to read the following fingerprint of `N` bytes, converting it to a + /// `Fingerprint` using `fingerprint_from`. + fn handle_prefix_common( + &mut self, + buf: &[u8], + magic: &[u8; MAGIC_LEN], + fingerprint_from: impl FnOnce([u8; N]) -> Fingerprint, + ) -> Result, ArrowError> { + // Need at least the magic bytes to decide + // 2 bytes for Avro Spec and 1 byte for Confluent Wire Protocol. + if buf.len() < MAGIC_LEN { + return Ok(Some(0)); } - // Need at least the magic bytes to decide (2 bytes). - let Some(magic_bytes) = buf.get(..SINGLE_OBJECT_MAGIC.len()) else { - return Ok(Some(0)); // Get more bytes - }; // Bail out early if the magic does not match. - if magic_bytes != SINGLE_OBJECT_MAGIC { - return Ok(None); // Continue to decode the next record + if &buf[..MAGIC_LEN] != magic { + return Ok(None); } // Try to parse the fingerprint that follows the magic. - let fingerprint_size = match self.fingerprint_algorithm { - FingerprintAlgorithm::Rabin => self - .handle_fingerprint(&buf[SINGLE_OBJECT_MAGIC.len()..], |bytes| { - Fingerprint::Rabin(u64::from_le_bytes(bytes)) - })?, - }; + let consumed_fp = self.handle_fingerprint(&buf[MAGIC_LEN..], fingerprint_from)?; // Convert the inner result into a “bytes consumed” count. // NOTE: Incomplete fingerprint consumes no bytes. - let consumed = fingerprint_size.map_or(0, |n| n + SINGLE_OBJECT_MAGIC.len()); - Ok(Some(consumed)) + Ok(Some(consumed_fp.map_or(0, |n| n + MAGIC_LEN))) } // Attempts to read and install a new fingerprint of `N` bytes. @@ -231,7 +469,7 @@ impl Decoder { ) -> Result, ArrowError> { // Need enough bytes to get fingerprint (next N bytes) let Some(fingerprint_bytes) = buf.get(..N) else { - return Ok(None); // Insufficient bytes + return Ok(None); // insufficient bytes }; // SAFETY: length checked above. let new_fingerprint = fingerprint_from(fingerprint_bytes.try_into().unwrap()); @@ -252,15 +490,7 @@ impl Decoder { Ok(Some(N)) } - /// Produce a `RecordBatch` if at least one row is fully decoded, returning - /// `Ok(None)` if no new rows are available. - pub fn flush(&mut self) -> Result, ArrowError> { - if self.remaining_capacity == self.batch_size { - return Ok(None); - } - let batch = self.active_decoder.flush()?; - self.remaining_capacity = self.batch_size; - // Apply any staged schema switch. + fn apply_pending_schema(&mut self) { if let Some((new_fingerprint, new_decoder)) = self.pending_schema.take() { if let Some(old_fingerprint) = self.active_fingerprint.replace(new_fingerprint) { let old_decoder = std::mem::replace(&mut self.active_decoder, new_decoder); @@ -270,9 +500,36 @@ impl Decoder { self.active_decoder = new_decoder; } } + } + + fn apply_pending_schema_if_batch_empty(&mut self) { + if self.batch_is_empty() { + self.apply_pending_schema(); + } + } + + fn flush_and_reset(&mut self) -> Result, ArrowError> { + if self.batch_is_empty() { + return Ok(None); + } + let batch = self.active_decoder.flush()?; + self.remaining_capacity = self.batch_size; Ok(Some(batch)) } + /// Produce a `RecordBatch` if at least one row is fully decoded, returning + /// `Ok(None)` if no new rows are available. + /// + /// If a schema change was detected while decoding rows for the current batch, the + /// schema switch is applied **after** flushing this batch, so the **next** batch + /// (if any) may have a different schema. + pub fn flush(&mut self) -> Result, ArrowError> { + // We must flush the active decoder before switching to the pending one. + let batch = self.flush_and_reset(); + self.apply_pending_schema(); + batch + } + /// Returns the number of rows that can be added to this decoder before it is full. pub fn capacity(&self) -> usize { self.remaining_capacity @@ -282,10 +539,84 @@ impl Decoder { pub fn batch_is_full(&self) -> bool { self.remaining_capacity == 0 } + + /// Returns true if the decoder has not decoded any batches yet (i.e., the current batch is empty). + pub fn batch_is_empty(&self) -> bool { + self.remaining_capacity == self.batch_size + } + + // Decode either the block count or remaining capacity from `data` (an OCF block payload). + // + // Returns the number of bytes consumed from `data` along with the number of records decoded. + fn decode_block(&mut self, data: &[u8], count: usize) -> Result<(usize, usize), ArrowError> { + // OCF decoding never interleaves records across blocks, so no chunking. + let to_decode = std::cmp::min(count, self.remaining_capacity); + if to_decode == 0 { + return Ok((0, 0)); + } + let consumed = self.active_decoder.decode(data, to_decode)?; + self.remaining_capacity -= to_decode; + Ok((consumed, to_decode)) + } + + // Produce a `RecordBatch` if at least one row is fully decoded, returning + // `Ok(None)` if no new rows are available. + fn flush_block(&mut self) -> Result, ArrowError> { + self.flush_and_reset() + } } -/// A builder to create an [`Avro Reader`](Reader) that reads Avro data -/// into Arrow `RecordBatch`. +/// A builder that configures and constructs Avro readers and decoders. +/// +/// `ReaderBuilder` is the primary entry point for this module. It supports: +/// +/// * OCF reading via `Self::build`, returning a `Reader` over any `BufRead`; +/// * streaming decoding via `Self::build_decoder`, returning a `Decoder`. +/// +/// ### Options +/// +/// * **`batch_size`**: Max rows per `RecordBatch` (default: `1024`). See `Self::with_batch_size`. +/// * **`utf8_view`**: Use Arrow `StringViewArray` for string columns (default: `false`). +/// See `Self::with_utf8_view`. +/// * **`strict_mode`**: Opt‑in to stricter union handling (default: `false`). +/// See `Self::with_strict_mode`. +/// * **`reader_schema`**: Optional reader schema (projection / evolution) used when decoding +/// values (default: `None`). See `Self::with_reader_schema`. +/// * **`writer_schema_store`**: Required for building a `Decoder` for single‑object or +/// Confluent framing. Maps fingerprints to Avro schemas. See `Self::with_writer_schema_store`. +/// * **`active_fingerprint`**: Optional starting fingerprint for streaming decode when the +/// first frame omits one (rare). See `Self::with_active_fingerprint`. +/// +/// ### Examples +/// +/// Read an OCF file in batches of 4096 rows: +/// +/// ```no_run +/// use std::fs::File; +/// use std::io::BufReader; +/// use arrow_avro::reader::ReaderBuilder; +/// +/// let file = File::open("data.avro")?; +/// let mut reader = ReaderBuilder::new() +/// .with_batch_size(4096) +/// .build(BufReader::new(file))?; +/// # Ok::<(), Box>(()) +/// ``` +/// +/// Build a `Decoder` for Confluent messages: +/// +/// ```no_run +/// use arrow_avro::schema::{AvroSchema, SchemaStore, Fingerprint, FingerprintAlgorithm}; +/// use arrow_avro::reader::ReaderBuilder; +/// +/// let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::None); +/// store.set(Fingerprint::Id(1234), AvroSchema::new(r#"{"type":"record","name":"E","fields":[]}"#.to_string()))?; +/// +/// let decoder = ReaderBuilder::new() +/// .with_writer_schema_store(store) +/// .build_decoder()?; +/// # Ok::<(), Box>(()) +/// ``` #[derive(Debug)] pub struct ReaderBuilder { batch_size: usize, @@ -310,13 +641,14 @@ impl Default for ReaderBuilder { } impl ReaderBuilder { - /// Creates a new [`ReaderBuilder`] with default settings: - /// - `batch_size` = 1024 - /// - `strict_mode` = false - /// - `utf8_view` = false - /// - `reader_schema` = None - /// - `writer_schema_store` = None - /// - `active_fingerprint` = None + /// Creates a new `ReaderBuilder` with defaults: + /// + /// * `batch_size = 1024` + /// * `strict_mode = false` + /// * `utf8_view = false` + /// * `reader_schema = None` + /// * `writer_schema_store = None` + /// * `active_fingerprint = None` pub fn new() -> Self { Self::default() } @@ -324,11 +656,11 @@ impl ReaderBuilder { fn make_record_decoder( &self, writer_schema: &Schema, - reader_schema: Option<&AvroSchema>, + reader_schema: Option<&Schema>, ) -> Result { let mut builder = AvroFieldBuilder::new(writer_schema); if let Some(reader_schema) = reader_schema { - builder = builder.with_reader_schema(reader_schema.clone()); + builder = builder.with_reader_schema(reader_schema); } let root = builder .with_utf8view(self.utf8_view) @@ -337,12 +669,20 @@ impl ReaderBuilder { RecordDecoder::try_new_with_options(root.data_type(), self.utf8_view) } + fn make_record_decoder_from_schemas( + &self, + writer_schema: &Schema, + reader_schema: Option<&AvroSchema>, + ) -> Result { + let reader_schema_raw = reader_schema.map(|s| s.schema()).transpose()?; + self.make_record_decoder(writer_schema, reader_schema_raw.as_ref()) + } + fn make_decoder_with_parts( &self, active_decoder: RecordDecoder, active_fingerprint: Option, cache: IndexMap, - expect_prefix: bool, fingerprint_algorithm: FingerprintAlgorithm, ) -> Decoder { Decoder { @@ -351,11 +691,11 @@ impl ReaderBuilder { active_fingerprint, active_decoder, cache, - expect_prefix, utf8_view: self.utf8_view, fingerprint_algorithm, strict_mode: self.strict_mode, pending_schema: None, + awaiting_body: false, } } @@ -371,12 +711,12 @@ impl ReaderBuilder { .ok_or_else(|| { ArrowError::ParseError("No Avro schema present in file header".into()) })?; - let record_decoder = self.make_record_decoder(&writer_schema, reader_schema)?; + let record_decoder = + self.make_record_decoder_from_schemas(&writer_schema, reader_schema)?; return Ok(self.make_decoder_with_parts( record_decoder, None, IndexMap::new(), - false, FingerprintAlgorithm::Rabin, )); } @@ -407,11 +747,12 @@ impl ReaderBuilder { } }; let writer_schema = avro_schema.schema()?; - let decoder = self.make_record_decoder(&writer_schema, reader_schema)?; + let record_decoder = + self.make_record_decoder_from_schemas(&writer_schema, reader_schema)?; if fingerprint == start_fingerprint { - active_decoder = Some(decoder); + active_decoder = Some(record_decoder); } else { - cache.insert(fingerprint, decoder); + cache.insert(fingerprint, record_decoder); } } let active_decoder = active_decoder.ok_or_else(|| { @@ -423,50 +764,60 @@ impl ReaderBuilder { active_decoder, Some(start_fingerprint), cache, - true, store.fingerprint_algorithm(), )) } - /// Sets the row-based batch size + /// Sets the **row‑based batch size**. + /// + /// Each call to `Decoder::flush` or each iteration of `Reader` yields a batch with + /// *up to* this many rows. Larger batches can reduce overhead; smaller batches can + /// reduce peak memory usage and latency. pub fn with_batch_size(mut self, batch_size: usize) -> Self { self.batch_size = batch_size; self } - /// Set whether to use StringViewArray for string data + /// Choose Arrow's `StringViewArray` for UTF‑8 string data. /// - /// When enabled, string data from Avro files will be loaded into - /// Arrow's StringViewArray instead of the standard StringArray. + /// When enabled, textual Avro fields are loaded into Arrow’s **StringViewArray** + /// instead of the standard `StringArray`. This can improve performance for workloads + /// with many short strings by reducing allocations. pub fn with_utf8_view(mut self, utf8_view: bool) -> Self { self.utf8_view = utf8_view; self } - /// Get whether StringViewArray is enabled for string data + /// Returns whether `StringViewArray` is enabled for string data. pub fn use_utf8view(&self) -> bool { self.utf8_view } - /// Controls whether certain Avro unions of the form `[T, "null"]` should produce an error. + /// Enable stricter behavior for certain Avro unions (e.g., `[T, "null"]`). + /// + /// When `true`, ambiguous or lossy unions that would otherwise be coerced may instead + /// produce a descriptive error. Use this to catch schema issues early during ingestion. pub fn with_strict_mode(mut self, strict_mode: bool) -> Self { self.strict_mode = strict_mode; self } - /// Sets the Avro reader schema. + /// Sets the **reader schema** used during decoding. /// - /// If a schema is not provided, the schema will be read from the Avro file header. + /// If not provided, the writer schema from the OCF header (for `Reader`) or the + /// schema looked up from the fingerprint (for `Decoder`) is used directly. + /// + /// A reader schema can be used for **schema evolution** or **projection**. pub fn with_reader_schema(mut self, schema: AvroSchema) -> Self { self.reader_schema = Some(schema); self } - /// Sets the `SchemaStore` used for resolving writer schemas. + /// Sets the `SchemaStore` used to resolve writer schemas by fingerprint. /// - /// This is necessary when decoding single-object encoded data that identifies - /// schemas by a fingerprint. The store allows the decoder to look up the - /// full writer schema from a fingerprint embedded in the data. + /// This is required when building a `Decoder` for **single‑object encoding** or the + /// **Confluent** wire format. The store maps a fingerprint (Rabin / MD5 / SHA‑256 / + /// ID) to a full Avro schema. /// /// Defaults to `None`. pub fn with_writer_schema_store(mut self, store: SchemaStore) -> Self { @@ -474,19 +825,20 @@ impl ReaderBuilder { self } - /// Sets the initial schema fingerprint for decoding single-object encoded data. - /// - /// This is useful when the data stream does not begin with a schema definition - /// or fingerprint, allowing the decoder to start with a known schema from the - /// `SchemaStore`. + /// Sets the initial schema fingerprint for stream decoding. /// - /// Defaults to `None`. + /// This can be useful for streams that **do not include** a fingerprint before the first + /// record body (uncommon). If not set, the first observed fingerprint is used. pub fn with_active_fingerprint(mut self, fp: Fingerprint) -> Self { self.active_fingerprint = Some(fp); self } - /// Create a [`Reader`] from this builder and a `BufRead` + /// Build a `Reader` (OCF) from this builder and a `BufRead`. + /// + /// This reads and validates the OCF header, initializes an internal row decoder from + /// the discovered writer (and optional reader) schema, and prepares to iterate blocks, + /// decompressing if necessary. pub fn build(self, mut reader: R) -> Result, ArrowError> { let header = read_header(&mut reader)?; let decoder = self.make_decoder(Some(&header), self.reader_schema.as_ref())?; @@ -496,12 +848,20 @@ impl ReaderBuilder { decoder, block_decoder: BlockDecoder::default(), block_data: Vec::new(), + block_count: 0, block_cursor: 0, finished: false, }) } - /// Create a [`Decoder`] from this builder. + /// Build a streaming `Decoder` from this builder. + /// + /// # Requirements + /// * `SchemaStore` **must** be provided via `Self::with_writer_schema_store`. + /// * The store should contain **all** fingerprints that may appear on the stream. + /// + /// # Errors + /// * Returns [`ArrowError::InvalidArgumentError`] if the schema store is missing pub fn build_decoder(self) -> Result { if self.writer_schema_store.is_none() { return Err(ArrowError::InvalidArgumentError( @@ -512,8 +872,15 @@ impl ReaderBuilder { } } -/// A high-level Avro `Reader` that reads container-file blocks -/// and feeds them into a row-level [`Decoder`]. +/// A high‑level Avro **Object Container File** reader. +/// +/// `Reader` pulls blocks from a `BufRead` source, handles optional block compression, +/// and decodes them row‑by‑row into Arrow `RecordBatch` values using an internal +/// `Decoder`. It implements both: +/// +/// * [`Iterator>`], and +/// * `RecordBatchReader`, guaranteeing a consistent schema across all produced batches. +/// #[derive(Debug)] pub struct Reader { reader: R, @@ -521,22 +888,27 @@ pub struct Reader { decoder: Decoder, block_decoder: BlockDecoder, block_data: Vec, + block_count: usize, block_cursor: usize, finished: bool, } impl Reader { - /// Return the Arrow schema discovered from the Avro file header + /// Returns the Arrow schema discovered from the Avro file header (or derived via + /// the optional reader schema). pub fn schema(&self) -> SchemaRef { self.decoder.schema() } - /// Return the Avro container-file header + /// Returns a reference to the parsed Avro container‑file header (magic, metadata, codec, sync). pub fn avro_header(&self) -> &Header { &self.header } - /// Reads the next [`RecordBatch`] from the Avro file or `Ok(None)` on EOF + /// Reads the next `RecordBatch` from the Avro file, or `Ok(None)` on EOF. + /// + /// Batches are bounded by `batch_size`; a single OCF block may yield multiple batches, + /// and a batch may also span multiple blocks. fn read(&mut self) -> Result, ArrowError> { 'outer: while !self.finished && !self.decoder.batch_is_full() { while self.block_cursor == self.block_data.len() { @@ -550,12 +922,12 @@ impl Reader { self.reader.consume(consumed); if let Some(block) = self.block_decoder.flush() { // Successfully decoded a block. - let block_data = if let Some(ref codec) = self.header.compression()? { + self.block_data = if let Some(ref codec) = self.header.compression()? { codec.decompress(&block.data)? } else { block.data }; - self.block_data = block_data; + self.block_count = block.count; self.block_cursor = 0; } else if consumed == 0 { // The block decoder made no progress on a non-empty buffer. @@ -564,11 +936,16 @@ impl Reader { )); } } - // Try to decode more rows from the current block. - let consumed = self.decoder.decode(&self.block_data[self.block_cursor..])?; - self.block_cursor += consumed; + // Decode as many rows as will fit in the current batch + if self.block_cursor < self.block_data.len() { + let (consumed, records_decoded) = self + .decoder + .decode_block(&self.block_data[self.block_cursor..], self.block_count)?; + self.block_cursor += consumed; + self.block_count -= records_decoded; + } } - self.decoder.flush() + self.decoder.flush_block() } } @@ -595,7 +972,7 @@ mod test { use crate::reader::{read_header, Decoder, Reader, ReaderBuilder}; use crate::schema::{ AvroSchema, Fingerprint, FingerprintAlgorithm, PrimitiveType, Schema as AvroRaw, - SchemaStore, SINGLE_OBJECT_MAGIC, + SchemaStore, AVRO_ENUM_SYMBOLS_METADATA_KEY, CONFLUENT_MAGIC, SINGLE_OBJECT_MAGIC, }; use crate::test_util::arrow_test_data; use arrow::array::ArrayDataBuilder; @@ -605,11 +982,12 @@ mod test { }; use arrow_array::types::{Int32Type, IntervalMonthDayNanoType}; use arrow_array::*; - use arrow_buffer::{Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}; + use arrow_buffer::{i256, Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow_schema::{ArrowError, DataType, Field, Fields, IntervalUnit, Schema}; use bytes::{Buf, BufMut, Bytes}; use futures::executor::block_on; use futures::{stream, Stream, StreamExt, TryStreamExt}; + use serde_json::Value; use std::collections::HashMap; use std::fs; use std::fs::File; @@ -696,6 +1074,17 @@ mod test { out.extend_from_slice(&v.to_le_bytes()); out } + Fingerprint::Id(v) => { + panic!("make_prefix expects a Rabin fingerprint, got ({v})"); + } + #[cfg(feature = "md5")] + Fingerprint::MD5(v) => { + panic!("make_prefix expects a Rabin fingerprint, got ({v:?})"); + } + #[cfg(feature = "sha256")] + Fingerprint::SHA256(id) => { + panic!("make_prefix expects a Rabin fingerprint, got ({id:?})"); + } } } @@ -709,163 +1098,1344 @@ mod test { .expect("decoder") } - #[test] - fn test_schema_store_register_lookup() { - let schema_int = make_record_schema(PrimitiveType::Int); - let schema_long = make_record_schema(PrimitiveType::Long); - let mut store = SchemaStore::new(); - let fp_int = store.register(schema_int.clone()).unwrap(); - let fp_long = store.register(schema_long.clone()).unwrap(); - assert_eq!(store.lookup(&fp_int).cloned(), Some(schema_int)); - assert_eq!(store.lookup(&fp_long).cloned(), Some(schema_long)); - assert_eq!(store.fingerprint_algorithm(), FingerprintAlgorithm::Rabin); + fn make_id_prefix(id: u32, additional: usize) -> Vec { + let capacity = CONFLUENT_MAGIC.len() + size_of::() + additional; + let mut out = Vec::with_capacity(capacity); + out.extend_from_slice(&CONFLUENT_MAGIC); + out.extend_from_slice(&id.to_be_bytes()); + out } - #[test] - fn test_unknown_fingerprint_is_error() { - let (store, fp_int, _fp_long, schema_int, _schema_long) = make_two_schema_store(); - let unknown_fp = Fingerprint::Rabin(0xDEAD_BEEF_DEAD_BEEF); - let prefix = make_prefix(unknown_fp); - let mut decoder = make_decoder(&store, fp_int, &schema_int); - let err = decoder.decode(&prefix).expect_err("decode should error"); - let msg = err.to_string(); - assert!( - msg.contains("Unknown fingerprint"), - "unexpected message: {msg}" - ); + fn make_message_id(id: u32, value: i64) -> Vec { + let encoded_value = encode_zigzag(value); + let mut msg = make_id_prefix(id, encoded_value.len()); + msg.extend_from_slice(&encoded_value); + msg } - #[test] - fn test_missing_initial_fingerprint_error() { - let (store, _fp_int, _fp_long, schema_int, _schema_long) = make_two_schema_store(); - let mut decoder = ReaderBuilder::new() - .with_batch_size(8) - .with_reader_schema(schema_int.clone()) - .with_writer_schema_store(store) - .build_decoder() - .unwrap(); - let buf = [0x02u8, 0x00u8]; - let err = decoder.decode(&buf).expect_err("decode should error"); - let msg = err.to_string(); - assert!( - msg.contains("Expected single‑object encoding fingerprint"), - "unexpected message: {msg}" + fn make_value_schema(pt: PrimitiveType) -> AvroSchema { + let json_schema = format!( + r#"{{"type":"record","name":"S","fields":[{{"name":"v","type":"{}"}}]}}"#, + pt.as_ref() ); + AvroSchema::new(json_schema) } - #[test] - fn test_handle_prefix_no_schema_store() { - let (store, fp_int, _fp_long, schema_int, _schema_long) = make_two_schema_store(); - let mut decoder = make_decoder(&store, fp_int, &schema_int); - decoder.expect_prefix = false; - let res = decoder - .handle_prefix(&SINGLE_OBJECT_MAGIC[..]) - .expect("handle_prefix"); - assert!(res.is_none(), "Expected None when expect_prefix is false"); - } - - #[test] - fn test_handle_prefix_incomplete_magic() { - let (store, fp_int, _fp_long, schema_int, _schema_long) = make_two_schema_store(); - let mut decoder = make_decoder(&store, fp_int, &schema_int); - let buf = &SINGLE_OBJECT_MAGIC[..1]; - let res = decoder.handle_prefix(buf).unwrap(); - assert_eq!(res, Some(0)); - assert!(decoder.pending_schema.is_none()); + fn encode_zigzag(value: i64) -> Vec { + let mut n = ((value << 1) ^ (value >> 63)) as u64; + let mut out = Vec::new(); + loop { + if (n & !0x7F) == 0 { + out.push(n as u8); + break; + } else { + out.push(((n & 0x7F) | 0x80) as u8); + n >>= 7; + } + } + out } - #[test] - fn test_handle_prefix_magic_mismatch() { - let (store, fp_int, _fp_long, schema_int, _schema_long) = make_two_schema_store(); - let mut decoder = make_decoder(&store, fp_int, &schema_int); - let buf = [0xFFu8, 0x00u8, 0x01u8]; - let res = decoder.handle_prefix(&buf).unwrap(); - assert!(res.is_none()); + fn make_message(fp: Fingerprint, value: i64) -> Vec { + let mut msg = make_prefix(fp); + msg.extend_from_slice(&encode_zigzag(value)); + msg } - #[test] - fn test_handle_prefix_incomplete_fingerprint() { - let (store, fp_int, fp_long, schema_int, _schema_long) = make_two_schema_store(); - let mut decoder = make_decoder(&store, fp_int, &schema_int); - let long_bytes = match fp_long { - Fingerprint::Rabin(v) => v.to_le_bytes(), - }; - let mut buf = Vec::from(SINGLE_OBJECT_MAGIC); - buf.extend_from_slice(&long_bytes[..4]); - let res = decoder.handle_prefix(&buf).unwrap(); - assert_eq!(res, Some(0)); - assert!(decoder.pending_schema.is_none()); + fn load_writer_schema_json(path: &str) -> Value { + let file = File::open(path).unwrap(); + let header = super::read_header(BufReader::new(file)).unwrap(); + let schema = header.schema().unwrap().unwrap(); + serde_json::to_value(&schema).unwrap() } - #[test] - fn test_handle_prefix_valid_prefix_switches_schema() { - let (store, fp_int, fp_long, schema_int, schema_long) = make_two_schema_store(); - let mut decoder = make_decoder(&store, fp_int, &schema_int); - let writer_schema_long = schema_long.schema().unwrap(); - let root_long = AvroFieldBuilder::new(&writer_schema_long).build().unwrap(); - let long_decoder = - RecordDecoder::try_new_with_options(root_long.data_type(), decoder.utf8_view).unwrap(); - let _ = decoder.cache.insert(fp_long, long_decoder); - let mut buf = Vec::from(SINGLE_OBJECT_MAGIC); - let Fingerprint::Rabin(v) = fp_long; - buf.extend_from_slice(&v.to_le_bytes()); - let consumed = decoder.handle_prefix(&buf).unwrap().unwrap(); - assert_eq!(consumed, buf.len()); - assert!(decoder.pending_schema.is_some()); - assert_eq!(decoder.pending_schema.as_ref().unwrap().0, fp_long); + fn make_reader_schema_with_promotions( + path: &str, + promotions: &HashMap<&str, &str>, + ) -> AvroSchema { + let mut root = load_writer_schema_json(path); + assert_eq!(root["type"], "record", "writer schema must be a record"); + let fields = root + .get_mut("fields") + .and_then(|f| f.as_array_mut()) + .expect("record has fields"); + for f in fields.iter_mut() { + let Some(name) = f.get("name").and_then(|n| n.as_str()) else { + continue; + }; + if let Some(new_ty) = promotions.get(name) { + let ty = f.get_mut("type").expect("field has a type"); + match ty { + Value::String(_) => { + *ty = Value::String((*new_ty).to_string()); + } + // Union + Value::Array(arr) => { + for b in arr.iter_mut() { + match b { + Value::String(s) if s != "null" => { + *b = Value::String((*new_ty).to_string()); + break; + } + Value::Object(_) => { + *b = Value::String((*new_ty).to_string()); + break; + } + _ => {} + } + } + } + Value::Object(_) => { + *ty = Value::String((*new_ty).to_string()); + } + _ => {} + } + } + } + AvroSchema::new(root.to_string()) } - #[test] - fn test_utf8view_support() { - let schema_json = r#"{ - "type": "record", - "name": "test", - "fields": [{ - "name": "str_field", - "type": "string" - }] - }"#; - - let schema: crate::schema::Schema = serde_json::from_str(schema_json).unwrap(); - let avro_field = AvroField::try_from(&schema).unwrap(); - - let data_type = avro_field.data_type(); + fn make_reader_schema_with_enum_remap( + path: &str, + remap: &HashMap<&str, Vec<&str>>, + ) -> AvroSchema { + let mut root = load_writer_schema_json(path); + assert_eq!(root["type"], "record", "writer schema must be a record"); + let fields = root + .get_mut("fields") + .and_then(|f| f.as_array_mut()) + .expect("record has fields"); + + fn to_symbols_array(symbols: &[&str]) -> Value { + Value::Array(symbols.iter().map(|s| Value::String((*s).into())).collect()) + } - struct TestHelper; - impl TestHelper { - fn with_utf8view(field: &Field) -> Field { - match field.data_type() { - DataType::Utf8 => { - Field::new(field.name(), DataType::Utf8View, field.is_nullable()) - .with_metadata(field.metadata().clone()) + fn update_enum_symbols(ty: &mut Value, symbols: &Value) { + match ty { + Value::Object(map) => { + if matches!(map.get("type"), Some(Value::String(t)) if t == "enum") { + map.insert("symbols".to_string(), symbols.clone()); } - _ => field.clone(), } + Value::Array(arr) => { + for b in arr.iter_mut() { + if let Value::Object(map) = b { + if matches!(map.get("type"), Some(Value::String(t)) if t == "enum") { + map.insert("symbols".to_string(), symbols.clone()); + } + } + } + } + _ => {} } } + for f in fields.iter_mut() { + let Some(name) = f.get("name").and_then(|n| n.as_str()) else { + continue; + }; + if let Some(new_symbols) = remap.get(name) { + let symbols_val = to_symbols_array(new_symbols); + let ty = f.get_mut("type").expect("field has a type"); + update_enum_symbols(ty, &symbols_val); + } + } + AvroSchema::new(root.to_string()) + } - let field = TestHelper::with_utf8view(&Field::new("str_field", DataType::Utf8, false)); - - assert_eq!(field.data_type(), &DataType::Utf8View); - - let array = StringViewArray::from(vec!["test1", "test2"]); - let batch = - RecordBatch::try_from_iter(vec![("str_field", Arc::new(array) as ArrayRef)]).unwrap(); + fn read_alltypes_with_reader_schema(path: &str, reader_schema: AvroSchema) -> RecordBatch { + let file = File::open(path).unwrap(); + let reader = ReaderBuilder::new() + .with_batch_size(1024) + .with_utf8_view(false) + .with_reader_schema(reader_schema) + .build(BufReader::new(file)) + .unwrap(); + let schema = reader.schema(); + let batches = reader.collect::, _>>().unwrap(); + arrow::compute::concat_batches(&schema, &batches).unwrap() + } - assert!(batch.column(0).as_any().is::()); + fn make_reader_schema_with_selected_fields_in_order( + path: &str, + selected: &[&str], + ) -> AvroSchema { + let mut root = load_writer_schema_json(path); + assert_eq!(root["type"], "record", "writer schema must be a record"); + let writer_fields = root + .get("fields") + .and_then(|f| f.as_array()) + .expect("record has fields"); + let mut field_map: HashMap = HashMap::with_capacity(writer_fields.len()); + for f in writer_fields { + if let Some(name) = f.get("name").and_then(|n| n.as_str()) { + field_map.insert(name.to_string(), f.clone()); + } + } + let mut new_fields = Vec::with_capacity(selected.len()); + for name in selected { + let f = field_map + .get(*name) + .unwrap_or_else(|| panic!("field '{name}' not found in writer schema")) + .clone(); + new_fields.push(f); + } + root["fields"] = Value::Array(new_fields); + AvroSchema::new(root.to_string()) } #[test] - fn test_read_zero_byte_avro_file() { - let batch = read_file("test/data/zero_byte.avro", 3, false); - let schema = batch.schema(); - assert_eq!(schema.fields().len(), 1); - let field = schema.field(0); - assert_eq!(field.name(), "data"); - assert_eq!(field.data_type(), &DataType::Binary); - assert!(field.is_nullable()); - assert_eq!(batch.num_rows(), 3); - assert_eq!(batch.num_columns(), 1); + fn test_alltypes_schema_promotion_mixed() { + let files = [ + "avro/alltypes_plain.avro", + "avro/alltypes_plain.snappy.avro", + "avro/alltypes_plain.zstandard.avro", + "avro/alltypes_plain.bzip2.avro", + "avro/alltypes_plain.xz.avro", + ]; + for file in files { + let file = arrow_test_data(file); + let mut promotions: HashMap<&str, &str> = HashMap::new(); + promotions.insert("id", "long"); + promotions.insert("tinyint_col", "float"); + promotions.insert("smallint_col", "double"); + promotions.insert("int_col", "double"); + promotions.insert("bigint_col", "double"); + promotions.insert("float_col", "double"); + promotions.insert("date_string_col", "string"); + promotions.insert("string_col", "string"); + let reader_schema = make_reader_schema_with_promotions(&file, &promotions); + let batch = read_alltypes_with_reader_schema(&file, reader_schema); + let expected = RecordBatch::try_from_iter_with_nullable([ + ( + "id", + Arc::new(Int64Array::from(vec![4i64, 5, 6, 7, 2, 3, 0, 1])) as _, + true, + ), + ( + "bool_col", + Arc::new(BooleanArray::from_iter((0..8).map(|x| Some(x % 2 == 0)))) as _, + true, + ), + ( + "tinyint_col", + Arc::new(Float32Array::from_iter_values( + (0..8).map(|x| (x % 2) as f32), + )) as _, + true, + ), + ( + "smallint_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64), + )) as _, + true, + ), + ( + "int_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64), + )) as _, + true, + ), + ( + "bigint_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| ((x % 2) * 10) as f64), + )) as _, + true, + ), + ( + "float_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| ((x % 2) as f32 * 1.1f32) as f64), + )) as _, + true, + ), + ( + "double_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64 * 10.1), + )) as _, + true, + ), + ( + "date_string_col", + Arc::new(StringArray::from(vec![ + "03/01/09", "03/01/09", "04/01/09", "04/01/09", "02/01/09", "02/01/09", + "01/01/09", "01/01/09", + ])) as _, + true, + ), + ( + "string_col", + Arc::new(StringArray::from( + (0..8) + .map(|x| if x % 2 == 0 { "0" } else { "1" }) + .collect::>(), + )) as _, + true, + ), + ( + "timestamp_col", + Arc::new( + TimestampMicrosecondArray::from_iter_values([ + 1235865600000000, // 2009-03-01T00:00:00.000 + 1235865660000000, // 2009-03-01T00:01:00.000 + 1238544000000000, // 2009-04-01T00:00:00.000 + 1238544060000000, // 2009-04-01T00:01:00.000 + 1233446400000000, // 2009-02-01T00:00:00.000 + 1233446460000000, // 2009-02-01T00:01:00.000 + 1230768000000000, // 2009-01-01T00:00:00.000 + 1230768060000000, // 2009-01-01T00:01:00.000 + ]) + .with_timezone("+00:00"), + ) as _, + true, + ), + ]) + .unwrap(); + assert_eq!(batch, expected, "mismatch for file {file}"); + } + } + + #[test] + fn test_alltypes_schema_promotion_long_to_float_only() { + let files = [ + "avro/alltypes_plain.avro", + "avro/alltypes_plain.snappy.avro", + "avro/alltypes_plain.zstandard.avro", + "avro/alltypes_plain.bzip2.avro", + "avro/alltypes_plain.xz.avro", + ]; + for file in files { + let file = arrow_test_data(file); + let mut promotions: HashMap<&str, &str> = HashMap::new(); + promotions.insert("bigint_col", "float"); + let reader_schema = make_reader_schema_with_promotions(&file, &promotions); + let batch = read_alltypes_with_reader_schema(&file, reader_schema); + let expected = RecordBatch::try_from_iter_with_nullable([ + ( + "id", + Arc::new(Int32Array::from(vec![4, 5, 6, 7, 2, 3, 0, 1])) as _, + true, + ), + ( + "bool_col", + Arc::new(BooleanArray::from_iter((0..8).map(|x| Some(x % 2 == 0)))) as _, + true, + ), + ( + "tinyint_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "smallint_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "int_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "bigint_col", + Arc::new(Float32Array::from_iter_values( + (0..8).map(|x| ((x % 2) * 10) as f32), + )) as _, + true, + ), + ( + "float_col", + Arc::new(Float32Array::from_iter_values( + (0..8).map(|x| (x % 2) as f32 * 1.1), + )) as _, + true, + ), + ( + "double_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64 * 10.1), + )) as _, + true, + ), + ( + "date_string_col", + Arc::new(BinaryArray::from_iter_values([ + [48, 51, 47, 48, 49, 47, 48, 57], + [48, 51, 47, 48, 49, 47, 48, 57], + [48, 52, 47, 48, 49, 47, 48, 57], + [48, 52, 47, 48, 49, 47, 48, 57], + [48, 50, 47, 48, 49, 47, 48, 57], + [48, 50, 47, 48, 49, 47, 48, 57], + [48, 49, 47, 48, 49, 47, 48, 57], + [48, 49, 47, 48, 49, 47, 48, 57], + ])) as _, + true, + ), + ( + "string_col", + Arc::new(BinaryArray::from_iter_values((0..8).map(|x| [48 + x % 2]))) as _, + true, + ), + ( + "timestamp_col", + Arc::new( + TimestampMicrosecondArray::from_iter_values([ + 1235865600000000, // 2009-03-01T00:00:00.000 + 1235865660000000, // 2009-03-01T00:01:00.000 + 1238544000000000, // 2009-04-01T00:00:00.000 + 1238544060000000, // 2009-04-01T00:01:00.000 + 1233446400000000, // 2009-02-01T00:00:00.000 + 1233446460000000, // 2009-02-01T00:01:00.000 + 1230768000000000, // 2009-01-01T00:00:00.000 + 1230768060000000, // 2009-01-01T00:01:00.000 + ]) + .with_timezone("+00:00"), + ) as _, + true, + ), + ]) + .unwrap(); + assert_eq!(batch, expected, "mismatch for file {file}"); + } + } + + #[test] + fn test_alltypes_schema_promotion_bytes_to_string_only() { + let files = [ + "avro/alltypes_plain.avro", + "avro/alltypes_plain.snappy.avro", + "avro/alltypes_plain.zstandard.avro", + "avro/alltypes_plain.bzip2.avro", + "avro/alltypes_plain.xz.avro", + ]; + for file in files { + let file = arrow_test_data(file); + let mut promotions: HashMap<&str, &str> = HashMap::new(); + promotions.insert("date_string_col", "string"); + promotions.insert("string_col", "string"); + let reader_schema = make_reader_schema_with_promotions(&file, &promotions); + let batch = read_alltypes_with_reader_schema(&file, reader_schema); + let expected = RecordBatch::try_from_iter_with_nullable([ + ( + "id", + Arc::new(Int32Array::from(vec![4, 5, 6, 7, 2, 3, 0, 1])) as _, + true, + ), + ( + "bool_col", + Arc::new(BooleanArray::from_iter((0..8).map(|x| Some(x % 2 == 0)))) as _, + true, + ), + ( + "tinyint_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "smallint_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "int_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "bigint_col", + Arc::new(Int64Array::from_iter_values((0..8).map(|x| (x % 2) * 10))) as _, + true, + ), + ( + "float_col", + Arc::new(Float32Array::from_iter_values( + (0..8).map(|x| (x % 2) as f32 * 1.1), + )) as _, + true, + ), + ( + "double_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64 * 10.1), + )) as _, + true, + ), + ( + "date_string_col", + Arc::new(StringArray::from(vec![ + "03/01/09", "03/01/09", "04/01/09", "04/01/09", "02/01/09", "02/01/09", + "01/01/09", "01/01/09", + ])) as _, + true, + ), + ( + "string_col", + Arc::new(StringArray::from( + (0..8) + .map(|x| if x % 2 == 0 { "0" } else { "1" }) + .collect::>(), + )) as _, + true, + ), + ( + "timestamp_col", + Arc::new( + TimestampMicrosecondArray::from_iter_values([ + 1235865600000000, // 2009-03-01T00:00:00.000 + 1235865660000000, // 2009-03-01T00:01:00.000 + 1238544000000000, // 2009-04-01T00:00:00.000 + 1238544060000000, // 2009-04-01T00:01:00.000 + 1233446400000000, // 2009-02-01T00:00:00.000 + 1233446460000000, // 2009-02-01T00:01:00.000 + 1230768000000000, // 2009-01-01T00:00:00.000 + 1230768060000000, // 2009-01-01T00:01:00.000 + ]) + .with_timezone("+00:00"), + ) as _, + true, + ), + ]) + .unwrap(); + assert_eq!(batch, expected, "mismatch for file {file}"); + } + } + + #[test] + fn test_alltypes_illegal_promotion_bool_to_double_errors() { + let file = arrow_test_data("avro/alltypes_plain.avro"); + let mut promotions: HashMap<&str, &str> = HashMap::new(); + promotions.insert("bool_col", "double"); // illegal + let reader_schema = make_reader_schema_with_promotions(&file, &promotions); + let file_handle = File::open(&file).unwrap(); + let result = ReaderBuilder::new() + .with_reader_schema(reader_schema) + .build(BufReader::new(file_handle)); + let err = result.expect_err("expected illegal promotion to error"); + let msg = err.to_string(); + assert!( + msg.contains("Illegal promotion") || msg.contains("illegal promotion"), + "unexpected error: {msg}" + ); + } + + #[test] + fn test_simple_enum_with_reader_schema_mapping() { + let file = arrow_test_data("avro/simple_enum.avro"); + let mut remap: HashMap<&str, Vec<&str>> = HashMap::new(); + remap.insert("f1", vec!["d", "c", "b", "a"]); + remap.insert("f2", vec!["h", "g", "f", "e"]); + remap.insert("f3", vec!["k", "i", "j"]); + let reader_schema = make_reader_schema_with_enum_remap(&file, &remap); + let actual = read_alltypes_with_reader_schema(&file, reader_schema); + let dict_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let f1_keys = Int32Array::from(vec![3, 2, 1, 0]); + let f1_vals = StringArray::from(vec!["d", "c", "b", "a"]); + let f1 = DictionaryArray::::try_new(f1_keys, Arc::new(f1_vals)).unwrap(); + let mut md_f1 = HashMap::new(); + md_f1.insert( + AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), + r#"["d","c","b","a"]"#.to_string(), + ); + let f1_field = Field::new("f1", dict_type.clone(), false).with_metadata(md_f1); + let f2_keys = Int32Array::from(vec![1, 0, 3, 2]); + let f2_vals = StringArray::from(vec!["h", "g", "f", "e"]); + let f2 = DictionaryArray::::try_new(f2_keys, Arc::new(f2_vals)).unwrap(); + let mut md_f2 = HashMap::new(); + md_f2.insert( + AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), + r#"["h","g","f","e"]"#.to_string(), + ); + let f2_field = Field::new("f2", dict_type.clone(), false).with_metadata(md_f2); + let f3_keys = Int32Array::from(vec![Some(2), Some(0), None, Some(1)]); + let f3_vals = StringArray::from(vec!["k", "i", "j"]); + let f3 = DictionaryArray::::try_new(f3_keys, Arc::new(f3_vals)).unwrap(); + let mut md_f3 = HashMap::new(); + md_f3.insert( + AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), + r#"["k","i","j"]"#.to_string(), + ); + let f3_field = Field::new("f3", dict_type.clone(), true).with_metadata(md_f3); + let expected_schema = Arc::new(Schema::new(vec![f1_field, f2_field, f3_field])); + let expected = RecordBatch::try_new( + expected_schema, + vec![Arc::new(f1) as ArrayRef, Arc::new(f2), Arc::new(f3)], + ) + .unwrap(); + assert_eq!(actual, expected); + } + + #[test] + fn test_schema_store_register_lookup() { + let schema_int = make_record_schema(PrimitiveType::Int); + let schema_long = make_record_schema(PrimitiveType::Long); + let mut store = SchemaStore::new(); + let fp_int = store.register(schema_int.clone()).unwrap(); + let fp_long = store.register(schema_long.clone()).unwrap(); + assert_eq!(store.lookup(&fp_int).cloned(), Some(schema_int)); + assert_eq!(store.lookup(&fp_long).cloned(), Some(schema_long)); + assert_eq!(store.fingerprint_algorithm(), FingerprintAlgorithm::Rabin); + } + + #[test] + fn test_unknown_fingerprint_is_error() { + let (store, fp_int, _fp_long, _schema_int, schema_long) = make_two_schema_store(); + let unknown_fp = Fingerprint::Rabin(0xDEAD_BEEF_DEAD_BEEF); + let prefix = make_prefix(unknown_fp); + let mut decoder = make_decoder(&store, fp_int, &schema_long); + let err = decoder.decode(&prefix).expect_err("decode should error"); + let msg = err.to_string(); + assert!( + msg.contains("Unknown fingerprint"), + "unexpected message: {msg}" + ); + } + + #[test] + fn test_handle_prefix_incomplete_magic() { + let (store, fp_int, _fp_long, _schema_int, schema_long) = make_two_schema_store(); + let mut decoder = make_decoder(&store, fp_int, &schema_long); + let buf = &SINGLE_OBJECT_MAGIC[..1]; + let res = decoder.handle_prefix(buf).unwrap(); + assert_eq!(res, Some(0)); + assert!(decoder.pending_schema.is_none()); + } + + #[test] + fn test_handle_prefix_magic_mismatch() { + let (store, fp_int, _fp_long, _schema_int, schema_long) = make_two_schema_store(); + let mut decoder = make_decoder(&store, fp_int, &schema_long); + let buf = [0xFFu8, 0x00u8, 0x01u8]; + let res = decoder.handle_prefix(&buf).unwrap(); + assert!(res.is_none()); + } + + #[test] + fn test_handle_prefix_incomplete_fingerprint() { + let (store, fp_int, fp_long, _schema_int, schema_long) = make_two_schema_store(); + let mut decoder = make_decoder(&store, fp_int, &schema_long); + let long_bytes = match fp_long { + Fingerprint::Rabin(v) => v.to_le_bytes(), + Fingerprint::Id(id) => panic!("expected Rabin fingerprint, got ({id})"), + #[cfg(feature = "md5")] + Fingerprint::MD5(v) => panic!("expected Rabin fingerprint, got ({v:?})"), + #[cfg(feature = "sha256")] + Fingerprint::SHA256(v) => panic!("expected Rabin fingerprint, got ({v:?})"), + }; + let mut buf = Vec::from(SINGLE_OBJECT_MAGIC); + buf.extend_from_slice(&long_bytes[..4]); + let res = decoder.handle_prefix(&buf).unwrap(); + assert_eq!(res, Some(0)); + assert!(decoder.pending_schema.is_none()); + } + + #[test] + fn test_handle_prefix_valid_prefix_switches_schema() { + let (store, fp_int, fp_long, _schema_int, schema_long) = make_two_schema_store(); + let mut decoder = make_decoder(&store, fp_int, &schema_long); + let writer_schema_long = schema_long.schema().unwrap(); + let root_long = AvroFieldBuilder::new(&writer_schema_long).build().unwrap(); + let long_decoder = + RecordDecoder::try_new_with_options(root_long.data_type(), decoder.utf8_view).unwrap(); + let _ = decoder.cache.insert(fp_long, long_decoder); + let mut buf = Vec::from(SINGLE_OBJECT_MAGIC); + match fp_long { + Fingerprint::Rabin(v) => buf.extend_from_slice(&v.to_le_bytes()), + Fingerprint::Id(id) => panic!("expected Rabin fingerprint, got ({id})"), + #[cfg(feature = "md5")] + Fingerprint::MD5(v) => panic!("expected Rabin fingerprint, got ({v:?})"), + #[cfg(feature = "sha256")] + Fingerprint::SHA256(v) => panic!("expected Rabin fingerprint, got ({v:?})"), + } + let consumed = decoder.handle_prefix(&buf).unwrap().unwrap(); + assert_eq!(consumed, buf.len()); + assert!(decoder.pending_schema.is_some()); + assert_eq!(decoder.pending_schema.as_ref().unwrap().0, fp_long); + } + + #[test] + fn test_two_messages_same_schema() { + let writer_schema = make_value_schema(PrimitiveType::Int); + let reader_schema = writer_schema.clone(); + let mut store = SchemaStore::new(); + let fp = store.register(writer_schema).unwrap(); + let msg1 = make_message(fp, 42); + let msg2 = make_message(fp, 11); + let input = [msg1.clone(), msg2.clone()].concat(); + let mut decoder = ReaderBuilder::new() + .with_batch_size(8) + .with_reader_schema(reader_schema.clone()) + .with_writer_schema_store(store) + .with_active_fingerprint(fp) + .build_decoder() + .unwrap(); + let _ = decoder.decode(&input).unwrap(); + let batch = decoder.flush().unwrap().expect("batch"); + assert_eq!(batch.num_rows(), 2); + let col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.value(0), 42); + assert_eq!(col.value(1), 11); + } + + #[test] + fn test_two_messages_schema_switch() { + let w_int = make_value_schema(PrimitiveType::Int); + let w_long = make_value_schema(PrimitiveType::Long); + let r_long = w_long.clone(); + let mut store = SchemaStore::new(); + let fp_int = store.register(w_int).unwrap(); + let fp_long = store.register(w_long).unwrap(); + let msg_int = make_message(fp_int, 1); + let msg_long = make_message(fp_long, 123456789_i64); + let mut decoder = ReaderBuilder::new() + .with_batch_size(8) + .with_writer_schema_store(store) + .with_active_fingerprint(fp_int) + .build_decoder() + .unwrap(); + let _ = decoder.decode(&msg_int).unwrap(); + let batch1 = decoder.flush().unwrap().expect("batch1"); + assert_eq!(batch1.num_rows(), 1); + assert_eq!( + batch1 + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 1 + ); + let _ = decoder.decode(&msg_long).unwrap(); + let batch2 = decoder.flush().unwrap().expect("batch2"); + assert_eq!(batch2.num_rows(), 1); + assert_eq!( + batch2 + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 123456789_i64 + ); + } + + #[test] + fn test_two_messages_same_schema_id() { + let writer_schema = make_value_schema(PrimitiveType::Int); + let reader_schema = writer_schema.clone(); + let id = 100u32; + // Set up store with None fingerprint algorithm and register schema by id + let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::None); + let _ = store + .set(Fingerprint::Id(id), writer_schema.clone()) + .expect("set id schema"); + let msg1 = make_message_id(id, 21); + let msg2 = make_message_id(id, 22); + let input = [msg1.clone(), msg2.clone()].concat(); + let mut decoder = ReaderBuilder::new() + .with_batch_size(8) + .with_reader_schema(reader_schema) + .with_writer_schema_store(store) + .with_active_fingerprint(Fingerprint::Id(id)) + .build_decoder() + .unwrap(); + let _ = decoder.decode(&input).unwrap(); + let batch = decoder.flush().unwrap().expect("batch"); + assert_eq!(batch.num_rows(), 2); + let col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.value(0), 21); + assert_eq!(col.value(1), 22); + } + + #[test] + fn test_unknown_id_fingerprint_is_error() { + let writer_schema = make_value_schema(PrimitiveType::Int); + let id_known = 7u32; + let id_unknown = 9u32; + let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::None); + let _ = store + .set(Fingerprint::Id(id_known), writer_schema.clone()) + .expect("set id schema"); + let mut decoder = ReaderBuilder::new() + .with_batch_size(8) + .with_reader_schema(writer_schema) + .with_writer_schema_store(store) + .with_active_fingerprint(Fingerprint::Id(id_known)) + .build_decoder() + .unwrap(); + let prefix = make_id_prefix(id_unknown, 0); + let err = decoder.decode(&prefix).expect_err("decode should error"); + let msg = err.to_string(); + assert!( + msg.contains("Unknown fingerprint"), + "unexpected message: {msg}" + ); + } + + #[test] + fn test_handle_prefix_id_incomplete_magic() { + let writer_schema = make_value_schema(PrimitiveType::Int); + let id = 5u32; + let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::None); + let _ = store + .set(Fingerprint::Id(id), writer_schema.clone()) + .expect("set id schema"); + let mut decoder = ReaderBuilder::new() + .with_batch_size(8) + .with_reader_schema(writer_schema) + .with_writer_schema_store(store) + .with_active_fingerprint(Fingerprint::Id(id)) + .build_decoder() + .unwrap(); + let buf = &crate::schema::CONFLUENT_MAGIC[..0]; // empty incomplete magic + let res = decoder.handle_prefix(buf).unwrap(); + assert_eq!(res, Some(0)); + assert!(decoder.pending_schema.is_none()); + } + + fn test_split_message_across_chunks() { + let writer_schema = make_value_schema(PrimitiveType::Int); + let reader_schema = writer_schema.clone(); + let mut store = SchemaStore::new(); + let fp = store.register(writer_schema).unwrap(); + let msg1 = make_message(fp, 7); + let msg2 = make_message(fp, 8); + let msg3 = make_message(fp, 9); + let (pref2, body2) = msg2.split_at(10); + let (pref3, body3) = msg3.split_at(10); + let mut decoder = ReaderBuilder::new() + .with_batch_size(8) + .with_reader_schema(reader_schema) + .with_writer_schema_store(store) + .with_active_fingerprint(fp) + .build_decoder() + .unwrap(); + let _ = decoder.decode(&msg1).unwrap(); + let batch1 = decoder.flush().unwrap().expect("batch1"); + assert_eq!(batch1.num_rows(), 1); + assert_eq!( + batch1 + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 7 + ); + let _ = decoder.decode(pref2).unwrap(); + assert!(decoder.flush().unwrap().is_none()); + let mut chunk3 = Vec::from(body2); + chunk3.extend_from_slice(pref3); + let _ = decoder.decode(&chunk3).unwrap(); + let batch2 = decoder.flush().unwrap().expect("batch2"); + assert_eq!(batch2.num_rows(), 1); + assert_eq!( + batch2 + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 8 + ); + let _ = decoder.decode(body3).unwrap(); + let batch3 = decoder.flush().unwrap().expect("batch3"); + assert_eq!(batch3.num_rows(), 1); + assert_eq!( + batch3 + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 9 + ); + } + + #[test] + fn test_decode_stream_with_schema() { + struct TestCase<'a> { + name: &'a str, + schema: &'a str, + expected_error: Option<&'a str>, + } + let tests = vec![ + TestCase { + name: "success", + schema: r#"{"type":"record","name":"test","fields":[{"name":"f2","type":"string"}]}"#, + expected_error: None, + }, + TestCase { + name: "valid schema invalid data", + schema: r#"{"type":"record","name":"test","fields":[{"name":"f2","type":"long"}]}"#, + expected_error: Some("did not consume all bytes"), + }, + ]; + for test in tests { + let avro_schema = AvroSchema::new(test.schema.to_string()); + let mut store = SchemaStore::new(); + let fp = store.register(avro_schema.clone()).unwrap(); + let prefix = make_prefix(fp); + let record_val = "some_string"; + let mut body = prefix; + body.push((record_val.len() as u8) << 1); + body.extend_from_slice(record_val.as_bytes()); + let decoder_res = ReaderBuilder::new() + .with_batch_size(1) + .with_writer_schema_store(store) + .with_active_fingerprint(fp) + .build_decoder(); + let decoder = match decoder_res { + Ok(d) => d, + Err(e) => { + if let Some(expected) = test.expected_error { + assert!( + e.to_string().contains(expected), + "Test '{}' failed at build – expected '{expected}', got '{e}'", + test.name + ); + continue; + } else { + panic!("Test '{}' failed during build: {e}", test.name); + } + } + }; + let stream = Box::pin(stream::once(async { Bytes::from(body) })); + let decoded_stream = decode_stream(decoder, stream); + let batches_result: Result, ArrowError> = + block_on(decoded_stream.try_collect()); + match (batches_result, test.expected_error) { + (Ok(batches), None) => { + let batch = + arrow::compute::concat_batches(&batches[0].schema(), &batches).unwrap(); + let expected_field = Field::new("f2", DataType::Utf8, false); + let expected_schema = Arc::new(Schema::new(vec![expected_field])); + let expected_array = Arc::new(StringArray::from(vec![record_val])); + let expected_batch = + RecordBatch::try_new(expected_schema, vec![expected_array]).unwrap(); + assert_eq!(batch, expected_batch, "Test '{}'", test.name); + } + (Err(e), Some(expected)) => { + assert!( + e.to_string().contains(expected), + "Test '{}' – expected error containing '{expected}', got '{e}'", + test.name + ); + } + (Ok(_), Some(expected)) => { + panic!( + "Test '{}' expected failure ('{expected}') but succeeded", + test.name + ); + } + (Err(e), None) => { + panic!("Test '{}' unexpectedly failed with '{e}'", test.name); + } + } + } + } + + #[test] + fn test_utf8view_support() { + let schema_json = r#"{ + "type": "record", + "name": "test", + "fields": [{ + "name": "str_field", + "type": "string" + }] + }"#; + + let schema: crate::schema::Schema = serde_json::from_str(schema_json).unwrap(); + let avro_field = AvroField::try_from(&schema).unwrap(); + + let data_type = avro_field.data_type(); + + struct TestHelper; + impl TestHelper { + fn with_utf8view(field: &Field) -> Field { + match field.data_type() { + DataType::Utf8 => { + Field::new(field.name(), DataType::Utf8View, field.is_nullable()) + .with_metadata(field.metadata().clone()) + } + _ => field.clone(), + } + } + } + + let field = TestHelper::with_utf8view(&Field::new("str_field", DataType::Utf8, false)); + + assert_eq!(field.data_type(), &DataType::Utf8View); + + let array = StringViewArray::from(vec!["test1", "test2"]); + let batch = + RecordBatch::try_from_iter(vec![("str_field", Arc::new(array) as ArrayRef)]).unwrap(); + + assert!(batch.column(0).as_any().is::()); + } + + fn make_reader_schema_with_default_fields( + path: &str, + default_fields: Vec, + ) -> AvroSchema { + let mut root = load_writer_schema_json(path); + assert_eq!(root["type"], "record", "writer schema must be a record"); + root.as_object_mut() + .expect("schema is a JSON object") + .insert("fields".to_string(), Value::Array(default_fields)); + AvroSchema::new(root.to_string()) + } + + #[test] + fn test_schema_resolution_defaults_all_supported_types() { + let path = "test/data/skippable_types.avro"; + let duration_default = "\u{0000}".repeat(12); + let reader_schema = make_reader_schema_with_default_fields( + path, + vec![ + serde_json::json!({"name":"d_bool","type":"boolean","default":true}), + serde_json::json!({"name":"d_int","type":"int","default":42}), + serde_json::json!({"name":"d_long","type":"long","default":12345}), + serde_json::json!({"name":"d_float","type":"float","default":1.5}), + serde_json::json!({"name":"d_double","type":"double","default":2.25}), + serde_json::json!({"name":"d_bytes","type":"bytes","default":"XYZ"}), + serde_json::json!({"name":"d_string","type":"string","default":"hello"}), + serde_json::json!({"name":"d_date","type":{"type":"int","logicalType":"date"},"default":0}), + serde_json::json!({"name":"d_time_ms","type":{"type":"int","logicalType":"time-millis"},"default":1000}), + serde_json::json!({"name":"d_time_us","type":{"type":"long","logicalType":"time-micros"},"default":2000}), + serde_json::json!({"name":"d_ts_ms","type":{"type":"long","logicalType":"local-timestamp-millis"},"default":0}), + serde_json::json!({"name":"d_ts_us","type":{"type":"long","logicalType":"local-timestamp-micros"},"default":0}), + serde_json::json!({"name":"d_decimal","type":{"type":"bytes","logicalType":"decimal","precision":10,"scale":2},"default":""}), + serde_json::json!({"name":"d_fixed","type":{"type":"fixed","name":"F4","size":4},"default":"ABCD"}), + serde_json::json!({"name":"d_enum","type":{"type":"enum","name":"E","symbols":["A","B","C"]},"default":"A"}), + serde_json::json!({"name":"d_duration","type":{"type":"fixed","name":"Dur","size":12,"logicalType":"duration"},"default":duration_default}), + serde_json::json!({"name":"d_uuid","type":{"type":"string","logicalType":"uuid"},"default":"00000000-0000-0000-0000-000000000000"}), + serde_json::json!({"name":"d_array","type":{"type":"array","items":"int"},"default":[1,2,3]}), + serde_json::json!({"name":"d_map","type":{"type":"map","values":"long"},"default":{"a":1,"b":2}}), + serde_json::json!({"name":"d_record","type":{ + "type":"record","name":"DefaultRec","fields":[ + {"name":"x","type":"int"}, + {"name":"y","type":["null","string"],"default":null} + ] + },"default":{"x":7}}), + serde_json::json!({"name":"d_nullable_null","type":["null","int"],"default":null}), + serde_json::json!({"name":"d_nullable_value","type":["int","null"],"default":123}), + ], + ); + let actual = read_alltypes_with_reader_schema(path, reader_schema); + let num_rows = actual.num_rows(); + assert!(num_rows > 0, "skippable_types.avro should contain rows"); + assert_eq!( + actual.num_columns(), + 22, + "expected exactly our defaulted fields" + ); + let mut arrays: Vec> = Vec::with_capacity(22); + arrays.push(Arc::new(BooleanArray::from_iter(std::iter::repeat_n( + Some(true), + num_rows, + )))); + arrays.push(Arc::new(Int32Array::from_iter_values(std::iter::repeat_n( + 42, num_rows, + )))); + arrays.push(Arc::new(Int64Array::from_iter_values(std::iter::repeat_n( + 12345, num_rows, + )))); + arrays.push(Arc::new(Float32Array::from_iter_values( + std::iter::repeat_n(1.5f32, num_rows), + ))); + arrays.push(Arc::new(Float64Array::from_iter_values( + std::iter::repeat_n(2.25f64, num_rows), + ))); + arrays.push(Arc::new(BinaryArray::from_iter_values( + std::iter::repeat_n(b"XYZ".as_ref(), num_rows), + ))); + arrays.push(Arc::new(StringArray::from_iter_values( + std::iter::repeat_n("hello", num_rows), + ))); + arrays.push(Arc::new(Date32Array::from_iter_values( + std::iter::repeat_n(0, num_rows), + ))); + arrays.push(Arc::new(Time32MillisecondArray::from_iter_values( + std::iter::repeat_n(1_000, num_rows), + ))); + arrays.push(Arc::new(Time64MicrosecondArray::from_iter_values( + std::iter::repeat_n(2_000i64, num_rows), + ))); + arrays.push(Arc::new(TimestampMillisecondArray::from_iter_values( + std::iter::repeat_n(0i64, num_rows), + ))); + arrays.push(Arc::new(TimestampMicrosecondArray::from_iter_values( + std::iter::repeat_n(0i64, num_rows), + ))); + #[cfg(feature = "small_decimals")] + let decimal = Decimal64Array::from_iter_values(std::iter::repeat_n(0i64, num_rows)) + .with_precision_and_scale(10, 2) + .unwrap(); + #[cfg(not(feature = "small_decimals"))] + let decimal = Decimal128Array::from_iter_values(std::iter::repeat_n(0i128, num_rows)) + .with_precision_and_scale(10, 2) + .unwrap(); + arrays.push(Arc::new(decimal)); + let fixed_iter = std::iter::repeat_n(Some(*b"ABCD"), num_rows); + arrays.push(Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size(fixed_iter, 4).unwrap(), + )); + let enum_keys = Int32Array::from_iter_values(std::iter::repeat_n(0, num_rows)); + let enum_values = StringArray::from_iter_values(["A", "B", "C"]); + let enum_arr = + DictionaryArray::::try_new(enum_keys, Arc::new(enum_values)).unwrap(); + arrays.push(Arc::new(enum_arr)); + let duration_values = std::iter::repeat_n( + Some(IntervalMonthDayNanoType::make_value(0, 0, 0)), + num_rows, + ); + let duration_arr: IntervalMonthDayNanoArray = duration_values.collect(); + arrays.push(Arc::new(duration_arr)); + let uuid_bytes = [0u8; 16]; + let uuid_iter = std::iter::repeat_n(Some(uuid_bytes), num_rows); + arrays.push(Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size(uuid_iter, 16).unwrap(), + )); + let item_field = Arc::new(Field::new( + Field::LIST_FIELD_DEFAULT_NAME, + DataType::Int32, + false, + )); + let mut list_builder = ListBuilder::new(Int32Builder::new()).with_field(item_field); + for _ in 0..num_rows { + list_builder.values().append_value(1); + list_builder.values().append_value(2); + list_builder.values().append_value(3); + list_builder.append(true); + } + arrays.push(Arc::new(list_builder.finish())); + let values_field = Arc::new(Field::new("value", DataType::Int64, false)); + let mut map_builder = MapBuilder::new( + Some(builder::MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }), + StringBuilder::new(), + Int64Builder::new(), + ) + .with_values_field(values_field); + for _ in 0..num_rows { + let (keys, vals) = map_builder.entries(); + keys.append_value("a"); + vals.append_value(1); + keys.append_value("b"); + vals.append_value(2); + map_builder.append(true).unwrap(); + } + arrays.push(Arc::new(map_builder.finish())); + let rec_fields: Fields = Fields::from(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, true), + ]); + let mut sb = StructBuilder::new( + rec_fields.clone(), + vec![ + Box::new(Int32Builder::new()), + Box::new(StringBuilder::new()), + ], + ); + for _ in 0..num_rows { + sb.field_builder::(0).unwrap().append_value(7); + sb.field_builder::(1).unwrap().append_null(); + sb.append(true); + } + arrays.push(Arc::new(sb.finish())); + arrays.push(Arc::new(Int32Array::from_iter(std::iter::repeat_n( + None::, + num_rows, + )))); + arrays.push(Arc::new(Int32Array::from_iter_values(std::iter::repeat_n( + 123, num_rows, + )))); + let expected = RecordBatch::try_new(actual.schema(), arrays).unwrap(); + assert_eq!( + actual, expected, + "defaults should materialize correctly for all fields" + ); + } + + #[test] + fn test_schema_resolution_default_enum_invalid_symbol_errors() { + let path = "test/data/skippable_types.avro"; + let bad_schema = make_reader_schema_with_default_fields( + path, + vec![serde_json::json!({ + "name":"bad_enum", + "type":{"type":"enum","name":"E","symbols":["A","B","C"]}, + "default":"Z" + })], + ); + let file = File::open(path).unwrap(); + let res = ReaderBuilder::new() + .with_reader_schema(bad_schema) + .build(BufReader::new(file)); + let err = res.expect_err("expected enum default validation to fail"); + let msg = err.to_string(); + let lower_msg = msg.to_lowercase(); + assert!( + lower_msg.contains("enum") + && (lower_msg.contains("symbol") || lower_msg.contains("default")), + "unexpected error: {msg}" + ); + } + + #[test] + fn test_schema_resolution_default_fixed_size_mismatch_errors() { + let path = "test/data/skippable_types.avro"; + let bad_schema = make_reader_schema_with_default_fields( + path, + vec![serde_json::json!({ + "name":"bad_fixed", + "type":{"type":"fixed","name":"F","size":4}, + "default":"ABC" + })], + ); + let file = File::open(path).unwrap(); + let res = ReaderBuilder::new() + .with_reader_schema(bad_schema) + .build(BufReader::new(file)); + let err = res.expect_err("expected fixed default validation to fail"); + let msg = err.to_string(); + let lower_msg = msg.to_lowercase(); + assert!( + lower_msg.contains("fixed") + && (lower_msg.contains("size") + || lower_msg.contains("length") + || lower_msg.contains("does not match")), + "unexpected error: {msg}" + ); + } + + #[test] + fn test_alltypes_skip_writer_fields_keep_double_only() { + let file = arrow_test_data("avro/alltypes_plain.avro"); + let reader_schema = + make_reader_schema_with_selected_fields_in_order(&file, &["double_col"]); + let batch = read_alltypes_with_reader_schema(&file, reader_schema); + let expected = RecordBatch::try_from_iter_with_nullable([( + "double_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64 * 10.1), + )) as _, + true, + )]) + .unwrap(); + assert_eq!(batch, expected); + } + + #[test] + fn test_alltypes_skip_writer_fields_reorder_and_skip_many() { + let file = arrow_test_data("avro/alltypes_plain.avro"); + let reader_schema = + make_reader_schema_with_selected_fields_in_order(&file, &["timestamp_col", "id"]); + let batch = read_alltypes_with_reader_schema(&file, reader_schema); + let expected = RecordBatch::try_from_iter_with_nullable([ + ( + "timestamp_col", + Arc::new( + TimestampMicrosecondArray::from_iter_values([ + 1235865600000000, // 2009-03-01T00:00:00.000 + 1235865660000000, // 2009-03-01T00:01:00.000 + 1238544000000000, // 2009-04-01T00:00:00.000 + 1238544060000000, // 2009-04-01T00:01:00.000 + 1233446400000000, // 2009-02-01T00:00:00.000 + 1233446460000000, // 2009-02-01T00:01:00.000 + 1230768000000000, // 2009-01-01T00:00:00.000 + 1230768060000000, // 2009-01-01T00:01:00.000 + ]) + .with_timezone("+00:00"), + ) as _, + true, + ), + ( + "id", + Arc::new(Int32Array::from(vec![4, 5, 6, 7, 2, 3, 0, 1])) as _, + true, + ), + ]) + .unwrap(); + assert_eq!(batch, expected); + } + + #[test] + fn test_skippable_types_project_each_field_individually() { + let path = "test/data/skippable_types.avro"; + let full = read_file(path, 1024, false); + let schema_full = full.schema(); + let num_rows = full.num_rows(); + let writer_json = load_writer_schema_json(path); + assert_eq!( + writer_json["type"], "record", + "writer schema must be a record" + ); + let fields_json = writer_json + .get("fields") + .and_then(|f| f.as_array()) + .expect("record has fields"); + assert_eq!( + schema_full.fields().len(), + fields_json.len(), + "full read column count vs writer fields" + ); + for (idx, f) in fields_json.iter().enumerate() { + let name = f + .get("name") + .and_then(|n| n.as_str()) + .unwrap_or_else(|| panic!("field at index {idx} has no name")); + let reader_schema = make_reader_schema_with_selected_fields_in_order(path, &[name]); + let projected = read_alltypes_with_reader_schema(path, reader_schema); + assert_eq!( + projected.num_columns(), + 1, + "projected batch should contain exactly the selected column '{name}'" + ); + assert_eq!( + projected.num_rows(), + num_rows, + "row count mismatch for projected column '{name}'" + ); + let field = schema_full.field(idx).clone(); + let col = full.column(idx).clone(); + let expected = + RecordBatch::try_new(Arc::new(Schema::new(vec![field])), vec![col]).unwrap(); + // Equality means: (1) read the right column values; and (2) all other + // writer fields were skipped correctly for this projection (no misalignment). + assert_eq!( + projected, expected, + "projected column '{name}' mismatch vs full read column" + ); + } + } + + #[test] + fn test_read_zero_byte_avro_file() { + let batch = read_file("test/data/zero_byte.avro", 3, false); + let schema = batch.schema(); + assert_eq!(schema.fields().len(), 1); + let field = schema.field(0); + assert_eq!(field.name(), "data"); + assert_eq!(field.data_type(), &DataType::Binary); + assert!(field.is_nullable()); + assert_eq!(batch.num_rows(), 3); + assert_eq!(batch.num_columns(), 1); let binary_array = batch .column(0) .as_any() @@ -1109,18 +2679,18 @@ mod test { let expected = RecordBatch::try_from_iter_with_nullable([( "foo", Arc::new(BinaryArray::from_iter_values(vec![ - b"\x00".as_ref(), - b"\x01".as_ref(), - b"\x02".as_ref(), - b"\x03".as_ref(), - b"\x04".as_ref(), - b"\x05".as_ref(), - b"\x06".as_ref(), - b"\x07".as_ref(), - b"\x08".as_ref(), - b"\t".as_ref(), - b"\n".as_ref(), - b"\x0b".as_ref(), + b"\x00" as &[u8], + b"\x01" as &[u8], + b"\x02" as &[u8], + b"\x03" as &[u8], + b"\x04" as &[u8], + b"\x05" as &[u8], + b"\x06" as &[u8], + b"\x07" as &[u8], + b"\x08" as &[u8], + b"\t" as &[u8], + b"\n" as &[u8], + b"\x0b" as &[u8], ])) as Arc, true, )]) @@ -1129,121 +2699,140 @@ mod test { } #[test] - fn test_decode_stream_with_schema() { - struct TestCase<'a> { - name: &'a str, - schema: &'a str, - expected_error: Option<&'a str>, - } - let tests = vec![ - TestCase { - name: "success", - schema: r#"{"type":"record","name":"test","fields":[{"name":"f2","type":"string"}]}"#, - expected_error: None, - }, - TestCase { - name: "valid schema invalid data", - schema: r#"{"type":"record","name":"test","fields":[{"name":"f2","type":"long"}]}"#, - expected_error: Some("did not consume all bytes"), - }, + fn test_decimal() { + // Choose expected Arrow types depending on the `small_decimals` feature flag. + // With `small_decimals` enabled, Decimal32/Decimal64 are used where their + // precision allows; otherwise, those cases resolve to Decimal128. + #[cfg(feature = "small_decimals")] + let files: [(&str, DataType); 8] = [ + ( + "avro/fixed_length_decimal.avro", + DataType::Decimal128(25, 2), + ), + ( + "avro/fixed_length_decimal_legacy.avro", + DataType::Decimal64(13, 2), + ), + ("avro/int32_decimal.avro", DataType::Decimal32(4, 2)), + ("avro/int64_decimal.avro", DataType::Decimal64(10, 2)), + ( + "test/data/int256_decimal.avro", + DataType::Decimal256(76, 10), + ), + ( + "test/data/fixed256_decimal.avro", + DataType::Decimal256(76, 10), + ), + ( + "test/data/fixed_length_decimal_legacy_32.avro", + DataType::Decimal32(9, 2), + ), + ("test/data/int128_decimal.avro", DataType::Decimal128(38, 2)), ]; - for test in tests { - let avro_schema = AvroSchema::new(test.schema.to_string()); - let mut store = SchemaStore::new(); - let fp = store.register(avro_schema.clone()).unwrap(); - let prefix = make_prefix(fp); - let record_val = "some_string"; - let mut body = prefix; - body.push((record_val.len() as u8) << 1); - body.extend_from_slice(record_val.as_bytes()); - let decoder_res = ReaderBuilder::new() - .with_batch_size(1) - .with_writer_schema_store(store) - .with_active_fingerprint(fp) - .build_decoder(); - let decoder = match decoder_res { - Ok(d) => d, - Err(e) => { - if let Some(expected) = test.expected_error { - assert!( - e.to_string().contains(expected), - "Test '{}' failed at build – expected '{expected}', got '{e}'", - test.name - ); - continue; - } else { - panic!("Test '{}' failed during build: {e}", test.name); + #[cfg(not(feature = "small_decimals"))] + let files: [(&str, DataType); 8] = [ + ( + "avro/fixed_length_decimal.avro", + DataType::Decimal128(25, 2), + ), + ( + "avro/fixed_length_decimal_legacy.avro", + DataType::Decimal128(13, 2), + ), + ("avro/int32_decimal.avro", DataType::Decimal128(4, 2)), + ("avro/int64_decimal.avro", DataType::Decimal128(10, 2)), + ( + "test/data/int256_decimal.avro", + DataType::Decimal256(76, 10), + ), + ( + "test/data/fixed256_decimal.avro", + DataType::Decimal256(76, 10), + ), + ( + "test/data/fixed_length_decimal_legacy_32.avro", + DataType::Decimal128(9, 2), + ), + ("test/data/int128_decimal.avro", DataType::Decimal128(38, 2)), + ]; + for (file, expected_dt) in files { + let (precision, scale) = match expected_dt { + DataType::Decimal32(p, s) + | DataType::Decimal64(p, s) + | DataType::Decimal128(p, s) + | DataType::Decimal256(p, s) => (p, s), + _ => unreachable!("Unexpected decimal type in test inputs"), + }; + assert!(scale >= 0, "test data uses non-negative scales only"); + let scale_u32 = scale as u32; + let file_path: String = if file.starts_with("avro/") { + arrow_test_data(file) + } else { + std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join(file) + .to_string_lossy() + .into_owned() + }; + let pow10: i128 = 10i128.pow(scale_u32); + let values_i128: Vec = (1..=24).map(|n| (n as i128) * pow10).collect(); + let build_expected = |dt: &DataType, values: &[i128]| -> ArrayRef { + match *dt { + #[cfg(feature = "small_decimals")] + DataType::Decimal32(p, s) => { + let it = values.iter().map(|&v| v as i32); + Arc::new( + Decimal32Array::from_iter_values(it) + .with_precision_and_scale(p, s) + .unwrap(), + ) + } + #[cfg(feature = "small_decimals")] + DataType::Decimal64(p, s) => { + let it = values.iter().map(|&v| v as i64); + Arc::new( + Decimal64Array::from_iter_values(it) + .with_precision_and_scale(p, s) + .unwrap(), + ) + } + DataType::Decimal128(p, s) => { + let it = values.iter().copied(); + Arc::new( + Decimal128Array::from_iter_values(it) + .with_precision_and_scale(p, s) + .unwrap(), + ) + } + DataType::Decimal256(p, s) => { + let it = values.iter().map(|&v| i256::from_i128(v)); + Arc::new( + Decimal256Array::from_iter_values(it) + .with_precision_and_scale(p, s) + .unwrap(), + ) } + _ => unreachable!("Unexpected decimal type in test"), } }; - let stream = Box::pin(stream::once(async { Bytes::from(body) })); - let decoded_stream = decode_stream(decoder, stream); - let batches_result: Result, ArrowError> = - block_on(decoded_stream.try_collect()); - match (batches_result, test.expected_error) { - (Ok(batches), None) => { - let batch = - arrow::compute::concat_batches(&batches[0].schema(), &batches).unwrap(); - let expected_field = Field::new("f2", DataType::Utf8, false); - let expected_schema = Arc::new(Schema::new(vec![expected_field])); - let expected_array = Arc::new(StringArray::from(vec![record_val])); - let expected_batch = - RecordBatch::try_new(expected_schema, vec![expected_array]).unwrap(); - assert_eq!(batch, expected_batch, "Test '{}'", test.name); - } - (Err(e), Some(expected)) => { - assert!( - e.to_string().contains(expected), - "Test '{}' – expected error containing '{expected}', got '{e}'", - test.name - ); - } - (Ok(_), Some(expected)) => { - panic!( - "Test '{}' expected failure ('{expected}') but succeeded", - test.name - ); - } - (Err(e), None) => { - panic!("Test '{}' unexpectedly failed with '{e}'", test.name); - } - } - } - } - - #[test] - fn test_decimal() { - let files = [ - ("avro/fixed_length_decimal.avro", 25, 2), - ("avro/fixed_length_decimal_legacy.avro", 13, 2), - ("avro/int32_decimal.avro", 4, 2), - ("avro/int64_decimal.avro", 10, 2), - ]; - let decimal_values: Vec = (1..=24).map(|n| n as i128 * 100).collect(); - for (file, precision, scale) in files { - let file_path = arrow_test_data(file); let actual_batch = read_file(&file_path, 8, false); - let expected_array = Decimal128Array::from_iter_values(decimal_values.clone()) - .with_precision_and_scale(precision, scale) - .unwrap(); + let actual_nullable = actual_batch.schema().field(0).is_nullable(); + let expected_array = build_expected(&expected_dt, &values_i128); let mut meta = HashMap::new(); meta.insert("precision".to_string(), precision.to_string()); meta.insert("scale".to_string(), scale.to_string()); - let field_with_meta = Field::new("value", DataType::Decimal128(precision, scale), true) - .with_metadata(meta); - let expected_schema = Arc::new(Schema::new(vec![field_with_meta])); + let field = + Field::new("value", expected_dt.clone(), actual_nullable).with_metadata(meta); + let expected_schema = Arc::new(Schema::new(vec![field])); let expected_batch = - RecordBatch::try_new(expected_schema.clone(), vec![Arc::new(expected_array)]) - .expect("Failed to build expected RecordBatch"); + RecordBatch::try_new(expected_schema.clone(), vec![expected_array]).unwrap(); assert_eq!( actual_batch, expected_batch, - "Decoded RecordBatch does not match the expected Decimal128 data for file {file}" + "Decoded RecordBatch does not match for {file}" ); let actual_batch_small = read_file(&file_path, 3, false); assert_eq!( - actual_batch_small, - expected_batch, - "Decoded RecordBatch does not match the expected Decimal128 data for file {file} with batch size 3" + actual_batch_small, expected_batch, + "Decoded RecordBatch does not match for {file} with batch size 3" ); } } @@ -1420,19 +3009,19 @@ mod test { DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); let mut md_f1 = HashMap::new(); md_f1.insert( - "avro.enum.symbols".to_string(), + AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), r#"["a","b","c","d"]"#.to_string(), ); let f1_field = Field::new("f1", dict_type.clone(), false).with_metadata(md_f1); let mut md_f2 = HashMap::new(); md_f2.insert( - "avro.enum.symbols".to_string(), + AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), r#"["e","f","g","h"]"#.to_string(), ); let f2_field = Field::new("f2", dict_type.clone(), false).with_metadata(md_f2); let mut md_f3 = HashMap::new(); md_f3.insert( - "avro.enum.symbols".to_string(), + AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), r#"["i","j","k"]"#.to_string(), ); let f3_field = Field::new("f3", dict_type.clone(), true).with_metadata(md_f3); diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 180afcd2d8c3..9ca8acb45b34 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -15,30 +15,74 @@ // specific language governing permissions and limitations // under the License. -use crate::codec::{AvroDataType, Codec, Nullability}; +use crate::codec::{ + AvroDataType, AvroField, AvroLiteral, Codec, Promotion, ResolutionInfo, ResolvedRecord, +}; use crate::reader::block::{Block, BlockDecoder}; use crate::reader::cursor::AvroCursor; -use crate::reader::header::Header; -use crate::schema::*; +use crate::schema::Nullability; use arrow_array::builder::{ - ArrayBuilder, Decimal128Builder, Decimal256Builder, IntervalMonthDayNanoBuilder, - PrimitiveBuilder, + Decimal128Builder, Decimal256Builder, IntervalMonthDayNanoBuilder, StringViewBuilder, }; +#[cfg(feature = "small_decimals")] +use arrow_array::builder::{Decimal32Builder, Decimal64Builder}; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::*; use arrow_schema::{ - ArrowError, DataType, Field as ArrowField, FieldRef, Fields, IntervalUnit, - Schema as ArrowSchema, SchemaRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, + ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef, + DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; +#[cfg(feature = "small_decimals")] +use arrow_schema::{DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION}; use std::cmp::Ordering; -use std::collections::HashMap; -use std::io::Read; use std::sync::Arc; use uuid::Uuid; const DEFAULT_CAPACITY: usize = 1024; +/// Macro to decode a decimal payload for a given width and integer type. +macro_rules! decode_decimal { + ($size:expr, $buf:expr, $builder:expr, $N:expr, $Int:ty) => {{ + let bytes = read_decimal_bytes_be::<{ $N }>($buf, $size)?; + $builder.append_value(<$Int>::from_be_bytes(bytes)); + }}; +} + +/// Macro to finish a decimal builder into an array with precision/scale and nulls. +macro_rules! flush_decimal { + ($builder:expr, $precision:expr, $scale:expr, $nulls:expr, $ArrayTy:ty) => {{ + let (_, vals, _) = $builder.finish().into_parts(); + let dec = <$ArrayTy>::new(vals, $nulls) + .with_precision_and_scale(*$precision as u8, $scale.unwrap_or(0) as i8) + .map_err(|e| ArrowError::ParseError(e.to_string()))?; + Arc::new(dec) as ArrayRef + }}; +} + +/// Macro to append a default decimal value from two's-complement big-endian bytes +/// into the corresponding decimal builder, with compile-time constructed error text. +macro_rules! append_decimal_default { + ($lit:expr, $builder:expr, $N:literal, $Int:ty, $name:literal) => {{ + match $lit { + AvroLiteral::Bytes(b) => { + let ext = sign_cast_to::<$N>(b)?; + let val = <$Int>::from_be_bytes(ext); + $builder.append_value(val); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + concat!( + "Default for ", + $name, + " must be bytes (two's-complement big-endian)" + ) + .to_string(), + )), + } + }}; +} + #[derive(Debug)] pub(crate) struct RecordDecoderBuilder<'a> { data_type: &'a AvroDataType, @@ -70,6 +114,7 @@ pub(crate) struct RecordDecoder { schema: SchemaRef, fields: Vec, use_utf8view: bool, + projector: Option, } impl RecordDecoder { @@ -92,8 +137,6 @@ impl RecordDecoder { /// # Arguments /// * `data_type` - The Avro data type to decode. /// * `use_utf8view` - A flag indicating whether to use `Utf8View` for string types. - /// * `strict_mode` - A flag to enable strict decoding, returning an error if the data - /// does not conform to the schema. /// /// # Errors /// This function will return an error if the provided `data_type` is not a `Record`. @@ -101,14 +144,30 @@ impl RecordDecoder { data_type: &AvroDataType, use_utf8view: bool, ) -> Result { - match Decoder::try_new(data_type)? { - Decoder::Record(fields, encodings) => Ok(Self { - schema: Arc::new(ArrowSchema::new(fields)), - fields: encodings, - use_utf8view, - }), - encoding => Err(ArrowError::ParseError(format!( - "Expected record got {encoding:?}" + match data_type.codec() { + Codec::Struct(reader_fields) => { + // Build Arrow schema fields and per-child decoders + let mut arrow_fields = Vec::with_capacity(reader_fields.len()); + let mut encodings = Vec::with_capacity(reader_fields.len()); + for avro_field in reader_fields.iter() { + arrow_fields.push(avro_field.field()); + encodings.push(Decoder::try_new(avro_field.data_type())?); + } + let projector = match data_type.resolution.as_ref() { + Some(ResolutionInfo::Record(rec)) => { + Some(ProjectorBuilder::try_new(rec, reader_fields).build()?) + } + _ => None, + }; + Ok(Self { + schema: Arc::new(ArrowSchema::new(arrow_fields)), + fields: encodings, + use_utf8view, + projector, + }) + } + other => Err(ArrowError::ParseError(format!( + "Expected record got {other:?}" ))), } } @@ -121,9 +180,18 @@ impl RecordDecoder { /// Decode `count` records from `buf` pub(crate) fn decode(&mut self, buf: &[u8], count: usize) -> Result { let mut cursor = AvroCursor::new(buf); - for _ in 0..count { - for field in &mut self.fields { - field.decode(&mut cursor)?; + match self.projector.as_mut() { + Some(proj) => { + for _ in 0..count { + proj.project_record(&mut cursor, &mut self.fields)?; + } + } + None => { + for _ in 0..count { + for field in &mut self.fields { + field.decode(&mut cursor)?; + } + } } } Ok(cursor.position()) @@ -136,11 +204,16 @@ impl RecordDecoder { .iter_mut() .map(|x| x.flush(None)) .collect::, _>>()?; - RecordBatch::try_new(self.schema.clone(), arrays) } } +#[derive(Debug)] +struct EnumResolution { + mapping: Arc<[i32]>, + default_index: i32, +} + #[derive(Debug)] enum Decoder { Null(usize), @@ -154,13 +227,21 @@ enum Decoder { TimeMicros(Vec), TimestampMillis(bool, Vec), TimestampMicros(bool, Vec), + Int32ToInt64(Vec), + Int32ToFloat32(Vec), + Int32ToFloat64(Vec), + Int64ToFloat32(Vec), + Int64ToFloat64(Vec), + Float32ToFloat64(Vec), + BytesToString(OffsetBufferBuilder, Vec), + StringToBytes(OffsetBufferBuilder, Vec), Binary(OffsetBufferBuilder, Vec), /// String data encoded as UTF-8 bytes, mapped to Arrow's StringArray String(OffsetBufferBuilder, Vec), /// String data encoded as UTF-8 bytes, but mapped to Arrow's StringViewArray StringView(OffsetBufferBuilder, Vec), Array(FieldRef, OffsetBufferBuilder, Box), - Record(Fields, Vec), + Record(Fields, Vec, Option), Map( FieldRef, OffsetBufferBuilder, @@ -169,9 +250,13 @@ enum Decoder { Box, ), Fixed(i32, Vec), - Enum(Vec, Arc<[String]>), + Enum(Vec, Arc<[String]>, Option), Duration(IntervalMonthDayNanoBuilder), Uuid(Vec), + #[cfg(feature = "small_decimals")] + Decimal32(usize, Option, Option, Decimal32Builder), + #[cfg(feature = "small_decimals")] + Decimal64(usize, Option, Option, Decimal64Builder), Decimal128(usize, Option, Option, Decimal128Builder), Decimal256(usize, Option, Option, Decimal256Builder), Nullable(Nullability, NullBufferBuilder, Box), @@ -179,76 +264,115 @@ enum Decoder { impl Decoder { fn try_new(data_type: &AvroDataType) -> Result { - let decoder = match data_type.codec() { - Codec::Null => Self::Null(0), - Codec::Boolean => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), - Codec::Int32 => Self::Int32(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Int64 => Self::Int64(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Float32 => Self::Float32(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Float64 => Self::Float64(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Binary => Self::Binary( + // Extract just the Promotion (if any) to simplify pattern matching + let promotion = match data_type.resolution.as_ref() { + Some(ResolutionInfo::Promotion(p)) => Some(p), + _ => None, + }; + let decoder = match (data_type.codec(), promotion) { + (Codec::Int64, Some(Promotion::IntToLong)) => { + Self::Int32ToInt64(Vec::with_capacity(DEFAULT_CAPACITY)) + } + (Codec::Float32, Some(Promotion::IntToFloat)) => { + Self::Int32ToFloat32(Vec::with_capacity(DEFAULT_CAPACITY)) + } + (Codec::Float64, Some(Promotion::IntToDouble)) => { + Self::Int32ToFloat64(Vec::with_capacity(DEFAULT_CAPACITY)) + } + (Codec::Float32, Some(Promotion::LongToFloat)) => { + Self::Int64ToFloat32(Vec::with_capacity(DEFAULT_CAPACITY)) + } + (Codec::Float64, Some(Promotion::LongToDouble)) => { + Self::Int64ToFloat64(Vec::with_capacity(DEFAULT_CAPACITY)) + } + (Codec::Float64, Some(Promotion::FloatToDouble)) => { + Self::Float32ToFloat64(Vec::with_capacity(DEFAULT_CAPACITY)) + } + (Codec::Utf8, Some(Promotion::BytesToString)) + | (Codec::Utf8View, Some(Promotion::BytesToString)) => Self::BytesToString( + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + ), + (Codec::Binary, Some(Promotion::StringToBytes)) => Self::StringToBytes( OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), - Codec::Utf8 => Self::String( + (Codec::Null, _) => Self::Null(0), + (Codec::Boolean, _) => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), + (Codec::Int32, _) => Self::Int32(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::Int64, _) => Self::Int64(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::Float32, _) => Self::Float32(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::Float64, _) => Self::Float64(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::Binary, _) => Self::Binary( OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), - Codec::Utf8View => Self::StringView( + (Codec::Utf8, _) => Self::String( OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), - Codec::Date32 => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimeMillis => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimeMicros => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimestampMillis(is_utc) => { + (Codec::Utf8View, _) => Self::StringView( + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + ), + (Codec::Date32, _) => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::TimeMillis, _) => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::TimeMicros, _) => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::TimestampMillis(is_utc), _) => { Self::TimestampMillis(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) } - Codec::TimestampMicros(is_utc) => { + (Codec::TimestampMicros(is_utc), _) => { Self::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) } - Codec::Fixed(sz) => Self::Fixed(*sz, Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Decimal(precision, scale, size) => { + (Codec::Fixed(sz), _) => Self::Fixed(*sz, Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::Decimal(precision, scale, size), _) => { let p = *precision; let s = *scale; - let sz = *size; let prec = p as u8; let scl = s.unwrap_or(0) as i8; - match (sz, p) { - (Some(fixed_size), _) if fixed_size <= 16 => { - let builder = - Decimal128Builder::new().with_precision_and_scale(prec, scl)?; - Self::Decimal128(p, s, sz, builder) - } - (Some(fixed_size), _) if fixed_size <= 32 => { - let builder = - Decimal256Builder::new().with_precision_and_scale(prec, scl)?; - Self::Decimal256(p, s, sz, builder) - } - (Some(fixed_size), _) => { + #[cfg(feature = "small_decimals")] + { + if p <= DECIMAL32_MAX_PRECISION as usize { + let builder = Decimal32Builder::with_capacity(DEFAULT_CAPACITY) + .with_precision_and_scale(prec, scl)?; + Self::Decimal32(p, s, *size, builder) + } else if p <= DECIMAL64_MAX_PRECISION as usize { + let builder = Decimal64Builder::with_capacity(DEFAULT_CAPACITY) + .with_precision_and_scale(prec, scl)?; + Self::Decimal64(p, s, *size, builder) + } else if p <= DECIMAL128_MAX_PRECISION as usize { + let builder = Decimal128Builder::with_capacity(DEFAULT_CAPACITY) + .with_precision_and_scale(prec, scl)?; + Self::Decimal128(p, s, *size, builder) + } else if p <= DECIMAL256_MAX_PRECISION as usize { + let builder = Decimal256Builder::with_capacity(DEFAULT_CAPACITY) + .with_precision_and_scale(prec, scl)?; + Self::Decimal256(p, s, *size, builder) + } else { return Err(ArrowError::ParseError(format!( - "Unsupported decimal size: {fixed_size:?}" + "Decimal precision {p} exceeds maximum supported" ))); } - (None, p) if p <= DECIMAL128_MAX_PRECISION as usize => { - let builder = - Decimal128Builder::new().with_precision_and_scale(prec, scl)?; - Self::Decimal128(p, s, sz, builder) - } - (None, p) if p <= DECIMAL256_MAX_PRECISION as usize => { - let builder = - Decimal256Builder::new().with_precision_and_scale(prec, scl)?; - Self::Decimal256(p, s, sz, builder) - } - (None, _) => { + } + #[cfg(not(feature = "small_decimals"))] + { + if p <= DECIMAL128_MAX_PRECISION as usize { + let builder = Decimal128Builder::with_capacity(DEFAULT_CAPACITY) + .with_precision_and_scale(prec, scl)?; + Self::Decimal128(p, s, *size, builder) + } else if p <= DECIMAL256_MAX_PRECISION as usize { + let builder = Decimal256Builder::with_capacity(DEFAULT_CAPACITY) + .with_precision_and_scale(prec, scl)?; + Self::Decimal256(p, s, *size, builder) + } else { return Err(ArrowError::ParseError(format!( "Decimal precision {p} exceeds maximum supported" ))); } } } - Codec::Interval => Self::Duration(IntervalMonthDayNanoBuilder::new()), - Codec::List(item) => { + (Codec::Interval, _) => Self::Duration(IntervalMonthDayNanoBuilder::new()), + (Codec::List(item), _) => { let decoder = Self::try_new(item)?; Self::Array( Arc::new(item.field_with_name("item")), @@ -256,10 +380,17 @@ impl Decoder { Box::new(decoder), ) } - Codec::Enum(symbols) => { - Self::Enum(Vec::with_capacity(DEFAULT_CAPACITY), symbols.clone()) + (Codec::Enum(symbols), _) => { + let res = match data_type.resolution.as_ref() { + Some(ResolutionInfo::EnumMapping(mapping)) => Some(EnumResolution { + mapping: mapping.mapping.clone(), + default_index: mapping.default_index, + }), + _ => None, + }; + Self::Enum(Vec::with_capacity(DEFAULT_CAPACITY), symbols.clone(), res) } - Codec::Struct(fields) => { + (Codec::Struct(fields), _) => { let mut arrow_fields = Vec::with_capacity(fields.len()); let mut encodings = Vec::with_capacity(fields.len()); for avro_field in fields.iter() { @@ -267,10 +398,16 @@ impl Decoder { arrow_fields.push(avro_field.field()); encodings.push(encoding); } - Self::Record(arrow_fields.into(), encodings) + let projector = + if let Some(ResolutionInfo::Record(rec)) = data_type.resolution.as_ref() { + Some(ProjectorBuilder::try_new(rec, fields).build()?) + } else { + None + }; + Self::Record(arrow_fields.into(), encodings, projector) } - Codec::Map(child) => { - let val_field = child.field_with_name("value").with_nullable(true); + (Codec::Map(child), _) => { + let val_field = child.field_with_name("value"); let map_field = Arc::new(ArrowField::new( "entries", DataType::Struct(Fields::from(vec![ @@ -288,7 +425,7 @@ impl Decoder { Box::new(val_dec), ) } - Codec::Uuid => Self::Uuid(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::Uuid, _) => Self::Uuid(Vec::with_capacity(DEFAULT_CAPACITY)), }; Ok(match data_type.nullability() { Some(nullability) => Self::Nullable( @@ -307,12 +444,20 @@ impl Decoder { Self::Boolean(b) => b.append(false), Self::Int32(v) | Self::Date32(v) | Self::TimeMillis(v) => v.push(0), Self::Int64(v) + | Self::Int32ToInt64(v) | Self::TimeMicros(v) | Self::TimestampMillis(_, v) | Self::TimestampMicros(_, v) => v.push(0), - Self::Float32(v) => v.push(0.), - Self::Float64(v) => v.push(0.), - Self::Binary(offsets, _) | Self::String(offsets, _) | Self::StringView(offsets, _) => { + Self::Float32(v) | Self::Int32ToFloat32(v) | Self::Int64ToFloat32(v) => v.push(0.), + Self::Float64(v) + | Self::Int32ToFloat64(v) + | Self::Int64ToFloat64(v) + | Self::Float32ToFloat64(v) => v.push(0.), + Self::Binary(offsets, _) + | Self::String(offsets, _) + | Self::StringView(offsets, _) + | Self::BytesToString(offsets, _) + | Self::StringToBytes(offsets, _) => { offsets.push_length(0); } Self::Uuid(v) => { @@ -321,16 +466,20 @@ impl Decoder { Self::Array(_, offsets, e) => { offsets.push_length(0); } - Self::Record(_, e) => e.iter_mut().for_each(|e| e.append_null()), + Self::Record(_, e, _) => e.iter_mut().for_each(|e| e.append_null()), Self::Map(_, _koff, moff, _, _) => { moff.push_length(0); } Self::Fixed(sz, accum) => { accum.extend(std::iter::repeat_n(0u8, *sz as usize)); } + #[cfg(feature = "small_decimals")] + Self::Decimal32(_, _, _, builder) => builder.append_value(0), + #[cfg(feature = "small_decimals")] + Self::Decimal64(_, _, _, builder) => builder.append_value(0), Self::Decimal128(_, _, _, builder) => builder.append_value(0), Self::Decimal256(_, _, _, builder) => builder.append_value(i256::ZERO), - Self::Enum(indices, _) => indices.push(0), + Self::Enum(indices, _, _) => indices.push(0), Self::Duration(builder) => builder.append_null(), Self::Nullable(_, null_buffer, inner) => { null_buffer.append(false); @@ -339,6 +488,244 @@ impl Decoder { } } + /// Append a single default literal into the decoder's buffers + fn append_default(&mut self, lit: &AvroLiteral) -> Result<(), ArrowError> { + match self { + Self::Nullable(_, nb, inner) => { + if matches!(lit, AvroLiteral::Null) { + nb.append(false); + inner.append_null(); + Ok(()) + } else { + nb.append(true); + inner.append_default(lit) + } + } + Self::Null(count) => match lit { + AvroLiteral::Null => { + *count += 1; + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Non-null default for null type".to_string(), + )), + }, + Self::Boolean(b) => match lit { + AvroLiteral::Boolean(v) => { + b.append(*v); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for boolean must be boolean".to_string(), + )), + }, + Self::Int32(v) | Self::Date32(v) | Self::TimeMillis(v) => match lit { + AvroLiteral::Int(i) => { + v.push(*i); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for int32/date32/time-millis must be int".to_string(), + )), + }, + Self::Int64(v) + | Self::Int32ToInt64(v) + | Self::TimeMicros(v) + | Self::TimestampMillis(_, v) + | Self::TimestampMicros(_, v) => match lit { + AvroLiteral::Long(i) => { + v.push(*i); + Ok(()) + } + AvroLiteral::Int(i) => { + v.push(*i as i64); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for long/time-micros/timestamp must be long or int".to_string(), + )), + }, + Self::Float32(v) | Self::Int32ToFloat32(v) | Self::Int64ToFloat32(v) => match lit { + AvroLiteral::Float(f) => { + v.push(*f); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for float must be float".to_string(), + )), + }, + Self::Float64(v) + | Self::Int32ToFloat64(v) + | Self::Int64ToFloat64(v) + | Self::Float32ToFloat64(v) => match lit { + AvroLiteral::Double(f) => { + v.push(*f); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for double must be double".to_string(), + )), + }, + Self::Binary(offsets, values) | Self::StringToBytes(offsets, values) => match lit { + AvroLiteral::Bytes(b) => { + offsets.push_length(b.len()); + values.extend_from_slice(b); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for bytes must be bytes".to_string(), + )), + }, + Self::BytesToString(offsets, values) + | Self::String(offsets, values) + | Self::StringView(offsets, values) => match lit { + AvroLiteral::String(s) => { + let b = s.as_bytes(); + offsets.push_length(b.len()); + values.extend_from_slice(b); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for string must be string".to_string(), + )), + }, + Self::Uuid(values) => match lit { + AvroLiteral::String(s) => { + let uuid = Uuid::try_parse(s).map_err(|e| { + ArrowError::InvalidArgumentError(format!("Invalid UUID default: {s} ({e})")) + })?; + values.extend_from_slice(uuid.as_bytes()); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for uuid must be string".to_string(), + )), + }, + Self::Fixed(sz, accum) => match lit { + AvroLiteral::Bytes(b) => { + if b.len() != *sz as usize { + return Err(ArrowError::InvalidArgumentError(format!( + "Fixed default length {} does not match size {sz}", + b.len(), + ))); + } + accum.extend_from_slice(b); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for fixed must be bytes".to_string(), + )), + }, + #[cfg(feature = "small_decimals")] + Self::Decimal32(_, _, _, builder) => { + append_decimal_default!(lit, builder, 4, i32, "decimal32") + } + #[cfg(feature = "small_decimals")] + Self::Decimal64(_, _, _, builder) => { + append_decimal_default!(lit, builder, 8, i64, "decimal64") + } + Self::Decimal128(_, _, _, builder) => { + append_decimal_default!(lit, builder, 16, i128, "decimal128") + } + Self::Decimal256(_, _, _, builder) => { + append_decimal_default!(lit, builder, 32, i256, "decimal256") + } + Self::Duration(builder) => match lit { + AvroLiteral::Bytes(b) => { + if b.len() != 12 { + return Err(ArrowError::InvalidArgumentError(format!( + "Duration default must be exactly 12 bytes, got {}", + b.len() + ))); + } + let months = u32::from_le_bytes([b[0], b[1], b[2], b[3]]); + let days = u32::from_le_bytes([b[4], b[5], b[6], b[7]]); + let millis = u32::from_le_bytes([b[8], b[9], b[10], b[11]]); + let nanos = (millis as i64) * 1_000_000; + builder.append_value(IntervalMonthDayNano::new( + months as i32, + days as i32, + nanos, + )); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for duration must be 12-byte little-endian months/days/millis" + .to_string(), + )), + }, + Self::Array(_, offsets, inner) => match lit { + AvroLiteral::Array(items) => { + offsets.push_length(items.len()); + for item in items { + inner.append_default(item)?; + } + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for array must be an array literal".to_string(), + )), + }, + Self::Map(_, koff, moff, kdata, valdec) => match lit { + AvroLiteral::Map(entries) => { + moff.push_length(entries.len()); + for (k, v) in entries { + let kb = k.as_bytes(); + koff.push_length(kb.len()); + kdata.extend_from_slice(kb); + valdec.append_default(v)?; + } + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for map must be a map/object literal".to_string(), + )), + }, + Self::Enum(indices, symbols, _) => match lit { + AvroLiteral::Enum(sym) => { + let pos = symbols.iter().position(|s| s == sym).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Enum default symbol {sym:?} not in reader symbols" + )) + })?; + indices.push(pos as i32); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for enum must be a symbol".to_string(), + )), + }, + Self::Record(field_meta, decoders, projector) => match lit { + AvroLiteral::Map(entries) => { + for (i, dec) in decoders.iter_mut().enumerate() { + let name = field_meta[i].name(); + if let Some(sub) = entries.get(name) { + dec.append_default(sub)?; + } else if let Some(proj) = projector.as_ref() { + proj.project_default(dec, i)?; + } else { + dec.append_null(); + } + } + Ok(()) + } + AvroLiteral::Null => { + for (i, dec) in decoders.iter_mut().enumerate() { + if let Some(proj) = projector.as_ref() { + proj.project_default(dec, i)?; + } else { + dec.append_null(); + } + } + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for record must be a map/object or null".to_string(), + )), + }, + } + } + /// Decode a single record from `buf` fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { match self { @@ -353,7 +740,15 @@ impl Decoder { | Self::TimestampMicros(_, values) => values.push(buf.get_long()?), Self::Float32(values) => values.push(buf.get_float()?), Self::Float64(values) => values.push(buf.get_double()?), - Self::Binary(offsets, values) + Self::Int32ToInt64(values) => values.push(buf.get_int()? as i64), + Self::Int32ToFloat32(values) => values.push(buf.get_int()? as f32), + Self::Int32ToFloat64(values) => values.push(buf.get_int()? as f64), + Self::Int64ToFloat32(values) => values.push(buf.get_long()? as f32), + Self::Int64ToFloat64(values) => values.push(buf.get_long()? as f64), + Self::Float32ToFloat64(values) => values.push(buf.get_float()? as f64), + Self::StringToBytes(offsets, values) + | Self::BytesToString(offsets, values) + | Self::Binary(offsets, values) | Self::String(offsets, values) | Self::StringView(offsets, values) => { let data = buf.get_bytes()?; @@ -373,11 +768,14 @@ impl Decoder { let total_items = read_blocks(buf, |cursor| encoding.decode(cursor))?; off.push_length(total_items); } - Self::Record(_, encodings) => { + Self::Record(_, encodings, None) => { for encoding in encodings { encoding.decode(buf)?; } } + Self::Record(_, encodings, Some(proj)) => { + proj.project_record(buf, encodings)?; + } Self::Map(_, koff, moff, kdata, valdec) => { let newly_added = read_blocks(buf, |cur| { let kb = cur.get_bytes()?; @@ -391,29 +789,38 @@ impl Decoder { let fx = buf.get_fixed(*sz as usize)?; accum.extend_from_slice(fx); } + #[cfg(feature = "small_decimals")] + Self::Decimal32(_, _, size, builder) => { + decode_decimal!(size, buf, builder, 4, i32); + } + #[cfg(feature = "small_decimals")] + Self::Decimal64(_, _, size, builder) => { + decode_decimal!(size, buf, builder, 8, i64); + } Self::Decimal128(_, _, size, builder) => { - let raw = if let Some(s) = size { - buf.get_fixed(*s)? - } else { - buf.get_bytes()? - }; - let ext = sign_extend_to::<16>(raw)?; - let val = i128::from_be_bytes(ext); - builder.append_value(val); + decode_decimal!(size, buf, builder, 16, i128); } Self::Decimal256(_, _, size, builder) => { - let raw = if let Some(s) = size { - buf.get_fixed(*s)? - } else { - buf.get_bytes()? - }; - let ext = sign_extend_to::<32>(raw)?; - let val = i256::from_be_bytes(ext); - builder.append_value(val); + decode_decimal!(size, buf, builder, 32, i256); } - Self::Enum(indices, _) => { + Self::Enum(indices, _, None) => { indices.push(buf.get_int()?); } + Self::Enum(indices, _, Some(res)) => { + let raw = buf.get_int()?; + let resolved = usize::try_from(raw) + .ok() + .and_then(|idx| res.mapping.get(idx).copied()) + .filter(|&idx| idx >= 0) + .unwrap_or(res.default_index); + if resolved >= 0 { + indices.push(resolved); + } else { + return Err(ArrowError::ParseError(format!( + "Enum symbol index {raw} not resolvable and no default provided", + ))); + } + } Self::Duration(builder) => { let b = buf.get_fixed(12)?; let months = u32::from_le_bytes(b[0..4].try_into().unwrap()); @@ -428,12 +835,13 @@ impl Decoder { Nullability::NullFirst => branch != 0, Nullability::NullSecond => branch == 0, }; - nb.append(is_not_null); if is_not_null { + // It is important to decode before appending to null buffer in case of decode error encoding.decode(buf)?; } else { encoding.append_null(); } + nb.append(is_not_null); } } Ok(()) @@ -464,12 +872,21 @@ impl Decoder { ), Self::Float32(values) => Arc::new(flush_primitive::(values, nulls)), Self::Float64(values) => Arc::new(flush_primitive::(values, nulls)), - Self::Binary(offsets, values) => { + Self::Int32ToInt64(values) => Arc::new(flush_primitive::(values, nulls)), + Self::Int32ToFloat32(values) | Self::Int64ToFloat32(values) => { + Arc::new(flush_primitive::(values, nulls)) + } + Self::Int32ToFloat64(values) + | Self::Int64ToFloat64(values) + | Self::Float32ToFloat64(values) => { + Arc::new(flush_primitive::(values, nulls)) + } + Self::StringToBytes(offsets, values) | Self::Binary(offsets, values) => { let offsets = flush_offsets(offsets); let values = flush_values(values).into(); Arc::new(BinaryArray::new(offsets, values, nulls)) } - Self::String(offsets, values) => { + Self::BytesToString(offsets, values) | Self::String(offsets, values) => { let offsets = flush_offsets(offsets); let values = flush_values(values).into(); Arc::new(StringArray::new(offsets, values, nulls)) @@ -494,7 +911,7 @@ impl Decoder { let offsets = flush_offsets(offsets); Arc::new(ListArray::new(field.clone(), offsets, values, nulls)) } - Self::Record(fields, encodings) => { + Self::Record(fields, encodings, _) => { let arrays = encodings .iter_mut() .map(|x| x.flush(None)) @@ -523,14 +940,16 @@ impl Decoder { ))); } } - let entries_struct = StructArray::new( - Fields::from(vec![ - Arc::new(ArrowField::new("key", DataType::Utf8, false)), - Arc::new(ArrowField::new("value", val_arr.data_type().clone(), true)), - ]), - vec![Arc::new(key_arr), val_arr], - None, - ); + let entries_fields = match map_field.data_type() { + DataType::Struct(fields) => fields.clone(), + other => { + return Err(ArrowError::InvalidArgumentError(format!( + "Map entries field must be a Struct, got {other:?}" + ))) + } + }; + let entries_struct = + StructArray::new(entries_fields, vec![Arc::new(key_arr), val_arr], None); let map_arr = MapArray::new(map_field.clone(), moff, entries_struct, nulls, false); Arc::new(map_arr) } @@ -545,29 +964,21 @@ impl Decoder { .map_err(|e| ArrowError::ParseError(e.to_string()))?; Arc::new(arr) } + #[cfg(feature = "small_decimals")] + Self::Decimal32(precision, scale, _, builder) => { + flush_decimal!(builder, precision, scale, nulls, Decimal32Array) + } + #[cfg(feature = "small_decimals")] + Self::Decimal64(precision, scale, _, builder) => { + flush_decimal!(builder, precision, scale, nulls, Decimal64Array) + } Self::Decimal128(precision, scale, _, builder) => { - let (_, vals, _) = builder.finish().into_parts(); - let scl = scale.unwrap_or(0); - let dec = Decimal128Array::new(vals, nulls) - .with_precision_and_scale(*precision as u8, scl as i8) - .map_err(|e| ArrowError::ParseError(e.to_string()))?; - Arc::new(dec) + flush_decimal!(builder, precision, scale, nulls, Decimal128Array) } Self::Decimal256(precision, scale, _, builder) => { - let (_, vals, _) = builder.finish().into_parts(); - let scl = scale.unwrap_or(0); - let dec = Decimal256Array::new(vals, nulls) - .with_precision_and_scale(*precision as u8, scl as i8) - .map_err(|e| ArrowError::ParseError(e.to_string()))?; - Arc::new(dec) - } - Self::Enum(indices, symbols) => { - let keys = flush_primitive::(indices, nulls); - let values = Arc::new(StringArray::from( - symbols.iter().map(|s| s.as_str()).collect::>(), - )); - Arc::new(DictionaryArray::try_new(keys, values)?) + flush_decimal!(builder, precision, scale, nulls, Decimal256Array) } + Self::Enum(indices, symbols, _) => flush_dict(indices, symbols, nulls)?, Self::Duration(builder) => { let (_, vals, _) = builder.finish().into_parts(); let vals = IntervalMonthDayNanoArray::try_new(vals, nulls) @@ -578,19 +989,52 @@ impl Decoder { } } +#[derive(Debug, Copy, Clone)] +enum NegativeBlockBehavior { + ProcessItems, + SkipBySize, +} + +#[inline] +fn skip_blocks( + buf: &mut AvroCursor, + mut skip_item: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, +) -> Result { + process_blockwise( + buf, + move |c| skip_item(c), + NegativeBlockBehavior::SkipBySize, + ) +} + +#[inline] +fn flush_dict( + indices: &mut Vec, + symbols: &[String], + nulls: Option, +) -> Result { + let keys = flush_primitive::(indices, nulls); + let values = Arc::new(StringArray::from_iter_values( + symbols.iter().map(|s| s.as_str()), + )); + DictionaryArray::try_new(keys, values) + .map_err(|e| ArrowError::ParseError(e.to_string())) + .map(|arr| Arc::new(arr) as ArrayRef) +} + #[inline] fn read_blocks( buf: &mut AvroCursor, decode_entry: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, ) -> Result { - read_blockwise_items(buf, true, decode_entry) + process_blockwise(buf, decode_entry, NegativeBlockBehavior::ProcessItems) } #[inline] -fn read_blockwise_items( +fn process_blockwise( buf: &mut AvroCursor, - read_size_after_negative: bool, - mut decode_fn: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, + mut on_item: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, + negative_behavior: NegativeBlockBehavior, ) -> Result { let mut total = 0usize; loop { @@ -602,22 +1046,27 @@ fn read_blockwise_items( match block_count.cmp(&0) { Ordering::Equal => break, Ordering::Less => { - // If block_count is negative, read the absolute value of count, - // then read the block size as a long and discard let count = (-block_count) as usize; - if read_size_after_negative { - let _size_in_bytes = buf.get_long()?; - } - for _ in 0..count { - decode_fn(buf)?; + // A negative count is followed by a long of the size in bytes + let size_in_bytes = buf.get_long()? as usize; + match negative_behavior { + NegativeBlockBehavior::ProcessItems => { + // Process items one-by-one after reading size + for _ in 0..count { + on_item(buf)?; + } + } + NegativeBlockBehavior::SkipBySize => { + // Skip the entire block payload at once + let _ = buf.get_fixed(size_in_bytes)?; + } } total += count; } Ordering::Greater => { - // If block_count is positive, decode that many items let count = block_count as usize; - for _i in 0..count { - decode_fn(buf)?; + for _ in 0..count { + on_item(buf)?; } total += count; } @@ -644,38 +1093,350 @@ fn flush_primitive( PrimitiveArray::new(flush_values(values).into(), nulls) } -/// Sign extends a byte slice to a fixed-size array of N bytes. -/// This is done by filling the leading bytes with 0x00 for positive numbers -/// or 0xFF for negative numbers. #[inline] -fn sign_extend_to(raw: &[u8]) -> Result<[u8; N], ArrowError> { - if raw.len() > N { - return Err(ArrowError::ParseError(format!( - "Cannot extend a slice of length {} to {} bytes.", - raw.len(), - N - ))); - } - let mut arr = [0u8; N]; - let pad_len = N - raw.len(); - // Determine the byte to use for padding based on the sign bit of the raw data. - let extension_byte = if raw.is_empty() || (raw[0] & 0x80 == 0) { - 0x00 - } else { - 0xFF - }; - arr[..pad_len].fill(extension_byte); - arr[pad_len..].copy_from_slice(raw); - Ok(arr) +fn read_decimal_bytes_be( + buf: &mut AvroCursor<'_>, + size: &Option, +) -> Result<[u8; N], ArrowError> { + match size { + Some(n) if *n == N => { + let raw = buf.get_fixed(N)?; + let mut arr = [0u8; N]; + arr.copy_from_slice(raw); + Ok(arr) + } + Some(n) => { + let raw = buf.get_fixed(*n)?; + sign_cast_to::(raw) + } + None => { + let raw = buf.get_bytes()?; + sign_cast_to::(raw) + } + } +} + +/// Sign-extend or (when larger) validate-and-truncate a big-endian two's-complement +/// integer into exactly `N` bytes. This matches Avro's decimal binary encoding: +/// the payload is a big-endian two's-complement integer, and when narrowing it must +/// be representable without changing sign or value. +/// +/// If `raw.len() < N`, the value is sign-extended. +/// If `raw.len() > N`, all truncated leading bytes must match the sign-extension byte +/// and the MSB of the first kept byte must match the sign (to avoid silent overflow). +#[inline] +fn sign_cast_to(raw: &[u8]) -> Result<[u8; N], ArrowError> { + let len = raw.len(); + // Fast path: exact width, just copy + if len == N { + let mut out = [0u8; N]; + out.copy_from_slice(raw); + return Ok(out); + } + // Determine sign byte from MSB of first byte (empty => positive) + let first = raw.first().copied().unwrap_or(0u8); + let sign_byte = if (first & 0x80) == 0 { 0x00 } else { 0xFF }; + // Pre-fill with sign byte to support sign extension + let mut out = [sign_byte; N]; + if len > N { + // Validate truncation: all dropped leading bytes must equal sign_byte, + // and the MSB of the first kept byte must match the sign. + let extra = len - N; + // Any non-sign byte in the truncated prefix indicates overflow + if raw[..extra].iter().any(|&b| b != sign_byte) { + return Err(ArrowError::ParseError(format!( + "Decimal value with {} bytes cannot be represented in {} bytes without overflow", + len, N + ))); + } + if N > 0 { + let first_kept = raw[extra]; + let sign_bit_mismatch = ((first_kept ^ sign_byte) & 0x80) != 0; + if sign_bit_mismatch { + return Err(ArrowError::ParseError(format!( + "Decimal value with {} bytes cannot be represented in {} bytes without overflow", + len, N + ))); + } + } + out.copy_from_slice(&raw[extra..]); + return Ok(out); + } + out[N - len..].copy_from_slice(raw); + Ok(out) +} + +#[derive(Debug)] +struct Projector { + writer_to_reader: Arc<[Option]>, + skip_decoders: Vec>, + field_defaults: Vec>, + default_injections: Arc<[(usize, AvroLiteral)]>, +} + +#[derive(Debug)] +struct ProjectorBuilder<'a> { + rec: &'a ResolvedRecord, + reader_fields: Arc<[AvroField]>, +} + +impl<'a> ProjectorBuilder<'a> { + #[inline] + fn try_new(rec: &'a ResolvedRecord, reader_fields: &Arc<[AvroField]>) -> Self { + Self { + rec, + reader_fields: reader_fields.clone(), + } + } + + #[inline] + fn build(self) -> Result { + let reader_fields = self.reader_fields; + let mut field_defaults: Vec> = Vec::with_capacity(reader_fields.len()); + for avro_field in reader_fields.as_ref() { + if let Some(ResolutionInfo::DefaultValue(lit)) = + avro_field.data_type().resolution.as_ref() + { + field_defaults.push(Some(lit.clone())); + } else { + field_defaults.push(None); + } + } + let mut default_injections: Vec<(usize, AvroLiteral)> = + Vec::with_capacity(self.rec.default_fields.len()); + for &idx in self.rec.default_fields.as_ref() { + let lit = field_defaults + .get(idx) + .and_then(|lit| lit.clone()) + .unwrap_or(AvroLiteral::Null); + default_injections.push((idx, lit)); + } + let mut skip_decoders: Vec> = + Vec::with_capacity(self.rec.skip_fields.len()); + for datatype in self.rec.skip_fields.as_ref() { + let skipper = match datatype { + Some(datatype) => Some(Skipper::from_avro(datatype)?), + None => None, + }; + skip_decoders.push(skipper); + } + Ok(Projector { + writer_to_reader: self.rec.writer_to_reader.clone(), + skip_decoders, + field_defaults, + default_injections: default_injections.into(), + }) + } +} + +impl Projector { + #[inline] + fn project_default(&self, decoder: &mut Decoder, index: usize) -> Result<(), ArrowError> { + // SAFETY: `index` is obtained by listing the reader's record fields (i.e., from + // `decoders.iter_mut().enumerate()`), and `field_defaults` was built in + // `ProjectorBuilder::build` to have exactly one element per reader field. + // Therefore, `index < self.field_defaults.len()` always holds here, so + // `self.field_defaults[index]` cannot panic. We only take an immutable reference + // via `.as_ref()`, and `self` is borrowed immutably. + if let Some(default_literal) = self.field_defaults[index].as_ref() { + decoder.append_default(default_literal) + } else { + decoder.append_null(); + Ok(()) + } + } + + #[inline] + fn project_record( + &mut self, + buf: &mut AvroCursor<'_>, + encodings: &mut [Decoder], + ) -> Result<(), ArrowError> { + debug_assert_eq!( + self.writer_to_reader.len(), + self.skip_decoders.len(), + "internal invariant: mapping and skipper lists must have equal length" + ); + for (i, (mapping, skipper_opt)) in self + .writer_to_reader + .iter() + .zip(self.skip_decoders.iter_mut()) + .enumerate() + { + match (mapping, skipper_opt.as_mut()) { + (Some(reader_index), _) => encodings[*reader_index].decode(buf)?, + (None, Some(skipper)) => skipper.skip(buf)?, + (None, None) => { + return Err(ArrowError::SchemaError(format!( + "No skipper available for writer-only field at index {i}", + ))); + } + } + } + for (reader_index, lit) in self.default_injections.as_ref() { + encodings[*reader_index].append_default(lit)?; + } + Ok(()) + } +} + +/// Lightweight skipper for non‑projected writer fields +/// (fields present in the writer schema but omitted by the reader/projection); +/// per Avro 1.11.1 schema resolution these fields are ignored. +/// +/// +#[derive(Debug)] +enum Skipper { + Null, + Boolean, + Int32, + Int64, + Float32, + Float64, + Bytes, + String, + Date32, + TimeMillis, + TimeMicros, + TimestampMillis, + TimestampMicros, + Fixed(usize), + Decimal(Option), + UuidString, + Enum, + DurationFixed12, + List(Box), + Map(Box), + Struct(Vec), + Nullable(Nullability, Box), +} + +impl Skipper { + fn from_avro(dt: &AvroDataType) -> Result { + let mut base = match dt.codec() { + Codec::Null => Self::Null, + Codec::Boolean => Self::Boolean, + Codec::Int32 | Codec::Date32 | Codec::TimeMillis => Self::Int32, + Codec::Int64 => Self::Int64, + Codec::TimeMicros => Self::TimeMicros, + Codec::TimestampMillis(_) => Self::TimestampMillis, + Codec::TimestampMicros(_) => Self::TimestampMicros, + Codec::Float32 => Self::Float32, + Codec::Float64 => Self::Float64, + Codec::Binary => Self::Bytes, + Codec::Utf8 | Codec::Utf8View => Self::String, + Codec::Fixed(sz) => Self::Fixed(*sz as usize), + Codec::Decimal(_, _, size) => Self::Decimal(*size), + Codec::Uuid => Self::UuidString, // encoded as string + Codec::Enum(_) => Self::Enum, + Codec::List(item) => Self::List(Box::new(Skipper::from_avro(item)?)), + Codec::Struct(fields) => Self::Struct( + fields + .iter() + .map(|f| Skipper::from_avro(f.data_type())) + .collect::>()?, + ), + Codec::Map(values) => Self::Map(Box::new(Skipper::from_avro(values)?)), + Codec::Interval => Self::DurationFixed12, + _ => { + return Err(ArrowError::NotYetImplemented(format!( + "Skipper not implemented for codec {:?}", + dt.codec() + ))); + } + }; + if let Some(n) = dt.nullability() { + base = Self::Nullable(n, Box::new(base)); + } + Ok(base) + } + + fn skip(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { + match self { + Self::Null => Ok(()), + Self::Boolean => { + buf.get_bool()?; + Ok(()) + } + Self::Int32 | Self::Date32 | Self::TimeMillis => { + buf.get_int()?; + Ok(()) + } + Self::Int64 | Self::TimeMicros | Self::TimestampMillis | Self::TimestampMicros => { + buf.get_long()?; + Ok(()) + } + Self::Float32 => { + buf.get_float()?; + Ok(()) + } + Self::Float64 => { + buf.get_double()?; + Ok(()) + } + Self::Bytes | Self::String | Self::UuidString => { + buf.get_bytes()?; + Ok(()) + } + Self::Fixed(sz) => { + buf.get_fixed(*sz)?; + Ok(()) + } + Self::Decimal(size) => { + if let Some(s) = size { + buf.get_fixed(*s) + } else { + buf.get_bytes() + }?; + Ok(()) + } + Self::Enum => { + buf.get_int()?; + Ok(()) + } + Self::DurationFixed12 => { + buf.get_fixed(12)?; + Ok(()) + } + Self::List(item) => { + skip_blocks(buf, |c| item.skip(c))?; + Ok(()) + } + Self::Map(value) => { + skip_blocks(buf, |c| { + c.get_bytes()?; // key + value.skip(c) + })?; + Ok(()) + } + Self::Struct(fields) => { + for f in fields.iter_mut() { + f.skip(buf)? + } + Ok(()) + } + Self::Nullable(order, inner) => { + let branch = buf.read_vlq()?; + let is_not_null = match *order { + Nullability::NullFirst => branch != 0, + Nullability::NullSecond => branch == 0, + }; + if is_not_null { + inner.skip(buf)?; + } + Ok(()) + } + } + } } #[cfg(test)] mod tests { use super::*; - use arrow_array::{ - cast::AsArray, Array, Decimal128Array, DictionaryArray, FixedSizeBinaryArray, - IntervalMonthDayNanoArray, ListArray, MapArray, StringArray, StructArray, - }; + use crate::codec::AvroField; + use crate::schema::{PrimitiveType, Schema, TypeName}; + use arrow_array::cast::AsArray; + use indexmap::IndexMap; fn encode_avro_int(value: i32) -> Vec { let mut buf = Vec::new(); @@ -709,46 +1470,225 @@ mod tests { AvroDataType::new(codec, Default::default(), None) } + fn decoder_for_promotion( + writer: PrimitiveType, + reader: PrimitiveType, + use_utf8view: bool, + ) -> Decoder { + let ws = Schema::TypeName(TypeName::Primitive(writer)); + let rs = Schema::TypeName(TypeName::Primitive(reader)); + let field = + AvroField::resolve_from_writer_and_reader(&ws, &rs, use_utf8view, false).unwrap(); + Decoder::try_new(field.data_type()).unwrap() + } + #[test] - fn test_map_decoding_one_entry() { - let value_type = avro_from_codec(Codec::Utf8); - let map_type = avro_from_codec(Codec::Map(Arc::new(value_type))); - let mut decoder = Decoder::try_new(&map_type).unwrap(); - // Encode a single map with one entry: {"hello": "world"} - let mut data = Vec::new(); - data.extend_from_slice(&encode_avro_long(1)); - data.extend_from_slice(&encode_avro_bytes(b"hello")); // key - data.extend_from_slice(&encode_avro_bytes(b"world")); // value - data.extend_from_slice(&encode_avro_long(0)); - let mut cursor = AvroCursor::new(&data); - decoder.decode(&mut cursor).unwrap(); - let array = decoder.flush(None).unwrap(); - let map_arr = array.as_any().downcast_ref::().unwrap(); - assert_eq!(map_arr.len(), 1); // one map - assert_eq!(map_arr.value_length(0), 1); - let entries = map_arr.value(0); - let struct_entries = entries.as_any().downcast_ref::().unwrap(); - assert_eq!(struct_entries.len(), 1); - let key_arr = struct_entries - .column_by_name("key") - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - let val_arr = struct_entries - .column_by_name("value") - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - assert_eq!(key_arr.value(0), "hello"); - assert_eq!(val_arr.value(0), "world"); + fn test_schema_resolution_promotion_int_to_long() { + let mut dec = decoder_for_promotion(PrimitiveType::Int, PrimitiveType::Long, false); + assert!(matches!(dec, Decoder::Int32ToInt64(_))); + for v in [0, 1, -2, 123456] { + let data = encode_avro_int(v); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 0); + assert_eq!(a.value(1), 1); + assert_eq!(a.value(2), -2); + assert_eq!(a.value(3), 123456); } #[test] - fn test_map_decoding_empty() { - let value_type = avro_from_codec(Codec::Utf8); - let map_type = avro_from_codec(Codec::Map(Arc::new(value_type))); + fn test_schema_resolution_promotion_int_to_float() { + let mut dec = decoder_for_promotion(PrimitiveType::Int, PrimitiveType::Float, false); + assert!(matches!(dec, Decoder::Int32ToFloat32(_))); + for v in [0, 42, -7] { + let data = encode_avro_int(v); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 0.0); + assert_eq!(a.value(1), 42.0); + assert_eq!(a.value(2), -7.0); + } + + #[test] + fn test_schema_resolution_promotion_int_to_double() { + let mut dec = decoder_for_promotion(PrimitiveType::Int, PrimitiveType::Double, false); + assert!(matches!(dec, Decoder::Int32ToFloat64(_))); + for v in [1, -1, 10_000] { + let data = encode_avro_int(v); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 1.0); + assert_eq!(a.value(1), -1.0); + assert_eq!(a.value(2), 10_000.0); + } + + #[test] + fn test_schema_resolution_promotion_long_to_float() { + let mut dec = decoder_for_promotion(PrimitiveType::Long, PrimitiveType::Float, false); + assert!(matches!(dec, Decoder::Int64ToFloat32(_))); + for v in [0_i64, 1_000_000_i64, -123_i64] { + let data = encode_avro_long(v); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 0.0); + assert_eq!(a.value(1), 1_000_000.0); + assert_eq!(a.value(2), -123.0); + } + + #[test] + fn test_schema_resolution_promotion_long_to_double() { + let mut dec = decoder_for_promotion(PrimitiveType::Long, PrimitiveType::Double, false); + assert!(matches!(dec, Decoder::Int64ToFloat64(_))); + for v in [2_i64, -2_i64, 9_223_372_i64] { + let data = encode_avro_long(v); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 2.0); + assert_eq!(a.value(1), -2.0); + assert_eq!(a.value(2), 9_223_372.0); + } + + #[test] + fn test_schema_resolution_promotion_float_to_double() { + let mut dec = decoder_for_promotion(PrimitiveType::Float, PrimitiveType::Double, false); + assert!(matches!(dec, Decoder::Float32ToFloat64(_))); + for v in [0.5_f32, -3.25_f32, 1.0e6_f32] { + let data = v.to_le_bytes().to_vec(); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 0.5_f64); + assert_eq!(a.value(1), -3.25_f64); + assert_eq!(a.value(2), 1.0e6_f64); + } + + #[test] + fn test_schema_resolution_promotion_bytes_to_string_utf8() { + let mut dec = decoder_for_promotion(PrimitiveType::Bytes, PrimitiveType::String, false); + assert!(matches!(dec, Decoder::BytesToString(_, _))); + for s in ["hello", "world", "héllo"] { + let data = encode_avro_bytes(s.as_bytes()); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), "hello"); + assert_eq!(a.value(1), "world"); + assert_eq!(a.value(2), "héllo"); + } + + #[test] + fn test_schema_resolution_promotion_bytes_to_string_utf8view_enabled() { + let mut dec = decoder_for_promotion(PrimitiveType::Bytes, PrimitiveType::String, true); + assert!(matches!(dec, Decoder::BytesToString(_, _))); + let data = encode_avro_bytes("abc".as_bytes()); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), "abc"); + } + + #[test] + fn test_schema_resolution_promotion_string_to_bytes() { + let mut dec = decoder_for_promotion(PrimitiveType::String, PrimitiveType::Bytes, false); + assert!(matches!(dec, Decoder::StringToBytes(_, _))); + for s in ["", "abc", "data"] { + let data = encode_avro_bytes(s.as_bytes()); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), b""); + assert_eq!(a.value(1), b"abc"); + assert_eq!(a.value(2), "data".as_bytes()); + } + + #[test] + fn test_schema_resolution_no_promotion_passthrough_int() { + let ws = Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)); + let rs = Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)); + let field = AvroField::resolve_from_writer_and_reader(&ws, &rs, false, false).unwrap(); + let mut dec = Decoder::try_new(field.data_type()).unwrap(); + assert!(matches!(dec, Decoder::Int32(_))); + for v in [7, -9] { + let data = encode_avro_int(v); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 7); + assert_eq!(a.value(1), -9); + } + + #[test] + fn test_schema_resolution_illegal_promotion_int_to_boolean_errors() { + let ws = Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)); + let rs = Schema::TypeName(TypeName::Primitive(PrimitiveType::Boolean)); + let res = AvroField::resolve_from_writer_and_reader(&ws, &rs, false, false); + assert!(res.is_err(), "expected error for illegal promotion"); + } + + #[test] + fn test_map_decoding_one_entry() { + let value_type = avro_from_codec(Codec::Utf8); + let map_type = avro_from_codec(Codec::Map(Arc::new(value_type))); + let mut decoder = Decoder::try_new(&map_type).unwrap(); + // Encode a single map with one entry: {"hello": "world"} + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_long(1)); + data.extend_from_slice(&encode_avro_bytes(b"hello")); // key + data.extend_from_slice(&encode_avro_bytes(b"world")); // value + data.extend_from_slice(&encode_avro_long(0)); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + let array = decoder.flush(None).unwrap(); + let map_arr = array.as_any().downcast_ref::().unwrap(); + assert_eq!(map_arr.len(), 1); // one map + assert_eq!(map_arr.value_length(0), 1); + let entries = map_arr.value(0); + let struct_entries = entries.as_any().downcast_ref::().unwrap(); + assert_eq!(struct_entries.len(), 1); + let key_arr = struct_entries + .column_by_name("key") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let val_arr = struct_entries + .column_by_name("value") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(key_arr.value(0), "hello"); + assert_eq!(val_arr.value(0), "world"); + } + + #[test] + fn test_map_decoding_empty() { + let value_type = avro_from_codec(Codec::Utf8); + let map_type = avro_from_codec(Codec::Map(Arc::new(value_type))); let mut decoder = Decoder::try_new(&map_type).unwrap(); let data = encode_avro_long(0); decoder.decode(&mut AvroCursor::new(&data)).unwrap(); @@ -942,7 +1882,7 @@ mod tests { #[test] fn test_decimal_decoding_fixed256() { - let dt = avro_from_codec(Codec::Decimal(5, Some(2), Some(32))); + let dt = avro_from_codec(Codec::Decimal(50, Some(2), Some(32))); let mut decoder = Decoder::try_new(&dt).unwrap(); let row1 = [ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, @@ -969,7 +1909,7 @@ mod tests { #[test] fn test_decimal_decoding_fixed128() { - let dt = avro_from_codec(Codec::Decimal(5, Some(2), Some(16))); + let dt = avro_from_codec(Codec::Decimal(28, Some(2), Some(16))); let mut decoder = Decoder::try_new(&dt).unwrap(); let row1 = [ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, @@ -992,6 +1932,79 @@ mod tests { assert_eq!(dec.value_as_string(1), "-1.23"); } + #[test] + fn test_decimal_decoding_fixed32_from_32byte_fixed_storage() { + let dt = avro_from_codec(Codec::Decimal(5, Some(2), Some(32))); + let mut decoder = Decoder::try_new(&dt).unwrap(); + let row1 = [ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x30, 0x39, + ]; + let row2 = [ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0x85, + ]; + let mut data = Vec::new(); + data.extend_from_slice(&row1); + data.extend_from_slice(&row2); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let arr = decoder.flush(None).unwrap(); + #[cfg(feature = "small_decimals")] + { + let dec = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec.len(), 2); + assert_eq!(dec.value_as_string(0), "123.45"); + assert_eq!(dec.value_as_string(1), "-1.23"); + } + #[cfg(not(feature = "small_decimals"))] + { + let dec = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec.len(), 2); + assert_eq!(dec.value_as_string(0), "123.45"); + assert_eq!(dec.value_as_string(1), "-1.23"); + } + } + + #[test] + fn test_decimal_decoding_fixed32_from_16byte_fixed_storage() { + let dt = avro_from_codec(Codec::Decimal(5, Some(2), Some(16))); + let mut decoder = Decoder::try_new(&dt).unwrap(); + let row1 = [ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x30, 0x39, + ]; + let row2 = [ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0x85, + ]; + let mut data = Vec::new(); + data.extend_from_slice(&row1); + data.extend_from_slice(&row2); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + + let arr = decoder.flush(None).unwrap(); + #[cfg(feature = "small_decimals")] + { + let dec = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec.len(), 2); + assert_eq!(dec.value_as_string(0), "123.45"); + assert_eq!(dec.value_as_string(1), "-1.23"); + } + #[cfg(not(feature = "small_decimals"))] + { + let dec = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec.len(), 2); + assert_eq!(dec.value_as_string(0), "123.45"); + assert_eq!(dec.value_as_string(1), "-1.23"); + } + } + #[test] fn test_decimal_decoding_bytes_with_nulls() { let dt = avro_from_codec(Codec::Decimal(4, Some(1), None)); @@ -1008,21 +2021,34 @@ mod tests { data.extend_from_slice(&encode_avro_int(0)); data.extend_from_slice(&encode_avro_bytes(&[0xFB, 0x2E])); let mut cursor = AvroCursor::new(&data); - decoder.decode(&mut cursor).unwrap(); // row1 - decoder.decode(&mut cursor).unwrap(); // row2 - decoder.decode(&mut cursor).unwrap(); // row3 + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); let arr = decoder.flush(None).unwrap(); - let dec_arr = arr.as_any().downcast_ref::().unwrap(); - assert_eq!(dec_arr.len(), 3); - assert!(dec_arr.is_valid(0)); - assert!(!dec_arr.is_valid(1)); - assert!(dec_arr.is_valid(2)); - assert_eq!(dec_arr.value_as_string(0), "123.4"); - assert_eq!(dec_arr.value_as_string(2), "-123.4"); + #[cfg(feature = "small_decimals")] + { + let dec_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 3); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); + assert_eq!(dec_arr.value_as_string(0), "123.4"); + assert_eq!(dec_arr.value_as_string(2), "-123.4"); + } + #[cfg(not(feature = "small_decimals"))] + { + let dec_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 3); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); + assert_eq!(dec_arr.value_as_string(0), "123.4"); + assert_eq!(dec_arr.value_as_string(2), "-123.4"); + } } #[test] - fn test_decimal_decoding_bytes_with_nulls_fixed_size() { + fn test_decimal_decoding_bytes_with_nulls_fixed_size_narrow_result() { let dt = avro_from_codec(Codec::Decimal(6, Some(2), Some(16))); let inner = Decoder::try_new(&dt).unwrap(); let mut decoder = Decoder::Nullable( @@ -1049,13 +2075,26 @@ mod tests { decoder.decode(&mut cursor).unwrap(); decoder.decode(&mut cursor).unwrap(); let arr = decoder.flush(None).unwrap(); - let dec_arr = arr.as_any().downcast_ref::().unwrap(); - assert_eq!(dec_arr.len(), 3); - assert!(dec_arr.is_valid(0)); - assert!(!dec_arr.is_valid(1)); - assert!(dec_arr.is_valid(2)); - assert_eq!(dec_arr.value_as_string(0), "1234.56"); - assert_eq!(dec_arr.value_as_string(2), "-1234.56"); + #[cfg(feature = "small_decimals")] + { + let dec_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 3); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); + assert_eq!(dec_arr.value_as_string(0), "1234.56"); + assert_eq!(dec_arr.value_as_string(2), "-1234.56"); + } + #[cfg(not(feature = "small_decimals"))] + { + let dec_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 3); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); + assert_eq!(dec_arr.value_as_string(0), "1234.56"); + assert_eq!(dec_arr.value_as_string(2), "-1234.56"); + } } #[test] @@ -1076,7 +2115,6 @@ mod tests { .as_any() .downcast_ref::>() .unwrap(); - assert_eq!(dict_array.len(), 3); let values = dict_array .values() @@ -1188,4 +2226,778 @@ mod tests { let array = decoder.flush(None).unwrap(); assert_eq!(array.len(), 0); } + + #[test] + fn test_nullable_decode_error_bitmap_corruption() { + // Nullable Int32 with ['T','null'] encoding (NullSecond) + let avro_type = AvroDataType::new( + Codec::Int32, + Default::default(), + Some(Nullability::NullSecond), + ); + let mut decoder = Decoder::try_new(&avro_type).unwrap(); + + // Row 1: union branch 1 (null) + let mut row1 = Vec::new(); + row1.extend_from_slice(&encode_avro_int(1)); + + // Row 2: union branch 0 (non-null) but missing the int payload -> decode error + let mut row2 = Vec::new(); + row2.extend_from_slice(&encode_avro_int(0)); // branch = 0 => non-null + + // Row 3: union branch 0 (non-null) with correct int payload -> should succeed + let mut row3 = Vec::new(); + row3.extend_from_slice(&encode_avro_int(0)); // branch + row3.extend_from_slice(&encode_avro_int(42)); // actual value + + decoder.decode(&mut AvroCursor::new(&row1)).unwrap(); + assert!(decoder.decode(&mut AvroCursor::new(&row2)).is_err()); // decode error + decoder.decode(&mut AvroCursor::new(&row3)).unwrap(); + + let array = decoder.flush(None).unwrap(); + + // Should contain 2 elements: row1 (null) and row3 (42) + assert_eq!(array.len(), 2); + let int_array = array.as_any().downcast_ref::().unwrap(); + assert!(int_array.is_null(0)); // row1 is null + assert_eq!(int_array.value(1), 42); // row3 value is 42 + } + + #[test] + fn test_enum_mapping_reordered_symbols() { + let reader_symbols: Arc<[String]> = + vec!["B".to_string(), "C".to_string(), "A".to_string()].into(); + let mapping: Arc<[i32]> = Arc::from(vec![2, 0, 1]); + let default_index: i32 = -1; + let mut dec = Decoder::Enum( + Vec::with_capacity(DEFAULT_CAPACITY), + reader_symbols.clone(), + Some(EnumResolution { + mapping, + default_index, + }), + ); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_int(2)); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + dec.decode(&mut cur).unwrap(); + dec.decode(&mut cur).unwrap(); + let arr = dec.flush(None).unwrap(); + let dict = arr + .as_any() + .downcast_ref::>() + .unwrap(); + let expected_keys = Int32Array::from(vec![2, 0, 1]); + assert_eq!(dict.keys(), &expected_keys); + let values = dict + .values() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(values.value(0), "B"); + assert_eq!(values.value(1), "C"); + assert_eq!(values.value(2), "A"); + } + + #[test] + fn test_enum_mapping_unknown_symbol_and_out_of_range_fall_back_to_default() { + let reader_symbols: Arc<[String]> = vec!["A".to_string(), "B".to_string()].into(); + let default_index: i32 = 1; + let mapping: Arc<[i32]> = Arc::from(vec![0, 1]); + let mut dec = Decoder::Enum( + Vec::with_capacity(DEFAULT_CAPACITY), + reader_symbols.clone(), + Some(EnumResolution { + mapping, + default_index, + }), + ); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_int(99)); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + dec.decode(&mut cur).unwrap(); + dec.decode(&mut cur).unwrap(); + let arr = dec.flush(None).unwrap(); + let dict = arr + .as_any() + .downcast_ref::>() + .unwrap(); + let expected_keys = Int32Array::from(vec![0, 1, 1]); + assert_eq!(dict.keys(), &expected_keys); + let values = dict + .values() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(values.value(0), "A"); + assert_eq!(values.value(1), "B"); + } + + #[test] + fn test_enum_mapping_unknown_symbol_without_default_errors() { + let reader_symbols: Arc<[String]> = vec!["A".to_string()].into(); + let default_index: i32 = -1; // indicates no default at type-level + let mapping: Arc<[i32]> = Arc::from(vec![-1]); + let mut dec = Decoder::Enum( + Vec::with_capacity(DEFAULT_CAPACITY), + reader_symbols, + Some(EnumResolution { + mapping, + default_index, + }), + ); + let data = encode_avro_int(0); + let mut cur = AvroCursor::new(&data); + let err = dec + .decode(&mut cur) + .expect_err("expected decode error for unresolved enum without default"); + let msg = err.to_string(); + assert!( + msg.contains("not resolvable") && msg.contains("no default"), + "unexpected error message: {msg}" + ); + } + + fn make_record_resolved_decoder( + reader_fields: &[(&str, DataType, bool)], + writer_to_reader: Vec>, + skip_decoders: Vec>, + ) -> Decoder { + let mut field_refs: Vec = Vec::with_capacity(reader_fields.len()); + let mut encodings: Vec = Vec::with_capacity(reader_fields.len()); + for (name, dt, nullable) in reader_fields { + field_refs.push(Arc::new(ArrowField::new(*name, dt.clone(), *nullable))); + let enc = match dt { + DataType::Int32 => Decoder::Int32(Vec::new()), + DataType::Int64 => Decoder::Int64(Vec::new()), + DataType::Utf8 => { + Decoder::String(OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::new()) + } + other => panic!("Unsupported test reader field type: {other:?}"), + }; + encodings.push(enc); + } + let fields: Fields = field_refs.into(); + Decoder::Record( + fields, + encodings, + Some(Projector { + writer_to_reader: Arc::from(writer_to_reader), + skip_decoders, + field_defaults: vec![None; reader_fields.len()], + default_injections: Arc::from(Vec::<(usize, AvroLiteral)>::new()), + }), + ) + } + + #[test] + fn test_skip_writer_trailing_field_int32() { + let mut dec = make_record_resolved_decoder( + &[("id", arrow_schema::DataType::Int32, false)], + vec![Some(0), None], + vec![None, Some(super::Skipper::Int32)], + ); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(7)); + data.extend_from_slice(&encode_avro_int(999)); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + assert_eq!(cur.position(), data.len()); + let arr = dec.flush(None).unwrap(); + let struct_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(struct_arr.len(), 1); + let id = struct_arr + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(id.value(0), 7); + } + + #[test] + fn test_skip_writer_middle_field_string() { + let mut dec = make_record_resolved_decoder( + &[ + ("id", DataType::Int32, false), + ("score", DataType::Int64, false), + ], + vec![Some(0), None, Some(1)], + vec![None, Some(Skipper::String), None], + ); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(42)); + data.extend_from_slice(&encode_avro_bytes(b"abcdef")); + data.extend_from_slice(&encode_avro_long(1000)); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + assert_eq!(cur.position(), data.len()); + let arr = dec.flush(None).unwrap(); + let s = arr.as_any().downcast_ref::().unwrap(); + let id = s + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let score = s + .column_by_name("score") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(id.value(0), 42); + assert_eq!(score.value(0), 1000); + } + + #[test] + fn test_skip_writer_array_with_negative_block_count_fast() { + let mut dec = make_record_resolved_decoder( + &[("id", DataType::Int32, false)], + vec![None, Some(0)], + vec![Some(super::Skipper::List(Box::new(Skipper::Int32))), None], + ); + let mut array_payload = Vec::new(); + array_payload.extend_from_slice(&encode_avro_int(1)); + array_payload.extend_from_slice(&encode_avro_int(2)); + array_payload.extend_from_slice(&encode_avro_int(3)); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_long(-3)); + data.extend_from_slice(&encode_avro_long(array_payload.len() as i64)); + data.extend_from_slice(&array_payload); + data.extend_from_slice(&encode_avro_long(0)); + data.extend_from_slice(&encode_avro_int(5)); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + assert_eq!(cur.position(), data.len()); + let arr = dec.flush(None).unwrap(); + let s = arr.as_any().downcast_ref::().unwrap(); + let id = s + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(id.len(), 1); + assert_eq!(id.value(0), 5); + } + + #[test] + fn test_skip_writer_map_with_negative_block_count_fast() { + let mut dec = make_record_resolved_decoder( + &[("id", DataType::Int32, false)], + vec![None, Some(0)], + vec![Some(Skipper::Map(Box::new(Skipper::Int32))), None], + ); + let mut entries = Vec::new(); + entries.extend_from_slice(&encode_avro_bytes(b"k1")); + entries.extend_from_slice(&encode_avro_int(10)); + entries.extend_from_slice(&encode_avro_bytes(b"k2")); + entries.extend_from_slice(&encode_avro_int(20)); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_long(-2)); + data.extend_from_slice(&encode_avro_long(entries.len() as i64)); + data.extend_from_slice(&entries); + data.extend_from_slice(&encode_avro_long(0)); + data.extend_from_slice(&encode_avro_int(123)); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + assert_eq!(cur.position(), data.len()); + let arr = dec.flush(None).unwrap(); + let s = arr.as_any().downcast_ref::().unwrap(); + let id = s + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(id.len(), 1); + assert_eq!(id.value(0), 123); + } + + #[test] + fn test_skip_writer_nullable_field_union_nullfirst() { + let mut dec = make_record_resolved_decoder( + &[("id", DataType::Int32, false)], + vec![None, Some(0)], + vec![ + Some(super::Skipper::Nullable( + Nullability::NullFirst, + Box::new(super::Skipper::Int32), + )), + None, + ], + ); + let mut row1 = Vec::new(); + row1.extend_from_slice(&encode_avro_long(0)); + row1.extend_from_slice(&encode_avro_int(5)); + let mut row2 = Vec::new(); + row2.extend_from_slice(&encode_avro_long(1)); + row2.extend_from_slice(&encode_avro_int(123)); + row2.extend_from_slice(&encode_avro_int(7)); + let mut cur1 = AvroCursor::new(&row1); + let mut cur2 = AvroCursor::new(&row2); + dec.decode(&mut cur1).unwrap(); + dec.decode(&mut cur2).unwrap(); + assert_eq!(cur1.position(), row1.len()); + assert_eq!(cur2.position(), row2.len()); + let arr = dec.flush(None).unwrap(); + let s = arr.as_any().downcast_ref::().unwrap(); + let id = s + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(id.len(), 2); + assert_eq!(id.value(0), 5); + assert_eq!(id.value(1), 7); + } + + fn make_record_decoder_with_projector_defaults( + reader_fields: &[(&str, DataType, bool)], + field_defaults: Vec>, + default_injections: Vec<(usize, AvroLiteral)>, + writer_to_reader_len: usize, + ) -> Decoder { + assert_eq!( + field_defaults.len(), + reader_fields.len(), + "field_defaults must have one entry per reader field" + ); + let mut field_refs: Vec = Vec::with_capacity(reader_fields.len()); + let mut encodings: Vec = Vec::with_capacity(reader_fields.len()); + for (name, dt, nullable) in reader_fields { + field_refs.push(Arc::new(ArrowField::new(*name, dt.clone(), *nullable))); + let enc = match dt { + DataType::Int32 => Decoder::Int32(Vec::with_capacity(DEFAULT_CAPACITY)), + DataType::Int64 => Decoder::Int64(Vec::with_capacity(DEFAULT_CAPACITY)), + DataType::Utf8 => Decoder::String( + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + ), + other => panic!("Unsupported test field type in helper: {other:?}"), + }; + encodings.push(enc); + } + let fields: Fields = field_refs.into(); + let skip_decoders: Vec> = + (0..writer_to_reader_len).map(|_| None::).collect(); + let projector = Projector { + writer_to_reader: Arc::from(vec![None; writer_to_reader_len]), + skip_decoders, + field_defaults, + default_injections: Arc::from(default_injections), + }; + Decoder::Record(fields, encodings, Some(projector)) + } + + #[test] + fn test_default_append_int32_and_int64_from_int_and_long() { + let mut d_i32 = Decoder::Int32(Vec::with_capacity(DEFAULT_CAPACITY)); + d_i32.append_default(&AvroLiteral::Int(42)).unwrap(); + let arr = d_i32.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.len(), 1); + assert_eq!(a.value(0), 42); + let mut d_i64 = Decoder::Int64(Vec::with_capacity(DEFAULT_CAPACITY)); + d_i64.append_default(&AvroLiteral::Int(5)).unwrap(); + d_i64.append_default(&AvroLiteral::Long(7)).unwrap(); + let arr64 = d_i64.flush(None).unwrap(); + let a64 = arr64.as_any().downcast_ref::().unwrap(); + assert_eq!(a64.len(), 2); + assert_eq!(a64.value(0), 5); + assert_eq!(a64.value(1), 7); + } + + #[test] + fn test_default_append_floats_and_doubles() { + let mut d_f32 = Decoder::Float32(Vec::with_capacity(DEFAULT_CAPACITY)); + d_f32.append_default(&AvroLiteral::Float(1.5)).unwrap(); + let arr32 = d_f32.flush(None).unwrap(); + let a = arr32.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 1.5); + let mut d_f64 = Decoder::Float64(Vec::with_capacity(DEFAULT_CAPACITY)); + d_f64.append_default(&AvroLiteral::Double(2.25)).unwrap(); + let arr64 = d_f64.flush(None).unwrap(); + let b = arr64.as_any().downcast_ref::().unwrap(); + assert_eq!(b.value(0), 2.25); + } + + #[test] + fn test_default_append_string_and_bytes() { + let mut d_str = Decoder::String( + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + ); + d_str + .append_default(&AvroLiteral::String("hi".into())) + .unwrap(); + let s_arr = d_str.flush(None).unwrap(); + let arr = s_arr.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), "hi"); + let mut d_bytes = Decoder::Binary( + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + ); + d_bytes + .append_default(&AvroLiteral::Bytes(vec![1, 2, 3])) + .unwrap(); + let b_arr = d_bytes.flush(None).unwrap(); + let barr = b_arr.as_any().downcast_ref::().unwrap(); + assert_eq!(barr.value(0), &[1, 2, 3]); + let mut d_str_err = Decoder::String( + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + ); + let err = d_str_err + .append_default(&AvroLiteral::Bytes(vec![0x61, 0x62])) + .unwrap_err(); + assert!( + err.to_string() + .contains("Default for string must be string"), + "unexpected error: {err:?}" + ); + } + + #[test] + fn test_default_append_nullable_int32_null_and_value() { + let inner = Decoder::Int32(Vec::with_capacity(DEFAULT_CAPACITY)); + let mut dec = Decoder::Nullable( + Nullability::NullFirst, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(inner), + ); + dec.append_default(&AvroLiteral::Null).unwrap(); + dec.append_default(&AvroLiteral::Int(11)).unwrap(); + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.len(), 2); + assert!(a.is_null(0)); + assert_eq!(a.value(1), 11); + } + + #[test] + fn test_default_append_array_of_ints() { + let list_dt = avro_from_codec(Codec::List(Arc::new(avro_from_codec(Codec::Int32)))); + let mut d = Decoder::try_new(&list_dt).unwrap(); + let items = vec![ + AvroLiteral::Int(1), + AvroLiteral::Int(2), + AvroLiteral::Int(3), + ]; + d.append_default(&AvroLiteral::Array(items)).unwrap(); + let arr = d.flush(None).unwrap(); + let list = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(list.len(), 1); + assert_eq!(list.value_length(0), 3); + let vals = list.values().as_any().downcast_ref::().unwrap(); + assert_eq!(vals.values(), &[1, 2, 3]); + } + + #[test] + fn test_default_append_map_string_to_int() { + let map_dt = avro_from_codec(Codec::Map(Arc::new(avro_from_codec(Codec::Int32)))); + let mut d = Decoder::try_new(&map_dt).unwrap(); + let mut m: IndexMap = IndexMap::new(); + m.insert("k1".to_string(), AvroLiteral::Int(10)); + m.insert("k2".to_string(), AvroLiteral::Int(20)); + d.append_default(&AvroLiteral::Map(m)).unwrap(); + let arr = d.flush(None).unwrap(); + let map = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(map.len(), 1); + assert_eq!(map.value_length(0), 2); + let binding = map.value(0); + let entries = binding.as_any().downcast_ref::().unwrap(); + let k = entries + .column_by_name("key") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let v = entries + .column_by_name("value") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let keys: std::collections::HashSet<&str> = (0..k.len()).map(|i| k.value(i)).collect(); + assert_eq!(keys, ["k1", "k2"].into_iter().collect()); + let vals: std::collections::HashSet = (0..v.len()).map(|i| v.value(i)).collect(); + assert_eq!(vals, [10, 20].into_iter().collect()); + } + + #[test] + fn test_default_append_enum_by_symbol() { + let symbols: Arc<[String]> = vec!["A".into(), "B".into(), "C".into()].into(); + let mut d = Decoder::Enum(Vec::with_capacity(DEFAULT_CAPACITY), symbols.clone(), None); + d.append_default(&AvroLiteral::Enum("B".into())).unwrap(); + let arr = d.flush(None).unwrap(); + let dict = arr + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(dict.len(), 1); + let expected = Int32Array::from(vec![1]); + assert_eq!(dict.keys(), &expected); + let values = dict + .values() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(values.value(1), "B"); + } + + #[test] + fn test_default_append_uuid_and_type_error() { + let mut d = Decoder::Uuid(Vec::with_capacity(DEFAULT_CAPACITY)); + let uuid_str = "123e4567-e89b-12d3-a456-426614174000"; + d.append_default(&AvroLiteral::String(uuid_str.into())) + .unwrap(); + let arr_ref = d.flush(None).unwrap(); + let arr = arr_ref + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(arr.value_length(), 16); + assert_eq!(arr.len(), 1); + let mut d2 = Decoder::Uuid(Vec::with_capacity(DEFAULT_CAPACITY)); + let err = d2 + .append_default(&AvroLiteral::Bytes(vec![0u8; 16])) + .unwrap_err(); + assert!( + err.to_string().contains("Default for uuid must be string"), + "unexpected error: {err:?}" + ); + } + + #[test] + fn test_default_append_fixed_and_length_mismatch() { + let mut d = Decoder::Fixed(4, Vec::with_capacity(DEFAULT_CAPACITY)); + d.append_default(&AvroLiteral::Bytes(vec![1, 2, 3, 4])) + .unwrap(); + let arr_ref = d.flush(None).unwrap(); + let arr = arr_ref + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(arr.value_length(), 4); + assert_eq!(arr.value(0), &[1, 2, 3, 4]); + let mut d_err = Decoder::Fixed(4, Vec::with_capacity(DEFAULT_CAPACITY)); + let err = d_err + .append_default(&AvroLiteral::Bytes(vec![1, 2, 3])) + .unwrap_err(); + assert!( + err.to_string().contains("Fixed default length"), + "unexpected error: {err:?}" + ); + } + + #[test] + fn test_default_append_duration_and_length_validation() { + let dt = avro_from_codec(Codec::Interval); + let mut d = Decoder::try_new(&dt).unwrap(); + let mut bytes = Vec::with_capacity(12); + bytes.extend_from_slice(&1u32.to_le_bytes()); + bytes.extend_from_slice(&2u32.to_le_bytes()); + bytes.extend_from_slice(&3u32.to_le_bytes()); + d.append_default(&AvroLiteral::Bytes(bytes)).unwrap(); + let arr_ref = d.flush(None).unwrap(); + let arr = arr_ref + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(arr.len(), 1); + let v = arr.value(0); + assert_eq!(v.months, 1); + assert_eq!(v.days, 2); + assert_eq!(v.nanoseconds, 3_000_000); + let mut d_err = Decoder::try_new(&avro_from_codec(Codec::Interval)).unwrap(); + let err = d_err + .append_default(&AvroLiteral::Bytes(vec![0u8; 11])) + .unwrap_err(); + assert!( + err.to_string() + .contains("Duration default must be exactly 12 bytes"), + "unexpected error: {err:?}" + ); + } + + #[test] + fn test_default_append_decimal256_from_bytes() { + let dt = avro_from_codec(Codec::Decimal(50, Some(2), Some(32))); + let mut d = Decoder::try_new(&dt).unwrap(); + let pos: [u8; 32] = [ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x30, 0x39, + ]; + d.append_default(&AvroLiteral::Bytes(pos.to_vec())).unwrap(); + let neg: [u8; 32] = [ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0x85, + ]; + d.append_default(&AvroLiteral::Bytes(neg.to_vec())).unwrap(); + let arr = d.flush(None).unwrap(); + let dec = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec.len(), 2); + assert_eq!(dec.value_as_string(0), "123.45"); + assert_eq!(dec.value_as_string(1), "-1.23"); + } + + #[test] + fn test_record_append_default_map_missing_fields_uses_projector_field_defaults() { + let field_defaults = vec![None, Some(AvroLiteral::String("hi".into()))]; + let mut rec = make_record_decoder_with_projector_defaults( + &[("a", DataType::Int32, false), ("b", DataType::Utf8, false)], + field_defaults, + vec![], + 0, + ); + let mut map: IndexMap = IndexMap::new(); + map.insert("a".to_string(), AvroLiteral::Int(7)); + rec.append_default(&AvroLiteral::Map(map)).unwrap(); + let arr = rec.flush(None).unwrap(); + let s = arr.as_any().downcast_ref::().unwrap(); + let a = s + .column_by_name("a") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let b = s + .column_by_name("b") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(a.value(0), 7); + assert_eq!(b.value(0), "hi"); + } + + #[test] + fn test_record_append_default_null_uses_projector_field_defaults() { + let field_defaults = vec![ + Some(AvroLiteral::Int(5)), + Some(AvroLiteral::String("x".into())), + ]; + let mut rec = make_record_decoder_with_projector_defaults( + &[("a", DataType::Int32, false), ("b", DataType::Utf8, false)], + field_defaults, + vec![], + 0, + ); + rec.append_default(&AvroLiteral::Null).unwrap(); + let arr = rec.flush(None).unwrap(); + let s = arr.as_any().downcast_ref::().unwrap(); + let a = s + .column_by_name("a") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let b = s + .column_by_name("b") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(a.value(0), 5); + assert_eq!(b.value(0), "x"); + } + + #[test] + fn test_record_append_default_missing_fields_without_projector_defaults_yields_type_nulls_or_empties( + ) { + let fields = vec![("a", DataType::Int32, true), ("b", DataType::Utf8, true)]; + let mut field_refs: Vec = Vec::new(); + let mut encoders: Vec = Vec::new(); + for (name, dt, nullable) in &fields { + field_refs.push(Arc::new(ArrowField::new(*name, dt.clone(), *nullable))); + } + let enc_a = Decoder::Nullable( + Nullability::NullSecond, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(Decoder::Int32(Vec::with_capacity(DEFAULT_CAPACITY))), + ); + let enc_b = Decoder::Nullable( + Nullability::NullSecond, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(Decoder::String( + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + )), + ); + encoders.push(enc_a); + encoders.push(enc_b); + let projector = Projector { + writer_to_reader: Arc::from(vec![]), + skip_decoders: vec![], + field_defaults: vec![None, None], // no defaults -> append_null + default_injections: Arc::from(Vec::<(usize, AvroLiteral)>::new()), + }; + let mut rec = Decoder::Record(field_refs.into(), encoders, Some(projector)); + let mut map: IndexMap = IndexMap::new(); + map.insert("a".to_string(), AvroLiteral::Int(9)); + rec.append_default(&AvroLiteral::Map(map)).unwrap(); + let arr = rec.flush(None).unwrap(); + let s = arr.as_any().downcast_ref::().unwrap(); + let a = s + .column_by_name("a") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let b = s + .column_by_name("b") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert!(a.is_valid(0)); + assert_eq!(a.value(0), 9); + assert!(b.is_null(0)); + } + + #[test] + fn test_projector_default_injection_when_writer_lacks_fields() { + let defaults = vec![None, None]; + let injections = vec![ + (0, AvroLiteral::Int(99)), + (1, AvroLiteral::String("alice".into())), + ]; + let mut rec = make_record_decoder_with_projector_defaults( + &[ + ("id", DataType::Int32, false), + ("name", DataType::Utf8, false), + ], + defaults, + injections, + 0, + ); + rec.decode(&mut AvroCursor::new(&[])).unwrap(); + let arr = rec.flush(None).unwrap(); + let s = arr.as_any().downcast_ref::().unwrap(); + let id = s + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let name = s + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(id.value(0), 99); + assert_eq!(name.value(0), "alice"); + } } diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index 539e7b02f306..511ba280f7ae 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -15,28 +15,67 @@ // specific language governing permissions and limitations // under the License. -use arrow_schema::ArrowError; +use arrow_schema::{ + ArrowError, DataType, Field as ArrowField, IntervalUnit, Schema as ArrowSchema, TimeUnit, +}; use serde::{Deserialize, Serialize}; -use serde_json::{json, Value}; +use serde_json::{json, Map as JsonMap, Value}; +#[cfg(feature = "sha256")] +use sha2::{Digest, Sha256}; use std::cmp::PartialEq; use std::collections::hash_map::Entry; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use strum_macros::AsRefStr; +/// The Avro single‑object encoding “magic” bytes (`0xC3 0x01`) +pub const SINGLE_OBJECT_MAGIC: [u8; 2] = [0xC3, 0x01]; + +/// The Confluent "magic" byte (`0x00`) +pub const CONFLUENT_MAGIC: [u8; 1] = [0x00]; + /// The metadata key used for storing the JSON encoded [`Schema`] pub const SCHEMA_METADATA_KEY: &str = "avro.schema"; -/// The Avro single‑object encoding “magic” bytes (`0xC3 0x01`) -pub const SINGLE_OBJECT_MAGIC: [u8; 2] = [0xC3, 0x01]; +/// Metadata key used to represent Avro enum symbols in an Arrow schema. +pub const AVRO_ENUM_SYMBOLS_METADATA_KEY: &str = "avro.enum.symbols"; + +/// Metadata key used to store the default value of a field in an Avro schema. +pub const AVRO_FIELD_DEFAULT_METADATA_KEY: &str = "avro.field.default"; + +/// Metadata key used to store the name of a type in an Avro schema. +pub const AVRO_NAME_METADATA_KEY: &str = "avro.name"; + +/// Metadata key used to store the name of a type in an Avro schema. +pub const AVRO_NAMESPACE_METADATA_KEY: &str = "avro.namespace"; + +/// Metadata key used to store the documentation for a type in an Avro schema. +pub const AVRO_DOC_METADATA_KEY: &str = "avro.doc"; + +/// Default name for the root record in an Avro schema. +pub const AVRO_ROOT_RECORD_DEFAULT_NAME: &str = "topLevelRecord"; /// Compare two Avro schemas for equality (identical schemas). /// Returns true if the schemas have the same parsing canonical form (i.e., logically identical). pub fn compare_schemas(writer: &Schema, reader: &Schema) -> Result { - let canon_writer = generate_canonical_form(writer)?; - let canon_reader = generate_canonical_form(reader)?; + let canon_writer = AvroSchema::generate_canonical_form(writer)?; + let canon_reader = AvroSchema::generate_canonical_form(reader)?; Ok(canon_writer == canon_reader) } +/// Avro types are not nullable, with nullability instead encoded as a union +/// where one of the variants is the null type. +/// +/// To accommodate this, we specially case two-variant unions where one of the +/// variants is the null type, and use this to derive arrow's notion of nullability +#[derive(Debug, Copy, Clone, PartialEq, Default)] +pub enum Nullability { + /// The nulls are encoded as the first union variant + #[default] + NullFirst, + /// The nulls are encoded as the second union variant + NullSecond, +} + /// Either a [`PrimitiveType`] or a reference to a previously defined named type /// /// @@ -91,7 +130,7 @@ pub struct Attributes<'a> { /// Additional JSON attributes #[serde(flatten)] - pub additional: HashMap<&'a str, serde_json::Value>, + pub additional: HashMap<&'a str, Value>, } impl Attributes<'_> { @@ -198,8 +237,8 @@ pub struct Field<'a> { #[serde(borrow)] pub r#type: Schema<'a>, /// Optional default value for this field - #[serde(borrow, default)] - pub default: Option<&'a str>, + #[serde(default)] + pub default: Option, } /// An enumeration @@ -284,6 +323,17 @@ pub struct AvroSchema { pub json_string: String, } +impl TryFrom<&ArrowSchema> for AvroSchema { + type Error = ArrowError; + + /// Converts an `ArrowSchema` to `AvroSchema`, delegating to + /// `AvroSchema::from_arrow_with_options` with `None` so that the + /// union null ordering is decided by `Nullability::default()`. + fn try_from(schema: &ArrowSchema) -> Result { + AvroSchema::from_arrow_with_options(schema, None) + } +} + impl AvroSchema { /// Creates a new `AvroSchema` from a JSON string. pub fn new(json_string: String) -> Self { @@ -300,17 +350,160 @@ impl AvroSchema { /// Returns the Rabin fingerprint of the schema. pub fn fingerprint(&self) -> Result { - generate_fingerprint_rabin(&self.schema()?) + Self::generate_fingerprint_rabin(&self.schema()?) + } + + /// Generates a fingerprint for the given `Schema` using the specified [`FingerprintAlgorithm`]. + /// + /// The fingerprint is computed over the schema's Parsed Canonical Form + /// as defined by the Avro specification. Depending on `hash_type`, this + /// will return one of the supported [`Fingerprint`] variants: + /// - [`Fingerprint::Rabin`] for [`FingerprintAlgorithm::Rabin`] + /// - [`Fingerprint::MD5`] for [`FingerprintAlgorithm::MD5`] + /// - [`Fingerprint::SHA256`] for [`FingerprintAlgorithm::SHA256`] + /// + /// Note: [`FingerprintAlgorithm::None`] cannot be used to generate a fingerprint + /// and will result in an error. If you intend to use a Schema Registry ID-based + /// wire format, load or set the [`Fingerprint::Id`] directly via [`Fingerprint::load_fingerprint_id`] + /// or [`SchemaStore::set`]. + /// + /// See also: + /// + /// # Errors + /// Returns an error if generating the canonical form of the schema fails, + /// or if `hash_type` is [`FingerprintAlgorithm::None`]. + /// + /// # Examples + /// ```no_run + /// use arrow_avro::schema::{AvroSchema, FingerprintAlgorithm}; + /// + /// let avro = AvroSchema::new("\"string\"".to_string()); + /// let schema = avro.schema().unwrap(); + /// let fp = AvroSchema::generate_fingerprint(&schema, FingerprintAlgorithm::Rabin).unwrap(); + /// ``` + pub fn generate_fingerprint( + schema: &Schema, + hash_type: FingerprintAlgorithm, + ) -> Result { + let canonical = Self::generate_canonical_form(schema).map_err(|e| { + ArrowError::ComputeError(format!("Failed to generate canonical form for schema: {e}")) + })?; + match hash_type { + FingerprintAlgorithm::Rabin => { + Ok(Fingerprint::Rabin(compute_fingerprint_rabin(&canonical))) + } + FingerprintAlgorithm::None => Err(ArrowError::SchemaError( + "FingerprintAlgorithm of None cannot be used to generate a fingerprint; \ + if using Fingerprint::Id, pass the registry ID in instead using the set method." + .to_string(), + )), + #[cfg(feature = "md5")] + FingerprintAlgorithm::MD5 => Ok(Fingerprint::MD5(compute_fingerprint_md5(&canonical))), + #[cfg(feature = "sha256")] + FingerprintAlgorithm::SHA256 => { + Ok(Fingerprint::SHA256(compute_fingerprint_sha256(&canonical))) + } + } + } + + /// Generates the 64-bit Rabin fingerprint for the given `Schema`. + /// + /// The fingerprint is computed from the canonical form of the schema. + /// This is also known as `CRC-64-AVRO`. + /// + /// # Returns + /// A `Fingerprint::Rabin` variant containing the 64-bit fingerprint. + pub fn generate_fingerprint_rabin(schema: &Schema) -> Result { + Self::generate_fingerprint(schema, FingerprintAlgorithm::Rabin) + } + + /// Generates the Parsed Canonical Form for the given [`Schema`]. + /// + /// The canonical form is a standardized JSON representation of the schema, + /// primarily used for generating a schema fingerprint for equality checking. + /// + /// This form strips attributes that do not affect the schema's identity, + /// such as `doc` fields, `aliases`, and any properties not defined in the + /// Avro specification. + /// + /// + pub fn generate_canonical_form(schema: &Schema) -> Result { + build_canonical(schema, None) + } + + /// Build Avro JSON from an Arrow [`ArrowSchema`], applying the given null‑union order. + /// + /// If the input Arrow schema already contains Avro JSON in + /// [`SCHEMA_METADATA_KEY`], that JSON is returned verbatim to preserve + /// the exact header encoding alignment; otherwise, a new JSON is generated + /// honoring `null_union_order` at **all nullable sites**. + pub fn from_arrow_with_options( + schema: &ArrowSchema, + null_order: Option, + ) -> Result { + if let Some(json) = schema.metadata.get(SCHEMA_METADATA_KEY) { + return Ok(AvroSchema::new(json.clone())); + } + let order = null_order.unwrap_or_default(); + let mut name_gen = NameGenerator::default(); + let fields_json = schema + .fields() + .iter() + .map(|f| arrow_field_to_avro(f, &mut name_gen, order)) + .collect::, _>>()?; + let record_name = schema + .metadata + .get(AVRO_NAME_METADATA_KEY) + .map_or(AVRO_ROOT_RECORD_DEFAULT_NAME, |s| s.as_str()); + let mut record = JsonMap::with_capacity(schema.metadata.len() + 4); + record.insert("type".into(), Value::String("record".into())); + record.insert( + "name".into(), + Value::String(sanitise_avro_name(record_name)), + ); + if let Some(ns) = schema.metadata.get(AVRO_NAMESPACE_METADATA_KEY) { + record.insert("namespace".into(), Value::String(ns.clone())); + } + if let Some(doc) = schema.metadata.get(AVRO_DOC_METADATA_KEY) { + record.insert("doc".into(), Value::String(doc.clone())); + } + record.insert("fields".into(), Value::Array(fields_json)); + extend_with_passthrough_metadata(&mut record, &schema.metadata); + let json_string = serde_json::to_string(&Value::Object(record)) + .map_err(|e| ArrowError::SchemaError(format!("Serializing Avro JSON failed: {e}")))?; + Ok(AvroSchema::new(json_string)) } } /// Supported fingerprint algorithms for Avro schema identification. -/// Currently only `Rabin` is supported, `SHA256` and `MD5` support will come in a future update +/// For use with Confluent Schema Registry IDs, set to None. #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)] pub enum FingerprintAlgorithm { /// 64‑bit CRC‑64‑AVRO Rabin fingerprint. #[default] Rabin, + /// Represents a fingerprint not based on a hash algorithm, (e.g., a 32-bit Schema Registry ID.) + None, + #[cfg(feature = "md5")] + /// 128-bit MD5 message digest. + MD5, + #[cfg(feature = "sha256")] + /// 256-bit SHA-256 digest. + SHA256, +} + +/// Allow easy extraction of the algorithm used to create a fingerprint. +impl From<&Fingerprint> for FingerprintAlgorithm { + fn from(fp: &Fingerprint) -> Self { + match fp { + Fingerprint::Rabin(_) => FingerprintAlgorithm::Rabin, + Fingerprint::Id(_) => FingerprintAlgorithm::None, + #[cfg(feature = "md5")] + Fingerprint::MD5(_) => FingerprintAlgorithm::MD5, + #[cfg(feature = "sha256")] + Fingerprint::SHA256(_) => FingerprintAlgorithm::SHA256, + } + } } /// A schema fingerprint in one of the supported formats. @@ -318,64 +511,36 @@ pub enum FingerprintAlgorithm { /// This is used as the key inside `SchemaStore` `HashMap`. Each `SchemaStore` /// instance always stores only one variant, matching its configured /// `FingerprintAlgorithm`, but the enum makes the API uniform. -/// Currently only `Rabin` is supported /// /// +/// #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum Fingerprint { /// A 64-bit Rabin fingerprint. Rabin(u64), + /// A 32-bit Schema Registry ID. + Id(u32), + #[cfg(feature = "md5")] + /// A 128-bit MD5 fingerprint. + MD5([u8; 16]), + #[cfg(feature = "sha256")] + /// A 256-bit SHA-256 fingerprint. + SHA256([u8; 32]), } -/// Allow easy extraction of the algorithm used to create a fingerprint. -impl From<&Fingerprint> for FingerprintAlgorithm { - fn from(fp: &Fingerprint) -> Self { - match fp { - Fingerprint::Rabin(_) => FingerprintAlgorithm::Rabin, - } - } -} - -/// Generates a fingerprint for the given `Schema` using the specified `FingerprintAlgorithm`. -pub(crate) fn generate_fingerprint( - schema: &Schema, - hash_type: FingerprintAlgorithm, -) -> Result { - let canonical = generate_canonical_form(schema).map_err(|e| { - ArrowError::ComputeError(format!("Failed to generate canonical form for schema: {e}")) - })?; - match hash_type { - FingerprintAlgorithm::Rabin => { - Ok(Fingerprint::Rabin(compute_fingerprint_rabin(&canonical))) - } +impl Fingerprint { + /// Loads the 32-bit Schema Registry fingerprint (Confluent Schema Registry ID). + /// + /// The provided `id` is in big-endian wire order; this converts it to host order + /// and returns `Fingerprint::Id`. + /// + /// # Returns + /// A `Fingerprint::Id` variant containing the 32-bit fingerprint. + pub fn load_fingerprint_id(id: u32) -> Self { + Fingerprint::Id(u32::from_be(id)) } } -/// Generates the 64-bit Rabin fingerprint for the given `Schema`. -/// -/// The fingerprint is computed from the canonical form of the schema. -/// This is also known as `CRC-64-AVRO`. -/// -/// # Returns -/// A `Fingerprint::Rabin` variant containing the 64-bit fingerprint. -pub fn generate_fingerprint_rabin(schema: &Schema) -> Result { - generate_fingerprint(schema, FingerprintAlgorithm::Rabin) -} - -/// Generates the Parsed Canonical Form for the given [`Schema`]. -/// -/// The canonical form is a standardized JSON representation of the schema, -/// primarily used for generating a schema fingerprint for equality checking. -/// -/// This form strips attributes that do not affect the schema's identity, -/// such as `doc` fields, `aliases`, and any properties not defined in the -/// Avro specification. -/// -/// -pub fn generate_canonical_form(schema: &Schema) -> Result { - build_canonical(schema, None) -} - /// An in-memory cache of Avro schemas, indexed by their fingerprint. /// /// `SchemaStore` provides a mechanism to store and retrieve Avro schemas efficiently. @@ -410,17 +575,16 @@ pub struct SchemaStore { schemas: HashMap, } -impl TryFrom<&[AvroSchema]> for SchemaStore { +impl TryFrom> for SchemaStore { type Error = ArrowError; - /// Creates a `SchemaStore` from a slice of schemas. - /// Each schema in the slice is registered with the new store. - fn try_from(schemas: &[AvroSchema]) -> Result { - let mut store = SchemaStore::new(); - for schema in schemas { - store.register(schema.clone())?; - } - Ok(store) + /// Creates a `SchemaStore` from a HashMap of schemas. + /// Each schema in the HashMap is registered with the new store. + fn try_from(schemas: HashMap) -> Result { + Ok(Self { + schemas, + ..Self::default() + }) } } @@ -430,23 +594,35 @@ impl SchemaStore { Self::default() } - /// Registers a schema with the store and returns its fingerprint. + /// Creates an empty `SchemaStore` using the default fingerprinting algorithm (64-bit Rabin). + pub fn new_with_type(fingerprint_algorithm: FingerprintAlgorithm) -> Self { + Self { + fingerprint_algorithm, + ..Self::default() + } + } + + /// Registers a schema with the store and the provided fingerprint. + /// Note: Confluent wire format implementations should leverage this method. /// - /// A fingerprint is calculated for the given schema using the store's configured - /// hash type. If a schema with the same fingerprint does not already exist in the - /// store, the new schema is inserted. If the fingerprint already exists, the - /// existing schema is not overwritten. + /// A schema is set in the store, using the provided fingerprint. If a schema + /// with the same fingerprint does not already exist in the store, the new schema + /// is inserted. If the fingerprint already exists, the existing schema is not overwritten. /// /// # Arguments /// + /// * `fingerprint` - A reference to the `Fingerprint` of the schema to register. /// * `schema` - The `AvroSchema` to register. /// /// # Returns /// - /// A `Result` containing the `Fingerprint` of the schema if successful, + /// A `Result` returning the provided `Fingerprint` of the schema if successful, /// or an `ArrowError` on failure. - pub fn register(&mut self, schema: AvroSchema) -> Result { - let fingerprint = generate_fingerprint(&schema.schema()?, self.fingerprint_algorithm)?; + pub fn set( + &mut self, + fingerprint: Fingerprint, + schema: AvroSchema, + ) -> Result { match self.schemas.entry(fingerprint) { Entry::Occupied(entry) => { if entry.get() != &schema { @@ -462,6 +638,37 @@ impl SchemaStore { Ok(fingerprint) } + /// Registers a schema with the store and returns its fingerprint. + /// + /// A fingerprint is calculated for the given schema using the store's configured + /// hash type. If a schema with the same fingerprint does not already exist in the + /// store, the new schema is inserted. If the fingerprint already exists, the + /// existing schema is not overwritten. If FingerprintAlgorithm is set to None, this + /// method will return an error. Confluent wire format implementations should leverage the + /// set method instead. + /// + /// # Arguments + /// + /// * `schema` - The `AvroSchema` to register. + /// + /// # Returns + /// + /// A `Result` containing the `Fingerprint` of the schema if successful, + /// or an `ArrowError` on failure. + pub fn register(&mut self, schema: AvroSchema) -> Result { + if self.fingerprint_algorithm == FingerprintAlgorithm::None { + return Err(ArrowError::SchemaError( + "Invalid FingerprintAlgorithm; unable to generate fingerprint. \ + Use the set method directly instead, providing a valid fingerprint" + .to_string(), + )); + } + let fingerprint = + AvroSchema::generate_fingerprint(&schema.schema()?, self.fingerprint_algorithm)?; + self.set(fingerprint, schema)?; + Ok(fingerprint) + } + /// Looks up a schema by its `Fingerprint`. /// /// # Arguments @@ -647,12 +854,436 @@ pub(crate) fn compute_fingerprint_rabin(canonical_form: &str) -> u64 { fp } +#[cfg(feature = "md5")] +/// Compute the **128‑bit MD5** fingerprint of the canonical form. +/// +/// Returns a 16‑byte array (`[u8; 16]`) containing the full MD5 digest, +/// exactly as required by the Avro specification. +#[inline] +pub(crate) fn compute_fingerprint_md5(canonical_form: &str) -> [u8; 16] { + let digest = md5::compute(canonical_form.as_bytes()); + digest.0 +} + +#[cfg(feature = "sha256")] +/// Compute the **256‑bit SHA‑256** fingerprint of the canonical form. +/// +/// Returns a 32‑byte array (`[u8; 32]`) containing the full SHA‑256 digest. +#[inline] +pub(crate) fn compute_fingerprint_sha256(canonical_form: &str) -> [u8; 32] { + let mut hasher = Sha256::new(); + hasher.update(canonical_form.as_bytes()); + let digest = hasher.finalize(); + digest.into() +} + +#[inline] +fn is_internal_arrow_key(key: &str) -> bool { + key.starts_with("ARROW:") || key == SCHEMA_METADATA_KEY +} + +/// Copies Arrow schema metadata entries to the provided JSON map, +/// skipping keys that are Avro-reserved, internal Arrow keys, or +/// nested under the `avro.schema.` namespace. Values that parse as +/// JSON are inserted as JSON; otherwise the raw string is preserved. +fn extend_with_passthrough_metadata( + target: &mut JsonMap, + metadata: &HashMap, +) { + for (meta_key, meta_val) in metadata { + if meta_key.starts_with("avro.") || is_internal_arrow_key(meta_key) { + continue; + } + let json_val = + serde_json::from_str(meta_val).unwrap_or_else(|_| Value::String(meta_val.clone())); + target.insert(meta_key.clone(), json_val); + } +} + +// Sanitize an arbitrary string so it is a valid Avro field or type name +fn sanitise_avro_name(base_name: &str) -> String { + if base_name.is_empty() { + return "_".to_owned(); + } + let mut out: String = base_name + .chars() + .map(|char| { + if char.is_ascii_alphanumeric() || char == '_' { + char + } else { + '_' + } + }) + .collect(); + if out.as_bytes()[0].is_ascii_digit() { + out.insert(0, '_'); + } + out +} + +#[derive(Default)] +struct NameGenerator { + used: HashSet, + counters: HashMap, +} + +impl NameGenerator { + fn make_unique(&mut self, field_name: &str) -> String { + let field_name = sanitise_avro_name(field_name); + if self.used.insert(field_name.clone()) { + self.counters.insert(field_name.clone(), 1); + return field_name; + } + let counter = self.counters.entry(field_name.clone()).or_insert(1); + loop { + let candidate = format!("{field_name}_{}", *counter); + if self.used.insert(candidate.clone()) { + return candidate; + } + *counter += 1; + } + } +} + +fn merge_extras(schema: Value, mut extras: JsonMap) -> Value { + if extras.is_empty() { + return schema; + } + match schema { + Value::Object(mut map) => { + map.extend(extras); + Value::Object(map) + } + Value::Array(mut union) => { + if let Some(non_null) = union.iter_mut().find(|val| val.as_str() != Some("null")) { + let original = std::mem::take(non_null); + *non_null = merge_extras(original, extras); + } + Value::Array(union) + } + primitive => { + let mut map = JsonMap::with_capacity(extras.len() + 1); + map.insert("type".into(), primitive); + map.extend(extras); + Value::Object(map) + } + } +} + +fn wrap_nullable(inner: Value, null_order: Nullability) -> Value { + let null = Value::String("null".into()); + let elements = match null_order { + Nullability::NullFirst => vec![null, inner], + Nullability::NullSecond => vec![inner, null], + }; + Value::Array(elements) +} + +fn datatype_to_avro( + dt: &DataType, + field_name: &str, + metadata: &HashMap, + name_gen: &mut NameGenerator, + null_order: Nullability, +) -> Result<(Value, JsonMap), ArrowError> { + let mut extras = JsonMap::new(); + let mut handle_decimal = |precision: &u8, scale: &i8| -> Result { + if *scale < 0 { + return Err(ArrowError::SchemaError(format!( + "Invalid Avro decimal for field '{field_name}': scale ({scale}) must be >= 0" + ))); + } + if (*scale as usize) > (*precision as usize) { + return Err(ArrowError::SchemaError(format!( + "Invalid Avro decimal for field '{field_name}': scale ({scale}) \ + must be <= precision ({precision})" + ))); + } + + let mut meta = JsonMap::from_iter([ + ("logicalType".into(), json!("decimal")), + ("precision".into(), json!(*precision)), + ("scale".into(), json!(*scale)), + ]); + if let Some(size) = metadata + .get("size") + .and_then(|val| val.parse::().ok()) + { + meta.insert("type".into(), json!("fixed")); + meta.insert("size".into(), json!(size)); + meta.insert("name".into(), json!(name_gen.make_unique(field_name))); + } else { + meta.insert("type".into(), json!("bytes")); + } + Ok(Value::Object(meta)) + }; + let val = match dt { + DataType::Null => Value::String("null".into()), + DataType::Boolean => Value::String("boolean".into()), + DataType::Int8 | DataType::Int16 | DataType::UInt8 | DataType::UInt16 | DataType::Int32 => { + Value::String("int".into()) + } + DataType::UInt32 | DataType::Int64 | DataType::UInt64 => Value::String("long".into()), + DataType::Float16 | DataType::Float32 => Value::String("float".into()), + DataType::Float64 => Value::String("double".into()), + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Value::String("string".into()), + DataType::Binary | DataType::LargeBinary => Value::String("bytes".into()), + DataType::FixedSizeBinary(len) => { + let is_uuid = metadata + .get("logicalType") + .is_some_and(|value| value == "uuid") + || (*len == 16 + && metadata + .get("ARROW:extension:name") + .is_some_and(|value| value == "uuid")); + if is_uuid { + json!({ "type": "string", "logicalType": "uuid" }) + } else { + json!({ + "type": "fixed", + "name": name_gen.make_unique(field_name), + "size": len + }) + } + } + #[cfg(feature = "small_decimals")] + DataType::Decimal32(precision, scale) | DataType::Decimal64(precision, scale) => { + handle_decimal(precision, scale)? + } + DataType::Decimal128(precision, scale) | DataType::Decimal256(precision, scale) => { + handle_decimal(precision, scale)? + } + DataType::Date32 => json!({ "type": "int", "logicalType": "date" }), + DataType::Date64 => json!({ "type": "long", "logicalType": "local-timestamp-millis" }), + DataType::Time32(unit) => match unit { + TimeUnit::Millisecond => json!({ "type": "int", "logicalType": "time-millis" }), + TimeUnit::Second => { + extras.insert("arrowTimeUnit".into(), Value::String("second".into())); + Value::String("int".into()) + } + _ => Value::String("int".into()), + }, + DataType::Time64(unit) => match unit { + TimeUnit::Microsecond => json!({ "type": "long", "logicalType": "time-micros" }), + TimeUnit::Nanosecond => { + extras.insert("arrowTimeUnit".into(), Value::String("nanosecond".into())); + Value::String("long".into()) + } + _ => Value::String("long".into()), + }, + DataType::Timestamp(unit, tz) => { + let logical_type = match (unit, tz.is_some()) { + (TimeUnit::Millisecond, true) => "timestamp-millis", + (TimeUnit::Millisecond, false) => "local-timestamp-millis", + (TimeUnit::Microsecond, true) => "timestamp-micros", + (TimeUnit::Microsecond, false) => "local-timestamp-micros", + (TimeUnit::Second, _) => { + extras.insert("arrowTimeUnit".into(), Value::String("second".into())); + return Ok((Value::String("long".into()), extras)); + } + (TimeUnit::Nanosecond, _) => { + extras.insert("arrowTimeUnit".into(), Value::String("nanosecond".into())); + return Ok((Value::String("long".into()), extras)); + } + }; + json!({ "type": "long", "logicalType": logical_type }) + } + DataType::Duration(unit) => { + extras.insert( + "arrowDurationUnit".into(), + Value::String(format!("{unit:?}").to_lowercase()), + ); + Value::String("long".into()) + } + DataType::Interval(IntervalUnit::MonthDayNano) => json!({ + "type": "fixed", + "name": name_gen.make_unique(&format!("{field_name}_duration")), + "size": 12, + "logicalType": "duration" + }), + DataType::Interval(IntervalUnit::YearMonth) => { + extras.insert( + "arrowIntervalUnit".into(), + Value::String("yearmonth".into()), + ); + Value::String("long".into()) + } + DataType::Interval(IntervalUnit::DayTime) => { + extras.insert("arrowIntervalUnit".into(), Value::String("daytime".into())); + Value::String("long".into()) + } + DataType::List(child) | DataType::LargeList(child) => { + if matches!(dt, DataType::LargeList(_)) { + extras.insert("arrowLargeList".into(), Value::Bool(true)); + } + let items_schema = process_datatype( + child.data_type(), + child.name(), + child.metadata(), + name_gen, + null_order, + child.is_nullable(), + )?; + json!({ + "type": "array", + "items": items_schema + }) + } + DataType::FixedSizeList(child, len) => { + extras.insert("arrowFixedSize".into(), json!(len)); + let items_schema = process_datatype( + child.data_type(), + child.name(), + child.metadata(), + name_gen, + null_order, + child.is_nullable(), + )?; + json!({ + "type": "array", + "items": items_schema + }) + } + DataType::Map(entries, _) => { + let value_field = match entries.data_type() { + DataType::Struct(fs) => &fs[1], + _ => { + return Err(ArrowError::SchemaError( + "Map 'entries' field must be Struct(key,value)".into(), + )) + } + }; + let values_schema = process_datatype( + value_field.data_type(), + value_field.name(), + value_field.metadata(), + name_gen, + null_order, + value_field.is_nullable(), + )?; + json!({ + "type": "map", + "values": values_schema + }) + } + DataType::Struct(fields) => { + let avro_fields = fields + .iter() + .map(|field| arrow_field_to_avro(field, name_gen, null_order)) + .collect::, _>>()?; + json!({ + "type": "record", + "name": name_gen.make_unique(field_name), + "fields": avro_fields + }) + } + DataType::Dictionary(_, value) => { + if let Some(j) = metadata.get(AVRO_ENUM_SYMBOLS_METADATA_KEY) { + let symbols: Vec<&str> = + serde_json::from_str(j).map_err(|e| ArrowError::ParseError(e.to_string()))?; + json!({ + "type": "enum", + "name": name_gen.make_unique(field_name), + "symbols": symbols + }) + } else { + process_datatype( + value.as_ref(), + field_name, + metadata, + name_gen, + null_order, + false, + )? + } + } + DataType::RunEndEncoded(_, values) => process_datatype( + values.data_type(), + values.name(), + values.metadata(), + name_gen, + null_order, + false, + )?, + DataType::Union(_, _) => { + return Err(ArrowError::NotYetImplemented( + "Arrow Union to Avro Union not yet supported".into(), + )) + } + other => { + return Err(ArrowError::NotYetImplemented(format!( + "Arrow type {other:?} has no Avro representation" + ))) + } + }; + Ok((val, extras)) +} + +fn process_datatype( + dt: &DataType, + field_name: &str, + metadata: &HashMap, + name_gen: &mut NameGenerator, + null_order: Nullability, + is_nullable: bool, +) -> Result { + let (schema, extras) = datatype_to_avro(dt, field_name, metadata, name_gen, null_order)?; + let mut merged = merge_extras(schema, extras); + if is_nullable { + merged = wrap_nullable(merged, null_order) + } + Ok(merged) +} + +fn arrow_field_to_avro( + field: &ArrowField, + name_gen: &mut NameGenerator, + null_order: Nullability, +) -> Result { + let avro_name = sanitise_avro_name(field.name()); + let schema_value = process_datatype( + field.data_type(), + &avro_name, + field.metadata(), + name_gen, + null_order, + field.is_nullable(), + )?; + // Build the field map + let mut map = JsonMap::with_capacity(field.metadata().len() + 3); + map.insert("name".into(), Value::String(avro_name)); + map.insert("type".into(), schema_value); + // Transfer selected metadata + for (meta_key, meta_val) in field.metadata() { + if is_internal_arrow_key(meta_key) { + continue; + } + match meta_key.as_str() { + AVRO_DOC_METADATA_KEY => { + map.insert("doc".into(), Value::String(meta_val.clone())); + } + AVRO_FIELD_DEFAULT_METADATA_KEY => { + let default_value = serde_json::from_str(meta_val) + .unwrap_or_else(|_| Value::String(meta_val.clone())); + map.insert("default".into(), default_value); + } + _ => { + let json_val = serde_json::from_str(meta_val) + .unwrap_or_else(|_| Value::String(meta_val.clone())); + map.insert(meta_key.clone(), json_val); + } + } + } + Ok(Value::Object(map)) +} + #[cfg(test)] mod tests { use super::*; use crate::codec::{AvroDataType, AvroField}; - use arrow_schema::{DataType, Fields, TimeUnit}; + use arrow_schema::{DataType, Fields, SchemaBuilder, TimeUnit}; use serde_json::json; + use std::sync::Arc; fn int_schema() -> Schema<'static> { Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)) @@ -682,6 +1313,19 @@ mod tests { })) } + fn single_field_schema(field: ArrowField) -> arrow_schema::Schema { + let mut sb = SchemaBuilder::new(); + sb.push(field); + sb.finish() + } + + fn assert_json_contains(avro_json: &str, needle: &str) { + assert!( + avro_json.contains(needle), + "JSON did not contain `{needle}` : {avro_json}" + ) + } + #[test] fn test_deserialize() { let t: Schema = serde_json::from_str("\"string\"").unwrap(); @@ -988,8 +1632,16 @@ mod tests { fn test_try_from_schemas_rabin() { let int_avro_schema = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); let record_avro_schema = AvroSchema::new(serde_json::to_string(&record_schema()).unwrap()); - let schemas = vec![int_avro_schema.clone(), record_avro_schema.clone()]; - let store = SchemaStore::try_from(schemas.as_slice()).unwrap(); + let mut schemas: HashMap = HashMap::new(); + schemas.insert( + int_avro_schema.fingerprint().unwrap(), + int_avro_schema.clone(), + ); + schemas.insert( + record_avro_schema.fingerprint().unwrap(), + record_avro_schema.clone(), + ); + let store = SchemaStore::try_from(schemas).unwrap(); let int_fp = int_avro_schema.fingerprint().unwrap(); assert_eq!(store.lookup(&int_fp).cloned(), Some(int_avro_schema)); let rec_fp = record_avro_schema.fingerprint().unwrap(); @@ -1000,12 +1652,21 @@ mod tests { fn test_try_from_with_duplicates() { let int_avro_schema = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); let record_avro_schema = AvroSchema::new(serde_json::to_string(&record_schema()).unwrap()); - let schemas = vec![ + let mut schemas: HashMap = HashMap::new(); + schemas.insert( + int_avro_schema.fingerprint().unwrap(), int_avro_schema.clone(), - record_avro_schema, + ); + schemas.insert( + record_avro_schema.fingerprint().unwrap(), + record_avro_schema.clone(), + ); + // Insert duplicate of int schema + schemas.insert( + int_avro_schema.fingerprint().unwrap(), int_avro_schema.clone(), - ]; - let store = SchemaStore::try_from(schemas.as_slice()).unwrap(); + ); + let store = SchemaStore::try_from(schemas).unwrap(); assert_eq!(store.schemas.len(), 2); let int_fp = int_avro_schema.fingerprint().unwrap(); assert_eq!(store.lookup(&int_fp).cloned(), Some(int_avro_schema)); @@ -1016,14 +1677,40 @@ mod tests { let mut store = SchemaStore::new(); let schema = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); let fp_enum = store.register(schema.clone()).unwrap(); - let Fingerprint::Rabin(fp_val) = fp_enum; - assert_eq!( - store.lookup(&Fingerprint::Rabin(fp_val)).cloned(), - Some(schema.clone()) - ); - assert!(store - .lookup(&Fingerprint::Rabin(fp_val.wrapping_add(1))) - .is_none()); + match fp_enum { + Fingerprint::Rabin(fp_val) => { + assert_eq!( + store.lookup(&Fingerprint::Rabin(fp_val)).cloned(), + Some(schema.clone()) + ); + assert!(store + .lookup(&Fingerprint::Rabin(fp_val.wrapping_add(1))) + .is_none()); + } + Fingerprint::Id(id) => { + unreachable!("This test should only generate Rabin fingerprints") + } + #[cfg(feature = "md5")] + Fingerprint::MD5(id) => { + unreachable!("This test should only generate Rabin fingerprints") + } + #[cfg(feature = "sha256")] + Fingerprint::SHA256(id) => { + unreachable!("This test should only generate Rabin fingerprints") + } + } + } + + #[test] + fn test_set_and_lookup_id() { + let mut store = SchemaStore::new(); + let schema = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); + let id = 42u32; + let fp = Fingerprint::Id(id); + let out_fp = store.set(fp, schema.clone()).unwrap(); + assert_eq!(out_fp, fp); + assert_eq!(store.lookup(&fp).cloned(), Some(schema.clone())); + assert!(store.lookup(&Fingerprint::Id(id.wrapping_add(1))).is_none()); } #[test] @@ -1037,10 +1724,43 @@ mod tests { assert_eq!(store.schemas.len(), 1); } + #[test] + fn test_set_and_lookup_with_provided_fingerprint() { + let mut store = SchemaStore::new(); + let schema = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); + let fp = schema.fingerprint().unwrap(); + let out_fp = store.set(fp, schema.clone()).unwrap(); + assert_eq!(out_fp, fp); + assert_eq!(store.lookup(&fp).cloned(), Some(schema)); + } + + #[test] + fn test_set_duplicate_same_schema_ok() { + let mut store = SchemaStore::new(); + let schema = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); + let fp = schema.fingerprint().unwrap(); + let _ = store.set(fp, schema.clone()).unwrap(); + let _ = store.set(fp, schema.clone()).unwrap(); + assert_eq!(store.schemas.len(), 1); + } + + #[test] + fn test_set_duplicate_different_schema_collision_error() { + let mut store = SchemaStore::new(); + let schema1 = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); + let schema2 = AvroSchema::new(serde_json::to_string(&record_schema()).unwrap()); + // Use the same Fingerprint::Id to simulate a collision across different schemas + let fp = Fingerprint::Id(123); + let _ = store.set(fp, schema1).unwrap(); + let err = store.set(fp, schema2).unwrap_err(); + let msg = format!("{err}"); + assert!(msg.contains("Schema fingerprint collision")); + } + #[test] fn test_canonical_form_generation_primitive() { let schema = int_schema(); - let canonical_form = generate_canonical_form(&schema).unwrap(); + let canonical_form = AvroSchema::generate_canonical_form(&schema).unwrap(); assert_eq!(canonical_form, r#""int""#); } @@ -1048,7 +1768,7 @@ mod tests { fn test_canonical_form_generation_record() { let schema = record_schema(); let expected_canonical_form = r#"{"name":"test.namespace.record1","type":"record","fields":[{"name":"field1","type":"int"},{"name":"field2","type":"string"}]}"#; - let canonical_form = generate_canonical_form(&schema).unwrap(); + let canonical_form = AvroSchema::generate_canonical_form(&schema).unwrap(); assert_eq!(canonical_form, expected_canonical_form); } @@ -1105,7 +1825,7 @@ mod tests { r#type: Schema::Type(Type { r#type: TypeName::Primitive(PrimitiveType::Bytes), attributes: Attributes { - logical_type: Some("decimal"), + logical_type: None, additional: HashMap::from([("precision", json!(4))]), }, }), @@ -1117,7 +1837,328 @@ mod tests { }, })); let expected_canonical_form = r#"{"name":"record_with_attrs","type":"record","fields":[{"name":"f1","type":"bytes"}]}"#; - let canonical_form = generate_canonical_form(&schema_with_attrs).unwrap(); + let canonical_form = AvroSchema::generate_canonical_form(&schema_with_attrs).unwrap(); assert_eq!(canonical_form, expected_canonical_form); } + + #[test] + fn test_primitive_mappings() { + let cases = vec![ + (DataType::Boolean, "\"boolean\""), + (DataType::Int8, "\"int\""), + (DataType::Int16, "\"int\""), + (DataType::Int32, "\"int\""), + (DataType::Int64, "\"long\""), + (DataType::UInt8, "\"int\""), + (DataType::UInt16, "\"int\""), + (DataType::UInt32, "\"long\""), + (DataType::UInt64, "\"long\""), + (DataType::Float16, "\"float\""), + (DataType::Float32, "\"float\""), + (DataType::Float64, "\"double\""), + (DataType::Utf8, "\"string\""), + (DataType::Binary, "\"bytes\""), + ]; + for (dt, avro_token) in cases { + let field = ArrowField::new("col", dt.clone(), false); + let arrow_schema = single_field_schema(field); + let avro = AvroSchema::try_from(&arrow_schema).unwrap(); + assert_json_contains(&avro.json_string, avro_token); + } + } + + #[test] + fn test_temporal_mappings() { + let cases = vec![ + (DataType::Date32, "\"logicalType\":\"date\""), + ( + DataType::Time32(TimeUnit::Millisecond), + "\"logicalType\":\"time-millis\"", + ), + ( + DataType::Time64(TimeUnit::Microsecond), + "\"logicalType\":\"time-micros\"", + ), + ( + DataType::Timestamp(TimeUnit::Millisecond, None), + "\"logicalType\":\"local-timestamp-millis\"", + ), + ( + DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())), + "\"logicalType\":\"timestamp-micros\"", + ), + ]; + for (dt, needle) in cases { + let field = ArrowField::new("ts", dt.clone(), true); + let arrow_schema = single_field_schema(field); + let avro = AvroSchema::try_from(&arrow_schema).unwrap(); + assert_json_contains(&avro.json_string, needle); + } + } + + #[test] + fn test_decimal_and_uuid() { + let decimal_field = ArrowField::new("amount", DataType::Decimal128(25, 2), false); + let dec_schema = single_field_schema(decimal_field); + let avro_dec = AvroSchema::try_from(&dec_schema).unwrap(); + assert_json_contains(&avro_dec.json_string, "\"logicalType\":\"decimal\""); + assert_json_contains(&avro_dec.json_string, "\"precision\":25"); + assert_json_contains(&avro_dec.json_string, "\"scale\":2"); + let mut md = HashMap::new(); + md.insert("logicalType".into(), "uuid".into()); + let uuid_field = + ArrowField::new("id", DataType::FixedSizeBinary(16), false).with_metadata(md); + let uuid_schema = single_field_schema(uuid_field); + let avro_uuid = AvroSchema::try_from(&uuid_schema).unwrap(); + assert_json_contains(&avro_uuid.json_string, "\"logicalType\":\"uuid\""); + } + + #[test] + fn test_interval_duration() { + let interval_field = ArrowField::new( + "span", + DataType::Interval(IntervalUnit::MonthDayNano), + false, + ); + let s = single_field_schema(interval_field); + let avro = AvroSchema::try_from(&s).unwrap(); + assert_json_contains(&avro.json_string, "\"logicalType\":\"duration\""); + assert_json_contains(&avro.json_string, "\"size\":12"); + let dur_field = ArrowField::new("latency", DataType::Duration(TimeUnit::Nanosecond), false); + let s2 = single_field_schema(dur_field); + let avro2 = AvroSchema::try_from(&s2).unwrap(); + assert_json_contains(&avro2.json_string, "\"arrowDurationUnit\""); + } + + #[test] + fn test_complex_types() { + let list_dt = DataType::List(Arc::new(ArrowField::new("item", DataType::Int32, true))); + let list_schema = single_field_schema(ArrowField::new("numbers", list_dt, false)); + let avro_list = AvroSchema::try_from(&list_schema).unwrap(); + assert_json_contains(&avro_list.json_string, "\"type\":\"array\""); + assert_json_contains(&avro_list.json_string, "\"items\""); + let value_field = ArrowField::new("value", DataType::Boolean, true); + let entries_struct = ArrowField::new( + "entries", + DataType::Struct(Fields::from(vec![ + ArrowField::new("key", DataType::Utf8, false), + value_field.clone(), + ])), + false, + ); + let map_dt = DataType::Map(Arc::new(entries_struct), false); + let map_schema = single_field_schema(ArrowField::new("props", map_dt, false)); + let avro_map = AvroSchema::try_from(&map_schema).unwrap(); + assert_json_contains(&avro_map.json_string, "\"type\":\"map\""); + assert_json_contains(&avro_map.json_string, "\"values\""); + let struct_dt = DataType::Struct(Fields::from(vec![ + ArrowField::new("f1", DataType::Int64, false), + ArrowField::new("f2", DataType::Utf8, true), + ])); + let struct_schema = single_field_schema(ArrowField::new("person", struct_dt, true)); + let avro_struct = AvroSchema::try_from(&struct_schema).unwrap(); + assert_json_contains(&avro_struct.json_string, "\"type\":\"record\""); + assert_json_contains(&avro_struct.json_string, "\"null\""); + } + + #[test] + fn test_enum_dictionary() { + let mut md = HashMap::new(); + md.insert( + AVRO_ENUM_SYMBOLS_METADATA_KEY.into(), + "[\"OPEN\",\"CLOSED\"]".into(), + ); + let enum_dt = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let field = ArrowField::new("status", enum_dt, false).with_metadata(md); + let schema = single_field_schema(field); + let avro = AvroSchema::try_from(&schema).unwrap(); + assert_json_contains(&avro.json_string, "\"type\":\"enum\""); + assert_json_contains(&avro.json_string, "\"symbols\":[\"OPEN\",\"CLOSED\"]"); + } + + #[test] + fn test_run_end_encoded() { + let ree_dt = DataType::RunEndEncoded( + Arc::new(ArrowField::new("run_ends", DataType::Int32, false)), + Arc::new(ArrowField::new("values", DataType::Utf8, false)), + ); + let s = single_field_schema(ArrowField::new("text", ree_dt, false)); + let avro = AvroSchema::try_from(&s).unwrap(); + assert_json_contains(&avro.json_string, "\"string\""); + } + + #[test] + fn test_dense_union_error() { + use arrow_schema::UnionFields; + let uf: UnionFields = vec![(0i8, Arc::new(ArrowField::new("a", DataType::Int32, false)))] + .into_iter() + .collect(); + let union_dt = DataType::Union(uf, arrow_schema::UnionMode::Dense); + let s = single_field_schema(ArrowField::new("u", union_dt, false)); + let err = AvroSchema::try_from(&s).unwrap_err(); + assert!(err + .to_string() + .contains("Arrow Union to Avro Union not yet supported")); + } + + #[test] + fn round_trip_primitive() { + let arrow_schema = ArrowSchema::new(vec![ArrowField::new("f1", DataType::Int32, false)]); + let avro_schema = AvroSchema::try_from(&arrow_schema).unwrap(); + let decoded = avro_schema.schema().unwrap(); + assert!(matches!(decoded, Schema::Complex(_))); + } + + #[test] + fn test_name_generator_sanitization_and_uniqueness() { + let f1 = ArrowField::new("weird-name", DataType::FixedSizeBinary(8), false); + let f2 = ArrowField::new("weird name", DataType::FixedSizeBinary(8), false); + let f3 = ArrowField::new("123bad", DataType::FixedSizeBinary(8), false); + let arrow_schema = ArrowSchema::new(vec![f1, f2, f3]); + let avro = AvroSchema::try_from(&arrow_schema).unwrap(); + assert_json_contains(&avro.json_string, "\"name\":\"weird_name\""); + assert_json_contains(&avro.json_string, "\"name\":\"weird_name_1\""); + assert_json_contains(&avro.json_string, "\"name\":\"_123bad\""); + } + + #[test] + fn test_date64_logical_type_mapping() { + let field = ArrowField::new("d", DataType::Date64, true); + let schema = single_field_schema(field); + let avro = AvroSchema::try_from(&schema).unwrap(); + assert_json_contains( + &avro.json_string, + "\"logicalType\":\"local-timestamp-millis\"", + ); + } + + #[test] + fn test_duration_list_extras_propagated() { + let child = ArrowField::new("lat", DataType::Duration(TimeUnit::Microsecond), false); + let list_dt = DataType::List(Arc::new(child)); + let arrow_schema = single_field_schema(ArrowField::new("durations", list_dt, false)); + let avro = AvroSchema::try_from(&arrow_schema).unwrap(); + assert_json_contains(&avro.json_string, "\"arrowDurationUnit\":\"microsecond\""); + } + + #[test] + fn test_interval_yearmonth_extra() { + let field = ArrowField::new("iv", DataType::Interval(IntervalUnit::YearMonth), false); + let schema = single_field_schema(field); + let avro = AvroSchema::try_from(&schema).unwrap(); + assert_json_contains(&avro.json_string, "\"arrowIntervalUnit\":\"yearmonth\""); + } + + #[test] + fn test_interval_daytime_extra() { + let field = ArrowField::new("iv_dt", DataType::Interval(IntervalUnit::DayTime), false); + let schema = single_field_schema(field); + let avro = AvroSchema::try_from(&schema).unwrap(); + assert_json_contains(&avro.json_string, "\"arrowIntervalUnit\":\"daytime\""); + } + + #[test] + fn test_fixed_size_list_extra() { + let child = ArrowField::new("item", DataType::Int32, false); + let dt = DataType::FixedSizeList(Arc::new(child), 3); + let schema = single_field_schema(ArrowField::new("triples", dt, false)); + let avro = AvroSchema::try_from(&schema).unwrap(); + assert_json_contains(&avro.json_string, "\"arrowFixedSize\":3"); + } + + #[test] + fn test_map_duration_value_extra() { + let val_field = ArrowField::new("value", DataType::Duration(TimeUnit::Second), true); + let entries_struct = ArrowField::new( + "entries", + DataType::Struct(Fields::from(vec![ + ArrowField::new("key", DataType::Utf8, false), + val_field, + ])), + false, + ); + let map_dt = DataType::Map(Arc::new(entries_struct), false); + let schema = single_field_schema(ArrowField::new("metrics", map_dt, false)); + let avro = AvroSchema::try_from(&schema).unwrap(); + assert_json_contains(&avro.json_string, "\"arrowDurationUnit\":\"second\""); + } + + #[test] + fn test_schema_with_non_string_defaults_decodes_successfully() { + let schema_json = r#"{ + "type": "record", + "name": "R", + "fields": [ + {"name": "a", "type": "int", "default": 0}, + {"name": "b", "type": {"type": "array", "items": "long"}, "default": [1, 2, 3]}, + {"name": "c", "type": {"type": "map", "values": "double"}, "default": {"x": 1.5, "y": 2.5}}, + {"name": "inner", "type": {"type": "record", "name": "Inner", "fields": [ + {"name": "flag", "type": "boolean", "default": true}, + {"name": "name", "type": "string", "default": "hi"} + ]}, "default": {"flag": false, "name": "d"}}, + {"name": "u", "type": ["int", "null"], "default": 42} + ] + }"#; + + let schema: Schema = serde_json::from_str(schema_json).expect("schema should parse"); + match &schema { + Schema::Complex(ComplexType::Record(_)) => {} + other => panic!("expected record schema, got: {:?}", other), + } + // Avro to Arrow conversion + let field = crate::codec::AvroField::try_from(&schema) + .expect("Avro->Arrow conversion should succeed"); + let arrow_field = field.field(); + + // Build expected Arrow field + let expected_list_item = ArrowField::new( + arrow_schema::Field::LIST_FIELD_DEFAULT_NAME, + DataType::Int64, + false, + ); + let expected_b = ArrowField::new("b", DataType::List(Arc::new(expected_list_item)), false); + + let expected_map_value = ArrowField::new("value", DataType::Float64, false); + let expected_entries = ArrowField::new( + "entries", + DataType::Struct(Fields::from(vec![ + ArrowField::new("key", DataType::Utf8, false), + expected_map_value, + ])), + false, + ); + let expected_c = + ArrowField::new("c", DataType::Map(Arc::new(expected_entries), false), false); + + let expected_inner = ArrowField::new( + "inner", + DataType::Struct(Fields::from(vec![ + ArrowField::new("flag", DataType::Boolean, false), + ArrowField::new("name", DataType::Utf8, false), + ])), + false, + ); + + let expected = ArrowField::new( + "R", + DataType::Struct(Fields::from(vec![ + ArrowField::new("a", DataType::Int32, false), + expected_b, + expected_c, + expected_inner, + ArrowField::new("u", DataType::Int32, true), + ])), + false, + ); + + assert_eq!(arrow_field, expected); + } + + #[test] + fn default_order_is_consistent() { + let arrow_schema = ArrowSchema::new(vec![ArrowField::new("s", DataType::Utf8, true)]); + let a = AvroSchema::try_from(&arrow_schema).unwrap().json_string; + let b = AvroSchema::from_arrow_with_options(&arrow_schema, None); + assert_eq!(a, b.unwrap().json_string); + } } diff --git a/arrow-avro/src/writer/encoder.rs b/arrow-avro/src/writer/encoder.rs new file mode 100644 index 000000000000..fd619249617e --- /dev/null +++ b/arrow-avro/src/writer/encoder.rs @@ -0,0 +1,2002 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Avro Encoder for Arrow types. + +use crate::codec::{AvroDataType, AvroField, Codec}; +use crate::schema::Nullability; +use arrow_array::cast::AsArray; +use arrow_array::types::{ + ArrowPrimitiveType, Float32Type, Float64Type, Int32Type, Int64Type, IntervalDayTimeType, + IntervalMonthDayNanoType, IntervalYearMonthType, TimestampMicrosecondType, +}; +use arrow_array::{ + Array, Decimal128Array, Decimal256Array, DictionaryArray, FixedSizeBinaryArray, + GenericBinaryArray, GenericListArray, GenericStringArray, LargeListArray, ListArray, MapArray, + OffsetSizeTrait, PrimitiveArray, RecordBatch, StringArray, StructArray, +}; +#[cfg(feature = "small_decimals")] +use arrow_array::{Decimal32Array, Decimal64Array}; +use arrow_buffer::NullBuffer; +use arrow_schema::{ArrowError, DataType, Field, IntervalUnit, Schema as ArrowSchema, TimeUnit}; +use std::io::Write; +use std::sync::Arc; +use uuid::Uuid; + +/// Encode a single Avro-`long` using ZigZag + variable length, buffered. +/// +/// Spec: +#[inline] +pub fn write_long(out: &mut W, value: i64) -> Result<(), ArrowError> { + let mut zz = ((value << 1) ^ (value >> 63)) as u64; + // At most 10 bytes for 64-bit varint + let mut buf = [0u8; 10]; + let mut i = 0; + while (zz & !0x7F) != 0 { + buf[i] = ((zz & 0x7F) as u8) | 0x80; + i += 1; + zz >>= 7; + } + buf[i] = (zz & 0x7F) as u8; + i += 1; + out.write_all(&buf[..i]) + .map_err(|e| ArrowError::IoError(format!("write long: {e}"), e)) +} + +#[inline] +fn write_int(out: &mut W, value: i32) -> Result<(), ArrowError> { + write_long(out, value as i64) +} + +#[inline] +fn write_len_prefixed(out: &mut W, bytes: &[u8]) -> Result<(), ArrowError> { + write_long(out, bytes.len() as i64)?; + out.write_all(bytes) + .map_err(|e| ArrowError::IoError(format!("write bytes: {e}"), e)) +} + +#[inline] +fn write_bool(out: &mut W, v: bool) -> Result<(), ArrowError> { + out.write_all(&[if v { 1 } else { 0 }]) + .map_err(|e| ArrowError::IoError(format!("write bool: {e}"), e)) +} + +/// Minimal two's-complement big-endian representation helper for Avro decimal (bytes). +/// +/// For positive numbers, trim leading 0x00 until an essential byte is reached. +/// For negative numbers, trim leading 0xFF until an essential byte is reached. +/// The resulting slice still encodes the same signed value. +/// +/// See Avro spec: decimal over `bytes` uses two's-complement big-endian +/// representation of the unscaled integer value. 1.11.1 specification. +#[inline] +fn minimal_twos_complement(be: &[u8]) -> &[u8] { + if be.is_empty() { + return be; + } + let sign_byte = if (be[0] & 0x80) != 0 { 0xFF } else { 0x00 }; + let mut k = 0usize; + while k < be.len() && be[k] == sign_byte { + k += 1; + } + if k == 0 { + return be; + } + if k == be.len() { + return &be[be.len() - 1..]; + } + let drop = if ((be[k] ^ sign_byte) & 0x80) == 0 { + k + } else { + k - 1 + }; + &be[drop..] +} + +/// Sign-extend (or validate/truncate) big-endian integer bytes to exactly `n` bytes. +/// +/// +/// - If shorter than `n`, the slice is sign-extended by left-padding with the +/// sign byte (`0x00` for positive, `0xFF` for negative). +/// - If longer than `n`, the slice is truncated from the left. An overflow error +/// is returned if any of the truncated bytes are not redundant sign bytes, +/// or if the resulting value's sign bit would differ from the original. +/// - If the slice is already `n` bytes long, it is copied. +/// +/// Used for encoding Avro decimal values into `fixed(N)` fields. +#[inline] +fn write_sign_extended( + out: &mut W, + src_be: &[u8], + n: usize, +) -> Result<(), ArrowError> { + let len = src_be.len(); + if len == n { + return out + .write_all(src_be) + .map_err(|e| ArrowError::IoError(format!("write decimal fixed: {e}"), e)); + } + let sign_byte = if len > 0 && (src_be[0] & 0x80) != 0 { + 0xFF + } else { + 0x00 + }; + if len > n { + let extra = len - n; + if n == 0 && src_be.iter().all(|&b| b == sign_byte) { + return Ok(()); + } + // All truncated bytes must equal the sign byte, and the MSB of the first + // retained byte must match the sign (otherwise overflow). + if src_be[..extra].iter().any(|&b| b != sign_byte) + || ((src_be[extra] ^ sign_byte) & 0x80) != 0 + { + return Err(ArrowError::InvalidArgumentError(format!( + "Decimal value with {len} bytes cannot be represented in {n} bytes without overflow", + ))); + } + return out + .write_all(&src_be[extra..]) + .map_err(|e| ArrowError::IoError(format!("write decimal fixed: {e}"), e)); + } + // len < n: prepend sign bytes (sign extension) then the payload + let pad_len = n - len; + // Fixed-size stack pads to avoid heap allocation on the hot path + const ZPAD: [u8; 64] = [0x00; 64]; + const FPAD: [u8; 64] = [0xFF; 64]; + let pad = if sign_byte == 0x00 { + &ZPAD[..] + } else { + &FPAD[..] + }; + // Emit padding in 64‑byte chunks (minimizes write calls without allocating), + // then write the original bytes. + let mut rem = pad_len; + while rem >= pad.len() { + out.write_all(pad) + .map_err(|e| ArrowError::IoError(format!("write decimal fixed: {e}"), e))?; + rem -= pad.len(); + } + if rem > 0 { + out.write_all(&pad[..rem]) + .map_err(|e| ArrowError::IoError(format!("write decimal fixed: {e}"), e))?; + } + out.write_all(src_be) + .map_err(|e| ArrowError::IoError(format!("write decimal fixed: {e}"), e)) +} + +/// Write the union branch index for an optional field. +/// +/// Branch index is 0-based per Avro unions: +/// - Null-first (default): null => 0, value => 1 +/// - Null-second (Impala): value => 0, null => 1 +fn write_optional_index( + out: &mut W, + is_null: bool, + null_order: Nullability, +) -> Result<(), ArrowError> { + let byte = union_value_branch_byte(null_order, is_null); + out.write_all(&[byte]) + .map_err(|e| ArrowError::IoError(format!("write union branch: {e}"), e)) +} + +#[derive(Debug, Clone)] +enum NullState { + NonNullable, + NullableNoNulls { + union_value_byte: u8, + }, + Nullable { + nulls: NullBuffer, + null_order: Nullability, + }, +} + +/// Arrow to Avro FieldEncoder: +/// - Holds the inner `Encoder` (by value) +/// - Carries the per-site nullability **state** as a single enum that enforces invariants +pub struct FieldEncoder<'a> { + encoder: Encoder<'a>, + null_state: NullState, +} + +impl<'a> FieldEncoder<'a> { + fn make_encoder( + array: &'a dyn Array, + field: &Field, + plan: &FieldPlan, + nullability: Option, + ) -> Result { + let encoder = match plan { + FieldPlan::Scalar => match array.data_type() { + DataType::Boolean => Encoder::Boolean(BooleanEncoder(array.as_boolean())), + DataType::Utf8 => { + Encoder::Utf8(Utf8GenericEncoder::(array.as_string::())) + } + DataType::LargeUtf8 => { + Encoder::Utf8Large(Utf8GenericEncoder::(array.as_string::())) + } + DataType::Int32 => Encoder::Int(IntEncoder(array.as_primitive::())), + DataType::Int64 => Encoder::Long(LongEncoder(array.as_primitive::())), + DataType::Float32 => { + Encoder::Float32(F32Encoder(array.as_primitive::())) + } + DataType::Float64 => { + Encoder::Float64(F64Encoder(array.as_primitive::())) + } + DataType::Binary => Encoder::Binary(BinaryEncoder(array.as_binary::())), + DataType::LargeBinary => { + Encoder::LargeBinary(BinaryEncoder(array.as_binary::())) + } + DataType::FixedSizeBinary(len) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::SchemaError("Expected FixedSizeBinaryArray".into()) + })?; + Encoder::Fixed(FixedEncoder(arr)) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => Encoder::Timestamp(LongEncoder( + array.as_primitive::(), + )), + DataType::Interval(unit) => match unit { + IntervalUnit::MonthDayNano => { + Encoder::IntervalMonthDayNano(DurationEncoder( + array.as_primitive::(), + )) + } + IntervalUnit::YearMonth => { + Encoder::IntervalYearMonth(DurationEncoder( + array.as_primitive::(), + )) + } + IntervalUnit::DayTime => Encoder::IntervalDayTime(DurationEncoder( + array.as_primitive::(), + )), + } + DataType::Duration(_) => { + return Err(ArrowError::NotYetImplemented( + "Avro writer: Arrow Duration(TimeUnit) has no standard Avro mapping; cast to Interval(MonthDayNano) to use Avro 'duration'".into(), + )); + } + other => { + return Err(ArrowError::NotYetImplemented(format!( + "Avro scalar type not yet supported: {other:?}" + ))); + } + }, + FieldPlan::Struct { encoders } => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::SchemaError("Expected StructArray".into()))?; + Encoder::Struct(Box::new(StructEncoder::try_new(arr, encoders)?)) + } + FieldPlan::List { + items_nullability, + item_plan, + } => match array.data_type() { + DataType::List(_) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::SchemaError("Expected ListArray".into()))?; + Encoder::List(Box::new(ListEncoder32::try_new( + arr, + *items_nullability, + item_plan.as_ref(), + )?)) + } + DataType::LargeList(_) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::SchemaError("Expected LargeListArray".into()))?; + Encoder::LargeList(Box::new(ListEncoder64::try_new( + arr, + *items_nullability, + item_plan.as_ref(), + )?)) + } + other => { + return Err(ArrowError::SchemaError(format!( + "Avro array site requires Arrow List/LargeList, found: {other:?}" + ))) + } + }, + FieldPlan::Decimal {size} => match array.data_type() { + #[cfg(feature = "small_decimals")] + DataType::Decimal32(_,_) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::SchemaError("Expected Decimal32Array".into()))?; + Encoder::Decimal32(DecimalEncoder::<4, Decimal32Array>::new(arr, *size)) + } + #[cfg(feature = "small_decimals")] + DataType::Decimal64(_,_) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::SchemaError("Expected Decimal64Array".into()))?; + Encoder::Decimal64(DecimalEncoder::<8, Decimal64Array>::new(arr, *size)) + } + DataType::Decimal128(_,_) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::SchemaError("Expected Decimal128Array".into()))?; + Encoder::Decimal128(DecimalEncoder::<16, Decimal128Array>::new(arr, *size)) + } + DataType::Decimal256(_,_) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::SchemaError("Expected Decimal256Array".into()))?; + Encoder::Decimal256(DecimalEncoder::<32, Decimal256Array>::new(arr, *size)) + } + other => { + return Err(ArrowError::SchemaError(format!( + "Avro decimal site requires Arrow Decimal 32, 64, 128, or 256, found: {other:?}" + ))) + } + }, + FieldPlan::Uuid => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::SchemaError("Expected FixedSizeBinaryArray".into()))?; + Encoder::Uuid(UuidEncoder(arr)) + } + FieldPlan::Map { values_nullability, + value_plan } => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::SchemaError("Expected MapArray".into()))?; + Encoder::Map(Box::new(MapEncoder::try_new(arr, *values_nullability, value_plan.as_ref())?)) + } + FieldPlan::Enum { symbols} => match array.data_type() { + DataType::Dictionary(key_dt, value_dt) => { + if **key_dt != DataType::Int32 || **value_dt != DataType::Utf8 { + return Err(ArrowError::SchemaError( + "Avro enum requires Dictionary".into(), + )); + } + let dict = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::SchemaError("Expected DictionaryArray".into()) + })?; + + let values = dict + .values() + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::SchemaError("Dictionary values must be Utf8".into()) + })?; + if values.len() != symbols.len() { + return Err(ArrowError::SchemaError(format!( + "Enum symbol length {} != dictionary size {}", + symbols.len(), + values.len() + ))); + } + for i in 0..values.len() { + if values.value(i) != symbols[i].as_str() { + return Err(ArrowError::SchemaError(format!( + "Enum symbol mismatch at {i}: schema='{}' dict='{}'", + symbols[i], + values.value(i) + ))); + } + } + let keys = dict.keys(); + Encoder::Enum(EnumEncoder { keys }) + } + other => { + return Err(ArrowError::SchemaError(format!( + "Avro enum site requires DataType::Dictionary, found: {other:?}" + ))) + } + } + other => { + return Err(ArrowError::NotYetImplemented(format!( + "Avro writer: {other:?} not yet supported", + ))); + } + }; + // Compute the effective null state from writer-declared nullability and data nulls. + let null_state = match (nullability, array.null_count() > 0) { + (None, false) => NullState::NonNullable, + (None, true) => { + return Err(ArrowError::InvalidArgumentError(format!( + "Avro site '{}' is non-nullable, but array contains nulls", + field.name() + ))); + } + (Some(order), false) => { + // Optimization: drop any bitmap; emit a constant "value" branch byte. + NullState::NullableNoNulls { + union_value_byte: union_value_branch_byte(order, false), + } + } + (Some(null_order), true) => { + let Some(nulls) = array.nulls().cloned() else { + return Err(ArrowError::InvalidArgumentError(format!( + "Array for Avro site '{}' reports nulls but has no null buffer", + field.name() + ))); + }; + NullState::Nullable { nulls, null_order } + } + }; + Ok(Self { + encoder, + null_state, + }) + } + + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + match &self.null_state { + NullState::NonNullable => {} + NullState::NullableNoNulls { union_value_byte } => out + .write_all(&[*union_value_byte]) + .map_err(|e| ArrowError::IoError(format!("write union value branch: {e}"), e))?, + NullState::Nullable { nulls, null_order } if nulls.is_null(idx) => { + return write_optional_index(out, true, *null_order); // no value to write + } + NullState::Nullable { null_order, .. } => { + write_optional_index(out, false, *null_order)?; + } + } + self.encoder.encode(out, idx) + } +} + +fn union_value_branch_byte(null_order: Nullability, is_null: bool) -> u8 { + let nulls_first = null_order == Nullability::default(); + if nulls_first == is_null { + 0x00 + } else { + 0x02 + } +} + +/// Per‑site encoder plan for a field. This mirrors the Avro structure, so nested +/// optional branch order can be honored exactly as declared by the schema. +#[derive(Debug, Clone)] +enum FieldPlan { + /// Non-nested scalar/logical type + Scalar, + /// Record/Struct with Avro‑ordered children + Struct { encoders: Vec }, + /// Array with item‑site nullability and nested plan + List { + items_nullability: Option, + item_plan: Box, + }, + /// Avro decimal logical type (bytes or fixed). `size=None` => bytes(decimal), `Some(n)` => fixed(n) + Decimal { size: Option }, + /// Avro UUID logical type (fixed) + Uuid, + /// Avro map with value‑site nullability and nested plan + Map { + values_nullability: Option, + value_plan: Box, + }, + /// Avro enum; maps to Arrow Dictionary with dictionary values + /// exactly equal and ordered as the Avro enum `symbols`. + Enum { symbols: Arc<[String]> }, +} + +#[derive(Debug, Clone)] +struct FieldBinding { + /// Index of the Arrow field/column associated with this Avro field site + arrow_index: usize, + /// Nullability/order for this site (None for required fields) + nullability: Option, + /// Nested plan for this site + plan: FieldPlan, +} + +/// Builder for `RecordEncoder` write plan +#[derive(Debug)] +pub struct RecordEncoderBuilder<'a> { + avro_root: &'a AvroField, + arrow_schema: &'a ArrowSchema, +} + +impl<'a> RecordEncoderBuilder<'a> { + /// Create a new builder from the Avro root and Arrow schema. + pub fn new(avro_root: &'a AvroField, arrow_schema: &'a ArrowSchema) -> Self { + Self { + avro_root, + arrow_schema, + } + } + + /// Build the `RecordEncoder` by walking the Avro **record** root in Avro order, + /// resolving each field to an Arrow index by name. + pub fn build(self) -> Result { + let avro_root_dt = self.avro_root.data_type(); + let Codec::Struct(root_fields) = avro_root_dt.codec() else { + return Err(ArrowError::SchemaError( + "Top-level Avro schema must be a record/struct".into(), + )); + }; + let mut columns = Vec::with_capacity(root_fields.len()); + for root_field in root_fields.as_ref() { + let name = root_field.name(); + let arrow_index = self.arrow_schema.index_of(name).map_err(|e| { + ArrowError::SchemaError(format!("Schema mismatch for field '{name}': {e}")) + })?; + columns.push(FieldBinding { + arrow_index, + nullability: root_field.data_type().nullability(), + plan: FieldPlan::build( + root_field.data_type(), + self.arrow_schema.field(arrow_index), + )?, + }); + } + Ok(RecordEncoder { columns }) + } +} + +/// A pre-computed plan for encoding a `RecordBatch` to Avro. +/// +/// Derived from an Avro schema and an Arrow schema. It maps +/// top-level Avro fields to Arrow columns and contains a nested encoding plan +/// for each column. +#[derive(Debug, Clone)] +pub struct RecordEncoder { + columns: Vec, +} + +impl RecordEncoder { + fn prepare_for_batch<'a>( + &'a self, + batch: &'a RecordBatch, + ) -> Result>, ArrowError> { + let schema_binding = batch.schema(); + let fields = schema_binding.fields(); + let arrays = batch.columns(); + let mut out = Vec::with_capacity(self.columns.len()); + for col_plan in self.columns.iter() { + let arrow_index = col_plan.arrow_index; + let array = arrays.get(arrow_index).ok_or_else(|| { + ArrowError::SchemaError(format!("Column index {arrow_index} out of range")) + })?; + let field = fields[arrow_index].as_ref(); + let encoder = prepare_value_site_encoder( + array.as_ref(), + field, + col_plan.nullability, + &col_plan.plan, + )?; + out.push(encoder); + } + Ok(out) + } + + /// Encode a `RecordBatch` using this encoder plan. + /// + /// Tip: Wrap `out` in a `std::io::BufWriter` to reduce the overhead of many small writes. + pub fn encode(&self, out: &mut W, batch: &RecordBatch) -> Result<(), ArrowError> { + let mut column_encoders = self.prepare_for_batch(batch)?; + for row in 0..batch.num_rows() { + for encoder in column_encoders.iter_mut() { + encoder.encode(out, row)?; + } + } + Ok(()) + } +} + +fn find_struct_child_index(fields: &arrow_schema::Fields, name: &str) -> Option { + fields.iter().position(|f| f.name() == name) +} + +fn find_map_value_field_index(fields: &arrow_schema::Fields) -> Option { + // Prefer common Arrow field names; fall back to second child if exactly two + find_struct_child_index(fields, "value") + .or_else(|| find_struct_child_index(fields, "values")) + .or_else(|| if fields.len() == 2 { Some(1) } else { None }) +} + +impl FieldPlan { + fn build(avro_dt: &AvroDataType, arrow_field: &Field) -> Result { + if let DataType::FixedSizeBinary(len) = arrow_field.data_type() { + // Extension-based detection (only when the feature is enabled) + let ext_is_uuid = { + #[cfg(feature = "canonical_extension_types")] + { + matches!( + arrow_field.extension_type_name(), + Some("arrow.uuid") | Some("uuid") + ) + } + #[cfg(not(feature = "canonical_extension_types"))] + { + false + } + }; + let md_is_uuid = arrow_field + .metadata() + .get("logicalType") + .map(|s| s.as_str()) + == Some("uuid"); + if ext_is_uuid || md_is_uuid { + if *len != 16 { + return Err(ArrowError::InvalidArgumentError( + "logicalType=uuid requires FixedSizeBinary(16)".into(), + )); + } + return Ok(FieldPlan::Uuid); + } + } + match avro_dt.codec() { + Codec::Struct(avro_fields) => { + let fields = match arrow_field.data_type() { + DataType::Struct(struct_fields) => struct_fields, + other => { + return Err(ArrowError::SchemaError(format!( + "Avro struct maps to Arrow Struct, found: {other:?}" + ))) + } + }; + let mut encoders = Vec::with_capacity(avro_fields.len()); + for avro_field in avro_fields.iter() { + let name = avro_field.name().to_string(); + let idx = find_struct_child_index(fields, &name).ok_or_else(|| { + ArrowError::SchemaError(format!( + "Struct field '{name}' not present in Arrow field '{}'", + arrow_field.name() + )) + })?; + encoders.push(FieldBinding { + arrow_index: idx, + nullability: avro_field.data_type().nullability(), + plan: FieldPlan::build(avro_field.data_type(), fields[idx].as_ref())?, + }); + } + Ok(FieldPlan::Struct { encoders }) + } + Codec::List(items_dt) => match arrow_field.data_type() { + DataType::List(field_ref) => Ok(FieldPlan::List { + items_nullability: items_dt.nullability(), + item_plan: Box::new(FieldPlan::build(items_dt.as_ref(), field_ref.as_ref())?), + }), + DataType::LargeList(field_ref) => Ok(FieldPlan::List { + items_nullability: items_dt.nullability(), + item_plan: Box::new(FieldPlan::build(items_dt.as_ref(), field_ref.as_ref())?), + }), + other => Err(ArrowError::SchemaError(format!( + "Avro array maps to Arrow List/LargeList, found: {other:?}" + ))), + }, + Codec::Map(values_dt) => { + let entries_field = match arrow_field.data_type() { + DataType::Map(entries, _sorted) => entries.as_ref(), + other => { + return Err(ArrowError::SchemaError(format!( + "Avro map maps to Arrow DataType::Map, found: {other:?}" + ))) + } + }; + let entries_struct_fields = match entries_field.data_type() { + DataType::Struct(fs) => fs, + other => { + return Err(ArrowError::SchemaError(format!( + "Arrow Map entries must be Struct, found: {other:?}" + ))) + } + }; + let value_idx = + find_map_value_field_index(entries_struct_fields).ok_or_else(|| { + ArrowError::SchemaError("Map entries struct missing value field".into()) + })?; + let value_field = entries_struct_fields[value_idx].as_ref(); + let value_plan = FieldPlan::build(values_dt.as_ref(), value_field)?; + Ok(FieldPlan::Map { + values_nullability: values_dt.nullability(), + value_plan: Box::new(value_plan), + }) + } + Codec::Enum(symbols) => match arrow_field.data_type() { + DataType::Dictionary(key_dt, value_dt) => { + if **key_dt != DataType::Int32 { + return Err(ArrowError::SchemaError( + "Avro enum requires Dictionary".into(), + )); + } + if **value_dt != DataType::Utf8 { + return Err(ArrowError::SchemaError( + "Avro enum requires Dictionary".into(), + )); + } + Ok(FieldPlan::Enum { + symbols: symbols.clone(), + }) + } + other => Err(ArrowError::SchemaError(format!( + "Avro enum maps to Arrow Dictionary, found: {other:?}" + ))), + }, + // decimal site (bytes or fixed(N)) with precision/scale validation + Codec::Decimal(precision, scale_opt, fixed_size_opt) => { + let (ap, as_) = match arrow_field.data_type() { + #[cfg(feature = "small_decimals")] + DataType::Decimal32(p, s) => (*p as usize, *s as i32), + #[cfg(feature = "small_decimals")] + DataType::Decimal64(p, s) => (*p as usize, *s as i32), + DataType::Decimal128(p, s) => (*p as usize, *s as i32), + DataType::Decimal256(p, s) => (*p as usize, *s as i32), + other => { + return Err(ArrowError::SchemaError(format!( + "Avro decimal requires Arrow decimal, got {other:?} for field '{}'", + arrow_field.name() + ))) + } + }; + let sc = scale_opt.unwrap_or(0) as i32; // Avro scale defaults to 0 if absent + if ap != *precision || as_ != sc { + return Err(ArrowError::SchemaError(format!( + "Decimal precision/scale mismatch for field '{}': Avro({precision},{sc}) vs Arrow({ap},{as_})", + arrow_field.name() + ))); + } + Ok(FieldPlan::Decimal { + size: *fixed_size_opt, + }) + } + Codec::Interval => match arrow_field.data_type() { + DataType::Interval(IntervalUnit::MonthDayNano | IntervalUnit::YearMonth | IntervalUnit::DayTime + ) => Ok(FieldPlan::Scalar), + other => Err(ArrowError::SchemaError(format!( + "Avro duration logical type requires Arrow Interval(MonthDayNano), found: {other:?}" + ))), + } + _ => Ok(FieldPlan::Scalar), + } + } +} + +enum Encoder<'a> { + Boolean(BooleanEncoder<'a>), + Int(IntEncoder<'a, Int32Type>), + Long(LongEncoder<'a, Int64Type>), + Timestamp(LongEncoder<'a, TimestampMicrosecondType>), + Float32(F32Encoder<'a>), + Float64(F64Encoder<'a>), + Binary(BinaryEncoder<'a, i32>), + LargeBinary(BinaryEncoder<'a, i64>), + Utf8(Utf8Encoder<'a>), + Utf8Large(Utf8LargeEncoder<'a>), + List(Box>), + LargeList(Box>), + Struct(Box>), + /// Avro `fixed` encoder (raw bytes, no length) + Fixed(FixedEncoder<'a>), + /// Avro `uuid` logical type encoder (string with RFC‑4122 hyphenated text) + Uuid(UuidEncoder<'a>), + /// Avro `duration` logical type (Arrow Interval(MonthDayNano)) encoder + IntervalMonthDayNano(DurationEncoder<'a, IntervalMonthDayNanoType>), + /// Avro `duration` logical type (Arrow Interval(YearMonth)) encoder + IntervalYearMonth(DurationEncoder<'a, IntervalYearMonthType>), + /// Avro `duration` logical type (Arrow Interval(DayTime)) encoder + IntervalDayTime(DurationEncoder<'a, IntervalDayTimeType>), + #[cfg(feature = "small_decimals")] + Decimal32(Decimal32Encoder<'a>), + #[cfg(feature = "small_decimals")] + Decimal64(Decimal64Encoder<'a>), + Decimal128(Decimal128Encoder<'a>), + Decimal256(Decimal256Encoder<'a>), + /// Avro `enum` encoder: writes the key (int) as the enum index. + Enum(EnumEncoder<'a>), + Map(Box>), +} + +impl<'a> Encoder<'a> { + /// Encode the value at `idx`. + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + match self { + Encoder::Boolean(e) => e.encode(out, idx), + Encoder::Int(e) => e.encode(out, idx), + Encoder::Long(e) => e.encode(out, idx), + Encoder::Timestamp(e) => e.encode(out, idx), + Encoder::Float32(e) => e.encode(out, idx), + Encoder::Float64(e) => e.encode(out, idx), + Encoder::Binary(e) => e.encode(out, idx), + Encoder::LargeBinary(e) => e.encode(out, idx), + Encoder::Utf8(e) => e.encode(out, idx), + Encoder::Utf8Large(e) => e.encode(out, idx), + Encoder::List(e) => e.encode(out, idx), + Encoder::LargeList(e) => e.encode(out, idx), + Encoder::Struct(e) => e.encode(out, idx), + Encoder::Fixed(e) => (e).encode(out, idx), + Encoder::Uuid(e) => (e).encode(out, idx), + Encoder::IntervalMonthDayNano(e) => (e).encode(out, idx), + Encoder::IntervalYearMonth(e) => (e).encode(out, idx), + Encoder::IntervalDayTime(e) => (e).encode(out, idx), + #[cfg(feature = "small_decimals")] + Encoder::Decimal32(e) => (e).encode(out, idx), + #[cfg(feature = "small_decimals")] + Encoder::Decimal64(e) => (e).encode(out, idx), + Encoder::Decimal128(e) => (e).encode(out, idx), + Encoder::Decimal256(e) => (e).encode(out, idx), + Encoder::Map(e) => (e).encode(out, idx), + Encoder::Enum(e) => (e).encode(out, idx), + } + } +} + +struct BooleanEncoder<'a>(&'a arrow_array::BooleanArray); +impl BooleanEncoder<'_> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + write_bool(out, self.0.value(idx)) + } +} + +/// Generic Avro `int` encoder for primitive arrays with `i32` native values. +struct IntEncoder<'a, P: ArrowPrimitiveType>(&'a PrimitiveArray

); +impl<'a, P: ArrowPrimitiveType> IntEncoder<'a, P> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + write_int(out, self.0.value(idx)) + } +} + +/// Generic Avro `long` encoder for primitive arrays with `i64` native values. +struct LongEncoder<'a, P: ArrowPrimitiveType>(&'a PrimitiveArray

); +impl<'a, P: ArrowPrimitiveType> LongEncoder<'a, P> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + write_long(out, self.0.value(idx)) + } +} + +/// Unified binary encoder generic over offset size (i32/i64). +struct BinaryEncoder<'a, O: OffsetSizeTrait>(&'a GenericBinaryArray); +impl<'a, O: OffsetSizeTrait> BinaryEncoder<'a, O> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + write_len_prefixed(out, self.0.value(idx)) + } +} + +struct F32Encoder<'a>(&'a arrow_array::Float32Array); +impl F32Encoder<'_> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + // Avro float: 4 bytes, IEEE-754 little-endian + let bits = self.0.value(idx).to_bits(); + out.write_all(&bits.to_le_bytes()) + .map_err(|e| ArrowError::IoError(format!("write f32: {e}"), e)) + } +} + +struct F64Encoder<'a>(&'a arrow_array::Float64Array); +impl F64Encoder<'_> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + // Avro double: 8 bytes, IEEE-754 little-endian + let bits = self.0.value(idx).to_bits(); + out.write_all(&bits.to_le_bytes()) + .map_err(|e| ArrowError::IoError(format!("write f64: {e}"), e)) + } +} + +struct Utf8GenericEncoder<'a, O: OffsetSizeTrait>(&'a GenericStringArray); + +impl<'a, O: OffsetSizeTrait> Utf8GenericEncoder<'a, O> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + write_len_prefixed(out, self.0.value(idx).as_bytes()) + } +} + +type Utf8Encoder<'a> = Utf8GenericEncoder<'a, i32>; +type Utf8LargeEncoder<'a> = Utf8GenericEncoder<'a, i64>; + +/// Internal key array kind used by Map encoder. +enum KeyKind<'a> { + Utf8(&'a GenericStringArray), + LargeUtf8(&'a GenericStringArray), +} +struct MapEncoder<'a> { + map: &'a MapArray, + keys: KeyKind<'a>, + values: FieldEncoder<'a>, + keys_offset: usize, + values_offset: usize, +} + +impl<'a> MapEncoder<'a> { + fn try_new( + map: &'a MapArray, + values_nullability: Option, + value_plan: &FieldPlan, + ) -> Result { + let keys_arr = map.keys(); + let keys_kind = match keys_arr.data_type() { + DataType::Utf8 => KeyKind::Utf8(keys_arr.as_string::()), + DataType::LargeUtf8 => KeyKind::LargeUtf8(keys_arr.as_string::()), + other => { + return Err(ArrowError::SchemaError(format!( + "Avro map requires string keys; Arrow key type must be Utf8/LargeUtf8, found: {other:?}" + ))) + } + }; + + let entries_struct_fields = match map.data_type() { + DataType::Map(entries, _) => match entries.data_type() { + DataType::Struct(fs) => fs, + other => { + return Err(ArrowError::SchemaError(format!( + "Arrow Map entries must be Struct, found: {other:?}" + ))) + } + }, + _ => { + return Err(ArrowError::SchemaError( + "Expected MapArray with DataType::Map".into(), + )) + } + }; + + let v_idx = find_map_value_field_index(entries_struct_fields).ok_or_else(|| { + ArrowError::SchemaError("Map entries struct missing value field".into()) + })?; + let value_field = entries_struct_fields[v_idx].as_ref(); + + let values_enc = prepare_value_site_encoder( + map.values().as_ref(), + value_field, + values_nullability, + value_plan, + )?; + + Ok(Self { + map, + keys: keys_kind, + values: values_enc, + keys_offset: keys_arr.offset(), + values_offset: map.values().offset(), + }) + } + + fn encode_map_entries( + out: &mut W, + keys: &GenericStringArray, + keys_offset: usize, + start: usize, + end: usize, + mut write_item: impl FnMut(&mut W, usize) -> Result<(), ArrowError>, + ) -> Result<(), ArrowError> + where + W: Write + ?Sized, + O: OffsetSizeTrait, + { + encode_blocked_range(out, start, end, |out, j| { + let j_key = j.saturating_sub(keys_offset); + write_len_prefixed(out, keys.value(j_key).as_bytes())?; + write_item(out, j) + }) + } + + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + let offsets = self.map.offsets(); + let start = offsets[idx] as usize; + let end = offsets[idx + 1] as usize; + + let mut write_item = |out: &mut W, j: usize| { + let j_val = j.saturating_sub(self.values_offset); + self.values.encode(out, j_val) + }; + + match self.keys { + KeyKind::Utf8(arr) => MapEncoder::<'a>::encode_map_entries( + out, + arr, + self.keys_offset, + start, + end, + write_item, + ), + KeyKind::LargeUtf8(arr) => MapEncoder::<'a>::encode_map_entries( + out, + arr, + self.keys_offset, + start, + end, + write_item, + ), + } + } +} + +/// Avro `enum` encoder for Arrow `DictionaryArray`. +/// +/// Per Avro spec, an enum is encoded as an **int** equal to the +/// zero-based position of the symbol in the schema’s `symbols` list. +/// We validate at construction that the dictionary values equal the symbols, +/// so we can directly write the key value here. +struct EnumEncoder<'a> { + keys: &'a PrimitiveArray, +} +impl EnumEncoder<'_> { + fn encode(&mut self, out: &mut W, row: usize) -> Result<(), ArrowError> { + write_int(out, self.keys.value(row)) + } +} + +struct StructEncoder<'a> { + encoders: Vec>, +} + +impl<'a> StructEncoder<'a> { + fn try_new( + array: &'a StructArray, + field_bindings: &[FieldBinding], + ) -> Result { + let DataType::Struct(fields) = array.data_type() else { + return Err(ArrowError::SchemaError("Expected Struct".into())); + }; + let mut encoders = Vec::with_capacity(field_bindings.len()); + for field_binding in field_bindings { + let idx = field_binding.arrow_index; + let column = array.columns().get(idx).ok_or_else(|| { + ArrowError::SchemaError(format!("Struct child index {idx} out of range")) + })?; + let field = fields.get(idx).ok_or_else(|| { + ArrowError::SchemaError(format!("Struct child index {idx} out of range")) + })?; + let encoder = prepare_value_site_encoder( + column.as_ref(), + field, + field_binding.nullability, + &field_binding.plan, + )?; + encoders.push(encoder); + } + Ok(Self { encoders }) + } + + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + for encoder in self.encoders.iter_mut() { + encoder.encode(out, idx)?; + } + Ok(()) + } +} + +/// Encode a blocked range of items with Avro array block framing. +/// +/// `write_item` must take `(out, index)` to maintain the "out-first" convention. +fn encode_blocked_range( + out: &mut W, + start: usize, + end: usize, + mut write_item: F, +) -> Result<(), ArrowError> +where + F: FnMut(&mut W, usize) -> Result<(), ArrowError>, +{ + let len = end.saturating_sub(start); + if len == 0 { + // Zero-length terminator per Avro spec. + write_long(out, 0)?; + return Ok(()); + } + // Emit a single positive block for performance, then the end marker. + write_long(out, len as i64)?; + for row in start..end { + write_item(out, row)?; + } + write_long(out, 0)?; + Ok(()) +} + +struct ListEncoder<'a, O: OffsetSizeTrait> { + list: &'a GenericListArray, + values: FieldEncoder<'a>, + values_offset: usize, +} + +type ListEncoder32<'a> = ListEncoder<'a, i32>; +type ListEncoder64<'a> = ListEncoder<'a, i64>; + +impl<'a, O: OffsetSizeTrait> ListEncoder<'a, O> { + fn try_new( + list: &'a GenericListArray, + items_nullability: Option, + item_plan: &FieldPlan, + ) -> Result { + let child_field = match list.data_type() { + DataType::List(field) => field.as_ref(), + DataType::LargeList(field) => field.as_ref(), + _ => { + return Err(ArrowError::SchemaError( + "Expected List or LargeList for ListEncoder".into(), + )) + } + }; + let values_enc = prepare_value_site_encoder( + list.values().as_ref(), + child_field, + items_nullability, + item_plan, + )?; + Ok(Self { + list, + values: values_enc, + values_offset: list.values().offset(), + }) + } + + fn encode_list_range( + &mut self, + out: &mut W, + start: usize, + end: usize, + ) -> Result<(), ArrowError> { + encode_blocked_range(out, start, end, |out, row| { + self.values + .encode(out, row.saturating_sub(self.values_offset)) + }) + } + + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + let offsets = self.list.offsets(); + let start = offsets[idx].to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!("Error converting offset[{idx}] to usize")) + })?; + let end = offsets[idx + 1].to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Error converting offset[{}] to usize", + idx + 1 + )) + })?; + self.encode_list_range(out, start, end) + } +} + +fn prepare_value_site_encoder<'a>( + values_array: &'a dyn Array, + value_field: &Field, + nullability: Option, + plan: &FieldPlan, +) -> Result, ArrowError> { + // Effective nullability is computed here from the writer-declared site nullability and data. + FieldEncoder::make_encoder(values_array, value_field, plan, nullability) +} + +/// Avro `fixed` encoder for Arrow `FixedSizeBinaryArray`. +/// Spec: a fixed is encoded as exactly `size` bytes, with no length prefix. +struct FixedEncoder<'a>(&'a FixedSizeBinaryArray); +impl FixedEncoder<'_> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + let v = self.0.value(idx); // &[u8] of fixed width + out.write_all(v) + .map_err(|e| ArrowError::IoError(format!("write fixed bytes: {e}"), e)) + } +} + +/// Avro UUID logical type encoder: Arrow FixedSizeBinary(16) → Avro string (UUID). +/// Spec: uuid is a logical type over string (RFC‑4122). We output hyphenated form. +struct UuidEncoder<'a>(&'a FixedSizeBinaryArray); +impl UuidEncoder<'_> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + let mut buf = [0u8; 1 + uuid::fmt::Hyphenated::LENGTH]; + buf[0] = 0x48; + let v = self.0.value(idx); + let u = Uuid::from_slice(v) + .map_err(|e| ArrowError::InvalidArgumentError(format!("Invalid UUID bytes: {e}")))?; + let _ = u.hyphenated().encode_lower(&mut buf[1..]); + out.write_all(&buf) + .map_err(|e| ArrowError::IoError(format!("write uuid: {e}"), e)) + } +} + +#[derive(Copy, Clone)] +struct DurationParts { + months: u32, + days: u32, + millis: u32, +} +/// Trait mapping an Arrow interval native value to Avro duration `(months, days, millis)`. +trait IntervalToDurationParts: ArrowPrimitiveType { + fn duration_parts(native: Self::Native) -> Result; +} +impl IntervalToDurationParts for IntervalMonthDayNanoType { + fn duration_parts(native: Self::Native) -> Result { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(native); + if months < 0 || days < 0 || nanos < 0 { + return Err(ArrowError::InvalidArgumentError( + "Avro 'duration' cannot encode negative months/days/nanoseconds".into(), + )); + } + if nanos % 1_000_000 != 0 { + return Err(ArrowError::InvalidArgumentError( + "Avro 'duration' requires whole milliseconds; nanoseconds must be divisible by 1_000_000" + .into(), + )); + } + let millis = nanos / 1_000_000; + if millis > u32::MAX as i64 { + return Err(ArrowError::InvalidArgumentError( + "Avro 'duration' milliseconds exceed u32::MAX".into(), + )); + } + Ok(DurationParts { + months: months as u32, + days: days as u32, + millis: millis as u32, + }) + } +} +impl IntervalToDurationParts for IntervalYearMonthType { + fn duration_parts(native: Self::Native) -> Result { + if native < 0 { + return Err(ArrowError::InvalidArgumentError( + "Avro 'duration' cannot encode negative months".into(), + )); + } + Ok(DurationParts { + months: native as u32, + days: 0, + millis: 0, + }) + } +} +impl IntervalToDurationParts for IntervalDayTimeType { + fn duration_parts(native: Self::Native) -> Result { + let (days, millis) = IntervalDayTimeType::to_parts(native); + if days < 0 || millis < 0 { + return Err(ArrowError::InvalidArgumentError( + "Avro 'duration' cannot encode negative days or milliseconds".into(), + )); + } + Ok(DurationParts { + months: 0, + days: days as u32, + millis: millis as u32, + }) + } +} +/// Single generic encoder used for all three interval units. +/// Writes Avro `fixed(12)` as three little-endian u32 values in one call. +struct DurationEncoder<'a, P: ArrowPrimitiveType + IntervalToDurationParts>(&'a PrimitiveArray

); +impl<'a, P: ArrowPrimitiveType + IntervalToDurationParts> DurationEncoder<'a, P> { + #[inline(always)] + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + let parts = P::duration_parts(self.0.value(idx))?; + let months = parts.months.to_le_bytes(); + let days = parts.days.to_le_bytes(); + let ms = parts.millis.to_le_bytes(); + // SAFETY + // - Endianness & layout: Avro's `duration` logical type is encoded as fixed(12) + // with three *little-endian* unsigned 32-bit integers in order: (months, days, millis). + // We explicitly materialize exactly those 12 bytes. + // - In-bounds indexing: `to_le_bytes()` on `u32` returns `[u8; 4]` by contract, + // therefore, the constant indices 0..=3 used below are *always* in-bounds. + // Rust will panic on out-of-bounds indexing, but there is no such path here; + // the compiler can also elide the bound checks for constant, provably in-range + // indices. [std docs; Rust Performance Book on bounds-check elimination] + // - Memory safety: The `[u8; 12]` array is built on the stack by value, with no + // aliasing and no uninitialized memory. There is no `unsafe`. + // - I/O: `write_all(&buf)` is fallible and its `Result` is propagated and mapped + // into `ArrowError`, so I/O errors are reported, not panicked. + // Consequently, constructing `buf` with the constant indices below is safe and + // panic-free under these validated preconditions. + let buf = [ + months[0], months[1], months[2], months[3], days[0], days[1], days[2], days[3], ms[0], + ms[1], ms[2], ms[3], + ]; + out.write_all(&buf) + .map_err(|e| ArrowError::IoError(format!("write duration: {e}"), e)) + } +} + +/// Minimal trait to obtain a big-endian fixed-size byte array for a decimal's +/// unscaled integer value at `idx`. +trait DecimalBeBytes { + fn value_be_bytes(&self, idx: usize) -> [u8; N]; +} +#[cfg(feature = "small_decimals")] +impl DecimalBeBytes<4> for Decimal32Array { + fn value_be_bytes(&self, idx: usize) -> [u8; 4] { + self.value(idx).to_be_bytes() + } +} +#[cfg(feature = "small_decimals")] +impl DecimalBeBytes<8> for Decimal64Array { + fn value_be_bytes(&self, idx: usize) -> [u8; 8] { + self.value(idx).to_be_bytes() + } +} +impl DecimalBeBytes<16> for Decimal128Array { + fn value_be_bytes(&self, idx: usize) -> [u8; 16] { + self.value(idx).to_be_bytes() + } +} +impl DecimalBeBytes<32> for Decimal256Array { + fn value_be_bytes(&self, idx: usize) -> [u8; 32] { + // Arrow i256 → [u8; 32] big-endian + self.value(idx).to_be_bytes() + } +} + +/// Generic Avro decimal encoder over Arrow decimal arrays. +/// - When `fixed_size` is `None` → Avro `bytes(decimal)`; writes the minimal +/// two's-complement representation with a length prefix. +/// - When `Some(n)` → Avro `fixed(n, decimal)`; sign-extends (or validates) +/// to exactly `n` bytes and writes them directly. +struct DecimalEncoder<'a, const N: usize, A: DecimalBeBytes> { + arr: &'a A, + fixed_size: Option, +} + +impl<'a, const N: usize, A: DecimalBeBytes> DecimalEncoder<'a, N, A> { + fn new(arr: &'a A, fixed_size: Option) -> Self { + Self { arr, fixed_size } + } + + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + let be = self.arr.value_be_bytes(idx); + match self.fixed_size { + Some(n) => write_sign_extended(out, &be, n), + None => write_len_prefixed(out, minimal_twos_complement(&be)), + } + } +} + +#[cfg(feature = "small_decimals")] +type Decimal32Encoder<'a> = DecimalEncoder<'a, 4, Decimal32Array>; +#[cfg(feature = "small_decimals")] +type Decimal64Encoder<'a> = DecimalEncoder<'a, 8, Decimal64Array>; +type Decimal128Encoder<'a> = DecimalEncoder<'a, 16, Decimal128Array>; +type Decimal256Encoder<'a> = DecimalEncoder<'a, 32, Decimal256Array>; + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::types::Int32Type; + use arrow_array::{ + Array, ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, + Int64Array, LargeBinaryArray, LargeListArray, LargeStringArray, ListArray, StringArray, + TimestampMicrosecondArray, + }; + use arrow_schema::{DataType, Field, Fields}; + + fn zigzag_i64(v: i64) -> u64 { + ((v << 1) ^ (v >> 63)) as u64 + } + + fn varint(mut x: u64) -> Vec { + let mut out = Vec::new(); + while (x & !0x7f) != 0 { + out.push(((x & 0x7f) as u8) | 0x80); + x >>= 7; + } + out.push((x & 0x7f) as u8); + out + } + + fn avro_long_bytes(v: i64) -> Vec { + varint(zigzag_i64(v)) + } + + fn avro_len_prefixed_bytes(payload: &[u8]) -> Vec { + let mut out = avro_long_bytes(payload.len() as i64); + out.extend_from_slice(payload); + out + } + + fn duration_fixed12(months: u32, days: u32, millis: u32) -> [u8; 12] { + let m = months.to_le_bytes(); + let d = days.to_le_bytes(); + let ms = millis.to_le_bytes(); + [ + m[0], m[1], m[2], m[3], d[0], d[1], d[2], d[3], ms[0], ms[1], ms[2], ms[3], + ] + } + + fn encode_all( + array: &dyn Array, + plan: &FieldPlan, + nullability: Option, + ) -> Vec { + let field = Field::new("f", array.data_type().clone(), true); + let mut enc = FieldEncoder::make_encoder(array, &field, plan, nullability).unwrap(); + let mut out = Vec::new(); + for i in 0..array.len() { + enc.encode(&mut out, i).unwrap(); + } + out + } + + fn assert_bytes_eq(actual: &[u8], expected: &[u8]) { + if actual != expected { + let to_hex = |b: &[u8]| { + b.iter() + .map(|x| format!("{:02X}", x)) + .collect::>() + .join(" ") + }; + panic!( + "mismatch\n expected: [{}]\n actual: [{}]", + to_hex(expected), + to_hex(actual) + ); + } + } + + #[test] + fn binary_encoder() { + let values: Vec<&[u8]> = vec![b"", b"ab", b"\x00\xFF"]; + let arr = BinaryArray::from_vec(values); + let mut expected = Vec::new(); + for payload in [b"" as &[u8], b"ab", b"\x00\xFF"] { + expected.extend(avro_len_prefixed_bytes(payload)); + } + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn large_binary_encoder() { + let values: Vec<&[u8]> = vec![b"xyz", b""]; + let arr = LargeBinaryArray::from_vec(values); + let mut expected = Vec::new(); + for payload in [b"xyz" as &[u8], b""] { + expected.extend(avro_len_prefixed_bytes(payload)); + } + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn utf8_encoder() { + let arr = StringArray::from(vec!["", "A", "BC"]); + let mut expected = Vec::new(); + for s in ["", "A", "BC"] { + expected.extend(avro_len_prefixed_bytes(s.as_bytes())); + } + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn large_utf8_encoder() { + let arr = LargeStringArray::from(vec!["hello", ""]); + let mut expected = Vec::new(); + for s in ["hello", ""] { + expected.extend(avro_len_prefixed_bytes(s.as_bytes())); + } + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn list_encoder_int32() { + // Build ListArray [[1,2], [], [3]] + let values = Int32Array::from(vec![1, 2, 3]); + let offsets = vec![0, 2, 2, 3]; + let list = ListArray::new( + Field::new("item", DataType::Int32, true).into(), + arrow_buffer::OffsetBuffer::new(offsets.into()), + Arc::new(values) as ArrayRef, + None, + ); + // Avro array encoding per row + let mut expected = Vec::new(); + // row 0: block len 2, items 1,2 then 0 + expected.extend(avro_long_bytes(2)); + expected.extend(avro_long_bytes(1)); + expected.extend(avro_long_bytes(2)); + expected.extend(avro_long_bytes(0)); + // row 1: empty + expected.extend(avro_long_bytes(0)); + // row 2: one item 3 + expected.extend(avro_long_bytes(1)); + expected.extend(avro_long_bytes(3)); + expected.extend(avro_long_bytes(0)); + + let plan = FieldPlan::List { + items_nullability: None, + item_plan: Box::new(FieldPlan::Scalar), + }; + let got = encode_all(&list, &plan, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn struct_encoder_two_fields() { + // Struct { a: Int32, b: Utf8 } + let a = Int32Array::from(vec![1, 2]); + let b = StringArray::from(vec!["x", "y"]); + let fields = Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ]); + let struct_arr = StructArray::new( + fields.clone(), + vec![Arc::new(a) as ArrayRef, Arc::new(b) as ArrayRef], + None, + ); + let plan = FieldPlan::Struct { + encoders: vec![ + FieldBinding { + arrow_index: 0, + nullability: None, + plan: FieldPlan::Scalar, + }, + FieldBinding { + arrow_index: 1, + nullability: None, + plan: FieldPlan::Scalar, + }, + ], + }; + let got = encode_all(&struct_arr, &plan, None); + // Expected: rows concatenated: a then b + let mut expected = Vec::new(); + expected.extend(avro_long_bytes(1)); // a=1 + expected.extend(avro_len_prefixed_bytes(b"x")); // b="x" + expected.extend(avro_long_bytes(2)); // a=2 + expected.extend(avro_len_prefixed_bytes(b"y")); // b="y" + assert_bytes_eq(&got, &expected); + } + + #[test] + fn enum_encoder_dictionary() { + // symbols: ["A","B","C"], keys [2,0,1] + let dict_values = StringArray::from(vec!["A", "B", "C"]); + let keys = Int32Array::from(vec![2, 0, 1]); + let dict = + DictionaryArray::::try_new(keys, Arc::new(dict_values) as ArrayRef).unwrap(); + let symbols = Arc::<[String]>::from( + vec!["A".to_string(), "B".to_string(), "C".to_string()].into_boxed_slice(), + ); + let plan = FieldPlan::Enum { symbols }; + let got = encode_all(&dict, &plan, None); + let mut expected = Vec::new(); + expected.extend(avro_long_bytes(2)); + expected.extend(avro_long_bytes(0)); + expected.extend(avro_long_bytes(1)); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn decimal_bytes_and_fixed() { + // Use Decimal128 with small positives and negatives + let dec = Decimal128Array::from(vec![1i128, -1i128, 0i128]) + .with_precision_and_scale(20, 0) + .unwrap(); + // bytes(decimal): minimal two's complement length-prefixed + let plan_bytes = FieldPlan::Decimal { size: None }; + let got_bytes = encode_all(&dec, &plan_bytes, None); + // 1 -> 0x01; -1 -> 0xFF; 0 -> 0x00 + let mut expected_bytes = Vec::new(); + expected_bytes.extend(avro_len_prefixed_bytes(&[0x01])); + expected_bytes.extend(avro_len_prefixed_bytes(&[0xFF])); + expected_bytes.extend(avro_len_prefixed_bytes(&[0x00])); + assert_bytes_eq(&got_bytes, &expected_bytes); + + let plan_fixed = FieldPlan::Decimal { size: Some(16) }; + let got_fixed = encode_all(&dec, &plan_fixed, None); + let mut expected_fixed = Vec::new(); + expected_fixed.extend_from_slice(&1i128.to_be_bytes()); + expected_fixed.extend_from_slice(&(-1i128).to_be_bytes()); + expected_fixed.extend_from_slice(&0i128.to_be_bytes()); + assert_bytes_eq(&got_fixed, &expected_fixed); + } + + #[test] + fn decimal_bytes_256() { + use arrow_buffer::i256; + // Use Decimal256 with small positives and negatives + let dec = Decimal256Array::from(vec![ + i256::from_i128(1), + i256::from_i128(-1), + i256::from_i128(0), + ]) + .with_precision_and_scale(76, 0) + .unwrap(); + // bytes(decimal): minimal two's complement length-prefixed + let plan_bytes = FieldPlan::Decimal { size: None }; + let got_bytes = encode_all(&dec, &plan_bytes, None); + // 1 -> 0x01; -1 -> 0xFF; 0 -> 0x00 + let mut expected_bytes = Vec::new(); + expected_bytes.extend(avro_len_prefixed_bytes(&[0x01])); + expected_bytes.extend(avro_len_prefixed_bytes(&[0xFF])); + expected_bytes.extend(avro_len_prefixed_bytes(&[0x00])); + assert_bytes_eq(&got_bytes, &expected_bytes); + + // fixed(32): 32-byte big-endian two's complement + let plan_fixed = FieldPlan::Decimal { size: Some(32) }; + let got_fixed = encode_all(&dec, &plan_fixed, None); + let mut expected_fixed = Vec::new(); + expected_fixed.extend_from_slice(&i256::from_i128(1).to_be_bytes()); + expected_fixed.extend_from_slice(&i256::from_i128(-1).to_be_bytes()); + expected_fixed.extend_from_slice(&i256::from_i128(0).to_be_bytes()); + assert_bytes_eq(&got_fixed, &expected_fixed); + } + + #[cfg(feature = "small_decimals")] + #[test] + fn decimal_bytes_and_fixed_32() { + // Use Decimal32 with small positives and negatives + let dec = Decimal32Array::from(vec![1i32, -1i32, 0i32]) + .with_precision_and_scale(9, 0) + .unwrap(); + // bytes(decimal) + let plan_bytes = FieldPlan::Decimal { size: None }; + let got_bytes = encode_all(&dec, &plan_bytes, None); + let mut expected_bytes = Vec::new(); + expected_bytes.extend(avro_len_prefixed_bytes(&[0x01])); + expected_bytes.extend(avro_len_prefixed_bytes(&[0xFF])); + expected_bytes.extend(avro_len_prefixed_bytes(&[0x00])); + assert_bytes_eq(&got_bytes, &expected_bytes); + // fixed(4) + let plan_fixed = FieldPlan::Decimal { size: Some(4) }; + let got_fixed = encode_all(&dec, &plan_fixed, None); + let mut expected_fixed = Vec::new(); + expected_fixed.extend_from_slice(&1i32.to_be_bytes()); + expected_fixed.extend_from_slice(&(-1i32).to_be_bytes()); + expected_fixed.extend_from_slice(&0i32.to_be_bytes()); + assert_bytes_eq(&got_fixed, &expected_fixed); + } + + #[cfg(feature = "small_decimals")] + #[test] + fn decimal_bytes_and_fixed_64() { + // Use Decimal64 with small positives and negatives + let dec = Decimal64Array::from(vec![1i64, -1i64, 0i64]) + .with_precision_and_scale(18, 0) + .unwrap(); + // bytes(decimal) + let plan_bytes = FieldPlan::Decimal { size: None }; + let got_bytes = encode_all(&dec, &plan_bytes, None); + let mut expected_bytes = Vec::new(); + expected_bytes.extend(avro_len_prefixed_bytes(&[0x01])); + expected_bytes.extend(avro_len_prefixed_bytes(&[0xFF])); + expected_bytes.extend(avro_len_prefixed_bytes(&[0x00])); + assert_bytes_eq(&got_bytes, &expected_bytes); + // fixed(8) + let plan_fixed = FieldPlan::Decimal { size: Some(8) }; + let got_fixed = encode_all(&dec, &plan_fixed, None); + let mut expected_fixed = Vec::new(); + expected_fixed.extend_from_slice(&1i64.to_be_bytes()); + expected_fixed.extend_from_slice(&(-1i64).to_be_bytes()); + expected_fixed.extend_from_slice(&0i64.to_be_bytes()); + assert_bytes_eq(&got_fixed, &expected_fixed); + } + + #[test] + fn float32_and_float64_encoders() { + let f32a = Float32Array::from(vec![0.0f32, -1.5f32, f32::from_bits(0x7fc00000)]); // includes a quiet NaN bit pattern + let f64a = Float64Array::from(vec![0.0f64, -2.25f64]); + // f32 expected + let mut expected32 = Vec::new(); + for v in [0.0f32, -1.5f32, f32::from_bits(0x7fc00000)] { + expected32.extend_from_slice(&v.to_bits().to_le_bytes()); + } + let got32 = encode_all(&f32a, &FieldPlan::Scalar, None); + assert_bytes_eq(&got32, &expected32); + // f64 expected + let mut expected64 = Vec::new(); + for v in [0.0f64, -2.25f64] { + expected64.extend_from_slice(&v.to_bits().to_le_bytes()); + } + let got64 = encode_all(&f64a, &FieldPlan::Scalar, None); + assert_bytes_eq(&got64, &expected64); + } + + #[test] + fn long_encoder_int64() { + let arr = Int64Array::from(vec![0i64, 1i64, -1i64, 2i64, -2i64, i64::MIN + 1]); + let mut expected = Vec::new(); + for v in [0, 1, -1, 2, -2, i64::MIN + 1] { + expected.extend(avro_long_bytes(v)); + } + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn fixed_encoder_plain() { + // Two values of width 4 + let data = [[0xDE, 0xAD, 0xBE, 0xEF], [0x00, 0x01, 0x02, 0x03]]; + let values: Vec> = data.iter().map(|x| x.to_vec()).collect(); + let arr = FixedSizeBinaryArray::try_from_iter(values.into_iter()).unwrap(); + let got = encode_all(&arr, &FieldPlan::Scalar, None); + let mut expected = Vec::new(); + expected.extend_from_slice(&data[0]); + expected.extend_from_slice(&data[1]); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn uuid_encoder_test() { + // Happy path + let u = Uuid::parse_str("00112233-4455-6677-8899-aabbccddeeff").unwrap(); + let bytes = *u.as_bytes(); + let arr_ok = FixedSizeBinaryArray::try_from_iter(vec![bytes.to_vec()].into_iter()).unwrap(); + // Expected: length 36 (0x48) followed by hyphenated lowercase text + let mut expected = Vec::new(); + expected.push(0x48); + expected.extend_from_slice(u.hyphenated().to_string().as_bytes()); + let got = encode_all(&arr_ok, &FieldPlan::Uuid, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn uuid_encoder_error() { + // Invalid UUID bytes: wrong length + let arr = + FixedSizeBinaryArray::try_new(10, arrow_buffer::Buffer::from(vec![0u8; 10]), None) + .unwrap(); + let plan = FieldPlan::Uuid; + + let field = Field::new("f", arr.data_type().clone(), true); + let mut enc = FieldEncoder::make_encoder(&arr, &field, &plan, None).unwrap(); + let mut out = Vec::new(); + let err = enc.encode(&mut out, 0).unwrap_err(); + match err { + ArrowError::InvalidArgumentError(msg) => { + assert!(msg.contains("Invalid UUID bytes")) + } + other => panic!("expected InvalidArgumentError, got {other:?}"), + } + } + + #[test] + fn map_encoder_string_keys_int_values() { + // Build MapArray with two rows + // Row0: {"k1":1, "k2":2} + // Row1: {} + let keys = StringArray::from(vec!["k1", "k2"]); + let values = Int32Array::from(vec![1, 2]); + let entries_fields = Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ]); + let entries = StructArray::new( + entries_fields, + vec![Arc::new(keys) as ArrayRef, Arc::new(values) as ArrayRef], + None, + ); + let offsets = arrow_buffer::OffsetBuffer::new(vec![0i32, 2, 2].into()); + let map = MapArray::new( + Field::new("entries", entries.data_type().clone(), false).into(), + offsets, + entries, + None, + false, + ); + let plan = FieldPlan::Map { + values_nullability: None, + value_plan: Box::new(FieldPlan::Scalar), + }; + let got = encode_all(&map, &plan, None); + let mut expected = Vec::new(); + // Row0: block 2 then pairs + expected.extend(avro_long_bytes(2)); + expected.extend(avro_len_prefixed_bytes(b"k1")); + expected.extend(avro_long_bytes(1)); + expected.extend(avro_len_prefixed_bytes(b"k2")); + expected.extend(avro_long_bytes(2)); + expected.extend(avro_long_bytes(0)); + // Row1: empty + expected.extend(avro_long_bytes(0)); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn list64_encoder_int32() { + // LargeList [[1,2,3], []] + let values = Int32Array::from(vec![1, 2, 3]); + let offsets: Vec = vec![0, 3, 3]; + let list = LargeListArray::new( + Field::new("item", DataType::Int32, true).into(), + arrow_buffer::OffsetBuffer::new(offsets.into()), + Arc::new(values) as ArrayRef, + None, + ); + let plan = FieldPlan::List { + items_nullability: None, + item_plan: Box::new(FieldPlan::Scalar), + }; + let got = encode_all(&list, &plan, None); + // Expected one block of 3 and then 0, then empty 0 + let mut expected = Vec::new(); + expected.extend(avro_long_bytes(3)); + expected.extend(avro_long_bytes(1)); + expected.extend(avro_long_bytes(2)); + expected.extend(avro_long_bytes(3)); + expected.extend(avro_long_bytes(0)); + expected.extend(avro_long_bytes(0)); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn int_encoder_test() { + let ints = Int32Array::from(vec![0, -1, 2]); + let mut expected_i = Vec::new(); + for v in [0i32, -1, 2] { + expected_i.extend(avro_long_bytes(v as i64)); + } + let got_i = encode_all(&ints, &FieldPlan::Scalar, None); + assert_bytes_eq(&got_i, &expected_i); + } + + #[test] + fn boolean_encoder_test() { + let bools = BooleanArray::from(vec![true, false]); + let mut expected_b = Vec::new(); + expected_b.extend_from_slice(&[1]); + expected_b.extend_from_slice(&[0]); + let got_b = encode_all(&bools, &FieldPlan::Scalar, None); + assert_bytes_eq(&got_b, &expected_b); + } + + #[test] + fn duration_encoder_year_month_happy_path() { + let arr: PrimitiveArray = vec![0i32, 1i32, 25i32].into(); + let mut expected = Vec::new(); + for m in [0u32, 1u32, 25u32] { + expected.extend_from_slice(&duration_fixed12(m, 0, 0)); + } + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn duration_encoder_year_month_rejects_negative() { + let arr: PrimitiveArray = vec![-1i32].into(); + let field = Field::new("f", DataType::Interval(IntervalUnit::YearMonth), true); + let mut enc = FieldEncoder::make_encoder(&arr, &field, &FieldPlan::Scalar, None).unwrap(); + let mut out = Vec::new(); + let err = enc.encode(&mut out, 0).unwrap_err(); + match err { + ArrowError::InvalidArgumentError(msg) => { + assert!(msg.contains("cannot encode negative months")) + } + other => panic!("expected InvalidArgumentError, got {other:?}"), + } + } + + #[test] + fn duration_encoder_day_time_happy_path() { + let v0 = IntervalDayTimeType::make_value(2, 500); // days=2, millis=500 + let v1 = IntervalDayTimeType::make_value(0, 0); + let arr: PrimitiveArray = vec![v0, v1].into(); + let mut expected = Vec::new(); + expected.extend_from_slice(&duration_fixed12(0, 2, 500)); + expected.extend_from_slice(&duration_fixed12(0, 0, 0)); + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn duration_encoder_day_time_rejects_negative() { + let bad = IntervalDayTimeType::make_value(-1, 0); + let arr: PrimitiveArray = vec![bad].into(); + let field = Field::new("f", DataType::Interval(IntervalUnit::DayTime), true); + let mut enc = FieldEncoder::make_encoder(&arr, &field, &FieldPlan::Scalar, None).unwrap(); + let mut out = Vec::new(); + let err = enc.encode(&mut out, 0).unwrap_err(); + match err { + ArrowError::InvalidArgumentError(msg) => { + assert!(msg.contains("cannot encode negative days")) + } + other => panic!("expected InvalidArgumentError, got {other:?}"), + } + } + + #[test] + fn duration_encoder_month_day_nano_happy_path() { + let v0 = IntervalMonthDayNanoType::make_value(1, 2, 3_000_000); // -> millis = 3 + let v1 = IntervalMonthDayNanoType::make_value(0, 0, 0); + let arr: PrimitiveArray = vec![v0, v1].into(); + let mut expected = Vec::new(); + expected.extend_from_slice(&duration_fixed12(1, 2, 3)); + expected.extend_from_slice(&duration_fixed12(0, 0, 0)); + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn duration_encoder_month_day_nano_rejects_non_ms_multiple() { + let bad = IntervalMonthDayNanoType::make_value(0, 0, 1); + let arr: PrimitiveArray = vec![bad].into(); + let field = Field::new("f", DataType::Interval(IntervalUnit::MonthDayNano), true); + let mut enc = FieldEncoder::make_encoder(&arr, &field, &FieldPlan::Scalar, None).unwrap(); + let mut out = Vec::new(); + let err = enc.encode(&mut out, 0).unwrap_err(); + match err { + ArrowError::InvalidArgumentError(msg) => { + assert!(msg.contains("requires whole milliseconds") || msg.contains("divisible")) + } + other => panic!("expected InvalidArgumentError, got {other:?}"), + } + } + + #[test] + fn minimal_twos_complement_test() { + let pos = [0x00, 0x00, 0x01]; + assert_eq!(minimal_twos_complement(&pos), &pos[2..]); + let neg = [0xFF, 0xFF, 0x80]; // negative minimal is 0x80 + assert_eq!(minimal_twos_complement(&neg), &neg[2..]); + let zero = [0x00, 0x00, 0x00]; + assert_eq!(minimal_twos_complement(&zero), &zero[2..]); + } + + #[test] + fn write_sign_extend_test() { + let mut out = Vec::new(); + write_sign_extended(&mut out, &[0x01], 4).unwrap(); + assert_eq!(out, vec![0x00, 0x00, 0x00, 0x01]); + out.clear(); + write_sign_extended(&mut out, &[0xFF], 4).unwrap(); + assert_eq!(out, vec![0xFF, 0xFF, 0xFF, 0xFF]); + out.clear(); + // truncation success (sign bytes only removed) + write_sign_extended(&mut out, &[0xFF, 0xFF, 0x80], 2).unwrap(); + assert_eq!(out, vec![0xFF, 0x80]); + out.clear(); + // truncation overflow + let err = write_sign_extended(&mut out, &[0x01, 0x00], 1).unwrap_err(); + match err { + ArrowError::InvalidArgumentError(_) => {} + _ => panic!("expected InvalidArgumentError"), + } + } + + #[test] + fn duration_month_day_nano_overflow_millis() { + // nanos leading to millis > u32::MAX + let nanos = ((u64::from(u32::MAX) + 1) * 1_000_000) as i64; + let v = IntervalMonthDayNanoType::make_value(0, 0, nanos); + let arr: PrimitiveArray = vec![v].into(); + let field = Field::new("f", DataType::Interval(IntervalUnit::MonthDayNano), true); + let mut enc = FieldEncoder::make_encoder(&arr, &field, &FieldPlan::Scalar, None).unwrap(); + let mut out = Vec::new(); + let err = enc.encode(&mut out, 0).unwrap_err(); + match err { + ArrowError::InvalidArgumentError(msg) => assert!(msg.contains("exceed u32::MAX")), + _ => panic!("expected InvalidArgumentError"), + } + } + + #[test] + fn fieldplan_decimal_precision_scale_mismatch_errors() { + // Avro expects (10,2), Arrow has (12,2) + use crate::codec::Codec; + use std::collections::HashMap; + let arrow_field = Field::new("d", DataType::Decimal128(12, 2), true); + let avro_dt = AvroDataType::new(Codec::Decimal(10, Some(2), None), HashMap::new(), None); + let err = FieldPlan::build(&avro_dt, &arrow_field).unwrap_err(); + match err { + ArrowError::SchemaError(msg) => { + assert!(msg.contains("Decimal precision/scale mismatch")) + } + _ => panic!("expected SchemaError"), + } + } +} diff --git a/arrow-avro/src/writer/format.rs b/arrow-avro/src/writer/format.rs new file mode 100644 index 000000000000..6fac9e8286a2 --- /dev/null +++ b/arrow-avro/src/writer/format.rs @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::compression::{CompressionCodec, CODEC_METADATA_KEY}; +use crate::schema::{AvroSchema, SCHEMA_METADATA_KEY}; +use crate::writer::encoder::write_long; +use arrow_schema::{ArrowError, Schema}; +use rand::RngCore; +use std::fmt::Debug; +use std::io::Write; + +/// Format abstraction implemented by each container‐level writer. +pub trait AvroFormat: Debug + Default { + /// Write any bytes required at the very beginning of the output stream + /// Implementations **must not** write any record data. + fn start_stream( + &mut self, + writer: &mut W, + schema: &Schema, + compression: Option, + ) -> Result<(), ArrowError>; + + /// Return the 16‑byte sync marker (OCF) or `None` (binary stream). + fn sync_marker(&self) -> Option<&[u8; 16]>; +} + +/// Avro Object Container File (OCF) format writer. +#[derive(Debug, Default)] +pub struct AvroOcfFormat { + sync_marker: [u8; 16], +} + +impl AvroFormat for AvroOcfFormat { + fn start_stream( + &mut self, + writer: &mut W, + schema: &Schema, + compression: Option, + ) -> Result<(), ArrowError> { + let mut rng = rand::rng(); + rng.fill_bytes(&mut self.sync_marker); + let avro_schema = AvroSchema::try_from(schema)?; + writer + .write_all(b"Obj\x01") + .map_err(|e| ArrowError::IoError(format!("write OCF magic: {e}"), e))?; + let codec_str = match compression { + Some(CompressionCodec::Deflate) => "deflate", + Some(CompressionCodec::Snappy) => "snappy", + Some(CompressionCodec::ZStandard) => "zstandard", + Some(CompressionCodec::Bzip2) => "bzip2", + Some(CompressionCodec::Xz) => "xz", + None => "null", + }; + write_long(writer, 2)?; + write_string(writer, SCHEMA_METADATA_KEY)?; + write_bytes(writer, avro_schema.json_string.as_bytes())?; + write_string(writer, CODEC_METADATA_KEY)?; + write_bytes(writer, codec_str.as_bytes())?; + write_long(writer, 0)?; + // Sync marker (16 bytes) + writer + .write_all(&self.sync_marker) + .map_err(|e| ArrowError::IoError(format!("write OCF sync marker: {e}"), e))?; + + Ok(()) + } + + fn sync_marker(&self) -> Option<&[u8; 16]> { + Some(&self.sync_marker) + } +} + +/// Raw Avro binary streaming format (no header or footer). +#[derive(Debug, Default)] +pub struct AvroBinaryFormat; + +impl AvroFormat for AvroBinaryFormat { + fn start_stream( + &mut self, + _writer: &mut W, + _schema: &Schema, + _compression: Option, + ) -> Result<(), ArrowError> { + Err(ArrowError::NotYetImplemented( + "avro binary format not yet implemented".to_string(), + )) + } + + fn sync_marker(&self) -> Option<&[u8; 16]> { + None + } +} + +#[inline] +fn write_string(writer: &mut W, s: &str) -> Result<(), ArrowError> { + write_bytes(writer, s.as_bytes()) +} + +#[inline] +fn write_bytes(writer: &mut W, bytes: &[u8]) -> Result<(), ArrowError> { + write_long(writer, bytes.len() as i64)?; + writer + .write_all(bytes) + .map_err(|e| ArrowError::IoError(format!("write bytes: {e}"), e)) +} diff --git a/arrow-avro/src/writer/mod.rs b/arrow-avro/src/writer/mod.rs new file mode 100644 index 000000000000..f5e84eeb50bb --- /dev/null +++ b/arrow-avro/src/writer/mod.rs @@ -0,0 +1,636 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Avro writer implementation for the `arrow-avro` crate. +//! +//! # Overview +//! +//! * Use **`AvroWriter`** (Object Container File) when you want a +//! self‑contained Avro file with header, schema JSON, optional compression, +//! blocks, and sync markers. +//! * Use **`AvroStreamWriter`** (raw binary stream) when you already know the +//! schema out‑of‑band (i.e., via a schema registry) and need a stream +//! of Avro‑encoded records with minimal framing. +//! + +/// Encodes `RecordBatch` into the Avro binary format. +pub mod encoder; +/// Logic for different Avro container file formats. +pub mod format; + +use crate::codec::AvroFieldBuilder; +use crate::compression::CompressionCodec; +use crate::schema::{AvroSchema, SCHEMA_METADATA_KEY}; +use crate::writer::encoder::{write_long, RecordEncoder, RecordEncoderBuilder}; +use crate::writer::format::{AvroBinaryFormat, AvroFormat, AvroOcfFormat}; +use arrow_array::RecordBatch; +use arrow_schema::{ArrowError, Schema}; +use std::io::Write; +use std::sync::Arc; + +/// Builder to configure and create a `Writer`. +#[derive(Debug, Clone)] +pub struct WriterBuilder { + schema: Schema, + codec: Option, + capacity: usize, +} + +impl WriterBuilder { + /// Create a new builder with default settings. + pub fn new(schema: Schema) -> Self { + Self { + schema, + codec: None, + capacity: 1024, + } + } + + /// Change the compression codec. + pub fn with_compression(mut self, codec: Option) -> Self { + self.codec = codec; + self + } + + /// Sets the capacity for the given object and returns the modified instance. + pub fn with_capacity(mut self, capacity: usize) -> Self { + self.capacity = capacity; + self + } + + /// Create a new `Writer` with specified `AvroFormat` and builder options. + /// Performs one‑time startup (header/stream init, encoder plan). + pub fn build(self, mut writer: W) -> Result, ArrowError> + where + W: Write, + F: AvroFormat, + { + let mut format = F::default(); + let avro_schema = match self.schema.metadata.get(SCHEMA_METADATA_KEY) { + Some(json) => AvroSchema::new(json.clone()), + None => AvroSchema::try_from(&self.schema)?, + }; + let mut md = self.schema.metadata().clone(); + md.insert( + SCHEMA_METADATA_KEY.to_string(), + avro_schema.clone().json_string, + ); + let schema = Arc::new(Schema::new_with_metadata(self.schema.fields().clone(), md)); + format.start_stream(&mut writer, &schema, self.codec)?; + let avro_root = AvroFieldBuilder::new(&avro_schema.schema()?).build()?; + let encoder = RecordEncoderBuilder::new(&avro_root, schema.as_ref()).build()?; + Ok(Writer { + writer, + schema, + format, + compression: self.codec, + capacity: self.capacity, + encoder, + }) + } +} + +/// Generic Avro writer. +#[derive(Debug)] +pub struct Writer { + writer: W, + schema: Arc, + format: F, + compression: Option, + capacity: usize, + encoder: RecordEncoder, +} + +/// Alias for an Avro **Object Container File** writer. +pub type AvroWriter = Writer; +/// Alias for a raw Avro **binary stream** writer. +pub type AvroStreamWriter = Writer; + +impl Writer { + /// Convenience constructor – same as [`WriterBuilder::build`] with `AvroOcfFormat`. + pub fn new(writer: W, schema: Schema) -> Result { + WriterBuilder::new(schema).build::(writer) + } + + /// Return a reference to the 16‑byte sync marker generated for this file. + pub fn sync_marker(&self) -> Option<&[u8; 16]> { + self.format.sync_marker() + } +} + +impl Writer { + /// Convenience constructor to create a new [`AvroStreamWriter`]. + pub fn new(writer: W, schema: Schema) -> Result { + WriterBuilder::new(schema).build::(writer) + } +} + +impl Writer { + /// Serialize one [`RecordBatch`] to the output. + pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { + if batch.schema().fields() != self.schema.fields() { + return Err(ArrowError::SchemaError( + "Schema of RecordBatch differs from Writer schema".to_string(), + )); + } + match self.format.sync_marker() { + Some(&sync) => self.write_ocf_block(batch, &sync), + None => self.write_stream(batch), + } + } + + /// A convenience method to write a slice of [`RecordBatch`]. + /// + /// This is equivalent to calling `write` for each batch in the slice. + pub fn write_batches(&mut self, batches: &[&RecordBatch]) -> Result<(), ArrowError> { + for b in batches { + self.write(b)?; + } + Ok(()) + } + + /// Flush remaining buffered data and (for OCF) ensure the header is present. + pub fn finish(&mut self) -> Result<(), ArrowError> { + self.writer + .flush() + .map_err(|e| ArrowError::IoError(format!("Error flushing writer: {e}"), e)) + } + + /// Consume the writer, returning the underlying output object. + pub fn into_inner(self) -> W { + self.writer + } + + fn write_ocf_block(&mut self, batch: &RecordBatch, sync: &[u8; 16]) -> Result<(), ArrowError> { + let mut buf = Vec::::with_capacity(1024); + self.encoder.encode(&mut buf, batch)?; + let encoded = match self.compression { + Some(codec) => codec.compress(&buf)?, + None => buf, + }; + write_long(&mut self.writer, batch.num_rows() as i64)?; + write_long(&mut self.writer, encoded.len() as i64)?; + self.writer + .write_all(&encoded) + .map_err(|e| ArrowError::IoError(format!("Error writing Avro block: {e}"), e))?; + self.writer + .write_all(sync) + .map_err(|e| ArrowError::IoError(format!("Error writing Avro sync: {e}"), e))?; + Ok(()) + } + + fn write_stream(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { + self.encoder.encode(&mut self.writer, batch) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::compression::CompressionCodec; + use crate::reader::ReaderBuilder; + use crate::schema::{AvroSchema, SchemaStore}; + use crate::test_util::arrow_test_data; + use arrow_array::{ArrayRef, BinaryArray, Int32Array, RecordBatch}; + use arrow_schema::{DataType, Field, IntervalUnit, Schema}; + use std::fs::File; + use std::io::{BufReader, Cursor}; + use std::path::PathBuf; + use std::sync::Arc; + use tempfile::NamedTempFile; + + fn make_schema() -> Schema { + Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Binary, false), + ]) + } + + fn make_batch() -> RecordBatch { + let ids = Int32Array::from(vec![1, 2, 3]); + let names = BinaryArray::from_vec(vec![b"a".as_ref(), b"b".as_ref(), b"c".as_ref()]); + RecordBatch::try_new( + Arc::new(make_schema()), + vec![Arc::new(ids) as ArrayRef, Arc::new(names) as ArrayRef], + ) + .expect("failed to build test RecordBatch") + } + + #[test] + fn test_ocf_writer_generates_header_and_sync() -> Result<(), ArrowError> { + let batch = make_batch(); + let buffer: Vec = Vec::new(); + let mut writer = AvroWriter::new(buffer, make_schema())?; + writer.write(&batch)?; + writer.finish()?; + let out = writer.into_inner(); + assert_eq!(&out[..4], b"Obj\x01", "OCF magic bytes missing/incorrect"); + let trailer = &out[out.len() - 16..]; + assert_eq!(trailer.len(), 16, "expected 16‑byte sync marker"); + Ok(()) + } + + #[test] + fn test_schema_mismatch_yields_error() { + let batch = make_batch(); + let alt_schema = Schema::new(vec![Field::new("x", DataType::Int32, false)]); + let buffer = Vec::::new(); + let mut writer = AvroWriter::new(buffer, alt_schema).unwrap(); + let err = writer.write(&batch).unwrap_err(); + assert!(matches!(err, ArrowError::SchemaError(_))); + } + + #[test] + fn test_write_batches_accumulates_multiple() -> Result<(), ArrowError> { + let batch1 = make_batch(); + let batch2 = make_batch(); + let buffer = Vec::::new(); + let mut writer = AvroWriter::new(buffer, make_schema())?; + writer.write_batches(&[&batch1, &batch2])?; + writer.finish()?; + let out = writer.into_inner(); + assert!(out.len() > 4, "combined batches produced tiny file"); + Ok(()) + } + + #[test] + fn test_finish_without_write_adds_header() -> Result<(), ArrowError> { + let buffer = Vec::::new(); + let mut writer = AvroWriter::new(buffer, make_schema())?; + writer.finish()?; + let out = writer.into_inner(); + assert_eq!(&out[..4], b"Obj\x01", "finish() should emit OCF header"); + Ok(()) + } + + #[test] + fn test_write_long_encodes_zigzag_varint() -> Result<(), ArrowError> { + let mut buf = Vec::new(); + write_long(&mut buf, 0)?; + write_long(&mut buf, -1)?; + write_long(&mut buf, 1)?; + write_long(&mut buf, -2)?; + write_long(&mut buf, 2147483647)?; + assert!( + buf.starts_with(&[0x00, 0x01, 0x02, 0x03]), + "zig‑zag varint encodings incorrect: {buf:?}" + ); + Ok(()) + } + + #[test] + fn test_roundtrip_alltypes_roundtrip_writer() -> Result<(), ArrowError> { + let files = [ + "avro/alltypes_plain.avro", + "avro/alltypes_plain.snappy.avro", + "avro/alltypes_plain.zstandard.avro", + "avro/alltypes_plain.bzip2.avro", + "avro/alltypes_plain.xz.avro", + ]; + for rel in files { + let path = arrow_test_data(rel); + let rdr_file = File::open(&path).expect("open input avro"); + let mut reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build reader"); + let schema = reader.schema(); + let input_batches = reader.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&schema, &input_batches).expect("concat input"); + let tmp = NamedTempFile::new().expect("create temp file"); + let out_path = tmp.into_temp_path(); + let out_file = File::create(&out_path).expect("create temp avro"); + let codec = if rel.contains(".snappy.") { + Some(CompressionCodec::Snappy) + } else if rel.contains(".zstandard.") { + Some(CompressionCodec::ZStandard) + } else if rel.contains(".bzip2.") { + Some(CompressionCodec::Bzip2) + } else if rel.contains(".xz.") { + Some(CompressionCodec::Xz) + } else { + None + }; + let mut writer = WriterBuilder::new(original.schema().as_ref().clone()) + .with_compression(codec) + .build::<_, AvroOcfFormat>(out_file)?; + writer.write(&original)?; + writer.finish()?; + drop(writer); + let rt_file = File::open(&out_path).expect("open roundtrip avro"); + let mut rt_reader = ReaderBuilder::new() + .build(BufReader::new(rt_file)) + .expect("build roundtrip reader"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let roundtrip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip"); + assert_eq!( + roundtrip, original, + "Round-trip batch mismatch for file: {}", + rel + ); + } + Ok(()) + } + + #[test] + fn test_roundtrip_nested_records_writer() -> Result<(), ArrowError> { + let path = arrow_test_data("avro/nested_records.avro"); + let rdr_file = File::open(&path).expect("open nested_records.avro"); + let mut reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build reader for nested_records.avro"); + let schema = reader.schema(); + let batches = reader.collect::, _>>()?; + let original = arrow::compute::concat_batches(&schema, &batches).expect("concat original"); + let tmp = NamedTempFile::new().expect("create temp file"); + let out_path = tmp.into_temp_path(); + { + let out_file = File::create(&out_path).expect("create output avro"); + let mut writer = AvroWriter::new(out_file, original.schema().as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + } + let rt_file = File::open(&out_path).expect("open round_trip avro"); + let mut rt_reader = ReaderBuilder::new() + .build(BufReader::new(rt_file)) + .expect("build round_trip reader"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let round_trip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip"); + assert_eq!( + round_trip, original, + "Round-trip batch mismatch for nested_records.avro" + ); + Ok(()) + } + + #[test] + fn test_roundtrip_nested_lists_writer() -> Result<(), ArrowError> { + let path = arrow_test_data("avro/nested_lists.snappy.avro"); + let rdr_file = File::open(&path).expect("open nested_lists.snappy.avro"); + let mut reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build reader for nested_lists.snappy.avro"); + let schema = reader.schema(); + let batches = reader.collect::, _>>()?; + let original = arrow::compute::concat_batches(&schema, &batches).expect("concat original"); + let tmp = NamedTempFile::new().expect("create temp file"); + let out_path = tmp.into_temp_path(); + { + let out_file = File::create(&out_path).expect("create output avro"); + let mut writer = WriterBuilder::new(original.schema().as_ref().clone()) + .with_compression(Some(CompressionCodec::Snappy)) + .build::<_, AvroOcfFormat>(out_file)?; + writer.write(&original)?; + writer.finish()?; + } + let rt_file = File::open(&out_path).expect("open round_trip avro"); + let mut rt_reader = ReaderBuilder::new() + .build(BufReader::new(rt_file)) + .expect("build round_trip reader"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let round_trip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip"); + assert_eq!( + round_trip, original, + "Round-trip batch mismatch for nested_lists.snappy.avro" + ); + Ok(()) + } + + #[test] + fn test_round_trip_simple_fixed_ocf() -> Result<(), ArrowError> { + let path = arrow_test_data("avro/simple_fixed.avro"); + let rdr_file = File::open(&path).expect("open avro/simple_fixed.avro"); + let mut reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build avro reader"); + let schema = reader.schema(); + let input_batches = reader.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&schema, &input_batches).expect("concat input"); + let tmp = NamedTempFile::new().expect("create temp file"); + let out_file = File::create(tmp.path()).expect("create temp avro"); + let mut writer = AvroWriter::new(out_file, original.schema().as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + drop(writer); + let rt_file = File::open(tmp.path()).expect("open round_trip avro"); + let mut rt_reader = ReaderBuilder::new() + .build(BufReader::new(rt_file)) + .expect("build round_trip reader"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let round_trip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip"); + assert_eq!(round_trip, original); + Ok(()) + } + + #[cfg(not(feature = "canonical_extension_types"))] + #[test] + fn test_round_trip_duration_and_uuid_ocf() -> Result<(), ArrowError> { + let in_file = + File::open("test/data/duration_uuid.avro").expect("open test/data/duration_uuid.avro"); + let mut reader = ReaderBuilder::new() + .build(BufReader::new(in_file)) + .expect("build reader for duration_uuid.avro"); + let in_schema = reader.schema(); + let has_mdn = in_schema.fields().iter().any(|f| { + matches!( + f.data_type(), + DataType::Interval(IntervalUnit::MonthDayNano) + ) + }); + assert!( + has_mdn, + "expected at least one Interval(MonthDayNano) field in duration_uuid.avro" + ); + let has_uuid_fixed = in_schema + .fields() + .iter() + .any(|f| matches!(f.data_type(), DataType::FixedSizeBinary(16))); + assert!( + has_uuid_fixed, + "expected at least one FixedSizeBinary(16) (uuid) field in duration_uuid.avro" + ); + let input_batches = reader.collect::, _>>()?; + let input = + arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input"); + let tmp = NamedTempFile::new().expect("create temp file"); + { + let out_file = File::create(tmp.path()).expect("create temp avro"); + let mut writer = AvroWriter::new(out_file, in_schema.as_ref().clone())?; + writer.write(&input)?; + writer.finish()?; + } + let rt_file = File::open(tmp.path()).expect("open round_trip avro"); + let mut rt_reader = ReaderBuilder::new() + .build(BufReader::new(rt_file)) + .expect("build round_trip reader"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let round_trip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip"); + assert_eq!(round_trip, input); + Ok(()) + } + + // This test reads the same 'nonnullable.impala.avro' used by the reader tests, + // writes it back out with the writer (hitting Map encoding paths), then reads it + // again and asserts exact Arrow equivalence. + #[test] + fn test_nonnullable_impala_roundtrip_writer() -> Result<(), ArrowError> { + // Load source Avro with Map fields + let path = arrow_test_data("avro/nonnullable.impala.avro"); + let rdr_file = File::open(&path).expect("open avro/nonnullable.impala.avro"); + let mut reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build reader for nonnullable.impala.avro"); + // Collect all input batches and concatenate to a single RecordBatch + let in_schema = reader.schema(); + // Sanity: ensure the file actually contains at least one Map field + let has_map = in_schema + .fields() + .iter() + .any(|f| matches!(f.data_type(), DataType::Map(_, _))); + assert!( + has_map, + "expected at least one Map field in avro/nonnullable.impala.avro" + ); + + let input_batches = reader.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input"); + // Write out using the OCF writer into an in-memory Vec + let buffer = Vec::::new(); + let mut writer = AvroWriter::new(buffer, in_schema.as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + let out_bytes = writer.into_inner(); + // Read the produced bytes back with the Reader + let mut rt_reader = ReaderBuilder::new() + .build(Cursor::new(out_bytes)) + .expect("build reader for round-tripped in-memory OCF"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let roundtrip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip"); + // Exact value fidelity (schema + data) + assert_eq!( + roundtrip, original, + "Round-trip Avro map data mismatch for nonnullable.impala.avro" + ); + Ok(()) + } + + #[test] + fn test_roundtrip_decimals_via_writer() -> Result<(), ArrowError> { + // (file, resolve via ARROW_TEST_DATA?) + let files: [(&str, bool); 8] = [ + ("avro/fixed_length_decimal.avro", true), // fixed-backed -> Decimal128(25,2) + ("avro/fixed_length_decimal_legacy.avro", true), // legacy fixed[8] -> Decimal64(13,2) + ("avro/int32_decimal.avro", true), // bytes-backed -> Decimal32(4,2) + ("avro/int64_decimal.avro", true), // bytes-backed -> Decimal64(10,2) + ("test/data/int256_decimal.avro", false), // bytes-backed -> Decimal256(76,2) + ("test/data/fixed256_decimal.avro", false), // fixed[32]-backed -> Decimal256(76,10) + ("test/data/fixed_length_decimal_legacy_32.avro", false), // legacy fixed[4] -> Decimal32(9,2) + ("test/data/int128_decimal.avro", false), // bytes-backed -> Decimal128(38,2) + ]; + for (rel, in_test_data_dir) in files { + // Resolve path the same way as reader::test_decimal + let path: String = if in_test_data_dir { + arrow_test_data(rel) + } else { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join(rel) + .to_string_lossy() + .into_owned() + }; + // Read original file into a single RecordBatch for comparison + let f_in = File::open(&path).expect("open input avro"); + let mut rdr = ReaderBuilder::new().build(BufReader::new(f_in))?; + let in_schema = rdr.schema(); + let in_batches = rdr.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&in_schema, &in_batches).expect("concat input"); + // Write it out with the OCF writer (no special compression) + let tmp = NamedTempFile::new().expect("create temp file"); + let out_path = tmp.into_temp_path(); + let out_file = File::create(&out_path).expect("create temp avro"); + let mut writer = AvroWriter::new(out_file, original.schema().as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + // Read back the file we just wrote and compare equality (schema + data) + let f_rt = File::open(&out_path).expect("open roundtrip avro"); + let mut rt_rdr = ReaderBuilder::new().build(BufReader::new(f_rt))?; + let rt_schema = rt_rdr.schema(); + let rt_batches = rt_rdr.collect::, _>>()?; + let roundtrip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat rt"); + assert_eq!(roundtrip, original, "decimal round-trip mismatch for {rel}"); + } + Ok(()) + } + + #[test] + fn test_enum_roundtrip_uses_reader_fixture() -> Result<(), ArrowError> { + // Read the known-good enum file (same as reader::test_simple) + let path = arrow_test_data("avro/simple_enum.avro"); + let rdr_file = File::open(&path).expect("open avro/simple_enum.avro"); + let mut reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build reader for simple_enum.avro"); + // Concatenate all batches to one RecordBatch for a clean equality check + let in_schema = reader.schema(); + let input_batches = reader.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input"); + // Sanity: expect at least one Dictionary(Int32, Utf8) column (enum) + let has_enum_dict = in_schema.fields().iter().any(|f| { + matches!( + f.data_type(), + DataType::Dictionary(k, v) if **k == DataType::Int32 && **v == DataType::Utf8 + ) + }); + assert!( + has_enum_dict, + "Expected at least one enum-mapped Dictionary field" + ); + // Write with OCF writer into memory using the reader-provided Arrow schema. + // The writer will embed the Avro JSON from `avro.schema` metadata if present. + let buffer: Vec = Vec::new(); + let mut writer = AvroWriter::new(buffer, in_schema.as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + let bytes = writer.into_inner(); + // Read back and compare for exact equality (schema + data) + let mut rt_reader = ReaderBuilder::new() + .build(Cursor::new(bytes)) + .expect("reader for round-trip"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let roundtrip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip"); + assert_eq!(roundtrip, original, "Avro enum round-trip mismatch"); + Ok(()) + } +} diff --git a/arrow-avro/test/data/README.md b/arrow-avro/test/data/README.md new file mode 100644 index 000000000000..51416c8416d4 --- /dev/null +++ b/arrow-avro/test/data/README.md @@ -0,0 +1,147 @@ + + +# Avro test files for `arrow-avro` + +This directory contains small Avro Object Container Files (OCF) used by +`arrow-avro` tests to validate the `Reader` implementation. These files are generated from +a set of python scripts and will gradually be removed as they are merged into `arrow-testing`. + +## Decimal Files + +This directory contains OCF files used to exercise decoding of Avro’s `decimal` logical type +across both `bytes` and `fixed` encodings, and to cover Arrow decimal widths ranging +from `Decimal32` up through `Decimal256`. The files were generated from a +script (see **How these files were created** below). + +> **Avro decimal recap.** Avro’s `decimal` logical type annotates either a +> `bytes` or `fixed` primitive and stores the **two’s‑complement big‑endian +> representation of the unscaled integer** (value × 10^scale). Implementations +> should reject invalid combinations such as `scale > precision`. + +> **Arrow decimal recap.** Arrow defines `Decimal32`, `Decimal64`, `Decimal128`, +> and `Decimal256` types with maximum precisions of 9, 18, 38, and 76 digits, +> respectively. Tests here validate that the Avro reader selects compatible +> Arrow decimal widths given the Avro decimal’s precision and storage. + +--- + +All files are one‑column Avro OCFs with a field named `value`. Each contains 24 +rows with the sequence `1 … 24` rendered at the file’s declared `scale` +(i.e., at scale 10: `1.0000000000`, `2.0000000000`). + +| File | Avro storage | Decimal (precision, scale) | Intended Arrow width | +|---|---|---|---| +| `int256_decimal.avro` | `bytes` + `logicalType: decimal` | (76, 10) | `Decimal256` | +| `fixed256_decimal.avro` | `fixed[32]` + `logicalType: decimal` | (76, 10) | `Decimal256` | +| `fixed_length_decimal_legacy_32.avro` | `fixed[4]` + `logicalType: decimal` | (9, 2) | `Decimal32` (legacy fixed‑width path) | +| `int128_decimal.avro` | `bytes` + `logicalType: decimal` | (38, 2) | `Decimal128` | + +### Schemas (for reference) + +#### int256_decimal.avro + +```json +{ + "type": "record", + "name": "OneColDecimal256Bytes", + "fields": [{ + "name": "value", + "type": { "type": "bytes", "logicalType": "decimal", "precision": 76, "scale": 10 } + }] +} +``` + +#### fixed256_decimal.avro + +```json +{ + "type": "record", + "name": "OneColDecimal256Fixed", + "fields": [{ + "name": "value", + "type": { + "type": "fixed", "name": "Decimal256Fixed", "size": 32, + "logicalType": "decimal", "precision": 76, "scale": 10 + } + }] +} +``` + +#### fixed_length_decimal_legacy_32.avro + +```json +{ + "type": "record", + "name": "OneColDecimal32FixedLegacy", + "fields": [{ + "name": "value", + "type": { + "type": "fixed", "name": "Decimal32FixedLegacy", "size": 4, + "logicalType": "decimal", "precision": 9, "scale": 2 + } + }] +} +``` + +#### int128_decimal.avro + +```json +{ + "type": "record", + "name": "OneColDecimal128Bytes", + "fields": [{ + "name": "value", + "type": { "type": "bytes", "logicalType": "decimal", "precision": 38, "scale": 2 } + }] +} +``` + +### How these files were created + +All four files were generated by the Python script +`create_avro_decimal_files.py` authored for this purpose. The script uses +`fastavro` to write OCFs and encodes decimal values as required by the Avro +spec (two’s‑complement big‑endian of the unscaled integer). + +#### Re‑generation + +From the repository root (defaults write into arrow-avro/test/data): + +```bash +# 1) Ensure Python 3 is available, then install fastavro +python -m pip install --upgrade fastavro + +# 2) Fetch the script +curl -L -o create_avro_decimal_files.py \ +https://gist.githubusercontent.com/jecsand838/3890349bdb33082a3e8fdcae3257eef7/raw/create_avro_decimal_files.py + +# 3) Generate the files (prints a verification dump by default) +python create_avro_decimal_files.py -o arrow-avro/test/data +``` + +Options: +* --num-rows (default 24) — number of rows to emit per file +* --scale (default 10) — the decimal scale used for the 256 files +* --no-verify — skip reading the files back for printed verification + +## Other Files + +This directory contains other small OCF files used by `arrow-avro` tests. Details on these will be added in +follow-up PRs. \ No newline at end of file diff --git a/arrow-avro/test/data/fixed256_decimal.avro b/arrow-avro/test/data/fixed256_decimal.avro new file mode 100644 index 000000000000..d1fc97dd8c83 Binary files /dev/null and b/arrow-avro/test/data/fixed256_decimal.avro differ diff --git a/arrow-avro/test/data/fixed_length_decimal_legacy_32.avro b/arrow-avro/test/data/fixed_length_decimal_legacy_32.avro new file mode 100644 index 000000000000..b746df9619b5 Binary files /dev/null and b/arrow-avro/test/data/fixed_length_decimal_legacy_32.avro differ diff --git a/arrow-avro/test/data/int128_decimal.avro b/arrow-avro/test/data/int128_decimal.avro new file mode 100644 index 000000000000..bd54d20ba487 Binary files /dev/null and b/arrow-avro/test/data/int128_decimal.avro differ diff --git a/arrow-avro/test/data/int256_decimal.avro b/arrow-avro/test/data/int256_decimal.avro new file mode 100644 index 000000000000..62ad7ea4df08 Binary files /dev/null and b/arrow-avro/test/data/int256_decimal.avro differ diff --git a/arrow-avro/test/data/skippable_types.avro b/arrow-avro/test/data/skippable_types.avro new file mode 100644 index 000000000000..b0518e0056b5 Binary files /dev/null and b/arrow-avro/test/data/skippable_types.avro differ diff --git a/arrow-buffer/benches/i256.rs b/arrow-buffer/benches/i256.rs index 7dec226bbc08..11aaa83c8d53 100644 --- a/arrow-buffer/benches/i256.rs +++ b/arrow-buffer/benches/i256.rs @@ -17,6 +17,7 @@ use arrow_buffer::i256; use criterion::*; +use num::cast::ToPrimitive; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use std::{hint, str::FromStr}; @@ -36,13 +37,19 @@ fn criterion_benchmark(c: &mut Criterion) { i256::MAX, ]; - for number in numbers { + for number in numbers.iter() { let t = hint::black_box(number.to_string()); c.bench_function(&format!("i256_parse({t})"), |b| { b.iter(|| i256::from_str(&t).unwrap()); }); } + for number in numbers.iter() { + c.bench_function(&format!("i256_to_f64({number})"), |b| { + b.iter(|| (*number).to_f64().unwrap()) + }); + } + let mut rng = StdRng::seed_from_u64(42); let numerators: Vec<_> = (0..SIZE) diff --git a/arrow-buffer/src/bigint/mod.rs b/arrow-buffer/src/bigint/mod.rs index 92f11d68d318..d7959a71abb2 100644 --- a/arrow-buffer/src/bigint/mod.rs +++ b/arrow-buffer/src/bigint/mod.rs @@ -586,6 +586,25 @@ impl i256 { pub const fn is_positive(self) -> bool { self.high.is_positive() || self.high == 0 && self.low != 0 } + + fn leading_zeros(&self) -> u32 { + match self.high { + 0 => u128::BITS + self.low.leading_zeros(), + _ => self.high.leading_zeros(), + } + } + + fn redundant_leading_sign_bits_i256(n: i256) -> u8 { + let mask = n >> 255; // all ones or all zeros + ((n ^ mask).leading_zeros() - 1) as u8 // we only need one sign bit + } + + fn i256_to_f64(input: i256) -> f64 { + let k = i256::redundant_leading_sign_bits_i256(input); + let n = input << k; // left-justify (no redundant sign bits) + let n = (n.high >> 64) as i64; // throw away the lower 192 bits + (n as f64) * f64::powi(2.0, 192 - (k as i32)) // convert to f64 and scale it, as we left-shift k bit previous, so we need to scale it by 2^(192-k) + } } /// Temporary workaround due to lack of stable const array slicing @@ -822,19 +841,14 @@ impl ToPrimitive for i256 { } fn to_f64(&self) -> Option { - let mag = if let Some(u) = self.checked_abs() { - let (low, high) = u.to_parts(); - (high as f64) * 2_f64.powi(128) + (low as f64) - } else { - // self == MIN - 2_f64.powi(255) - }; - if *self < i256::ZERO { - Some(-mag) - } else { - Some(mag) + match *self { + Self::MIN => Some(-2_f64.powi(255)), + Self::ZERO => Some(0f64), + Self::ONE => Some(1f64), + n => Some(Self::i256_to_f64(n)), } } + fn to_u64(&self) -> Option { let as_i128 = self.low as i128; @@ -1286,6 +1300,20 @@ mod tests { let v = i256::from_i128(-123456789012345678i128); assert_eq!(v.to_f64().unwrap(), -123456789012345678.0); + + let v = i256::from_string("0").unwrap(); + assert_eq!(v.to_f64().unwrap(), 0.0); + + let v = i256::from_string("1").unwrap(); + assert_eq!(v.to_f64().unwrap(), 1.0); + + let mut rng = rng(); + for _ in 0..10 { + let f64_value = + (rng.random_range(i128::MIN..i128::MAX) as f64) * rng.random_range(0.0..1.0); + let big = i256::from_f64(f64_value).unwrap(); + assert_eq!(big.to_f64().unwrap(), f64_value); + } } #[test] diff --git a/arrow-cast/Cargo.toml b/arrow-cast/Cargo.toml index 49145cf987f9..32bbd35e811d 100644 --- a/arrow-cast/Cargo.toml +++ b/arrow-cast/Cargo.toml @@ -50,7 +50,8 @@ half = { version = "2.1", default-features = false } num = { version = "0.4", default-features = false, features = ["std"] } lexical-core = { version = "1.0", default-features = false, features = ["write-integers", "write-floats", "parse-integers", "parse-floats"] } atoi = "2.0.0" -comfy-table = { version = "7.0", optional = true, default-features = false } +# unpin after MSRV bump to 1.85 +comfy-table = { version = "=7.1.2", optional = true, default-features = false } base64 = "0.22" ryu = "1.0.16" diff --git a/arrow-cast/src/cast/decimal.rs b/arrow-cast/src/cast/decimal.rs index 597f384fa452..095e31274887 100644 --- a/arrow-cast/src/cast/decimal.rs +++ b/arrow-cast/src/cast/decimal.rs @@ -20,6 +20,10 @@ use crate::cast::*; /// A utility trait that provides checked conversions between /// decimal types inspired by [`NumCast`] pub(crate) trait DecimalCast: Sized { + fn to_i32(self) -> Option; + + fn to_i64(self) -> Option; + fn to_i128(self) -> Option; fn to_i256(self) -> Option; @@ -29,7 +33,69 @@ pub(crate) trait DecimalCast: Sized { fn from_f64(n: f64) -> Option; } +impl DecimalCast for i32 { + fn to_i32(self) -> Option { + Some(self) + } + + fn to_i64(self) -> Option { + Some(self as i64) + } + + fn to_i128(self) -> Option { + Some(self as i128) + } + + fn to_i256(self) -> Option { + Some(i256::from_i128(self as i128)) + } + + fn from_decimal(n: T) -> Option { + n.to_i32() + } + + fn from_f64(n: f64) -> Option { + n.to_i32() + } +} + +impl DecimalCast for i64 { + fn to_i32(self) -> Option { + i32::try_from(self).ok() + } + + fn to_i64(self) -> Option { + Some(self) + } + + fn to_i128(self) -> Option { + Some(self as i128) + } + + fn to_i256(self) -> Option { + Some(i256::from_i128(self as i128)) + } + + fn from_decimal(n: T) -> Option { + n.to_i64() + } + + fn from_f64(n: f64) -> Option { + // Call implementation explicitly otherwise this resolves to `to_i64` + // in arrow-buffer that behaves differently. + num::traits::ToPrimitive::to_i64(&n) + } +} + impl DecimalCast for i128 { + fn to_i32(self) -> Option { + i32::try_from(self).ok() + } + + fn to_i64(self) -> Option { + i64::try_from(self).ok() + } + fn to_i128(self) -> Option { Some(self) } @@ -48,6 +114,14 @@ impl DecimalCast for i128 { } impl DecimalCast for i256 { + fn to_i32(self) -> Option { + self.to_i128().map(|x| i32::try_from(x).ok())? + } + + fn to_i64(self) -> Option { + self.to_i128().map(|x| i64::try_from(x).ok())? + } + fn to_i128(self) -> Option { self.to_i128() } diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index 8fb0c4fdd15d..fc241bea48da 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -148,8 +148,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { can_cast_types(list_from.data_type(), list_to.data_type()) } (List(_), _) => false, - (FixedSizeList(list_from,_), List(list_to)) | - (FixedSizeList(list_from,_), LargeList(list_to)) => { + (FixedSizeList(list_from, _), List(list_to)) + | (FixedSizeList(list_from, _), LargeList(list_to)) => { can_cast_types(list_from.data_type(), list_to.data_type()) } (FixedSizeList(inner, size), FixedSizeList(inner_to, size_to)) if size == size_to => { @@ -157,38 +157,66 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { } (_, List(list_to)) => can_cast_types(from_type, list_to.data_type()), (_, LargeList(list_to)) => can_cast_types(from_type, list_to.data_type()), - (_, FixedSizeList(list_to,size)) if *size == 1 => { - can_cast_types(from_type, list_to.data_type())}, - (FixedSizeList(list_from,size), _) if *size == 1 => { - can_cast_types(list_from.data_type(), to_type)}, - (Map(from_entries,ordered_from), Map(to_entries, ordered_to)) if ordered_from == ordered_to => - match (key_field(from_entries), key_field(to_entries), value_field(from_entries), value_field(to_entries)) { - (Some(from_key), Some(to_key), Some(from_value), Some(to_value)) => - can_cast_types(from_key.data_type(), to_key.data_type()) && can_cast_types(from_value.data_type(), to_value.data_type()), - _ => false - }, + (_, FixedSizeList(list_to, size)) if *size == 1 => { + can_cast_types(from_type, list_to.data_type()) + } + (FixedSizeList(list_from, size), _) if *size == 1 => { + can_cast_types(list_from.data_type(), to_type) + } + (Map(from_entries, ordered_from), Map(to_entries, ordered_to)) + if ordered_from == ordered_to => + { + match ( + key_field(from_entries), + key_field(to_entries), + value_field(from_entries), + value_field(to_entries), + ) { + (Some(from_key), Some(to_key), Some(from_value), Some(to_value)) => { + can_cast_types(from_key.data_type(), to_key.data_type()) + && can_cast_types(from_value.data_type(), to_value.data_type()) + } + _ => false, + } + } // cast one decimal type to another decimal type - (Decimal128(_, _), Decimal128(_, _)) => true, - (Decimal256(_, _), Decimal256(_, _)) => true, - (Decimal128(_, _), Decimal256(_, _)) => true, - (Decimal256(_, _), Decimal128(_, _)) => true, + ( + Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), + Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), + ) => true, // unsigned integer to decimal - (UInt8 | UInt16 | UInt32 | UInt64, Decimal128(_, _)) | - (UInt8 | UInt16 | UInt32 | UInt64, Decimal256(_, _)) | + ( + UInt8 | UInt16 | UInt32 | UInt64, + Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), + ) => true, // signed numeric to decimal - (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal128(_, _)) | - (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal256(_, _)) | + ( + Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, + Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), + ) => true, // decimal to unsigned numeric - (Decimal128(_, _) | Decimal256(_, _), UInt8 | UInt16 | UInt32 | UInt64) | + ( + Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), + UInt8 | UInt16 | UInt32 | UInt64, + ) => true, // decimal to signed numeric - (Decimal128(_, _) | Decimal256(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64) => true, + ( + Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), + Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, + ) => true, // decimal to string - (Decimal128(_, _) | Decimal256(_, _), Utf8View | Utf8 | LargeUtf8) => true, + ( + Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), + Utf8View | Utf8 | LargeUtf8, + ) => true, // string to decimal - (Utf8View | Utf8 | LargeUtf8, Decimal128(_, _) | Decimal256(_, _)) => true, + ( + Utf8View | Utf8 | LargeUtf8, + Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), + ) => true, (Struct(from_fields), Struct(to_fields)) => { - from_fields.len() == to_fields.len() && - from_fields.iter().zip(to_fields.iter()).all(|(f1, f2)| { + from_fields.len() == to_fields.len() + && from_fields.iter().zip(to_fields.iter()).all(|(f1, f2)| { // Assume that nullability between two structs are compatible, if not, // cast kernel will return error. can_cast_types(f1.data_type(), f2.data_type()) @@ -211,8 +239,12 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { || to_type == &LargeUtf8 } - (Binary, LargeBinary | Utf8 | LargeUtf8 | FixedSizeBinary(_) | BinaryView | Utf8View ) => true, - (LargeBinary, Binary | Utf8 | LargeUtf8 | FixedSizeBinary(_) | BinaryView | Utf8View ) => true, + (Binary, LargeBinary | Utf8 | LargeUtf8 | FixedSizeBinary(_) | BinaryView | Utf8View) => { + true + } + (LargeBinary, Binary | Utf8 | LargeUtf8 | FixedSizeBinary(_) | BinaryView | Utf8View) => { + true + } (FixedSizeBinary(_), Binary | LargeBinary | BinaryView) => true, ( Utf8 | LargeUtf8 | Utf8View, @@ -236,15 +268,16 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Utf8 | LargeUtf8, Utf8View) => true, (BinaryView, Binary | LargeBinary | Utf8 | LargeUtf8 | Utf8View) => true, (Utf8View | Utf8 | LargeUtf8, _) => to_type.is_numeric() && to_type != &Float16, - (_, Utf8 | LargeUtf8) => from_type.is_primitive(), - (_, Utf8View) => from_type.is_numeric(), + (_, Utf8 | Utf8View | LargeUtf8) => from_type.is_primitive(), (_, Binary | LargeBinary) => from_type.is_integer(), // start numeric casts ( - UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float16 | Float32 | Float64, - UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float16 | Float32 | Float64, + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float16 | Float32 + | Float64, + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float16 | Float32 + | Float64, ) => true, // end numeric casts @@ -847,6 +880,26 @@ pub fn cast_with_options( cast_map_values(array.as_map(), to_type, cast_options, ordered1.to_owned()) } // Decimal to decimal, same width + (Decimal32(p1, s1), Decimal32(p2, s2)) => { + cast_decimal_to_decimal_same_type::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal64(p1, s1), Decimal64(p2, s2)) => { + cast_decimal_to_decimal_same_type::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } (Decimal128(p1, s1), Decimal128(p2, s2)) => { cast_decimal_to_decimal_same_type::( array.as_primitive(), @@ -868,6 +921,86 @@ pub fn cast_with_options( ) } // Decimal to decimal, different width + (Decimal32(p1, s1), Decimal64(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal32(p1, s1), Decimal128(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal32(p1, s1), Decimal256(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal64(p1, s1), Decimal32(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal64(p1, s1), Decimal128(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal64(p1, s1), Decimal256(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal128(p1, s1), Decimal32(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal128(p1, s1), Decimal64(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } (Decimal128(p1, s1), Decimal256(p2, s2)) => { cast_decimal_to_decimal::( array.as_primitive(), @@ -878,6 +1011,26 @@ pub fn cast_with_options( cast_options, ) } + (Decimal256(p1, s1), Decimal32(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal256(p1, s1), Decimal64(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } (Decimal256(p1, s1), Decimal128(p2, s2)) => { cast_decimal_to_decimal::( array.as_primitive(), @@ -889,6 +1042,28 @@ pub fn cast_with_options( ) } // Decimal to non-decimal + (Decimal32(_, scale), _) if !to_type.is_temporal() => { + cast_from_decimal::( + array, + 10_i32, + scale, + from_type, + to_type, + |x: i32| x as f64, + cast_options, + ) + } + (Decimal64(_, scale), _) if !to_type.is_temporal() => { + cast_from_decimal::( + array, + 10_i64, + scale, + from_type, + to_type, + |x: i64| x as f64, + cast_options, + ) + } (Decimal128(_, scale), _) if !to_type.is_temporal() => { cast_from_decimal::( array, @@ -912,6 +1087,28 @@ pub fn cast_with_options( ) } // Non-decimal to decimal + (_, Decimal32(precision, scale)) if !from_type.is_temporal() => { + cast_to_decimal::( + array, + 10_i32, + precision, + scale, + from_type, + to_type, + cast_options, + ) + } + (_, Decimal64(precision, scale)) if !from_type.is_temporal() => { + cast_to_decimal::( + array, + 10_i64, + precision, + scale, + from_type, + to_type, + cast_options, + ) + } (_, Decimal128(precision, scale)) if !from_type.is_temporal() => { cast_to_decimal::( array, @@ -2524,6 +2721,28 @@ mod tests { } } + fn create_decimal32_array( + array: Vec>, + precision: u8, + scale: i8, + ) -> Result { + array + .into_iter() + .collect::() + .with_precision_and_scale(precision, scale) + } + + fn create_decimal64_array( + array: Vec>, + precision: u8, + scale: i8, + ) -> Result { + array + .into_iter() + .collect::() + .with_precision_and_scale(precision, scale) + } + fn create_decimal128_array( array: Vec>, precision: u8, @@ -2672,8 +2891,77 @@ mod tests { ); } + #[test] + fn test_cast_decimal32_to_decimal32() { + // test changing precision + let input_type = DataType::Decimal32(9, 3); + let output_type = DataType::Decimal32(9, 4); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let array = create_decimal32_array(array, 9, 3).unwrap(); + generate_cast_test_case!( + &array, + Decimal32Array, + &output_type, + vec![ + Some(11234560_i32), + Some(21234560_i32), + Some(31234560_i32), + None + ] + ); + // negative test + let array = vec![Some(123456), None]; + let array = create_decimal32_array(array, 9, 0).unwrap(); + let result_safe = cast(&array, &DataType::Decimal32(2, 2)); + assert!(result_safe.is_ok()); + let options = CastOptions { + safe: false, + ..Default::default() + }; + + let result_unsafe = cast_with_options(&array, &DataType::Decimal32(2, 2), &options); + assert_eq!("Invalid argument error: 12345600 is too large to store in a Decimal32 of precision 2. Max is 99", + result_unsafe.unwrap_err().to_string()); + } + + #[test] + fn test_cast_decimal64_to_decimal64() { + // test changing precision + let input_type = DataType::Decimal64(17, 3); + let output_type = DataType::Decimal64(17, 4); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let array = create_decimal64_array(array, 17, 3).unwrap(); + generate_cast_test_case!( + &array, + Decimal64Array, + &output_type, + vec![ + Some(11234560_i64), + Some(21234560_i64), + Some(31234560_i64), + None + ] + ); + // negative test + let array = vec![Some(123456), None]; + let array = create_decimal64_array(array, 9, 0).unwrap(); + let result_safe = cast(&array, &DataType::Decimal64(2, 2)); + assert!(result_safe.is_ok()); + let options = CastOptions { + safe: false, + ..Default::default() + }; + + let result_unsafe = cast_with_options(&array, &DataType::Decimal64(2, 2), &options); + assert_eq!("Invalid argument error: 12345600 is too large to store in a Decimal64 of precision 2. Max is 99", + result_unsafe.unwrap_err().to_string()); + } + #[test] fn test_cast_decimal128_to_decimal128() { + // test changing precision let input_type = DataType::Decimal128(20, 3); let output_type = DataType::Decimal128(20, 4); assert!(can_cast_types(&input_type, &output_type)); @@ -2705,6 +2993,38 @@ mod tests { result_unsafe.unwrap_err().to_string()); } + #[test] + fn test_cast_decimal32_to_decimal32_dict() { + let p = 9; + let s = 3; + let input_type = DataType::Decimal32(p, s); + let output_type = DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Decimal32(p, s)), + ); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let array = create_decimal32_array(array, p, s).unwrap(); + let cast_array = cast_with_options(&array, &output_type, &CastOptions::default()).unwrap(); + assert_eq!(cast_array.data_type(), &output_type); + } + + #[test] + fn test_cast_decimal64_to_decimal64_dict() { + let p = 15; + let s = 3; + let input_type = DataType::Decimal64(p, s); + let output_type = DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Decimal64(p, s)), + ); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let array = create_decimal64_array(array, p, s).unwrap(); + let cast_array = cast_with_options(&array, &output_type, &CastOptions::default()).unwrap(); + assert_eq!(cast_array.data_type(), &output_type); + } + #[test] fn test_cast_decimal128_to_decimal128_dict() { let p = 20; @@ -2737,6 +3057,79 @@ mod tests { assert_eq!(cast_array.data_type(), &output_type); } + #[test] + fn test_cast_decimal32_to_decimal32_overflow() { + let input_type = DataType::Decimal32(9, 3); + let output_type = DataType::Decimal32(9, 9); + assert!(can_cast_types(&input_type, &output_type)); + + let array = vec![Some(i32::MAX)]; + let array = create_decimal32_array(array, 9, 3).unwrap(); + let result = cast_with_options( + &array, + &output_type, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!( + "Cast error: Cannot cast to Decimal32(9, 9). Overflowing on 2147483647", + result.unwrap_err().to_string() + ); + } + + #[test] + fn test_cast_decimal64_to_decimal64_overflow() { + let input_type = DataType::Decimal64(18, 3); + let output_type = DataType::Decimal64(18, 18); + assert!(can_cast_types(&input_type, &output_type)); + + let array = vec![Some(i64::MAX)]; + let array = create_decimal64_array(array, 18, 3).unwrap(); + let result = cast_with_options( + &array, + &output_type, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!( + "Cast error: Cannot cast to Decimal64(18, 18). Overflowing on 9223372036854775807", + result.unwrap_err().to_string() + ); + } + + #[test] + fn test_cast_floating_to_decimals() { + for output_type in [ + DataType::Decimal32(9, 3), + DataType::Decimal64(9, 3), + DataType::Decimal128(9, 3), + DataType::Decimal256(9, 3), + ] { + let input_type = DataType::Float64; + assert!(can_cast_types(&input_type, &output_type)); + + let array = vec![Some(1.1_f64)]; + let array = PrimitiveArray::::from_iter(array); + let result = cast_with_options( + &array, + &output_type, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert!( + result.is_ok(), + "Failed to cast to {output_type} with: {}", + result.unwrap_err() + ); + } + } + #[test] fn test_cast_decimal128_to_decimal128_overflow() { let input_type = DataType::Decimal128(38, 3); @@ -2777,6 +3170,44 @@ mod tests { result.unwrap_err().to_string()); } + #[test] + fn test_cast_decimal32_to_decimal256() { + let input_type = DataType::Decimal32(8, 3); + let output_type = DataType::Decimal256(20, 4); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let array = create_decimal32_array(array, 8, 3).unwrap(); + generate_cast_test_case!( + &array, + Decimal256Array, + &output_type, + vec![ + Some(i256::from_i128(11234560_i128)), + Some(i256::from_i128(21234560_i128)), + Some(i256::from_i128(31234560_i128)), + None + ] + ); + } + #[test] + fn test_cast_decimal64_to_decimal256() { + let input_type = DataType::Decimal64(12, 3); + let output_type = DataType::Decimal256(20, 4); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let array = create_decimal64_array(array, 12, 3).unwrap(); + generate_cast_test_case!( + &array, + Decimal256Array, + &output_type, + vec![ + Some(i256::from_i128(11234560_i128)), + Some(i256::from_i128(21234560_i128)), + Some(i256::from_i128(31234560_i128)), + None + ] + ); + } #[test] fn test_cast_decimal128_to_decimal256() { let input_type = DataType::Decimal128(20, 3); @@ -2973,6 +3404,22 @@ mod tests { ); } + #[test] + fn test_cast_decimal32_to_numeric() { + let value_array: Vec> = vec![Some(125), Some(225), Some(325), None, Some(525)]; + let array = create_decimal32_array(value_array, 8, 2).unwrap(); + + generate_decimal_to_numeric_cast_test_case(&array); + } + + #[test] + fn test_cast_decimal64_to_numeric() { + let value_array: Vec> = vec![Some(125), Some(225), Some(325), None, Some(525)]; + let array = create_decimal64_array(value_array, 8, 2).unwrap(); + + generate_decimal_to_numeric_cast_test_case(&array); + } + #[test] fn test_cast_decimal128_to_numeric() { let value_array: Vec> = vec![Some(125), Some(225), Some(325), None, Some(525)]; @@ -5356,28 +5803,9 @@ mod tests { assert!(c.is_null(2)); } - #[test] - fn test_cast_date32_to_string() { - let array = Date32Array::from(vec![10000, 17890]); - let b = cast(&array, &DataType::Utf8).unwrap(); - let c = b.as_any().downcast_ref::().unwrap(); - assert_eq!(&DataType::Utf8, c.data_type()); - assert_eq!("1997-05-19", c.value(0)); - assert_eq!("2018-12-25", c.value(1)); - } - - #[test] - fn test_cast_date64_to_string() { - let array = Date64Array::from(vec![10000 * 86400000, 17890 * 86400000]); - let b = cast(&array, &DataType::Utf8).unwrap(); - let c = b.as_any().downcast_ref::().unwrap(); - assert_eq!(&DataType::Utf8, c.data_type()); - assert_eq!("1997-05-19T00:00:00", c.value(0)); - assert_eq!("2018-12-25T00:00:00", c.value(1)); - } - - macro_rules! assert_cast_timestamp_to_string { + macro_rules! assert_cast { ($array:expr, $datatype:expr, $output_array_type: ty, $expected:expr) => {{ + assert!(can_cast_types($array.data_type(), &$datatype)); let out = cast(&$array, &$datatype).unwrap(); let actual = out .as_any() @@ -5388,6 +5816,7 @@ mod tests { assert_eq!(actual, $expected); }}; ($array:expr, $datatype:expr, $output_array_type: ty, $options:expr, $expected:expr) => {{ + assert!(can_cast_types($array.data_type(), &$datatype)); let out = cast_with_options(&$array, &$datatype, &$options).unwrap(); let actual = out .as_any() @@ -5399,6 +5828,44 @@ mod tests { }}; } + #[test] + fn test_cast_date32_to_string() { + let array = Date32Array::from(vec![Some(0), Some(10000), Some(13036), Some(17890), None]); + let expected = vec![ + Some("1970-01-01"), + Some("1997-05-19"), + Some("2005-09-10"), + Some("2018-12-25"), + None, + ]; + + assert_cast!(array, DataType::Utf8View, StringViewArray, expected); + assert_cast!(array, DataType::Utf8, StringArray, expected); + assert_cast!(array, DataType::LargeUtf8, LargeStringArray, expected); + } + + #[test] + fn test_cast_date64_to_string() { + let array = Date64Array::from(vec![ + Some(0), + Some(10000 * 86400000), + Some(13036 * 86400000), + Some(17890 * 86400000), + None, + ]); + let expected = vec![ + Some("1970-01-01T00:00:00"), + Some("1997-05-19T00:00:00"), + Some("2005-09-10T00:00:00"), + Some("2018-12-25T00:00:00"), + None, + ]; + + assert_cast!(array, DataType::Utf8View, StringViewArray, expected); + assert_cast!(array, DataType::Utf8, StringArray, expected); + assert_cast!(array, DataType::LargeUtf8, LargeStringArray, expected); + } + #[test] fn test_cast_date32_to_timestamp_and_timestamp_with_timezone() { let tz = "+0545"; // UTC + 0545 is Asia/Kathmandu @@ -5601,9 +6068,9 @@ mod tests { None, ]; - assert_cast_timestamp_to_string!(array, DataType::Utf8View, StringViewArray, expected); - assert_cast_timestamp_to_string!(array, DataType::Utf8, StringArray, expected); - assert_cast_timestamp_to_string!(array, DataType::LargeUtf8, LargeStringArray, expected); + assert_cast!(array, DataType::Utf8View, StringViewArray, expected); + assert_cast!(array, DataType::Utf8, StringArray, expected); + assert_cast!(array, DataType::LargeUtf8, LargeStringArray, expected); } #[test] @@ -5625,21 +6092,21 @@ mod tests { Some("2018-12-25 00:00:02.001000"), None, ]; - assert_cast_timestamp_to_string!( + assert_cast!( array_without_tz, DataType::Utf8View, StringViewArray, cast_options, expected ); - assert_cast_timestamp_to_string!( + assert_cast!( array_without_tz, DataType::Utf8, StringArray, cast_options, expected ); - assert_cast_timestamp_to_string!( + assert_cast!( array_without_tz, DataType::LargeUtf8, LargeStringArray, @@ -5655,21 +6122,21 @@ mod tests { Some("2018-12-25 05:45:02.001000"), None, ]; - assert_cast_timestamp_to_string!( + assert_cast!( array_with_tz, DataType::Utf8View, StringViewArray, cast_options, expected ); - assert_cast_timestamp_to_string!( + assert_cast!( array_with_tz, DataType::Utf8, StringArray, cast_options, expected ); - assert_cast_timestamp_to_string!( + assert_cast!( array_with_tz, DataType::LargeUtf8, LargeStringArray, @@ -9559,6 +10026,14 @@ mod tests { #[test] fn test_cast_decimal_to_string() { + assert!(can_cast_types( + &DataType::Decimal32(9, 4), + &DataType::Utf8View + )); + assert!(can_cast_types( + &DataType::Decimal64(16, 4), + &DataType::Utf8View + )); assert!(can_cast_types( &DataType::Decimal128(10, 4), &DataType::Utf8View @@ -9603,7 +10078,7 @@ mod tests { } } - let array128: Vec> = vec![ + let array32: Vec> = vec![ Some(1123454), Some(2123456), Some(-3123453), @@ -9614,11 +10089,40 @@ mod tests { Some(-123456789), None, ]; + let array64: Vec> = array32.iter().map(|num| num.map(|x| x as i64)).collect(); + let array128: Vec> = + array64.iter().map(|num| num.map(|x| x as i128)).collect(); let array256: Vec> = array128 .iter() .map(|num| num.map(i256::from_i128)) .collect(); + test_decimal_to_string::( + DataType::Utf8View, + create_decimal32_array(array32.clone(), 7, 3).unwrap(), + ); + test_decimal_to_string::( + DataType::Utf8, + create_decimal32_array(array32.clone(), 7, 3).unwrap(), + ); + test_decimal_to_string::( + DataType::LargeUtf8, + create_decimal32_array(array32, 7, 3).unwrap(), + ); + + test_decimal_to_string::( + DataType::Utf8View, + create_decimal64_array(array64.clone(), 7, 3).unwrap(), + ); + test_decimal_to_string::( + DataType::Utf8, + create_decimal64_array(array64.clone(), 7, 3).unwrap(), + ); + test_decimal_to_string::( + DataType::LargeUtf8, + create_decimal64_array(array64, 7, 3).unwrap(), + ); + test_decimal_to_string::( DataType::Utf8View, create_decimal128_array(array128.clone(), 7, 3).unwrap(), diff --git a/arrow-csv/src/writer.rs b/arrow-csv/src/writer.rs index c2cb38a226b6..e10943a6a91c 100644 --- a/arrow-csv/src/writer.rs +++ b/arrow-csv/src/writer.rs @@ -102,7 +102,7 @@ impl Writer { WriterBuilder::new().with_delimiter(delimiter).build(writer) } - /// Write a vector of record batches to a writable object + /// Write a RecordBatch to a writable object pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { let num_columns = batch.num_columns(); if self.beginning { diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml index ca0d1c5e4b3d..854a149473d1 100644 --- a/arrow-flight/Cargo.toml +++ b/arrow-flight/Cargo.toml @@ -70,7 +70,7 @@ tls-ring = ["tonic/tls-ring"] tls-webpki-roots = ["tonic/tls-webpki-roots"] # Enable CLI tools -cli = ["arrow-array/chrono-tz", "arrow-cast/prettyprint", "tonic/tls-webpki-roots", "dep:anyhow", "dep:clap", "dep:tracing-log", "dep:tracing-subscriber", "dep:tokio"] +cli = ["arrow-array/chrono-tz", "arrow-cast/prettyprint", "tonic/tls-webpki-roots", "tonic/gzip", "tonic/deflate", "tonic/zstd", "dep:anyhow", "dep:clap", "dep:tracing-log", "dep:tracing-subscriber", "dep:tokio"] [dev-dependencies] arrow-cast = { workspace = true, features = ["prettyprint"] } diff --git a/arrow-flight/src/bin/flight_sql_client.rs b/arrow-flight/src/bin/flight_sql_client.rs index 7b9e34898ac8..154b59f5d379 100644 --- a/arrow-flight/src/bin/flight_sql_client.rs +++ b/arrow-flight/src/bin/flight_sql_client.rs @@ -21,11 +21,12 @@ use anyhow::{bail, Context, Result}; use arrow_array::{ArrayRef, Datum, RecordBatch, StringArray}; use arrow_cast::{cast_with_options, pretty::pretty_format_batches, CastOptions}; use arrow_flight::{ + flight_service_client::FlightServiceClient, sql::{client::FlightSqlServiceClient, CommandGetDbSchemas, CommandGetTables}, FlightInfo, }; use arrow_schema::Schema; -use clap::{Parser, Subcommand}; +use clap::{Parser, Subcommand, ValueEnum}; use core::str; use futures::TryStreamExt; use tonic::{ @@ -53,6 +54,24 @@ pub struct LoggingArgs { log_verbose_count: u8, } +/// gRPC/HTTP compression algorithms. +#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)] +pub enum CompressionEncoding { + Gzip, + Deflate, + Zstd, +} + +impl From for tonic::codec::CompressionEncoding { + fn from(encoding: CompressionEncoding) -> Self { + match encoding { + CompressionEncoding::Gzip => Self::Gzip, + CompressionEncoding::Deflate => Self::Deflate, + CompressionEncoding::Zstd => Self::Zstd, + } + } +} + #[derive(Debug, Parser)] struct ClientArgs { /// Additional headers. @@ -85,6 +104,14 @@ struct ClientArgs { #[clap(long)] tls: bool, + /// Dump TLS key log. + /// + /// The target file is specified by the `SSLKEYLOGFILE` environment variable. + /// + /// Requires `--tls`. + #[clap(long, requires = "tls")] + key_log: bool, + /// Server host. /// /// Required. @@ -96,6 +123,34 @@ struct ClientArgs { /// Defaults to `443` if `tls` is set, otherwise defaults to `80`. #[clap(long)] port: Option, + + /// Compression accepted by the client for responses sent by the server. + /// + /// The client will send this information to the server as part of the request. The server is free to pick an + /// algorithm from that list or use no compression (called "identity" encoding). + /// + /// You may define multiple algorithms by using a comma-separated list. + #[clap(long, value_delimiter = ',')] + accept_compression: Vec, + + /// Compression of requests sent by the client to the server. + /// + /// Since the client needs to decide on the compression before sending the request, there is no client<->server + /// negotiation. If the server does NOT support the chosen compression, it will respond with an error a la: + /// + /// ``` + /// Ipc error: Status { + /// code: Unimplemented, + /// message: "Content is compressed with `zstd` which isn't supported", + /// metadata: MetadataMap { headers: {"grpc-accept-encoding": "identity", ...} }, + /// ... + /// } + /// ``` + /// + /// Based on the algorithms listed in the `grpc-accept-encoding` header, you may make a more educated guess for + /// your next request. Note that `identity` is a synonym for "no compression". + #[clap(long)] + send_compression: Option, } #[derive(Debug, Parser)] @@ -357,7 +412,11 @@ async fn setup_client(args: ClientArgs) -> Result Result { .len(run_array_length) .offset(0) .add_child_data(run_ends.into_data()) - .add_child_data(values.into_data()); + .add_child_data(values.into_data()) + .null_count(run_node.null_count() as usize); + self.create_array_from_builder(builder) } // Create dictionary array from RecordBatch @@ -247,7 +253,7 @@ impl RecordBatchDecoder<'_> { ) -> Result { let length = field_node.length() as usize; let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); - let builder = match data_type { + let mut builder = match data_type { Utf8 | Binary | LargeBinary | LargeUtf8 => { // read 3 buffers: null buffer (optional), offsets buffer and data buffer ArrayData::builder(data_type.clone()) @@ -269,6 +275,8 @@ impl RecordBatchDecoder<'_> { t => unreachable!("Data type {:?} either unsupported or not primitive", t), }; + builder = builder.null_count(field_node.null_count() as usize); + self.create_array_from_builder(builder) } @@ -294,7 +302,7 @@ impl RecordBatchDecoder<'_> { let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); let length = field_node.length() as usize; let child_data = child_array.into_data(); - let builder = match data_type { + let mut builder = match data_type { List(_) | LargeList(_) | Map(_, _) => ArrayData::builder(data_type.clone()) .len(length) .add_buffer(buffers[1].clone()) @@ -309,6 +317,8 @@ impl RecordBatchDecoder<'_> { _ => unreachable!("Cannot create list or map array from {:?}", data_type), }; + builder = builder.null_count(field_node.null_count() as usize); + self.create_array_from_builder(builder) } @@ -321,15 +331,38 @@ impl RecordBatchDecoder<'_> { ) -> Result { let null_count = struct_node.null_count() as usize; let len = struct_node.length() as usize; + let skip_validation = self.skip_validation.get(); + + let nulls = if null_count > 0 { + let validity_buffer = BooleanBuffer::new(null_buffer, 0, len); + let null_buffer = if skip_validation { + // safety: flag can only be set via unsafe code + unsafe { NullBuffer::new_unchecked(validity_buffer, null_count) } + } else { + let null_buffer = NullBuffer::new(validity_buffer); - let nulls = (null_count > 0).then(|| BooleanBuffer::new(null_buffer, 0, len).into()); + if null_buffer.null_count() != null_count { + return Err(ArrowError::InvalidArgumentError(format!( + "null_count value ({}) doesn't match actual number of nulls in array ({})", + null_count, + null_buffer.null_count() + ))); + } + + null_buffer + }; + + Some(null_buffer) + } else { + None + }; if struct_arrays.is_empty() { // `StructArray::from` can't infer the correct row count // if we have zero fields return Ok(Arc::new(StructArray::new_empty_fields(len, nulls))); } - let struct_array = if self.skip_validation.get() { + let struct_array = if skip_validation { // safety: flag can only be set via unsafe code unsafe { StructArray::new_unchecked(struct_fields.clone(), struct_arrays, nulls) } } else { @@ -354,7 +387,8 @@ impl RecordBatchDecoder<'_> { .len(field_node.length() as usize) .add_buffer(buffers[1].clone()) .add_child_data(value_array.into_data()) - .null_bit_buffer(null_buffer); + .null_bit_buffer(null_buffer) + .null_count(field_node.null_count() as usize); self.create_array_from_builder(builder) } else { unreachable!("Cannot create dictionary array from {:?}", data_type) @@ -366,7 +400,8 @@ impl RecordBatchDecoder<'_> { /// [`RecordBatch`] /// /// [IPC RecordBatch]: crate::RecordBatch -struct RecordBatchDecoder<'a> { +/// +pub struct RecordBatchDecoder<'a> { /// The flatbuffers encoded record batch batch: crate::RecordBatch<'a>, /// The output schema @@ -678,12 +713,72 @@ fn read_dictionary_impl( require_alignment: bool, skip_validation: UnsafeFlag, ) -> Result<(), ArrowError> { - if batch.isDelta() { - return Err(ArrowError::InvalidArgumentError( - "delta dictionary batches not supported".to_string(), - )); - } + let id = batch.id(); + + let dictionary_values = get_dictionary_values( + buf, + batch, + schema, + dictionaries_by_id, + metadata, + require_alignment, + skip_validation, + )?; + + update_dictionaries(dictionaries_by_id, batch.isDelta(), id, dictionary_values)?; + Ok(()) +} + +/// Updates the `dictionaries_by_id` with the provided dictionary values and id. +/// +/// # Errors +/// - If `is_delta` is true and there is no existing dictionary for the given +/// `dict_id` +/// - If `is_delta` is true and the concatenation of the existing and new +/// dictionary fails. This usually signals a type mismatch between the old and +/// new values. +fn update_dictionaries( + dictionaries_by_id: &mut HashMap, + is_delta: bool, + dict_id: i64, + dict_values: ArrayRef, +) -> Result<(), ArrowError> { + if !is_delta { + // We don't currently record the isOrdered field. This could be general + // attributes of arrays. + // Add (possibly multiple) array refs to the dictionaries array. + dictionaries_by_id.insert(dict_id, dict_values.clone()); + return Ok(()); + } + + let existing = dictionaries_by_id.get(&dict_id).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "No existing dictionary for delta dictionary with id '{dict_id}'" + )) + })?; + + let combined = concat::concat(&[existing, &dict_values]).map_err(|e| { + ArrowError::InvalidArgumentError(format!("Failed to concat delta dictionary: {e}")) + })?; + + dictionaries_by_id.insert(dict_id, combined); + + Ok(()) +} + +/// Given a dictionary batch IPC message/body along with the full state of a +/// stream including schema, dictionary cache, metadata, and other flags, this +/// function will parse the buffer into an array of dictionary values. +fn get_dictionary_values( + buf: &Buffer, + batch: crate::DictionaryBatch, + schema: &Schema, + dictionaries_by_id: &mut HashMap, + metadata: &MetadataVersion, + require_alignment: bool, + skip_validation: UnsafeFlag, +) -> Result { let id = batch.id(); #[allow(deprecated)] let fields_using_this_dictionary = schema.fields_with_dict_id(id); @@ -719,12 +814,7 @@ fn read_dictionary_impl( ArrowError::InvalidArgumentError(format!("dictionary id {id} not found in schema")) })?; - // We don't currently record the isOrdered field. This could be general - // attributes of arrays. - // Add (possibly multiple) array refs to the dictionaries array. - dictionaries_by_id.insert(id, dictionary_values.clone()); - - Ok(()) + Ok(dictionary_values) } /// Read the data for a given block @@ -742,7 +832,7 @@ fn read_block(mut reader: R, block: &Block) -> Result -fn parse_message(buf: &[u8]) -> Result, ArrowError> { +fn parse_message(buf: &[u8]) -> Result, ArrowError> { let buf = match buf[..4] == CONTINUATION_MARKER { true => &buf[8..], false => &buf[4..], @@ -893,7 +983,7 @@ impl FileDecoder { self } - fn read_message<'a>(&self, buf: &'a [u8]) -> Result, ArrowError> { + fn read_message<'a>(&self, buf: &'a [u8]) -> Result, ArrowError> { let message = parse_message(buf)?; // some old test data's footer metadata is not set, so we account for that @@ -1329,7 +1419,7 @@ impl RecordBatchReader for FileReader { /// [IPC Streaming Format]: https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format pub struct StreamReader { /// Stream reader - reader: R, + reader: MessageReader, /// The schema that is read from the stream's first message schema: SchemaRef, @@ -1387,32 +1477,28 @@ impl StreamReader { /// An ['Err'](Result::Err) may be returned if the reader does not encounter a schema /// as the first message in the stream. pub fn try_new( - mut reader: R, + reader: R, projection: Option>, ) -> Result, ArrowError> { - // determine metadata length - let mut meta_size: [u8; 4] = [0; 4]; - reader.read_exact(&mut meta_size)?; - let meta_len = { - // If a continuation marker is encountered, skip over it and read - // the size from the next four bytes. - if meta_size == CONTINUATION_MARKER { - reader.read_exact(&mut meta_size)?; - } - i32::from_le_bytes(meta_size) + let mut msg_reader = MessageReader::new(reader); + let message = msg_reader.maybe_next()?; + let Some((message, _)) = message else { + return Err(ArrowError::IpcError( + "Expected schema message, found empty stream.".to_string(), + )); }; - let mut meta_buffer = vec![0; meta_len as usize]; - reader.read_exact(&mut meta_buffer)?; + if message.header_type() != Message::MessageHeader::Schema { + return Err(ArrowError::IpcError(format!( + "Expected a schema as the first message in the stream, got: {:?}", + message.header_type() + ))); + } - let message = crate::root_as_message(meta_buffer.as_slice()).map_err(|err| { - ArrowError::ParseError(format!("Unable to get root as message: {err:?}")) - })?; - // message header is a Schema, so read it - let ipc_schema: crate::Schema = message.header_as_schema().ok_or_else(|| { - ArrowError::ParseError("Unable to read IPC message as schema".to_string()) + let schema = message.header_as_schema().ok_or_else(|| { + ArrowError::ParseError("Failed to parse schema from message header".to_string()) })?; - let schema = crate::convert::fb_to_schema(ipc_schema); + let schema = crate::convert::fb_to_schema(schema); // Create an array of optional dictionary value arrays, one per field. let dictionaries_by_id = HashMap::new(); @@ -1424,8 +1510,9 @@ impl StreamReader { } _ => None, }; + Ok(Self { - reader, + reader: msg_reader, schema: Arc::new(schema), finished: false, dictionaries_by_id, @@ -1457,114 +1544,127 @@ impl StreamReader { if self.finished { return Ok(None); } - // determine metadata length - let mut meta_size: [u8; 4] = [0; 4]; - - match self.reader.read_exact(&mut meta_size) { - Ok(()) => (), - Err(e) => { - return if e.kind() == std::io::ErrorKind::UnexpectedEof { - // Handle EOF without the "0xFFFFFFFF 0x00000000" - // valid according to: - // https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format - self.finished = true; - Ok(None) - } else { - Err(ArrowError::from(e)) - }; - } - } - let meta_len = { - // If a continuation marker is encountered, skip over it and read - // the size from the next four bytes. - if meta_size == CONTINUATION_MARKER { - self.reader.read_exact(&mut meta_size)?; - } - i32::from_le_bytes(meta_size) - }; + // Read messages until we get a record batch or end of stream + loop { + let message = self.next_ipc_message()?; + let Some(message) = message else { + // If the message is None, we have reached the end of the stream. + self.finished = true; + return Ok(None); + }; - if meta_len == 0 { - // the stream has ended, mark the reader as finished - self.finished = true; - return Ok(None); + match message { + IpcMessage::Schema(_) => { + return Err(ArrowError::IpcError( + "Expected a record batch, but found a schema".to_string(), + )); + } + IpcMessage::RecordBatch(record_batch) => { + return Ok(Some(record_batch)); + } + IpcMessage::DictionaryBatch { .. } => { + continue; + } + }; } + } - let mut meta_buffer = vec![0; meta_len as usize]; - self.reader.read_exact(&mut meta_buffer)?; - - let vecs = &meta_buffer.to_vec(); - let message = crate::root_as_message(vecs).map_err(|err| { - ArrowError::ParseError(format!("Unable to get root as message: {err:?}")) - })?; + /// Reads and fully parses the next IPC message from the stream. Whereas + /// [`Self::maybe_next`] is a higher level method focused on reading + /// `RecordBatch`es, this method returns the individual fully parsed IPC + /// messages from the underlying stream. + /// + /// This is useful primarily for testing reader/writer behaviors as it + /// allows a full view into the messages that have been written to a stream. + pub(crate) fn next_ipc_message(&mut self) -> Result, ArrowError> { + let message = self.reader.maybe_next()?; + let Some((message, body)) = message else { + // If the message is None, we have reached the end of the stream. + return Ok(None); + }; - match message.header_type() { - crate::MessageHeader::Schema => Err(ArrowError::IpcError( - "Not expecting a schema when messages are read".to_string(), - )), - crate::MessageHeader::RecordBatch => { + let ipc_message = match message.header_type() { + Message::MessageHeader::Schema => { + let schema = message.header_as_schema().ok_or_else(|| { + ArrowError::ParseError("Failed to parse schema from message header".to_string()) + })?; + let arrow_schema = crate::convert::fb_to_schema(schema); + IpcMessage::Schema(arrow_schema) + } + Message::MessageHeader::RecordBatch => { let batch = message.header_as_record_batch().ok_or_else(|| { ArrowError::IpcError("Unable to read IPC message as record batch".to_string()) })?; - // read the block that makes up the record batch into a buffer - let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize); - self.reader.read_exact(&mut buf)?; - RecordBatchDecoder::try_new( - &buf.into(), + let version = message.version(); + let schema = self.schema.clone(); + let record_batch = RecordBatchDecoder::try_new( + &body.into(), batch, - self.schema(), + schema, &self.dictionaries_by_id, - &message.version(), + &version, )? .with_projection(self.projection.as_ref().map(|x| x.0.as_ref())) .with_require_alignment(false) .with_skip_validation(self.skip_validation.clone()) - .read_record_batch() - .map(Some) + .read_record_batch()?; + IpcMessage::RecordBatch(record_batch) } - crate::MessageHeader::DictionaryBatch => { - let batch = message.header_as_dictionary_batch().ok_or_else(|| { - ArrowError::IpcError( - "Unable to read IPC message as dictionary batch".to_string(), + Message::MessageHeader::DictionaryBatch => { + let dict = message.header_as_dictionary_batch().ok_or_else(|| { + ArrowError::ParseError( + "Failed to parse dictionary batch from message header".to_string(), ) })?; - // read the block that makes up the dictionary batch into a buffer - let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize); - self.reader.read_exact(&mut buf)?; - read_dictionary_impl( - &buf.into(), - batch, + let version = message.version(); + let dict_values = get_dictionary_values( + &body.into(), + dict, &self.schema, &mut self.dictionaries_by_id, - &message.version(), + &version, false, self.skip_validation.clone(), )?; - // read the next message until we encounter a RecordBatch - self.maybe_next() + update_dictionaries( + &mut self.dictionaries_by_id, + dict.isDelta(), + dict.id(), + dict_values.clone(), + )?; + + IpcMessage::DictionaryBatch { + id: dict.id(), + is_delta: (dict.isDelta()), + values: (dict_values), + } } - crate::MessageHeader::NONE => Ok(None), - t => Err(ArrowError::InvalidArgumentError(format!( - "Reading types other than record batches not yet supported, unable to read {t:?} " - ))), - } + x => { + return Err(ArrowError::ParseError(format!( + "Unsupported message header type in IPC stream: '{x:?}'" + ))); + } + }; + + Ok(Some(ipc_message)) } /// Gets a reference to the underlying reader. /// /// It is inadvisable to directly read from the underlying reader. pub fn get_ref(&self) -> &R { - &self.reader + self.reader.inner() } /// Gets a mutable reference to the underlying reader. /// /// It is inadvisable to directly read from the underlying reader. pub fn get_mut(&mut self) -> &mut R { - &mut self.reader + self.reader.inner_mut() } /// Specifies if validation should be skipped when reading data (defaults to `false`) @@ -1592,8 +1692,126 @@ impl RecordBatchReader for StreamReader { } } +/// Representation of a fully parsed IpcMessage from the underlying stream. +/// Parsing this kind of message is done by higher level constructs such as +/// [`StreamReader`], because fully interpreting the messages into a record +/// batch or dictionary batch requires access to stream state such as schema +/// and the full dictionary cache. +#[derive(Debug)] +#[allow(dead_code)] +pub(crate) enum IpcMessage { + Schema(arrow_schema::Schema), + RecordBatch(RecordBatch), + DictionaryBatch { + id: i64, + is_delta: bool, + values: ArrayRef, + }, +} + +/// A low-level construct that reads [`Message::Message`]s from a reader while +/// re-using a buffer for metadata. This is composed into [`StreamReader`]. +struct MessageReader { + reader: R, + buf: Vec, +} + +impl MessageReader { + fn new(reader: R) -> Self { + Self { + reader, + buf: Vec::new(), + } + } + + /// Reads the entire next message from the underlying reader which includes + /// the metadata length, the metadata, and the body. + /// + /// # Returns + /// - `Ok(None)` if the the reader signals the end of stream with EOF on + /// the first read + /// - `Err(_)` if the reader returns an error other than on the first + /// read, or if the metadata length is invalid + /// - `Ok(Some(_))` with the Message and buffer containiner the + /// body bytes otherwise. + fn maybe_next(&mut self) -> Result, MutableBuffer)>, ArrowError> { + let meta_len = self.read_meta_len()?; + let Some(meta_len) = meta_len else { + return Ok(None); + }; + + self.buf.resize(meta_len, 0); + self.reader.read_exact(&mut self.buf)?; + + let message = crate::root_as_message(self.buf.as_slice()).map_err(|err| { + ArrowError::ParseError(format!("Unable to get root as message: {err:?}")) + })?; + + let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize); + self.reader.read_exact(&mut buf)?; + + Ok(Some((message, buf))) + } + + /// Get a mutable reference to the underlying reader. + fn inner_mut(&mut self) -> &mut R { + &mut self.reader + } + + /// Get an immutable reference to the underlying reader. + fn inner(&self) -> &R { + &self.reader + } + + /// Read the metadata length for the next message from the underlying stream. + /// + /// # Returns + /// - `Ok(None)` if the the reader signals the end of stream with EOF on + /// the first read + /// - `Err(_)` if the reader returns an error other than on the first + /// read, or if the metadata length is less than 0. + /// - `Ok(Some(_))` with the length otherwise. + pub fn read_meta_len(&mut self) -> Result, ArrowError> { + let mut meta_len: [u8; 4] = [0; 4]; + match self.reader.read_exact(&mut meta_len) { + Ok(_) => {} + Err(e) => { + return if e.kind() == std::io::ErrorKind::UnexpectedEof { + // Handle EOF without the "0xFFFFFFFF 0x00000000" + // valid according to: + // https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format + Ok(None) + } else { + Err(ArrowError::from(e)) + }; + } + }; + + let meta_len = { + // If a continuation marker is encountered, skip over it and read + // the size from the next four bytes. + if meta_len == CONTINUATION_MARKER { + self.reader.read_exact(&mut meta_len)?; + } + + i32::from_le_bytes(meta_len) + }; + + if meta_len == 0 { + return Ok(None); + } + + let meta_len = usize::try_from(meta_len) + .map_err(|_| ArrowError::ParseError(format!("Invalid metadata length: {meta_len}")))?; + + Ok(Some(meta_len)) + } +} + #[cfg(test)] mod tests { + use std::io::Cursor; + use crate::convert::fb_to_schema; use crate::writer::{ unslice_run_array, write_message, DictionaryTracker, IpcDataGenerator, IpcWriteOptions, @@ -1740,6 +1958,49 @@ mod tests { .unwrap() } + #[test] + fn test_negative_meta_len_start_stream() { + let bytes = i32::to_le_bytes(-1); + let mut buf = vec![]; + buf.extend(CONTINUATION_MARKER); + buf.extend(bytes); + + let reader_err = StreamReader::try_new(Cursor::new(buf), None).err(); + assert!(reader_err.is_some()); + assert_eq!( + reader_err.unwrap().to_string(), + "Parser error: Invalid metadata length: -1" + ); + } + + #[test] + fn test_negative_meta_len_mid_stream() { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let mut buf = Vec::new(); + { + let mut writer = crate::writer::StreamWriter::try_new(&mut buf, &schema).unwrap(); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(Int32Array::from(vec![1]))]) + .unwrap(); + writer.write(&batch).unwrap(); + } + + let bytes = i32::to_le_bytes(-1); + buf.extend(CONTINUATION_MARKER); + buf.extend(bytes); + + let mut reader = StreamReader::try_new(Cursor::new(buf), None).unwrap(); + // Read the valid value + assert!(reader.maybe_next().is_ok()); + // Read the invalid meta len + let batch_err = reader.maybe_next().err(); + assert!(batch_err.is_some()); + assert_eq!( + batch_err.unwrap().to_string(), + "Parser error: Invalid metadata length: -1" + ); + } + #[test] fn test_projection_array_values() { // define schema @@ -2871,4 +3132,15 @@ mod tests { assert_eq!(schema, new_schema); } + + #[test] + fn test_negative_meta_len() { + let bytes = i32::to_le_bytes(-1); + let mut buf = vec![]; + buf.extend(CONTINUATION_MARKER); + buf.extend(bytes); + + let reader = StreamReader::try_new(Cursor::new(buf), None); + assert!(reader.is_err()); + } } diff --git a/arrow-ipc/src/tests/delta_dictionary.rs b/arrow-ipc/src/tests/delta_dictionary.rs new file mode 100644 index 000000000000..3f2f99b751ca --- /dev/null +++ b/arrow-ipc/src/tests/delta_dictionary.rs @@ -0,0 +1,479 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ + reader::IpcMessage, + writer::{DictionaryHandling, IpcWriteOptions, StreamWriter}, +}; +use crate::{ + reader::{FileReader, StreamReader}, + writer::FileWriter, +}; +use arrow_array::{ + builder::StringDictionaryBuilder, types::Int32Type, Array, ArrayRef, DictionaryArray, + RecordBatch, StringArray, +}; +use arrow_schema::{DataType, Field, Schema}; +use std::io::Cursor; +use std::sync::Arc; + +#[test] +fn test_zero_row_dict() { + let batches: &[&[&str]] = &[&[], &["A"], &[], &["B", "C"], &[]]; + run_delta_sequence_test( + batches, + &[ + MessageType::Dict(vec![]), + MessageType::RecordBatch, + MessageType::DeltaDict(str_vec(&["A"])), + MessageType::RecordBatch, + MessageType::RecordBatch, + MessageType::DeltaDict(str_vec(&["B", "C"])), + MessageType::RecordBatch, + ], + ); + + run_resend_sequence_test( + batches, + &[ + MessageType::Dict(vec![]), + MessageType::RecordBatch, + MessageType::Dict(str_vec(&["A"])), + MessageType::RecordBatch, + MessageType::RecordBatch, + MessageType::Dict(str_vec(&["A", "B", "C"])), + MessageType::RecordBatch, + ], + ); +} + +#[test] +fn test_mixed_delta() { + let batches: &[&[&str]] = &[ + &["A"], + &["A", "B"], + &["C"], + &["D", "E"], + &["A", "B", "C", "D", "E"], + ]; + + run_delta_sequence_test( + batches, + &[ + MessageType::Dict(str_vec(&["A"])), + MessageType::RecordBatch, + MessageType::DeltaDict(str_vec(&["B"])), + MessageType::RecordBatch, + MessageType::DeltaDict(str_vec(&["C"])), + MessageType::RecordBatch, + MessageType::DeltaDict(str_vec(&["D", "E"])), + MessageType::RecordBatch, + MessageType::RecordBatch, + ], + ); + + run_resend_sequence_test( + batches, + &[ + MessageType::Dict(str_vec(&["A"])), + MessageType::RecordBatch, + MessageType::Dict(str_vec(&["A", "B"])), + MessageType::RecordBatch, + MessageType::Dict(str_vec(&["A", "B", "C"])), + MessageType::RecordBatch, + MessageType::Dict(str_vec(&["A", "B", "C", "D", "E"])), + MessageType::RecordBatch, + MessageType::RecordBatch, + ], + ); +} + +#[test] +fn test_disjoint_delta() { + let batches: &[&[&str]] = &[&["A"], &["B"], &["C", "E"]]; + run_delta_sequence_test( + batches, + &[ + MessageType::Dict(str_vec(&["A"])), + MessageType::RecordBatch, + MessageType::DeltaDict(str_vec(&["B"])), + MessageType::RecordBatch, + MessageType::DeltaDict(str_vec(&["C", "E"])), + MessageType::RecordBatch, + ], + ); + + run_resend_sequence_test( + batches, + &[ + MessageType::Dict(str_vec(&["A"])), + MessageType::RecordBatch, + MessageType::Dict(str_vec(&["A", "B"])), + MessageType::RecordBatch, + MessageType::Dict(str_vec(&["A", "B", "C", "E"])), + MessageType::RecordBatch, + ], + ); +} + +#[test] +fn test_increasing_delta() { + let batches: &[&[&str]] = &[&["A"], &["A", "B"], &["A", "B", "C"]]; + run_delta_sequence_test( + batches, + &[ + MessageType::Dict(str_vec(&["A"])), + MessageType::RecordBatch, + MessageType::DeltaDict(str_vec(&["B"])), + MessageType::RecordBatch, + MessageType::DeltaDict(str_vec(&["C"])), + MessageType::RecordBatch, + ], + ); + + run_resend_sequence_test( + batches, + &[ + MessageType::Dict(str_vec(&["A"])), + MessageType::RecordBatch, + MessageType::Dict(str_vec(&["A", "B"])), + MessageType::RecordBatch, + MessageType::Dict(str_vec(&["A", "B", "C"])), + MessageType::RecordBatch, + ], + ); +} + +#[test] +fn test_single_delta() { + let batches: &[&[&str]] = &[&["A", "B", "C"], &["D"]]; + run_delta_sequence_test( + batches, + &[ + MessageType::Dict(str_vec(&["A", "B", "C"])), + MessageType::RecordBatch, + MessageType::DeltaDict(str_vec(&["D"])), + MessageType::RecordBatch, + ], + ); + + run_resend_sequence_test( + batches, + &[ + MessageType::Dict(str_vec(&["A", "B", "C"])), + MessageType::RecordBatch, + MessageType::Dict(str_vec(&["A", "B", "C", "D"])), + MessageType::RecordBatch, + ], + ); +} + +#[test] +fn test_single_same_value_sequence() { + let batches: &[&[&str]] = &[&["A"], &["A"], &["A"], &["A"]]; + run_delta_sequence_test( + batches, + &[ + MessageType::Dict(str_vec(&["A"])), + MessageType::RecordBatch, + MessageType::RecordBatch, + MessageType::RecordBatch, + MessageType::RecordBatch, + ], + ); + + run_resend_sequence_test( + batches, + &[ + MessageType::Dict(str_vec(&["A"])), + MessageType::RecordBatch, + MessageType::RecordBatch, + MessageType::RecordBatch, + MessageType::RecordBatch, + ], + ); +} + +fn str_vec(strings: &[&str]) -> Vec { + strings.iter().map(|s| s.to_string()).collect() +} + +#[test] +fn test_multi_same_value_sequence() { + let batches: &[&[&str]] = &[&["A", "B", "C"], &["A", "B", "C"]]; + run_delta_sequence_test( + batches, + &[ + MessageType::Dict(str_vec(&["A", "B", "C"])), + MessageType::RecordBatch, + ], + ); +} + +#[derive(Debug, PartialEq)] +enum MessageType { + Schema, + Dict(Vec), + DeltaDict(Vec), + RecordBatch, +} + +fn run_resend_sequence_test(batches: &[&[&str]], sequence: &[MessageType]) { + let opts = IpcWriteOptions::default().with_dictionary_handling(DictionaryHandling::Resend); + run_sequence_test(batches, sequence, opts); +} + +fn run_delta_sequence_test(batches: &[&[&str]], sequence: &[MessageType]) { + let opts = IpcWriteOptions::default().with_dictionary_handling(DictionaryHandling::Delta); + run_sequence_test(batches, sequence, opts); +} + +fn run_sequence_test(batches: &[&[&str]], sequence: &[MessageType], options: IpcWriteOptions) { + let stream_buf = write_all_to_stream(options.clone(), batches); + let ipc_stream = get_ipc_message_stream(stream_buf); + for (message, expected) in ipc_stream.iter().zip(sequence.iter()) { + match message { + IpcMessage::Schema(_) => { + assert_eq!(expected, &MessageType::Schema, "Expected schema message"); + } + IpcMessage::RecordBatch(_) => { + assert_eq!( + expected, + &MessageType::RecordBatch, + "Expected record batch message" + ); + } + IpcMessage::DictionaryBatch { + id: _, + is_delta, + values, + } => { + let expected_values = if *is_delta { + let MessageType::DeltaDict(values) = expected else { + panic!("Expected DeltaDict message type"); + }; + + values + } else { + let MessageType::Dict(values) = expected else { + panic!("Expected Dict message type"); + }; + values + }; + + let values: Vec = values + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| v.map(|s| s.to_string()).unwrap_or_default()) + .collect(); + + assert_eq!(*expected_values, values) + } + } + } +} + +fn get_ipc_message_stream(buf: Vec) -> Vec { + let mut reader = StreamReader::try_new(Cursor::new(buf), None).unwrap(); + let mut results = vec![]; + + loop { + match reader.next_ipc_message() { + Ok(Some(message)) => results.push(message), + Ok(None) => break, // End of stream + Err(e) => panic!("Error reading IPC message: {e:?}"), + } + } + + results +} + +#[test] +fn test_replace_same_length() { + let batches: &[&[&str]] = &[ + &["A", "B", "C", "D", "E", "F"], + &["A", "G", "H", "I", "J", "K"], + ]; + run_parity_test(batches); +} + +#[test] +fn test_sparse_deltas() { + let batches: &[&[&str]] = &[ + &["A"], + &["C"], + &["E", "F", "D"], + &["FOO"], + &["parquet", "B"], + &["123", "B", "C"], + ]; + run_parity_test(batches); +} + +#[test] +fn test_deltas_with_reset() { + // Dictionary resets at ["C", "D"] + let batches: &[&[&str]] = &[&["A"], &["A", "B"], &["C", "D"], &["A", "B", "C", "D"]]; + run_parity_test(batches); +} + +/// FileWriter can only tolerate very specific patterns of delta dictionaries, +/// because the dictionary cannot be replaced/reset. +#[test] +fn test_deltas_with_file() { + let batches: &[&[&str]] = &[&["A"], &["A", "B"], &["A", "B", "C"], &["A", "B", "C", "D"]]; + run_parity_test(batches); +} + +/// Encode all batches three times and compare all three for the same results +/// on the other end. +/// +/// - Stream encoding with delta +/// - Stream encoding without delta +/// - File encoding with delta (File format does not allow replacement +/// dictionaries) +fn run_parity_test(batches: &[&[&str]]) { + let delta_options = + IpcWriteOptions::default().with_dictionary_handling(DictionaryHandling::Delta); + let delta_stream_buf = write_all_to_stream(delta_options.clone(), batches); + + let resend_options = + IpcWriteOptions::default().with_dictionary_handling(DictionaryHandling::Resend); + let resend_stream_buf = write_all_to_stream(resend_options.clone(), batches); + + let delta_file_buf = write_all_to_file(delta_options, batches); + + let mut streams = [ + get_stream_batches(delta_stream_buf), + get_stream_batches(resend_stream_buf), + get_file_batches(delta_file_buf), + ]; + + let (first_stream, other_streams) = streams.split_first_mut().unwrap(); + + for (idx, batch) in first_stream.by_ref().enumerate() { + let first_dict = extract_dictionary(batch); + let expected_values = batches[idx]; + assert_eq!(expected_values, &dict_to_vec(first_dict.clone())); + + for stream in other_streams.iter_mut() { + let next_batch = stream + .next() + .expect("All streams should yield same number of elements"); + let next_dict = extract_dictionary(next_batch); + assert_eq!(expected_values, &dict_to_vec(next_dict.clone())); + assert_eq!(first_dict, next_dict); + } + } + + for stream in other_streams.iter_mut() { + assert!( + stream.next().is_none(), + "All streams should yield same number of elements" + ); + } +} + +fn dict_to_vec(dict: DictionaryArray) -> Vec { + dict.downcast_dict::() + .unwrap() + .into_iter() + .map(|v| v.unwrap_or_default().to_string()) + .collect() +} + +fn get_stream_batches(buf: Vec) -> Box> { + let reader = StreamReader::try_new(Cursor::new(buf), None).unwrap(); + Box::new( + reader + .collect::>>() + .into_iter() + .map(|r| r.unwrap()), + ) +} + +fn get_file_batches(buf: Vec) -> Box> { + let reader = FileReader::try_new(Cursor::new(buf), None).unwrap(); + Box::new( + reader + .collect::>>() + .into_iter() + .map(|r| r.unwrap()), + ) +} + +fn extract_dictionary(batch: RecordBatch) -> DictionaryArray { + batch + .column(0) + .as_any() + .downcast_ref::>() + .unwrap() + .clone() +} + +fn write_all_to_file(options: IpcWriteOptions, vals: &[&[&str]]) -> Vec { + let batches = build_batches(vals); + let mut buf: Vec = Vec::new(); + let mut writer = + FileWriter::try_new_with_options(&mut buf, &batches[0].schema(), options).unwrap(); + for batch in batches { + writer.write(&batch).unwrap(); + } + writer.finish().unwrap(); + buf +} + +fn write_all_to_stream(options: IpcWriteOptions, vals: &[&[&str]]) -> Vec { + let batches = build_batches(vals); + + let mut buf: Vec = Vec::new(); + let mut writer = + StreamWriter::try_new_with_options(&mut buf, &batches[0].schema(), options).unwrap(); + for batch in batches { + writer.write(&batch).unwrap(); + } + + writer.finish().unwrap(); + + buf +} + +fn build_batches(vals: &[&[&str]]) -> Vec { + let mut builder = StringDictionaryBuilder::::new(); + vals.iter().map(|v| build_batch(v, &mut builder)).collect() +} + +fn build_batch( + vals: &[&str], + builder: &mut StringDictionaryBuilder, +) -> RecordBatch { + for &val in vals { + builder.append_value(val); + } + + let array = builder.finish_preserve_values(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "dict", + DataType::Dictionary(Box::from(DataType::Int32), Box::from(DataType::Utf8)), + true, + )])); + + RecordBatch::try_new(schema.clone(), vec![Arc::new(array) as ArrayRef]).unwrap() +} diff --git a/arrow-ipc/src/tests/mod.rs b/arrow-ipc/src/tests/mod.rs new file mode 100644 index 000000000000..e98b28de1482 --- /dev/null +++ b/arrow-ipc/src/tests/mod.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/*! +This module contains cross-functional tests for various ipc components. Some +tests rely on functionality that is not public and so they're placed here rather +than in integration tests or unit tests for a specific module. +*/ +mod delta_dictionary; diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index 114f3a42e3a5..59a1a3c0a190 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -65,6 +65,8 @@ pub struct IpcWriteOptions { /// Compression, if desired. Will result in a runtime error /// if the corresponding feature is not enabled batch_compression_type: Option, + /// How to handle updating dictionaries in IPC messages + dictionary_handling: DictionaryHandling, } impl IpcWriteOptions { @@ -113,6 +115,7 @@ impl IpcWriteOptions { write_legacy_ipc_format, metadata_version, batch_compression_type: None, + dictionary_handling: DictionaryHandling::default(), }), crate::MetadataVersion::V5 => { if write_legacy_ipc_format { @@ -125,6 +128,7 @@ impl IpcWriteOptions { write_legacy_ipc_format, metadata_version, batch_compression_type: None, + dictionary_handling: DictionaryHandling::default(), }) } } @@ -133,6 +137,12 @@ impl IpcWriteOptions { ))), } } + + /// Configure how dictionaries are handled in IPC messages + pub fn with_dictionary_handling(mut self, dictionary_handling: DictionaryHandling) -> Self { + self.dictionary_handling = dictionary_handling; + self + } } impl Default for IpcWriteOptions { @@ -142,6 +152,7 @@ impl Default for IpcWriteOptions { write_legacy_ipc_format: false, metadata_version: crate::MetadataVersion::V5, batch_compression_type: None, + dictionary_handling: DictionaryHandling::default(), } } } @@ -363,21 +374,35 @@ impl IpcDataGenerator { dict_id_seq, )?; - // It's importnat to only take the dict_id at this point, because the dict ID + // It's important to only take the dict_id at this point, because the dict ID // sequence is assigned depth-first, so we need to first encode children and have // them take their assigned dict IDs before we take the dict ID for this field. let dict_id = dict_id_seq.next().ok_or_else(|| { ArrowError::IpcError(format!("no dict id for field {}", field.name())) })?; - let emit = dictionary_tracker.insert(dict_id, column)?; - - if emit { - encoded_dictionaries.push(self.dictionary_batch_to_bytes( - dict_id, - dict_values, - write_options, - )?); + match dictionary_tracker.insert_column( + dict_id, + column, + write_options.dictionary_handling, + )? { + DictionaryUpdate::None => {} + DictionaryUpdate::New | DictionaryUpdate::Replaced => { + encoded_dictionaries.push(self.dictionary_batch_to_bytes( + dict_id, + dict_values, + write_options, + false, + )?); + } + DictionaryUpdate::Delta(data) => { + encoded_dictionaries.push(self.dictionary_batch_to_bytes( + dict_id, + &data, + write_options, + true, + )?); + } } } _ => self._encode_dictionaries( @@ -519,6 +544,7 @@ impl IpcDataGenerator { dict_id: i64, array_data: &ArrayData, write_options: &IpcWriteOptions, + is_delta: bool, ) -> Result { let mut fbb = FlatBufferBuilder::new(); @@ -587,6 +613,7 @@ impl IpcDataGenerator { let mut batch_builder = crate::DictionaryBatchBuilder::new(&mut fbb); batch_builder.add_id(dict_id); batch_builder.add_data(root); + batch_builder.add_isDelta(is_delta); batch_builder.finish().as_union_value() }; @@ -700,6 +727,39 @@ fn into_zero_offset_run_array( Ok(array_data.into()) } +/// Controls how dictionaries are handled in Arrow IPC messages +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DictionaryHandling { + /// Send the entire dictionary every time it is encountered (default) + Resend, + /// Send only new dictionary values since the last batch (delta encoding) + /// + /// When a dictionary is first encountered, the entire dictionary is sent. + /// For subsequent batches, only values that are new (not previously sent) + /// are transmitted with the `isDelta` flag set to true. + Delta, +} + +impl Default for DictionaryHandling { + fn default() -> Self { + Self::Resend + } +} + +/// Describes what kind of update took place after a call to [`DictionaryTracker::insert`]. +#[derive(Debug, Clone)] +pub enum DictionaryUpdate { + /// No dictionary was written, the dictionary was identical to what was already + /// in the tracker. + None, + /// No dictionary was present in the tracker + New, + /// Dictionary was replaced with the new data + Replaced, + /// Dictionary was updated, ArrayData is the delta between old and new + Delta(ArrayData), +} + /// Keeps track of dictionaries that have been written, to avoid emitting the same dictionary /// multiple times. /// @@ -718,11 +778,6 @@ impl DictionaryTracker { /// If `error_on_replacement` /// is true, an error will be generated if an update to an /// existing dictionary is attempted. - /// - /// If `preserve_dict_id` is true, the dictionary ID defined in the schema - /// is used, otherwise a unique dictionary ID will be assigned by incrementing - /// the last seen dictionary ID (or using `0` if no other dictionary IDs have been - /// seen) pub fn new(error_on_replacement: bool) -> Self { #[allow(deprecated)] Self { @@ -760,6 +815,7 @@ impl DictionaryTracker { /// * If the tracker has not been configured to error on replacement or this dictionary /// has never been seen before, return `Ok(true)` to indicate that the dictionary was just /// inserted. + #[deprecated(since = "56.1.0", note = "Use `insert_column` instead")] pub fn insert(&mut self, dict_id: i64, column: &ArrayRef) -> Result { let dict_data = column.to_data(); let dict_values = &dict_data.child_data()[0]; @@ -788,6 +844,125 @@ impl DictionaryTracker { self.written.insert(dict_id, dict_data); Ok(true) } + + /// Keep track of the dictionary with the given ID and values. The return + /// value indicates what, if any, update to the internal map took place + /// and how it should be interpreted based on the `dict_handling` parameter. + /// + /// # Returns + /// + /// * `Ok(Dictionary::New)` - If the dictionary was not previously written + /// * `Ok(Dictionary::Replaced)` - If the dictionary was previously written + /// with completely different data, or if the data is a delta of the existing, + /// but with `dict_handling` set to `DictionaryHandling::Resend` + /// * `Ok(Dictionary::Delta)` - If the dictionary was previously written, but + /// the new data is a delta of the old and the `dict_handling` is set to + /// `DictionaryHandling::Delta` + /// * `Err(e)` - If the dictionary was previously written with different data, + /// and `error_on_replacement` is set to `true`. + pub fn insert_column( + &mut self, + dict_id: i64, + column: &ArrayRef, + dict_handling: DictionaryHandling, + ) -> Result { + let new_data = column.to_data(); + let new_values = &new_data.child_data()[0]; + + // If there is no existing dictionary with this ID, we always insert + let Some(old) = self.written.get(&dict_id) else { + self.written.insert(dict_id, new_data); + return Ok(DictionaryUpdate::New); + }; + + // Fast path - If the array data points to the same buffer as the + // existing then they're the same. + let old_values = &old.child_data()[0]; + if ArrayData::ptr_eq(old_values, new_values) { + return Ok(DictionaryUpdate::None); + } + + // Slow path - Compare the dictionaries value by value + let comparison = compare_dictionaries(old_values, new_values); + if matches!(comparison, DictionaryComparison::Equal) { + return Ok(DictionaryUpdate::None); + } + + const REPLACEMENT_ERROR: &str = + "Dictionary replacement detected when writing IPC file format. \ + Arrow IPC files only support a single dictionary for a given field \ + across all batches."; + + match comparison { + DictionaryComparison::NotEqual => { + if self.error_on_replacement { + return Err(ArrowError::InvalidArgumentError( + REPLACEMENT_ERROR.to_string(), + )); + } + + self.written.insert(dict_id, new_data); + Ok(DictionaryUpdate::Replaced) + } + DictionaryComparison::Delta => match dict_handling { + DictionaryHandling::Resend => { + if self.error_on_replacement { + return Err(ArrowError::InvalidArgumentError( + REPLACEMENT_ERROR.to_string(), + )); + } + + self.written.insert(dict_id, new_data); + Ok(DictionaryUpdate::Replaced) + } + DictionaryHandling::Delta => { + let delta = + new_values.slice(old_values.len(), new_values.len() - old_values.len()); + self.written.insert(dict_id, new_data); + Ok(DictionaryUpdate::Delta(delta)) + } + }, + DictionaryComparison::Equal => unreachable!("Already checked equal case"), + } + } +} + +/// Describes how two dictionary arrays compare to each other. +#[derive(Debug, Clone)] +enum DictionaryComparison { + /// Neither a delta, nor an exact match + NotEqual, + /// Exact element-wise match + Equal, + /// The two arrays are dictionary deltas of each other, meaning the first + /// is a prefix of the second. + Delta, +} + +// Compares two dictionaries and returns a [`DictionaryComparison`]. +fn compare_dictionaries(old: &ArrayData, new: &ArrayData) -> DictionaryComparison { + // Check for exact match + let existing_len = old.len(); + let new_len = new.len(); + if existing_len == new_len { + if *old == *new { + return DictionaryComparison::Equal; + } else { + return DictionaryComparison::NotEqual; + } + } + + // Can't be a delta if the new is shorter than the existing + if new_len < existing_len { + return DictionaryComparison::NotEqual; + } + + // Check for delta + if new.slice(0, existing_len) == *old { + return DictionaryComparison::Delta; + } + + DictionaryComparison::NotEqual } /// Arrow File Writer @@ -926,6 +1101,7 @@ impl FileWriter { } let (meta, data) = write_message(&mut self.writer, encoded_message, &self.write_options)?; + // add a record block for the footer let block = crate::Block::new( self.block_offsets as i64, @@ -1041,7 +1217,7 @@ impl RecordBatchWriter for FileWriter { /// /// * [`FileWriter`] for writing IPC Files /// -/// # Example +/// # Example - Basic usage /// ``` /// # use arrow_array::record_batch; /// # use arrow_ipc::writer::StreamWriter; @@ -1054,7 +1230,57 @@ impl RecordBatchWriter for FileWriter { /// // When all batches are written, call finish to flush all buffers /// writer.finish().unwrap(); /// ``` +/// # Example - Efficient delta dictionaries +/// ``` +/// # use arrow_array::record_batch; +/// # use arrow_ipc::writer::{StreamWriter, IpcWriteOptions}; +/// # use arrow_ipc::writer::DictionaryHandling; +/// # use arrow_schema::{DataType, Field, Schema, SchemaRef}; +/// # use arrow_array::{ +/// # builder::StringDictionaryBuilder, types::Int32Type, Array, ArrayRef, DictionaryArray, +/// # RecordBatch, StringArray, +/// # }; +/// # use std::sync::Arc; /// +/// let schema = Arc::new(Schema::new(vec![Field::new( +/// "col1", +/// DataType::Dictionary(Box::from(DataType::Int32), Box::from(DataType::Utf8)), +/// true, +/// )])); +/// +/// let mut builder = StringDictionaryBuilder::::new(); +/// +/// // `finish_preserve_values` will keep the dictionary values along with their +/// // key assignments so that they can be re-used in the next batch. +/// builder.append("a").unwrap(); +/// builder.append("b").unwrap(); +/// let array1 = builder.finish_preserve_values(); +/// let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(array1) as ArrayRef]).unwrap(); +/// +/// // In this batch, 'a' will have the same dictionary key as 'a' in the previous batch, +/// // and 'd' will take the next available key. +/// builder.append("a").unwrap(); +/// builder.append("d").unwrap(); +/// let array2 = builder.finish_preserve_values(); +/// let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(array2) as ArrayRef]).unwrap(); +/// +/// let mut stream = vec![]; +/// // You must set `.with_dictionary_handling(DictionaryHandling::Delta)` to +/// // enable delta dictionaries in the writer +/// let options = IpcWriteOptions::default().with_dictionary_handling(DictionaryHandling::Delta); +/// let mut writer = StreamWriter::try_new(&mut stream, &schema).unwrap(); +/// +/// // When writing the first batch, a dictionary message with 'a' and 'b' will be written +/// // prior to the record batch. +/// writer.write(&batch1).unwrap(); +/// // With the second batch only a delta dictionary with 'd' will be written +/// // prior to the record batch. This is only possible with `finish_preserve_values`. +/// // Without it, 'a' and 'd' in this batch would have different keys than the +/// // first batch and so we'd have to send a replacement dictionary with new keys +/// // for both. +/// writer.write(&batch2).unwrap(); +/// writer.finish().unwrap(); +/// ``` /// [IPC Streaming Format]: https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format pub struct StreamWriter { /// The object to write to diff --git a/arrow-ipc/tests/test_delta_dictionary.rs b/arrow-ipc/tests/test_delta_dictionary.rs new file mode 100644 index 000000000000..f7c4e7f32554 --- /dev/null +++ b/arrow-ipc/tests/test_delta_dictionary.rs @@ -0,0 +1,590 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::{ + builder::{ListBuilder, PrimitiveDictionaryBuilder, StringDictionaryBuilder}, + Array, ArrayRef, DictionaryArray, ListArray, RecordBatch, StringArray, +}; +use arrow_ipc::reader::StreamReader; +use arrow_ipc::writer::{DictionaryHandling, IpcWriteOptions, StreamWriter}; +use arrow_schema::{ArrowError, DataType, Field, Schema}; +use std::io::Cursor; +use std::sync::Arc; + +#[test] +fn test_dictionary_handling_option() { + // Test that DictionaryHandling can be set + let _options = IpcWriteOptions::default().with_dictionary_handling(DictionaryHandling::Delta); + + // Verify it was set (we can't access private field directly) + // This test just verifies the API exists +} + +#[test] +fn test_nested_dictionary_with_delta() -> Result<(), ArrowError> { + // Test writing nested dictionaries with delta option + // Create a simple nested structure for testing + + // Create dictionary arrays + let mut dict_builder = StringDictionaryBuilder::::new(); + dict_builder.append_value("hello"); + dict_builder.append_value("world"); + let dict_array = dict_builder.finish(); + + // Create a list of dictionaries + let mut list_builder = + ListBuilder::new(StringDictionaryBuilder::::new()); + list_builder.values().append_value("item1"); + list_builder.values().append_value("item2"); + list_builder.append(true); + list_builder.values().append_value("item3"); + list_builder.append(true); + let list_array = list_builder.finish(); + + // Create schema with nested dictionaries + let schema = Arc::new(Schema::new(vec![ + Field::new("dict", dict_array.data_type().clone(), true), + Field::new("list_of_dict", list_array.data_type().clone(), true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(dict_array) as ArrayRef, + Arc::new(list_array) as ArrayRef, + ], + )?; + + // Write with delta dictionary handling + let mut buffer = Vec::new(); + { + let options = + IpcWriteOptions::default().with_dictionary_handling(DictionaryHandling::Delta); + let mut writer = StreamWriter::try_new_with_options(&mut buffer, &schema, options)?; + writer.write(&batch)?; + writer.finish()?; + } + + // Read back and verify + let reader = StreamReader::try_new(Cursor::new(buffer), None)?; + let read_batches: Result, _> = reader.collect(); + let read_batches = read_batches?; + assert_eq!(read_batches.len(), 1); + + let read_batch = &read_batches[0]; + assert_eq!(read_batch.num_columns(), 2); + assert_eq!(read_batch.num_rows(), 2); + let dict_array = read_batch + .column(0) + .as_any() + .downcast_ref::>() + .unwrap(); + let dict_values = dict_array + .values() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(dict_values.len(), 2); + assert_eq!(dict_values.value(0), "hello"); + assert_eq!(dict_values.value(1), "world"); + let list_array = read_batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + let list_dict_array = list_array + .values() + .as_any() + .downcast_ref::>() + .unwrap(); + let list_values = list_dict_array + .values() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(list_values.len(), 3); + assert_eq!(list_values.value(0), "item1"); + assert_eq!(list_values.value(1), "item2"); + assert_eq!(list_values.value(2), "item3"); + + Ok(()) +} + +#[test] +fn test_complex_nested_dictionaries() -> Result<(), ArrowError> { + // Test nested structure with dictionaries at multiple levels + + // Create a nested structure: List(Dictionary(List(Dictionary))) + + // Inner dictionary for the nested list + let _inner_dict_field = Field::new( + "inner_item", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ); + + // Create a list of dictionaries + let mut list_builder = + ListBuilder::new(StringDictionaryBuilder::::new()); + + // First list + list_builder.values().append_value("inner_a"); + list_builder.values().append_value("inner_b"); + list_builder.append(true); + + // Second list + list_builder.values().append_value("inner_c"); + list_builder.values().append_value("inner_d"); + list_builder.append(true); + + let list_array = list_builder.finish(); + + // Create outer dictionary containing the list + let mut outer_dict_builder = StringDictionaryBuilder::::new(); + outer_dict_builder.append_value("outer_1"); + outer_dict_builder.append_value("outer_2"); + let outer_dict = outer_dict_builder.finish(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("outer_dict", outer_dict.data_type().clone(), true), + Field::new("nested_list", list_array.data_type().clone(), true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(outer_dict) as ArrayRef, + Arc::new(list_array) as ArrayRef, + ], + )?; + + // Write with delta dictionary handling + let mut buffer = Vec::new(); + { + let options = + IpcWriteOptions::default().with_dictionary_handling(DictionaryHandling::Delta); + let mut writer = StreamWriter::try_new_with_options(&mut buffer, &schema, options)?; + writer.write(&batch)?; + writer.finish()?; + } + + // Verify it writes without error + assert!(!buffer.is_empty()); + + // Read back and verify + let reader = StreamReader::try_new(Cursor::new(buffer), None)?; + let read_batches: Result, _> = reader.collect(); + let read_batches = read_batches?; + + assert_eq!(read_batches.len(), 1); + + let read_batch = &read_batches[0]; + assert_eq!(read_batch.num_columns(), 2); + assert_eq!(read_batch.num_rows(), 2); + let outer_dict_array = read_batch + .column(0) + .as_any() + .downcast_ref::>() + .unwrap(); + let outer_dict_values = outer_dict_array + .values() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(outer_dict_values.len(), 2); + assert_eq!(outer_dict_values.value(0), "outer_1"); + assert_eq!(outer_dict_values.value(1), "outer_2"); + + let nested_list_array = read_batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + let nested_dict_array = nested_list_array + .values() + .as_any() + .downcast_ref::>() + .unwrap(); + let nested_dict_values = nested_dict_array + .values() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(nested_dict_values.len(), 4); + assert_eq!(nested_dict_values.value(0), "inner_a"); + assert_eq!(nested_dict_values.value(1), "inner_b"); + assert_eq!(nested_dict_values.value(2), "inner_c"); + assert_eq!(nested_dict_values.value(3), "inner_d"); + + Ok(()) +} + +#[test] +fn test_multiple_dictionary_types() -> Result<(), ArrowError> { + // Test different dictionary value types in one schema + + // String dictionary + let mut string_dict_builder = StringDictionaryBuilder::::new(); + string_dict_builder.append_value("apple"); + string_dict_builder.append_value("banana"); + string_dict_builder.append_value("apple"); + let string_dict = string_dict_builder.finish(); + + // Integer dictionary + let mut int_dict_builder = PrimitiveDictionaryBuilder::< + arrow_array::types::Int32Type, + arrow_array::types::Int64Type, + >::new(); + int_dict_builder.append_value(100); + int_dict_builder.append_value(200); + int_dict_builder.append_value(100); + let int_dict = int_dict_builder.finish(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("string_dict", string_dict.data_type().clone(), true), + Field::new("int_dict", int_dict.data_type().clone(), true), + ])); + + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(string_dict) as ArrayRef, + Arc::new(int_dict) as ArrayRef, + ], + )?; + + // Create second batch with extended dictionaries + let mut string_dict_builder2 = StringDictionaryBuilder::::new(); + string_dict_builder2.append_value("apple"); + string_dict_builder2.append_value("banana"); + string_dict_builder2.append_value("cherry"); // new + string_dict_builder2.append_value("date"); // new + let string_dict2 = string_dict_builder2.finish(); + + let mut int_dict_builder2 = PrimitiveDictionaryBuilder::< + arrow_array::types::Int32Type, + arrow_array::types::Int64Type, + >::new(); + int_dict_builder2.append_value(100); + int_dict_builder2.append_value(200); + int_dict_builder2.append_value(300); // new + int_dict_builder2.append_value(400); // new + let int_dict2 = int_dict_builder2.finish(); + + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(string_dict2) as ArrayRef, + Arc::new(int_dict2) as ArrayRef, + ], + )?; + + // Write with delta dictionary handling + let mut buffer = Vec::new(); + { + let options = + IpcWriteOptions::default().with_dictionary_handling(DictionaryHandling::Delta); + let mut writer = StreamWriter::try_new_with_options(&mut buffer, &schema, options)?; + writer.write(&batch1)?; + writer.write(&batch2)?; + writer.finish()?; + } + + // Read back and verify + let reader = StreamReader::try_new(Cursor::new(buffer), None)?; + let read_batches: Result, _> = reader.collect(); + let read_batches = read_batches?; + + assert_eq!(read_batches.len(), 2); + + // Check string dictionary in second batch + let read_batch2 = &read_batches[1]; + let string_dict_array = read_batch2 + .column(0) + .as_any() + .downcast_ref::>() + .unwrap(); + + let string_values = string_dict_array + .values() + .as_any() + .downcast_ref::() + .unwrap(); + + // Should have all 4 string values + assert_eq!(string_values.len(), 4); + assert_eq!(string_values.value(0), "apple"); + assert_eq!(string_values.value(1), "banana"); + assert_eq!(string_values.value(2), "cherry"); + assert_eq!(string_values.value(3), "date"); + + Ok(()) +} + +#[test] +fn test_empty_dictionary_delta() -> Result<(), ArrowError> { + // Test edge case with empty dictionaries + + // First batch with empty dictionary + let mut builder1 = StringDictionaryBuilder::::new(); + builder1.append_null(); + builder1.append_null(); + let array1 = builder1.finish(); + + // Second batch with some values + let mut builder2 = StringDictionaryBuilder::::new(); + builder2.append_value("first"); + builder2.append_value("second"); + let array2 = builder2.finish(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "dict", + array1.data_type().clone(), + true, + )])); + + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(array1) as ArrayRef])?; + + let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(array2) as ArrayRef])?; + + // Write with delta dictionary handling + let mut buffer = Vec::new(); + { + let options = + IpcWriteOptions::default().with_dictionary_handling(DictionaryHandling::Delta); + let mut writer = StreamWriter::try_new_with_options(&mut buffer, &schema, options)?; + writer.write(&batch1)?; + writer.write(&batch2)?; + writer.finish()?; + } + + // Read back and verify + let reader = StreamReader::try_new(Cursor::new(buffer), None)?; + let read_batches: Result, _> = reader.collect(); + let read_batches = read_batches?; + + assert_eq!(read_batches.len(), 2); + + // Second batch should have the dictionary values + let read_batch2 = &read_batches[1]; + let dict_array = read_batch2 + .column(0) + .as_any() + .downcast_ref::>() + .unwrap(); + + let dict_values = dict_array + .values() + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(dict_values.len(), 2); + assert_eq!(dict_values.value(0), "first"); + assert_eq!(dict_values.value(1), "second"); + + Ok(()) +} + +#[test] +fn test_delta_with_shared_dictionary_data() -> Result<(), ArrowError> { + // Test efficient delta detection when dictionaries share underlying data + + // Create initial dictionary + let mut builder = StringDictionaryBuilder::::new(); + builder.append_value("alpha"); + builder.append_value("beta"); + let dict1 = builder.finish(); + + // Create a dictionary that extends the first one by sharing its data + // This simulates a common pattern where dictionaries are built incrementally + let dict1_values = dict1.values(); + let mut builder2 = StringDictionaryBuilder::::new(); + // First, add the existing values + for i in 0..dict1_values.len() { + builder2.append_value( + dict1_values + .as_any() + .downcast_ref::() + .unwrap() + .value(i), + ); + } + // Then add new values + builder2.append_value("gamma"); + builder2.append_value("delta"); + let dict2 = builder2.finish(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "dict", + dict1.data_type().clone(), + true, + )])); + + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(dict1) as ArrayRef])?; + + let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(dict2) as ArrayRef])?; + + // Write with delta dictionary handling + let mut buffer = Vec::new(); + { + let options = + IpcWriteOptions::default().with_dictionary_handling(DictionaryHandling::Delta); + let mut writer = StreamWriter::try_new_with_options(&mut buffer, &schema, options)?; + writer.write(&batch1)?; + writer.write(&batch2)?; + writer.finish()?; + } + + // Read back and verify delta was used correctly + let reader = StreamReader::try_new(Cursor::new(buffer), None)?; + let read_batches: Result, _> = reader.collect(); + let read_batches = read_batches?; + + assert_eq!(read_batches.len(), 2); + + // Verify second batch has all values + let read_batch2 = &read_batches[1]; + let dict_array = read_batch2 + .column(0) + .as_any() + .downcast_ref::>() + .unwrap(); + + let dict_values = dict_array + .values() + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(dict_values.len(), 4); + assert_eq!(dict_values.value(0), "alpha"); + assert_eq!(dict_values.value(1), "beta"); + assert_eq!(dict_values.value(2), "gamma"); + assert_eq!(dict_values.value(3), "delta"); + + Ok(()) +} + +#[test] +fn test_large_dictionary_delta_performance() -> Result<(), ArrowError> { + // Test delta dictionary with large dictionaries to ensure efficiency + + // Create a large initial dictionary + let mut builder1 = StringDictionaryBuilder::::new(); + for i in 0..1000 { + builder1.append_value(format!("value_{i}")); + } + let dict1 = builder1.finish(); + + // Create extended dictionary + let mut builder2 = StringDictionaryBuilder::::new(); + for i in 0..1000 { + builder2.append_value(format!("value_{i}")); + } + // Add just a few new values + for i in 1000..1005 { + builder2.append_value(format!("value_{i}")); + } + let dict2 = builder2.finish(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "dict", + dict1.data_type().clone(), + true, + )])); + + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(dict1) as ArrayRef])?; + + let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(dict2) as ArrayRef])?; + + // Write with delta dictionary handling + let mut buffer = Vec::new(); + { + let options = + IpcWriteOptions::default().with_dictionary_handling(DictionaryHandling::Delta); + let mut writer = StreamWriter::try_new_with_options(&mut buffer, &schema, options)?; + writer.write(&batch1)?; + writer.write(&batch2)?; + writer.finish()?; + } + + // The buffer should be relatively small since we only sent 5 new values + // as delta instead of resending all 1005 values + let buffer_size = buffer.len(); + + // Write without delta for comparison + let mut buffer_no_delta = Vec::new(); + { + let options = + IpcWriteOptions::default().with_dictionary_handling(DictionaryHandling::Resend); + let mut writer = + StreamWriter::try_new_with_options(&mut buffer_no_delta, &schema, options)?; + writer.write(&batch1)?; + writer.write(&batch2)?; + writer.finish()?; + } + + let buffer_no_delta_size = buffer_no_delta.len(); + + // Delta encoding should result in smaller output + println!("Delta buffer size: {buffer_size}"); + println!("Non-delta buffer size: {buffer_size}"); + + // Delta encoding should result in significantly smaller output + assert!( + buffer_size < buffer_no_delta_size, + "Delta buffer ({buffer_size}) should be smaller than non-delta buffer ({buffer_no_delta_size})" + ); + + // The delta should save approximately the size of the second dictionary minus the delta + // We sent 5 values instead of 1005, saving ~99.5% on the second dictionary + let savings_ratio = (buffer_no_delta_size - buffer_size) as f64 / buffer_no_delta_size as f64; + println!("Space savings: {:.1}%", savings_ratio * 100.0); + + // We should save at least 30% (conservative estimate accounting for metadata overhead) + assert!( + savings_ratio > 0.30, + "Delta encoding should provide significant space savings (got {:.1}%)", + savings_ratio * 100.0 + ); + + // Verify correctness + let reader = StreamReader::try_new(Cursor::new(buffer), None)?; + let read_batches: Result, _> = reader.collect(); + let read_batches = read_batches?; + + assert_eq!(read_batches.len(), 2); + + let read_batch2 = &read_batches[1]; + let dict_array = read_batch2 + .column(0) + .as_any() + .downcast_ref::>() + .unwrap(); + + let dict_values = dict_array + .values() + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(dict_values.len(), 1005); + assert_eq!(dict_values.value(1004), "value_1004"); + + Ok(()) +} diff --git a/arrow-ord/src/comparison.rs b/arrow-ord/src/comparison.rs index bb82f54d4918..f4daff8501b6 100644 --- a/arrow-ord/src/comparison.rs +++ b/arrow-ord/src/comparison.rs @@ -3059,6 +3059,120 @@ mod tests { ); } + fn create_decimal_array(data: Vec>) -> PrimitiveArray { + data.into_iter().collect::>() + } + + fn test_cmp_dict_decimal( + values1: Vec>, + values2: Vec>, + ) { + let values = create_decimal_array::(values1); + let keys = Int8Array::from_iter_values([1_i8, 2, 5, 4, 3, 0]); + let array1 = DictionaryArray::new(keys, Arc::new(values)); + + let values = create_decimal_array::(values2); + let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2, 3, 4]); + let array2 = DictionaryArray::new(keys, Arc::new(values)); + + let expected = BooleanArray::from(vec![false, false, false, true, true, false]); + assert_eq!(crate::cmp::eq(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from(vec![true, true, false, false, false, true]); + assert_eq!(crate::cmp::lt(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from(vec![true, true, false, true, true, true]); + assert_eq!(crate::cmp::lt_eq(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from(vec![false, false, true, false, false, false]); + assert_eq!(crate::cmp::gt(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from(vec![false, false, true, true, true, false]); + assert_eq!(crate::cmp::gt_eq(&array1, &array2).unwrap(), expected); + } + + #[test] + fn test_cmp_dict_decimal32() { + test_cmp_dict_decimal::( + vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)], + vec![Some(7), Some(-3), Some(4), Some(3), Some(5)], + ); + } + + #[test] + fn test_cmp_dict_non_dict_decimal32() { + let array1: Decimal32Array = Decimal32Array::from_iter_values([1, 2, 5, 4, 3, 0]); + + let values = Decimal32Array::from_iter_values([7, -3, 4, 3, 5]); + let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2, 3, 4]); + let array2 = DictionaryArray::new(keys, Arc::new(values)); + + let expected = BooleanArray::from(vec![false, false, false, true, true, false]); + assert_eq!(crate::cmp::eq(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from(vec![true, true, false, false, false, true]); + assert_eq!(crate::cmp::lt(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from(vec![true, true, false, true, true, true]); + assert_eq!(crate::cmp::lt_eq(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from(vec![false, false, true, false, false, false]); + assert_eq!(crate::cmp::gt(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from(vec![false, false, true, true, true, false]); + assert_eq!(crate::cmp::gt_eq(&array1, &array2).unwrap(), expected); + } + + #[test] + fn test_cmp_dict_decimal64() { + let values = Decimal64Array::from_iter_values([0, 1, 2, 3, 4, 5]); + let keys = Int8Array::from_iter_values([1_i8, 2, 5, 4, 3, 0]); + let array1 = DictionaryArray::new(keys, Arc::new(values)); + + let values = Decimal64Array::from_iter_values([7, -3, 4, 3, 5]); + let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2, 3, 4]); + let array2 = DictionaryArray::new(keys, Arc::new(values)); + + let expected = BooleanArray::from(vec![false, false, false, true, true, false]); + assert_eq!(crate::cmp::eq(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from(vec![true, true, false, false, false, true]); + assert_eq!(crate::cmp::lt(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from(vec![true, true, false, true, true, true]); + assert_eq!(crate::cmp::lt_eq(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from(vec![false, false, true, false, false, false]); + assert_eq!(crate::cmp::gt(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from(vec![false, false, true, true, true, false]); + assert_eq!(crate::cmp::gt_eq(&array1, &array2).unwrap(), expected); + } + + #[test] + fn test_cmp_dict_non_dict_decimal64() { + let array1: Decimal64Array = Decimal64Array::from_iter_values([1, 2, 5, 4, 3, 0]); + + let values = Decimal64Array::from_iter_values([7, -3, 4, 3, 5]); + let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2, 3, 4]); + let array2 = DictionaryArray::new(keys, Arc::new(values)); + + let expected = BooleanArray::from(vec![false, false, false, true, true, false]); + assert_eq!(crate::cmp::eq(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from(vec![true, true, false, false, false, true]); + assert_eq!(crate::cmp::lt(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from(vec![true, true, false, true, true, true]); + assert_eq!(crate::cmp::lt_eq(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from(vec![false, false, true, false, false, false]); + assert_eq!(crate::cmp::gt(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from(vec![false, false, true, true, true, false]); + assert_eq!(crate::cmp::gt_eq(&array1, &array2).unwrap(), expected); + } + #[test] fn test_cmp_dict_decimal128() { let values = Decimal128Array::from_iter_values([0, 1, 2, 3, 4, 5]); @@ -3163,6 +3277,103 @@ mod tests { assert_eq!(crate::cmp::gt_eq(&array1, &array2).unwrap(), expected); } + #[test] + fn test_decimal32() { + let a = Decimal32Array::from_iter_values([1, 2, 4, 5]); + let b = Decimal32Array::from_iter_values([7, -3, 4, 3]); + let e = BooleanArray::from(vec![false, false, true, false]); + let r = crate::cmp::eq(&a, &b).unwrap(); + assert_eq!(e, r); + + let e = BooleanArray::from(vec![true, false, false, false]); + let r = crate::cmp::lt(&a, &b).unwrap(); + assert_eq!(e, r); + + let e = BooleanArray::from(vec![true, false, true, false]); + let r = crate::cmp::lt_eq(&a, &b).unwrap(); + assert_eq!(e, r); + + let e = BooleanArray::from(vec![false, true, false, true]); + let r = crate::cmp::gt(&a, &b).unwrap(); + assert_eq!(e, r); + + let e = BooleanArray::from(vec![false, true, true, true]); + let r = crate::cmp::gt_eq(&a, &b).unwrap(); + assert_eq!(e, r); + } + + #[test] + fn test_decimal32_scalar() { + let a = Decimal32Array::from(vec![Some(1), Some(2), Some(3), None, Some(4), Some(5)]); + let b = Decimal32Array::new_scalar(3_i32); + // array eq scalar + let e = BooleanArray::from( + vec![Some(false), Some(false), Some(true), None, Some(false), Some(false)], + ); + let r = crate::cmp::eq(&a, &b).unwrap(); + assert_eq!(e, r); + + // array neq scalar + let e = BooleanArray::from( + vec![Some(true), Some(true), Some(false), None, Some(true), Some(true)], + ); + let r = crate::cmp::neq(&a, &b).unwrap(); + assert_eq!(e, r); + + // array lt scalar + let e = BooleanArray::from( + vec![Some(true), Some(true), Some(false), None, Some(false), Some(false)], + ); + let r = crate::cmp::lt(&a, &b).unwrap(); + assert_eq!(e, r); + + // array lt_eq scalar + let e = BooleanArray::from( + vec![Some(true), Some(true), Some(true), None, Some(false), Some(false)], + ); + let r = crate::cmp::lt_eq(&a, &b).unwrap(); + assert_eq!(e, r); + + // array gt scalar + let e = BooleanArray::from( + vec![Some(false), Some(false), Some(false), None, Some(true), Some(true)], + ); + let r = crate::cmp::gt(&a, &b).unwrap(); + assert_eq!(e, r); + + // array gt_eq scalar + let e = BooleanArray::from( + vec![Some(false), Some(false), Some(true), None, Some(true), Some(true)], + ); + let r = crate::cmp::gt_eq(&a, &b).unwrap(); + assert_eq!(e, r); + } + + #[test] + fn test_decimal64() { + let a = Decimal64Array::from_iter_values([1, 2, 4, 5]); + let b = Decimal64Array::from_iter_values([7, -3, 4, 3]); + let e = BooleanArray::from(vec![false, false, true, false]); + let r = crate::cmp::eq(&a, &b).unwrap(); + assert_eq!(e, r); + + let e = BooleanArray::from(vec![true, false, false, false]); + let r = crate::cmp::lt(&a, &b).unwrap(); + assert_eq!(e, r); + + let e = BooleanArray::from(vec![true, false, true, false]); + let r = crate::cmp::lt_eq(&a, &b).unwrap(); + assert_eq!(e, r); + + let e = BooleanArray::from(vec![false, true, false, true]); + let r = crate::cmp::gt(&a, &b).unwrap(); + assert_eq!(e, r); + + let e = BooleanArray::from(vec![false, true, true, true]); + let r = crate::cmp::gt_eq(&a, &b).unwrap(); + assert_eq!(e, r); + } + #[test] fn test_decimal128() { let a = Decimal128Array::from_iter_values([1, 2, 4, 5]); diff --git a/arrow-ord/src/ord.rs b/arrow-ord/src/ord.rs index 7d1c9b0c13dd..6ff076632491 100644 --- a/arrow-ord/src/ord.rs +++ b/arrow-ord/src/ord.rs @@ -575,7 +575,33 @@ mod tests { } #[test] - fn test_decimal() { + fn test_decimali32() { + let array = vec![Some(5_i32), Some(2_i32), Some(3_i32)] + .into_iter() + .collect::() + .with_precision_and_scale(8, 6) + .unwrap(); + + let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap(); + assert_eq!(Ordering::Less, cmp(1, 0)); + assert_eq!(Ordering::Greater, cmp(0, 2)); + } + + #[test] + fn test_decimali64() { + let array = vec![Some(5_i64), Some(2_i64), Some(3_i64)] + .into_iter() + .collect::() + .with_precision_and_scale(16, 6) + .unwrap(); + + let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap(); + assert_eq!(Ordering::Less, cmp(1, 0)); + assert_eq!(Ordering::Greater, cmp(0, 2)); + } + + #[test] + fn test_decimali128() { let array = vec![Some(5_i128), Some(2_i128), Some(3_i128)] .into_iter() .collect::() diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs index ba026af637d7..797c2246738c 100644 --- a/arrow-ord/src/sort.rs +++ b/arrow-ord/src/sort.rs @@ -841,7 +841,7 @@ pub struct SortColumn { /// Sort a list of `ArrayRef` using `SortOptions` provided for each array. /// -/// Performs a stable lexicographical sort on values and indices. +/// Performs an unstable lexicographical sort on values and indices. /// /// Returns an `ArrowError::ComputeError(String)` if any of the array type is either unsupported by /// `lexsort_to_indices` or `take`. @@ -2307,6 +2307,16 @@ mod tests { ); } + #[test] + fn test_sort_indices_decimal32() { + test_sort_indices_decimal::(8, 3); + } + + #[test] + fn test_sort_indices_decimal64() { + test_sort_indices_decimal::(17, 5); + } + #[test] fn test_sort_indices_decimal128() { test_sort_indices_decimal::(23, 6); @@ -2460,6 +2470,16 @@ mod tests { ); } + #[test] + fn test_sort_decimal32() { + test_sort_decimal::(8, 3); + } + + #[test] + fn test_sort_decimal64() { + test_sort_decimal::(17, 5); + } + #[test] fn test_sort_decimal128() { test_sort_decimal::(23, 6); diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 9508249324ee..cdb52a8ee7fd 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -97,7 +97,7 @@ //! assert_eq!(&c2_values, &["a", "f", "c", "e"]); //! ``` //! -//! # Lexsort +//! # Lexicographic Sorts (lexsort) //! //! The row format can also be used to implement a fast multi-column / lexicographic sort //! @@ -117,6 +117,33 @@ //! } //! ``` //! +//! # Flattening Dictionaries +//! +//! For performance reasons, dictionary arrays are flattened ("hydrated") to their +//! underlying values during row conversion. See [the issue] for more details. +//! +//! This means that the arrays that come out of [`RowConverter::convert_rows`] +//! may not have the same data types as the input arrays. For example, encoding +//! a `Dictionary` and then will come out as a `Utf8` array. +//! +//! ``` +//! # use arrow_array::{Array, ArrayRef, DictionaryArray}; +//! # use arrow_array::types::Int8Type; +//! # use arrow_row::{RowConverter, SortField}; +//! # use arrow_schema::DataType; +//! # use std::sync::Arc; +//! // Input is a Dictionary array +//! let dict: DictionaryArray:: = ["a", "b", "c", "a", "b"].into_iter().collect(); +//! let sort_fields = vec![SortField::new(dict.data_type().clone())]; +//! let arrays = vec![Arc::new(dict) as ArrayRef]; +//! let converter = RowConverter::new(sort_fields).unwrap(); +//! // Convert to rows +//! let rows = converter.convert_columns(&arrays).unwrap(); +//! let converted = converter.convert_rows(&rows).unwrap(); +//! // result was a Utf8 array, not a Dictionary array +//! assert_eq!(converted[0].data_type(), &DataType::Utf8); +//! ``` +//! //! [non-comparison sorts]: https://en.wikipedia.org/wiki/Sorting_algorithm#Non-comparison_sorts //! [radix sort]: https://en.wikipedia.org/wiki/Radix_sort //! [normalized for sorting]: http://wwwlgis.informatik.uni-kl.de/archiv/wwwdvs.informatik.uni-kl.de/courses/DBSREAL/SS2005/Vorlesungsunterlagen/Implementing_Sorting.pdf @@ -124,6 +151,7 @@ //! [`lexsort`]: https://docs.rs/arrow-ord/latest/arrow_ord/sort/fn.lexsort.html //! [compared]: PartialOrd //! [compare]: PartialOrd +//! [the issue]: https://github.com/apache/arrow-rs/issues/4811 #![doc( html_logo_url = "https://arrow.apache.org/img/arrow-logo_chevrons_black-txt_white-bg.svg", @@ -139,7 +167,7 @@ use arrow_array::cast::*; use arrow_array::types::ArrowDictionaryKeyType; use arrow_array::*; use arrow_buffer::{ArrowNativeType, Buffer, OffsetBuffer, ScalarBuffer}; -use arrow_data::ArrayDataBuilder; +use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::*; use variable::{decode_binary_view, decode_string_view}; @@ -661,6 +689,8 @@ impl RowConverter { /// /// See [`Row`] for information on when [`Row`] can be compared /// + /// See [`Self::convert_rows`] for converting [`Rows`] back into [`ArrayRef`] + /// /// # Panics /// /// Panics if the schema of `columns` does not match that provided to [`RowConverter::new`] @@ -768,6 +798,8 @@ impl RowConverter { /// Convert [`Rows`] columns into [`ArrayRef`] /// + /// See [`Self::convert_columns`] for converting [`ArrayRef`] into [`Rows`] + /// /// # Panics /// /// Panics if the rows were not produced by this [`RowConverter`] @@ -1668,8 +1700,24 @@ unsafe fn decode_column( rows.iter_mut().for_each(|row| *row = &row[1..]); let children = converter.convert_raw(rows, validate_utf8)?; - let child_data = children.iter().map(|c| c.to_data()).collect(); - let builder = ArrayDataBuilder::new(field.data_type.clone()) + let child_data: Vec = children.iter().map(|c| c.to_data()).collect(); + // Since RowConverter flattens certain data types (i.e. Dictionary), + // we need to use updated data type instead of original field + let corrected_fields: Vec = match &field.data_type { + DataType::Struct(struct_fields) => struct_fields + .iter() + .zip(child_data.iter()) + .map(|(orig_field, child_array)| { + orig_field + .as_ref() + .clone() + .with_data_type(child_array.data_type().clone()) + }) + .collect(), + _ => unreachable!("Only Struct types should be corrected here"), + }; + let corrected_struct_type = DataType::Struct(corrected_fields.into()); + let builder = ArrayDataBuilder::new(corrected_struct_type) .len(rows.len()) .null_count(null_count) .null_bit_buffer(Some(nulls)) @@ -1800,6 +1848,66 @@ mod tests { } } + #[test] + fn test_decimal32() { + let converter = RowConverter::new(vec![SortField::new(DataType::Decimal32( + DECIMAL32_MAX_PRECISION, + 7, + ))]) + .unwrap(); + let col = Arc::new( + Decimal32Array::from_iter([ + None, + Some(i32::MIN), + Some(-13), + Some(46_i32), + Some(5456_i32), + Some(i32::MAX), + ]) + .with_precision_and_scale(9, 7) + .unwrap(), + ) as ArrayRef; + + let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap(); + for i in 0..rows.num_rows() - 1 { + assert!(rows.row(i) < rows.row(i + 1)); + } + + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + assert_eq!(col.as_ref(), back[0].as_ref()) + } + + #[test] + fn test_decimal64() { + let converter = RowConverter::new(vec![SortField::new(DataType::Decimal64( + DECIMAL64_MAX_PRECISION, + 7, + ))]) + .unwrap(); + let col = Arc::new( + Decimal64Array::from_iter([ + None, + Some(i64::MIN), + Some(-13), + Some(46_i64), + Some(5456_i64), + Some(i64::MAX), + ]) + .with_precision_and_scale(18, 7) + .unwrap(), + ) as ArrayRef; + + let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap(); + for i in 0..rows.num_rows() - 1 { + assert!(rows.row(i) < rows.row(i + 1)); + } + + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + assert_eq!(col.as_ref(), back[0].as_ref()) + } + #[test] fn test_decimal128() { let converter = RowConverter::new(vec![SortField::new(DataType::Decimal128( @@ -2148,6 +2256,177 @@ mod tests { back[0].to_data().validate_full().unwrap(); } + #[test] + fn test_dictionary_in_struct() { + let builder = StringDictionaryBuilder::::new(); + let mut struct_builder = StructBuilder::new( + vec![Field::new_dictionary( + "foo", + DataType::Int32, + DataType::Utf8, + true, + )], + vec![Box::new(builder)], + ); + + let dict_builder = struct_builder + .field_builder::>(0) + .unwrap(); + + // Flattened: ["a", null, "a", "b"] + dict_builder.append_value("a"); + dict_builder.append_null(); + dict_builder.append_value("a"); + dict_builder.append_value("b"); + + for _ in 0..4 { + struct_builder.append(true); + } + + let s = Arc::new(struct_builder.finish()) as ArrayRef; + let sort_fields = vec![SortField::new(s.data_type().clone())]; + let converter = RowConverter::new(sort_fields).unwrap(); + let r = converter.convert_columns(&[Arc::clone(&s)]).unwrap(); + + let back = converter.convert_rows(&r).unwrap(); + let [s2] = back.try_into().unwrap(); + + // RowConverter flattens Dictionary + // s.ty = Struct(foo Dictionary(Int32, Utf8)), s2.ty = Struct(foo Utf8) + assert_ne!(&s.data_type(), &s2.data_type()); + s2.to_data().validate_full().unwrap(); + + // Check if the logical data remains the same + // Keys: [0, null, 0, 1] + // Values: ["a", "b"] + let s1_struct = s.as_struct(); + let s1_0 = s1_struct.column(0); + let s1_idx_0 = s1_0.as_dictionary::(); + let keys = s1_idx_0.keys(); + let values = s1_idx_0.values().as_string::(); + // Flattened: ["a", null, "a", "b"] + let s2_struct = s2.as_struct(); + let s2_0 = s2_struct.column(0); + let s2_idx_0 = s2_0.as_string::(); + + for i in 0..keys.len() { + if keys.is_null(i) { + assert!(s2_idx_0.is_null(i)); + } else { + let dict_index = keys.value(i) as usize; + assert_eq!(values.value(dict_index), s2_idx_0.value(i)); + } + } + } + + #[test] + fn test_dictionary_in_struct_empty() { + let ty = DataType::Struct( + vec![Field::new_dictionary( + "foo", + DataType::Int32, + DataType::Int32, + false, + )] + .into(), + ); + let s = arrow_array::new_empty_array(&ty); + + let sort_fields = vec![SortField::new(s.data_type().clone())]; + let converter = RowConverter::new(sort_fields).unwrap(); + let r = converter.convert_columns(&[Arc::clone(&s)]).unwrap(); + + let back = converter.convert_rows(&r).unwrap(); + let [s2] = back.try_into().unwrap(); + + // RowConverter flattens Dictionary + // s.ty = Struct(foo Dictionary(Int32, Int32)), s2.ty = Struct(foo Int32) + assert_ne!(&s.data_type(), &s2.data_type()); + s2.to_data().validate_full().unwrap(); + assert_eq!(s.len(), 0); + assert_eq!(s2.len(), 0); + } + + #[test] + fn test_list_of_string_dictionary() { + let mut builder = ListBuilder::>::default(); + // List[0] = ["a", "b", "zero", null, "c", "b", "d" (dict)] + builder.values().append("a").unwrap(); + builder.values().append("b").unwrap(); + builder.values().append("zero").unwrap(); + builder.values().append_null(); + builder.values().append("c").unwrap(); + builder.values().append("b").unwrap(); + builder.values().append("d").unwrap(); + builder.append(true); + // List[1] = null + builder.append(false); + // List[2] = ["e", "zero", "a" (dict)] + builder.values().append("e").unwrap(); + builder.values().append("zero").unwrap(); + builder.values().append("a").unwrap(); + builder.append(true); + + let a = Arc::new(builder.finish()) as ArrayRef; + let data_type = a.data_type().clone(); + + let field = SortField::new(data_type.clone()); + let converter = RowConverter::new(vec![field]).unwrap(); + let rows = converter.convert_columns(&[Arc::clone(&a)]).unwrap(); + + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + let [a2] = back.try_into().unwrap(); + + // RowConverter flattens Dictionary + // a.ty: List(Dictionary(Int32, Utf8)), a2.ty: List(Utf8) + assert_ne!(&a.data_type(), &a2.data_type()); + + a2.to_data().validate_full().unwrap(); + + let a2_list = a2.as_list::(); + let a1_list = a.as_list::(); + + // Check if the logical data remains the same + // List[0] = ["a", "b", "zero", null, "c", "b", "d" (dict)] + let a1_0 = a1_list.value(0); + let a1_idx_0 = a1_0.as_dictionary::(); + let keys = a1_idx_0.keys(); + let values = a1_idx_0.values().as_string::(); + let a2_0 = a2_list.value(0); + let a2_idx_0 = a2_0.as_string::(); + + for i in 0..keys.len() { + if keys.is_null(i) { + assert!(a2_idx_0.is_null(i)); + } else { + let dict_index = keys.value(i) as usize; + assert_eq!(values.value(dict_index), a2_idx_0.value(i)); + } + } + + // List[1] = null + assert!(a1_list.is_null(1)); + assert!(a2_list.is_null(1)); + + // List[2] = ["e", "zero", "a" (dict)] + let a1_2 = a1_list.value(2); + let a1_idx_2 = a1_2.as_dictionary::(); + let keys = a1_idx_2.keys(); + let values = a1_idx_2.values().as_string::(); + let a2_2 = a2_list.value(2); + let a2_idx_2 = a2_2.as_string::(); + + for i in 0..keys.len() { + if keys.is_null(i) { + assert!(a2_idx_2.is_null(i)); + } else { + let dict_index = keys.value(i) as usize; + assert_eq!(values.value(dict_index), a2_idx_2.value(i)); + } + } + } + #[test] fn test_primitive_dictionary() { let mut builder = PrimitiveDictionaryBuilder::::new(); @@ -2171,6 +2450,10 @@ mod tests { assert!(rows.row(3) < rows.row(2)); assert!(rows.row(6) < rows.row(2)); assert!(rows.row(3) < rows.row(6)); + + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); } #[test] diff --git a/arrow-row/src/list.rs b/arrow-row/src/list.rs index 91c788fc8f41..72d93d2f4bbe 100644 --- a/arrow-row/src/list.rs +++ b/arrow-row/src/list.rs @@ -20,7 +20,7 @@ use arrow_array::{new_null_array, Array, FixedSizeListArray, GenericListArray, O use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer}; use arrow_data::ArrayDataBuilder; use arrow_schema::{ArrowError, DataType, SortOptions}; -use std::ops::Range; +use std::{ops::Range, sync::Arc}; pub fn compute_lengths( lengths: &mut [usize], @@ -179,7 +179,25 @@ pub unsafe fn decode( let child_data = child[0].to_data(); - let builder = ArrayDataBuilder::new(field.data_type.clone()) + // Since RowConverter flattens certain data types (i.e. Dictionary), + // we need to use updated data type instead of original field + let corrected_type = match &field.data_type { + DataType::List(inner_field) => DataType::List(Arc::new( + inner_field + .as_ref() + .clone() + .with_data_type(child_data.data_type().clone()), + )), + DataType::LargeList(inner_field) => DataType::LargeList(Arc::new( + inner_field + .as_ref() + .clone() + .with_data_type(child_data.data_type().clone()), + )), + _ => unreachable!(), + }; + + let builder = ArrayDataBuilder::new(corrected_type) .len(rows.len()) .null_count(null_count) .null_bit_buffer(Some(nulls.into())) diff --git a/arrow-schema/src/field.rs b/arrow-schema/src/field.rs index 469c930d31c7..3beae35795e4 100644 --- a/arrow-schema/src/field.rs +++ b/arrow-schema/src/field.rs @@ -547,7 +547,7 @@ impl Field { /// # Error /// /// Returns an error if - /// - this field does have a canonical extension type (mismatch or missing) + /// - this field does not have a canonical extension type (mismatch or missing) /// - the canonical extension is not supported /// - the construction of the extension type fails #[cfg(feature = "canonical_extension_types")] diff --git a/arrow-select/src/coalesce.rs b/arrow-select/src/coalesce.rs index 891d62fc3aa6..3ae31612c903 100644 --- a/arrow-select/src/coalesce.rs +++ b/arrow-select/src/coalesce.rs @@ -142,6 +142,8 @@ pub struct BatchCoalescer { buffered_rows: usize, /// Completed batches completed: VecDeque, + /// Biggest coalesce batch size. See [`Self::with_biggest_coalesce_batch_size`] + biggest_coalesce_batch_size: Option, } impl BatchCoalescer { @@ -166,9 +168,41 @@ impl BatchCoalescer { // We will for sure store at least one completed batch completed: VecDeque::with_capacity(1), buffered_rows: 0, + biggest_coalesce_batch_size: None, } } + /// Set the coalesce batch size limit (default `None`) + /// + /// This limit determine when batches should bypass coalescing. Intuitively, + /// batches that are already large are costly to coalesce and are efficient + /// enough to process directly without coalescing. + /// + /// If `Some(limit)`, batches larger than this limit will bypass coalescing + /// when there is no buffered data, or when the previously buffered data + /// already exceeds this limit. + /// + /// If `None`, all batches will be coalesced according to the + /// target_batch_size. + pub fn with_biggest_coalesce_batch_size(mut self, limit: Option) -> Self { + self.biggest_coalesce_batch_size = limit; + self + } + + /// Get the current biggest coalesce batch size limit + /// + /// See [`Self::with_biggest_coalesce_batch_size`] for details + pub fn biggest_coalesce_batch_size(&self) -> Option { + self.biggest_coalesce_batch_size + } + + /// Set the biggest coalesce batch size limit + /// + /// See [`Self::with_biggest_coalesce_batch_size`] for details + pub fn set_biggest_coalesce_batch_size(&mut self, limit: Option) { + self.biggest_coalesce_batch_size = limit; + } + /// Return the schema of the output batches pub fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) @@ -236,11 +270,160 @@ impl BatchCoalescer { /// assert_eq!(completed_batch, expected_batch); /// ``` pub fn push_batch(&mut self, batch: RecordBatch) -> Result<(), ArrowError> { - let (_schema, arrays, mut num_rows) = batch.into_parts(); - if num_rows == 0 { + // Large batch bypass optimization: + // When biggest_coalesce_batch_size is configured and a batch exceeds this limit, + // we can avoid expensive split-and-merge operations by passing it through directly. + // + // IMPORTANT: This optimization is OPTIONAL and only active when biggest_coalesce_batch_size + // is explicitly set via with_biggest_coalesce_batch_size(Some(limit)). + // If not set (None), ALL batches follow normal coalescing behavior regardless of size. + + // ============================================================================= + // CASE 1: No buffer + large batch → Direct bypass + // ============================================================================= + // Example scenario (target_batch_size=1000, biggest_coalesce_batch_size=Some(500)): + // Input sequence: [600, 1200, 300] + // + // With biggest_coalesce_batch_size=Some(500) (optimization enabled): + // 600 → large batch detected! buffered_rows=0 → Case 1: direct bypass + // → output: [600] (bypass, preserves large batch) + // 1200 → large batch detected! buffered_rows=0 → Case 1: direct bypass + // → output: [1200] (bypass, preserves large batch) + // 300 → normal batch, buffer: [300] + // Result: [600], [1200], [300] - large batches preserved, mixed sizes + + // ============================================================================= + // CASE 2: Buffer too large + large batch → Flush first, then bypass + // ============================================================================= + // This case prevents creating extremely large merged batches that would + // significantly exceed both target_batch_size and biggest_coalesce_batch_size. + // + // Example 1: Buffer exceeds limit before large batch arrives + // target_batch_size=1000, biggest_coalesce_batch_size=Some(400) + // Input: [350, 200, 800] + // + // Step 1: push_batch([350]) + // → batch_size=350 <= 400, normal path + // → buffer: [350], buffered_rows=350 + // + // Step 2: push_batch([200]) + // → batch_size=200 <= 400, normal path + // → buffer: [350, 200], buffered_rows=550 + // + // Step 3: push_batch([800]) + // → batch_size=800 > 400, large batch path + // → buffered_rows=550 > 400 → Case 2: flush first + // → flush: output [550] (combined [350, 200]) + // → then bypass: output [800] + // Result: [550], [800] - buffer flushed to prevent oversized merge + // + // Example 2: Multiple small batches accumulate before large batch + // target_batch_size=1000, biggest_coalesce_batch_size=Some(300) + // Input: [150, 100, 80, 900] + // + // Step 1-3: Accumulate small batches + // 150 → buffer: [150], buffered_rows=150 + // 100 → buffer: [150, 100], buffered_rows=250 + // 80 → buffer: [150, 100, 80], buffered_rows=330 + // + // Step 4: push_batch([900]) + // → batch_size=900 > 300, large batch path + // → buffered_rows=330 > 300 → Case 2: flush first + // → flush: output [330] (combined [150, 100, 80]) + // → then bypass: output [900] + // Result: [330], [900] - prevents merge into [1230] which would be too large + + // ============================================================================= + // CASE 3: Small buffer + large batch → Normal coalescing (no bypass) + // ============================================================================= + // When buffer is small enough, we still merge to maintain efficiency + // Example: target_batch_size=1000, biggest_coalesce_batch_size=Some(500) + // Input: [300, 1200] + // + // Step 1: push_batch([300]) + // → batch_size=300 <= 500, normal path + // → buffer: [300], buffered_rows=300 + // + // Step 2: push_batch([1200]) + // → batch_size=1200 > 500, large batch path + // → buffered_rows=300 <= 500 → Case 3: normal merge + // → buffer: [300, 1200] (1500 total) + // → 1500 > target_batch_size → split: output [1000], buffer [500] + // Result: [1000], [500] - normal split/merge behavior maintained + + // ============================================================================= + // Comparison: Default vs Optimized Behavior + // ============================================================================= + // target_batch_size=1000, biggest_coalesce_batch_size=Some(500) + // Input: [600, 1200, 300] + // + // DEFAULT BEHAVIOR (biggest_coalesce_batch_size=None): + // 600 → buffer: [600] + // 1200 → buffer: [600, 1200] (1800 rows total) + // → split: output [1000 rows], buffer [800 rows remaining] + // 300 → buffer: [800, 300] (1100 rows total) + // → split: output [1000 rows], buffer [100 rows remaining] + // Result: [1000], [1000], [100] - all outputs respect target_batch_size + // + // OPTIMIZED BEHAVIOR (biggest_coalesce_batch_size=Some(500)): + // 600 → Case 1: direct bypass → output: [600] + // 1200 → Case 1: direct bypass → output: [1200] + // 300 → normal path → buffer: [300] + // Result: [600], [1200], [300] - large batches preserved + + // ============================================================================= + // Benefits and Trade-offs + // ============================================================================= + // Benefits of the optimization: + // - Large batches stay intact (better for downstream vectorized processing) + // - Fewer split/merge operations (better CPU performance) + // - More predictable memory usage patterns + // - Maintains streaming efficiency while preserving batch boundaries + // + // Trade-offs: + // - Output batch sizes become variable (not always target_batch_size) + // - May produce smaller partial batches when flushing before large batches + // - Requires tuning biggest_coalesce_batch_size parameter for optimal performance + + // TODO, for unsorted batches, we may can filter all large batches, and coalesce all + // small batches together? + + let batch_size = batch.num_rows(); + + // Fast path: skip empty batches + if batch_size == 0 { return Ok(()); } + // Large batch optimization: bypass coalescing for oversized batches + if let Some(limit) = self.biggest_coalesce_batch_size { + if batch_size > limit { + // Case 1: No buffered data - emit large batch directly + // Example: [] + [1200] → output [1200], buffer [] + if self.buffered_rows == 0 { + self.completed.push_back(batch); + return Ok(()); + } + + // Case 2: Buffer too large - flush then emit to avoid oversized merge + // Example: [850] + [1200] → output [850], then output [1200] + // This prevents creating batches much larger than both target_batch_size + // and biggest_coalesce_batch_size, which could cause memory issues + if self.buffered_rows > limit { + self.finish_buffered_batch()?; + self.completed.push_back(batch); + return Ok(()); + } + + // Case 3: Small buffer - proceed with normal coalescing + // Example: [300] + [1200] → split and merge normally + // This ensures small batches still get properly coalesced + // while allowing some controlled growth beyond the limit + } + } + + let (_schema, arrays, mut num_rows) = batch.into_parts(); + // setup input rows assert_eq!(arrays.len(), self.in_progress_arrays.len()); self.in_progress_arrays @@ -290,6 +473,11 @@ impl BatchCoalescer { Ok(()) } + /// Returns the number of buffered rows + pub fn get_buffered_rows(&self) -> usize { + self.buffered_rows + } + /// Concatenates any buffered batches into a single `RecordBatch` and /// clears any output buffers /// @@ -394,7 +582,7 @@ mod tests { use arrow_array::builder::StringViewBuilder; use arrow_array::cast::AsArray; use arrow_array::{ - BinaryViewArray, Int64Array, RecordBatchOptions, StringArray, StringViewArray, + BinaryViewArray, Int32Array, Int64Array, RecordBatchOptions, StringArray, StringViewArray, TimestampNanosecondArray, UInt32Array, }; use arrow_schema::{DataType, Field, Schema}; @@ -1314,4 +1502,436 @@ mod tests { let options = RecordBatchOptions::new().with_row_count(Some(row_count)); RecordBatch::try_new_with_options(schema, columns, &options).unwrap() } + + /// Helper function to create a test batch with specified number of rows + fn create_test_batch(num_rows: usize) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])); + let array = Int32Array::from_iter_values(0..num_rows as i32); + RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap() + } + #[test] + fn test_biggest_coalesce_batch_size_none_default() { + // Test that default behavior (None) coalesces all batches + let mut coalescer = BatchCoalescer::new( + Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])), + 100, + ); + + // Push a large batch (1000 rows) - should be coalesced normally + let large_batch = create_test_batch(1000); + coalescer.push_batch(large_batch).unwrap(); + + // Should produce multiple batches of target size (100) + let mut output_batches = vec![]; + while let Some(batch) = coalescer.next_completed_batch() { + output_batches.push(batch); + } + + coalescer.finish_buffered_batch().unwrap(); + while let Some(batch) = coalescer.next_completed_batch() { + output_batches.push(batch); + } + + // Should have 10 batches of 100 rows each + assert_eq!(output_batches.len(), 10); + for batch in output_batches { + assert_eq!(batch.num_rows(), 100); + } + } + + #[test] + fn test_biggest_coalesce_batch_size_bypass_large_batch() { + // Test that batches larger than biggest_coalesce_batch_size bypass coalescing + let mut coalescer = BatchCoalescer::new( + Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])), + 100, + ); + coalescer.set_biggest_coalesce_batch_size(Some(500)); + + // Push a large batch (1000 rows) - should bypass coalescing + let large_batch = create_test_batch(1000); + coalescer.push_batch(large_batch.clone()).unwrap(); + + // Should have one completed batch immediately (the original large batch) + assert!(coalescer.has_completed_batch()); + let output_batch = coalescer.next_completed_batch().unwrap(); + assert_eq!(output_batch.num_rows(), 1000); + + // Should be no more completed batches + assert!(!coalescer.has_completed_batch()); + assert_eq!(coalescer.get_buffered_rows(), 0); + } + + #[test] + fn test_biggest_coalesce_batch_size_coalesce_small_batch() { + // Test that batches smaller than biggest_coalesce_batch_size are coalesced normally + let mut coalescer = BatchCoalescer::new( + Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])), + 100, + ); + coalescer.set_biggest_coalesce_batch_size(Some(500)); + + // Push small batches that should be coalesced + let small_batch = create_test_batch(50); + coalescer.push_batch(small_batch.clone()).unwrap(); + + // Should not have completed batch yet (only 50 rows, target is 100) + assert!(!coalescer.has_completed_batch()); + assert_eq!(coalescer.get_buffered_rows(), 50); + + // Push another small batch + coalescer.push_batch(small_batch).unwrap(); + + // Now should have a completed batch (100 rows total) + assert!(coalescer.has_completed_batch()); + let output_batch = coalescer.next_completed_batch().unwrap(); + assert_eq!(output_batch.num_rows(), 100); + + assert_eq!(coalescer.get_buffered_rows(), 0); + } + + #[test] + fn test_biggest_coalesce_batch_size_equal_boundary() { + // Test behavior when batch size equals biggest_coalesce_batch_size + let mut coalescer = BatchCoalescer::new( + Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])), + 100, + ); + coalescer.set_biggest_coalesce_batch_size(Some(500)); + + // Push a batch exactly equal to the limit + let boundary_batch = create_test_batch(500); + coalescer.push_batch(boundary_batch).unwrap(); + + // Should be coalesced (not bypass) since it's equal, not greater + let mut output_count = 0; + while coalescer.next_completed_batch().is_some() { + output_count += 1; + } + + coalescer.finish_buffered_batch().unwrap(); + while coalescer.next_completed_batch().is_some() { + output_count += 1; + } + + // Should have 5 batches of 100 rows each + assert_eq!(output_count, 5); + } + + #[test] + fn test_biggest_coalesce_batch_size_first_large_then_consecutive_bypass() { + // Test the new consecutive large batch bypass behavior + // Pattern: small batches -> first large batch (coalesced) -> consecutive large batches (bypass) + let mut coalescer = BatchCoalescer::new( + Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])), + 100, + ); + coalescer.set_biggest_coalesce_batch_size(Some(200)); + + let small_batch = create_test_batch(50); + + // Push small batch first to create buffered data + coalescer.push_batch(small_batch).unwrap(); + assert_eq!(coalescer.get_buffered_rows(), 50); + assert!(!coalescer.has_completed_batch()); + + // Push first large batch - should go through normal coalescing due to buffered data + let large_batch1 = create_test_batch(250); + coalescer.push_batch(large_batch1).unwrap(); + + // 50 + 250 = 300 -> 3 complete batches of 100, 0 rows buffered + let mut completed_batches = vec![]; + while let Some(batch) = coalescer.next_completed_batch() { + completed_batches.push(batch); + } + assert_eq!(completed_batches.len(), 3); + assert_eq!(coalescer.get_buffered_rows(), 0); + + // Now push consecutive large batches - they should bypass + let large_batch2 = create_test_batch(300); + let large_batch3 = create_test_batch(400); + + // Push second large batch - should bypass since it's consecutive and buffer is empty + coalescer.push_batch(large_batch2).unwrap(); + assert!(coalescer.has_completed_batch()); + let output = coalescer.next_completed_batch().unwrap(); + assert_eq!(output.num_rows(), 300); // bypassed with original size + assert_eq!(coalescer.get_buffered_rows(), 0); + + // Push third large batch - should also bypass + coalescer.push_batch(large_batch3).unwrap(); + assert!(coalescer.has_completed_batch()); + let output = coalescer.next_completed_batch().unwrap(); + assert_eq!(output.num_rows(), 400); // bypassed with original size + assert_eq!(coalescer.get_buffered_rows(), 0); + } + + #[test] + fn test_biggest_coalesce_batch_size_empty_batch() { + // Test that empty batches don't trigger the bypass logic + let mut coalescer = BatchCoalescer::new( + Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])), + 100, + ); + coalescer.set_biggest_coalesce_batch_size(Some(50)); + + let empty_batch = create_test_batch(0); + coalescer.push_batch(empty_batch).unwrap(); + + // Empty batch should be handled normally (no effect) + assert!(!coalescer.has_completed_batch()); + assert_eq!(coalescer.get_buffered_rows(), 0); + } + + #[test] + fn test_biggest_coalesce_batch_size_with_buffered_data_no_bypass() { + // Test that when there is buffered data, large batches do NOT bypass (unless consecutive) + let mut coalescer = BatchCoalescer::new( + Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])), + 100, + ); + coalescer.set_biggest_coalesce_batch_size(Some(200)); + + // Add some buffered data first + let small_batch = create_test_batch(30); + coalescer.push_batch(small_batch.clone()).unwrap(); + coalescer.push_batch(small_batch).unwrap(); + assert_eq!(coalescer.get_buffered_rows(), 60); + + // Push large batch that would normally bypass, but shouldn't because buffered_rows > 0 + let large_batch = create_test_batch(250); + coalescer.push_batch(large_batch).unwrap(); + + // The large batch should be processed through normal coalescing logic + // Total: 60 (buffered) + 250 (new) = 310 rows + // Output: 3 complete batches of 100 rows each, 10 rows remain buffered + + let mut completed_batches = vec![]; + while let Some(batch) = coalescer.next_completed_batch() { + completed_batches.push(batch); + } + + assert_eq!(completed_batches.len(), 3); + for batch in &completed_batches { + assert_eq!(batch.num_rows(), 100); + } + assert_eq!(coalescer.get_buffered_rows(), 10); + } + + #[test] + fn test_biggest_coalesce_batch_size_zero_limit() { + // Test edge case where limit is 0 (all batches bypass when no buffered data) + let mut coalescer = BatchCoalescer::new( + Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])), + 100, + ); + coalescer.set_biggest_coalesce_batch_size(Some(0)); + + // Even a 1-row batch should bypass when there's no buffered data + let tiny_batch = create_test_batch(1); + coalescer.push_batch(tiny_batch).unwrap(); + + assert!(coalescer.has_completed_batch()); + let output = coalescer.next_completed_batch().unwrap(); + assert_eq!(output.num_rows(), 1); + } + + #[test] + fn test_biggest_coalesce_batch_size_bypass_only_when_no_buffer() { + // Test that bypass only occurs when buffered_rows == 0 + let mut coalescer = BatchCoalescer::new( + Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])), + 100, + ); + coalescer.set_biggest_coalesce_batch_size(Some(200)); + + // First, push a large batch with no buffered data - should bypass + let large_batch = create_test_batch(300); + coalescer.push_batch(large_batch.clone()).unwrap(); + + assert!(coalescer.has_completed_batch()); + let output = coalescer.next_completed_batch().unwrap(); + assert_eq!(output.num_rows(), 300); // bypassed + assert_eq!(coalescer.get_buffered_rows(), 0); + + // Now add some buffered data + let small_batch = create_test_batch(50); + coalescer.push_batch(small_batch).unwrap(); + assert_eq!(coalescer.get_buffered_rows(), 50); + + // Push the same large batch again - should NOT bypass this time (not consecutive) + coalescer.push_batch(large_batch).unwrap(); + + // Should process through normal coalescing: 50 + 300 = 350 rows + // Output: 3 complete batches of 100 rows, 50 rows buffered + let mut completed_batches = vec![]; + while let Some(batch) = coalescer.next_completed_batch() { + completed_batches.push(batch); + } + + assert_eq!(completed_batches.len(), 3); + for batch in &completed_batches { + assert_eq!(batch.num_rows(), 100); + } + assert_eq!(coalescer.get_buffered_rows(), 50); + } + + #[test] + fn test_biggest_coalesce_batch_size_consecutive_large_batches_scenario() { + // Test your exact scenario: 20, 20, 30, 700, 600, 700, 900, 700, 600 + let mut coalescer = BatchCoalescer::new( + Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])), + 1000, + ); + coalescer.set_biggest_coalesce_batch_size(Some(500)); + + // Push small batches first + coalescer.push_batch(create_test_batch(20)).unwrap(); + coalescer.push_batch(create_test_batch(20)).unwrap(); + coalescer.push_batch(create_test_batch(30)).unwrap(); + + assert_eq!(coalescer.get_buffered_rows(), 70); + assert!(!coalescer.has_completed_batch()); + + // Push first large batch (700) - should coalesce due to buffered data + coalescer.push_batch(create_test_batch(700)).unwrap(); + + // 70 + 700 = 770 rows, not enough for 1000, so all stay buffered + assert_eq!(coalescer.get_buffered_rows(), 770); + assert!(!coalescer.has_completed_batch()); + + // Push second large batch (600) - should bypass since previous was large + coalescer.push_batch(create_test_batch(600)).unwrap(); + + // Should flush buffer (770 rows) and bypass the 600 + let mut outputs = vec![]; + while let Some(batch) = coalescer.next_completed_batch() { + outputs.push(batch); + } + assert_eq!(outputs.len(), 2); // one flushed buffer batch (770) + one bypassed (600) + assert_eq!(outputs[0].num_rows(), 770); + assert_eq!(outputs[1].num_rows(), 600); + assert_eq!(coalescer.get_buffered_rows(), 0); + + // Push remaining large batches - should all bypass + let remaining_batches = [700, 900, 700, 600]; + for &size in &remaining_batches { + coalescer.push_batch(create_test_batch(size)).unwrap(); + + assert!(coalescer.has_completed_batch()); + let output = coalescer.next_completed_batch().unwrap(); + assert_eq!(output.num_rows(), size); + assert_eq!(coalescer.get_buffered_rows(), 0); + } + } + + #[test] + fn test_biggest_coalesce_batch_size_truly_consecutive_large_bypass() { + // Test truly consecutive large batches that should all bypass + // This test ensures buffer is completely empty between large batches + let mut coalescer = BatchCoalescer::new( + Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])), + 100, + ); + coalescer.set_biggest_coalesce_batch_size(Some(200)); + + // Push consecutive large batches with no prior buffered data + let large_batches = vec![ + create_test_batch(300), + create_test_batch(400), + create_test_batch(350), + create_test_batch(500), + ]; + + let mut all_outputs = vec![]; + + for (i, large_batch) in large_batches.into_iter().enumerate() { + let expected_size = large_batch.num_rows(); + + // Buffer should be empty before each large batch + assert_eq!( + coalescer.get_buffered_rows(), + 0, + "Buffer should be empty before batch {}", + i + ); + + coalescer.push_batch(large_batch).unwrap(); + + // Each large batch should bypass and produce exactly one output batch + assert!( + coalescer.has_completed_batch(), + "Should have completed batch after pushing batch {}", + i + ); + + let output = coalescer.next_completed_batch().unwrap(); + assert_eq!( + output.num_rows(), + expected_size, + "Batch {} should have bypassed with original size", + i + ); + + // Should be no more batches and buffer should be empty + assert!( + !coalescer.has_completed_batch(), + "Should have no more completed batches after batch {}", + i + ); + assert_eq!( + coalescer.get_buffered_rows(), + 0, + "Buffer should be empty after batch {}", + i + ); + + all_outputs.push(output); + } + + // Verify we got exactly 4 output batches with original sizes + assert_eq!(all_outputs.len(), 4); + assert_eq!(all_outputs[0].num_rows(), 300); + assert_eq!(all_outputs[1].num_rows(), 400); + assert_eq!(all_outputs[2].num_rows(), 350); + assert_eq!(all_outputs[3].num_rows(), 500); + } + + #[test] + fn test_biggest_coalesce_batch_size_reset_consecutive_on_small_batch() { + // Test that small batches reset the consecutive large batch tracking + let mut coalescer = BatchCoalescer::new( + Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])), + 100, + ); + coalescer.set_biggest_coalesce_batch_size(Some(200)); + + // Push first large batch - should bypass (no buffered data) + coalescer.push_batch(create_test_batch(300)).unwrap(); + let output = coalescer.next_completed_batch().unwrap(); + assert_eq!(output.num_rows(), 300); + + // Push second large batch - should bypass (consecutive) + coalescer.push_batch(create_test_batch(400)).unwrap(); + let output = coalescer.next_completed_batch().unwrap(); + assert_eq!(output.num_rows(), 400); + + // Push small batch - resets consecutive tracking + coalescer.push_batch(create_test_batch(50)).unwrap(); + assert_eq!(coalescer.get_buffered_rows(), 50); + + // Push large batch again - should NOT bypass due to buffered data + coalescer.push_batch(create_test_batch(350)).unwrap(); + + // Should coalesce: 50 + 350 = 400 -> 4 complete batches of 100 + let mut outputs = vec![]; + while let Some(batch) = coalescer.next_completed_batch() { + outputs.push(batch); + } + assert_eq!(outputs.len(), 4); + for batch in outputs { + assert_eq!(batch.num_rows(), 100); + } + assert_eq!(coalescer.get_buffered_rows(), 0); + } } diff --git a/arrow-select/src/concat.rs b/arrow-select/src/concat.rs index 6636988305c5..bd93650055bc 100644 --- a/arrow-select/src/concat.rs +++ b/arrow-select/src/concat.rs @@ -1189,11 +1189,9 @@ mod tests { // 3 * 3 = 9 // ------------+ // 909 - // closest 64 byte aligned cap = 960 let arr = concat(&[&a, &b, &c]).unwrap(); - // this would have been 1280 if we did not precompute the value lengths. - assert_eq!(arr.to_data().buffers()[1].capacity(), 960); + assert_eq!(arr.to_data().buffers()[1].capacity(), 909); } #[test] @@ -1328,12 +1326,12 @@ mod tests { let a = concat(&[&a, &b]).unwrap(); let data = a.to_data(); assert_eq!(data.buffers()[0].len(), 440); - assert_eq!(data.buffers()[0].capacity(), 448); // Nearest multiple of 64 + assert_eq!(data.buffers()[0].capacity(), 440); let a = concat(&[&a.slice(10, 20), &b]).unwrap(); let data = a.to_data(); assert_eq!(data.buffers()[0].len(), 120); - assert_eq!(data.buffers()[0].capacity(), 128); // Nearest multiple of 64 + assert_eq!(data.buffers()[0].capacity(), 120); let a = StringArray::from_iter_values(std::iter::repeat_n("foo", 100)); let b = StringArray::from(vec!["bingo", "bongo", "lorem", ""]); @@ -1342,21 +1340,21 @@ mod tests { let data = a.to_data(); // (100 + 4 + 1) * size_of() assert_eq!(data.buffers()[0].len(), 420); - assert_eq!(data.buffers()[0].capacity(), 448); // Nearest multiple of 64 + assert_eq!(data.buffers()[0].capacity(), 420); // len("foo") * 100 + len("bingo") + len("bongo") + len("lorem") assert_eq!(data.buffers()[1].len(), 315); - assert_eq!(data.buffers()[1].capacity(), 320); // Nearest multiple of 64 + assert_eq!(data.buffers()[1].capacity(), 315); let a = concat(&[&a.slice(10, 40), &b]).unwrap(); let data = a.to_data(); // (40 + 4 + 5) * size_of() assert_eq!(data.buffers()[0].len(), 180); - assert_eq!(data.buffers()[0].capacity(), 192); // Nearest multiple of 64 + assert_eq!(data.buffers()[0].capacity(), 180); // len("foo") * 40 + len("bingo") + len("bongo") + len("lorem") assert_eq!(data.buffers()[1].len(), 135); - assert_eq!(data.buffers()[1].capacity(), 192); // Nearest multiple of 64 + assert_eq!(data.buffers()[1].capacity(), 135); let a = LargeBinaryArray::from_iter_values(std::iter::repeat_n(b"foo", 100)); let b = LargeBinaryArray::from_iter_values(std::iter::repeat_n(b"cupcakes", 10)); @@ -1365,21 +1363,21 @@ mod tests { let data = a.to_data(); // (100 + 10 + 1) * size_of() assert_eq!(data.buffers()[0].len(), 888); - assert_eq!(data.buffers()[0].capacity(), 896); // Nearest multiple of 64 + assert_eq!(data.buffers()[0].capacity(), 888); // len("foo") * 100 + len("cupcakes") * 10 assert_eq!(data.buffers()[1].len(), 380); - assert_eq!(data.buffers()[1].capacity(), 384); // Nearest multiple of 64 + assert_eq!(data.buffers()[1].capacity(), 380); let a = concat(&[&a.slice(10, 40), &b]).unwrap(); let data = a.to_data(); // (40 + 10 + 1) * size_of() assert_eq!(data.buffers()[0].len(), 408); - assert_eq!(data.buffers()[0].capacity(), 448); // Nearest multiple of 64 + assert_eq!(data.buffers()[0].capacity(), 408); // len("foo") * 40 + len("cupcakes") * 10 assert_eq!(data.buffers()[1].len(), 200); - assert_eq!(data.buffers()[1].capacity(), 256); // Nearest multiple of 64 + assert_eq!(data.buffers()[1].capacity(), 200); } #[test] diff --git a/arrow/benches/array_from_vec.rs b/arrow/benches/array_from_vec.rs index 2850eae5d718..dc1b2d7b749d 100644 --- a/arrow/benches/array_from_vec.rs +++ b/arrow/benches/array_from_vec.rs @@ -73,6 +73,28 @@ fn struct_array_from_vec( hint::black_box(StructArray::try_from(vec![(field1, strings), (field2, ints)]).unwrap()); } +fn decimal32_array_from_vec(array: &[Option]) { + hint::black_box( + array + .iter() + .copied() + .collect::() + .with_precision_and_scale(9, 2) + .unwrap(), + ); +} + +fn decimal64_array_from_vec(array: &[Option]) { + hint::black_box( + array + .iter() + .copied() + .collect::() + .with_precision_and_scale(17, 2) + .unwrap(), + ); +} + fn decimal128_array_from_vec(array: &[Option]) { hint::black_box( array @@ -96,6 +118,30 @@ fn decimal256_array_from_vec(array: &[Option]) { } fn decimal_benchmark(c: &mut Criterion) { + // bench decimal32 array + // create option array + let size: usize = 1 << 15; + let mut rng = rand::rng(); + let mut array = vec![]; + for _ in 0..size { + array.push(Some(rng.random_range::(0..99999999))); + } + c.bench_function("decimal32_array_from_vec 32768", |b| { + b.iter(|| decimal32_array_from_vec(array.as_slice())) + }); + + // bench decimal64 array + // create option array + let size: usize = 1 << 15; + let mut rng = rand::rng(); + let mut array = vec![]; + for _ in 0..size { + array.push(Some(rng.random_range::(0..9999999999))); + } + c.bench_function("decimal64_array_from_vec 32768", |b| { + b.iter(|| decimal64_array_from_vec(array.as_slice())) + }); + // bench decimal128 array // create option array let size: usize = 1 << 15; diff --git a/arrow/benches/builder.rs b/arrow/benches/builder.rs index 46dd18c0fa52..2374797961a1 100644 --- a/arrow/benches/builder.rs +++ b/arrow/benches/builder.rs @@ -108,6 +108,42 @@ fn bench_string(c: &mut Criterion) { group.finish(); } +fn bench_decimal32(c: &mut Criterion) { + c.bench_function("bench_decimal32_builder", |b| { + b.iter(|| { + let mut rng = rand::rng(); + let mut decimal_builder = Decimal32Builder::with_capacity(BATCH_SIZE); + for _ in 0..BATCH_SIZE { + decimal_builder.append_value(rng.random_range::(0..999999999)); + } + hint::black_box( + decimal_builder + .finish() + .with_precision_and_scale(9, 0) + .unwrap(), + ); + }) + }); +} + +fn bench_decimal64(c: &mut Criterion) { + c.bench_function("bench_decimal64_builder", |b| { + b.iter(|| { + let mut rng = rand::rng(); + let mut decimal_builder = Decimal64Builder::with_capacity(BATCH_SIZE); + for _ in 0..BATCH_SIZE { + decimal_builder.append_value(rng.random_range::(0..9999999999)); + } + hint::black_box( + decimal_builder + .finish() + .with_precision_and_scale(18, 0) + .unwrap(), + ); + }) + }); +} + fn bench_decimal128(c: &mut Criterion) { c.bench_function("bench_decimal128_builder", |b| { b.iter(|| { @@ -151,6 +187,8 @@ criterion_group!( bench_primitive_nulls, bench_bool, bench_string, + bench_decimal32, + bench_decimal64, bench_decimal128, bench_decimal256, ); diff --git a/arrow/benches/cast_kernels.rs b/arrow/benches/cast_kernels.rs index d01031be5fd4..179fde0a70be 100644 --- a/arrow/benches/cast_kernels.rs +++ b/arrow/benches/cast_kernels.rs @@ -83,6 +83,36 @@ fn build_utf8_date_time_array(size: usize, with_nulls: bool) -> ArrayRef { Arc::new(builder.finish()) } +fn build_decimal32_array(size: usize, precision: u8, scale: i8) -> ArrayRef { + let mut rng = seedable_rng(); + let mut builder = Decimal32Builder::with_capacity(size); + + for _ in 0..size { + builder.append_value(rng.random_range::(0..1000000)); + } + Arc::new( + builder + .finish() + .with_precision_and_scale(precision, scale) + .unwrap(), + ) +} + +fn build_decimal64_array(size: usize, precision: u8, scale: i8) -> ArrayRef { + let mut rng = seedable_rng(); + let mut builder = Decimal64Builder::with_capacity(size); + + for _ in 0..size { + builder.append_value(rng.random_range::(0..1000000000)); + } + Arc::new( + builder + .finish() + .with_precision_and_scale(precision, scale) + .unwrap(), + ) +} + fn build_decimal128_array(size: usize, precision: u8, scale: i8) -> ArrayRef { let mut rng = seedable_rng(); let mut builder = Decimal128Builder::with_capacity(size); @@ -159,6 +189,8 @@ fn add_benchmark(c: &mut Criterion) { let utf8_date_array = build_utf8_date_array(512, true); let utf8_date_time_array = build_utf8_date_time_array(512, true); + let decimal32_array = build_decimal32_array(512, 9, 3); + let decimal64_array = build_decimal64_array(512, 10, 3); let decimal128_array = build_decimal128_array(512, 10, 3); let decimal256_array = build_decimal256_array(512, 50, 3); let string_array = build_string_array(512); @@ -248,6 +280,22 @@ fn add_benchmark(c: &mut Criterion) { b.iter(|| cast_array(&utf8_date_time_array, DataType::Date64)) }); + c.bench_function("cast decimal32 to decimal32 512", |b| { + b.iter(|| cast_array(&decimal32_array, DataType::Decimal32(9, 4))) + }); + c.bench_function("cast decimal32 to decimal32 512 lower precision", |b| { + b.iter(|| cast_array(&decimal32_array, DataType::Decimal32(6, 5))) + }); + c.bench_function("cast decimal32 to decimal64 512", |b| { + b.iter(|| cast_array(&decimal32_array, DataType::Decimal64(11, 5))) + }); + c.bench_function("cast decimal64 to decimal32 512", |b| { + b.iter(|| cast_array(&decimal64_array, DataType::Decimal32(9, 2))) + }); + c.bench_function("cast decimal64 to decimal64 512", |b| { + b.iter(|| cast_array(&decimal64_array, DataType::Decimal64(12, 4))) + }); + c.bench_function("cast decimal128 to decimal128 512", |b| { b.iter(|| cast_array(&decimal128_array, DataType::Decimal128(30, 5))) }); diff --git a/arrow/benches/decimal_validate.rs b/arrow/benches/decimal_validate.rs index dfa4f5992023..7867b10ba222 100644 --- a/arrow/benches/decimal_validate.rs +++ b/arrow/benches/decimal_validate.rs @@ -18,7 +18,10 @@ #[macro_use] extern crate criterion; -use arrow::array::{Array, Decimal128Array, Decimal128Builder, Decimal256Array, Decimal256Builder}; +use arrow::array::{ + Array, Decimal128Array, Decimal128Builder, Decimal256Array, Decimal256Builder, Decimal32Array, + Decimal32Builder, Decimal64Array, Decimal64Builder, +}; use criterion::Criterion; use rand::Rng; @@ -26,6 +29,14 @@ extern crate arrow; use arrow_buffer::i256; +fn validate_decimal32_array(array: Decimal32Array) { + array.with_precision_and_scale(8, 0).unwrap(); +} + +fn validate_decimal64_array(array: Decimal64Array) { + array.with_precision_and_scale(16, 0).unwrap(); +} + fn validate_decimal128_array(array: Decimal128Array) { array.with_precision_and_scale(35, 0).unwrap(); } @@ -34,6 +45,46 @@ fn validate_decimal256_array(array: Decimal256Array) { array.with_precision_and_scale(35, 0).unwrap(); } +fn validate_decimal32_benchmark(c: &mut Criterion) { + let mut rng = rand::rng(); + let size: i32 = 20000; + let mut decimal_builder = Decimal32Builder::with_capacity(size as usize); + for _ in 0..size { + decimal_builder.append_value(rng.random_range::(0..99999999)); + } + let decimal_array = decimal_builder + .finish() + .with_precision_and_scale(9, 0) + .unwrap(); + let data = decimal_array.into_data(); + c.bench_function("validate_decimal32_array 20000", |b| { + b.iter(|| { + let array = Decimal32Array::from(data.clone()); + validate_decimal32_array(array); + }) + }); +} + +fn validate_decimal64_benchmark(c: &mut Criterion) { + let mut rng = rand::rng(); + let size: i64 = 20000; + let mut decimal_builder = Decimal64Builder::with_capacity(size as usize); + for _ in 0..size { + decimal_builder.append_value(rng.random_range::(0..999999999999)); + } + let decimal_array = decimal_builder + .finish() + .with_precision_and_scale(18, 0) + .unwrap(); + let data = decimal_array.into_data(); + c.bench_function("validate_decimal64_array 20000", |b| { + b.iter(|| { + let array = Decimal64Array::from(data.clone()); + validate_decimal64_array(array); + }) + }); +} + fn validate_decimal128_benchmark(c: &mut Criterion) { let mut rng = rand::rng(); let size: i128 = 20000; @@ -78,6 +129,8 @@ fn validate_decimal256_benchmark(c: &mut Criterion) { criterion_group!( benches, + validate_decimal32_benchmark, + validate_decimal64_benchmark, validate_decimal128_benchmark, validate_decimal256_benchmark, ); diff --git a/arrow/src/tensor.rs b/arrow/src/tensor.rs index cd135a2f04df..3b65ea7b52f9 100644 --- a/arrow/src/tensor.rs +++ b/arrow/src/tensor.rs @@ -86,6 +86,10 @@ pub type BooleanTensor<'a> = Tensor<'a, BooleanType>; pub type Date32Tensor<'a> = Tensor<'a, Date32Type>; /// [Tensor] of type [Int16Type] pub type Date64Tensor<'a> = Tensor<'a, Date64Type>; +/// [Tensor] of type [Decimal32Type] +pub type Decimal32Tensor<'a> = Tensor<'a, Decimal32Type>; +/// [Tensor] of type [Decimal64Type] +pub type Decimal64Tensor<'a> = Tensor<'a, Decimal64Type>; /// [Tensor] of type [Decimal128Type] pub type Decimal128Tensor<'a> = Tensor<'a, Decimal128Type>; /// [Tensor] of type [Decimal256Type] diff --git a/arrow/tests/array_cast.rs b/arrow/tests/array_cast.rs index da7d37fc48a4..522687c3e493 100644 --- a/arrow/tests/array_cast.rs +++ b/arrow/tests/array_cast.rs @@ -18,19 +18,21 @@ use arrow_array::builder::{PrimitiveDictionaryBuilder, StringDictionaryBuilder, UnionBuilder}; use arrow_array::cast::AsArray; use arrow_array::types::{ - ArrowDictionaryKeyType, Decimal128Type, Decimal256Type, Int16Type, Int32Type, Int64Type, - Int8Type, TimestampMicrosecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + ArrowDictionaryKeyType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, + Int16Type, Int32Type, Int64Type, Int8Type, TimestampMicrosecondType, UInt16Type, UInt32Type, + UInt64Type, UInt8Type, }; use arrow_array::{ Array, ArrayRef, ArrowPrimitiveType, BinaryArray, BooleanArray, Date32Array, Date64Array, - Decimal128Array, DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray, - DurationSecondArray, FixedSizeBinaryArray, FixedSizeListArray, Float16Array, Float32Array, - Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, - IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, LargeListArray, - LargeStringArray, ListArray, NullArray, PrimitiveArray, StringArray, StructArray, - Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, UnionArray, + Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array, DurationMicrosecondArray, + DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, FixedSizeBinaryArray, + FixedSizeListArray, Float16Array, Float32Array, Float64Array, Int16Array, Int32Array, + Int64Array, Int8Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, + LargeBinaryArray, LargeListArray, LargeStringArray, ListArray, NullArray, PrimitiveArray, + StringArray, StructArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, + Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, + UInt8Array, UnionArray, }; use arrow_buffer::{i256, Buffer, IntervalDayTime, IntervalMonthDayNano}; use arrow_cast::pretty::pretty_format_columns; @@ -261,7 +263,37 @@ fn get_arrays_of_all_types() -> Vec { Arc::new(DurationMillisecondArray::from(vec![1000, 2000])), Arc::new(DurationMicrosecondArray::from(vec![1000, 2000])), Arc::new(DurationNanosecondArray::from(vec![1000, 2000])), + Arc::new(create_decimal32_array(vec![Some(1), Some(2), Some(3)], 9, 0).unwrap()), + Arc::new(create_decimal64_array(vec![Some(1), Some(2), Some(3)], 18, 0).unwrap()), Arc::new(create_decimal128_array(vec![Some(1), Some(2), Some(3)], 38, 0).unwrap()), + Arc::new( + create_decimal256_array( + vec![ + Some(i256::from_i128(1)), + Some(i256::from_i128(2)), + Some(i256::from_i128(3)), + ], + 40, + 0, + ) + .unwrap(), + ), + make_dictionary_primitive::(vec![1, 2]), + make_dictionary_primitive::(vec![1, 2]), + make_dictionary_primitive::(vec![1, 2]), + make_dictionary_primitive::(vec![1, 2]), + make_dictionary_primitive::(vec![1, 2]), + make_dictionary_primitive::(vec![1, 2]), + make_dictionary_primitive::(vec![1, 2]), + make_dictionary_primitive::(vec![1, 2]), + make_dictionary_primitive::(vec![1, 2]), + make_dictionary_primitive::(vec![1, 2]), + make_dictionary_primitive::(vec![1, 2]), + make_dictionary_primitive::(vec![1, 2]), + make_dictionary_primitive::(vec![1, 2]), + make_dictionary_primitive::(vec![1, 2]), + make_dictionary_primitive::(vec![1, 2]), + make_dictionary_primitive::(vec![1, 2]), make_dictionary_primitive::(vec![1, 2]), make_dictionary_primitive::(vec![1, 2]), make_dictionary_primitive::(vec![1, 2]), @@ -411,6 +443,28 @@ fn make_dictionary_utf8() -> ArrayRef { Arc::new(b.finish()) } +fn create_decimal32_array( + array: Vec>, + precision: u8, + scale: i8, +) -> Result { + array + .into_iter() + .collect::() + .with_precision_and_scale(precision, scale) +} + +fn create_decimal64_array( + array: Vec>, + precision: u8, + scale: i8, +) -> Result { + array + .into_iter() + .collect::() + .with_precision_and_scale(precision, scale) +} + fn create_decimal128_array( array: Vec>, precision: u8, @@ -422,6 +476,17 @@ fn create_decimal128_array( .with_precision_and_scale(precision, scale) } +fn create_decimal256_array( + array: Vec>, + precision: u8, + scale: i8, +) -> Result { + array + .into_iter() + .collect::() + .with_precision_and_scale(precision, scale) +} + // Get a selection of datatypes to try and cast to fn get_all_types() -> Vec { use DataType::*; @@ -501,6 +566,8 @@ fn get_all_types() -> Vec { Dictionary(Box::new(key_type.clone()), Box::new(LargeUtf8)), Dictionary(Box::new(key_type.clone()), Box::new(Binary)), Dictionary(Box::new(key_type.clone()), Box::new(LargeBinary)), + Dictionary(Box::new(key_type.clone()), Box::new(Decimal32(9, 0))), + Dictionary(Box::new(key_type.clone()), Box::new(Decimal64(18, 0))), Dictionary(Box::new(key_type.clone()), Box::new(Decimal128(38, 0))), Dictionary(Box::new(key_type), Box::new(Decimal256(76, 0))), ] diff --git a/dev/release/update_change_log.sh b/dev/release/update_change_log.sh index e447909fd362..b99a21ffa708 100755 --- a/dev/release/update_change_log.sh +++ b/dev/release/update_change_log.sh @@ -29,8 +29,8 @@ set -e -SINCE_TAG="55.2.0" -FUTURE_RELEASE="56.0.0" +SINCE_TAG="56.0.0" +FUTURE_RELEASE="56.1.0" SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" SOURCE_TOP_DIR="$(cd "${SOURCE_DIR}/../../" && pwd)" diff --git a/parquet-testing b/parquet-testing index b68bea40fed8..a3d96a65e11e 160000 --- a/parquet-testing +++ b/parquet-testing @@ -1 +1 @@ -Subproject commit b68bea40fed8d1a780a9e09dd2262017e04b19ad +Subproject commit a3d96a65e11e2bbca7d22a894e8313ede90a33a3 diff --git a/parquet-variant-compute/Cargo.toml b/parquet-variant-compute/Cargo.toml index 0aa926ee7fa4..819a131f9c42 100644 --- a/parquet-variant-compute/Cargo.toml +++ b/parquet-variant-compute/Cargo.toml @@ -36,6 +36,7 @@ arrow-schema = { workspace = true } half = { version = "2.1", default-features = false } parquet-variant = { workspace = true } parquet-variant-json = { workspace = true } +chrono = { workspace = true } [lib] name = "parquet_variant_compute" diff --git a/parquet-variant-compute/benches/variant_kernels.rs b/parquet-variant-compute/benches/variant_kernels.rs index 8fd6af333fed..5e97f948b231 100644 --- a/parquet-variant-compute/benches/variant_kernels.rs +++ b/parquet-variant-compute/benches/variant_kernels.rs @@ -20,7 +20,7 @@ use arrow::util::test_util::seedable_rng; use criterion::{criterion_group, criterion_main, Criterion}; use parquet_variant::{Variant, VariantBuilder}; use parquet_variant_compute::variant_get::{variant_get, GetOptions}; -use parquet_variant_compute::{batch_json_string_to_variant, VariantArray, VariantArrayBuilder}; +use parquet_variant_compute::{json_to_variant, VariantArray, VariantArrayBuilder}; use rand::distr::Alphanumeric; use rand::rngs::StdRng; use rand::Rng; @@ -34,7 +34,7 @@ fn benchmark_batch_json_string_to_variant(c: &mut Criterion) { "batch_json_string_to_variant repeated_struct 8k string", |b| { b.iter(|| { - let _ = batch_json_string_to_variant(&array_ref).unwrap(); + let _ = json_to_variant(&array_ref).unwrap(); }); }, ); @@ -43,7 +43,7 @@ fn benchmark_batch_json_string_to_variant(c: &mut Criterion) { let array_ref: ArrayRef = Arc::new(input_array); c.bench_function("batch_json_string_to_variant json_list 8k string", |b| { b.iter(|| { - let _ = batch_json_string_to_variant(&array_ref).unwrap(); + let _ = json_to_variant(&array_ref).unwrap(); }); }); @@ -60,7 +60,7 @@ fn benchmark_batch_json_string_to_variant(c: &mut Criterion) { let array_ref: ArrayRef = Arc::new(input_array); c.bench_function(&id, |b| { b.iter(|| { - let _ = batch_json_string_to_variant(&array_ref).unwrap(); + let _ = json_to_variant(&array_ref).unwrap(); }); }); @@ -77,7 +77,7 @@ fn benchmark_batch_json_string_to_variant(c: &mut Criterion) { let array_ref: ArrayRef = Arc::new(input_array); c.bench_function(&id, |b| { b.iter(|| { - let _ = batch_json_string_to_variant(&array_ref).unwrap(); + let _ = json_to_variant(&array_ref).unwrap(); }); }); } diff --git a/parquet-variant-compute/src/arrow_to_variant.rs b/parquet-variant-compute/src/arrow_to_variant.rs new file mode 100644 index 000000000000..26713ce8ee19 --- /dev/null +++ b/parquet-variant-compute/src/arrow_to_variant.rs @@ -0,0 +1,1976 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; + +use crate::type_conversion::{decimal_to_variant_decimal, CastOptions}; +use arrow::array::{ + Array, AsArray, GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, +}; +use arrow::compute::kernels::cast; +use arrow::datatypes::{ + ArrowNativeType, ArrowPrimitiveType, ArrowTemporalType, ArrowTimestampType, Date32Type, + Date64Type, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + RunEndIndexType, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use arrow::temporal_conversions::{as_date, as_datetime, as_time}; +use arrow_schema::{ArrowError, DataType, TimeUnit}; +use chrono::{DateTime, TimeZone, Utc}; +use parquet_variant::{ + ObjectFieldBuilder, Variant, VariantBuilderExt, VariantDecimal16, VariantDecimal4, + VariantDecimal8, +}; + +// ============================================================================ +// Row-oriented builders for efficient Arrow-to-Variant conversion +// ============================================================================ + +/// Row builder for converting Arrow arrays to VariantArray row by row +pub(crate) enum ArrowToVariantRowBuilder<'a> { + Null(NullArrowToVariantBuilder), + Boolean(BooleanArrowToVariantBuilder<'a>), + PrimitiveInt8(PrimitiveArrowToVariantBuilder<'a, Int8Type>), + PrimitiveInt16(PrimitiveArrowToVariantBuilder<'a, Int16Type>), + PrimitiveInt32(PrimitiveArrowToVariantBuilder<'a, Int32Type>), + PrimitiveInt64(PrimitiveArrowToVariantBuilder<'a, Int64Type>), + PrimitiveUInt8(PrimitiveArrowToVariantBuilder<'a, UInt8Type>), + PrimitiveUInt16(PrimitiveArrowToVariantBuilder<'a, UInt16Type>), + PrimitiveUInt32(PrimitiveArrowToVariantBuilder<'a, UInt32Type>), + PrimitiveUInt64(PrimitiveArrowToVariantBuilder<'a, UInt64Type>), + PrimitiveFloat16(PrimitiveArrowToVariantBuilder<'a, Float16Type>), + PrimitiveFloat32(PrimitiveArrowToVariantBuilder<'a, Float32Type>), + PrimitiveFloat64(PrimitiveArrowToVariantBuilder<'a, Float64Type>), + Decimal32(Decimal32ArrowToVariantBuilder<'a>), + Decimal64(Decimal64ArrowToVariantBuilder<'a>), + Decimal128(Decimal128ArrowToVariantBuilder<'a>), + Decimal256(Decimal256ArrowToVariantBuilder<'a>), + TimestampSecond(TimestampArrowToVariantBuilder<'a, TimestampSecondType>), + TimestampMillisecond(TimestampArrowToVariantBuilder<'a, TimestampMillisecondType>), + TimestampMicrosecond(TimestampArrowToVariantBuilder<'a, TimestampMicrosecondType>), + TimestampNanosecond(TimestampArrowToVariantBuilder<'a, TimestampNanosecondType>), + Date32(DateArrowToVariantBuilder<'a, Date32Type>), + Date64(DateArrowToVariantBuilder<'a, Date64Type>), + Time32Second(TimeArrowToVariantBuilder<'a, Time32SecondType>), + Time32Millisecond(TimeArrowToVariantBuilder<'a, Time32MillisecondType>), + Time64Microsecond(TimeArrowToVariantBuilder<'a, Time64MicrosecondType>), + Time64Nanosecond(TimeArrowToVariantBuilder<'a, Time64NanosecondType>), + Binary(BinaryArrowToVariantBuilder<'a, i32>), + LargeBinary(BinaryArrowToVariantBuilder<'a, i64>), + BinaryView(BinaryViewArrowToVariantBuilder<'a>), + FixedSizeBinary(FixedSizeBinaryArrowToVariantBuilder<'a>), + Utf8(StringArrowToVariantBuilder<'a, i32>), + LargeUtf8(StringArrowToVariantBuilder<'a, i64>), + Utf8View(StringViewArrowToVariantBuilder<'a>), + List(ListArrowToVariantBuilder<'a, i32>), + LargeList(ListArrowToVariantBuilder<'a, i64>), + Struct(StructArrowToVariantBuilder<'a>), + Map(MapArrowToVariantBuilder<'a>), + Union(UnionArrowToVariantBuilder<'a>), + Dictionary(DictionaryArrowToVariantBuilder<'a>), + RunEndEncodedInt16(RunEndEncodedArrowToVariantBuilder<'a, Int16Type>), + RunEndEncodedInt32(RunEndEncodedArrowToVariantBuilder<'a, Int32Type>), + RunEndEncodedInt64(RunEndEncodedArrowToVariantBuilder<'a, Int64Type>), +} + +impl<'a> ArrowToVariantRowBuilder<'a> { + /// Appends a single row at the given index to the supplied builder. + pub fn append_row( + &mut self, + builder: &mut impl VariantBuilderExt, + index: usize, + ) -> Result<(), ArrowError> { + use ArrowToVariantRowBuilder::*; + match self { + Null(b) => b.append_row(builder, index), + Boolean(b) => b.append_row(builder, index), + PrimitiveInt8(b) => b.append_row(builder, index), + PrimitiveInt16(b) => b.append_row(builder, index), + PrimitiveInt32(b) => b.append_row(builder, index), + PrimitiveInt64(b) => b.append_row(builder, index), + PrimitiveUInt8(b) => b.append_row(builder, index), + PrimitiveUInt16(b) => b.append_row(builder, index), + PrimitiveUInt32(b) => b.append_row(builder, index), + PrimitiveUInt64(b) => b.append_row(builder, index), + PrimitiveFloat16(b) => b.append_row(builder, index), + PrimitiveFloat32(b) => b.append_row(builder, index), + PrimitiveFloat64(b) => b.append_row(builder, index), + Decimal32(b) => b.append_row(builder, index), + Decimal64(b) => b.append_row(builder, index), + Decimal128(b) => b.append_row(builder, index), + Decimal256(b) => b.append_row(builder, index), + TimestampSecond(b) => b.append_row(builder, index), + TimestampMillisecond(b) => b.append_row(builder, index), + TimestampMicrosecond(b) => b.append_row(builder, index), + TimestampNanosecond(b) => b.append_row(builder, index), + Date32(b) => b.append_row(builder, index), + Date64(b) => b.append_row(builder, index), + Time32Second(b) => b.append_row(builder, index), + Time32Millisecond(b) => b.append_row(builder, index), + Time64Microsecond(b) => b.append_row(builder, index), + Time64Nanosecond(b) => b.append_row(builder, index), + Binary(b) => b.append_row(builder, index), + LargeBinary(b) => b.append_row(builder, index), + BinaryView(b) => b.append_row(builder, index), + FixedSizeBinary(b) => b.append_row(builder, index), + Utf8(b) => b.append_row(builder, index), + LargeUtf8(b) => b.append_row(builder, index), + Utf8View(b) => b.append_row(builder, index), + List(b) => b.append_row(builder, index), + LargeList(b) => b.append_row(builder, index), + Struct(b) => b.append_row(builder, index), + Map(b) => b.append_row(builder, index), + Union(b) => b.append_row(builder, index), + Dictionary(b) => b.append_row(builder, index), + RunEndEncodedInt16(b) => b.append_row(builder, index), + RunEndEncodedInt32(b) => b.append_row(builder, index), + RunEndEncodedInt64(b) => b.append_row(builder, index), + } + } +} + +/// Factory function to create the appropriate row builder for a given DataType +pub(crate) fn make_arrow_to_variant_row_builder<'a>( + data_type: &'a DataType, + array: &'a dyn Array, + options: &'a CastOptions, +) -> Result, ArrowError> { + use ArrowToVariantRowBuilder::*; + let builder = + match data_type { + DataType::Null => Null(NullArrowToVariantBuilder), + DataType::Boolean => Boolean(BooleanArrowToVariantBuilder::new(array)), + DataType::Int8 => PrimitiveInt8(PrimitiveArrowToVariantBuilder::new(array)), + DataType::Int16 => PrimitiveInt16(PrimitiveArrowToVariantBuilder::new(array)), + DataType::Int32 => PrimitiveInt32(PrimitiveArrowToVariantBuilder::new(array)), + DataType::Int64 => PrimitiveInt64(PrimitiveArrowToVariantBuilder::new(array)), + DataType::UInt8 => PrimitiveUInt8(PrimitiveArrowToVariantBuilder::new(array)), + DataType::UInt16 => PrimitiveUInt16(PrimitiveArrowToVariantBuilder::new(array)), + DataType::UInt32 => PrimitiveUInt32(PrimitiveArrowToVariantBuilder::new(array)), + DataType::UInt64 => PrimitiveUInt64(PrimitiveArrowToVariantBuilder::new(array)), + DataType::Float16 => PrimitiveFloat16(PrimitiveArrowToVariantBuilder::new(array)), + DataType::Float32 => PrimitiveFloat32(PrimitiveArrowToVariantBuilder::new(array)), + DataType::Float64 => PrimitiveFloat64(PrimitiveArrowToVariantBuilder::new(array)), + DataType::Decimal32(_, scale) => { + Decimal32(Decimal32ArrowToVariantBuilder::new(array, *scale)) + } + DataType::Decimal64(_, scale) => { + Decimal64(Decimal64ArrowToVariantBuilder::new(array, *scale)) + } + DataType::Decimal128(_, scale) => { + Decimal128(Decimal128ArrowToVariantBuilder::new(array, *scale)) + } + DataType::Decimal256(_, scale) => { + Decimal256(Decimal256ArrowToVariantBuilder::new(array, *scale)) + } + DataType::Timestamp(time_unit, time_zone) => { + match time_unit { + TimeUnit::Second => TimestampSecond(TimestampArrowToVariantBuilder::new( + array, + options, + time_zone.is_some(), + )), + TimeUnit::Millisecond => TimestampMillisecond( + TimestampArrowToVariantBuilder::new(array, options, time_zone.is_some()), + ), + TimeUnit::Microsecond => TimestampMicrosecond( + TimestampArrowToVariantBuilder::new(array, options, time_zone.is_some()), + ), + TimeUnit::Nanosecond => TimestampNanosecond( + TimestampArrowToVariantBuilder::new(array, options, time_zone.is_some()), + ), + } + } + DataType::Date32 => Date32(DateArrowToVariantBuilder::new(array, options)), + DataType::Date64 => Date64(DateArrowToVariantBuilder::new(array, options)), + DataType::Time32(time_unit) => match time_unit { + TimeUnit::Second => Time32Second(TimeArrowToVariantBuilder::new(array, options)), + TimeUnit::Millisecond => { + Time32Millisecond(TimeArrowToVariantBuilder::new(array, options)) + } + _ => { + return Err(ArrowError::CastError(format!( + "Unsupported Time32 unit: {time_unit:?}" + ))) + } + }, + DataType::Time64(time_unit) => match time_unit { + TimeUnit::Microsecond => { + Time64Microsecond(TimeArrowToVariantBuilder::new(array, options)) + } + TimeUnit::Nanosecond => { + Time64Nanosecond(TimeArrowToVariantBuilder::new(array, options)) + } + _ => { + return Err(ArrowError::CastError(format!( + "Unsupported Time64 unit: {time_unit:?}" + ))) + } + }, + DataType::Duration(_) | DataType::Interval(_) => { + return Err(ArrowError::InvalidArgumentError( + "Casting duration/interval types to Variant is not supported. \ + The Variant format does not define duration/interval types." + .to_string(), + )) + } + DataType::Binary => Binary(BinaryArrowToVariantBuilder::new(array)), + DataType::LargeBinary => LargeBinary(BinaryArrowToVariantBuilder::new(array)), + DataType::BinaryView => BinaryView(BinaryViewArrowToVariantBuilder::new(array)), + DataType::FixedSizeBinary(_) => { + FixedSizeBinary(FixedSizeBinaryArrowToVariantBuilder::new(array)) + } + DataType::Utf8 => Utf8(StringArrowToVariantBuilder::new(array)), + DataType::LargeUtf8 => LargeUtf8(StringArrowToVariantBuilder::new(array)), + DataType::Utf8View => Utf8View(StringViewArrowToVariantBuilder::new(array)), + DataType::List(_) => List(ListArrowToVariantBuilder::new(array, options)?), + DataType::LargeList(_) => LargeList(ListArrowToVariantBuilder::new(array, options)?), + DataType::Struct(_) => Struct(StructArrowToVariantBuilder::new( + array.as_struct(), + options, + )?), + DataType::Map(_, _) => Map(MapArrowToVariantBuilder::new(array, options)?), + DataType::Union(_, _) => Union(UnionArrowToVariantBuilder::new(array, options)?), + DataType::Dictionary(_, _) => { + Dictionary(DictionaryArrowToVariantBuilder::new(array, options)?) + } + DataType::RunEndEncoded(run_ends, _) => match run_ends.data_type() { + DataType::Int16 => { + RunEndEncodedInt16(RunEndEncodedArrowToVariantBuilder::new(array, options)?) + } + DataType::Int32 => { + RunEndEncodedInt32(RunEndEncodedArrowToVariantBuilder::new(array, options)?) + } + DataType::Int64 => { + RunEndEncodedInt64(RunEndEncodedArrowToVariantBuilder::new(array, options)?) + } + _ => { + return Err(ArrowError::CastError(format!( + "Unsupported run ends type: {:?}", + run_ends.data_type() + ))); + } + }, + dt => { + return Err(ArrowError::CastError(format!( + "Unsupported data type for casting to Variant: {dt:?}", + ))); + } + }; + Ok(builder) +} + +/// Macro to define (possibly generic) row builders with consistent structure and behavior. +/// +/// The macro optionally allows to define a transform for values read from the underlying +/// array. Transforms of the form `|value| { ... }` are infallible (and should produce something +/// that implements `Into`), while transforms of the form `|value| -> Option<_> { ... }` +/// are fallible (and should produce `Option>`); a failed tarnsform will either +/// append null to the builder or return an error, depending on cast options. +/// +/// Also supports optional extra fields that are passed to the constructor and which are available +/// by reference in the value transform. Providing a fallible value transform requires also +/// providing the extra field `options: &'a CastOptions`. +// TODO: If/when the macro_metavar_expr feature stabilizes, the `ignore` meta-function would allow +// us to "use" captured tokens without emitting them: +// +// ``` +// $( +// ${ignore($value)} +// $( +// ${ignore($option_ty)} +// options: &$lifetime CastOptions, +// )? +// )? +// ``` +// +// That, in turn, would allow us to inject the `options` field whenever the user specifies a +// fallible value transform, instead of requiring them to manually define it. This might not be +// worth the trouble, tho, because it makes for some pretty bulky and unwieldy macro expansions. +macro_rules! define_row_builder { + ( + struct $name:ident<$lifetime:lifetime $(, $generic:ident: $bound:path )?> + $( where $where_path:path: $where_bound:path $(,)? )? + $({ $($field:ident: $field_type:ty),+ $(,)? })?, + |$array_param:ident| -> $array_type:ty { $init_expr:expr } + $(, |$value:ident| $(-> Option<$option_ty:ty>)? $value_transform:expr)? + ) => { + pub(crate) struct $name<$lifetime $(, $generic: $bound )?> + $( where $where_path: $where_bound )? + { + array: &$lifetime $array_type, + $( $( $field: $field_type, )+ )? + } + + impl<$lifetime $(, $generic: $bound+ )?> $name<$lifetime $(, $generic)?> + $( where $where_path: $where_bound )? + { + pub(crate) fn new($array_param: &$lifetime dyn Array $(, $( $field: $field_type ),+ )?) -> Self { + Self { + array: $init_expr, + $( $( $field, )+ )? + } + } + + fn append_row(&self, builder: &mut impl VariantBuilderExt, index: usize) -> Result<(), ArrowError> { + if self.array.is_null(index) { + builder.append_null(); + } else { + // Macro hygiene: Give any extra fields names the value transform can access. + // + // The value transform doesn't normally reference cast options, but the macro's + // caller still has to declare the field because stable rust has no way to "use" + // a captured token without emitting it. So, silence unused variable warnings, + // assuming that's the `options` field. Unfortunately, that also silences + // legitimate compiler warnings if an infallible value transform fails to use + // its first extra field. + $( + #[allow(unused)] + $( let $field = &self.$field; )+ + )? + + // Apply the value transform, if any (with name swapping for hygiene) + let value = self.array.value(index); + $( + let $value = value; + let value = $value_transform; + $( + // NOTE: The `?` macro expansion fails without the type annotation. + let Some(value): Option<$option_ty> = value else { + if self.options.strict { + return Err(ArrowError::ComputeError(format!( + "Failed to convert value at index {index}: conversion failed", + ))); + } else { + builder.append_null(); + return Ok(()); + } + }; + )? + )? + builder.append_value(value); + } + Ok(()) + } + } + }; +} + +define_row_builder!( + struct BooleanArrowToVariantBuilder<'a>, + |array| -> arrow::array::BooleanArray { array.as_boolean() } +); + +define_row_builder!( + struct PrimitiveArrowToVariantBuilder<'a, T: ArrowPrimitiveType> + where T::Native: Into>, + |array| -> PrimitiveArray { array.as_primitive() } +); + +define_row_builder!( + struct Decimal32ArrowToVariantBuilder<'a> { + scale: i8, + }, + |array| -> arrow::array::Decimal32Array { array.as_primitive() }, + |value| decimal_to_variant_decimal!(value, scale, i32, VariantDecimal4) +); + +define_row_builder!( + struct Decimal64ArrowToVariantBuilder<'a> { + scale: i8, + }, + |array| -> arrow::array::Decimal64Array { array.as_primitive() }, + |value| decimal_to_variant_decimal!(value, scale, i64, VariantDecimal8) +); + +define_row_builder!( + struct Decimal128ArrowToVariantBuilder<'a> { + scale: i8, + }, + |array| -> arrow::array::Decimal128Array { array.as_primitive() }, + |value| decimal_to_variant_decimal!(value, scale, i128, VariantDecimal16) +); + +define_row_builder!( + struct Decimal256ArrowToVariantBuilder<'a> { + scale: i8, + }, + |array| -> arrow::array::Decimal256Array { array.as_primitive() }, + |value| { + // Decimal256 needs special handling - convert to i128 if possible + match value.to_i128() { + Some(i128_val) => decimal_to_variant_decimal!(i128_val, scale, i128, VariantDecimal16), + None => Variant::Null, // Value too large for i128 + } + } +); + +define_row_builder!( + struct TimestampArrowToVariantBuilder<'a, T: ArrowTimestampType> { + options: &'a CastOptions, + has_time_zone: bool, + }, + |array| -> arrow::array::PrimitiveArray { array.as_primitive() }, + |value| -> Option<_> { + // Convert using Arrow's temporal conversion functions + as_datetime::(value).map(|naive_datetime| { + if *has_time_zone { + // Has timezone -> DateTime -> TimestampMicros/TimestampNanos + let utc_dt: DateTime = Utc.from_utc_datetime(&naive_datetime); + Variant::from(utc_dt) // Uses From> for Variant + } else { + // No timezone -> NaiveDateTime -> TimestampNtzMicros/TimestampNtzNanos + Variant::from(naive_datetime) // Uses From for Variant + } + }) + } +); + +define_row_builder!( + struct DateArrowToVariantBuilder<'a, T: ArrowTemporalType> + where + i64: From, + { + options: &'a CastOptions, + }, + |array| -> PrimitiveArray { array.as_primitive() }, + |value| -> Option<_> { + let date_value = i64::from(value); + as_date::(date_value) + } +); + +define_row_builder!( + struct TimeArrowToVariantBuilder<'a, T: ArrowTemporalType> + where + i64: From, + { + options: &'a CastOptions, + }, + |array| -> PrimitiveArray { array.as_primitive() }, + |value| -> Option<_> { + let time_value = i64::from(value); + as_time::(time_value) + } +); + +define_row_builder!( + struct BinaryArrowToVariantBuilder<'a, O: OffsetSizeTrait>, + |array| -> GenericBinaryArray { array.as_binary() } +); + +define_row_builder!( + struct BinaryViewArrowToVariantBuilder<'a>, + |array| -> arrow::array::BinaryViewArray { array.as_byte_view() } +); + +define_row_builder!( + struct FixedSizeBinaryArrowToVariantBuilder<'a>, + |array| -> arrow::array::FixedSizeBinaryArray { array.as_fixed_size_binary() } +); + +define_row_builder!( + struct StringArrowToVariantBuilder<'a, O: OffsetSizeTrait>, + |array| -> GenericStringArray { array.as_string() } +); + +define_row_builder!( + struct StringViewArrowToVariantBuilder<'a>, + |array| -> arrow::array::StringViewArray { array.as_string_view() } +); + +/// Null builder that always appends null +pub(crate) struct NullArrowToVariantBuilder; + +impl NullArrowToVariantBuilder { + fn append_row( + &mut self, + builder: &mut impl VariantBuilderExt, + _index: usize, + ) -> Result<(), ArrowError> { + builder.append_null(); + Ok(()) + } +} + +/// Generic list builder for List and LargeList types +pub(crate) struct ListArrowToVariantBuilder<'a, O: OffsetSizeTrait> { + list_array: &'a arrow::array::GenericListArray, + values_builder: Box>, +} + +impl<'a, O: OffsetSizeTrait> ListArrowToVariantBuilder<'a, O> { + pub(crate) fn new(array: &'a dyn Array, options: &'a CastOptions) -> Result { + let list_array = array.as_list(); + let values = list_array.values(); + let values_builder = + make_arrow_to_variant_row_builder(values.data_type(), values.as_ref(), options)?; + + Ok(Self { + list_array, + values_builder: Box::new(values_builder), + }) + } + + fn append_row( + &mut self, + builder: &mut impl VariantBuilderExt, + index: usize, + ) -> Result<(), ArrowError> { + if self.list_array.is_null(index) { + builder.append_null(); + return Ok(()); + } + + let offsets = self.list_array.offsets(); + let start = offsets[index].as_usize(); + let end = offsets[index + 1].as_usize(); + + let mut list_builder = builder.try_new_list()?; + for value_index in start..end { + self.values_builder + .append_row(&mut list_builder, value_index)?; + } + list_builder.finish(); + Ok(()) + } +} + +/// Struct builder for StructArray +pub(crate) struct StructArrowToVariantBuilder<'a> { + struct_array: &'a arrow::array::StructArray, + field_builders: Vec<(&'a str, ArrowToVariantRowBuilder<'a>)>, +} + +impl<'a> StructArrowToVariantBuilder<'a> { + pub(crate) fn new( + struct_array: &'a arrow::array::StructArray, + options: &'a CastOptions, + ) -> Result { + let mut field_builders = Vec::new(); + + // Create a row builder for each field + for (field_name, field_array) in struct_array + .column_names() + .iter() + .zip(struct_array.columns().iter()) + { + let field_builder = make_arrow_to_variant_row_builder( + field_array.data_type(), + field_array.as_ref(), + options, + )?; + field_builders.push((*field_name, field_builder)); + } + + Ok(Self { + struct_array, + field_builders, + }) + } + + fn append_row( + &mut self, + builder: &mut impl VariantBuilderExt, + index: usize, + ) -> Result<(), ArrowError> { + if self.struct_array.is_null(index) { + builder.append_null(); + } else { + // Create object builder for this struct row + let mut obj_builder = builder.try_new_object()?; + + // Process each field + for (field_name, row_builder) in &mut self.field_builders { + let mut field_builder = + parquet_variant::ObjectFieldBuilder::new(field_name, &mut obj_builder); + row_builder.append_row(&mut field_builder, index)?; + } + + obj_builder.finish(); + } + Ok(()) + } +} + +/// Map builder for MapArray types +pub(crate) struct MapArrowToVariantBuilder<'a> { + map_array: &'a arrow::array::MapArray, + key_strings: arrow::array::StringArray, + values_builder: Box>, +} + +impl<'a> MapArrowToVariantBuilder<'a> { + pub(crate) fn new(array: &'a dyn Array, options: &'a CastOptions) -> Result { + let map_array = array.as_map(); + + // Pre-cast keys to strings once + let keys = cast(map_array.keys(), &DataType::Utf8)?; + let key_strings = keys.as_string::().clone(); + + // Create recursive builder for values + let values = map_array.values(); + let values_builder = + make_arrow_to_variant_row_builder(values.data_type(), values.as_ref(), options)?; + + Ok(Self { + map_array, + key_strings, + values_builder: Box::new(values_builder), + }) + } + + fn append_row( + &mut self, + builder: &mut impl VariantBuilderExt, + index: usize, + ) -> Result<(), ArrowError> { + // Check for NULL map first (via null bitmap) + if self.map_array.is_null(index) { + builder.append_null(); + return Ok(()); + } + + let offsets = self.map_array.offsets(); + let start = offsets[index].as_usize(); + let end = offsets[index + 1].as_usize(); + + // Create object builder for this map + let mut object_builder = builder.try_new_object()?; + + // Add each key-value pair (loop does nothing for empty maps - correct!) + for kv_index in start..end { + let key = self.key_strings.value(kv_index); + let mut field_builder = ObjectFieldBuilder::new(key, &mut object_builder); + self.values_builder + .append_row(&mut field_builder, kv_index)?; + } + + object_builder.finish(); + Ok(()) + } +} + +/// Union builder for both sparse and dense union arrays +/// +/// NOTE: Union type ids are _not_ required to be dense, hence the hash map for child builders. +pub(crate) struct UnionArrowToVariantBuilder<'a> { + union_array: &'a arrow::array::UnionArray, + child_builders: HashMap>>, +} + +impl<'a> UnionArrowToVariantBuilder<'a> { + pub(crate) fn new(array: &'a dyn Array, options: &'a CastOptions) -> Result { + let union_array = array.as_union(); + let type_ids = union_array.type_ids(); + + // Create child builders for each union field + let mut child_builders = HashMap::new(); + for &type_id in type_ids { + let child_array = union_array.child(type_id); + let child_builder = make_arrow_to_variant_row_builder( + child_array.data_type(), + child_array.as_ref(), + options, + )?; + child_builders.insert(type_id, Box::new(child_builder)); + } + + Ok(Self { + union_array, + child_builders, + }) + } + + fn append_row( + &mut self, + builder: &mut impl VariantBuilderExt, + index: usize, + ) -> Result<(), ArrowError> { + let type_id = self.union_array.type_id(index); + let value_offset = self.union_array.value_offset(index); + + // Delegate to the appropriate child builder, or append null to handle an invalid type_id + match self.child_builders.get_mut(&type_id) { + Some(child_builder) => child_builder.append_row(builder, value_offset)?, + None => builder.append_null(), + } + + Ok(()) + } +} + +/// Dictionary array builder with simple O(1) indexing +pub(crate) struct DictionaryArrowToVariantBuilder<'a> { + keys: &'a dyn Array, // only needed for null checks + normalized_keys: Vec, + values_builder: Box>, +} + +impl<'a> DictionaryArrowToVariantBuilder<'a> { + pub(crate) fn new(array: &'a dyn Array, options: &'a CastOptions) -> Result { + let dict_array = array.as_any_dictionary(); + let values = dict_array.values(); + let values_builder = + make_arrow_to_variant_row_builder(values.data_type(), values.as_ref(), options)?; + + // WARNING: normalized_keys panics if values is empty + let normalized_keys = match values.len() { + 0 => Vec::new(), + _ => dict_array.normalized_keys(), + }; + + Ok(Self { + keys: dict_array.keys(), + normalized_keys, + values_builder: Box::new(values_builder), + }) + } + + fn append_row( + &mut self, + builder: &mut impl VariantBuilderExt, + index: usize, + ) -> Result<(), ArrowError> { + if self.keys.is_null(index) { + builder.append_null(); + } else { + let normalized_key = self.normalized_keys[index]; + self.values_builder.append_row(builder, normalized_key)?; + } + Ok(()) + } +} + +/// Run-end encoded array builder with efficient sequential access +pub(crate) struct RunEndEncodedArrowToVariantBuilder<'a, R: RunEndIndexType> { + run_array: &'a arrow::array::RunArray, + values_builder: Box>, + + run_ends: &'a [R::Native], + run_number: usize, // Physical index into run_ends and values + run_start: usize, // Logical start index of current run +} + +impl<'a, R: RunEndIndexType> RunEndEncodedArrowToVariantBuilder<'a, R> { + pub(crate) fn new(array: &'a dyn Array, options: &'a CastOptions) -> Result { + let Some(run_array) = array.as_run_opt() else { + return Err(ArrowError::CastError("Expected RunArray".to_string())); + }; + + let values = run_array.values(); + let values_builder = + make_arrow_to_variant_row_builder(values.data_type(), values.as_ref(), options)?; + + Ok(Self { + run_array, + values_builder: Box::new(values_builder), + run_ends: run_array.run_ends().values(), + run_number: 0, + run_start: 0, + }) + } + + fn set_run_for_index(&mut self, index: usize) -> Result<(), ArrowError> { + if index >= self.run_start { + let Some(run_end) = self.run_ends.get(self.run_number) else { + return Err(ArrowError::CastError(format!( + "Index {index} beyond run array" + ))); + }; + if index < run_end.as_usize() { + return Ok(()); + } + if index == run_end.as_usize() { + self.run_number += 1; + self.run_start = run_end.as_usize(); + return Ok(()); + } + } + + // Use partition_point for all non-sequential cases + let run_number = self + .run_ends + .partition_point(|&run_end| run_end.as_usize() <= index); + if run_number >= self.run_ends.len() { + return Err(ArrowError::CastError(format!( + "Index {index} beyond run array" + ))); + } + self.run_number = run_number; + self.run_start = match run_number { + 0 => 0, + _ => self.run_ends[run_number - 1].as_usize(), + }; + Ok(()) + } + + fn append_row( + &mut self, + builder: &mut impl VariantBuilderExt, + index: usize, + ) -> Result<(), ArrowError> { + self.set_run_for_index(index)?; + + // Handle null values + if self.run_array.values().is_null(self.run_number) { + builder.append_null(); + return Ok(()); + } + + // Re-encode the value + self.values_builder.append_row(builder, self.run_number)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{VariantArray, VariantArrayBuilder}; + use arrow::array::{ArrayRef, BooleanArray, Int32Array, StringArray}; + use std::sync::Arc; + + /// Builds a VariantArray from an Arrow array using the row builder. + fn execute_row_builder_test(array: &dyn Array) -> VariantArray { + let options = CastOptions::default(); + let mut row_builder = + make_arrow_to_variant_row_builder(array.data_type(), array, &options).unwrap(); + + let mut array_builder = VariantArrayBuilder::new(array.len()); + + // The repetitive loop that appears in every test + for i in 0..array.len() { + row_builder.append_row(&mut array_builder, i).unwrap(); + } + + let variant_array = array_builder.build(); + assert_eq!(variant_array.len(), array.len()); + variant_array + } + + /// Generic helper function to test row builders with basic assertion patterns. + /// Uses execute_row_builder_test and adds simple value comparison assertions. + fn test_row_builder_basic(array: &dyn Array, expected_values: Vec>) { + let variant_array = execute_row_builder_test(array); + + // The repetitive assertion pattern + for (i, expected) in expected_values.iter().enumerate() { + match expected { + Some(variant) => { + assert_eq!(variant_array.value(i), *variant, "Mismatch at index {}", i) + } + None => assert!(variant_array.is_null(i), "Expected null at index {}", i), + } + } + } + + #[test] + fn test_primitive_row_builder() { + let int_array = Int32Array::from(vec![Some(42), None, Some(100)]); + test_row_builder_basic( + &int_array, + vec![Some(Variant::Int32(42)), None, Some(Variant::Int32(100))], + ); + } + + #[test] + fn test_string_row_builder() { + let string_array = StringArray::from(vec![Some("hello"), None, Some("world")]); + test_row_builder_basic( + &string_array, + vec![ + Some(Variant::from("hello")), + None, + Some(Variant::from("world")), + ], + ); + } + + #[test] + fn test_boolean_row_builder() { + let bool_array = BooleanArray::from(vec![Some(true), None, Some(false)]); + test_row_builder_basic( + &bool_array, + vec![Some(Variant::from(true)), None, Some(Variant::from(false))], + ); + } + + #[test] + fn test_struct_row_builder() { + use arrow::array::{ArrayRef, Int32Array, StringArray, StructArray}; + use arrow_schema::{DataType, Field}; + use std::sync::Arc; + + // Create a struct array with int and string fields + let int_field = Field::new("id", DataType::Int32, true); + let string_field = Field::new("name", DataType::Utf8, true); + + let int_array = Int32Array::from(vec![Some(1), None, Some(3)]); + let string_array = StringArray::from(vec![Some("Alice"), Some("Bob"), None]); + + let struct_array = StructArray::try_new( + vec![int_field, string_field].into(), + vec![ + Arc::new(int_array) as ArrayRef, + Arc::new(string_array) as ArrayRef, + ], + None, + ) + .unwrap(); + + let variant_array = execute_row_builder_test(&struct_array); + + // Check first row - should have both fields + let first_variant = variant_array.value(0); + assert_eq!(first_variant.get_object_field("id"), Some(Variant::from(1))); + assert_eq!( + first_variant.get_object_field("name"), + Some(Variant::from("Alice")) + ); + + // Check second row - should have name field but not id (null field omitted) + let second_variant = variant_array.value(1); + assert_eq!(second_variant.get_object_field("id"), None); // null field omitted + assert_eq!( + second_variant.get_object_field("name"), + Some(Variant::from("Bob")) + ); + + // Check third row - should have id field but not name (null field omitted) + let third_variant = variant_array.value(2); + assert_eq!(third_variant.get_object_field("id"), Some(Variant::from(3))); + assert_eq!(third_variant.get_object_field("name"), None); // null field omitted + } + + #[test] + fn test_run_end_encoded_row_builder() { + use arrow::array::{Int32Array, RunArray}; + use arrow::datatypes::Int32Type; + + // Create a run-end encoded array: [A, A, B, B, B, C] + // run_ends: [2, 5, 6] + // values: ["A", "B", "C"] + let values = StringArray::from(vec!["A", "B", "C"]); + let run_ends = Int32Array::from(vec![2, 5, 6]); + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + + let variant_array = execute_row_builder_test(&run_array); + + // Verify the values + assert_eq!(variant_array.value(0), Variant::from("A")); // Run 0 + assert_eq!(variant_array.value(1), Variant::from("A")); // Run 0 + assert_eq!(variant_array.value(2), Variant::from("B")); // Run 1 + assert_eq!(variant_array.value(3), Variant::from("B")); // Run 1 + assert_eq!(variant_array.value(4), Variant::from("B")); // Run 1 + assert_eq!(variant_array.value(5), Variant::from("C")); // Run 2 + } + + #[test] + fn test_run_end_encoded_random_access() { + use arrow::array::{Int32Array, RunArray}; + use arrow::datatypes::Int32Type; + + // Create a run-end encoded array: [A, A, B, B, B, C] + let values = StringArray::from(vec!["A", "B", "C"]); + let run_ends = Int32Array::from(vec![2, 5, 6]); + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + + let options = CastOptions::default(); + let mut row_builder = + make_arrow_to_variant_row_builder(run_array.data_type(), &run_array, &options).unwrap(); + + // Test random access pattern (backward jumps, forward jumps) + let access_pattern = [0, 5, 2, 4, 1, 3]; // Mix of all cases + let expected_values = ["A", "C", "B", "B", "A", "B"]; + + for (i, &index) in access_pattern.iter().enumerate() { + let mut array_builder = VariantArrayBuilder::new(1); + row_builder.append_row(&mut array_builder, index).unwrap(); + let variant_array = array_builder.build(); + assert_eq!(variant_array.value(0), Variant::from(expected_values[i])); + } + } + + #[test] + fn test_run_end_encoded_with_nulls() { + use arrow::array::{Int32Array, RunArray}; + use arrow::datatypes::Int32Type; + + // Create a run-end encoded array with null values: [A, A, null, null, B] + let values = StringArray::from(vec![Some("A"), None, Some("B")]); + let run_ends = Int32Array::from(vec![2, 4, 5]); + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + + let options = CastOptions::default(); + let mut row_builder = + make_arrow_to_variant_row_builder(run_array.data_type(), &run_array, &options).unwrap(); + let mut array_builder = VariantArrayBuilder::new(5); + + // Test sequential access + for i in 0..5 { + row_builder.append_row(&mut array_builder, i).unwrap(); + } + + let variant_array = array_builder.build(); + assert_eq!(variant_array.len(), 5); + + // Verify the values + assert_eq!(variant_array.value(0), Variant::from("A")); // Run 0 + assert_eq!(variant_array.value(1), Variant::from("A")); // Run 0 + assert!(variant_array.is_null(2)); // Run 1 (null) + assert!(variant_array.is_null(3)); // Run 1 (null) + assert_eq!(variant_array.value(4), Variant::from("B")); // Run 2 + } + + #[test] + fn test_dictionary_row_builder() { + use arrow::array::{DictionaryArray, Int32Array}; + use arrow::datatypes::Int32Type; + + // Create a dictionary array: keys=[0, 1, 0, 2, 1], values=["apple", "banana", "cherry"] + let values = StringArray::from(vec!["apple", "banana", "cherry"]); + let keys = Int32Array::from(vec![0, 1, 0, 2, 1]); + let dict_array = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + + let variant_array = execute_row_builder_test(&dict_array); + + // Verify the values match the dictionary lookup + assert_eq!(variant_array.value(0), Variant::from("apple")); // keys[0] = 0 -> values[0] = "apple" + assert_eq!(variant_array.value(1), Variant::from("banana")); // keys[1] = 1 -> values[1] = "banana" + assert_eq!(variant_array.value(2), Variant::from("apple")); // keys[2] = 0 -> values[0] = "apple" + assert_eq!(variant_array.value(3), Variant::from("cherry")); // keys[3] = 2 -> values[2] = "cherry" + assert_eq!(variant_array.value(4), Variant::from("banana")); // keys[4] = 1 -> values[1] = "banana" + } + + #[test] + fn test_dictionary_with_nulls() { + use arrow::array::{DictionaryArray, Int32Array}; + use arrow::datatypes::Int32Type; + + // Create a dictionary array with null keys: keys=[0, null, 1, null, 2], values=["x", "y", "z"] + let values = StringArray::from(vec!["x", "y", "z"]); + let keys = Int32Array::from(vec![Some(0), None, Some(1), None, Some(2)]); + let dict_array = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + + let options = CastOptions::default(); + let mut row_builder = + make_arrow_to_variant_row_builder(dict_array.data_type(), &dict_array, &options) + .unwrap(); + let mut array_builder = VariantArrayBuilder::new(5); + + // Test sequential access + for i in 0..5 { + row_builder.append_row(&mut array_builder, i).unwrap(); + } + + let variant_array = array_builder.build(); + assert_eq!(variant_array.len(), 5); + + // Verify the values and nulls + assert_eq!(variant_array.value(0), Variant::from("x")); // keys[0] = 0 -> values[0] = "x" + assert!(variant_array.is_null(1)); // keys[1] = null + assert_eq!(variant_array.value(2), Variant::from("y")); // keys[2] = 1 -> values[1] = "y" + assert!(variant_array.is_null(3)); // keys[3] = null + assert_eq!(variant_array.value(4), Variant::from("z")); // keys[4] = 2 -> values[2] = "z" + } + + #[test] + fn test_dictionary_random_access() { + use arrow::array::{DictionaryArray, Int32Array}; + use arrow::datatypes::Int32Type; + + // Create a dictionary array: keys=[0, 1, 2, 0, 1, 2], values=["red", "green", "blue"] + let values = StringArray::from(vec!["red", "green", "blue"]); + let keys = Int32Array::from(vec![0, 1, 2, 0, 1, 2]); + let dict_array = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + + let options = CastOptions::default(); + let mut row_builder = + make_arrow_to_variant_row_builder(dict_array.data_type(), &dict_array, &options) + .unwrap(); + + // Test random access pattern + let access_pattern = [5, 0, 3, 1, 4, 2]; // Random order + let expected_values = ["blue", "red", "red", "green", "green", "blue"]; + + for (i, &index) in access_pattern.iter().enumerate() { + let mut array_builder = VariantArrayBuilder::new(1); + row_builder.append_row(&mut array_builder, index).unwrap(); + let variant_array = array_builder.build(); + assert_eq!(variant_array.value(0), Variant::from(expected_values[i])); + } + } + + #[test] + fn test_nested_dictionary() { + use arrow::array::{DictionaryArray, Int32Array, StructArray}; + use arrow::datatypes::{Field, Int32Type}; + + // Create a dictionary with struct values + let id_array = Int32Array::from(vec![1, 2, 3]); + let name_array = StringArray::from(vec!["Alice", "Bob", "Charlie"]); + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("id", DataType::Int32, false)), + Arc::new(id_array) as ArrayRef, + ), + ( + Arc::new(Field::new("name", DataType::Utf8, false)), + Arc::new(name_array) as ArrayRef, + ), + ]); + + let keys = Int32Array::from(vec![0, 1, 0, 2, 1]); + let dict_array = + DictionaryArray::::try_new(keys, Arc::new(struct_array)).unwrap(); + + let options = CastOptions::default(); + let mut row_builder = + make_arrow_to_variant_row_builder(dict_array.data_type(), &dict_array, &options) + .unwrap(); + let mut array_builder = VariantArrayBuilder::new(5); + + // Test sequential access + for i in 0..5 { + row_builder.append_row(&mut array_builder, i).unwrap(); + } + + let variant_array = array_builder.build(); + assert_eq!(variant_array.len(), 5); + + // Verify the nested struct values + let first_variant = variant_array.value(0); + assert_eq!(first_variant.get_object_field("id"), Some(Variant::from(1))); + assert_eq!( + first_variant.get_object_field("name"), + Some(Variant::from("Alice")) + ); + + let second_variant = variant_array.value(1); + assert_eq!( + second_variant.get_object_field("id"), + Some(Variant::from(2)) + ); + assert_eq!( + second_variant.get_object_field("name"), + Some(Variant::from("Bob")) + ); + + // Test that repeated keys give same values + let third_variant = variant_array.value(2); + assert_eq!(third_variant.get_object_field("id"), Some(Variant::from(1))); + assert_eq!( + third_variant.get_object_field("name"), + Some(Variant::from("Alice")) + ); + } + + #[test] + fn test_list_row_builder() { + use arrow::array::ListArray; + + // Create a list array: [[1, 2], [3, 4, 5], null, []] + let data = vec![ + Some(vec![Some(1), Some(2)]), + Some(vec![Some(3), Some(4), Some(5)]), + None, + Some(vec![]), + ]; + let list_array = ListArray::from_iter_primitive::(data); + + let variant_array = execute_row_builder_test(&list_array); + + // Row 0: [1, 2] + let row0 = variant_array.value(0); + let list0 = row0.as_list().unwrap(); + assert_eq!(list0.len(), 2); + assert_eq!(list0.get(0), Some(Variant::from(1))); + assert_eq!(list0.get(1), Some(Variant::from(2))); + + // Row 1: [3, 4, 5] + let row1 = variant_array.value(1); + let list1 = row1.as_list().unwrap(); + assert_eq!(list1.len(), 3); + assert_eq!(list1.get(0), Some(Variant::from(3))); + assert_eq!(list1.get(1), Some(Variant::from(4))); + assert_eq!(list1.get(2), Some(Variant::from(5))); + + // Row 2: null + assert!(variant_array.is_null(2)); + + // Row 3: [] + let row3 = variant_array.value(3); + let list3 = row3.as_list().unwrap(); + assert_eq!(list3.len(), 0); + } + + #[test] + fn test_sliced_list_row_builder() { + use arrow::array::ListArray; + + // Create a list array: [[1, 2], [3, 4, 5], [6]] + let data = vec![ + Some(vec![Some(1), Some(2)]), + Some(vec![Some(3), Some(4), Some(5)]), + Some(vec![Some(6)]), + ]; + let list_array = ListArray::from_iter_primitive::(data); + + // Slice to get just the middle element: [[3, 4, 5]] + let sliced_array = list_array.slice(1, 1); + + let options = CastOptions::default(); + let mut row_builder = + make_arrow_to_variant_row_builder(sliced_array.data_type(), &sliced_array, &options) + .unwrap(); + let mut variant_array_builder = VariantArrayBuilder::new(sliced_array.len()); + + // Test the single row + row_builder + .append_row(&mut variant_array_builder, 0) + .unwrap(); + let variant_array = variant_array_builder.build(); + + // Verify result + assert_eq!(variant_array.len(), 1); + + // Row 0: [3, 4, 5] + let row0 = variant_array.value(0); + let list0 = row0.as_list().unwrap(); + assert_eq!(list0.len(), 3); + assert_eq!(list0.get(0), Some(Variant::from(3))); + assert_eq!(list0.get(1), Some(Variant::from(4))); + assert_eq!(list0.get(2), Some(Variant::from(5))); + } + + #[test] + fn test_nested_list_row_builder() { + use arrow::array::ListArray; + use arrow::datatypes::Field; + + // Build the nested structure manually + let inner_field = Arc::new(Field::new("item", DataType::Int32, true)); + let inner_list_field = Arc::new(Field::new("item", DataType::List(inner_field), true)); + + let values_data = vec![Some(vec![Some(1), Some(2)]), Some(vec![Some(3)])]; + let values_list = ListArray::from_iter_primitive::(values_data); + + let outer_offsets = arrow::buffer::OffsetBuffer::new(vec![0i32, 2, 2].into()); + let outer_list = ListArray::new( + inner_list_field, + outer_offsets, + Arc::new(values_list), + Some(arrow::buffer::NullBuffer::from(vec![true, false])), + ); + + let options = CastOptions::default(); + let mut row_builder = + make_arrow_to_variant_row_builder(outer_list.data_type(), &outer_list, &options) + .unwrap(); + let mut variant_array_builder = VariantArrayBuilder::new(outer_list.len()); + + for i in 0..outer_list.len() { + row_builder + .append_row(&mut variant_array_builder, i) + .unwrap(); + } + + let variant_array = variant_array_builder.build(); + + // Verify results + assert_eq!(variant_array.len(), 2); + + // Row 0: [[1, 2], [3]] + let row0 = variant_array.value(0); + let outer_list0 = row0.as_list().unwrap(); + assert_eq!(outer_list0.len(), 2); + + let inner_list0_0 = outer_list0.get(0).unwrap(); + let inner_list0_0 = inner_list0_0.as_list().unwrap(); + assert_eq!(inner_list0_0.len(), 2); + assert_eq!(inner_list0_0.get(0), Some(Variant::from(1))); + assert_eq!(inner_list0_0.get(1), Some(Variant::from(2))); + + let inner_list0_1 = outer_list0.get(1).unwrap(); + let inner_list0_1 = inner_list0_1.as_list().unwrap(); + assert_eq!(inner_list0_1.len(), 1); + assert_eq!(inner_list0_1.get(0), Some(Variant::from(3))); + + // Row 1: null + assert!(variant_array.is_null(1)); + } + + #[test] + fn test_map_row_builder() { + use arrow::array::{Int32Array, MapArray, StringArray, StructArray}; + use arrow::buffer::{NullBuffer, OffsetBuffer}; + use arrow::datatypes::{DataType, Field, Fields}; + use std::sync::Arc; + + // Create the entries struct array (key-value pairs) + let keys = StringArray::from(vec!["key1", "key2", "key3"]); + let values = Int32Array::from(vec![1, 2, 3]); + let entries_fields = Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ]); + let entries = StructArray::new( + entries_fields.clone(), + vec![Arc::new(keys), Arc::new(values)], + None, // No nulls in the entries themselves + ); + + // Create offsets for 4 maps: [0..1], [1..1], [1..1], [1..3] + // Map 0: {"key1": 1} (1 entry) + // Map 1: {} (0 entries - empty) + // Map 2: null (0 entries but NULL via null buffer) + // Map 3: {"key2": 2, "key3": 3} (2 entries) + let offsets = OffsetBuffer::new(vec![0, 1, 1, 1, 3].into()); + + // Create null buffer - map at index 2 is NULL + let null_buffer = Some(NullBuffer::from(vec![true, true, false, true])); + + // Create the map field + let map_field = Arc::new(Field::new( + "entries", + DataType::Struct(entries_fields), + false, // Keys are non-nullable + )); + + // Create MapArray using try_new + let map_array = MapArray::try_new( + map_field, + offsets, + entries, + null_buffer, + false, // not ordered + ) + .unwrap(); + + let variant_array = execute_row_builder_test(&map_array); + + // Map 0: {"key1": 1} + let map0 = variant_array.value(0); + let obj0 = map0.as_object().unwrap(); + assert_eq!(obj0.len(), 1); + assert_eq!(obj0.get("key1"), Some(Variant::from(1))); + + // Map 1: {} (empty object, not null) + let map1 = variant_array.value(1); + let obj1 = map1.as_object().unwrap(); + assert_eq!(obj1.len(), 0); // Empty object + + // Map 2: null (actual NULL) + assert!(variant_array.is_null(2)); + + // Map 3: {"key2": 2, "key3": 3} + let map3 = variant_array.value(3); + let obj3 = map3.as_object().unwrap(); + assert_eq!(obj3.len(), 2); + assert_eq!(obj3.get("key2"), Some(Variant::from(2))); + assert_eq!(obj3.get("key3"), Some(Variant::from(3))); + } + + #[test] + fn test_union_sparse_row_builder() { + use arrow::array::{Float64Array, Int32Array, StringArray, UnionArray}; + use arrow::buffer::ScalarBuffer; + use arrow::datatypes::{DataType, Field, UnionFields}; + use std::sync::Arc; + + // Create a sparse union array with mixed types (int, float, string) + let int_array = Int32Array::from(vec![Some(1), None, None, None, Some(34), None]); + let float_array = Float64Array::from(vec![None, Some(3.2), None, Some(32.5), None, None]); + let string_array = StringArray::from(vec![None, None, Some("hello"), None, None, None]); + let type_ids = [0, 1, 2, 1, 0, 0].into_iter().collect::>(); + + let union_fields = UnionFields::new( + vec![0, 1, 2], + vec![ + Field::new("int_field", DataType::Int32, false), + Field::new("float_field", DataType::Float64, false), + Field::new("string_field", DataType::Utf8, false), + ], + ); + + let children: Vec> = vec![ + Arc::new(int_array), + Arc::new(float_array), + Arc::new(string_array), + ]; + + let union_array = UnionArray::try_new( + union_fields, + type_ids, + None, // Sparse union + children, + ) + .unwrap(); + + let variant_array = execute_row_builder_test(&union_array); + assert_eq!(variant_array.value(0), Variant::Int32(1)); + assert_eq!(variant_array.value(1), Variant::Double(3.2)); + assert_eq!(variant_array.value(2), Variant::from("hello")); + assert_eq!(variant_array.value(3), Variant::Double(32.5)); + assert_eq!(variant_array.value(4), Variant::Int32(34)); + assert!(variant_array.is_null(5)); + } + + #[test] + fn test_union_dense_row_builder() { + use arrow::array::{Float64Array, Int32Array, StringArray, UnionArray}; + use arrow::buffer::ScalarBuffer; + use arrow::datatypes::{DataType, Field, UnionFields}; + use std::sync::Arc; + + // Create a dense union array with mixed types (int, float, string) + let int_array = Int32Array::from(vec![Some(1), Some(34), None]); + let float_array = Float64Array::from(vec![3.2, 32.5]); + let string_array = StringArray::from(vec!["hello"]); + let type_ids = [0, 1, 2, 1, 0, 0].into_iter().collect::>(); + let offsets = [0, 0, 0, 1, 1, 2] + .into_iter() + .collect::>(); + + let union_fields = UnionFields::new( + vec![0, 1, 2], + vec![ + Field::new("int_field", DataType::Int32, false), + Field::new("float_field", DataType::Float64, false), + Field::new("string_field", DataType::Utf8, false), + ], + ); + + let children: Vec> = vec![ + Arc::new(int_array), + Arc::new(float_array), + Arc::new(string_array), + ]; + + let union_array = UnionArray::try_new( + union_fields, + type_ids, + Some(offsets), // Dense union + children, + ) + .unwrap(); + + // Test the row builder + let options = CastOptions::default(); + let mut row_builder = + make_arrow_to_variant_row_builder(union_array.data_type(), &union_array, &options) + .unwrap(); + + let mut variant_builder = VariantArrayBuilder::new(union_array.len()); + for i in 0..union_array.len() { + row_builder.append_row(&mut variant_builder, i).unwrap(); + } + let variant_array = variant_builder.build(); + + assert_eq!(variant_array.len(), 6); + assert_eq!(variant_array.value(0), Variant::Int32(1)); + assert_eq!(variant_array.value(1), Variant::Double(3.2)); + assert_eq!(variant_array.value(2), Variant::from("hello")); + assert_eq!(variant_array.value(3), Variant::Double(32.5)); + assert_eq!(variant_array.value(4), Variant::Int32(34)); + assert!(variant_array.is_null(5)); + } + + #[test] + fn test_union_sparse_type_ids_row_builder() { + use arrow::array::{Int32Array, StringArray, UnionArray}; + use arrow::buffer::ScalarBuffer; + use arrow::datatypes::{DataType, Field, UnionFields}; + use std::sync::Arc; + + // Create a sparse union with non-contiguous type IDs (1, 3) + let int_array = Int32Array::from(vec![Some(42), None]); + let string_array = StringArray::from(vec![None, Some("test")]); + let type_ids = [1, 3].into_iter().collect::>(); + + let union_fields = UnionFields::new( + vec![1, 3], // Non-contiguous type IDs + vec![ + Field::new("int_field", DataType::Int32, false), + Field::new("string_field", DataType::Utf8, false), + ], + ); + + let children: Vec> = vec![Arc::new(int_array), Arc::new(string_array)]; + + let union_array = UnionArray::try_new( + union_fields, + type_ids, + None, // Sparse union + children, + ) + .unwrap(); + + // Test the row builder + let options = CastOptions::default(); + let mut row_builder = + make_arrow_to_variant_row_builder(union_array.data_type(), &union_array, &options) + .unwrap(); + + let mut variant_builder = VariantArrayBuilder::new(union_array.len()); + for i in 0..union_array.len() { + row_builder.append_row(&mut variant_builder, i).unwrap(); + } + let variant_array = variant_builder.build(); + + // Verify results + assert_eq!(variant_array.len(), 2); + + // Row 0: int 42 (type_id = 1) + assert_eq!(variant_array.value(0), Variant::Int32(42)); + + // Row 1: string "test" (type_id = 3) + assert_eq!(variant_array.value(1), Variant::from("test")); + } + + #[test] + fn test_decimal32_row_builder() { + use arrow::array::Decimal32Array; + use parquet_variant::VariantDecimal4; + + // Test Decimal32Array with scale 2 (e.g., for currency: 12.34) + let decimal_array = Decimal32Array::from(vec![Some(1234), None, Some(-5678)]) + .with_precision_and_scale(9, 2) + .unwrap(); + + test_row_builder_basic( + &decimal_array, + vec![ + Some(Variant::from(VariantDecimal4::try_new(1234, 2).unwrap())), + None, + Some(Variant::from(VariantDecimal4::try_new(-5678, 2).unwrap())), + ], + ); + } + + #[test] + fn test_decimal128_row_builder() { + use arrow::array::Decimal128Array; + use parquet_variant::VariantDecimal16; + + // Test Decimal128Array with negative scale (multiply by 10^|scale|) + let decimal_array = Decimal128Array::from(vec![Some(123), None, Some(456)]) + .with_precision_and_scale(10, -2) + .unwrap(); + + test_row_builder_basic( + &decimal_array, + vec![ + Some(Variant::from(VariantDecimal16::try_new(12300, 0).unwrap())), + None, + Some(Variant::from(VariantDecimal16::try_new(45600, 0).unwrap())), + ], + ); + } + + #[test] + fn test_decimal256_overflow_row_builder() { + use arrow::array::Decimal256Array; + use arrow::datatypes::i256; + + // Test Decimal256Array with a value that overflows i128 + let large_value = i256::from_i128(i128::MAX) + i256::from(1); // Overflows i128 + let decimal_array = Decimal256Array::from(vec![Some(large_value), Some(i256::from(123))]) + .with_precision_and_scale(76, 3) + .unwrap(); + + test_row_builder_basic( + &decimal_array, + vec![ + Some(Variant::Null), // Overflow value becomes Null + Some(Variant::from(VariantDecimal16::try_new(123, 3).unwrap())), + ], + ); + } + + #[test] + fn test_binary_row_builder() { + use arrow::array::BinaryArray; + + let binary_data = vec![ + Some(b"hello".as_slice()), + None, + Some(b"\x00\x01\x02\xFF".as_slice()), + Some(b"".as_slice()), // Empty binary + ]; + let binary_array = BinaryArray::from(binary_data); + + test_row_builder_basic( + &binary_array, + vec![ + Some(Variant::from(b"hello".as_slice())), + None, + Some(Variant::from([0x00, 0x01, 0x02, 0xFF].as_slice())), + Some(Variant::from([].as_slice())), + ], + ); + } + + #[test] + fn test_binary_view_row_builder() { + use arrow::array::BinaryViewArray; + + let binary_data = vec![ + Some(b"short".as_slice()), + None, + Some(b"this is a longer binary view that exceeds inline storage".as_slice()), + ]; + let binary_view_array = BinaryViewArray::from(binary_data); + + test_row_builder_basic( + &binary_view_array, + vec![ + Some(Variant::from(b"short".as_slice())), + None, + Some(Variant::from( + b"this is a longer binary view that exceeds inline storage".as_slice(), + )), + ], + ); + } + + #[test] + fn test_fixed_size_binary_row_builder() { + use arrow::array::FixedSizeBinaryArray; + + let binary_data = vec![ + Some([0x01, 0x02, 0x03, 0x04]), + None, + Some([0xFF, 0xFE, 0xFD, 0xFC]), + ]; + let fixed_binary_array = + FixedSizeBinaryArray::try_from_sparse_iter_with_size(binary_data.into_iter(), 4) + .unwrap(); + + test_row_builder_basic( + &fixed_binary_array, + vec![ + Some(Variant::from([0x01, 0x02, 0x03, 0x04].as_slice())), + None, + Some(Variant::from([0xFF, 0xFE, 0xFD, 0xFC].as_slice())), + ], + ); + } + + #[test] + fn test_utf8_view_row_builder() { + use arrow::array::StringViewArray; + + let string_data = vec![ + Some("short"), + None, + Some("this is a much longer string that will be stored out-of-line in the buffer"), + ]; + let string_view_array = StringViewArray::from(string_data); + + test_row_builder_basic( + &string_view_array, + vec![ + Some(Variant::from("short")), + None, + Some(Variant::from( + "this is a much longer string that will be stored out-of-line in the buffer", + )), + ], + ); + } + + #[test] + fn test_timestamp_second_row_builder() { + use arrow::array::TimestampSecondArray; + + let timestamp_data = vec![ + Some(1609459200), // 2021-01-01 00:00:00 UTC + None, + Some(1640995200), // 2022-01-01 00:00:00 UTC + ]; + let timestamp_array = TimestampSecondArray::from(timestamp_data); + + let expected_naive1 = DateTime::from_timestamp(1609459200, 0).unwrap().naive_utc(); + let expected_naive2 = DateTime::from_timestamp(1640995200, 0).unwrap().naive_utc(); + + test_row_builder_basic( + ×tamp_array, + vec![ + Some(Variant::from(expected_naive1)), + None, + Some(Variant::from(expected_naive2)), + ], + ); + } + + #[test] + fn test_timestamp_with_timezone_row_builder() { + use arrow::array::TimestampMicrosecondArray; + use chrono::DateTime; + + let timestamp_data = vec![ + Some(1609459200000000), // 2021-01-01 00:00:00 UTC (in microseconds) + None, + Some(1640995200000000), // 2022-01-01 00:00:00 UTC (in microseconds) + ]; + let timezone = "UTC".to_string(); + let timestamp_array = + TimestampMicrosecondArray::from(timestamp_data).with_timezone(timezone); + + let expected_utc1 = DateTime::from_timestamp(1609459200, 0).unwrap(); + let expected_utc2 = DateTime::from_timestamp(1640995200, 0).unwrap(); + + test_row_builder_basic( + ×tamp_array, + vec![ + Some(Variant::from(expected_utc1)), + None, + Some(Variant::from(expected_utc2)), + ], + ); + } + + #[test] + fn test_timestamp_nanosecond_precision_row_builder() { + use arrow::array::TimestampNanosecondArray; + + let timestamp_data = vec![ + Some(1609459200123456789), // 2021-01-01 00:00:00.123456789 UTC + None, + Some(1609459200000000000), // 2021-01-01 00:00:00.000000000 UTC (no fractional seconds) + ]; + let timestamp_array = TimestampNanosecondArray::from(timestamp_data); + + let expected_with_nanos = DateTime::from_timestamp(1609459200, 123456789) + .unwrap() + .naive_utc(); + let expected_no_nanos = DateTime::from_timestamp(1609459200, 0).unwrap().naive_utc(); + + test_row_builder_basic( + ×tamp_array, + vec![ + Some(Variant::from(expected_with_nanos)), + None, + Some(Variant::from(expected_no_nanos)), + ], + ); + } + + #[test] + fn test_timestamp_millisecond_row_builder() { + use arrow::array::TimestampMillisecondArray; + + let timestamp_data = vec![ + Some(1609459200123), // 2021-01-01 00:00:00.123 UTC + None, + Some(1609459200000), // 2021-01-01 00:00:00.000 UTC + ]; + let timestamp_array = TimestampMillisecondArray::from(timestamp_data); + + let expected_with_millis = DateTime::from_timestamp(1609459200, 123000000) + .unwrap() + .naive_utc(); + let expected_no_millis = DateTime::from_timestamp(1609459200, 0).unwrap().naive_utc(); + + test_row_builder_basic( + ×tamp_array, + vec![ + Some(Variant::from(expected_with_millis)), + None, + Some(Variant::from(expected_no_millis)), + ], + ); + } + + #[test] + fn test_date32_row_builder() { + use arrow::array::Date32Array; + use chrono::NaiveDate; + + let date_data = vec![ + Some(0), // 1970-01-01 + None, + Some(19723), // 2024-01-01 (days since epoch) + Some(-719162), // 0001-01-01 (near minimum) + ]; + let date_array = Date32Array::from(date_data); + + let expected_epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + let expected_2024 = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(); + let expected_min = NaiveDate::from_ymd_opt(1, 1, 1).unwrap(); + + test_row_builder_basic( + &date_array, + vec![ + Some(Variant::from(expected_epoch)), + None, + Some(Variant::from(expected_2024)), + Some(Variant::from(expected_min)), + ], + ); + } + + #[test] + fn test_date64_row_builder() { + use arrow::array::Date64Array; + use chrono::NaiveDate; + + // Test Date64Array with various dates (milliseconds since epoch) + let date_data = vec![ + Some(0), // 1970-01-01 + None, + Some(1704067200000), // 2024-01-01 (milliseconds since epoch) + Some(86400000), // 1970-01-02 + ]; + let date_array = Date64Array::from(date_data); + + let expected_epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + let expected_2024 = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(); + let expected_next_day = NaiveDate::from_ymd_opt(1970, 1, 2).unwrap(); + + test_row_builder_basic( + &date_array, + vec![ + Some(Variant::from(expected_epoch)), + None, + Some(Variant::from(expected_2024)), + Some(Variant::from(expected_next_day)), + ], + ); + } + + #[test] + fn test_time32_second_row_builder() { + use arrow::array::Time32SecondArray; + use chrono::NaiveTime; + + // Test Time32SecondArray with various times (seconds since midnight) + let time_data = vec![ + Some(0), // 00:00:00 + None, + Some(3661), // 01:01:01 + Some(86399), // 23:59:59 + ]; + let time_array = Time32SecondArray::from(time_data); + + let expected_midnight = NaiveTime::from_hms_opt(0, 0, 0).unwrap(); + let expected_time = NaiveTime::from_hms_opt(1, 1, 1).unwrap(); + let expected_last = NaiveTime::from_hms_opt(23, 59, 59).unwrap(); + + test_row_builder_basic( + &time_array, + vec![ + Some(Variant::from(expected_midnight)), + None, + Some(Variant::from(expected_time)), + Some(Variant::from(expected_last)), + ], + ); + } + + #[test] + fn test_time32_millisecond_row_builder() { + use arrow::array::Time32MillisecondArray; + use chrono::NaiveTime; + + // Test Time32MillisecondArray with various times (milliseconds since midnight) + let time_data = vec![ + Some(0), // 00:00:00.000 + None, + Some(3661123), // 01:01:01.123 + Some(86399999), // 23:59:59.999 + ]; + let time_array = Time32MillisecondArray::from(time_data); + + let expected_midnight = NaiveTime::from_hms_milli_opt(0, 0, 0, 0).unwrap(); + let expected_time = NaiveTime::from_hms_milli_opt(1, 1, 1, 123).unwrap(); + let expected_last = NaiveTime::from_hms_milli_opt(23, 59, 59, 999).unwrap(); + + test_row_builder_basic( + &time_array, + vec![ + Some(Variant::from(expected_midnight)), + None, + Some(Variant::from(expected_time)), + Some(Variant::from(expected_last)), + ], + ); + } + + #[test] + fn test_time64_microsecond_row_builder() { + use arrow::array::Time64MicrosecondArray; + use chrono::NaiveTime; + + // Test Time64MicrosecondArray with various times (microseconds since midnight) + let time_data = vec![ + Some(0), // 00:00:00.000000 + None, + Some(3661123456), // 01:01:01.123456 + Some(86399999999), // 23:59:59.999999 + ]; + let time_array = Time64MicrosecondArray::from(time_data); + + let expected_midnight = NaiveTime::from_hms_micro_opt(0, 0, 0, 0).unwrap(); + let expected_time = NaiveTime::from_hms_micro_opt(1, 1, 1, 123456).unwrap(); + let expected_last = NaiveTime::from_hms_micro_opt(23, 59, 59, 999999).unwrap(); + + test_row_builder_basic( + &time_array, + vec![ + Some(Variant::from(expected_midnight)), + None, + Some(Variant::from(expected_time)), + Some(Variant::from(expected_last)), + ], + ); + } + + #[test] + fn test_time64_nanosecond_row_builder() { + use arrow::array::Time64NanosecondArray; + use chrono::NaiveTime; + + // Test Time64NanosecondArray with various times (nanoseconds since midnight) + let time_data = vec![ + Some(0), // 00:00:00.000000000 + None, + Some(3661123456789), // 01:01:01.123456789 + Some(86399999999999), // 23:59:59.999999999 + ]; + let time_array = Time64NanosecondArray::from(time_data); + + let expected_midnight = NaiveTime::from_hms_nano_opt(0, 0, 0, 0).unwrap(); + // Nanoseconds are truncated to microsecond precision in Variant + let expected_time = NaiveTime::from_hms_micro_opt(1, 1, 1, 123456).unwrap(); + let expected_last = NaiveTime::from_hms_micro_opt(23, 59, 59, 999999).unwrap(); + + test_row_builder_basic( + &time_array, + vec![ + Some(Variant::from(expected_midnight)), + None, + Some(Variant::from(expected_time)), + Some(Variant::from(expected_last)), + ], + ); + } +} diff --git a/parquet-variant-compute/src/cast_to_variant.rs b/parquet-variant-compute/src/cast_to_variant.rs index 446baf30384c..295019645f62 100644 --- a/parquet-variant-compute/src/cast_to_variant.rs +++ b/parquet-variant-compute/src/cast_to_variant.rs @@ -15,47 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::{VariantArray, VariantArrayBuilder}; -use arrow::array::{Array, AsArray}; -use arrow::datatypes::{ - BinaryType, BinaryViewType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, - Int64Type, Int8Type, LargeBinaryType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, -}; -use arrow_schema::{ArrowError, DataType}; -use half::f16; -use parquet_variant::Variant; - -/// Convert the input array of a specific primitive type to a `VariantArray` -/// row by row -macro_rules! primitive_conversion { - ($t:ty, $input:expr, $builder:expr) => {{ - let array = $input.as_primitive::<$t>(); - for i in 0..array.len() { - if array.is_null(i) { - $builder.append_null(); - continue; - } - $builder.append_variant(Variant::from(array.value(i))); - } - }}; -} - -/// Convert the input array to a `VariantArray` row by row, using `method` -/// to downcast the generic array to a specific array type and `cast_fn` -/// to transform each element to a type compatible with Variant -macro_rules! cast_conversion { - ($t:ty, $method:ident, $cast_fn:expr, $input:expr, $builder:expr) => {{ - let array = $input.$method::<$t>(); - for i in 0..array.len() { - if array.is_null(i) { - $builder.append_null(); - continue; - } - let cast_value = $cast_fn(array.value(i)); - $builder.append_variant(Variant::from(cast_value)); - } - }}; -} +use crate::arrow_to_variant::make_arrow_to_variant_row_builder; +use crate::{CastOptions, VariantArray, VariantArrayBuilder}; +use arrow::array::Array; +use arrow_schema::ArrowError; /// Casts a typed arrow [`Array`] to a [`VariantArray`]. This is useful when you /// need to convert a specific data type @@ -80,67 +43,41 @@ macro_rules! cast_conversion { /// assert!(result.is_null(1)); // note null, not Variant::Null /// assert_eq!(result.value(2), Variant::Int64(3)); /// ``` -pub fn cast_to_variant(input: &dyn Array) -> Result { - let mut builder = VariantArrayBuilder::new(input.len()); +/// +/// For `DataType::Timestamp`s: if the timestamp has any level of precision +/// greater than a microsecond, it will be truncated. For example +/// `1970-01-01T00:00:01.234567890Z` +/// will be truncated to +/// `1970-01-01T00:00:01.234567Z` +/// +/// # Arguments +/// * `input` - The array to convert to VariantArray +/// * `options` - Options controlling conversion behavior +pub fn cast_to_variant_with_options( + input: &dyn Array, + options: &CastOptions, +) -> Result { + // Create row builder for the input array type + let mut row_builder = make_arrow_to_variant_row_builder(input.data_type(), input, options)?; - let input_type = input.data_type(); - // todo: handle other types like Boolean, Strings, Date, Timestamp, etc. - match input_type { - DataType::Binary => { - cast_conversion!(BinaryType, as_bytes, |v| v, input, builder); - } - DataType::LargeBinary => { - cast_conversion!(LargeBinaryType, as_bytes, |v| v, input, builder); - } - DataType::BinaryView => { - cast_conversion!(BinaryViewType, as_byte_view, |v| v, input, builder); - } - DataType::Int8 => { - primitive_conversion!(Int8Type, input, builder); - } - DataType::Int16 => { - primitive_conversion!(Int16Type, input, builder); - } - DataType::Int32 => { - primitive_conversion!(Int32Type, input, builder); - } - DataType::Int64 => { - primitive_conversion!(Int64Type, input, builder); - } - DataType::UInt8 => { - primitive_conversion!(UInt8Type, input, builder); - } - DataType::UInt16 => { - primitive_conversion!(UInt16Type, input, builder); - } - DataType::UInt32 => { - primitive_conversion!(UInt32Type, input, builder); - } - DataType::UInt64 => { - primitive_conversion!(UInt64Type, input, builder); - } - DataType::Float16 => { - cast_conversion!( - Float16Type, - as_primitive, - |v: f16| -> f32 { v.into() }, - input, - builder - ); - } - DataType::Float32 => { - primitive_conversion!(Float32Type, input, builder); - } - DataType::Float64 => { - primitive_conversion!(Float64Type, input, builder); - } - dt => { - return Err(ArrowError::CastError(format!( - "Unsupported data type for casting to Variant: {dt:?}", - ))); - } - }; - Ok(builder.build()) + // Create output array builder + let mut array_builder = VariantArrayBuilder::new(input.len()); + + // Process each row using the row builder + for i in 0..input.len() { + row_builder.append_row(&mut array_builder, i)?; + } + + Ok(array_builder.build()) +} + +/// Convert an array to a [`VariantArray`] with strict mode enabled (returns errors on conversion +/// failures). +/// +/// This function provides backward compatibility. For non-strict behavior, +/// use [`cast_to_variant_with_options`] with `CastOptions { strict: false }`. +pub fn cast_to_variant(input: &dyn Array) -> Result { + cast_to_variant_with_options(input, &CastOptions::default()) } // TODO do we need a cast_with_options to allow specifying conversion behavior, @@ -151,63 +88,59 @@ pub fn cast_to_variant(input: &dyn Array) -> Result { mod tests { use super::*; use arrow::array::{ - ArrayRef, Float16Array, Float32Array, Float64Array, GenericByteBuilder, - GenericByteViewBuilder, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, - UInt32Array, UInt64Array, UInt8Array, + ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, + Decimal256Array, Decimal32Array, Decimal64Array, DictionaryArray, DurationMicrosecondArray, + DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, + FixedSizeBinaryBuilder, Float16Array, Float32Array, Float64Array, GenericByteBuilder, + GenericByteViewBuilder, Int16Array, Int32Array, Int64Array, Int8Array, + IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeListArray, + LargeStringArray, ListArray, MapArray, NullArray, StringArray, StringRunBuilder, + StringViewArray, StructArray, Time32MillisecondArray, Time32SecondArray, + Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, + TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt16Array, + UInt32Array, UInt64Array, UInt8Array, UnionArray, + }; + use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; + use arrow::datatypes::{ + i256, BinaryType, BinaryViewType, Date32Type, Date64Type, Int32Type, Int64Type, Int8Type, + IntervalDayTime, IntervalMonthDayNano, LargeBinaryType, + }; + use arrow_schema::{DataType, Field, Fields, UnionFields}; + use arrow_schema::{ + DECIMAL128_MAX_PRECISION, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, }; - use parquet_variant::{Variant, VariantDecimal16}; - use std::sync::Arc; + use chrono::{DateTime, NaiveDate, NaiveTime}; + use half::f16; + use parquet_variant::{ + Variant, VariantBuilder, VariantDecimal16, VariantDecimal4, VariantDecimal8, + }; + use std::{sync::Arc, vec}; - #[test] - fn test_cast_to_variant_binary() { - // BinaryType - let mut builder = GenericByteBuilder::::new(); - builder.append_value(b"hello"); - builder.append_value(b""); - builder.append_null(); - builder.append_value(b"world"); - let binary_array = builder.finish(); - run_test( - Arc::new(binary_array), - vec![ - Some(Variant::Binary(b"hello")), - Some(Variant::Binary(b"")), - None, - Some(Variant::Binary(b"world")), - ], - ); + macro_rules! max_unscaled_value { + (32, $precision:expr) => { + (u32::pow(10, $precision as u32) - 1) as i32 + }; + (64, $precision:expr) => { + (u64::pow(10, $precision as u32) - 1) as i64 + }; + (128, $precision:expr) => { + (u128::pow(10, $precision as u32) - 1) as i128 + }; + } - // LargeBinaryType - let mut builder = GenericByteBuilder::::new(); - builder.append_value(b"hello"); - builder.append_value(b""); - builder.append_null(); - builder.append_value(b"world"); - let large_binary_array = builder.finish(); - run_test( - Arc::new(large_binary_array), - vec![ - Some(Variant::Binary(b"hello")), - Some(Variant::Binary(b"")), - None, - Some(Variant::Binary(b"world")), - ], - ); + #[test] + fn test_cast_to_variant_null() { + run_test(Arc::new(NullArray::new(2)), vec![None, None]) + } - // BinaryViewType - let mut builder = GenericByteViewBuilder::::new(); - builder.append_value(b"hello"); - builder.append_value(b""); - builder.append_null(); - builder.append_value(b"world"); - let byte_view_array = builder.finish(); + #[test] + fn test_cast_to_variant_bool() { run_test( - Arc::new(byte_view_array), + Arc::new(BooleanArray::from(vec![Some(true), None, Some(false)])), vec![ - Some(Variant::Binary(b"hello")), - Some(Variant::Binary(b"")), + Some(Variant::BooleanTrue), None, - Some(Variant::Binary(b"world")), + Some(Variant::BooleanFalse), ], ); } @@ -441,23 +374,1557 @@ mod tests { ) } - /// Converts the given `Array` to a `VariantArray` and tests the conversion - /// against the expected values. It also tests the handling of nulls by - /// setting one element to null and verifying the output. - fn run_test(values: ArrayRef, expected: Vec>) { - // test without nulls - let variant_array = cast_to_variant(&values).unwrap(); - assert_eq!(variant_array.len(), expected.len()); - for (i, expected_value) in expected.iter().enumerate() { - match expected_value { - Some(value) => { - assert!(!variant_array.is_null(i), "Expected non-null at index {i}"); - assert_eq!(variant_array.value(i), *value, "mismatch at index {i}"); - } - None => { - assert!(variant_array.is_null(i), "Expected null at index {i}"); + #[test] + fn test_cast_to_variant_decimal32() { + run_test( + Arc::new( + Decimal32Array::from(vec![ + Some(i32::MIN), + Some(-max_unscaled_value!(32, DECIMAL32_MAX_PRECISION) - 1), // Overflow value will be cast to Null + Some(-max_unscaled_value!(32, DECIMAL32_MAX_PRECISION)), // The min of Decimal32 with positive scale that can be cast to VariantDecimal4 + None, + Some(-123), + Some(0), + Some(123), + Some(max_unscaled_value!(32, DECIMAL32_MAX_PRECISION)), // The max of Decimal32 with positive scale that can be cast to VariantDecimal4 + Some(max_unscaled_value!(32, DECIMAL32_MAX_PRECISION) + 1), // Overflow value will be cast to Null + Some(i32::MAX), + ]) + .with_precision_and_scale(DECIMAL32_MAX_PRECISION, 3) + .unwrap(), + ), + vec![ + Some(Variant::Null), + Some(Variant::Null), + Some( + VariantDecimal4::try_new(-max_unscaled_value!(32, DECIMAL32_MAX_PRECISION), 3) + .unwrap() + .into(), + ), + None, + Some(VariantDecimal4::try_new(-123, 3).unwrap().into()), + Some(VariantDecimal4::try_new(0, 3).unwrap().into()), + Some(VariantDecimal4::try_new(123, 3).unwrap().into()), + Some( + VariantDecimal4::try_new(max_unscaled_value!(32, DECIMAL32_MAX_PRECISION), 3) + .unwrap() + .into(), + ), + Some(Variant::Null), + Some(Variant::Null), + ], + ) + } + + #[test] + fn test_cast_to_variant_decimal32_negative_scale() { + run_test( + Arc::new( + Decimal32Array::from(vec![ + Some(i32::MIN), + Some(-max_unscaled_value!(32, DECIMAL32_MAX_PRECISION - 3) - 1), // Overflow value will be cast to Null + Some(-max_unscaled_value!(32, DECIMAL32_MAX_PRECISION - 3)), // The min of Decimal32 with scale -3 that can be cast to VariantDecimal4 + None, + Some(-123), + Some(0), + Some(123), + Some(max_unscaled_value!(32, DECIMAL32_MAX_PRECISION - 3)), // The max of Decimal32 with scale -3 that can be cast to VariantDecimal4 + Some(max_unscaled_value!(32, DECIMAL32_MAX_PRECISION - 3) + 1), // Overflow value will be cast to Null + Some(i32::MAX), + ]) + .with_precision_and_scale(DECIMAL32_MAX_PRECISION, -3) + .unwrap(), + ), + vec![ + Some(Variant::Null), + Some(Variant::Null), + Some( + VariantDecimal4::try_new( + -max_unscaled_value!(32, DECIMAL32_MAX_PRECISION - 3) * 1000, + 0, + ) + .unwrap() + .into(), + ), + None, + Some(VariantDecimal4::try_new(-123_000, 0).unwrap().into()), + Some(VariantDecimal4::try_new(0, 0).unwrap().into()), + Some(VariantDecimal4::try_new(123_000, 0).unwrap().into()), + Some( + VariantDecimal4::try_new( + max_unscaled_value!(32, DECIMAL32_MAX_PRECISION - 3) * 1000, + 0, + ) + .unwrap() + .into(), + ), + Some(Variant::Null), + Some(Variant::Null), + ], + ) + } + + #[test] + fn test_cast_to_variant_decimal64() { + run_test( + Arc::new( + Decimal64Array::from(vec![ + Some(i64::MIN), + Some(-max_unscaled_value!(64, DECIMAL64_MAX_PRECISION) - 1), // Overflow value will be cast to Null + Some(-max_unscaled_value!(64, DECIMAL64_MAX_PRECISION)), // The min of Decimal64 with positive scale that can be cast to VariantDecimal8 + None, + Some(-123), + Some(0), + Some(123), + Some(max_unscaled_value!(64, DECIMAL64_MAX_PRECISION)), // The max of Decimal64 with positive scale that can be cast to VariantDecimal8 + Some(max_unscaled_value!(64, DECIMAL64_MAX_PRECISION) + 1), // Overflow value will be cast to Null + Some(i64::MAX), + ]) + .with_precision_and_scale(DECIMAL64_MAX_PRECISION, 3) + .unwrap(), + ), + vec![ + Some(Variant::Null), + Some(Variant::Null), + Some( + VariantDecimal8::try_new(-max_unscaled_value!(64, DECIMAL64_MAX_PRECISION), 3) + .unwrap() + .into(), + ), + None, + Some(VariantDecimal8::try_new(-123, 3).unwrap().into()), + Some(VariantDecimal8::try_new(0, 3).unwrap().into()), + Some(VariantDecimal8::try_new(123, 3).unwrap().into()), + Some( + VariantDecimal8::try_new(max_unscaled_value!(64, DECIMAL64_MAX_PRECISION), 3) + .unwrap() + .into(), + ), + Some(Variant::Null), + Some(Variant::Null), + ], + ) + } + + #[test] + fn test_cast_to_variant_decimal64_negative_scale() { + run_test( + Arc::new( + Decimal64Array::from(vec![ + Some(i64::MIN), + Some(-max_unscaled_value!(64, DECIMAL64_MAX_PRECISION - 3) - 1), // Overflow value will be cast to Null + Some(-max_unscaled_value!(64, DECIMAL64_MAX_PRECISION - 3)), // The min of Decimal64 with scale -3 that can be cast to VariantDecimal8 + None, + Some(-123), + Some(0), + Some(123), + Some(max_unscaled_value!(64, DECIMAL64_MAX_PRECISION - 3)), // The max of Decimal64 with scale -3 that can be cast to VariantDecimal8 + Some(max_unscaled_value!(64, DECIMAL64_MAX_PRECISION - 3) + 1), // Overflow value will be cast to Null + Some(i64::MAX), + ]) + .with_precision_and_scale(DECIMAL64_MAX_PRECISION, -3) + .unwrap(), + ), + vec![ + Some(Variant::Null), + Some(Variant::Null), + Some( + VariantDecimal8::try_new( + -max_unscaled_value!(64, DECIMAL64_MAX_PRECISION - 3) * 1000, + 0, + ) + .unwrap() + .into(), + ), + None, + Some(VariantDecimal8::try_new(-123_000, 0).unwrap().into()), + Some(VariantDecimal8::try_new(0, 0).unwrap().into()), + Some(VariantDecimal8::try_new(123_000, 0).unwrap().into()), + Some( + VariantDecimal8::try_new( + max_unscaled_value!(64, DECIMAL64_MAX_PRECISION - 3) * 1000, + 0, + ) + .unwrap() + .into(), + ), + Some(Variant::Null), + Some(Variant::Null), + ], + ) + } + + #[test] + fn test_cast_to_variant_decimal128() { + run_test( + Arc::new( + Decimal128Array::from(vec![ + Some(i128::MIN), + Some(-max_unscaled_value!(128, DECIMAL128_MAX_PRECISION) - 1), // Overflow value will be cast to Null + Some(-max_unscaled_value!(128, DECIMAL128_MAX_PRECISION)), // The min of Decimal128 with positive scale that can be cast to VariantDecimal16 + None, + Some(-123), + Some(0), + Some(123), + Some(max_unscaled_value!(128, DECIMAL128_MAX_PRECISION)), // The max of Decimal128 with positive scale that can be cast to VariantDecimal16 + Some(max_unscaled_value!(128, DECIMAL128_MAX_PRECISION) + 1), // Overflow value will be cast to Null + Some(i128::MAX), + ]) + .with_precision_and_scale(DECIMAL128_MAX_PRECISION, 3) + .unwrap(), + ), + vec![ + Some(Variant::Null), + Some(Variant::Null), + Some( + VariantDecimal16::try_new( + -max_unscaled_value!(128, DECIMAL128_MAX_PRECISION), + 3, + ) + .unwrap() + .into(), + ), + None, + Some(VariantDecimal16::try_new(-123, 3).unwrap().into()), + Some(VariantDecimal16::try_new(0, 3).unwrap().into()), + Some(VariantDecimal16::try_new(123, 3).unwrap().into()), + Some( + VariantDecimal16::try_new( + max_unscaled_value!(128, DECIMAL128_MAX_PRECISION), + 3, + ) + .unwrap() + .into(), + ), + Some(Variant::Null), + Some(Variant::Null), + ], + ) + } + + #[test] + fn test_cast_to_variant_decimal128_negative_scale() { + run_test( + Arc::new( + Decimal128Array::from(vec![ + Some(i128::MIN), + Some(-max_unscaled_value!(128, DECIMAL128_MAX_PRECISION - 3) - 1), // Overflow value will be cast to Null + Some(-max_unscaled_value!(128, DECIMAL128_MAX_PRECISION - 3)), // The min of Decimal128 with scale -3 that can be cast to VariantDecimal16 + None, + Some(-123), + Some(0), + Some(123), + Some(max_unscaled_value!(128, DECIMAL128_MAX_PRECISION - 3)), // The max of Decimal128 with scale -3 that can be cast to VariantDecimal16 + Some(max_unscaled_value!(128, DECIMAL128_MAX_PRECISION - 3) + 1), // Overflow value will be cast to Null + Some(i128::MAX), + ]) + .with_precision_and_scale(DECIMAL128_MAX_PRECISION, -3) + .unwrap(), + ), + vec![ + Some(Variant::Null), + Some(Variant::Null), + Some( + VariantDecimal16::try_new( + -max_unscaled_value!(128, DECIMAL128_MAX_PRECISION - 3) * 1000, + 0, + ) + .unwrap() + .into(), + ), + None, + Some(VariantDecimal16::try_new(-123_000, 0).unwrap().into()), + Some(VariantDecimal16::try_new(0, 0).unwrap().into()), + Some(VariantDecimal16::try_new(123_000, 0).unwrap().into()), + Some( + VariantDecimal16::try_new( + max_unscaled_value!(128, DECIMAL128_MAX_PRECISION - 3) * 1000, + 0, + ) + .unwrap() + .into(), + ), + Some(Variant::Null), + Some(Variant::Null), + ], + ) + } + + #[test] + fn test_cast_to_variant_decimal256() { + run_test( + Arc::new( + Decimal256Array::from(vec![ + Some(i256::MIN), + Some(i256::from_i128( + -max_unscaled_value!(128, DECIMAL128_MAX_PRECISION) - 1, + )), // Overflow value will be cast to Null + Some(i256::from_i128(-max_unscaled_value!( + 128, + DECIMAL128_MAX_PRECISION + ))), // The min of Decimal256 with positive scale that can be cast to VariantDecimal16 + None, + Some(i256::from_i128(-123)), + Some(i256::from_i128(0)), + Some(i256::from_i128(123)), + Some(i256::from_i128(max_unscaled_value!( + 128, + DECIMAL128_MAX_PRECISION + ))), // The max of Decimal256 with positive scale that can be cast to VariantDecimal16 + Some(i256::from_i128( + max_unscaled_value!(128, DECIMAL128_MAX_PRECISION) + 1, + )), // Overflow value will be cast to Null + Some(i256::MAX), + ]) + .with_precision_and_scale(DECIMAL128_MAX_PRECISION, 3) + .unwrap(), + ), + vec![ + Some(Variant::Null), + Some(Variant::Null), + Some( + VariantDecimal16::try_new( + -max_unscaled_value!(128, DECIMAL128_MAX_PRECISION), + 3, + ) + .unwrap() + .into(), + ), + None, + Some(VariantDecimal16::try_new(-123, 3).unwrap().into()), + Some(VariantDecimal16::try_new(0, 3).unwrap().into()), + Some(VariantDecimal16::try_new(123, 3).unwrap().into()), + Some( + VariantDecimal16::try_new( + max_unscaled_value!(128, DECIMAL128_MAX_PRECISION), + 3, + ) + .unwrap() + .into(), + ), + Some(Variant::Null), + Some(Variant::Null), + ], + ) + } + + #[test] + fn test_cast_to_variant_decimal256_negative_scale() { + run_test( + Arc::new( + Decimal256Array::from(vec![ + Some(i256::MIN), + Some(i256::from_i128( + -max_unscaled_value!(128, DECIMAL128_MAX_PRECISION - 3) - 1, + )), // Overflow value will be cast to Null + Some(i256::from_i128(-max_unscaled_value!( + 128, + DECIMAL128_MAX_PRECISION - 3 + ))), // The min of Decimal256 with scale -3 that can be cast to VariantDecimal16 + None, + Some(i256::from_i128(-123)), + Some(i256::from_i128(0)), + Some(i256::from_i128(123)), + Some(i256::from_i128(max_unscaled_value!( + 128, + DECIMAL128_MAX_PRECISION - 3 + ))), // The max of Decimal256 with scale -3 that can be cast to VariantDecimal16 + Some(i256::from_i128( + max_unscaled_value!(128, DECIMAL128_MAX_PRECISION - 3) + 1, + )), // Overflow value will be cast to Null + Some(i256::MAX), + ]) + .with_precision_and_scale(DECIMAL128_MAX_PRECISION, -3) + .unwrap(), + ), + vec![ + Some(Variant::Null), + Some(Variant::Null), + Some( + VariantDecimal16::try_new( + -max_unscaled_value!(128, DECIMAL128_MAX_PRECISION - 3) * 1000, + 0, + ) + .unwrap() + .into(), + ), + None, + Some(VariantDecimal16::try_new(-123_000, 0).unwrap().into()), + Some(VariantDecimal16::try_new(0, 0).unwrap().into()), + Some(VariantDecimal16::try_new(123_000, 0).unwrap().into()), + Some( + VariantDecimal16::try_new( + max_unscaled_value!(128, DECIMAL128_MAX_PRECISION - 3) * 1000, + 0, + ) + .unwrap() + .into(), + ), + Some(Variant::Null), + Some(Variant::Null), + ], + ) + } + + #[test] + fn test_cast_to_variant_timestamp() { + let run_array_tests = + |microseconds: i64, array_ntz: Arc, array_tz: Arc| { + let timestamp = DateTime::from_timestamp_nanos(microseconds * 1000); + run_test( + array_tz, + vec![Some(Variant::TimestampMicros(timestamp)), None], + ); + run_test( + array_ntz, + vec![ + Some(Variant::TimestampNtzMicros(timestamp.naive_utc())), + None, + ], + ); + }; + + let nanosecond = 1234567890; + let microsecond = 1234567; + let millisecond = 1234; + let second = 1; + + let second_array = TimestampSecondArray::from(vec![Some(second), None]); + run_array_tests( + second * 1000 * 1000, + Arc::new(second_array.clone()), + Arc::new(second_array.with_timezone("+01:00".to_string())), + ); + + let millisecond_array = TimestampMillisecondArray::from(vec![Some(millisecond), None]); + run_array_tests( + millisecond * 1000, + Arc::new(millisecond_array.clone()), + Arc::new(millisecond_array.with_timezone("+01:00".to_string())), + ); + + let microsecond_array = TimestampMicrosecondArray::from(vec![Some(microsecond), None]); + run_array_tests( + microsecond, + Arc::new(microsecond_array.clone()), + Arc::new(microsecond_array.with_timezone("+01:00".to_string())), + ); + + let timestamp = DateTime::from_timestamp_nanos(nanosecond); + let nanosecond_array = TimestampNanosecondArray::from(vec![Some(nanosecond), None]); + run_test( + Arc::new(nanosecond_array.clone()), + vec![ + Some(Variant::TimestampNtzNanos(timestamp.naive_utc())), + None, + ], + ); + run_test( + Arc::new(nanosecond_array.with_timezone("+01:00".to_string())), + vec![Some(Variant::TimestampNanos(timestamp)), None], + ); + } + + #[test] + fn test_cast_to_variant_date() { + // Date32Array + run_test( + Arc::new(Date32Array::from(vec![ + Some(Date32Type::from_naive_date(NaiveDate::MIN)), + None, + Some(Date32Type::from_naive_date( + NaiveDate::from_ymd_opt(2025, 8, 1).unwrap(), + )), + Some(Date32Type::from_naive_date(NaiveDate::MAX)), + ])), + vec![ + Some(Variant::Date(NaiveDate::MIN)), + None, + Some(Variant::Date(NaiveDate::from_ymd_opt(2025, 8, 1).unwrap())), + Some(Variant::Date(NaiveDate::MAX)), + ], + ); + + // Date64Array + run_test( + Arc::new(Date64Array::from(vec![ + Some(Date64Type::from_naive_date(NaiveDate::MIN)), + None, + Some(Date64Type::from_naive_date( + NaiveDate::from_ymd_opt(2025, 8, 1).unwrap(), + )), + Some(Date64Type::from_naive_date(NaiveDate::MAX)), + ])), + vec![ + Some(Variant::Date(NaiveDate::MIN)), + None, + Some(Variant::Date(NaiveDate::from_ymd_opt(2025, 8, 1).unwrap())), + Some(Variant::Date(NaiveDate::MAX)), + ], + ); + } + + #[test] + fn test_cast_to_variant_time32_second() { + let array: Time32SecondArray = vec![Some(1), Some(86_399), None].into(); + let values = Arc::new(array); + run_test( + values, + vec![ + Some(Variant::Time( + NaiveTime::from_num_seconds_from_midnight_opt(1, 0).unwrap(), + )), + Some(Variant::Time( + NaiveTime::from_num_seconds_from_midnight_opt(86_399, 0).unwrap(), + )), + None, + ], + ) + } + + #[test] + fn test_cast_to_variant_time32_millisecond() { + let array: Time32MillisecondArray = vec![Some(123_456), Some(456_000), None].into(); + let values = Arc::new(array); + run_test( + values, + vec![ + Some(Variant::Time( + NaiveTime::from_num_seconds_from_midnight_opt(123, 456_000_000).unwrap(), + )), + Some(Variant::Time( + NaiveTime::from_num_seconds_from_midnight_opt(456, 0).unwrap(), + )), + None, + ], + ) + } + + #[test] + fn test_cast_to_variant_time64_micro() { + let array: Time64MicrosecondArray = vec![Some(1), Some(123_456_789), None].into(); + let values = Arc::new(array); + run_test( + values, + vec![ + Some(Variant::Time( + NaiveTime::from_num_seconds_from_midnight_opt(0, 1_000).unwrap(), + )), + Some(Variant::Time( + NaiveTime::from_num_seconds_from_midnight_opt(123, 456_789_000).unwrap(), + )), + None, + ], + ) + } + + #[test] + fn test_cast_to_variant_time64_nano() { + let array: Time64NanosecondArray = + vec![Some(1), Some(1001), Some(123_456_789_012), None].into(); + run_test( + Arc::new(array), + // as we can only present with micro second, so the nano second will round donw to 0 + vec![ + Some(Variant::Time( + NaiveTime::from_num_seconds_from_midnight_opt(0, 0).unwrap(), + )), + Some(Variant::Time( + NaiveTime::from_num_seconds_from_midnight_opt(0, 1_000).unwrap(), + )), + Some(Variant::Time( + NaiveTime::from_num_seconds_from_midnight_opt(123, 456_789_000).unwrap(), + )), + None, + ], + ) + } + + #[test] + fn test_cast_to_variant_duration_or_interval_errors() { + let arrays: Vec> = vec![ + // Duration types + Box::new(DurationSecondArray::from(vec![Some(10), None, Some(-5)])), + Box::new(DurationMillisecondArray::from(vec![ + Some(10), + None, + Some(-5), + ])), + Box::new(DurationMicrosecondArray::from(vec![ + Some(10), + None, + Some(-5), + ])), + Box::new(DurationNanosecondArray::from(vec![ + Some(10), + None, + Some(-5), + ])), + // Interval types + Box::new(IntervalYearMonthArray::from(vec![Some(12), None, Some(-6)])), + Box::new(IntervalDayTimeArray::from(vec![ + Some(IntervalDayTime::new(12, 0)), + None, + Some(IntervalDayTime::new(-6, 0)), + ])), + Box::new(IntervalMonthDayNanoArray::from(vec![ + Some(IntervalMonthDayNano::new(12, 0, 0)), + None, + Some(IntervalMonthDayNano::new(-6, 0, 0)), + ])), + ]; + + for array in arrays { + let result = cast_to_variant(array.as_ref()); + assert!(result.is_err()); + match result.unwrap_err() { + ArrowError::InvalidArgumentError(msg) => { + assert!( + msg.contains("Casting duration/interval types to Variant is not supported") + ); + assert!( + msg.contains("The Variant format does not define duration/interval types") + ); } + _ => panic!("Expected InvalidArgumentError"), } } } + + #[test] + fn test_cast_to_variant_binary() { + // BinaryType + let mut builder = GenericByteBuilder::::new(); + builder.append_value(b"hello"); + builder.append_value(b""); + builder.append_null(); + builder.append_value(b"world"); + let binary_array = builder.finish(); + run_test( + Arc::new(binary_array), + vec![ + Some(Variant::Binary(b"hello")), + Some(Variant::Binary(b"")), + None, + Some(Variant::Binary(b"world")), + ], + ); + + // LargeBinaryType + let mut builder = GenericByteBuilder::::new(); + builder.append_value(b"hello"); + builder.append_value(b""); + builder.append_null(); + builder.append_value(b"world"); + let large_binary_array = builder.finish(); + run_test( + Arc::new(large_binary_array), + vec![ + Some(Variant::Binary(b"hello")), + Some(Variant::Binary(b"")), + None, + Some(Variant::Binary(b"world")), + ], + ); + + // BinaryViewType + let mut builder = GenericByteViewBuilder::::new(); + builder.append_value(b"hello"); + builder.append_value(b""); + builder.append_null(); + builder.append_value(b"world"); + let byte_view_array = builder.finish(); + run_test( + Arc::new(byte_view_array), + vec![ + Some(Variant::Binary(b"hello")), + Some(Variant::Binary(b"")), + None, + Some(Variant::Binary(b"world")), + ], + ); + } + + #[test] + fn test_cast_to_variant_fixed_size_binary() { + let v1 = vec![1, 2]; + let v2 = vec![3, 4]; + let v3 = vec![5, 6]; + + let mut builder = FixedSizeBinaryBuilder::new(2); + builder.append_value(&v1).unwrap(); + builder.append_value(&v2).unwrap(); + builder.append_null(); + builder.append_value(&v3).unwrap(); + let array = builder.finish(); + + run_test( + Arc::new(array), + vec![ + Some(Variant::Binary(&v1)), + Some(Variant::Binary(&v2)), + None, + Some(Variant::Binary(&v3)), + ], + ); + } + + #[test] + fn test_cast_to_variant_utf8() { + // Test with short strings (should become ShortString variants) + let short_strings = vec![Some("hello"), Some(""), None, Some("world"), Some("test")]; + let string_array = StringArray::from(short_strings.clone()); + + run_test( + Arc::new(string_array), + vec![ + Some(Variant::from("hello")), + Some(Variant::from("")), + None, + Some(Variant::from("world")), + Some(Variant::from("test")), + ], + ); + + // Test with a long string (should become String variant) + let long_string = "a".repeat(100); // > 63 bytes, so will be Variant::String + let long_strings = vec![Some(long_string.clone()), None, Some("short".to_string())]; + let string_array = StringArray::from(long_strings); + + run_test( + Arc::new(string_array), + vec![ + Some(Variant::from(long_string.as_str())), + None, + Some(Variant::from("short")), + ], + ); + } + + #[test] + fn test_cast_to_variant_large_utf8() { + // Test with short strings (should become ShortString variants) + let short_strings = vec![Some("hello"), Some(""), None, Some("world")]; + let string_array = LargeStringArray::from(short_strings.clone()); + + run_test( + Arc::new(string_array), + vec![ + Some(Variant::from("hello")), + Some(Variant::from("")), + None, + Some(Variant::from("world")), + ], + ); + + // Test with a long string (should become String variant) + let long_string = "b".repeat(100); // > 63 bytes, so will be Variant::String + let long_strings = vec![Some(long_string.clone()), None, Some("short".to_string())]; + let string_array = LargeStringArray::from(long_strings); + + run_test( + Arc::new(string_array), + vec![ + Some(Variant::from(long_string.as_str())), + None, + Some(Variant::from("short")), + ], + ); + } + + #[test] + fn test_cast_to_variant_utf8_view() { + // Test with short strings (should become ShortString variants) + let short_strings = vec![Some("hello"), Some(""), None, Some("world")]; + let string_view_array = StringViewArray::from(short_strings.clone()); + + run_test( + Arc::new(string_view_array), + vec![ + Some(Variant::from("hello")), + Some(Variant::from("")), + None, + Some(Variant::from("world")), + ], + ); + + // Test with a long string (should become String variant) + let long_string = "c".repeat(100); // > 63 bytes, so will be Variant::String + let long_strings = vec![Some(long_string.clone()), None, Some("short".to_string())]; + let string_view_array = StringViewArray::from(long_strings); + + run_test( + Arc::new(string_view_array), + vec![ + Some(Variant::from(long_string.as_str())), + None, + Some(Variant::from("short")), + ], + ); + } + + #[test] + fn test_cast_to_variant_list() { + // List Array + let data = vec![Some(vec![Some(0), Some(1), Some(2)]), None]; + let list_array = ListArray::from_iter_primitive::(data); + + // Expected value + let (metadata, value) = { + let mut builder = VariantBuilder::new(); + let mut list = builder.new_list(); + list.append_value(0); + list.append_value(1); + list.append_value(2); + list.finish(); + builder.finish() + }; + let variant = Variant::new(&metadata, &value); + + run_test(Arc::new(list_array), vec![Some(variant), None]); + } + + #[test] + fn test_cast_to_variant_sliced_list() { + // List Array + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + Some(vec![Some(3), Some(4), Some(5)]), + None, + ]; + let list_array = ListArray::from_iter_primitive::(data); + + // Expected value + let (metadata, value) = { + let mut builder = VariantBuilder::new(); + let mut list = builder.new_list(); + list.append_value(3); + list.append_value(4); + list.append_value(5); + list.finish(); + builder.finish() + }; + let variant = Variant::new(&metadata, &value); + + run_test(Arc::new(list_array.slice(1, 2)), vec![Some(variant), None]); + } + + #[test] + fn test_cast_to_variant_large_list() { + // Large List Array + let data = vec![Some(vec![Some(0), Some(1), Some(2)]), None]; + let large_list_array = LargeListArray::from_iter_primitive::(data); + + // Expected value + let (metadata, value) = { + let mut builder = VariantBuilder::new(); + let mut list = builder.new_list(); + list.append_value(0i64); + list.append_value(1i64); + list.append_value(2i64); + list.finish(); + builder.finish() + }; + let variant = Variant::new(&metadata, &value); + + run_test(Arc::new(large_list_array), vec![Some(variant), None]); + } + + #[test] + fn test_cast_to_variant_sliced_large_list() { + // List Array + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + Some(vec![Some(3), Some(4), Some(5)]), + None, + ]; + let large_list_array = ListArray::from_iter_primitive::(data); + + // Expected value + let (metadata, value) = { + let mut builder = VariantBuilder::new(); + let mut list = builder.new_list(); + list.append_value(3i64); + list.append_value(4i64); + list.append_value(5i64); + list.finish(); + builder.finish() + }; + let variant = Variant::new(&metadata, &value); + + run_test( + Arc::new(large_list_array.slice(1, 2)), + vec![Some(variant), None], + ); + } + + #[test] + fn test_cast_to_variant_struct() { + // Test a simple struct with two fields: id (int64) and age (int32) + let id_array = Int64Array::from(vec![Some(1001), Some(1002), None, Some(1003)]); + let age_array = Int32Array::from(vec![Some(25), Some(30), Some(35), None]); + + let fields = Fields::from(vec![ + Field::new("id", DataType::Int64, true), + Field::new("age", DataType::Int32, true), + ]); + + let struct_array = StructArray::new( + fields, + vec![Arc::new(id_array), Arc::new(age_array)], + None, // no nulls at the struct level + ); + + let result = cast_to_variant(&struct_array).unwrap(); + assert_eq!(result.len(), 4); + + // Check first row: {"id": 1001, "age": 25} + let variant1 = result.value(0); + let obj1 = variant1.as_object().unwrap(); + assert_eq!(obj1.get("id"), Some(Variant::from(1001i64))); + assert_eq!(obj1.get("age"), Some(Variant::from(25i32))); + + // Check second row: {"id": 1002, "age": 30} + let variant2 = result.value(1); + let obj2 = variant2.as_object().unwrap(); + assert_eq!(obj2.get("id"), Some(Variant::from(1002i64))); + assert_eq!(obj2.get("age"), Some(Variant::from(30i32))); + + // Check third row: {"age": 35} (id is null, so omitted) + let variant3 = result.value(2); + let obj3 = variant3.as_object().unwrap(); + assert_eq!(obj3.get("id"), None); + assert_eq!(obj3.get("age"), Some(Variant::from(35i32))); + + // Check fourth row: {"id": 1003} (age is null, so omitted) + let variant4 = result.value(3); + let obj4 = variant4.as_object().unwrap(); + assert_eq!(obj4.get("id"), Some(Variant::from(1003i64))); + assert_eq!(obj4.get("age"), None); + } + + #[test] + fn test_cast_to_variant_struct_with_nulls() { + // Test struct with null values at the struct level + let id_array = Int64Array::from(vec![Some(1001), Some(1002)]); + let age_array = Int32Array::from(vec![Some(25), Some(30)]); + + let fields = Fields::from(vec![ + Field::new("id", DataType::Int64, false), + Field::new("age", DataType::Int32, false), + ]); + + // Create null buffer to make second row null + let null_buffer = NullBuffer::from(vec![true, false]); + + let struct_array = StructArray::new( + fields, + vec![Arc::new(id_array), Arc::new(age_array)], + Some(null_buffer), + ); + + let result = cast_to_variant(&struct_array).unwrap(); + assert_eq!(result.len(), 2); + + // Check first row: {"id": 1001, "age": 25} + assert!(!result.is_null(0)); + let variant1 = result.value(0); + let obj1 = variant1.as_object().unwrap(); + assert_eq!(obj1.get("id"), Some(Variant::from(1001i64))); + assert_eq!(obj1.get("age"), Some(Variant::from(25i32))); + + // Check second row: null struct + assert!(result.is_null(1)); + } + + #[test] + fn test_cast_to_variant_struct_performance() { + // Test with a larger struct to demonstrate performance optimization + // This test ensures that field arrays are only converted once, not per row + let size = 1000; + + let id_array = Int64Array::from((0..size).map(|i| Some(i as i64)).collect::>()); + let age_array = Int32Array::from( + (0..size) + .map(|i| Some((i % 100) as i32)) + .collect::>(), + ); + let score_array = + Float64Array::from((0..size).map(|i| Some(i as f64 * 0.1)).collect::>()); + + let fields = Fields::from(vec![ + Field::new("id", DataType::Int64, false), + Field::new("age", DataType::Int32, false), + Field::new("score", DataType::Float64, false), + ]); + + let struct_array = StructArray::new( + fields, + vec![ + Arc::new(id_array), + Arc::new(age_array), + Arc::new(score_array), + ], + None, + ); + + let result = cast_to_variant(&struct_array).unwrap(); + assert_eq!(result.len(), size); + + // Verify a few sample rows + let variant0 = result.value(0); + let obj0 = variant0.as_object().unwrap(); + assert_eq!(obj0.get("id"), Some(Variant::from(0i64))); + assert_eq!(obj0.get("age"), Some(Variant::from(0i32))); + assert_eq!(obj0.get("score"), Some(Variant::from(0.0f64))); + + let variant999 = result.value(999); + let obj999 = variant999.as_object().unwrap(); + assert_eq!(obj999.get("id"), Some(Variant::from(999i64))); + assert_eq!(obj999.get("age"), Some(Variant::from(99i32))); // 999 % 100 = 99 + assert_eq!(obj999.get("score"), Some(Variant::from(99.9f64))); + } + + #[test] + fn test_cast_to_variant_struct_performance_large() { + // Test with even larger struct and more fields to demonstrate optimization benefits + let size = 10000; + let num_fields = 10; + + // Create arrays for many fields + let mut field_arrays: Vec = Vec::new(); + let mut fields = Vec::new(); + + for field_idx in 0..num_fields { + match field_idx % 4 { + 0 => { + // Int64 fields + let array = Int64Array::from( + (0..size) + .map(|i| Some(i as i64 + field_idx as i64)) + .collect::>(), + ); + field_arrays.push(Arc::new(array)); + fields.push(Field::new( + format!("int_field_{}", field_idx), + DataType::Int64, + false, + )); + } + 1 => { + // Int32 fields + let array = Int32Array::from( + (0..size) + .map(|i| Some((i % 1000) as i32 + field_idx as i32)) + .collect::>(), + ); + field_arrays.push(Arc::new(array)); + fields.push(Field::new( + format!("int32_field_{}", field_idx), + DataType::Int32, + false, + )); + } + 2 => { + // Float64 fields + let array = Float64Array::from( + (0..size) + .map(|i| Some(i as f64 * 0.1 + field_idx as f64)) + .collect::>(), + ); + field_arrays.push(Arc::new(array)); + fields.push(Field::new( + format!("float_field_{}", field_idx), + DataType::Float64, + false, + )); + } + _ => { + // Binary fields + let binary_data: Vec> = (0..size) + .map(|i| { + // Use static data to avoid lifetime issues in tests + match i % 3 { + 0 => Some(b"test_data_0" as &[u8]), + 1 => Some(b"test_data_1" as &[u8]), + _ => Some(b"test_data_2" as &[u8]), + } + }) + .collect(); + let array = BinaryArray::from(binary_data); + field_arrays.push(Arc::new(array)); + fields.push(Field::new( + format!("binary_field_{}", field_idx), + DataType::Binary, + false, + )); + } + } + } + + let struct_array = StructArray::new(Fields::from(fields), field_arrays, None); + + let result = cast_to_variant(&struct_array).unwrap(); + assert_eq!(result.len(), size); + + // Verify a sample of rows + for sample_idx in [0, size / 4, size / 2, size - 1] { + let variant = result.value(sample_idx); + let obj = variant.as_object().unwrap(); + + // Should have all fields + assert_eq!(obj.len(), num_fields); + + // Verify a few field values + if let Some(int_field_0) = obj.get("int_field_0") { + assert_eq!(int_field_0, Variant::from(sample_idx as i64)); + } + if let Some(float_field_2) = obj.get("float_field_2") { + assert_eq!(float_field_2, Variant::from(sample_idx as f64 * 0.1 + 2.0)); + } + } + } + + #[test] + fn test_cast_to_variant_nested_struct() { + // Test nested struct: person with location struct + let id_array = Int64Array::from(vec![Some(1001), Some(1002)]); + let x_array = Float64Array::from(vec![Some(40.7), Some(37.8)]); + let y_array = Float64Array::from(vec![Some(-74.0), Some(-122.4)]); + + // Create location struct + let location_fields = Fields::from(vec![ + Field::new("x", DataType::Float64, true), + Field::new("y", DataType::Float64, true), + ]); + let location_struct = StructArray::new( + location_fields.clone(), + vec![Arc::new(x_array), Arc::new(y_array)], + None, + ); + + // Create person struct containing location + let person_fields = Fields::from(vec![ + Field::new("id", DataType::Int64, true), + Field::new("location", DataType::Struct(location_fields), true), + ]); + let person_struct = StructArray::new( + person_fields, + vec![Arc::new(id_array), Arc::new(location_struct)], + None, + ); + + let result = cast_to_variant(&person_struct).unwrap(); + assert_eq!(result.len(), 2); + + // Check first row + let variant1 = result.value(0); + let obj1 = variant1.as_object().unwrap(); + assert_eq!(obj1.get("id"), Some(Variant::from(1001i64))); + + let location_variant1 = obj1.get("location").unwrap(); + let location_obj1 = location_variant1.as_object().unwrap(); + assert_eq!(location_obj1.get("x"), Some(Variant::from(40.7f64))); + assert_eq!(location_obj1.get("y"), Some(Variant::from(-74.0f64))); + + // Check second row + let variant2 = result.value(1); + let obj2 = variant2.as_object().unwrap(); + assert_eq!(obj2.get("id"), Some(Variant::from(1002i64))); + + let location_variant2 = obj2.get("location").unwrap(); + let location_obj2 = location_variant2.as_object().unwrap(); + assert_eq!(location_obj2.get("x"), Some(Variant::from(37.8f64))); + assert_eq!(location_obj2.get("y"), Some(Variant::from(-122.4f64))); + } + + #[test] + fn test_cast_to_variant_map() { + let keys = vec!["key1", "key2", "key3"]; + let values_data = Int32Array::from(vec![1, 2, 3]); + let entry_offsets = vec![0, 1, 3]; + let map_array = + MapArray::new_from_strings(keys.clone().into_iter(), &values_data, &entry_offsets) + .unwrap(); + + let result = cast_to_variant(&map_array).unwrap(); + // [{"key1":1}] + let variant1 = result.value(0); + assert_eq!( + variant1.as_object().unwrap().get("key1").unwrap(), + Variant::from(1) + ); + + // [{"key2":2},{"key3":3}] + let variant2 = result.value(1); + assert_eq!( + variant2.as_object().unwrap().get("key2").unwrap(), + Variant::from(2) + ); + assert_eq!( + variant2.as_object().unwrap().get("key3").unwrap(), + Variant::from(3) + ); + } + + #[test] + fn test_cast_to_variant_map_with_nulls_and_empty() { + use arrow::array::{Int32Array, MapArray, StringArray, StructArray}; + use arrow::buffer::{NullBuffer, OffsetBuffer}; + use arrow::datatypes::{DataType, Field, Fields}; + use std::sync::Arc; + + // Create entries struct array + let keys = StringArray::from(vec!["key1", "key2", "key3"]); + let values = Int32Array::from(vec![1, 2, 3]); + let entries_fields = Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ]); + let entries = StructArray::new( + entries_fields.clone(), + vec![Arc::new(keys), Arc::new(values)], + None, + ); + + // Create offsets for 4 maps: [0..1], [1..1], [1..1], [1..3] + let offsets = OffsetBuffer::new(vec![0, 1, 1, 1, 3].into()); + + // Create null buffer - map at index 2 is NULL + let null_buffer = Some(NullBuffer::from(vec![true, true, false, true])); + + let map_field = Arc::new(Field::new( + "entries", + DataType::Struct(entries_fields), + false, + )); + + let map_array = MapArray::try_new(map_field, offsets, entries, null_buffer, false).unwrap(); + + let result = cast_to_variant(&map_array).unwrap(); + + // Map 0: {"key1": 1} + let variant0 = result.value(0); + assert_eq!( + variant0.as_object().unwrap().get("key1").unwrap(), + Variant::from(1) + ); + + // Map 1: {} (empty, not null) + let variant1 = result.value(1); + let obj1 = variant1.as_object().unwrap(); + assert_eq!(obj1.len(), 0); // Empty object + + // Map 2: null (actual NULL) + assert!(result.is_null(2)); + + // Map 3: {"key2": 2, "key3": 3} + let variant3 = result.value(3); + assert_eq!( + variant3.as_object().unwrap().get("key2").unwrap(), + Variant::from(2) + ); + assert_eq!( + variant3.as_object().unwrap().get("key3").unwrap(), + Variant::from(3) + ); + } + + #[test] + fn test_cast_to_variant_map_with_non_string_keys() { + let offsets = OffsetBuffer::new(vec![0, 1, 3].into()); + let fields = Fields::from(vec![ + Field::new("key", DataType::Int32, false), + Field::new("values", DataType::Int32, false), + ]); + let columns = vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])) as _, + Arc::new(Int32Array::from(vec![1, 2, 3])) as _, + ]; + + let entries = StructArray::new(fields.clone(), columns, None); + let field = Arc::new(Field::new("entries", DataType::Struct(fields), false)); + + let map_array = MapArray::new(field.clone(), offsets.clone(), entries.clone(), None, false); + + let result = cast_to_variant(&map_array).unwrap(); + + let variant1 = result.value(0); + assert_eq!( + variant1.as_object().unwrap().get("1").unwrap(), + Variant::from(1) + ); + + let variant2 = result.value(1); + assert_eq!( + variant2.as_object().unwrap().get("2").unwrap(), + Variant::from(2) + ); + assert_eq!( + variant2.as_object().unwrap().get("3").unwrap(), + Variant::from(3) + ); + } + + #[test] + fn test_cast_to_variant_union_sparse() { + // Create a sparse union array with mixed types (int, float, string) + let int_array = Int32Array::from(vec![Some(1), None, None, None, Some(34), None]); + let float_array = Float64Array::from(vec![None, Some(3.2), None, Some(32.5), None, None]); + let string_array = StringArray::from(vec![None, None, Some("hello"), None, None, None]); + let type_ids = [0, 1, 2, 1, 0, 0].into_iter().collect::>(); + + let union_fields = UnionFields::new( + vec![0, 1, 2], + vec![ + Field::new("int_field", DataType::Int32, false), + Field::new("float_field", DataType::Float64, false), + Field::new("string_field", DataType::Utf8, false), + ], + ); + + let children: Vec> = vec![ + Arc::new(int_array), + Arc::new(float_array), + Arc::new(string_array), + ]; + + let union_array = UnionArray::try_new( + union_fields, + type_ids, + None, // Sparse union + children, + ) + .unwrap(); + + run_test( + Arc::new(union_array), + vec![ + Some(Variant::Int32(1)), + Some(Variant::Double(3.2)), + Some(Variant::from("hello")), + Some(Variant::Double(32.5)), + Some(Variant::Int32(34)), + None, + ], + ); + } + + #[test] + fn test_cast_to_variant_union_dense() { + // Create a dense union array with mixed types (int, float, string) + let int_array = Int32Array::from(vec![Some(1), Some(34), None]); + let float_array = Float64Array::from(vec![3.2, 32.5]); + let string_array = StringArray::from(vec!["hello"]); + let type_ids = [0, 1, 2, 1, 0, 0].into_iter().collect::>(); + let offsets = [0, 0, 0, 1, 1, 2] + .into_iter() + .collect::>(); + + let union_fields = UnionFields::new( + vec![0, 1, 2], + vec![ + Field::new("int_field", DataType::Int32, false), + Field::new("float_field", DataType::Float64, false), + Field::new("string_field", DataType::Utf8, false), + ], + ); + + let children: Vec> = vec![ + Arc::new(int_array), + Arc::new(float_array), + Arc::new(string_array), + ]; + + let union_array = UnionArray::try_new( + union_fields, + type_ids, + Some(offsets), // Dense union + children, + ) + .unwrap(); + + run_test( + Arc::new(union_array), + vec![ + Some(Variant::Int32(1)), + Some(Variant::Double(3.2)), + Some(Variant::from("hello")), + Some(Variant::Double(32.5)), + Some(Variant::Int32(34)), + None, + ], + ); + } + + #[test] + fn test_cast_to_variant_dictionary() { + let values = StringArray::from(vec!["apple", "banana", "cherry", "date"]); + let keys = Int32Array::from(vec![Some(0), Some(1), None, Some(2), Some(0), Some(3)]); + let dict_array = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + + run_test( + Arc::new(dict_array), + vec![ + Some(Variant::from("apple")), + Some(Variant::from("banana")), + None, + Some(Variant::from("cherry")), + Some(Variant::from("apple")), + Some(Variant::from("date")), + ], + ); + } + + #[test] + fn test_cast_to_variant_dictionary_with_nulls() { + // Test dictionary with null values in the values array + let values = StringArray::from(vec![Some("a"), None, Some("c")]); + let keys = Int8Array::from(vec![Some(0), Some(1), Some(2), Some(0)]); + let dict_array = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + + run_test( + Arc::new(dict_array), + vec![ + Some(Variant::from("a")), + None, // key 1 points to null value + Some(Variant::from("c")), + Some(Variant::from("a")), + ], + ); + } + + #[test] + fn test_cast_to_variant_run_end_encoded() { + let mut builder = StringRunBuilder::::new(); + builder.append_value("apple"); + builder.append_value("apple"); + builder.append_value("banana"); + builder.append_value("banana"); + builder.append_value("banana"); + builder.append_value("cherry"); + let run_array = builder.finish(); + + run_test( + Arc::new(run_array), + vec![ + Some(Variant::from("apple")), + Some(Variant::from("apple")), + Some(Variant::from("banana")), + Some(Variant::from("banana")), + Some(Variant::from("banana")), + Some(Variant::from("cherry")), + ], + ); + } + + #[test] + fn test_cast_to_variant_run_end_encoded_with_nulls() { + use arrow::array::StringRunBuilder; + use arrow::datatypes::Int32Type; + + // Test run-end encoded array with nulls + let mut builder = StringRunBuilder::::new(); + builder.append_value("apple"); + builder.append_null(); + builder.append_value("banana"); + builder.append_value("banana"); + builder.append_null(); + builder.append_null(); + let run_array = builder.finish(); + + run_test( + Arc::new(run_array), + vec![ + Some(Variant::from("apple")), + None, + Some(Variant::from("banana")), + Some(Variant::from("banana")), + None, + None, + ], + ); + } + + /// Converts the given `Array` to a `VariantArray` and tests the conversion + /// against the expected values. It also tests the handling of nulls by + /// setting one element to null and verifying the output. + fn run_test_with_options(values: ArrayRef, expected: Vec>, strict: bool) { + let options = CastOptions { strict }; + let variant_array = cast_to_variant_with_options(&values, &options).unwrap(); + assert_eq!(variant_array.len(), expected.len()); + for (i, expected_value) in expected.iter().enumerate() { + match expected_value { + Some(value) => { + assert!(!variant_array.is_null(i), "Expected non-null at index {i}"); + assert_eq!(variant_array.value(i), *value, "mismatch at index {i}"); + } + None => { + assert!(variant_array.is_null(i), "Expected null at index {i}"); + } + } + } + } + + fn run_test(values: ArrayRef, expected: Vec>) { + run_test_with_options(values, expected, true); + } + + fn run_test_non_strict(values: ArrayRef, expected: Vec>) { + run_test_with_options(values, expected, false); + } + + #[test] + fn test_cast_to_variant_non_strict_mode_date64() { + let date64_values = Date64Array::from(vec![Some(i64::MAX), Some(0), Some(i64::MIN)]); + + let values = Arc::new(date64_values); + run_test_non_strict( + values, + vec![ + None, + Some(Variant::Date(Date64Type::to_naive_date_opt(0).unwrap())), + None, + ], + ); + } + + #[test] + fn test_cast_to_variant_non_strict_mode_time32() { + let time32_array = Time32SecondArray::from(vec![Some(90000), Some(3600), Some(-1)]); + + let values = Arc::new(time32_array); + run_test_non_strict( + values, + vec![ + None, + Some(Variant::Time( + NaiveTime::from_num_seconds_from_midnight_opt(3600, 0).unwrap(), + )), + None, + ], + ); + } + + #[test] + fn test_cast_to_variant_non_strict_mode_timestamp() { + use arrow::temporal_conversions::timestamp_s_to_datetime; + + let ts_array = TimestampSecondArray::from(vec![Some(i64::MAX), Some(0), Some(1609459200)]) + .with_timezone_opt(None::<&str>); + + let values = Arc::new(ts_array); + run_test_non_strict( + values, + vec![ + None, // Invalid timestamp becomes null + Some(Variant::TimestampNtzMicros( + timestamp_s_to_datetime(0).unwrap(), + )), + Some(Variant::TimestampNtzMicros( + timestamp_s_to_datetime(1609459200).unwrap(), + )), + ], + ); + } } diff --git a/parquet-variant-compute/src/from_json.rs b/parquet-variant-compute/src/from_json.rs index a101bf01cfda..0983147132a2 100644 --- a/parquet-variant-compute/src/from_json.rs +++ b/parquet-variant-compute/src/from_json.rs @@ -19,46 +19,61 @@ //! STRUCT use crate::{VariantArray, VariantArrayBuilder}; -use arrow::array::{Array, ArrayRef, StringArray}; +use arrow::array::{Array, ArrayRef, LargeStringArray, StringArray, StringViewArray}; use arrow_schema::ArrowError; -use parquet_variant_json::json_to_variant; +use parquet_variant_json::JsonToVariant; + +/// Macro to convert string array to variant array +macro_rules! string_array_to_variant { + ($input:expr, $array:expr, $builder:expr) => {{ + for i in 0..$input.len() { + if $input.is_null(i) { + $builder.append_null(); + } else { + $builder.append_json($array.value(i))?; + } + } + }}; +} /// Parse a batch of JSON strings into a batch of Variants represented as /// STRUCT where nulls are preserved. The JSON strings in the input /// must be valid. -pub fn batch_json_string_to_variant(input: &ArrayRef) -> Result { - let input_string_array = match input.as_any().downcast_ref::() { - Some(string_array) => Ok(string_array), - None => Err(ArrowError::CastError( - "Expected reference to StringArray as input".into(), - )), - }?; - - let mut variant_array_builder = VariantArrayBuilder::new(input_string_array.len()); - for i in 0..input.len() { - if input.is_null(i) { - // The subfields are expected to be non-nullable according to the parquet variant spec. - variant_array_builder.append_null(); - } else { - let mut vb = variant_array_builder.variant_builder(); - // parse JSON directly to the variant builder - json_to_variant(input_string_array.value(i), &mut vb)?; - vb.finish() - } +/// +/// Supports the following string array types: +/// - [`StringArray`] +/// - [`LargeStringArray`] +/// - [`StringViewArray`] +pub fn json_to_variant(input: &ArrayRef) -> Result { + let mut variant_array_builder = VariantArrayBuilder::new(input.len()); + + // Try each string array type in sequence + if let Some(string_array) = input.as_any().downcast_ref::() { + string_array_to_variant!(input, string_array, variant_array_builder); + } else if let Some(large_string_array) = input.as_any().downcast_ref::() { + string_array_to_variant!(input, large_string_array, variant_array_builder); + } else if let Some(string_view_array) = input.as_any().downcast_ref::() { + string_array_to_variant!(input, string_view_array, variant_array_builder); + } else { + return Err(ArrowError::CastError( + "Expected reference to StringArray, LargeStringArray, or StringViewArray as input" + .into(), + )); } + Ok(variant_array_builder.build()) } #[cfg(test)] mod test { - use crate::batch_json_string_to_variant; - use arrow::array::{Array, ArrayRef, StringArray}; + use crate::json_to_variant; + use arrow::array::{Array, ArrayRef, LargeStringArray, StringArray, StringViewArray}; use arrow_schema::ArrowError; use parquet_variant::{Variant, VariantBuilder}; use std::sync::Arc; #[test] - fn test_batch_json_string_to_variant() -> Result<(), ArrowError> { + fn test_json_to_variant() -> Result<(), ArrowError> { let input = StringArray::from(vec![ Some("1"), None, @@ -67,7 +82,105 @@ mod test { None, ]); let array_ref: ArrayRef = Arc::new(input); - let variant_array = batch_json_string_to_variant(&array_ref).unwrap(); + let variant_array = json_to_variant(&array_ref).unwrap(); + + let metadata_array = variant_array.metadata_field(); + let value_array = variant_array.value_field().expect("value field"); + + // Compare row 0 + assert!(!variant_array.is_null(0)); + assert_eq!(variant_array.value(0), Variant::Int8(1)); + + // Compare row 1 + assert!(variant_array.is_null(1)); + + // Compare row 2 + assert!(!variant_array.is_null(2)); + { + let mut vb = VariantBuilder::new(); + let mut ob = vb.new_object(); + ob.insert("a", Variant::Int8(32)); + ob.finish(); + let (object_metadata, object_value) = vb.finish(); + let expected = Variant::new(&object_metadata, &object_value); + assert_eq!(variant_array.value(2), expected); + } + + // Compare row 3 (Note this is a variant NULL, not a null row) + assert!(!variant_array.is_null(3)); + assert_eq!(variant_array.value(3), Variant::Null); + + // Compare row 4 + assert!(variant_array.is_null(4)); + + // Ensure that the subfields are not nullable + assert!(!metadata_array.is_null(1)); + assert!(!value_array.is_null(1)); + assert!(!metadata_array.is_null(4)); + assert!(!value_array.is_null(4)); + Ok(()) + } + + #[test] + fn test_json_to_variant_large_string() -> Result<(), ArrowError> { + let input = LargeStringArray::from(vec![ + Some("1"), + None, + Some("{\"a\": 32}"), + Some("null"), + None, + ]); + let array_ref: ArrayRef = Arc::new(input); + let variant_array = json_to_variant(&array_ref).unwrap(); + + let metadata_array = variant_array.metadata_field(); + let value_array = variant_array.value_field().expect("value field"); + + // Compare row 0 + assert!(!variant_array.is_null(0)); + assert_eq!(variant_array.value(0), Variant::Int8(1)); + + // Compare row 1 + assert!(variant_array.is_null(1)); + + // Compare row 2 + assert!(!variant_array.is_null(2)); + { + let mut vb = VariantBuilder::new(); + let mut ob = vb.new_object(); + ob.insert("a", Variant::Int8(32)); + ob.finish(); + let (object_metadata, object_value) = vb.finish(); + let expected = Variant::new(&object_metadata, &object_value); + assert_eq!(variant_array.value(2), expected); + } + + // Compare row 3 (Note this is a variant NULL, not a null row) + assert!(!variant_array.is_null(3)); + assert_eq!(variant_array.value(3), Variant::Null); + + // Compare row 4 + assert!(variant_array.is_null(4)); + + // Ensure that the subfields are not nullable + assert!(!metadata_array.is_null(1)); + assert!(!value_array.is_null(1)); + assert!(!metadata_array.is_null(4)); + assert!(!value_array.is_null(4)); + Ok(()) + } + + #[test] + fn test_json_to_variant_string_view() -> Result<(), ArrowError> { + let input = StringViewArray::from(vec![ + Some("1"), + None, + Some("{\"a\": 32}"), + Some("null"), + None, + ]); + let array_ref: ArrayRef = Arc::new(input); + let variant_array = json_to_variant(&array_ref).unwrap(); let metadata_array = variant_array.metadata_field(); let value_array = variant_array.value_field().expect("value field"); @@ -85,7 +198,7 @@ mod test { let mut vb = VariantBuilder::new(); let mut ob = vb.new_object(); ob.insert("a", Variant::Int8(32)); - ob.finish()?; + ob.finish(); let (object_metadata, object_value) = vb.finish(); let expected = Variant::new(&object_metadata, &object_value); assert_eq!(variant_array.value(2), expected); diff --git a/parquet-variant-compute/src/lib.rs b/parquet-variant-compute/src/lib.rs index de7fc720be93..999e118367ac 100644 --- a/parquet-variant-compute/src/lib.rs +++ b/parquet-variant-compute/src/lib.rs @@ -20,9 +20,9 @@ //! ## Main APIs //! - [`VariantArray`] : Represents an array of `Variant` values. //! - [`VariantArrayBuilder`]: For building [`VariantArray`] -//! - [`batch_json_string_to_variant`]: Function to convert a batch of JSON strings to a `VariantArray`. -//! - [`batch_variant_to_json_string`]: Function to convert a `VariantArray` to a batch of JSON strings. -//! - [`cast_to_variant`]: Module to cast other Arrow arrays to `VariantArray`. +//! - [`json_to_variant`]: Function to convert a batch of JSON strings to a `VariantArray`. +//! - [`variant_to_json`]: Function to convert a `VariantArray` to a batch of JSON strings. +//! - [`mod@cast_to_variant`]: Module to cast other Arrow arrays to `VariantArray`. //! - [`variant_get`]: Module to get values from a `VariantArray` using a specified [`VariantPath`] //! //! ## 🚧 Work In Progress @@ -35,15 +35,20 @@ //! [`VariantPath`]: parquet_variant::VariantPath //! [Variant issue]: https://github.com/apache/arrow-rs/issues/6736 +mod arrow_to_variant; pub mod cast_to_variant; mod from_json; mod to_json; +mod type_conversion; mod variant_array; mod variant_array_builder; pub mod variant_get; +mod variant_to_arrow; pub use variant_array::{ShreddingState, VariantArray}; -pub use variant_array_builder::{VariantArrayBuilder, VariantArrayVariantBuilder}; +pub use variant_array_builder::VariantArrayBuilder; -pub use from_json::batch_json_string_to_variant; -pub use to_json::batch_variant_to_json_string; +pub use cast_to_variant::{cast_to_variant, cast_to_variant_with_options}; +pub use from_json::json_to_variant; +pub use to_json::variant_to_json; +pub use type_conversion::CastOptions; diff --git a/parquet-variant-compute/src/to_json.rs b/parquet-variant-compute/src/to_json.rs index c7c4653ac780..1d6f51ca2446 100644 --- a/parquet-variant-compute/src/to_json.rs +++ b/parquet-variant-compute/src/to_json.rs @@ -23,11 +23,11 @@ use arrow::buffer::{Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow::datatypes::DataType; use arrow_schema::ArrowError; use parquet_variant::Variant; -use parquet_variant_json::variant_to_json; +use parquet_variant_json::VariantToJson; /// Transform a batch of Variant represented as STRUCT to a batch /// of JSON strings where nulls are preserved. The JSON strings in the input must be valid. -pub fn batch_variant_to_json_string(input: &ArrayRef) -> Result { +pub fn variant_to_json(input: &ArrayRef) -> Result { let struct_array = input .as_any() .downcast_ref::() @@ -83,7 +83,7 @@ pub fn batch_variant_to_json_string(input: &ArrayRef) -> Result Result Self { + Self { strict: true } + } +} + +/// Helper trait for converting `Variant` values to arrow primitive values. +pub(crate) trait VariantAsPrimitive { + fn as_primitive(&self) -> Option; +} + +impl VariantAsPrimitive for Variant<'_, '_> { + fn as_primitive(&self) -> Option { + self.as_int32() + } +} +impl VariantAsPrimitive for Variant<'_, '_> { + fn as_primitive(&self) -> Option { + self.as_int16() + } +} +impl VariantAsPrimitive for Variant<'_, '_> { + fn as_primitive(&self) -> Option { + self.as_int8() + } +} +impl VariantAsPrimitive for Variant<'_, '_> { + fn as_primitive(&self) -> Option { + self.as_int64() + } +} +impl VariantAsPrimitive for Variant<'_, '_> { + fn as_primitive(&self) -> Option { + self.as_f16() + } +} +impl VariantAsPrimitive for Variant<'_, '_> { + fn as_primitive(&self) -> Option { + self.as_f32() + } +} +impl VariantAsPrimitive for Variant<'_, '_> { + fn as_primitive(&self) -> Option { + self.as_f64() + } +} + +impl VariantAsPrimitive for Variant<'_, '_> { + fn as_primitive(&self) -> Option { + self.as_u8() + } +} + +impl VariantAsPrimitive for Variant<'_, '_> { + fn as_primitive(&self) -> Option { + self.as_u16() + } +} + +impl VariantAsPrimitive for Variant<'_, '_> { + fn as_primitive(&self) -> Option { + self.as_u32() + } +} + +impl VariantAsPrimitive for Variant<'_, '_> { + fn as_primitive(&self) -> Option { + self.as_u64() + } +} + +/// Convert the value at a specific index in the given array into a `Variant`. +macro_rules! non_generic_conversion_single_value { + ($array:expr, $cast_fn:expr, $index:expr) => {{ + let array = $array; + if array.is_null($index) { + Variant::Null + } else { + let cast_value = $cast_fn(array.value($index)); + Variant::from(cast_value) + } + }}; +} +pub(crate) use non_generic_conversion_single_value; + +/// Convert the value at a specific index in the given array into a `Variant`, +/// using `method` requiring a generic type to downcast the generic array +/// to a specific array type and `cast_fn` to transform the element. +macro_rules! generic_conversion_single_value { + ($t:ty, $method:ident, $cast_fn:expr, $input:expr, $index:expr) => {{ + $crate::type_conversion::non_generic_conversion_single_value!( + $input.$method::<$t>(), + $cast_fn, + $index + ) + }}; +} +pub(crate) use generic_conversion_single_value; + +/// Convert the value at a specific index in the given array into a `Variant`. +macro_rules! primitive_conversion_single_value { + ($t:ty, $input:expr, $index:expr) => {{ + $crate::type_conversion::generic_conversion_single_value!( + $t, + as_primitive, + |v| v, + $input, + $index + ) + }}; +} +pub(crate) use primitive_conversion_single_value; + +/// Convert a decimal value to a `VariantDecimal` +macro_rules! decimal_to_variant_decimal { + ($v:ident, $scale:expr, $value_type:ty, $variant_type:ty) => {{ + let (v, scale) = if *$scale < 0 { + // For negative scale, we need to multiply the value by 10^|scale| + // For example: 123 with scale -2 becomes 12300 with scale 0 + let multiplier = <$value_type>::pow(10, (-*$scale) as u32); + (<$value_type>::checked_mul($v, multiplier), 0u8) + } else { + (Some($v), *$scale as u8) + }; + + v.and_then(|v| <$variant_type>::try_new(v, scale).ok()) + .map_or(Variant::Null, Variant::from) + }}; +} +pub(crate) use decimal_to_variant_decimal; diff --git a/parquet-variant-compute/src/variant_array.rs b/parquet-variant-compute/src/variant_array.rs index d51df550622d..4abffa65c23f 100644 --- a/parquet-variant-compute/src/variant_array.rs +++ b/parquet-variant-compute/src/variant_array.rs @@ -17,10 +17,15 @@ //! [`VariantArray`] implementation +use crate::type_conversion::primitive_conversion_single_value; use arrow::array::{Array, ArrayData, ArrayRef, AsArray, BinaryViewArray, StructArray}; use arrow::buffer::NullBuffer; -use arrow::datatypes::Int32Type; -use arrow_schema::{ArrowError, DataType}; +use arrow::datatypes::{ + Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, +}; +use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields}; +use parquet_variant::Uuid; use parquet_variant::Variant; use std::any::Any; use std::sync::Arc; @@ -48,6 +53,9 @@ pub struct VariantArray { /// Reference to the underlying StructArray inner: StructArray, + /// The metadata column of this variant + metadata: BinaryViewArray, + /// how is this variant array shredded? shredding_state: ShreddingState, } @@ -102,31 +110,56 @@ impl VariantArray { ))); }; - // Find the value field, if present - let value = inner - .column_by_name("value") - .map(|v| { - v.as_binary_view_opt().ok_or_else(|| { - ArrowError::NotYetImplemented(format!( - "VariantArray 'value' field must be BinaryView, got {}", - v.data_type() - )) - }) - }) - .transpose()?; - - // Find the typed_value field, if present - let typed_value = inner.column_by_name("typed_value"); + // Extract value and typed_value fields + let value = if let Some(value_col) = inner.column_by_name("value") { + if let Some(binary_view) = value_col.as_binary_view_opt() { + Some(binary_view.clone()) + } else { + return Err(ArrowError::NotYetImplemented(format!( + "VariantArray 'value' field must be BinaryView, got {}", + value_col.data_type() + ))); + } + } else { + None + }; + let typed_value = inner.column_by_name("typed_value").cloned(); // Note these clones are cheap, they just bump the ref count - let inner = inner.clone(); - let shredding_state = - ShreddingState::try_new(metadata.clone(), value.cloned(), typed_value.cloned())?; - Ok(Self { + inner: inner.clone(), + metadata: metadata.clone(), + shredding_state: ShreddingState::try_new(value, typed_value)?, + }) + } + + pub(crate) fn from_parts( + metadata: BinaryViewArray, + value: Option, + typed_value: Option, + nulls: Option, + ) -> Self { + let mut builder = + StructArrayBuilder::new().with_field("metadata", Arc::new(metadata.clone()), false); + if let Some(value) = value.clone() { + builder = builder.with_field("value", Arc::new(value), true); + } + if let Some(typed_value) = typed_value.clone() { + builder = builder.with_field("typed_value", typed_value, true); + } + if let Some(nulls) = nulls { + builder = builder.with_nulls(nulls); + } + + // This would be a lot simpler if ShreddingState were just a pair of Option... we already + // have everything we need. + let inner = builder.build(); + let shredding_state = ShreddingState::try_new(value, typed_value).unwrap(); // valid by construction + Self { inner, + metadata, shredding_state, - }) + } } /// Returns a reference to the underlying [`StructArray`]. @@ -146,8 +179,8 @@ impl VariantArray { /// Return the [`Variant`] instance stored at the given row /// - /// Consistently with other Arrow arrays types, this API requires you to - /// check for nulls first using [`Self::is_valid`]. + /// Note: This method does not check for nulls and the value is arbitrary + /// (but still well-defined) if [`is_null`](Self::is_null) returns true for the index. /// /// # Panics /// * if the index is out of bounds @@ -166,10 +199,12 @@ impl VariantArray { /// caller to ensure that the metadata and value were constructed correctly. pub fn value(&self, index: usize) -> Variant<'_, '_> { match &self.shredding_state { - ShreddingState::Unshredded { metadata, value } => { - Variant::new(metadata.value(index), value.value(index)) + ShreddingState::Unshredded { value, .. } => { + // Unshredded case + Variant::new(self.metadata.value(index), value.value(index)) } ShreddingState::Typed { typed_value, .. } => { + // Typed case (formerly PerfectlyShredded) if typed_value.is_null(index) { Variant::Null } else { @@ -177,22 +212,129 @@ impl VariantArray { } } ShreddingState::PartiallyShredded { - metadata, - value, - typed_value, + value, typed_value, .. } => { + // PartiallyShredded case (formerly ImperfectlyShredded) if typed_value.is_null(index) { - Variant::new(metadata.value(index), value.value(index)) + Variant::new(self.metadata.value(index), value.value(index)) } else { typed_value_to_variant(typed_value, index) } } + ShreddingState::AllNull => { + // AllNull case: neither value nor typed_value fields exist + // NOTE: This handles the case where neither value nor typed_value fields exist. + // For top-level variants, this returns Variant::Null (JSON null). + // For shredded object fields, this technically should indicate SQL NULL, + // but the current API cannot distinguish these contexts. + Variant::Null + } } } /// Return a reference to the metadata field of the [`StructArray`] pub fn metadata_field(&self) -> &BinaryViewArray { - self.shredding_state.metadata_field() + &self.metadata + } + + /// Return a reference to the value field of the `StructArray` + pub fn value_field(&self) -> Option<&BinaryViewArray> { + self.shredding_state.value_field() + } + + /// Return a reference to the typed_value field of the `StructArray`, if present + pub fn typed_value_field(&self) -> Option<&ArrayRef> { + self.shredding_state.typed_value_field() + } +} + +/// One shredded field of a partially or prefectly shredded variant. For example, suppose the +/// shredding schema for variant `v` treats it as an object with a single field `a`, where `a` is +/// itself a struct with the single field `b` of type INT. Then the physical layout of the column +/// is: +/// +/// ```text +/// v: VARIANT { +/// metadata: BINARY, +/// value: BINARY, +/// typed_value: STRUCT { +/// a: SHREDDED_VARIANT_FIELD { +/// value: BINARY, +/// typed_value: STRUCT { +/// a: SHREDDED_VARIANT_FIELD { +/// value: BINARY, +/// typed_value: INT, +/// }, +/// }, +/// }, +/// }, +/// } +/// ``` +/// +/// In the above, each row of `v.value` is either a variant value (shredding failed, `v` was not an +/// object at all) or a variant object (partial shredding, `v` was an object but included unexpected +/// fields other than `a`), or is NULL (perfect shredding, `v` was an object containing only the +/// single expected field `a`). +/// +/// A similar story unfolds for each `v.typed_value.a.value` -- a variant value if shredding failed +/// (`v:a` was not an object at all), or a variant object (`v:a` was an object with unexpected +/// additional fields), or NULL (`v:a` was an object containing only the single expected field `b`). +/// +/// Finally, `v.typed_value.a.typed_value.b.value` is either NULL (`v:a.b` was an integer) or else a +/// variant value (which could be `Variant::Null`). +#[derive(Debug)] +pub struct ShreddedVariantFieldArray { + /// Reference to the underlying StructArray + inner: StructArray, + shredding_state: ShreddingState, +} + +#[allow(unused)] +impl ShreddedVariantFieldArray { + /// Creates a new `ShreddedVariantFieldArray` from a [`StructArray`]. + /// + /// # Arguments + /// - `inner` - The underlying [`StructArray`] that contains the variant data. + /// + /// # Returns + /// - A new instance of `ShreddedVariantFieldArray`. + /// + /// # Errors: + /// - If the `StructArray` does not contain the required fields + /// + /// # Requirements of the `StructArray` + /// + /// 1. An optional field named `value` that is binary, large_binary, or + /// binary_view + /// + /// 2. An optional field named `typed_value` which can be any primitive type + /// or be a list, large_list, list_view or struct + /// + /// Currently, only `value` columns of type [`BinaryViewArray`] are supported. + pub fn try_new(inner: ArrayRef) -> Result { + let Some(inner_struct) = inner.as_struct_opt() else { + return Err(ArrowError::InvalidArgumentError( + "Invalid ShreddedVariantFieldArray: requires StructArray as input".to_string(), + )); + }; + + // Extract value and typed_value fields (metadata is not expected in ShreddedVariantFieldArray) + let value = inner_struct + .column_by_name("value") + .and_then(|col| col.as_binary_view_opt().cloned()); + let typed_value = inner_struct.column_by_name("typed_value").cloned(); + + // Note this clone is cheap, it just bumps the ref count + let inner = inner_struct.clone(); + Ok(Self { + inner: inner.clone(), + shredding_state: ShreddingState::try_new(value, typed_value)?, + }) + } + + /// Return the shredding state of this `VariantArray` + pub fn shredding_state(&self) -> &ShreddingState { + &self.shredding_state } /// Return a reference to the value field of the `StructArray` @@ -204,6 +346,65 @@ impl VariantArray { pub fn typed_value_field(&self) -> Option<&ArrayRef> { self.shredding_state.typed_value_field() } + + /// Returns a reference to the underlying [`StructArray`]. + pub fn inner(&self) -> &StructArray { + &self.inner + } +} + +impl Array for ShreddedVariantFieldArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn to_data(&self) -> ArrayData { + self.inner.to_data() + } + + fn into_data(self) -> ArrayData { + self.inner.into_data() + } + + fn data_type(&self) -> &DataType { + self.inner.data_type() + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + let inner = self.inner.slice(offset, length); + let shredding_state = self.shredding_state.slice(offset, length); + Arc::new(Self { + inner, + shredding_state, + }) + } + + fn len(&self) -> usize { + self.inner.len() + } + + fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + fn offset(&self) -> usize { + self.inner.offset() + } + + fn nulls(&self) -> Option<&NullBuffer> { + // According to the shredding spec, ShreddedVariantFieldArray should be + // physically non-nullable - SQL NULL is inferred by both value and + // typed_value being physically NULL + None + } + + fn get_buffer_memory_size(&self) -> usize { + self.inner.get_buffer_memory_size() + } + + fn get_array_memory_size(&self) -> usize { + self.inner.get_array_memory_size() + } } /// Represents the shredding state of a [`VariantArray`] @@ -226,62 +427,45 @@ impl VariantArray { /// [Parquet Variant Shredding Spec]: https://github.com/apache/parquet-format/blob/master/VariantShredding.md#value-shredding #[derive(Debug)] pub enum ShreddingState { - // TODO: add missing state where there is neither value nor typed_value - // Missing { metadata: BinaryViewArray }, /// This variant has no typed_value field - Unshredded { - metadata: BinaryViewArray, - value: BinaryViewArray, - }, + Unshredded { value: BinaryViewArray }, /// This variant has a typed_value field and no value field /// meaning it is the shredded type - Typed { - metadata: BinaryViewArray, - typed_value: ArrayRef, - }, - /// Partially shredded: - /// * value is an object - /// * typed_value is a shredded object. + Typed { typed_value: ArrayRef }, + /// Imperfectly shredded: Shredded values reside in `typed_value` while those that failed to + /// shred reside in `value`. Missing field values are NULL in both columns, while NULL primitive + /// values have NULL `typed_value` and `Variant::Null` in `value`. /// - /// Note the spec says "Writers must not produce data where both value and - /// typed_value are non-null, unless the Variant value is an object." + /// NOTE: A partially shredded struct is a special kind of imperfect shredding, where + /// `typed_value` and `value` are both non-NULL. The `typed_value` is a struct containing the + /// subset of fields for which shredding was attempted (each field will then have its own value + /// and/or typed_value sub-fields that indicate how shredding actually turned out). Meanwhile, + /// the `value` is a variant object containing the subset of fields for which shredding was + /// not even attempted. PartiallyShredded { - metadata: BinaryViewArray, value: BinaryViewArray, typed_value: ArrayRef, }, + /// All values are null, only metadata is present. + /// + /// This state occurs when neither `value` nor `typed_value` fields exist in the schema. + /// Note: By strict spec interpretation, this should only be valid for shredded object fields, + /// not top-level variants. However, we allow it and treat as Variant::Null for pragmatic + /// handling of missing data. + AllNull, } impl ShreddingState { /// try to create a new `ShreddingState` from the given fields pub fn try_new( - metadata: BinaryViewArray, value: Option, typed_value: Option, ) -> Result { - match (metadata, value, typed_value) { - (metadata, Some(value), Some(typed_value)) => Ok(Self::PartiallyShredded { - metadata, - value, - typed_value, - }), - (metadata, Some(value), None) => Ok(Self::Unshredded { metadata, value }), - (metadata, None, Some(typed_value)) => Ok(Self::Typed { - metadata, - typed_value, - }), - (_metadata_field, None, None) => Err(ArrowError::InvalidArgumentError(String::from( - "VariantArray has neither value nor typed_value field", - ))), - } - } - - /// Return a reference to the metadata field - pub fn metadata_field(&self) -> &BinaryViewArray { - match self { - ShreddingState::Unshredded { metadata, .. } => metadata, - ShreddingState::Typed { metadata, .. } => metadata, - ShreddingState::PartiallyShredded { metadata, .. } => metadata, + match (value, typed_value) { + (Some(value), Some(typed_value)) => Ok(Self::PartiallyShredded { value, typed_value }), + (Some(value), None) => Ok(Self::Unshredded { value }), + (None, Some(typed_value)) => Ok(Self::Typed { typed_value }), + (None, None) => Ok(Self::AllNull), } } @@ -291,6 +475,7 @@ impl ShreddingState { ShreddingState::Unshredded { value, .. } => Some(value), ShreddingState::Typed { .. } => None, ShreddingState::PartiallyShredded { value, .. } => Some(value), + ShreddingState::AllNull => None, } } @@ -300,49 +485,138 @@ impl ShreddingState { ShreddingState::Unshredded { .. } => None, ShreddingState::Typed { typed_value, .. } => Some(typed_value), ShreddingState::PartiallyShredded { typed_value, .. } => Some(typed_value), + ShreddingState::AllNull => None, } } /// Slice all the underlying arrays pub fn slice(&self, offset: usize, length: usize) -> Self { match self { - ShreddingState::Unshredded { metadata, value } => ShreddingState::Unshredded { - metadata: metadata.slice(offset, length), + ShreddingState::Unshredded { value } => ShreddingState::Unshredded { value: value.slice(offset, length), }, - ShreddingState::Typed { - metadata, - typed_value, - } => ShreddingState::Typed { - metadata: metadata.slice(offset, length), - typed_value: typed_value.slice(offset, length), - }, - ShreddingState::PartiallyShredded { - metadata, - value, - typed_value, - } => ShreddingState::PartiallyShredded { - metadata: metadata.slice(offset, length), - value: value.slice(offset, length), + ShreddingState::Typed { typed_value } => ShreddingState::Typed { typed_value: typed_value.slice(offset, length), }, + ShreddingState::PartiallyShredded { value, typed_value } => { + ShreddingState::PartiallyShredded { + value: value.slice(offset, length), + typed_value: typed_value.slice(offset, length), + } + } + ShreddingState::AllNull => ShreddingState::AllNull, } } } +/// Builds struct arrays from component fields +/// +/// TODO: move to arrow crate +#[derive(Debug, Default, Clone)] +pub(crate) struct StructArrayBuilder { + fields: Vec, + arrays: Vec, + nulls: Option, +} + +impl StructArrayBuilder { + pub fn new() -> Self { + Default::default() + } + + /// Add an array to this struct array as a field with the specified name. + pub fn with_field(mut self, field_name: &str, array: ArrayRef, nullable: bool) -> Self { + let field = Field::new(field_name, array.data_type().clone(), nullable); + self.fields.push(Arc::new(field)); + self.arrays.push(array); + self + } + + /// Set the null buffer for this struct array. + pub fn with_nulls(mut self, nulls: NullBuffer) -> Self { + self.nulls = Some(nulls); + self + } + + pub fn build(self) -> StructArray { + let Self { + fields, + arrays, + nulls, + } = self; + StructArray::new(Fields::from(fields), arrays, nulls) + } +} + /// returns the non-null element at index as a Variant fn typed_value_to_variant(typed_value: &ArrayRef, index: usize) -> Variant<'_, '_> { match typed_value.data_type() { + DataType::Boolean => { + let boolean_array = typed_value.as_boolean(); + let value = boolean_array.value(index); + Variant::from(value) + } + DataType::FixedSizeBinary(binary_len) => { + let array = typed_value.as_fixed_size_binary(); + // Try to treat 16 byte FixedSizeBinary as UUID + let value = array.value(index); + if *binary_len == 16 { + if let Ok(uuid) = Uuid::from_slice(value) { + return Variant::from(uuid); + } + } + let value = array.value(index); + Variant::from(value) + } + DataType::BinaryView => { + let array = typed_value.as_binary_view(); + let value = array.value(index); + Variant::from(value) + } + DataType::Utf8 => { + let array = typed_value.as_string::(); + let value = array.value(index); + Variant::from(value) + } + DataType::Int8 => { + primitive_conversion_single_value!(Int8Type, typed_value, index) + } + DataType::Int16 => { + primitive_conversion_single_value!(Int16Type, typed_value, index) + } DataType::Int32 => { - let typed_value = typed_value.as_primitive::(); - Variant::from(typed_value.value(index)) + primitive_conversion_single_value!(Int32Type, typed_value, index) + } + DataType::Int64 => { + primitive_conversion_single_value!(Int64Type, typed_value, index) + } + DataType::UInt8 => { + primitive_conversion_single_value!(UInt8Type, typed_value, index) + } + DataType::UInt16 => { + primitive_conversion_single_value!(UInt16Type, typed_value, index) + } + DataType::UInt32 => { + primitive_conversion_single_value!(UInt32Type, typed_value, index) + } + DataType::UInt64 => { + primitive_conversion_single_value!(UInt64Type, typed_value, index) + } + DataType::Float16 => { + primitive_conversion_single_value!(Float16Type, typed_value, index) + } + DataType::Float32 => { + primitive_conversion_single_value!(Float32Type, typed_value, index) + } + DataType::Float64 => { + primitive_conversion_single_value!(Float64Type, typed_value, index) } // todo other types here (note this is very similar to cast_to_variant.rs) // so it would be great to figure out how to share this code _ => { // We shouldn't panic in production code, but this is a // placeholder until we implement more types - // TODO tickets: XXXX + // https://github.com/apache/arrow-rs/issues/8091 debug_assert!( false, "Unsupported typed_value type: {:?}", @@ -372,9 +646,11 @@ impl Array for VariantArray { fn slice(&self, offset: usize, length: usize) -> ArrayRef { let inner = self.inner.slice(offset, length); + let metadata = self.metadata.slice(offset, length); let shredding_state = self.shredding_state.slice(offset, length); Arc::new(Self { inner, + metadata, shredding_state, }) } @@ -434,15 +710,27 @@ mod test { } #[test] - fn invalid_missing_value() { + fn all_null_missing_value_and_typed_value() { let fields = Fields::from(vec![Field::new("metadata", DataType::BinaryView, false)]); let array = StructArray::new(fields, vec![make_binary_view_array()], None); - // Should fail because the StructArray does not contain a 'value' field - let err = VariantArray::try_new(Arc::new(array)); - assert_eq!( - err.unwrap_err().to_string(), - "Invalid argument error: VariantArray has neither value nor typed_value field" - ); + + // NOTE: By strict spec interpretation, this case (top-level variant with null/null) + // should be invalid, but we currently allow it and treat it as Variant::Null. + // This is a pragmatic decision to handle missing data gracefully. + let variant_array = VariantArray::try_new(Arc::new(array)).unwrap(); + + // Verify the shredding state is AllNull + assert!(matches!( + variant_array.shredding_state(), + ShreddingState::AllNull + )); + + // Verify that value() returns Variant::Null (compensating for spec violation) + for i in 0..variant_array.len() { + if variant_array.is_valid(i) { + assert_eq!(variant_array.value(i), parquet_variant::Variant::Null); + } + } } #[test] @@ -488,4 +776,79 @@ mod test { fn make_binary_array() -> ArrayRef { Arc::new(BinaryArray::from(vec![b"test" as &[u8]])) } + + #[test] + fn all_null_shredding_state() { + let shredding_state = ShreddingState::try_new(None, None).unwrap(); + + // Verify the shredding state is AllNull + assert!(matches!(shredding_state, ShreddingState::AllNull)); + } + + #[test] + fn all_null_variant_array_construction() { + let metadata = BinaryViewArray::from(vec![b"test" as &[u8]; 3]); + let nulls = NullBuffer::from(vec![false, false, false]); // all null + + let fields = Fields::from(vec![Field::new("metadata", DataType::BinaryView, false)]); + let struct_array = StructArray::new(fields, vec![Arc::new(metadata)], Some(nulls)); + + let variant_array = VariantArray::try_new(Arc::new(struct_array)).unwrap(); + + // Verify the shredding state is AllNull + assert!(matches!( + variant_array.shredding_state(), + ShreddingState::AllNull + )); + + // Verify all values are null + assert_eq!(variant_array.len(), 3); + assert!(!variant_array.is_valid(0)); + assert!(!variant_array.is_valid(1)); + assert!(!variant_array.is_valid(2)); + + // Verify that value() returns Variant::Null for all indices + for i in 0..variant_array.len() { + assert!( + !variant_array.is_valid(i), + "Expected value at index {i} to be null" + ); + } + } + + #[test] + fn value_field_present_but_all_null_should_be_unshredded() { + // This test demonstrates the issue: when a value field exists in schema + // but all its values are null, it should remain Unshredded, not AllNull + let metadata = BinaryViewArray::from(vec![b"test" as &[u8]; 3]); + + // Create a value field with all null values + let value_nulls = NullBuffer::from(vec![false, false, false]); // all null + let value_array = BinaryViewArray::from_iter_values(vec![""; 3]); + let value_data = value_array + .to_data() + .into_builder() + .nulls(Some(value_nulls)) + .build() + .unwrap(); + let value = BinaryViewArray::from(value_data); + + let fields = Fields::from(vec![ + Field::new("metadata", DataType::BinaryView, false), + Field::new("value", DataType::BinaryView, true), // Field exists in schema + ]); + let struct_array = StructArray::new( + fields, + vec![Arc::new(metadata), Arc::new(value)], + None, // struct itself is not null, just the value field is all null + ); + + let variant_array = VariantArray::try_new(Arc::new(struct_array)).unwrap(); + + // This should be Unshredded, not AllNull, because value field exists in schema + assert!(matches!( + variant_array.shredding_state(), + ShreddingState::Unshredded { .. } + )); + } } diff --git a/parquet-variant-compute/src/variant_array_builder.rs b/parquet-variant-compute/src/variant_array_builder.rs index 36bd6567700b..9779d4a06d4a 100644 --- a/parquet-variant-compute/src/variant_array_builder.rs +++ b/parquet-variant-compute/src/variant_array_builder.rs @@ -19,8 +19,11 @@ use crate::VariantArray; use arrow::array::{ArrayRef, BinaryViewArray, BinaryViewBuilder, NullBufferBuilder, StructArray}; -use arrow_schema::{DataType, Field, Fields}; -use parquet_variant::{ListBuilder, ObjectBuilder, Variant, VariantBuilder, VariantBuilderExt}; +use arrow_schema::{ArrowError, DataType, Field, Fields}; +use parquet_variant::{ + BuilderSpecificState, ListBuilder, MetadataBuilder, ObjectBuilder, Variant, VariantBuilderExt, +}; +use parquet_variant::{ParentState, ValueBuilder, WritableMetadataBuilder}; use std::sync::Arc; /// A builder for [`VariantArray`] @@ -45,13 +48,10 @@ use std::sync::Arc; /// builder.append_variant(Variant::from(42)); /// // append a null row (note not a Variant::Null) /// builder.append_null(); -/// // append an object to the builder -/// let mut vb = builder.variant_builder(); -/// vb.new_object() +/// // append an object to the builder using VariantBuilderExt methods directly +/// builder.new_object() /// .with_field("foo", "bar") -/// .finish() -/// .unwrap(); -/// vb.finish(); // must call finish to write the variant to the buffers +/// .finish(); /// /// // create the final VariantArray /// let variant_array = builder.build(); @@ -72,14 +72,14 @@ use std::sync::Arc; pub struct VariantArrayBuilder { /// Nulls nulls: NullBufferBuilder, - /// buffer for all the metadata - metadata_buffer: Vec, - /// (offset, len) pairs for locations of metadata in the buffer - metadata_locations: Vec<(usize, usize)>, - /// buffer for values - value_buffer: Vec, - /// (offset, len) pairs for locations of values in the buffer - value_locations: Vec<(usize, usize)>, + /// builder for all the metadata + metadata_builder: WritableMetadataBuilder, + /// ending offset for each serialized metadata dictionary in the buffer + metadata_offsets: Vec, + /// builder for values + value_builder: ValueBuilder, + /// ending offset for each serialized variant value in the buffer + value_offsets: Vec, /// The fields of the final `StructArray` /// /// TODO: 1) Add extension type metadata @@ -95,10 +95,10 @@ impl VariantArrayBuilder { Self { nulls: NullBufferBuilder::new(row_capacity), - metadata_buffer: Vec::new(), // todo allocation capacity - metadata_locations: Vec::with_capacity(row_capacity), - value_buffer: Vec::new(), - value_locations: Vec::with_capacity(row_capacity), + metadata_builder: WritableMetadataBuilder::default(), + metadata_offsets: Vec::with_capacity(row_capacity), + value_builder: ValueBuilder::new(), + value_offsets: Vec::with_capacity(row_capacity), fields: Fields::from(vec![metadata_field, value_field]), } } @@ -107,16 +107,18 @@ impl VariantArrayBuilder { pub fn build(self) -> VariantArray { let Self { mut nulls, - metadata_buffer, - metadata_locations, - value_buffer, - value_locations, + metadata_builder, + metadata_offsets, + value_builder, + value_offsets, fields, } = self; - let metadata_array = binary_view_array_from_buffers(metadata_buffer, metadata_locations); + let metadata_buffer = metadata_builder.into_inner(); + let metadata_array = binary_view_array_from_buffers(metadata_buffer, metadata_offsets); - let value_array = binary_view_array_from_buffers(value_buffer, value_locations); + let value_buffer = value_builder.into_inner(); + let value_array = binary_view_array_from_buffers(value_buffer, value_offsets); // The build the final struct array let inner = StructArray::new( @@ -136,221 +138,88 @@ impl VariantArrayBuilder { pub fn append_null(&mut self) { self.nulls.append_null(); // The subfields are expected to be non-nullable according to the parquet variant spec. - let metadata_offset = self.metadata_buffer.len(); - let metadata_length = 0; - self.metadata_locations - .push((metadata_offset, metadata_length)); - let value_offset = self.value_buffer.len(); - let value_length = 0; - self.value_locations.push((value_offset, value_length)); + self.metadata_offsets.push(self.metadata_builder.offset()); + self.value_offsets.push(self.value_builder.offset()); } /// Append the [`Variant`] to the builder as the next row pub fn append_variant(&mut self, variant: Variant) { - let mut direct_builder = self.variant_builder(); - direct_builder.variant_builder.append_value(variant); - direct_builder.finish() + ValueBuilder::append_variant(self.parent_state(), variant); } - /// Return a `VariantArrayVariantBuilder` that writes directly to the - /// buffers of this builder. - /// - /// You must call [`VariantArrayVariantBuilder::finish`] to complete the builder - /// - /// # Example - /// ``` - /// # use parquet_variant::{Variant, VariantBuilder, VariantBuilderExt}; - /// # use parquet_variant_compute::{VariantArray, VariantArrayBuilder}; - /// let mut array_builder = VariantArrayBuilder::new(10); - /// - /// // First row has a string - /// let mut variant_builder = array_builder.variant_builder(); - /// variant_builder.append_value("Hello, World!"); - /// // must call finish to write the variant to the buffers - /// variant_builder.finish(); - /// - /// // Second row is an object - /// let mut variant_builder = array_builder.variant_builder(); - /// variant_builder - /// .new_object() - /// .with_field("my_field", 42i64) - /// .finish() - /// .unwrap(); - /// variant_builder.finish(); - /// - /// // finalize the array - /// let variant_array: VariantArray = array_builder.build(); - /// - /// // verify what we wrote is still there - /// assert_eq!(variant_array.value(0), Variant::from("Hello, World!")); - /// assert!(variant_array.value(1).as_object().is_some()); - /// ``` - pub fn variant_builder(&mut self) -> VariantArrayVariantBuilder<'_> { - // append directly into the metadata and value buffers - let metadata_buffer = std::mem::take(&mut self.metadata_buffer); - let value_buffer = std::mem::take(&mut self.value_buffer); - VariantArrayVariantBuilder::new(self, metadata_buffer, value_buffer) + /// Creates a builder-specific parent state + fn parent_state(&mut self) -> ParentState<'_, ArrayBuilderState<'_>> { + let state = ArrayBuilderState { + metadata_offsets: &mut self.metadata_offsets, + value_offsets: &mut self.value_offsets, + nulls: &mut self.nulls, + }; + + ParentState::new(&mut self.value_builder, &mut self.metadata_builder, state) } } -/// A `VariantBuilderExt` that writes directly to the buffers of a `VariantArrayBuilder`. -/// -// This struct implements [`VariantBuilderExt`], so in most cases it can be used as a -// [`VariantBuilder`] to perform variant-related operations for [`VariantArrayBuilder`]. -/// -/// If [`Self::finish`] is not called, any changes will be rolled back -/// -/// See [`VariantArrayBuilder::variant_builder`] for an example -pub struct VariantArrayVariantBuilder<'a> { - /// was finish called? - finished: bool, - /// starting offset in the variant_builder's `metadata` buffer - metadata_offset: usize, - /// starting offset in the variant_builder's `value` buffer - value_offset: usize, - /// Parent array builder that this variant builder writes to. Buffers - /// have been moved into the variant builder, and must be returned on - /// drop - array_builder: &'a mut VariantArrayBuilder, - /// Builder for the in progress variant value, temporarily owns the buffers - /// from `array_builder` - variant_builder: VariantBuilder, +/// Builder-specific state for array building that manages array-level offsets and nulls. See +/// [`VariantBuilderExt`] for details. +#[derive(Debug)] +pub struct ArrayBuilderState<'a> { + metadata_offsets: &'a mut Vec, + value_offsets: &'a mut Vec, + nulls: &'a mut NullBufferBuilder, } -impl<'a> VariantBuilderExt for VariantArrayVariantBuilder<'a> { - fn append_value<'m, 'v>(&mut self, value: impl Into>) { - self.variant_builder.append_value(value); - } - - fn new_list(&mut self) -> ListBuilder<'_> { - self.variant_builder.new_list() - } - - fn new_object(&mut self) -> ObjectBuilder<'_> { - self.variant_builder.new_object() +// All changes are pending until finalized +impl BuilderSpecificState for ArrayBuilderState<'_> { + fn finish( + &mut self, + metadata_builder: &mut dyn MetadataBuilder, + value_builder: &mut ValueBuilder, + ) { + self.metadata_offsets.push(metadata_builder.finish()); + self.value_offsets.push(value_builder.offset()); + self.nulls.append_non_null(); } } -impl<'a> VariantArrayVariantBuilder<'a> { - /// Constructs a new VariantArrayVariantBuilder - /// - /// Note this is not public as this is a structure that is logically - /// part of the [`VariantArrayBuilder`] and relies on its internal structure - fn new( - array_builder: &'a mut VariantArrayBuilder, - metadata_buffer: Vec, - value_buffer: Vec, - ) -> Self { - let metadata_offset = metadata_buffer.len(); - let value_offset = value_buffer.len(); - VariantArrayVariantBuilder { - finished: false, - metadata_offset, - value_offset, - variant_builder: VariantBuilder::new_with_buffers(metadata_buffer, value_buffer), - array_builder, - } - } +impl VariantBuilderExt for VariantArrayBuilder { + type State<'a> + = ArrayBuilderState<'a> + where + Self: 'a; - /// Return a reference to the underlying `VariantBuilder` - pub fn inner(&self) -> &VariantBuilder { - &self.variant_builder + /// Appending NULL to a variant array produces an actual NULL value + fn append_null(&mut self) { + self.append_null(); } - /// Return a mutable reference to the underlying `VariantBuilder` - pub fn inner_mut(&mut self) -> &mut VariantBuilder { - &mut self.variant_builder + fn append_value<'m, 'v>(&mut self, value: impl Into>) { + self.append_variant(value.into()); } - /// Called to finish the in progress variant and write it to the underlying - /// buffers - /// - /// Note if you do not call finish, on drop any changes made to the - /// underlying buffers will be rolled back. - pub fn finish(mut self) { - self.finished = true; - - let metadata_offset = self.metadata_offset; - let value_offset = self.value_offset; - // get the buffers back from the variant builder - let (metadata_buffer, value_buffer) = std::mem::take(&mut self.variant_builder).finish(); - - // Sanity Check: if the buffers got smaller, something went wrong (previous data was lost) - let metadata_len = metadata_buffer - .len() - .checked_sub(metadata_offset) - .expect("metadata length decreased unexpectedly"); - let value_len = value_buffer - .len() - .checked_sub(value_offset) - .expect("value length decreased unexpectedly"); - - // commit the changes by putting the - // offsets and lengths into the parent array builder. - self.array_builder - .metadata_locations - .push((metadata_offset, metadata_len)); - self.array_builder - .value_locations - .push((value_offset, value_len)); - self.array_builder.nulls.append_non_null(); - // put the buffers back into the array builder - self.array_builder.metadata_buffer = metadata_buffer; - self.array_builder.value_buffer = value_buffer; + fn try_new_list(&mut self) -> Result>, ArrowError> { + Ok(ListBuilder::new(self.parent_state(), false)) } -} - -impl<'a> Drop for VariantArrayVariantBuilder<'a> { - /// If the builder was not finished, roll back any changes made to the - /// underlying buffers (by truncating them) - fn drop(&mut self) { - if self.finished { - return; - } - - // if the object was not finished, need to rollback any changes by - // truncating the buffers to the original offsets - let metadata_offset = self.metadata_offset; - let value_offset = self.value_offset; - - // get the buffers back from the variant builder - let (mut metadata_buffer, mut value_buffer) = - std::mem::take(&mut self.variant_builder).into_buffers(); - - // Sanity Check: if the buffers got smaller, something went wrong (previous data was lost) so panic immediately - metadata_buffer - .len() - .checked_sub(metadata_offset) - .expect("metadata length decreased unexpectedly"); - value_buffer - .len() - .checked_sub(value_offset) - .expect("value length decreased unexpectedly"); - // Note this truncate is fast because truncate doesn't free any memory: - // it just has to drop elements (and u8 doesn't have a destructor) - metadata_buffer.truncate(metadata_offset); - value_buffer.truncate(value_offset); - - // put the buffers back into the array builder - self.array_builder.metadata_buffer = metadata_buffer; - self.array_builder.value_buffer = value_buffer; + fn try_new_object(&mut self) -> Result>, ArrowError> { + Ok(ObjectBuilder::new(self.parent_state(), false)) } } -fn binary_view_array_from_buffers( - buffer: Vec, - locations: Vec<(usize, usize)>, -) -> BinaryViewArray { - let mut builder = BinaryViewBuilder::with_capacity(locations.len()); +fn binary_view_array_from_buffers(buffer: Vec, offsets: Vec) -> BinaryViewArray { + // All offsets are less than or equal to the buffer length, so we can safely cast all offsets + // inside the loop below, as long as the buffer length fits in u32. + u32::try_from(buffer.len()).expect("buffer length should fit in u32"); + + let mut builder = BinaryViewBuilder::with_capacity(offsets.len()); let block = builder.append_block(buffer.into()); // TODO this can be much faster if it creates the views directly during append - for (offset, length) in locations { - let offset = offset.try_into().expect("offset should fit in u32"); - let length = length.try_into().expect("length should fit in u32"); + let mut start = 0; + for end in offsets { + let end = end as u32; // Safe cast: validated max offset fits in u32 above builder - .try_append_view(block, offset, length) + .try_append_view(block, start, end - start) .expect("Failed to append view"); + start = end; } builder.finish() } @@ -388,30 +257,22 @@ mod test { } } - /// Test using sub builders to append variants + /// Test using appending variants to the array builder #[test] - fn test_variant_array_builder_variant_builder() { + fn test_variant_array_builder() { let mut builder = VariantArrayBuilder::new(10); builder.append_null(); // should not panic builder.append_variant(Variant::from(42i32)); - // let's make a sub-object in the next row - let mut sub_builder = builder.variant_builder(); - sub_builder - .new_object() - .with_field("foo", "bar") - .finish() - .unwrap(); - sub_builder.finish(); // must call finish to write the variant to the buffers + // make an object in the next row + builder.new_object().with_field("foo", "bar").finish(); // append a new list - let mut sub_builder = builder.variant_builder(); - sub_builder + builder .new_list() .with_value(Variant::from(1i32)) .with_value(Variant::from(2i32)) .finish(); - sub_builder.finish(); let variant_array = builder.build(); assert_eq!(variant_array.len(), 4); @@ -427,51 +288,4 @@ mod test { let list = variant.as_list().expect("variant to be a list"); assert_eq!(list.len(), 2); } - - /// Test using non-finished sub builders to append variants - #[test] - fn test_variant_array_builder_variant_builder_reset() { - let mut builder = VariantArrayBuilder::new(10); - - // make a sub-object in the first row - let mut sub_builder = builder.variant_builder(); - sub_builder - .new_object() - .with_field("foo", 1i32) - .finish() - .unwrap(); - sub_builder.finish(); // must call finish to write the variant to the buffers - - // start appending an object but don't finish - let mut sub_builder = builder.variant_builder(); - sub_builder - .new_object() - .with_field("bar", 2i32) - .finish() - .unwrap(); - drop(sub_builder); // drop the sub builder without finishing it - - // make a third sub-object (this should reset the previous unfinished object) - let mut sub_builder = builder.variant_builder(); - sub_builder - .new_object() - .with_field("baz", 3i32) - .finish() - .unwrap(); - sub_builder.finish(); // must call finish to write the variant to the buffers - - let variant_array = builder.build(); - - // only the two finished objects should be present - assert_eq!(variant_array.len(), 2); - assert!(!variant_array.is_null(0)); - let variant = variant_array.value(0); - let variant = variant.as_object().expect("variant to be an object"); - assert_eq!(variant.get("foo").unwrap(), Variant::from(1i32)); - - assert!(!variant_array.is_null(1)); - let variant = variant_array.value(1); - let variant = variant.as_object().expect("variant to be an object"); - assert_eq!(variant.get("baz").unwrap(), Variant::from(3i32)); - } } diff --git a/parquet-variant-compute/src/variant_get.rs b/parquet-variant-compute/src/variant_get.rs new file mode 100644 index 000000000000..9d32c7f5a613 --- /dev/null +++ b/parquet-variant-compute/src/variant_get.rs @@ -0,0 +1,2761 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +use arrow::{ + array::{self, Array, ArrayRef, BinaryViewArray, StructArray}, + compute::CastOptions, + datatypes::Field, + error::Result, +}; +use arrow_schema::{ArrowError, DataType, FieldRef}; +use parquet_variant::{VariantPath, VariantPathElement}; + +use crate::variant_array::{ShreddedVariantFieldArray, ShreddingState}; +use crate::variant_to_arrow::make_variant_to_arrow_row_builder; +use crate::VariantArray; + +use std::sync::Arc; + +pub(crate) enum ShreddedPathStep<'a> { + /// Path step succeeded, return the new shredding state + Success(&'a ShreddingState), + /// The path element is not present in the `typed_value` column and there is no `value` column, + /// so we we know it does not exist. It, and all paths under it, are all-NULL. + Missing, + /// The path element is not present in the `typed_value` column and must be retrieved from the `value` + /// column instead. The caller should be prepared to handle any value, including the requested + /// type, an arbitrary "wrong" type, or `Variant::Null`. + NotShredded, +} + +/// Given a shredded variant field -- a `(value?, typed_value?)` pair -- try to take one path step +/// deeper. For a `VariantPathElement::Field`, the step fails if there is no `typed_value` at this +/// level, or if `typed_value` is not a struct, or if the requested field name does not exist. +/// +/// TODO: Support `VariantPathElement::Index`? It wouldn't be easy, and maybe not even possible. +pub(crate) fn follow_shredded_path_element<'a>( + shredding_state: &'a ShreddingState, + path_element: &VariantPathElement<'_>, + cast_options: &CastOptions, +) -> Result> { + // If the requested path element is not present in `typed_value`, and `value` is missing, then + // we know it does not exist; it, and all paths under it, are all-NULL. + let missing_path_step = || { + let Some(_value_field) = shredding_state.value_field() else { + return ShreddedPathStep::Missing; + }; + ShreddedPathStep::NotShredded + }; + + let Some(typed_value) = shredding_state.typed_value_field() else { + return Ok(missing_path_step()); + }; + + match path_element { + VariantPathElement::Field { name } => { + // Try to step into the requested field name of a struct. + // First, try to downcast to StructArray + let Some(struct_array) = typed_value.as_any().downcast_ref::() else { + // Downcast failure - if strict cast options are enabled, this should be an error + if !cast_options.safe { + return Err(ArrowError::CastError(format!( + "Cannot access field '{}' on non-struct type: {}", + name, + typed_value.data_type() + ))); + } + // With safe cast options, return NULL (missing_path_step) + return Ok(missing_path_step()); + }; + + // Now try to find the column - missing column in a present struct is just missing data + let Some(field) = struct_array.column_by_name(name) else { + // Missing column in a present struct is just missing, not wrong - return Ok + return Ok(missing_path_step()); + }; + + let field = field + .as_any() + .downcast_ref::() + .ok_or_else(|| { + // TODO: Should we blow up? Or just end the traversal and let the normal + // variant pathing code sort out the mess that it must anyway be + // prepared to handle? + ArrowError::InvalidArgumentError(format!( + "Expected a ShreddedVariantFieldArray, got {:?} instead", + field.data_type(), + )) + })?; + + Ok(ShreddedPathStep::Success(field.shredding_state())) + } + VariantPathElement::Index { .. } => { + // TODO: Support array indexing. Among other things, it will require slicing not + // only the array we have here, but also the corresponding metadata and null masks. + Err(ArrowError::NotYetImplemented( + "Pathing into shredded variant array index".into(), + )) + } + } +} + +/// Follows the given path as far as possible through shredded variant fields. If the path ends on a +/// shredded field, return it directly. Otherwise, use a row shredder to follow the rest of the path +/// and extract the requested value on a per-row basis. +fn shredded_get_path( + input: &VariantArray, + path: &[VariantPathElement<'_>], + as_field: Option<&Field>, + cast_options: &CastOptions, +) -> Result { + // Helper that creates a new VariantArray from the given nested value and typed_value columns, + // properly accounting for accumulated nulls from path traversal + let make_target_variant = + |value: Option, + typed_value: Option, + accumulated_nulls: Option| { + let metadata = input.metadata_field().clone(); + VariantArray::from_parts(metadata, value, typed_value, accumulated_nulls) + }; + + // Helper that shreds a VariantArray to a specific type. + let shred_basic_variant = + |target: VariantArray, path: VariantPath<'_>, as_field: Option<&Field>| { + let as_type = as_field.map(|f| f.data_type()); + let mut builder = + make_variant_to_arrow_row_builder(path, as_type, cast_options, target.len())?; + for i in 0..target.len() { + if target.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(&target.value(i))?; + } + } + builder.finish() + }; + + // Peel away the prefix of path elements that traverses the shredded parts of this variant + // column. Shredding will traverse the rest of the path on a per-row basis. + let mut shredding_state = input.shredding_state(); + let mut accumulated_nulls = input.inner().nulls().cloned(); + let mut path_index = 0; + for path_element in path { + match follow_shredded_path_element(shredding_state, path_element, cast_options)? { + ShreddedPathStep::Success(state) => { + // Union nulls from the typed_value we just accessed + if let Some(typed_value) = shredding_state.typed_value_field() { + accumulated_nulls = arrow::buffer::NullBuffer::union( + accumulated_nulls.as_ref(), + typed_value.nulls(), + ); + } + shredding_state = state; + path_index += 1; + continue; + } + ShreddedPathStep::Missing => { + let num_rows = input.len(); + let arr = match as_field.map(|f| f.data_type()) { + Some(data_type) => Arc::new(array::new_null_array(data_type, num_rows)) as _, + None => Arc::new(array::NullArray::new(num_rows)) as _, + }; + return Ok(arr); + } + ShreddedPathStep::NotShredded => { + let target = make_target_variant( + shredding_state.value_field().cloned(), + None, + accumulated_nulls, + ); + return shred_basic_variant(target, path[path_index..].into(), as_field); + } + }; + } + + // Path exhausted! Create a new `VariantArray` for the location we landed on. + let target = make_target_variant( + shredding_state.value_field().cloned(), + shredding_state.typed_value_field().cloned(), + accumulated_nulls, + ); + + // If our caller did not request any specific type, we can just return whatever we landed on. + let Some(as_field) = as_field else { + return Ok(Arc::new(target)); + }; + + // Structs are special. Recurse into each field separately, hoping to follow the shredding even + // further, and build up the final struct from those individually shredded results. + if let DataType::Struct(fields) = as_field.data_type() { + let children = fields + .iter() + .map(|field| { + shredded_get_path( + &target, + &[VariantPathElement::from(field.name().as_str())], + Some(field), + cast_options, + ) + }) + .collect::>>()?; + + let struct_nulls = target.nulls().cloned(); + + return Ok(Arc::new(StructArray::try_new( + fields.clone(), + children, + struct_nulls, + )?)); + } + + // Not a struct, so directly shred the variant as the requested type + shred_basic_variant(target, VariantPath::default(), Some(as_field)) +} + +/// Returns an array with the specified path extracted from the variant values. +/// +/// The return array type depends on the `as_type` field of the options parameter +/// 1. `as_type: None`: a VariantArray is returned. The values in this new VariantArray will point +/// to the specified path. +/// 2. `as_type: Some()`: an array of the specified type is returned. +/// +/// TODO: How would a caller request a struct or list type where the fields/elements can be any +/// variant? Caller can pass None as the requested type to fetch a specific path, but it would +/// quickly become annoying (and inefficient) to call `variant_get` for each leaf value in a struct or +/// list and then try to assemble the results. +pub fn variant_get(input: &ArrayRef, options: GetOptions) -> Result { + let variant_array: &VariantArray = input.as_any().downcast_ref().ok_or_else(|| { + ArrowError::InvalidArgumentError( + "expected a VariantArray as the input for variant_get".to_owned(), + ) + })?; + + let GetOptions { + as_type, + path, + cast_options, + } = options; + + shredded_get_path(variant_array, &path, as_type.as_deref(), &cast_options) +} + +/// Controls the action of the variant_get kernel. +#[derive(Debug, Clone, Default)] +pub struct GetOptions<'a> { + /// What path to extract + pub path: VariantPath<'a>, + /// if `as_type` is None, the returned array will itself be a VariantArray. + /// + /// if `as_type` is `Some(type)` the field is returned as the specified type. + pub as_type: Option, + /// Controls the casting behavior (e.g. error vs substituting null on cast error). + pub cast_options: CastOptions<'a>, +} + +impl<'a> GetOptions<'a> { + /// Construct default options to get the specified path as a variant. + pub fn new() -> Self { + Default::default() + } + + /// Construct options to get the specified path as a variant. + pub fn new_with_path(path: VariantPath<'a>) -> Self { + Self { + path, + as_type: None, + cast_options: Default::default(), + } + } + + /// Specify the type to return. + pub fn with_as_type(mut self, as_type: Option) -> Self { + self.as_type = as_type; + self + } + + /// Specify the cast options to use when casting to the specified type. + pub fn with_cast_options(mut self, cast_options: CastOptions<'a>) -> Self { + self.cast_options = cast_options; + self + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{ + Array, ArrayRef, BinaryViewArray, Float16Array, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, StringArray, StructArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, + }; + use arrow::buffer::NullBuffer; + use arrow::compute::CastOptions; + use arrow::datatypes::DataType::{Int16, Int32, Int64, UInt16, UInt32, UInt64, UInt8}; + use arrow_schema::{DataType, Field, FieldRef, Fields}; + use parquet_variant::{Variant, VariantPath, EMPTY_VARIANT_METADATA_BYTES}; + + use crate::json_to_variant; + use crate::variant_array::{ShreddedVariantFieldArray, StructArrayBuilder}; + use crate::VariantArray; + + use super::{variant_get, GetOptions}; + + fn single_variant_get_test(input_json: &str, path: VariantPath, expected_json: &str) { + // Create input array from JSON string + let input_array_ref: ArrayRef = Arc::new(StringArray::from(vec![Some(input_json)])); + let input_variant_array_ref: ArrayRef = + Arc::new(json_to_variant(&input_array_ref).unwrap()); + + let result = + variant_get(&input_variant_array_ref, GetOptions::new_with_path(path)).unwrap(); + + // Create expected array from JSON string + let expected_array_ref: ArrayRef = Arc::new(StringArray::from(vec![Some(expected_json)])); + let expected_variant_array = json_to_variant(&expected_array_ref).unwrap(); + + let result_array: &VariantArray = result.as_any().downcast_ref().unwrap(); + assert_eq!( + result_array.len(), + 1, + "Expected result array to have length 1" + ); + assert!( + result_array.nulls().is_none(), + "Expected no nulls in result array" + ); + let result_variant = result_array.value(0); + let expected_variant = expected_variant_array.value(0); + assert_eq!( + result_variant, expected_variant, + "Result variant does not match expected variant" + ); + } + + #[test] + fn get_primitive_variant_field() { + single_variant_get_test( + r#"{"some_field": 1234}"#, + VariantPath::from("some_field"), + "1234", + ); + } + + #[test] + fn get_primitive_variant_list_index() { + single_variant_get_test("[1234, 5678]", VariantPath::from(0), "1234"); + } + + #[test] + fn get_primitive_variant_inside_object_of_object() { + single_variant_get_test( + r#"{"top_level_field": {"inner_field": 1234}}"#, + VariantPath::from("top_level_field").join("inner_field"), + "1234", + ); + } + + #[test] + fn get_primitive_variant_inside_list_of_object() { + single_variant_get_test( + r#"[{"some_field": 1234}]"#, + VariantPath::from(0).join("some_field"), + "1234", + ); + } + + #[test] + fn get_primitive_variant_inside_object_of_list() { + single_variant_get_test( + r#"{"some_field": [1234]}"#, + VariantPath::from("some_field").join(0), + "1234", + ); + } + + #[test] + fn get_complex_variant() { + single_variant_get_test( + r#"{"top_level_field": {"inner_field": 1234}}"#, + VariantPath::from("top_level_field"), + r#"{"inner_field": 1234}"#, + ); + } + + /// Partial Shredding: extract a value as a VariantArray + macro_rules! numeric_partially_shredded_test { + ($primitive_type:ty, $data_fn:ident) => { + let array = $data_fn(); + let options = GetOptions::new(); + let result = variant_get(&array, options).unwrap(); + + // expect the result is a VariantArray + let result: &VariantArray = result.as_any().downcast_ref().unwrap(); + assert_eq!(result.len(), 4); + + // Expect the values are the same as the original values + assert_eq!( + result.value(0), + Variant::from(<$primitive_type>::try_from(34u8).unwrap()) + ); + assert!(!result.is_valid(1)); + assert_eq!(result.value(2), Variant::from("n/a")); + assert_eq!( + result.value(3), + Variant::from(<$primitive_type>::try_from(100u8).unwrap()) + ); + }; + } + + #[test] + fn get_variant_partially_shredded_int8_as_variant() { + numeric_partially_shredded_test!(i8, partially_shredded_int8_variant_array); + } + + #[test] + fn get_variant_partially_shredded_int16_as_variant() { + numeric_partially_shredded_test!(i16, partially_shredded_int16_variant_array); + } + + #[test] + fn get_variant_partially_shredded_int32_as_variant() { + numeric_partially_shredded_test!(i32, partially_shredded_int32_variant_array); + } + + #[test] + fn get_variant_partially_shredded_int64_as_variant() { + numeric_partially_shredded_test!(i64, partially_shredded_int64_variant_array); + } + + #[test] + fn get_variant_partially_shredded_uint8_as_variant() { + numeric_partially_shredded_test!(u8, partially_shredded_uint8_variant_array); + } + + #[test] + fn get_variant_partially_shredded_uint16_as_variant() { + numeric_partially_shredded_test!(u16, partially_shredded_uint16_variant_array); + } + + #[test] + fn get_variant_partially_shredded_uint32_as_variant() { + numeric_partially_shredded_test!(u32, partially_shredded_uint32_variant_array); + } + + #[test] + fn get_variant_partially_shredded_uint64_as_variant() { + numeric_partially_shredded_test!(u64, partially_shredded_uint64_variant_array); + } + + #[test] + fn get_variant_partially_shredded_float16_as_variant() { + numeric_partially_shredded_test!(half::f16, partially_shredded_float16_variant_array); + } + + #[test] + fn get_variant_partially_shredded_float32_as_variant() { + numeric_partially_shredded_test!(f32, partially_shredded_float32_variant_array); + } + + #[test] + fn get_variant_partially_shredded_float64_as_variant() { + numeric_partially_shredded_test!(f64, partially_shredded_float64_variant_array); + } + + #[test] + fn get_variant_partially_shredded_bool_as_variant() { + let array = partially_shredded_bool_variant_array(); + let options = GetOptions::new(); + let result = variant_get(&array, options).unwrap(); + + // expect the result is a VariantArray + let result: &VariantArray = result.as_any().downcast_ref().unwrap(); + assert_eq!(result.len(), 4); + + // Expect the values are the same as the original values + assert_eq!(result.value(0), Variant::from(true)); + assert!(!result.is_valid(1)); + assert_eq!(result.value(2), Variant::from("n/a")); + assert_eq!(result.value(3), Variant::from(false)); + } + + #[test] + fn get_variant_partially_shredded_fixed_size_binary_as_variant() { + let array = partially_shredded_fixed_size_binary_variant_array(); + let options = GetOptions::new(); + let result = variant_get(&array, options).unwrap(); + + // expect the result is a VariantArray + let result: &VariantArray = result.as_any().downcast_ref().unwrap(); + assert_eq!(result.len(), 4); + + // Expect the values are the same as the original values + assert_eq!(result.value(0), Variant::from(&[1u8, 2u8, 3u8][..])); + assert!(!result.is_valid(1)); + assert_eq!(result.value(2), Variant::from("n/a")); + assert_eq!(result.value(3), Variant::from(&[4u8, 5u8, 6u8][..])); + } + + #[test] + fn get_variant_partially_shredded_utf8_as_variant() { + let array = partially_shredded_utf8_variant_array(); + let options = GetOptions::new(); + let result = variant_get(&array, options).unwrap(); + + // expect the result is a VariantArray + let result: &VariantArray = result.as_any().downcast_ref().unwrap(); + assert_eq!(result.len(), 4); + + // Expect the values are the same as the original values + assert_eq!(result.value(0), Variant::from("hello")); + assert!(!result.is_valid(1)); + assert_eq!(result.value(2), Variant::from("n/a")); + assert_eq!(result.value(3), Variant::from("world")); + } + + #[test] + fn get_variant_partially_shredded_binary_view_as_variant() { + let array = partially_shredded_binary_view_variant_array(); + let options = GetOptions::new(); + let result = variant_get(&array, options).unwrap(); + + // expect the result is a VariantArray + let result: &VariantArray = result.as_any().downcast_ref().unwrap(); + assert_eq!(result.len(), 4); + + // Expect the values are the same as the original values + assert_eq!(result.value(0), Variant::from(&[1u8, 2u8, 3u8][..])); + assert!(!result.is_valid(1)); + assert_eq!(result.value(2), Variant::from("n/a")); + assert_eq!(result.value(3), Variant::from(&[4u8, 5u8, 6u8][..])); + } + + /// Shredding: extract a value as an Int32Array + #[test] + fn get_variant_shredded_int32_as_int32_safe_cast() { + // Extract the typed value as Int32Array + let array = partially_shredded_int32_variant_array(); + // specify we want the typed value as Int32 + let field = Field::new("typed_value", DataType::Int32, true); + let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&array, options).unwrap(); + let expected: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(34), + None, + None, // "n/a" is not an Int32 so converted to null + Some(100), + ])); + assert_eq!(&result, &expected) + } + + /// Shredding: extract a value as an Int32Array, unsafe cast (should error on "n/a") + #[test] + fn get_variant_shredded_int32_as_int32_unsafe_cast() { + // Extract the typed value as Int32Array + let array = partially_shredded_int32_variant_array(); + let field = Field::new("typed_value", DataType::Int32, true); + let cast_options = CastOptions { + safe: false, // unsafe cast + ..Default::default() + }; + let options = GetOptions::new() + .with_as_type(Some(FieldRef::from(field))) + .with_cast_options(cast_options); + + let err = variant_get(&array, options).unwrap_err(); + // TODO make this error message nicer (not Debug format) + assert_eq!(err.to_string(), "Cast error: Failed to extract primitive of type Int32 from variant ShortString(ShortString(\"n/a\")) at path VariantPath([])"); + } + + /// Perfect Shredding: extract the typed value as a VariantArray + macro_rules! numeric_perfectly_shredded_test { + ($primitive_type:ty, $data_fn:ident) => { + let array = $data_fn(); + let options = GetOptions::new(); + let result = variant_get(&array, options).unwrap(); + + // expect the result is a VariantArray + let result: &VariantArray = result.as_any().downcast_ref().unwrap(); + assert_eq!(result.len(), 3); + + // Expect the values are the same as the original values + assert_eq!( + result.value(0), + Variant::from(<$primitive_type>::try_from(1u8).unwrap()) + ); + assert_eq!( + result.value(1), + Variant::from(<$primitive_type>::try_from(2u8).unwrap()) + ); + assert_eq!( + result.value(2), + Variant::from(<$primitive_type>::try_from(3u8).unwrap()) + ); + }; + } + + #[test] + fn get_variant_perfectly_shredded_int8_as_variant() { + numeric_perfectly_shredded_test!(i8, perfectly_shredded_int8_variant_array); + } + + #[test] + fn get_variant_perfectly_shredded_int16_as_variant() { + numeric_perfectly_shredded_test!(i16, perfectly_shredded_int16_variant_array); + } + + #[test] + fn get_variant_perfectly_shredded_int32_as_variant() { + numeric_perfectly_shredded_test!(i32, perfectly_shredded_int32_variant_array); + } + + #[test] + fn get_variant_perfectly_shredded_int64_as_variant() { + numeric_perfectly_shredded_test!(i64, perfectly_shredded_int64_variant_array); + } + + #[test] + fn get_variant_perfectly_shredded_uint8_as_variant() { + numeric_perfectly_shredded_test!(u8, perfectly_shredded_uint8_variant_array); + } + + #[test] + fn get_variant_perfectly_shredded_uint16_as_variant() { + numeric_perfectly_shredded_test!(u16, perfectly_shredded_uint16_variant_array); + } + + #[test] + fn get_variant_perfectly_shredded_uint32_as_variant() { + numeric_perfectly_shredded_test!(u32, perfectly_shredded_uint32_variant_array); + } + + #[test] + fn get_variant_perfectly_shredded_uint64_as_variant() { + numeric_perfectly_shredded_test!(u64, perfectly_shredded_uint64_variant_array); + } + + #[test] + fn get_variant_perfectly_shredded_float16_as_variant() { + numeric_perfectly_shredded_test!(half::f16, perfectly_shredded_float16_variant_array); + } + + #[test] + fn get_variant_perfectly_shredded_float32_as_variant() { + numeric_perfectly_shredded_test!(f32, perfectly_shredded_float32_variant_array); + } + + #[test] + fn get_variant_perfectly_shredded_float64_as_variant() { + numeric_perfectly_shredded_test!(f64, perfectly_shredded_float64_variant_array); + } + + /// AllNull: extract a value as a VariantArray + #[test] + fn get_variant_all_null_as_variant() { + let array = all_null_variant_array(); + let options = GetOptions::new(); + let result = variant_get(&array, options).unwrap(); + + // expect the result is a VariantArray + let result: &VariantArray = result.as_any().downcast_ref().unwrap(); + assert_eq!(result.len(), 3); + + // All values should be null + assert!(!result.is_valid(0)); + assert!(!result.is_valid(1)); + assert!(!result.is_valid(2)); + } + + /// AllNull: extract a value as an Int32Array + #[test] + fn get_variant_all_null_as_int32() { + let array = all_null_variant_array(); + // specify we want the typed value as Int32 + let field = Field::new("typed_value", DataType::Int32, true); + let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&array, options).unwrap(); + + let expected: ArrayRef = Arc::new(Int32Array::from(vec![ + Option::::None, + Option::::None, + Option::::None, + ])); + assert_eq!(&result, &expected) + } + + macro_rules! perfectly_shredded_to_arrow_primitive_test { + ($name:ident, $primitive_type:ident, $perfectly_shredded_array_gen_fun:ident, $expected_array:expr) => { + #[test] + fn $name() { + let array = $perfectly_shredded_array_gen_fun(); + let field = Field::new("typed_value", $primitive_type, true); + let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&array, options).unwrap(); + let expected_array: ArrayRef = Arc::new($expected_array); + assert_eq!(&result, &expected_array); + } + }; + } + + perfectly_shredded_to_arrow_primitive_test!( + get_variant_perfectly_shredded_int16_as_int16, + Int16, + perfectly_shredded_int16_variant_array, + Int16Array::from(vec![Some(1), Some(2), Some(3)]) + ); + + perfectly_shredded_to_arrow_primitive_test!( + get_variant_perfectly_shredded_int32_as_int32, + Int32, + perfectly_shredded_int32_variant_array, + Int32Array::from(vec![Some(1), Some(2), Some(3)]) + ); + + perfectly_shredded_to_arrow_primitive_test!( + get_variant_perfectly_shredded_int64_as_int64, + Int64, + perfectly_shredded_int64_variant_array, + Int64Array::from(vec![Some(1), Some(2), Some(3)]) + ); + + perfectly_shredded_to_arrow_primitive_test!( + get_variant_perfectly_shredded_uint8_as_int8, + UInt8, + perfectly_shredded_uint8_variant_array, + UInt8Array::from(vec![Some(1), Some(2), Some(3)]) + ); + + perfectly_shredded_to_arrow_primitive_test!( + get_variant_perfectly_shredded_uint16_as_uint16, + UInt16, + perfectly_shredded_uint16_variant_array, + UInt16Array::from(vec![Some(1), Some(2), Some(3)]) + ); + + perfectly_shredded_to_arrow_primitive_test!( + get_variant_perfectly_shredded_uint32_as_uint32, + UInt32, + perfectly_shredded_uint32_variant_array, + UInt32Array::from(vec![Some(1), Some(2), Some(3)]) + ); + + perfectly_shredded_to_arrow_primitive_test!( + get_variant_perfectly_shredded_uint64_as_uint64, + UInt64, + perfectly_shredded_uint64_variant_array, + UInt64Array::from(vec![Some(1), Some(2), Some(3)]) + ); + + /// Return a VariantArray that represents a perfectly "shredded" variant + /// for the given typed value. + /// + /// The schema of the corresponding `StructArray` would look like this: + /// + /// ```text + /// StructArray { + /// metadata: BinaryViewArray, + /// typed_value: Int32Array, + /// } + /// ``` + macro_rules! numeric_perfectly_shredded_variant_array_fn { + ($func:ident, $array_type:ident, $primitive_type:ty) => { + fn $func() -> ArrayRef { + // At the time of writing, the `VariantArrayBuilder` does not support shredding. + // so we must construct the array manually. see https://github.com/apache/arrow-rs/issues/7895 + let metadata = BinaryViewArray::from_iter_values(std::iter::repeat_n( + EMPTY_VARIANT_METADATA_BYTES, + 3, + )); + let typed_value = $array_type::from(vec![ + Some(<$primitive_type>::try_from(1u8).unwrap()), + Some(<$primitive_type>::try_from(2u8).unwrap()), + Some(<$primitive_type>::try_from(3u8).unwrap()), + ]); + + let struct_array = StructArrayBuilder::new() + .with_field("metadata", Arc::new(metadata), false) + .with_field("typed_value", Arc::new(typed_value), true) + .build(); + + Arc::new( + VariantArray::try_new(Arc::new(struct_array)) + .expect("should create variant array"), + ) + } + }; + } + + numeric_perfectly_shredded_variant_array_fn!( + perfectly_shredded_int8_variant_array, + Int8Array, + i8 + ); + numeric_perfectly_shredded_variant_array_fn!( + perfectly_shredded_int16_variant_array, + Int16Array, + i16 + ); + numeric_perfectly_shredded_variant_array_fn!( + perfectly_shredded_int32_variant_array, + Int32Array, + i32 + ); + numeric_perfectly_shredded_variant_array_fn!( + perfectly_shredded_int64_variant_array, + Int64Array, + i64 + ); + numeric_perfectly_shredded_variant_array_fn!( + perfectly_shredded_uint8_variant_array, + UInt8Array, + u8 + ); + numeric_perfectly_shredded_variant_array_fn!( + perfectly_shredded_uint16_variant_array, + UInt16Array, + u16 + ); + numeric_perfectly_shredded_variant_array_fn!( + perfectly_shredded_uint32_variant_array, + UInt32Array, + u32 + ); + numeric_perfectly_shredded_variant_array_fn!( + perfectly_shredded_uint64_variant_array, + UInt64Array, + u64 + ); + numeric_perfectly_shredded_variant_array_fn!( + perfectly_shredded_float16_variant_array, + Float16Array, + half::f16 + ); + numeric_perfectly_shredded_variant_array_fn!( + perfectly_shredded_float32_variant_array, + Float32Array, + f32 + ); + numeric_perfectly_shredded_variant_array_fn!( + perfectly_shredded_float64_variant_array, + Float64Array, + f64 + ); + + /// Return a VariantArray that represents a normal "shredded" variant + /// for the following example + /// + /// Based on the example from [the doc] + /// + /// [the doc]: https://docs.google.com/document/d/1pw0AWoMQY3SjD7R4LgbPvMjG_xSCtXp3rZHkVp9jpZ4/edit?tab=t.0 + /// + /// ```text + /// 34 + /// null (an Arrow NULL, not a Variant::Null) + /// "n/a" (a string) + /// 100 + /// ``` + /// + /// The schema of the corresponding `StructArray` would look like this: + /// + /// ```text + /// StructArray { + /// metadata: BinaryViewArray, + /// value: BinaryViewArray, + /// typed_value: Int32Array, + /// } + /// ``` + macro_rules! numeric_partially_shredded_variant_array_fn { + ($func:ident, $array_type:ident, $primitive_type:ty) => { + fn $func() -> ArrayRef { + // At the time of writing, the `VariantArrayBuilder` does not support shredding. + // so we must construct the array manually. see https://github.com/apache/arrow-rs/issues/7895 + let (metadata, string_value) = { + let mut builder = parquet_variant::VariantBuilder::new(); + builder.append_value("n/a"); + builder.finish() + }; + + let nulls = NullBuffer::from(vec![ + true, // row 0 non null + false, // row 1 is null + true, // row 2 non null + true, // row 3 non null + ]); + + // metadata is the same for all rows + let metadata = BinaryViewArray::from_iter_values(std::iter::repeat_n(&metadata, 4)); + + // See https://docs.google.com/document/d/1pw0AWoMQY3SjD7R4LgbPvMjG_xSCtXp3rZHkVp9jpZ4/edit?disco=AAABml8WQrY + // about why row1 is an empty but non null, value. + let values = BinaryViewArray::from(vec![ + None, // row 0 is shredded, so no value + Some(b"" as &[u8]), // row 1 is null, so empty value (why?) + Some(&string_value), // copy the string value "N/A" + None, // row 3 is shredded, so no value + ]); + + let typed_value = $array_type::from(vec![ + Some(<$primitive_type>::try_from(34u8).unwrap()), // row 0 is shredded, so it has a value + None, // row 1 is null, so no value + None, // row 2 is a string, so no typed value + Some(<$primitive_type>::try_from(100u8).unwrap()), // row 3 is shredded, so it has a value + ]); + + let struct_array = StructArrayBuilder::new() + .with_field("metadata", Arc::new(metadata), false) + .with_field("typed_value", Arc::new(typed_value), true) + .with_field("value", Arc::new(values), true) + .with_nulls(nulls) + .build(); + + Arc::new( + VariantArray::try_new(Arc::new(struct_array)) + .expect("should create variant array"), + ) + } + }; + } + + numeric_partially_shredded_variant_array_fn!( + partially_shredded_int8_variant_array, + Int8Array, + i8 + ); + numeric_partially_shredded_variant_array_fn!( + partially_shredded_int16_variant_array, + Int16Array, + i16 + ); + numeric_partially_shredded_variant_array_fn!( + partially_shredded_int32_variant_array, + Int32Array, + i32 + ); + numeric_partially_shredded_variant_array_fn!( + partially_shredded_int64_variant_array, + Int64Array, + i64 + ); + numeric_partially_shredded_variant_array_fn!( + partially_shredded_uint8_variant_array, + UInt8Array, + u8 + ); + numeric_partially_shredded_variant_array_fn!( + partially_shredded_uint16_variant_array, + UInt16Array, + u16 + ); + numeric_partially_shredded_variant_array_fn!( + partially_shredded_uint32_variant_array, + UInt32Array, + u32 + ); + numeric_partially_shredded_variant_array_fn!( + partially_shredded_uint64_variant_array, + UInt64Array, + u64 + ); + numeric_partially_shredded_variant_array_fn!( + partially_shredded_float16_variant_array, + Float16Array, + half::f16 + ); + numeric_partially_shredded_variant_array_fn!( + partially_shredded_float32_variant_array, + Float32Array, + f32 + ); + numeric_partially_shredded_variant_array_fn!( + partially_shredded_float64_variant_array, + Float64Array, + f64 + ); + + /// Return a VariantArray that represents a partially "shredded" variant for bool + fn partially_shredded_bool_variant_array() -> ArrayRef { + let (metadata, string_value) = { + let mut builder = parquet_variant::VariantBuilder::new(); + builder.append_value("n/a"); + builder.finish() + }; + + let nulls = NullBuffer::from(vec![ + true, // row 0 non null + false, // row 1 is null + true, // row 2 non null + true, // row 3 non null + ]); + + // metadata is the same for all rows + let metadata = BinaryViewArray::from_iter_values(std::iter::repeat_n(&metadata, 4)); + + // See https://docs.google.com/document/d/1pw0AWoMQY3SjD7R4LgbPvMjG_xSCtXp3rZHkVp9jpZ4/edit?disco=AAABml8WQrY + // about why row1 is an empty but non null, value. + let values = BinaryViewArray::from(vec![ + None, // row 0 is shredded, so no value + Some(b"" as &[u8]), // row 1 is null, so empty value (why?) + Some(&string_value), // copy the string value "N/A" + None, // row 3 is shredded, so no value + ]); + + let typed_value = arrow::array::BooleanArray::from(vec![ + Some(true), // row 0 is shredded, so it has a value + None, // row 1 is null, so no value + None, // row 2 is a string, so no typed value + Some(false), // row 3 is shredded, so it has a value + ]); + + let struct_array = StructArrayBuilder::new() + .with_field("metadata", Arc::new(metadata), true) + .with_field("typed_value", Arc::new(typed_value), true) + .with_field("value", Arc::new(values), true) + .with_nulls(nulls) + .build(); + + Arc::new( + VariantArray::try_new(Arc::new(struct_array)).expect("should create variant array"), + ) + } + + /// Return a VariantArray that represents a partially "shredded" variant for fixed size binary + fn partially_shredded_fixed_size_binary_variant_array() -> ArrayRef { + let (metadata, string_value) = { + let mut builder = parquet_variant::VariantBuilder::new(); + builder.append_value("n/a"); + builder.finish() + }; + + // Create the null buffer for the overall array + let nulls = NullBuffer::from(vec![ + true, // row 0 non null + false, // row 1 is null + true, // row 2 non null + true, // row 3 non null + ]); + + // metadata is the same for all rows + let metadata = BinaryViewArray::from_iter_values(std::iter::repeat_n(&metadata, 4)); + + // See https://docs.google.com/document/d/1pw0AWoMQY3SjD7R4LgbPvMjG_xSCtXp3rZHkVp9jpZ4/edit?disco=AAABml8WQrY + // about why row1 is an empty but non null, value. + let values = BinaryViewArray::from(vec![ + None, // row 0 is shredded, so no value + Some(b"" as &[u8]), // row 1 is null, so empty value + Some(&string_value), // copy the string value "N/A" + None, // row 3 is shredded, so no value + ]); + + // Create fixed size binary array with 3-byte values + let data = vec![ + 1u8, 2u8, 3u8, // row 0 is shredded + 0u8, 0u8, 0u8, // row 1 is null (value doesn't matter) + 0u8, 0u8, 0u8, // row 2 is a string (value doesn't matter) + 4u8, 5u8, 6u8, // row 3 is shredded + ]; + let typed_value_nulls = arrow::buffer::NullBuffer::from(vec![ + true, // row 0 has value + false, // row 1 is null + false, // row 2 is string + true, // row 3 has value + ]); + let typed_value = arrow::array::FixedSizeBinaryArray::try_new( + 3, // byte width + arrow::buffer::Buffer::from(data), + Some(typed_value_nulls), + ) + .expect("should create fixed size binary array"); + + let struct_array = StructArrayBuilder::new() + .with_field("metadata", Arc::new(metadata), true) + .with_field("typed_value", Arc::new(typed_value), true) + .with_field("value", Arc::new(values), true) + .with_nulls(nulls) + .build(); + + Arc::new( + VariantArray::try_new(Arc::new(struct_array)).expect("should create variant array"), + ) + } + + /// Return a VariantArray that represents a partially "shredded" variant for UTF8 + fn partially_shredded_utf8_variant_array() -> ArrayRef { + let (metadata, string_value) = { + let mut builder = parquet_variant::VariantBuilder::new(); + builder.append_value("n/a"); + builder.finish() + }; + + // Create the null buffer for the overall array + let nulls = NullBuffer::from(vec![ + true, // row 0 non null + false, // row 1 is null + true, // row 2 non null + true, // row 3 non null + ]); + + // metadata is the same for all rows + let metadata = BinaryViewArray::from_iter_values(std::iter::repeat_n(&metadata, 4)); + + // See https://docs.google.com/document/d/1pw0AWoMQY3SjD7R4LgbPvMjG_xSCtXp3rZHkVp9jpZ4/edit?disco=AAABml8WQrY + // about why row1 is an empty but non null, value. + let values = BinaryViewArray::from(vec![ + None, // row 0 is shredded, so no value + Some(b"" as &[u8]), // row 1 is null, so empty value + Some(&string_value), // copy the string value "N/A" + None, // row 3 is shredded, so no value + ]); + + let typed_value = StringArray::from(vec![ + Some("hello"), // row 0 is shredded + None, // row 1 is null + None, // row 2 is a string + Some("world"), // row 3 is shredded + ]); + + let struct_array = StructArrayBuilder::new() + .with_field("metadata", Arc::new(metadata), true) + .with_field("typed_value", Arc::new(typed_value), true) + .with_field("value", Arc::new(values), true) + .with_nulls(nulls) + .build(); + + Arc::new( + VariantArray::try_new(Arc::new(struct_array)).expect("should create variant array"), + ) + } + + /// Return a VariantArray that represents a partially "shredded" variant for BinaryView + fn partially_shredded_binary_view_variant_array() -> ArrayRef { + let (metadata, string_value) = { + let mut builder = parquet_variant::VariantBuilder::new(); + builder.append_value("n/a"); + builder.finish() + }; + + // Create the null buffer for the overall array + let nulls = NullBuffer::from(vec![ + true, // row 0 non null + false, // row 1 is null + true, // row 2 non null + true, // row 3 non null + ]); + + // metadata is the same for all rows + let metadata = BinaryViewArray::from_iter_values(std::iter::repeat_n(&metadata, 4)); + + // See https://docs.google.com/document/d/1pw0AWoMQY3SjD7R4LgbPvMjG_xSCtXp3rZHkVp9jpZ4/edit?disco=AAABml8WQrY + // about why row1 is an empty but non null, value. + let values = BinaryViewArray::from(vec![ + None, // row 0 is shredded, so no value + Some(b"" as &[u8]), // row 1 is null, so empty value + Some(&string_value), // copy the string value "N/A" + None, // row 3 is shredded, so no value + ]); + + let typed_value = BinaryViewArray::from(vec![ + Some(&[1u8, 2u8, 3u8][..]), // row 0 is shredded + None, // row 1 is null + None, // row 2 is a string + Some(&[4u8, 5u8, 6u8][..]), // row 3 is shredded + ]); + + let struct_array = StructArrayBuilder::new() + .with_field("metadata", Arc::new(metadata), true) + .with_field("typed_value", Arc::new(typed_value), true) + .with_field("value", Arc::new(values), true) + .with_nulls(nulls) + .build(); + + Arc::new( + VariantArray::try_new(Arc::new(struct_array)).expect("should create variant array"), + ) + } + + /// Return a VariantArray that represents an "all null" variant + /// for the following example (3 null values): + /// + /// ```text + /// null + /// null + /// null + /// ``` + /// + /// The schema of the corresponding `StructArray` would look like this: + /// + /// ```text + /// StructArray { + /// metadata: BinaryViewArray, + /// } + /// ``` + fn all_null_variant_array() -> ArrayRef { + let nulls = NullBuffer::from(vec![ + false, // row 0 is null + false, // row 1 is null + false, // row 2 is null + ]); + + // metadata is the same for all rows (though they're all null) + let metadata = + BinaryViewArray::from_iter_values(std::iter::repeat_n(EMPTY_VARIANT_METADATA_BYTES, 3)); + + let struct_array = StructArrayBuilder::new() + .with_field("metadata", Arc::new(metadata), false) + .with_nulls(nulls) + .build(); + + Arc::new( + VariantArray::try_new(Arc::new(struct_array)).expect("should create variant array"), + ) + } + /// This test manually constructs a shredded variant array representing objects + /// like {"x": 1, "y": "foo"} and {"x": 42} and tests extracting the "x" field + /// as VariantArray using variant_get. + #[test] + fn test_shredded_object_field_access() { + let array = shredded_object_with_x_field_variant_array(); + + // Test: Extract the "x" field as VariantArray first + let options = GetOptions::new_with_path(VariantPath::from("x")); + let result = variant_get(&array, options).unwrap(); + + let result_variant: &VariantArray = result.as_any().downcast_ref().unwrap(); + assert_eq!(result_variant.len(), 2); + + // Row 0: expect x=1 + assert_eq!(result_variant.value(0), Variant::Int32(1)); + // Row 1: expect x=42 + assert_eq!(result_variant.value(1), Variant::Int32(42)); + } + + /// Test extracting shredded object field with type conversion + #[test] + fn test_shredded_object_field_as_int32() { + let array = shredded_object_with_x_field_variant_array(); + + // Test: Extract the "x" field as Int32Array (type conversion) + let field = Field::new("x", DataType::Int32, false); + let options = GetOptions::new_with_path(VariantPath::from("x")) + .with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&array, options).unwrap(); + + // Should get Int32Array + let expected: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), Some(42)])); + assert_eq!(&result, &expected); + } + + /// Helper function to create a shredded variant array representing objects + /// + /// This creates an array that represents: + /// Row 0: {"x": 1, "y": "foo"} (x is shredded, y is in value field) + /// Row 1: {"x": 42} (x is shredded, perfect shredding) + /// + /// The physical layout follows the shredding spec where: + /// - metadata: contains object metadata + /// - typed_value: StructArray with field "x" (ShreddedVariantFieldArray) + /// - value: contains fallback for unshredded fields like {"y": "foo"} + /// - The "x" field has typed_value=Int32Array and value=NULL (perfect shredding) + fn shredded_object_with_x_field_variant_array() -> ArrayRef { + // Create the base metadata for objects + let (metadata, y_field_value) = { + let mut builder = parquet_variant::VariantBuilder::new(); + let mut obj = builder.new_object(); + obj.insert("x", Variant::Int32(42)); + obj.insert("y", Variant::from("foo")); + obj.finish(); + builder.finish() + }; + + // Create metadata array (same for both rows) + let metadata_array = BinaryViewArray::from_iter_values(std::iter::repeat_n(&metadata, 2)); + + // Create the main value field per the 3-step shredding spec: + // Step 2: If field not in shredding schema, check value field + // Row 0: {"y": "foo"} (y is not shredded, stays in value for step 2) + // Row 1: {} (empty object - no unshredded fields) + let empty_object_value = { + let mut builder = parquet_variant::VariantBuilder::new(); + let obj = builder.new_object(); + obj.finish(); + let (_, value) = builder.finish(); + value + }; + + let value_array = BinaryViewArray::from(vec![ + Some(y_field_value.as_slice()), // Row 0 has {"y": "foo"} + Some(empty_object_value.as_slice()), // Row 1 has {} + ]); + + // Create the "x" field as a ShreddedVariantFieldArray + // This represents the shredded Int32 values for the "x" field + let x_field_typed_value = Int32Array::from(vec![Some(1), Some(42)]); + + // For perfect shredding of the x field, no "value" column, only typed_value + let x_field_struct = StructArrayBuilder::new() + .with_field("typed_value", Arc::new(x_field_typed_value), true) + .build(); + + // Wrap the x field struct in a ShreddedVariantFieldArray + let x_field_shredded = ShreddedVariantFieldArray::try_new(Arc::new(x_field_struct)) + .expect("should create ShreddedVariantFieldArray"); + + // Create the main typed_value as a struct containing the "x" field + let typed_value_fields = Fields::from(vec![Field::new( + "x", + x_field_shredded.data_type().clone(), + true, + )]); + let typed_value_struct = StructArray::try_new( + typed_value_fields, + vec![Arc::new(x_field_shredded)], + None, // No nulls - both rows have the object structure + ) + .unwrap(); + + // Create the main VariantArray + let main_struct = StructArrayBuilder::new() + .with_field("metadata", Arc::new(metadata_array), false) + .with_field("value", Arc::new(value_array), true) + .with_field("typed_value", Arc::new(typed_value_struct), true) + .build(); + + Arc::new(VariantArray::try_new(Arc::new(main_struct)).expect("should create variant array")) + } + + /// Simple test to check if nested paths are supported by current implementation + #[test] + fn test_simple_nested_path_support() { + // Check: How does VariantPath parse different strings? + println!("Testing path parsing:"); + + let path_x = VariantPath::from("x"); + let elements_x: Vec<_> = path_x.iter().collect(); + println!(" 'x' -> {} elements: {:?}", elements_x.len(), elements_x); + + let path_ax = VariantPath::from("a.x"); + let elements_ax: Vec<_> = path_ax.iter().collect(); + println!( + " 'a.x' -> {} elements: {:?}", + elements_ax.len(), + elements_ax + ); + + let path_ax_alt = VariantPath::from("$.a.x"); + let elements_ax_alt: Vec<_> = path_ax_alt.iter().collect(); + println!( + " '$.a.x' -> {} elements: {:?}", + elements_ax_alt.len(), + elements_ax_alt + ); + + let path_nested = VariantPath::from("a").join("x"); + let elements_nested: Vec<_> = path_nested.iter().collect(); + println!( + " VariantPath::from('a').join('x') -> {} elements: {:?}", + elements_nested.len(), + elements_nested + ); + + // Use your existing simple test data but try "a.x" instead of "x" + let array = shredded_object_with_x_field_variant_array(); + + // Test if variant_get with REAL nested path throws not implemented error + let real_nested_path = VariantPath::from("a").join("x"); + let options = GetOptions::new_with_path(real_nested_path); + let result = variant_get(&array, options); + + match result { + Ok(_) => { + println!("Nested path 'a.x' works unexpectedly!"); + } + Err(e) => { + println!("Nested path 'a.x' error: {}", e); + if e.to_string().contains("not yet implemented") + || e.to_string().contains("NotYetImplemented") + { + println!("This is expected - nested paths are not implemented"); + return; + } + // Any other error is also expected for now + println!("This shows nested paths need implementation"); + } + } + } + + /// Test comprehensive variant_get scenarios with Int32 conversion + /// Test depth 0: Direct field access "x" with Int32 conversion + /// Covers shredded vs non-shredded VariantArrays for simple field access + #[test] + fn test_depth_0_int32_conversion() { + println!("=== Testing Depth 0: Direct field access ==="); + + // Non-shredded test data: [{"x": 42}, {"x": "foo"}, {"y": 10}] + let unshredded_array = create_depth_0_test_data(); + + let field = Field::new("result", DataType::Int32, true); + let path = VariantPath::from("x"); + let options = GetOptions::new_with_path(path).with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&unshredded_array, options).unwrap(); + + let expected: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(42), // {"x": 42} -> 42 + None, // {"x": "foo"} -> NULL (type mismatch) + None, // {"y": 10} -> NULL (field missing) + ])); + assert_eq!(&result, &expected); + println!("Depth 0 (unshredded) passed"); + + // Shredded test data: using simplified approach based on working pattern + let shredded_array = create_depth_0_shredded_test_data_simple(); + + let field = Field::new("result", DataType::Int32, true); + let path = VariantPath::from("x"); + let options = GetOptions::new_with_path(path).with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&shredded_array, options).unwrap(); + + let expected: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(42), // {"x": 42} -> 42 (from typed_value) + None, // {"x": "foo"} -> NULL (type mismatch, from value field) + ])); + assert_eq!(&result, &expected); + println!("Depth 0 (shredded) passed"); + } + + /// Test depth 1: Single nested field access "a.x" with Int32 conversion + /// Covers shredded vs non-shredded VariantArrays for nested field access + #[test] + fn test_depth_1_int32_conversion() { + println!("=== Testing Depth 1: Single nested field access ==="); + + // Non-shredded test data from the GitHub issue + let unshredded_array = create_nested_path_test_data(); + + let field = Field::new("result", DataType::Int32, true); + let path = VariantPath::from("a.x"); // Dot notation! + let options = GetOptions::new_with_path(path).with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&unshredded_array, options).unwrap(); + + let expected: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(55), // {"a": {"x": 55}} -> 55 + None, // {"a": {"x": "foo"}} -> NULL (type mismatch) + ])); + assert_eq!(&result, &expected); + println!("Depth 1 (unshredded) passed"); + + // Shredded test data: depth 1 nested shredding + let shredded_array = create_depth_1_shredded_test_data_working(); + + let field = Field::new("result", DataType::Int32, true); + let path = VariantPath::from("a.x"); // Dot notation! + let options = GetOptions::new_with_path(path).with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&shredded_array, options).unwrap(); + + let expected: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(55), // {"a": {"x": 55}} -> 55 (from nested shredded x) + None, // {"a": {"x": "foo"}} -> NULL (type mismatch in nested value) + ])); + assert_eq!(&result, &expected); + println!("Depth 1 (shredded) passed"); + } + + /// Test depth 2: Double nested field access "a.b.x" with Int32 conversion + /// Covers shredded vs non-shredded VariantArrays for deeply nested field access + #[test] + fn test_depth_2_int32_conversion() { + println!("=== Testing Depth 2: Double nested field access ==="); + + // Non-shredded test data: [{"a": {"b": {"x": 100}}}, {"a": {"b": {"x": "bar"}}}, {"a": {"b": {"y": 200}}}] + let unshredded_array = create_depth_2_test_data(); + + let field = Field::new("result", DataType::Int32, true); + let path = VariantPath::from("a.b.x"); // Double nested dot notation! + let options = GetOptions::new_with_path(path).with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&unshredded_array, options).unwrap(); + + let expected: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(100), // {"a": {"b": {"x": 100}}} -> 100 + None, // {"a": {"b": {"x": "bar"}}} -> NULL (type mismatch) + None, // {"a": {"b": {"y": 200}}} -> NULL (field missing) + ])); + assert_eq!(&result, &expected); + println!("Depth 2 (unshredded) passed"); + + // Shredded test data: depth 2 nested shredding + let shredded_array = create_depth_2_shredded_test_data_working(); + + let field = Field::new("result", DataType::Int32, true); + let path = VariantPath::from("a.b.x"); // Double nested dot notation! + let options = GetOptions::new_with_path(path).with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&shredded_array, options).unwrap(); + + let expected: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(100), // {"a": {"b": {"x": 100}}} -> 100 (from deeply nested shredded x) + None, // {"a": {"b": {"x": "bar"}}} -> NULL (type mismatch in deep value) + None, // {"a": {"b": {"y": 200}}} -> NULL (field missing in deep structure) + ])); + assert_eq!(&result, &expected); + println!("Depth 2 (shredded) passed"); + } + + /// Test that demonstrates what CURRENTLY WORKS + /// + /// This shows that nested path functionality does work, but only when the + /// test data matches what the current implementation expects + #[test] + fn test_current_nested_path_functionality() { + let array = shredded_object_with_x_field_variant_array(); + + // Test: Extract the "x" field (single level) - this works + let single_path = VariantPath::from("x"); + let field = Field::new("result", DataType::Int32, true); + let options = + GetOptions::new_with_path(single_path).with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&array, options).unwrap(); + + println!("Single path 'x' works - result: {:?}", result); + + // Test: Try nested path "a.x" - this is what we need to implement + let nested_path = VariantPath::from("a").join("x"); + let field = Field::new("result", DataType::Int32, true); + let options = + GetOptions::new_with_path(nested_path).with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&array, options).unwrap(); + + println!("Nested path 'a.x' result: {:?}", result); + } + + /// Create test data for depth 0 (direct field access) + /// [{"x": 42}, {"x": "foo"}, {"y": 10}] + fn create_depth_0_test_data() -> ArrayRef { + let mut builder = crate::VariantArrayBuilder::new(3); + + // Row 1: {"x": 42} + { + let json_str = r#"{"x": 42}"#; + let string_array: ArrayRef = Arc::new(StringArray::from(vec![json_str])); + if let Ok(variant_array) = json_to_variant(&string_array) { + builder.append_variant(variant_array.value(0)); + } else { + builder.append_null(); + } + } + + // Row 2: {"x": "foo"} + { + let json_str = r#"{"x": "foo"}"#; + let string_array: ArrayRef = Arc::new(StringArray::from(vec![json_str])); + if let Ok(variant_array) = json_to_variant(&string_array) { + builder.append_variant(variant_array.value(0)); + } else { + builder.append_null(); + } + } + + // Row 3: {"y": 10} (missing "x" field) + { + let json_str = r#"{"y": 10}"#; + let string_array: ArrayRef = Arc::new(StringArray::from(vec![json_str])); + if let Ok(variant_array) = json_to_variant(&string_array) { + builder.append_variant(variant_array.value(0)); + } else { + builder.append_null(); + } + } + + Arc::new(builder.build()) + } + + /// Create test data for depth 1 (single nested field) + /// This represents the exact scenarios from the GitHub issue: "a.x" + fn create_nested_path_test_data() -> ArrayRef { + let mut builder = crate::VariantArrayBuilder::new(2); + + // Row 1: {"a": {"x": 55}, "b": 42} + { + let json_str = r#"{"a": {"x": 55}, "b": 42}"#; + let string_array: ArrayRef = Arc::new(StringArray::from(vec![json_str])); + if let Ok(variant_array) = json_to_variant(&string_array) { + builder.append_variant(variant_array.value(0)); + } else { + builder.append_null(); + } + } + + // Row 2: {"a": {"x": "foo"}, "b": 42} + { + let json_str = r#"{"a": {"x": "foo"}, "b": 42}"#; + let string_array: ArrayRef = Arc::new(StringArray::from(vec![json_str])); + if let Ok(variant_array) = json_to_variant(&string_array) { + builder.append_variant(variant_array.value(0)); + } else { + builder.append_null(); + } + } + + Arc::new(builder.build()) + } + + /// Create test data for depth 2 (double nested field) + /// [{"a": {"b": {"x": 100}}}, {"a": {"b": {"x": "bar"}}}, {"a": {"b": {"y": 200}}}] + fn create_depth_2_test_data() -> ArrayRef { + let mut builder = crate::VariantArrayBuilder::new(3); + + // Row 1: {"a": {"b": {"x": 100}}} + { + let json_str = r#"{"a": {"b": {"x": 100}}}"#; + let string_array: ArrayRef = Arc::new(StringArray::from(vec![json_str])); + if let Ok(variant_array) = json_to_variant(&string_array) { + builder.append_variant(variant_array.value(0)); + } else { + builder.append_null(); + } + } + + // Row 2: {"a": {"b": {"x": "bar"}}} + { + let json_str = r#"{"a": {"b": {"x": "bar"}}}"#; + let string_array: ArrayRef = Arc::new(StringArray::from(vec![json_str])); + if let Ok(variant_array) = json_to_variant(&string_array) { + builder.append_variant(variant_array.value(0)); + } else { + builder.append_null(); + } + } + + // Row 3: {"a": {"b": {"y": 200}}} (missing "x" field) + { + let json_str = r#"{"a": {"b": {"y": 200}}}"#; + let string_array: ArrayRef = Arc::new(StringArray::from(vec![json_str])); + if let Ok(variant_array) = json_to_variant(&string_array) { + builder.append_variant(variant_array.value(0)); + } else { + builder.append_null(); + } + } + + Arc::new(builder.build()) + } + + /// Create simple shredded test data for depth 0 using a simplified working pattern + /// Creates 2 rows: [{"x": 42}, {"x": "foo"}] with "x" shredded where possible + fn create_depth_0_shredded_test_data_simple() -> ArrayRef { + // Create base metadata using the working pattern + let (metadata, string_x_value) = { + let mut builder = parquet_variant::VariantBuilder::new(); + let mut obj = builder.new_object(); + obj.insert("x", Variant::from("foo")); + obj.finish(); + builder.finish() + }; + + // Metadata array (same for both rows) + let metadata_array = BinaryViewArray::from_iter_values(std::iter::repeat_n(&metadata, 2)); + + // Value array following the 3-step shredding spec: + // Row 0: {} (x is shredded, no unshredded fields) + // Row 1: {"x": "foo"} (x is a string, can't be shredded to Int32) + let empty_object_value = { + let mut builder = parquet_variant::VariantBuilder::new(); + let obj = builder.new_object(); + obj.finish(); + let (_, value) = builder.finish(); + value + }; + + let value_array = BinaryViewArray::from(vec![ + Some(empty_object_value.as_slice()), // Row 0: {} (x shredded out) + Some(string_x_value.as_slice()), // Row 1: {"x": "foo"} (fallback) + ]); + + // Create the "x" field as a ShreddedVariantFieldArray + let x_field_typed_value = Int32Array::from(vec![Some(42), None]); + + // For the x field, only typed_value (perfect shredding when possible) + let x_field_struct = StructArrayBuilder::new() + .with_field("typed_value", Arc::new(x_field_typed_value), true) + .build(); + + let x_field_shredded = ShreddedVariantFieldArray::try_new(Arc::new(x_field_struct)) + .expect("should create ShreddedVariantFieldArray"); + + // Create the main typed_value as a struct containing the "x" field + let typed_value_fields = Fields::from(vec![Field::new( + "x", + x_field_shredded.data_type().clone(), + true, + )]); + let typed_value_struct = + StructArray::try_new(typed_value_fields, vec![Arc::new(x_field_shredded)], None) + .unwrap(); + + // Build final VariantArray + let struct_array = StructArrayBuilder::new() + .with_field("metadata", Arc::new(metadata_array), false) + .with_field("value", Arc::new(value_array), true) + .with_field("typed_value", Arc::new(typed_value_struct), true) + .build(); + + Arc::new(VariantArray::try_new(Arc::new(struct_array)).expect("should create VariantArray")) + } + + /// Create working depth 1 shredded test data based on the existing working pattern + /// This creates a properly structured shredded variant for "a.x" where: + /// - Row 0: {"a": {"x": 55}, "b": 42} with a.x shredded into typed_value + /// - Row 1: {"a": {"x": "foo"}, "b": 42} with a.x fallback to value field due to type mismatch + fn create_depth_1_shredded_test_data_working() -> ArrayRef { + // Create metadata following the working pattern from shredded_object_with_x_field_variant_array + let (metadata, _) = { + // Create nested structure: {"a": {"x": 55}, "b": 42} + let mut builder = parquet_variant::VariantBuilder::new(); + let mut obj = builder.new_object(); + + // Create the nested "a" object + let mut a_obj = obj.new_object("a"); + a_obj.insert("x", Variant::Int32(55)); + a_obj.finish(); + + obj.insert("b", Variant::Int32(42)); + obj.finish(); + builder.finish() + }; + + let metadata_array = BinaryViewArray::from_iter_values(std::iter::repeat_n(&metadata, 2)); + + // Create value arrays for the fallback case + // Following the spec: if field cannot be shredded, it stays in value + let empty_object_value = { + let mut builder = parquet_variant::VariantBuilder::new(); + let obj = builder.new_object(); + obj.finish(); + let (_, value) = builder.finish(); + value + }; + + // Row 1 fallback: use the working pattern from the existing shredded test + // This avoids metadata issues by using the simple fallback approach + let row1_fallback = { + let mut builder = parquet_variant::VariantBuilder::new(); + let mut obj = builder.new_object(); + obj.insert("fallback", Variant::from("data")); + obj.finish(); + let (_, value) = builder.finish(); + value + }; + + let value_array = BinaryViewArray::from(vec![ + Some(empty_object_value.as_slice()), // Row 0: {} (everything shredded except b in unshredded fields) + Some(row1_fallback.as_slice()), // Row 1: {"a": {"x": "foo"}, "b": 42} (a.x can't be shredded) + ]); + + // Create the nested shredded structure + // Level 2: x field (the deepest level) + let x_typed_value = Int32Array::from(vec![Some(55), None]); + let x_field_struct = StructArrayBuilder::new() + .with_field("typed_value", Arc::new(x_typed_value), true) + .build(); + let x_field_shredded = ShreddedVariantFieldArray::try_new(Arc::new(x_field_struct)) + .expect("should create ShreddedVariantFieldArray for x"); + + // Level 1: a field containing x field + value field for fallbacks + // The "a" field needs both typed_value (for shredded x) and value (for fallback cases) + + // Create the value field for "a" (for cases where a.x can't be shredded) + let a_value_data = { + let mut builder = parquet_variant::VariantBuilder::new(); + let obj = builder.new_object(); + obj.finish(); + let (_, value) = builder.finish(); + value + }; + let a_value_array = BinaryViewArray::from(vec![ + None, // Row 0: x is shredded, so no value fallback needed + Some(a_value_data.as_slice()), // Row 1: fallback for a.x="foo" (but logic will check typed_value first) + ]); + + let a_inner_fields = Fields::from(vec![Field::new( + "x", + x_field_shredded.data_type().clone(), + true, + )]); + let a_inner_struct = StructArrayBuilder::new() + .with_field( + "typed_value", + Arc::new( + StructArray::try_new(a_inner_fields, vec![Arc::new(x_field_shredded)], None) + .unwrap(), + ), + true, + ) + .with_field("value", Arc::new(a_value_array), true) + .build(); + let a_field_shredded = ShreddedVariantFieldArray::try_new(Arc::new(a_inner_struct)) + .expect("should create ShreddedVariantFieldArray for a"); + + // Level 0: main typed_value struct containing a field + let typed_value_fields = Fields::from(vec![Field::new( + "a", + a_field_shredded.data_type().clone(), + true, + )]); + let typed_value_struct = + StructArray::try_new(typed_value_fields, vec![Arc::new(a_field_shredded)], None) + .unwrap(); + + // Build final VariantArray + let struct_array = StructArrayBuilder::new() + .with_field("metadata", Arc::new(metadata_array), false) + .with_field("value", Arc::new(value_array), true) + .with_field("typed_value", Arc::new(typed_value_struct), true) + .build(); + + Arc::new(VariantArray::try_new(Arc::new(struct_array)).expect("should create VariantArray")) + } + + /// Create working depth 2 shredded test data for "a.b.x" paths + /// This creates a 3-level nested shredded structure where: + /// - Row 0: {"a": {"b": {"x": 100}}} with a.b.x shredded into typed_value + /// - Row 1: {"a": {"b": {"x": "bar"}}} with type mismatch fallback + /// - Row 2: {"a": {"b": {"y": 200}}} with missing field fallback + fn create_depth_2_shredded_test_data_working() -> ArrayRef { + // Create metadata following the working pattern + let (metadata, _) = { + // Create deeply nested structure: {"a": {"b": {"x": 100}}} + let mut builder = parquet_variant::VariantBuilder::new(); + let mut obj = builder.new_object(); + + // Create the nested "a.b" structure + let mut a_obj = obj.new_object("a"); + let mut b_obj = a_obj.new_object("b"); + b_obj.insert("x", Variant::Int32(100)); + b_obj.finish(); + a_obj.finish(); + + obj.finish(); + builder.finish() + }; + + let metadata_array = BinaryViewArray::from_iter_values(std::iter::repeat_n(&metadata, 3)); + + // Create value arrays for fallback cases + let empty_object_value = { + let mut builder = parquet_variant::VariantBuilder::new(); + let obj = builder.new_object(); + obj.finish(); + let (_, value) = builder.finish(); + value + }; + + // Simple fallback values - avoiding complex nested metadata + let value_array = BinaryViewArray::from(vec![ + Some(empty_object_value.as_slice()), // Row 0: fully shredded + Some(empty_object_value.as_slice()), // Row 1: fallback (simplified) + Some(empty_object_value.as_slice()), // Row 2: fallback (simplified) + ]); + + // Create the deeply nested shredded structure: a.b.x + + // Level 3: x field (deepest level) + let x_typed_value = Int32Array::from(vec![Some(100), None, None]); + let x_field_struct = StructArrayBuilder::new() + .with_field("typed_value", Arc::new(x_typed_value), true) + .build(); + let x_field_shredded = ShreddedVariantFieldArray::try_new(Arc::new(x_field_struct)) + .expect("should create ShreddedVariantFieldArray for x"); + + // Level 2: b field containing x field + value field + let b_value_data = { + let mut builder = parquet_variant::VariantBuilder::new(); + let obj = builder.new_object(); + obj.finish(); + let (_, value) = builder.finish(); + value + }; + let b_value_array = BinaryViewArray::from(vec![ + None, // Row 0: x is shredded + Some(b_value_data.as_slice()), // Row 1: fallback for b.x="bar" + Some(b_value_data.as_slice()), // Row 2: fallback for b.y=200 + ]); + + let b_inner_fields = Fields::from(vec![Field::new( + "x", + x_field_shredded.data_type().clone(), + true, + )]); + let b_inner_struct = StructArrayBuilder::new() + .with_field( + "typed_value", + Arc::new( + StructArray::try_new(b_inner_fields, vec![Arc::new(x_field_shredded)], None) + .unwrap(), + ), + true, + ) + .with_field("value", Arc::new(b_value_array), true) + .build(); + let b_field_shredded = ShreddedVariantFieldArray::try_new(Arc::new(b_inner_struct)) + .expect("should create ShreddedVariantFieldArray for b"); + + // Level 1: a field containing b field + value field + let a_value_data = { + let mut builder = parquet_variant::VariantBuilder::new(); + let obj = builder.new_object(); + obj.finish(); + let (_, value) = builder.finish(); + value + }; + let a_value_array = BinaryViewArray::from(vec![ + None, // Row 0: b is shredded + Some(a_value_data.as_slice()), // Row 1: fallback for a.b.* + Some(a_value_data.as_slice()), // Row 2: fallback for a.b.* + ]); + + let a_inner_fields = Fields::from(vec![Field::new( + "b", + b_field_shredded.data_type().clone(), + true, + )]); + let a_inner_struct = StructArrayBuilder::new() + .with_field( + "typed_value", + Arc::new( + StructArray::try_new(a_inner_fields, vec![Arc::new(b_field_shredded)], None) + .unwrap(), + ), + true, + ) + .with_field("value", Arc::new(a_value_array), true) + .build(); + let a_field_shredded = ShreddedVariantFieldArray::try_new(Arc::new(a_inner_struct)) + .expect("should create ShreddedVariantFieldArray for a"); + + // Level 0: main typed_value struct containing a field + let typed_value_fields = Fields::from(vec![Field::new( + "a", + a_field_shredded.data_type().clone(), + true, + )]); + let typed_value_struct = + StructArray::try_new(typed_value_fields, vec![Arc::new(a_field_shredded)], None) + .unwrap(); + + // Build final VariantArray + let struct_array = StructArrayBuilder::new() + .with_field("metadata", Arc::new(metadata_array), false) + .with_field("value", Arc::new(value_array), true) + .with_field("typed_value", Arc::new(typed_value_struct), true) + .build(); + + Arc::new(VariantArray::try_new(Arc::new(struct_array)).expect("should create VariantArray")) + } + + #[test] + fn test_strict_cast_options_downcast_failure() { + use arrow::compute::CastOptions; + use arrow::datatypes::{DataType, Field}; + use arrow::error::ArrowError; + use parquet_variant::VariantPath; + use std::sync::Arc; + + // Use the existing simple test data that has Int32 as typed_value + let variant_array = perfectly_shredded_int32_variant_array(); + + // Try to access a field with safe cast options (should return NULLs) + let safe_options = GetOptions { + path: VariantPath::from("nonexistent_field"), + as_type: Some(Arc::new(Field::new("result", DataType::Int32, true))), + cast_options: CastOptions::default(), // safe = true + }; + + let variant_array_ref: Arc = variant_array.clone(); + let result = variant_get(&variant_array_ref, safe_options); + // Should succeed and return NULLs (safe behavior) + assert!(result.is_ok()); + let result_array = result.unwrap(); + assert_eq!(result_array.len(), 3); + assert!(result_array.is_null(0)); + assert!(result_array.is_null(1)); + assert!(result_array.is_null(2)); + + // Try to access a field with strict cast options (should error) + let strict_options = GetOptions { + path: VariantPath::from("nonexistent_field"), + as_type: Some(Arc::new(Field::new("result", DataType::Int32, true))), + cast_options: CastOptions { + safe: false, + ..Default::default() + }, + }; + + let result = variant_get(&variant_array_ref, strict_options); + // Should fail with a cast error + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(matches!(error, ArrowError::CastError(_))); + assert!(error + .to_string() + .contains("Cannot access field 'nonexistent_field' on non-struct type")); + } + + #[test] + fn test_null_buffer_union_for_shredded_paths() { + use arrow::compute::CastOptions; + use arrow::datatypes::{DataType, Field}; + use parquet_variant::VariantPath; + use std::sync::Arc; + + // Test that null buffers are properly unioned when traversing shredded paths + // This test verifies scovich's null buffer union requirement + + // Create a depth-1 shredded variant array where: + // - The top-level variant array has some nulls + // - The nested typed_value also has some nulls + // - The result should be the union of both null buffers + + let variant_array = create_depth_1_shredded_test_data_working(); + + // Get the field "x" which should union nulls from: + // 1. The top-level variant array nulls + // 2. The "a" field's typed_value nulls + // 3. The "x" field's typed_value nulls + let options = GetOptions { + path: VariantPath::from("a.x"), + as_type: Some(Arc::new(Field::new("result", DataType::Int32, true))), + cast_options: CastOptions::default(), + }; + + let variant_array_ref: Arc = variant_array.clone(); + let result = variant_get(&variant_array_ref, options).unwrap(); + + // Verify the result length matches input + assert_eq!(result.len(), variant_array.len()); + + // The null pattern should reflect the union of all ancestor nulls + // Row 0: Should have valid data (path exists and is shredded as Int32) + // Row 1: Should be null (due to type mismatch - "foo" can't cast to Int32) + assert!(!result.is_null(0), "Row 0 should have valid Int32 data"); + assert!( + result.is_null(1), + "Row 1 should be null due to type casting failure" + ); + + // Verify the actual values + let int32_result = result + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(int32_result.value(0), 55); // The valid Int32 value + } + + #[test] + fn test_struct_null_mask_union_from_children() { + use arrow::compute::CastOptions; + use arrow::datatypes::{DataType, Field, Fields}; + use parquet_variant::VariantPath; + use std::sync::Arc; + + use arrow::array::StringArray; + + // Test that struct null masks properly union nulls from children field extractions + // This verifies scovich's concern about incomplete null masks in struct construction + + // Create test data where some fields will fail type casting + let json_strings = vec![ + r#"{"a": 42, "b": "hello"}"#, // Row 0: a=42 (castable to int), b="hello" (not castable to int) + r#"{"a": "world", "b": 100}"#, // Row 1: a="world" (not castable to int), b=100 (castable to int) + r#"{"a": 55, "b": 77}"#, // Row 2: a=55 (castable to int), b=77 (castable to int) + ]; + + let string_array: Arc = Arc::new(StringArray::from(json_strings)); + let variant_array = json_to_variant(&string_array).unwrap(); + + // Request extraction as a struct with both fields as Int32 + // This should create child arrays where some fields are null due to casting failures + let struct_fields = Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]); + let struct_type = DataType::Struct(struct_fields); + + let options = GetOptions { + path: VariantPath::default(), // Extract the whole object as struct + as_type: Some(Arc::new(Field::new("result", struct_type, true))), + cast_options: CastOptions::default(), + }; + + let variant_array_ref: Arc = Arc::new(variant_array); + let result = variant_get(&variant_array_ref, options).unwrap(); + + // Verify the result is a StructArray + let struct_result = result + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(struct_result.len(), 3); + + // Get the individual field arrays + let field_a = struct_result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let field_b = struct_result + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + // Verify field values and nulls + // Row 0: a=42 (valid), b=null (casting failure) + assert!(!field_a.is_null(0)); + assert_eq!(field_a.value(0), 42); + assert!(field_b.is_null(0)); // "hello" can't cast to int + + // Row 1: a=null (casting failure), b=100 (valid) + assert!(field_a.is_null(1)); // "world" can't cast to int + assert!(!field_b.is_null(1)); + assert_eq!(field_b.value(1), 100); + + // Row 2: a=55 (valid), b=77 (valid) + assert!(!field_a.is_null(2)); + assert_eq!(field_a.value(2), 55); + assert!(!field_b.is_null(2)); + assert_eq!(field_b.value(2), 77); + + // Verify the struct-level null mask properly unions child nulls + // The struct should NOT be null in any row because each row has at least one valid field + // (This tests that we're not incorrectly making the entire struct null when children fail) + assert!(!struct_result.is_null(0)); // Has valid field 'a' + assert!(!struct_result.is_null(1)); // Has valid field 'b' + assert!(!struct_result.is_null(2)); // Has both valid fields + } + + #[test] + fn test_field_nullability_preservation() { + use arrow::compute::CastOptions; + use arrow::datatypes::{DataType, Field}; + use parquet_variant::VariantPath; + use std::sync::Arc; + + use arrow::array::StringArray; + + // Test that field nullability from GetOptions.as_type is preserved in the result + + let json_strings = vec![ + r#"{"x": 42}"#, // Row 0: Valid int that should convert to Int32 + r#"{"x": "not_a_number"}"#, // Row 1: String that can't cast to Int32 + r#"{"x": null}"#, // Row 2: Explicit null value + r#"{"x": "hello"}"#, // Row 3: Another string (wrong type) + r#"{"y": 100}"#, // Row 4: Missing "x" field (SQL NULL case) + r#"{"x": 127}"#, // Row 5: Small int (could be Int8, widening cast candidate) + r#"{"x": 32767}"#, // Row 6: Medium int (could be Int16, widening cast candidate) + r#"{"x": 2147483647}"#, // Row 7: Max Int32 value (fits in Int32) + r#"{"x": 9223372036854775807}"#, // Row 8: Large Int64 value (cannot convert to Int32) + ]; + + let string_array: Arc = Arc::new(StringArray::from(json_strings)); + let variant_array = json_to_variant(&string_array).unwrap(); + + // Test 1: nullable field (should allow nulls from cast failures) + let nullable_field = Arc::new(Field::new("result", DataType::Int32, true)); + let options_nullable = GetOptions { + path: VariantPath::from("x"), + as_type: Some(nullable_field.clone()), + cast_options: CastOptions::default(), + }; + + let variant_array_ref: Arc = Arc::new(variant_array); + let result_nullable = variant_get(&variant_array_ref, options_nullable).unwrap(); + + // Verify we get an Int32Array with nulls for cast failures + let int32_result = result_nullable + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(int32_result.len(), 9); + + // Row 0: 42 converts successfully to Int32 + assert!(!int32_result.is_null(0)); + assert_eq!(int32_result.value(0), 42); + + // Row 1: "not_a_number" fails to convert -> NULL + assert!(int32_result.is_null(1)); + + // Row 2: explicit null value -> NULL + assert!(int32_result.is_null(2)); + + // Row 3: "hello" (wrong type) fails to convert -> NULL + assert!(int32_result.is_null(3)); + + // Row 4: missing "x" field (SQL NULL case) -> NULL + assert!(int32_result.is_null(4)); + + // Row 5: 127 (small int, potential Int8 -> Int32 widening) + // Current behavior: JSON parses to Int8, should convert to Int32 + assert!(!int32_result.is_null(5)); + assert_eq!(int32_result.value(5), 127); + + // Row 6: 32767 (medium int, potential Int16 -> Int32 widening) + // Current behavior: JSON parses to Int16, should convert to Int32 + assert!(!int32_result.is_null(6)); + assert_eq!(int32_result.value(6), 32767); + + // Row 7: 2147483647 (max Int32, fits exactly) + // Current behavior: Should convert successfully + assert!(!int32_result.is_null(7)); + assert_eq!(int32_result.value(7), 2147483647); + + // Row 8: 9223372036854775807 (large Int64, cannot fit in Int32) + // Current behavior: Should fail conversion -> NULL + assert!(int32_result.is_null(8)); + + // Test 2: non-nullable field (behavior should be the same with safe casting) + let non_nullable_field = Arc::new(Field::new("result", DataType::Int32, false)); + let options_non_nullable = GetOptions { + path: VariantPath::from("x"), + as_type: Some(non_nullable_field.clone()), + cast_options: CastOptions::default(), // safe=true by default + }; + + // Create variant array again since we moved it + let variant_array_2 = json_to_variant(&string_array).unwrap(); + let variant_array_ref_2: Arc = Arc::new(variant_array_2); + let result_non_nullable = variant_get(&variant_array_ref_2, options_non_nullable).unwrap(); + let int32_result_2 = result_non_nullable + .as_any() + .downcast_ref::() + .unwrap(); + + // Even with a non-nullable field, safe casting should still produce nulls for failures + assert_eq!(int32_result_2.len(), 9); + + // Row 0: 42 converts successfully to Int32 + assert!(!int32_result_2.is_null(0)); + assert_eq!(int32_result_2.value(0), 42); + + // Rows 1-4: All should be null due to safe casting behavior + // (non-nullable field specification doesn't override safe casting behavior) + assert!(int32_result_2.is_null(1)); // "not_a_number" + assert!(int32_result_2.is_null(2)); // explicit null + assert!(int32_result_2.is_null(3)); // "hello" + assert!(int32_result_2.is_null(4)); // missing field + + // Rows 5-7: These should also convert successfully (numeric widening/fitting) + assert!(!int32_result_2.is_null(5)); // 127 (Int8 -> Int32) + assert_eq!(int32_result_2.value(5), 127); + assert!(!int32_result_2.is_null(6)); // 32767 (Int16 -> Int32) + assert_eq!(int32_result_2.value(6), 32767); + assert!(!int32_result_2.is_null(7)); // 2147483647 (fits in Int32) + assert_eq!(int32_result_2.value(7), 2147483647); + + // Row 8: Large Int64 should fail conversion -> NULL + assert!(int32_result_2.is_null(8)); // 9223372036854775807 (too large for Int32) + } + + #[test] + fn test_struct_extraction_subset_superset_schema_perfectly_shredded() { + // Create variant with diverse null patterns and empty objects + let variant_array = create_comprehensive_shredded_variant(); + + // Request struct with fields "a", "b", "d" (skip existing "c", add missing "d") + let struct_fields = Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + ]); + let struct_type = DataType::Struct(struct_fields); + + let options = GetOptions { + path: VariantPath::default(), + as_type: Some(Arc::new(Field::new("result", struct_type, true))), + cast_options: CastOptions::default(), + }; + + let result = variant_get(&variant_array, options).unwrap(); + + // Verify the result is a StructArray with 3 fields and 5 rows + let struct_result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(struct_result.len(), 5); + assert_eq!(struct_result.num_columns(), 3); + + let field_a = struct_result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let field_b = struct_result + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + let field_d = struct_result + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + + // Row 0: Normal values {"a": 1, "b": 2, "c": 3} → {a: 1, b: 2, d: NULL} + assert!(!struct_result.is_null(0)); + assert_eq!(field_a.value(0), 1); + assert_eq!(field_b.value(0), 2); + assert!(field_d.is_null(0)); // Missing field "d" + + // Row 1: Top-level NULL → struct-level NULL + assert!(struct_result.is_null(1)); + + // Row 2: Field "a" missing → {a: NULL, b: 2, d: NULL} + assert!(!struct_result.is_null(2)); + assert!(field_a.is_null(2)); // Missing field "a" + assert_eq!(field_b.value(2), 2); + assert!(field_d.is_null(2)); // Missing field "d" + + // Row 3: Field "b" missing → {a: 1, b: NULL, d: NULL} + assert!(!struct_result.is_null(3)); + assert_eq!(field_a.value(3), 1); + assert!(field_b.is_null(3)); // Missing field "b" + assert!(field_d.is_null(3)); // Missing field "d" + + // Row 4: Empty object {} → {a: NULL, b: NULL, d: NULL} + assert!(!struct_result.is_null(4)); + assert!(field_a.is_null(4)); // Empty object + assert!(field_b.is_null(4)); // Empty object + assert!(field_d.is_null(4)); // Missing field "d" + } + + #[test] + fn test_nested_struct_extraction_perfectly_shredded() { + // Create nested variant with diverse null patterns + let variant_array = create_comprehensive_nested_shredded_variant(); + println!("variant_array: {variant_array:?}"); + + // Request 3-level nested struct type {"outer": {"inner": INT}} + let inner_field = Field::new("inner", DataType::Int32, true); + let inner_type = DataType::Struct(Fields::from(vec![inner_field])); + let outer_field = Field::new("outer", inner_type, true); + let result_type = DataType::Struct(Fields::from(vec![outer_field])); + + let options = GetOptions { + path: VariantPath::default(), + as_type: Some(Arc::new(Field::new("result", result_type, true))), + cast_options: CastOptions::default(), + }; + + let result = variant_get(&variant_array, options).unwrap(); + println!("result: {result:?}"); + + // Verify the result is a StructArray with "outer" field and 4 rows + let outer_struct = result.as_any().downcast_ref::().unwrap(); + assert_eq!(outer_struct.len(), 4); + assert_eq!(outer_struct.num_columns(), 1); + + // Get the "inner" struct column + let inner_struct = outer_struct + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(inner_struct.num_columns(), 1); + + // Get the "leaf" field (Int32 values) + let leaf_field = inner_struct + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + // Row 0: Normal nested {"outer": {"inner": {"leaf": 42}}} + assert!(!outer_struct.is_null(0)); + assert!(!inner_struct.is_null(0)); + assert_eq!(leaf_field.value(0), 42); + + // Row 1: "inner" field missing → {outer: {inner: NULL}} + assert!(!outer_struct.is_null(1)); + assert!(!inner_struct.is_null(1)); // outer exists, inner exists but leaf is NULL + assert!(leaf_field.is_null(1)); // leaf field is NULL + + // Row 2: "outer" field missing → {outer: NULL} + assert!(!outer_struct.is_null(2)); + assert!(inner_struct.is_null(2)); // outer field is NULL + + // Row 3: Top-level NULL → struct-level NULL + assert!(outer_struct.is_null(3)); + } + + #[test] + fn test_path_based_null_masks_one_step() { + // Create nested variant with diverse null patterns + let variant_array = create_comprehensive_nested_shredded_variant(); + + // Extract "outer" field using path-based variant_get + let path = VariantPath::from("outer"); + let inner_field = Field::new("inner", DataType::Int32, true); + let result_type = DataType::Struct(Fields::from(vec![inner_field])); + + let options = GetOptions { + path, + as_type: Some(Arc::new(Field::new("result", result_type, true))), + cast_options: CastOptions::default(), + }; + + let result = variant_get(&variant_array, options).unwrap(); + + // Verify the result is a StructArray with "inner" field and 4 rows + let outer_result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(outer_result.len(), 4); + assert_eq!(outer_result.num_columns(), 1); + + // Get the "inner" field (Int32 values) + let inner_field = outer_result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + // Row 0: Normal nested {"outer": {"inner": 42}} → {"inner": 42} + assert!(!outer_result.is_null(0)); + assert_eq!(inner_field.value(0), 42); + + // Row 1: Inner field null {"outer": {"inner": null}} → {"inner": null} + assert!(!outer_result.is_null(1)); + assert!(inner_field.is_null(1)); + + // Row 2: Outer field null {"outer": null} → null (entire struct is null) + assert!(outer_result.is_null(2)); + + // Row 3: Top-level null → null (entire struct is null) + assert!(outer_result.is_null(3)); + } + + #[test] + fn test_path_based_null_masks_two_steps() { + // Create nested variant with diverse null patterns + let variant_array = create_comprehensive_nested_shredded_variant(); + + // Extract "outer.inner" field using path-based variant_get + let path = VariantPath::from("outer").join("inner"); + + let options = GetOptions { + path, + as_type: Some(Arc::new(Field::new("result", DataType::Int32, true))), + cast_options: CastOptions::default(), + }; + + let result = variant_get(&variant_array, options).unwrap(); + + // Verify the result is an Int32Array with 4 rows + let int_result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(int_result.len(), 4); + + // Row 0: Normal nested {"outer": {"inner": 42}} → 42 + assert!(!int_result.is_null(0)); + assert_eq!(int_result.value(0), 42); + + // Row 1: Inner field null {"outer": {"inner": null}} → null + assert!(int_result.is_null(1)); + + // Row 2: Outer field null {"outer": null} → null (path traversal fails) + assert!(int_result.is_null(2)); + + // Row 3: Top-level null → null (path traversal fails) + assert!(int_result.is_null(3)); + } + + #[test] + fn test_struct_extraction_mixed_and_unshredded() { + // Create a partially shredded variant (x shredded, y not) + let variant_array = create_mixed_and_unshredded_variant(); + + // Request struct with both shredded and unshredded fields + let struct_fields = Fields::from(vec![ + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Int32, true), + ]); + let struct_type = DataType::Struct(struct_fields); + + let options = GetOptions { + path: VariantPath::default(), + as_type: Some(Arc::new(Field::new("result", struct_type, true))), + cast_options: CastOptions::default(), + }; + + let result = variant_get(&variant_array, options).unwrap(); + + // Verify the mixed shredding works (should succeed with current implementation) + let struct_result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(struct_result.len(), 4); + assert_eq!(struct_result.num_columns(), 2); + + let field_x = struct_result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let field_y = struct_result + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + // Row 0: {"x": 1, "y": 42} - x from shredded, y from value field + assert_eq!(field_x.value(0), 1); + assert_eq!(field_y.value(0), 42); + + // Row 1: {"x": 2} - x from shredded, y missing (perfect shredding) + assert_eq!(field_x.value(1), 2); + assert!(field_y.is_null(1)); + + // Row 2: {"x": 3, "y": null} - x from shredded, y explicitly null in value + assert_eq!(field_x.value(2), 3); + assert!(field_y.is_null(2)); + + // Row 3: top-level null - entire struct row should be null + assert!(struct_result.is_null(3)); + } + + /// Test that demonstrates the actual struct row builder gap + /// This test should fail because it hits unshredded nested structs + #[test] + fn test_struct_row_builder_gap_demonstration() { + // Create completely unshredded JSON variant (no typed_value at all) + let json_strings = vec![ + r#"{"outer": {"inner": 42}}"#, + r#"{"outer": {"inner": 100}}"#, + ]; + let string_array: Arc = Arc::new(StringArray::from(json_strings)); + let variant_array = json_to_variant(&string_array).unwrap(); + + // Request nested struct - this should fail at the row builder level + let inner_fields = Fields::from(vec![Field::new("inner", DataType::Int32, true)]); + let inner_struct_type = DataType::Struct(inner_fields); + let outer_fields = Fields::from(vec![Field::new("outer", inner_struct_type, true)]); + let outer_struct_type = DataType::Struct(outer_fields); + + let options = GetOptions { + path: VariantPath::default(), + as_type: Some(Arc::new(Field::new("result", outer_struct_type, true))), + cast_options: CastOptions::default(), + }; + + let variant_array_ref: Arc = Arc::new(variant_array); + let result = variant_get(&variant_array_ref, options); + + // Should fail with NotYetImplemented when the row builder tries to handle struct type + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.to_string().contains("not yet implemented")); + } + + /// Create comprehensive shredded variant with diverse null patterns and empty objects + /// Rows: normal values, top-level null, missing field a, missing field b, empty object + fn create_comprehensive_shredded_variant() -> ArrayRef { + let (metadata, _) = { + let mut builder = parquet_variant::VariantBuilder::new(); + let obj = builder.new_object(); + obj.finish(); + builder.finish() + }; + + // Create null buffer for top-level nulls + let nulls = NullBuffer::from(vec![ + true, // row 0: normal values + false, // row 1: top-level null + true, // row 2: missing field a + true, // row 3: missing field b + true, // row 4: empty object + ]); + + let metadata_array = BinaryViewArray::from_iter_values(std::iter::repeat_n(&metadata, 5)); + + // Create shredded fields with different null patterns + // Field "a": present in rows 0,3 (missing in rows 1,2,4) + let a_field_typed_value = Int32Array::from(vec![Some(1), None, None, Some(1), None]); + let a_field_struct = StructArrayBuilder::new() + .with_field("typed_value", Arc::new(a_field_typed_value), true) + .build(); + let a_field_shredded = ShreddedVariantFieldArray::try_new(Arc::new(a_field_struct)) + .expect("should create ShreddedVariantFieldArray for a"); + + // Field "b": present in rows 0,2 (missing in rows 1,3,4) + let b_field_typed_value = Int32Array::from(vec![Some(2), None, Some(2), None, None]); + let b_field_struct = StructArrayBuilder::new() + .with_field("typed_value", Arc::new(b_field_typed_value), true) + .build(); + let b_field_shredded = ShreddedVariantFieldArray::try_new(Arc::new(b_field_struct)) + .expect("should create ShreddedVariantFieldArray for b"); + + // Field "c": present in row 0 only (missing in all other rows) + let c_field_typed_value = Int32Array::from(vec![Some(3), None, None, None, None]); + let c_field_struct = StructArrayBuilder::new() + .with_field("typed_value", Arc::new(c_field_typed_value), true) + .build(); + let c_field_shredded = ShreddedVariantFieldArray::try_new(Arc::new(c_field_struct)) + .expect("should create ShreddedVariantFieldArray for c"); + + // Create main typed_value struct + let typed_value_fields = Fields::from(vec![ + Field::new("a", a_field_shredded.data_type().clone(), true), + Field::new("b", b_field_shredded.data_type().clone(), true), + Field::new("c", c_field_shredded.data_type().clone(), true), + ]); + let typed_value_struct = StructArray::try_new( + typed_value_fields, + vec![ + Arc::new(a_field_shredded), + Arc::new(b_field_shredded), + Arc::new(c_field_shredded), + ], + None, + ) + .unwrap(); + + // Build final VariantArray with top-level nulls + let struct_array = StructArrayBuilder::new() + .with_field("metadata", Arc::new(metadata_array), false) + .with_field("typed_value", Arc::new(typed_value_struct), true) + .with_nulls(nulls) + .build(); + + Arc::new(VariantArray::try_new(Arc::new(struct_array)).expect("should create VariantArray")) + } + + /// Create comprehensive nested shredded variant with diverse null patterns + /// Represents 3-level structure: variant -> outer -> inner (INT value) + /// The shredding schema is: {"metadata": BINARY, "typed_value": {"outer": {"typed_value": {"inner": {"typed_value": INT}}}}} + /// Rows: normal nested value, inner field null, outer field null, top-level null + fn create_comprehensive_nested_shredded_variant() -> ArrayRef { + // Create the inner level: contains typed_value with Int32 values + // Row 0: has value 42, Row 1: inner null, Row 2: outer null, Row 3: top-level null + let inner_typed_value = Int32Array::from(vec![Some(42), None, None, None]); // dummy value for row 2 + let inner = StructArrayBuilder::new() + .with_field("typed_value", Arc::new(inner_typed_value), true) + .build(); + let inner = ShreddedVariantFieldArray::try_new(Arc::new(inner)).unwrap(); + + let outer_typed_value_nulls = NullBuffer::from(vec![ + true, // row 0: inner struct exists with typed_value=42 + false, // row 1: inner field NULL + false, // row 2: outer field NULL + false, // row 3: top-level NULL + ]); + let outer_typed_value = StructArrayBuilder::new() + .with_field("inner", Arc::new(inner), false) + .with_nulls(outer_typed_value_nulls) + .build(); + + let outer = StructArrayBuilder::new() + .with_field("typed_value", Arc::new(outer_typed_value), true) + .build(); + let outer = ShreddedVariantFieldArray::try_new(Arc::new(outer)).unwrap(); + + let typed_value_nulls = NullBuffer::from(vec![ + true, // row 0: inner struct exists with typed_value=42 + true, // row 1: inner field NULL + false, // row 2: outer field NULL + false, // row 3: top-level NULL + ]); + let typed_value = StructArrayBuilder::new() + .with_field("outer", Arc::new(outer), false) + .with_nulls(typed_value_nulls) + .build(); + + // Build final VariantArray with top-level nulls + let metadata_array = + BinaryViewArray::from_iter_values(std::iter::repeat_n(EMPTY_VARIANT_METADATA_BYTES, 4)); + let nulls = NullBuffer::from(vec![ + true, // row 0: inner struct exists with typed_value=42 + true, // row 1: inner field NULL + true, // row 2: outer field NULL + false, // row 3: top-level NULL + ]); + let struct_array = StructArrayBuilder::new() + .with_field("metadata", Arc::new(metadata_array), false) + .with_field("typed_value", Arc::new(typed_value), true) + .with_nulls(nulls) + .build(); + + Arc::new(VariantArray::try_new(Arc::new(struct_array)).expect("should create VariantArray")) + } + + /// Create variant with mixed shredding (spec-compliant) including null scenarios + /// Field "x" is globally shredded, field "y" is never shredded + fn create_mixed_and_unshredded_variant() -> ArrayRef { + // Create spec-compliant mixed shredding: + // - Field "x" is globally shredded (has typed_value column) + // - Field "y" is never shredded (only appears in value field when present) + + let (metadata, y_field_value) = { + let mut builder = parquet_variant::VariantBuilder::new(); + let mut obj = builder.new_object(); + obj.insert("y", Variant::from(42)); + obj.finish(); + builder.finish() + }; + + let metadata_array = BinaryViewArray::from_iter_values(std::iter::repeat_n(&metadata, 4)); + + // Value field contains objects with unshredded fields only (never contains "x") + // Row 0: {"y": "foo"} - x is shredded out, y remains in value + // Row 1: {} - both x and y are absent (perfect shredding for x, y missing) + // Row 2: {"y": null} - x is shredded out, y explicitly null + // Row 3: top-level null (encoded in VariantArray's null mask, but fields contain valid data) + + let empty_object_value = { + let mut builder = parquet_variant::VariantBuilder::new(); + builder.new_object().finish(); + let (_, value) = builder.finish(); + value + }; + + let y_null_value = { + let mut builder = parquet_variant::VariantBuilder::new(); + builder.new_object().with_field("y", Variant::Null).finish(); + let (_, value) = builder.finish(); + value + }; + + let value_array = BinaryViewArray::from(vec![ + Some(y_field_value.as_slice()), // Row 0: {"y": 42} + Some(empty_object_value.as_slice()), // Row 1: {} + Some(y_null_value.as_slice()), // Row 2: {"y": null} + Some(empty_object_value.as_slice()), // Row 3: top-level null (but value field contains valid data) + ]); + + // Create shredded field "x" (globally shredded - never appears in value field) + // For top-level null row, the field still needs valid content (not null) + let x_field_typed_value = Int32Array::from(vec![Some(1), Some(2), Some(3), Some(0)]); + let x_field_struct = StructArrayBuilder::new() + .with_field("typed_value", Arc::new(x_field_typed_value), true) + .build(); + let x_field_shredded = ShreddedVariantFieldArray::try_new(Arc::new(x_field_struct)) + .expect("should create ShreddedVariantFieldArray for x"); + + // Create main typed_value struct (only contains shredded fields) + let typed_value_struct = StructArrayBuilder::new() + .with_field("x", Arc::new(x_field_shredded), false) + .build(); + + // Build VariantArray with both value and typed_value (PartiallyShredded) + // Top-level null is encoded in the main StructArray's null mask + let variant_nulls = NullBuffer::from(vec![true, true, true, false]); // Row 3 is top-level null + let struct_array = StructArrayBuilder::new() + .with_field("metadata", Arc::new(metadata_array), false) + .with_field("value", Arc::new(value_array), true) + .with_field("typed_value", Arc::new(typed_value_struct), true) + .with_nulls(variant_nulls) + .build(); + + Arc::new(VariantArray::try_new(Arc::new(struct_array)).expect("should create VariantArray")) + } +} diff --git a/parquet-variant-compute/src/variant_get/mod.rs b/parquet-variant-compute/src/variant_get/mod.rs deleted file mode 100644 index cc852bbc32a2..000000000000 --- a/parquet-variant-compute/src/variant_get/mod.rs +++ /dev/null @@ -1,430 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. -use arrow::{ - array::{Array, ArrayRef}, - compute::CastOptions, - error::Result, -}; -use arrow_schema::{ArrowError, FieldRef}; -use parquet_variant::VariantPath; - -use crate::variant_array::ShreddingState; -use crate::variant_get::output::instantiate_output_builder; -use crate::VariantArray; - -mod output; - -/// Returns an array with the specified path extracted from the variant values. -/// -/// The return array type depends on the `as_type` field of the options parameter -/// 1. `as_type: None`: a VariantArray is returned. The values in this new VariantArray will point -/// to the specified path. -/// 2. `as_type: Some()`: an array of the specified type is returned. -pub fn variant_get(input: &ArrayRef, options: GetOptions) -> Result { - let variant_array: &VariantArray = input.as_any().downcast_ref().ok_or_else(|| { - ArrowError::InvalidArgumentError( - "expected a VariantArray as the input for variant_get".to_owned(), - ) - })?; - - // Create the output writer based on the specified output options - let output_builder = instantiate_output_builder(options.clone())?; - - // Dispatch based on the shredding state of the input variant array - match variant_array.shredding_state() { - ShreddingState::PartiallyShredded { - metadata, - value, - typed_value, - } => output_builder.partially_shredded(variant_array, metadata, value, typed_value), - ShreddingState::Typed { - metadata, - typed_value, - } => output_builder.typed(variant_array, metadata, typed_value), - ShreddingState::Unshredded { metadata, value } => { - output_builder.unshredded(variant_array, metadata, value) - } - } -} - -/// Controls the action of the variant_get kernel. -#[derive(Debug, Clone, Default)] -pub struct GetOptions<'a> { - /// What path to extract - pub path: VariantPath<'a>, - /// if `as_type` is None, the returned array will itself be a VariantArray. - /// - /// if `as_type` is `Some(type)` the field is returned as the specified type. - pub as_type: Option, - /// Controls the casting behavior (e.g. error vs substituting null on cast error). - pub cast_options: CastOptions<'a>, -} - -impl<'a> GetOptions<'a> { - /// Construct default options to get the specified path as a variant. - pub fn new() -> Self { - Default::default() - } - - /// Construct options to get the specified path as a variant. - pub fn new_with_path(path: VariantPath<'a>) -> Self { - Self { - path, - as_type: None, - cast_options: Default::default(), - } - } - - /// Specify the type to return. - pub fn with_as_type(mut self, as_type: Option) -> Self { - self.as_type = as_type; - self - } - - /// Specify the cast options to use when casting to the specified type. - pub fn with_cast_options(mut self, cast_options: CastOptions<'a>) -> Self { - self.cast_options = cast_options; - self - } -} - -#[cfg(test)] -mod test { - use std::sync::Arc; - - use arrow::array::{Array, ArrayRef, BinaryViewArray, Int32Array, StringArray, StructArray}; - use arrow::buffer::NullBuffer; - use arrow::compute::CastOptions; - use arrow_schema::{DataType, Field, FieldRef, Fields}; - use parquet_variant::{Variant, VariantPath}; - - use crate::batch_json_string_to_variant; - use crate::VariantArray; - - use super::{variant_get, GetOptions}; - - fn single_variant_get_test(input_json: &str, path: VariantPath, expected_json: &str) { - // Create input array from JSON string - let input_array_ref: ArrayRef = Arc::new(StringArray::from(vec![Some(input_json)])); - let input_variant_array_ref: ArrayRef = - Arc::new(batch_json_string_to_variant(&input_array_ref).unwrap()); - - let result = - variant_get(&input_variant_array_ref, GetOptions::new_with_path(path)).unwrap(); - - // Create expected array from JSON string - let expected_array_ref: ArrayRef = Arc::new(StringArray::from(vec![Some(expected_json)])); - let expected_variant_array = batch_json_string_to_variant(&expected_array_ref).unwrap(); - - let result_array: &VariantArray = result.as_any().downcast_ref().unwrap(); - assert_eq!( - result_array.len(), - 1, - "Expected result array to have length 1" - ); - assert!( - result_array.nulls().is_none(), - "Expected no nulls in result array" - ); - let result_variant = result_array.value(0); - let expected_variant = expected_variant_array.value(0); - assert_eq!( - result_variant, expected_variant, - "Result variant does not match expected variant" - ); - } - - #[test] - fn get_primitive_variant_field() { - single_variant_get_test( - r#"{"some_field": 1234}"#, - VariantPath::from("some_field"), - "1234", - ); - } - - #[test] - fn get_primitive_variant_list_index() { - single_variant_get_test("[1234, 5678]", VariantPath::from(0), "1234"); - } - - #[test] - fn get_primitive_variant_inside_object_of_object() { - single_variant_get_test( - r#"{"top_level_field": {"inner_field": 1234}}"#, - VariantPath::from("top_level_field").join("inner_field"), - "1234", - ); - } - - #[test] - fn get_primitive_variant_inside_list_of_object() { - single_variant_get_test( - r#"[{"some_field": 1234}]"#, - VariantPath::from(0).join("some_field"), - "1234", - ); - } - - #[test] - fn get_primitive_variant_inside_object_of_list() { - single_variant_get_test( - r#"{"some_field": [1234]}"#, - VariantPath::from("some_field").join(0), - "1234", - ); - } - - #[test] - fn get_complex_variant() { - single_variant_get_test( - r#"{"top_level_field": {"inner_field": 1234}}"#, - VariantPath::from("top_level_field"), - r#"{"inner_field": 1234}"#, - ); - } - - /// Shredding: extract a value as a VariantArray - #[test] - fn get_variant_shredded_int32_as_variant() { - let array = shredded_int32_variant_array(); - let options = GetOptions::new(); - let result = variant_get(&array, options).unwrap(); - - // expect the result is a VariantArray - let result: &VariantArray = result.as_any().downcast_ref().unwrap(); - assert_eq!(result.len(), 4); - - // Expect the values are the same as the original values - assert_eq!(result.value(0), Variant::Int32(34)); - assert!(!result.is_valid(1)); - assert_eq!(result.value(2), Variant::from("n/a")); - assert_eq!(result.value(3), Variant::Int32(100)); - } - - /// Shredding: extract a value as an Int32Array - #[test] - fn get_variant_shredded_int32_as_int32_safe_cast() { - // Extract the typed value as Int32Array - let array = shredded_int32_variant_array(); - // specify we want the typed value as Int32 - let field = Field::new("typed_value", DataType::Int32, true); - let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); - let result = variant_get(&array, options).unwrap(); - let expected: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(34), - None, - None, // "n/a" is not an Int32 so converted to null - Some(100), - ])); - assert_eq!(&result, &expected) - } - - /// Shredding: extract a value as an Int32Array, unsafe cast (should error on "n/a") - - #[test] - fn get_variant_shredded_int32_as_int32_unsafe_cast() { - // Extract the typed value as Int32Array - let array = shredded_int32_variant_array(); - let field = Field::new("typed_value", DataType::Int32, true); - let cast_options = CastOptions { - safe: false, // unsafe cast - ..Default::default() - }; - let options = GetOptions::new() - .with_as_type(Some(FieldRef::from(field))) - .with_cast_options(cast_options); - - let err = variant_get(&array, options).unwrap_err(); - // TODO make this error message nicer (not Debug format) - assert_eq!(err.to_string(), "Cast error: Failed to extract primitive of type Int32 from variant ShortString(ShortString(\"n/a\")) at path VariantPath([])"); - } - - /// Perfect Shredding: extract the typed value as a VariantArray - #[test] - fn get_variant_perfectly_shredded_int32_as_variant() { - let array = perfectly_shredded_int32_variant_array(); - let options = GetOptions::new(); - let result = variant_get(&array, options).unwrap(); - - // expect the result is a VariantArray - let result: &VariantArray = result.as_any().downcast_ref().unwrap(); - assert_eq!(result.len(), 3); - - // Expect the values are the same as the original values - assert_eq!(result.value(0), Variant::Int32(1)); - assert_eq!(result.value(1), Variant::Int32(2)); - assert_eq!(result.value(2), Variant::Int32(3)); - } - - /// Shredding: Extract the typed value as Int32Array - #[test] - fn get_variant_perfectly_shredded_int32_as_int32() { - // Extract the typed value as Int32Array - let array = perfectly_shredded_int32_variant_array(); - // specify we want the typed value as Int32 - let field = Field::new("typed_value", DataType::Int32, true); - let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); - let result = variant_get(&array, options).unwrap(); - let expected: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])); - assert_eq!(&result, &expected) - } - - /// Return a VariantArray that represents a perfectly "shredded" variant - /// for the following example (3 Variant::Int32 values): - /// - /// ```text - /// 1 - /// 2 - /// 3 - /// ``` - /// - /// The schema of the corresponding `StructArray` would look like this: - /// - /// ```text - /// StructArray { - /// metadata: BinaryViewArray, - /// typed_value: Int32Array, - /// } - /// ``` - fn perfectly_shredded_int32_variant_array() -> ArrayRef { - // At the time of writing, the `VariantArrayBuilder` does not support shredding. - // so we must construct the array manually. see https://github.com/apache/arrow-rs/issues/7895 - let (metadata, _value) = { parquet_variant::VariantBuilder::new().finish() }; - - let metadata = BinaryViewArray::from_iter_values(std::iter::repeat_n(&metadata, 3)); - let typed_value = Int32Array::from(vec![Some(1), Some(2), Some(3)]); - - let struct_array = StructArrayBuilder::new() - .with_field("metadata", Arc::new(metadata)) - .with_field("typed_value", Arc::new(typed_value)) - .build(); - - Arc::new( - VariantArray::try_new(Arc::new(struct_array)).expect("should create variant array"), - ) - } - - /// Return a VariantArray that represents a normal "shredded" variant - /// for the following example - /// - /// Based on the example from [the doc] - /// - /// [the doc]: https://docs.google.com/document/d/1pw0AWoMQY3SjD7R4LgbPvMjG_xSCtXp3rZHkVp9jpZ4/edit?tab=t.0 - /// - /// ```text - /// 34 - /// null (an Arrow NULL, not a Variant::Null) - /// "n/a" (a string) - /// 100 - /// ``` - /// - /// The schema of the corresponding `StructArray` would look like this: - /// - /// ```text - /// StructArray { - /// metadata: BinaryViewArray, - /// value: BinaryViewArray, - /// typed_value: Int32Array, - /// } - /// ``` - fn shredded_int32_variant_array() -> ArrayRef { - // At the time of writing, the `VariantArrayBuilder` does not support shredding. - // so we must construct the array manually. see https://github.com/apache/arrow-rs/issues/7895 - let (metadata, string_value) = { - let mut builder = parquet_variant::VariantBuilder::new(); - builder.append_value("n/a"); - builder.finish() - }; - - let nulls = NullBuffer::from(vec![ - true, // row 0 non null - false, // row 1 is null - true, // row 2 non null - true, // row 3 non null - ]); - - // metadata is the same for all rows - let metadata = BinaryViewArray::from_iter_values(std::iter::repeat_n(&metadata, 4)); - - // See https://docs.google.com/document/d/1pw0AWoMQY3SjD7R4LgbPvMjG_xSCtXp3rZHkVp9jpZ4/edit?disco=AAABml8WQrY - // about why row1 is an empty but non null, value. - let values = BinaryViewArray::from(vec![ - None, // row 0 is shredded, so no value - Some(b"" as &[u8]), // row 1 is null, so empty value (why?) - Some(&string_value), // copy the string value "N/A" - None, // row 3 is shredded, so no value - ]); - - let typed_value = Int32Array::from(vec![ - Some(34), // row 0 is shredded, so it has a value - None, // row 1 is null, so no value - None, // row 2 is a string, so no typed value - Some(100), // row 3 is shredded, so it has a value - ]); - - let struct_array = StructArrayBuilder::new() - .with_field("metadata", Arc::new(metadata)) - .with_field("typed_value", Arc::new(typed_value)) - .with_field("value", Arc::new(values)) - .with_nulls(nulls) - .build(); - - Arc::new( - VariantArray::try_new(Arc::new(struct_array)).expect("should create variant array"), - ) - } - - /// Builds struct arrays from component fields - /// - /// TODO: move to arrow crate - #[derive(Debug, Default, Clone)] - struct StructArrayBuilder { - fields: Vec, - arrays: Vec, - nulls: Option, - } - - impl StructArrayBuilder { - fn new() -> Self { - Default::default() - } - - /// Add an array to this struct array as a field with the specified name. - fn with_field(mut self, field_name: &str, array: ArrayRef) -> Self { - let field = Field::new(field_name, array.data_type().clone(), true); - self.fields.push(Arc::new(field)); - self.arrays.push(array); - self - } - - /// Set the null buffer for this struct array. - fn with_nulls(mut self, nulls: NullBuffer) -> Self { - self.nulls = Some(nulls); - self - } - - pub fn build(self) -> StructArray { - let Self { - fields, - arrays, - nulls, - } = self; - StructArray::new(Fields::from(fields), arrays, nulls) - } - } -} diff --git a/parquet-variant-compute/src/variant_get/output/mod.rs b/parquet-variant-compute/src/variant_get/output/mod.rs deleted file mode 100644 index 245d73cce8db..000000000000 --- a/parquet-variant-compute/src/variant_get/output/mod.rs +++ /dev/null @@ -1,87 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -mod primitive; -mod variant; - -use crate::variant_get::output::primitive::PrimitiveOutputBuilder; -use crate::variant_get::output::variant::VariantOutputBuilder; -use crate::variant_get::GetOptions; -use crate::VariantArray; -use arrow::array::{ArrayRef, BinaryViewArray}; -use arrow::datatypes::Int32Type; -use arrow::error::Result; -use arrow_schema::{ArrowError, DataType}; - -/// This trait represents something that gets the output of the variant_get kernel. -/// -/// For example, there are specializations for writing the output as a VariantArray, -/// or as a specific type (e.g. Int32Array). -/// -/// See [`instantiate_output_builder`] to create an instance of this trait. -pub(crate) trait OutputBuilder { - /// create output for a shredded variant array - fn partially_shredded( - &self, - variant_array: &VariantArray, - metadata: &BinaryViewArray, - value_field: &BinaryViewArray, - typed_value: &ArrayRef, - ) -> Result; - - /// output for a perfectly shredded variant array - fn typed( - &self, - variant_array: &VariantArray, - metadata: &BinaryViewArray, - typed_value: &ArrayRef, - ) -> Result; - - /// write out an unshredded variant array - fn unshredded( - &self, - variant_array: &VariantArray, - metadata: &BinaryViewArray, - value_field: &BinaryViewArray, - ) -> Result; -} - -pub(crate) fn instantiate_output_builder<'a>( - options: GetOptions<'a>, -) -> Result> { - let GetOptions { - as_type, - path, - cast_options, - } = options; - - let Some(as_type) = as_type else { - return Ok(Box::new(VariantOutputBuilder::new(path))); - }; - - // handle typed output - match as_type.data_type() { - DataType::Int32 => Ok(Box::new(PrimitiveOutputBuilder::::new( - path, - as_type, - cast_options, - ))), - dt => Err(ArrowError::NotYetImplemented(format!( - "variant_get with as_type={dt} is not implemented yet", - ))), - } -} diff --git a/parquet-variant-compute/src/variant_get/output/primitive.rs b/parquet-variant-compute/src/variant_get/output/primitive.rs deleted file mode 100644 index 36e4221e3242..000000000000 --- a/parquet-variant-compute/src/variant_get/output/primitive.rs +++ /dev/null @@ -1,166 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::variant_get::output::OutputBuilder; -use crate::VariantArray; -use arrow::error::Result; - -use arrow::array::{ - Array, ArrayRef, ArrowPrimitiveType, AsArray, BinaryViewArray, NullBufferBuilder, - PrimitiveArray, -}; -use arrow::compute::{cast_with_options, CastOptions}; -use arrow::datatypes::Int32Type; -use arrow_schema::{ArrowError, FieldRef}; -use parquet_variant::{Variant, VariantPath}; -use std::marker::PhantomData; -use std::sync::Arc; - -/// Trait for Arrow primitive types that can be used in the output builder -/// -/// This just exists to add a generic way to convert from Variant to the primitive type -pub(super) trait ArrowPrimitiveVariant: ArrowPrimitiveType { - /// Try to extract the primitive value from a Variant, returning None if it - /// cannot be converted - /// - /// TODO: figure out how to handle coercion/casting - fn from_variant(variant: &Variant) -> Option; -} - -/// Outputs Primitive arrays -pub(super) struct PrimitiveOutputBuilder<'a, T: ArrowPrimitiveVariant> { - /// What path to extract - path: VariantPath<'a>, - /// Returned output type - as_type: FieldRef, - /// Controls the casting behavior (e.g. error vs substituting null on cast error). - cast_options: CastOptions<'a>, - /// Phantom data for the primitive type - _phantom: PhantomData, -} - -impl<'a, T: ArrowPrimitiveVariant> PrimitiveOutputBuilder<'a, T> { - pub(super) fn new( - path: VariantPath<'a>, - as_type: FieldRef, - cast_options: CastOptions<'a>, - ) -> Self { - Self { - path, - as_type, - cast_options, - _phantom: PhantomData, - } - } -} - -impl<'a, T: ArrowPrimitiveVariant> OutputBuilder for PrimitiveOutputBuilder<'a, T> { - fn partially_shredded( - &self, - variant_array: &VariantArray, - _metadata: &BinaryViewArray, - _value_field: &BinaryViewArray, - typed_value: &ArrayRef, - ) -> arrow::error::Result { - // build up the output array element by element - let mut nulls = NullBufferBuilder::new(variant_array.len()); - let mut values = Vec::with_capacity(variant_array.len()); - let typed_value = - cast_with_options(typed_value, self.as_type.data_type(), &self.cast_options)?; - // downcast to the primitive array (e.g. Int32Array, Float64Array, etc) - let typed_value = typed_value.as_primitive::(); - - for i in 0..variant_array.len() { - if variant_array.is_null(i) { - nulls.append_null(); - values.push(T::default_value()); // not used, placeholder - continue; - } - - // if the typed value is null, decode the variant and extract the value - if typed_value.is_null(i) { - // todo follow path - let variant = variant_array.value(i); - let Some(value) = T::from_variant(&variant) else { - if self.cast_options.safe { - // safe mode: append null if we can't convert - nulls.append_null(); - values.push(T::default_value()); // not used, placeholder - continue; - } else { - return Err(ArrowError::CastError(format!( - "Failed to extract primitive of type {} from variant {:?} at path {:?}", - self.as_type.data_type(), - variant, - self.path - ))); - } - }; - - nulls.append_non_null(); - values.push(value) - } else { - // otherwise we have a typed value, so we can use it directly - nulls.append_non_null(); - values.push(typed_value.value(i)); - } - } - - let nulls = nulls.finish(); - let array = PrimitiveArray::::new(values.into(), nulls) - .with_data_type(self.as_type.data_type().clone()); - Ok(Arc::new(array)) - } - - fn typed( - &self, - _variant_array: &VariantArray, - _metadata: &BinaryViewArray, - typed_value: &ArrayRef, - ) -> arrow::error::Result { - // if the types match exactly, we can just return the typed_value - if typed_value.data_type() == self.as_type.data_type() { - Ok(typed_value.clone()) - } else { - // TODO: try to cast the typed_value to the desired type? - Err(ArrowError::NotYetImplemented(format!( - "variant_get fully_shredded as {:?} with typed_value={:?} is not implemented yet", - self.as_type.data_type(), - typed_value.data_type() - ))) - } - } - - fn unshredded( - &self, - _variant_array: &VariantArray, - _metadata: &BinaryViewArray, - _value_field: &BinaryViewArray, - ) -> Result { - Err(ArrowError::NotYetImplemented(String::from( - "variant_get unshredded to primitive types is not implemented yet", - ))) - } -} - -impl ArrowPrimitiveVariant for Int32Type { - fn from_variant(variant: &Variant) -> Option { - variant.as_int32() - } -} - -// todo for other primitive types diff --git a/parquet-variant-compute/src/variant_get/output/variant.rs b/parquet-variant-compute/src/variant_get/output/variant.rs deleted file mode 100644 index 2c04111a5306..000000000000 --- a/parquet-variant-compute/src/variant_get/output/variant.rs +++ /dev/null @@ -1,146 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::variant_get::output::OutputBuilder; -use crate::{VariantArray, VariantArrayBuilder}; -use arrow::array::{Array, ArrayRef, AsArray, BinaryViewArray}; -use arrow::datatypes::Int32Type; -use arrow_schema::{ArrowError, DataType}; -use parquet_variant::{Variant, VariantPath}; -use std::sync::Arc; - -/// Outputs VariantArrays -pub(super) struct VariantOutputBuilder<'a> { - /// What path to extract - path: VariantPath<'a>, -} - -impl<'a> VariantOutputBuilder<'a> { - pub(super) fn new(path: VariantPath<'a>) -> Self { - Self { path } - } -} - -impl<'a> OutputBuilder for VariantOutputBuilder<'a> { - fn partially_shredded( - &self, - variant_array: &VariantArray, - // TODO(perf): can reuse the metadata field here to avoid re-creating it - _metadata: &BinaryViewArray, - _value_field: &BinaryViewArray, - typed_value: &ArrayRef, - ) -> arrow::error::Result { - // in this case dispatch on the typed_value and - // TODO macro'ize this using downcast! to handle all other primitive types - // TODO(perf): avoid builders entirely (and write the raw variant directly as we know the metadata is the same) - let mut array_builder = VariantArrayBuilder::new(variant_array.len()); - match typed_value.data_type() { - DataType::Int32 => { - let primitive_array = typed_value.as_primitive::(); - for i in 0..variant_array.len() { - if variant_array.is_null(i) { - array_builder.append_null(); - continue; - } - - if typed_value.is_null(i) { - // fall back to the value (variant) field - // (TODO could copy the variant bytes directly) - let value = variant_array.value(i); - array_builder.append_variant(value); - continue; - } - - // otherwise we have a typed value, so we can use it directly - let int_value = primitive_array.value(i); - array_builder.append_variant(Variant::from(int_value)); - } - } - dt => { - return Err(ArrowError::NotYetImplemented(format!( - "variant_get fully_shredded with typed_value={dt} is not implemented yet", - ))); - } - }; - Ok(Arc::new(array_builder.build())) - } - - fn typed( - &self, - variant_array: &VariantArray, - // TODO(perf): can reuse the metadata field here to avoid re-creating it - _metadata: &BinaryViewArray, - typed_value: &ArrayRef, - ) -> arrow::error::Result { - // in this case dispatch on the typed_value and - // TODO macro'ize this using downcast! to handle all other primitive types - // TODO(perf): avoid builders entirely (and write the raw variant directly as we know the metadata is the same) - let mut array_builder = VariantArrayBuilder::new(variant_array.len()); - match typed_value.data_type() { - DataType::Int32 => { - let primitive_array = typed_value.as_primitive::(); - for i in 0..variant_array.len() { - if primitive_array.is_null(i) { - array_builder.append_null(); - continue; - } - - let int_value = primitive_array.value(i); - array_builder.append_variant(Variant::from(int_value)); - } - } - dt => { - return Err(ArrowError::NotYetImplemented(format!( - "variant_get fully_shredded with typed_value={dt} is not implemented yet", - ))); - } - }; - Ok(Arc::new(array_builder.build())) - } - - fn unshredded( - &self, - variant_array: &VariantArray, - _metadata: &BinaryViewArray, - _value_field: &BinaryViewArray, - ) -> arrow::error::Result { - let mut builder = VariantArrayBuilder::new(variant_array.len()); - for i in 0..variant_array.len() { - let new_variant = variant_array.value(i); - - // TODO: perf? - let Some(new_variant) = new_variant.get_path(&self.path) else { - // path not found, append null - builder.append_null(); - continue; - }; - - // TODO: we're decoding the value and doing a copy into a variant value - // again. This can be much faster by using the _metadata and _value_field - // to avoid decoding the entire variant: - // - // 1) reuse the metadata arrays as is - // - // 2) Create a new BinaryViewArray that uses the same underlying buffers - // that the original variant used, but whose views points to a new - // offset for the new path - builder.append_variant(new_variant); - } - - Ok(Arc::new(builder.build())) - } -} diff --git a/parquet-variant-compute/src/variant_to_arrow.rs b/parquet-variant-compute/src/variant_to_arrow.rs new file mode 100644 index 000000000000..115a6a42bebb --- /dev/null +++ b/parquet-variant-compute/src/variant_to_arrow.rs @@ -0,0 +1,313 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, PrimitiveBuilder}; +use arrow::compute::CastOptions; +use arrow::datatypes::{self, ArrowPrimitiveType, DataType}; +use arrow::error::{ArrowError, Result}; +use parquet_variant::{Variant, VariantPath}; + +use crate::type_conversion::VariantAsPrimitive; +use crate::VariantArrayBuilder; + +use std::sync::Arc; + +/// Builder for converting variant values into strongly typed Arrow arrays. +/// +/// Useful for variant_get kernels that need to extract specific paths from variant values, possibly +/// with casting of leaf values to specific types. +pub(crate) enum VariantToArrowRowBuilder<'a> { + // Direct builders (no path extraction) + Int8(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Int8Type>), + Int16(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Int16Type>), + Int32(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Int32Type>), + Int64(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Int64Type>), + Float16(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Float16Type>), + Float32(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Float32Type>), + Float64(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Float64Type>), + UInt8(VariantToPrimitiveArrowRowBuilder<'a, datatypes::UInt8Type>), + UInt16(VariantToPrimitiveArrowRowBuilder<'a, datatypes::UInt16Type>), + UInt32(VariantToPrimitiveArrowRowBuilder<'a, datatypes::UInt32Type>), + UInt64(VariantToPrimitiveArrowRowBuilder<'a, datatypes::UInt64Type>), + BinaryVariant(VariantToBinaryVariantArrowRowBuilder), + + // Path extraction wrapper - contains a boxed enum for any of the above + WithPath(VariantPathRowBuilder<'a>), +} + +impl<'a> VariantToArrowRowBuilder<'a> { + pub fn append_null(&mut self) -> Result<()> { + use VariantToArrowRowBuilder::*; + match self { + Int8(b) => b.append_null(), + Int16(b) => b.append_null(), + Int32(b) => b.append_null(), + Int64(b) => b.append_null(), + UInt8(b) => b.append_null(), + UInt16(b) => b.append_null(), + UInt32(b) => b.append_null(), + UInt64(b) => b.append_null(), + Float16(b) => b.append_null(), + Float32(b) => b.append_null(), + Float64(b) => b.append_null(), + BinaryVariant(b) => b.append_null(), + WithPath(path_builder) => path_builder.append_null(), + } + } + + pub fn append_value(&mut self, value: &Variant<'_, '_>) -> Result { + use VariantToArrowRowBuilder::*; + match self { + Int8(b) => b.append_value(value), + Int16(b) => b.append_value(value), + Int32(b) => b.append_value(value), + Int64(b) => b.append_value(value), + UInt8(b) => b.append_value(value), + UInt16(b) => b.append_value(value), + UInt32(b) => b.append_value(value), + UInt64(b) => b.append_value(value), + Float16(b) => b.append_value(value), + Float32(b) => b.append_value(value), + Float64(b) => b.append_value(value), + BinaryVariant(b) => b.append_value(value), + WithPath(path_builder) => path_builder.append_value(value), + } + } + + pub fn finish(self) -> Result { + use VariantToArrowRowBuilder::*; + match self { + Int8(b) => b.finish(), + Int16(b) => b.finish(), + Int32(b) => b.finish(), + Int64(b) => b.finish(), + UInt8(b) => b.finish(), + UInt16(b) => b.finish(), + UInt32(b) => b.finish(), + UInt64(b) => b.finish(), + Float16(b) => b.finish(), + Float32(b) => b.finish(), + Float64(b) => b.finish(), + BinaryVariant(b) => b.finish(), + WithPath(path_builder) => path_builder.finish(), + } + } +} + +pub(crate) fn make_variant_to_arrow_row_builder<'a>( + //metadata: &BinaryViewArray, + path: VariantPath<'a>, + data_type: Option<&'a DataType>, + cast_options: &'a CastOptions, + capacity: usize, +) -> Result> { + use VariantToArrowRowBuilder::*; + + let mut builder = match data_type { + // If no data type was requested, build an unshredded VariantArray. + None => BinaryVariant(VariantToBinaryVariantArrowRowBuilder::new(capacity)), + Some(DataType::Int8) => Int8(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + Some(DataType::Int16) => Int16(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + Some(DataType::Int32) => Int32(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + Some(DataType::Int64) => Int64(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + Some(DataType::Float16) => Float16(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + Some(DataType::Float32) => Float32(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + Some(DataType::Float64) => Float64(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + Some(DataType::UInt8) => UInt8(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + Some(DataType::UInt16) => UInt16(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + Some(DataType::UInt32) => UInt32(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + Some(DataType::UInt64) => UInt64(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + _ => { + return Err(ArrowError::NotYetImplemented(format!( + "variant_get with path={:?} and data_type={:?} not yet implemented", + path, data_type + ))); + } + }; + + // Wrap with path extraction if needed + if !path.is_empty() { + builder = WithPath(VariantPathRowBuilder { + builder: Box::new(builder), + path, + }) + }; + + Ok(builder) +} + +/// A thin wrapper whose only job is to extract a specific path from a variant value and pass the +/// result to a nested builder. +pub(crate) struct VariantPathRowBuilder<'a> { + builder: Box>, + path: VariantPath<'a>, +} + +impl<'a> VariantPathRowBuilder<'a> { + fn append_null(&mut self) -> Result<()> { + self.builder.append_null() + } + + fn append_value(&mut self, value: &Variant<'_, '_>) -> Result { + if let Some(v) = value.get_path(&self.path) { + self.builder.append_value(&v) + } else { + self.builder.append_null()?; + Ok(false) + } + } + + fn finish(self) -> Result { + self.builder.finish() + } +} + +/// Helper function to get a user-friendly type name +fn get_type_name() -> &'static str { + match std::any::type_name::() { + "arrow_array::types::Int32Type" => "Int32", + "arrow_array::types::Int16Type" => "Int16", + "arrow_array::types::Int8Type" => "Int8", + "arrow_array::types::Int64Type" => "Int64", + "arrow_array::types::UInt32Type" => "UInt32", + "arrow_array::types::UInt16Type" => "UInt16", + "arrow_array::types::UInt8Type" => "UInt8", + "arrow_array::types::UInt64Type" => "UInt64", + "arrow_array::types::Float32Type" => "Float32", + "arrow_array::types::Float64Type" => "Float64", + "arrow_array::types::Float16Type" => "Float16", + _ => "Unknown", + } +} + +/// Builder for converting variant values to primitive values +pub(crate) struct VariantToPrimitiveArrowRowBuilder<'a, T: ArrowPrimitiveType> { + builder: arrow::array::PrimitiveBuilder, + cast_options: &'a CastOptions<'a>, +} + +impl<'a, T: ArrowPrimitiveType> VariantToPrimitiveArrowRowBuilder<'a, T> { + fn new(cast_options: &'a CastOptions<'a>, capacity: usize) -> Self { + Self { + builder: PrimitiveBuilder::::with_capacity(capacity), + cast_options, + } + } +} + +impl<'a, T> VariantToPrimitiveArrowRowBuilder<'a, T> +where + T: ArrowPrimitiveType, + for<'m, 'v> Variant<'m, 'v>: VariantAsPrimitive, +{ + fn append_null(&mut self) -> Result<()> { + self.builder.append_null(); + Ok(()) + } + + fn append_value(&mut self, value: &Variant<'_, '_>) -> Result { + if let Some(v) = value.as_primitive() { + self.builder.append_value(v); + Ok(true) + } else { + if !self.cast_options.safe { + // Unsafe casting: return error on conversion failure + return Err(ArrowError::CastError(format!( + "Failed to extract primitive of type {} from variant {:?} at path VariantPath([])", + get_type_name::(), + value + ))); + } + // Safe casting: append null on conversion failure + self.builder.append_null(); + Ok(false) + } + } + + fn finish(mut self) -> Result { + Ok(Arc::new(self.builder.finish())) + } +} + +/// Builder for creating VariantArray output (for path extraction without type conversion) +pub(crate) struct VariantToBinaryVariantArrowRowBuilder { + builder: VariantArrayBuilder, +} + +impl VariantToBinaryVariantArrowRowBuilder { + fn new(capacity: usize) -> Self { + Self { + builder: VariantArrayBuilder::new(capacity), + } + } +} + +impl VariantToBinaryVariantArrowRowBuilder { + fn append_null(&mut self) -> Result<()> { + self.builder.append_null(); + Ok(()) + } + + fn append_value(&mut self, value: &Variant<'_, '_>) -> Result { + // TODO: We need a way to convert a Variant directly to bytes. In particular, we want to + // just copy across the underlying value byte slice of a `Variant::Object` or + // `Variant::List`, without any interaction with a `VariantMetadata` (because the shredding + // spec requires us to reuse the existing metadata when unshredding). + // + // One could _probably_ emulate this with parquet_variant::VariantBuilder, but it would do a + // lot of unnecessary work and would also create a new metadata column we don't need. + self.builder.append_variant(value.clone()); + Ok(true) + } + + fn finish(self) -> Result { + Ok(Arc::new(self.builder.build())) + } +} diff --git a/parquet-variant-json/Cargo.toml b/parquet-variant-json/Cargo.toml index 76255f0681cd..5d8e02546b09 100644 --- a/parquet-variant-json/Cargo.toml +++ b/parquet-variant-json/Cargo.toml @@ -37,6 +37,7 @@ parquet-variant = { path = "../parquet-variant" } chrono = { workspace = true } serde_json = "1.0" base64 = "0.22" +uuid = "1.18.0" [lib] diff --git a/parquet-variant-json/src/from_json.rs b/parquet-variant-json/src/from_json.rs index 134bafe953a4..3a6e869ec1fc 100644 --- a/parquet-variant-json/src/from_json.rs +++ b/parquet-variant-json/src/from_json.rs @@ -18,10 +18,10 @@ //! Module for parsing JSON strings as Variant use arrow_schema::ArrowError; -use parquet_variant::{ListBuilder, ObjectBuilder, Variant, VariantBuilderExt}; +use parquet_variant::{ObjectFieldBuilder, Variant, VariantBuilderExt}; use serde_json::{Number, Value}; -/// Converts a JSON string to Variant to a [`VariantBuilderExt`], such as +/// Converts a JSON string to Variant using a [`VariantBuilderExt`], such as /// [`VariantBuilder`]. /// /// The resulting `value` and `metadata` buffers can be @@ -29,9 +29,6 @@ use serde_json::{Number, Value}; /// /// # Arguments /// * `json` - The JSON string to parse as Variant. -/// * `variant_builder` - Object of type `VariantBuilder` used to build the variant from the JSON -/// string -/// /// /// # Returns /// @@ -42,25 +39,23 @@ use serde_json::{Number, Value}; /// /// ```rust /// # use parquet_variant::VariantBuilder; -/// # use parquet_variant_json::{ -/// # json_to_variant, variant_to_json_string, variant_to_json, variant_to_json_value -/// # }; +/// # use parquet_variant_json::{JsonToVariant, VariantToJson}; /// /// let mut variant_builder = VariantBuilder::new(); /// let person_string = "{\"name\":\"Alice\", \"age\":30, ".to_string() /// + "\"email\":\"alice@example.com\", \"is_active\": true, \"score\": 95.7," /// + "\"additional_info\": null}"; -/// json_to_variant(&person_string, &mut variant_builder)?; +/// variant_builder.append_json(&person_string)?; /// /// let (metadata, value) = variant_builder.finish(); /// /// let variant = parquet_variant::Variant::try_new(&metadata, &value)?; /// -/// let json_result = variant_to_json_string(&variant)?; -/// let json_value = variant_to_json_value(&variant)?; +/// let json_result = variant.to_json_string()?; +/// let json_value = variant.to_json_value()?; /// /// let mut buffer = Vec::new(); -/// variant_to_json(&mut buffer, &variant)?; +/// variant.to_json(&mut buffer)?; /// let buffer_result = String::from_utf8(buffer)?; /// assert_eq!(json_result, "{\"additional_info\":null,\"age\":30,".to_string() + /// "\"email\":\"alice@example.com\",\"is_active\":true,\"name\":\"Alice\",\"score\":95.7}"); @@ -68,17 +63,19 @@ use serde_json::{Number, Value}; /// assert_eq!(json_result, serde_json::to_string(&json_value)?); /// # Ok::<(), Box>(()) /// ``` -pub fn json_to_variant(json: &str, builder: &mut impl VariantBuilderExt) -> Result<(), ArrowError> { - let json: Value = serde_json::from_str(json) - .map_err(|e| ArrowError::InvalidArgumentError(format!("JSON format error: {e}")))?; - - build_json(&json, builder)?; - Ok(()) +pub trait JsonToVariant { + /// Create a Variant from a JSON string + fn append_json(&mut self, json: &str) -> Result<(), ArrowError>; } -fn build_json(json: &Value, builder: &mut impl VariantBuilderExt) -> Result<(), ArrowError> { - append_json(json, builder)?; - Ok(()) +impl JsonToVariant for T { + fn append_json(&mut self, json: &str) -> Result<(), ArrowError> { + let json: Value = serde_json::from_str(json) + .map_err(|e| ArrowError::InvalidArgumentError(format!("JSON format error: {e}")))?; + + append_json(&json, self)?; + Ok(()) + } } fn variant_from_number<'m, 'v>(n: &Number) -> Result, ArrowError> { @@ -114,50 +111,28 @@ fn append_json(json: &Value, builder: &mut impl VariantBuilderExt) -> Result<(), } Value::String(s) => builder.append_value(s.as_str()), Value::Array(arr) => { - let mut list_builder = builder.new_list(); + let mut list_builder = builder.try_new_list()?; for val in arr { append_json(val, &mut list_builder)?; } list_builder.finish(); } Value::Object(obj) => { - let mut obj_builder = builder.new_object(); + let mut obj_builder = builder.try_new_object()?; for (key, value) in obj.iter() { - let mut field_builder = ObjectFieldBuilder { - key, - builder: &mut obj_builder, - }; + let mut field_builder = ObjectFieldBuilder::new(key, &mut obj_builder); append_json(value, &mut field_builder)?; } - obj_builder.finish()?; + obj_builder.finish(); } }; Ok(()) } -struct ObjectFieldBuilder<'o, 'v, 's> { - key: &'s str, - builder: &'o mut ObjectBuilder<'v>, -} - -impl VariantBuilderExt for ObjectFieldBuilder<'_, '_, '_> { - fn append_value<'m, 'v>(&mut self, value: impl Into>) { - self.builder.insert(self.key, value); - } - - fn new_list(&mut self) -> ListBuilder<'_> { - self.builder.new_list(self.key) - } - - fn new_object(&mut self) -> ObjectBuilder<'_> { - self.builder.new_object(self.key) - } -} - #[cfg(test)] mod test { use super::*; - use crate::variant_to_json_string; + use crate::VariantToJson; use arrow_schema::ArrowError; use parquet_variant::{ ShortString, Variant, VariantBuilder, VariantDecimal16, VariantDecimal4, VariantDecimal8, @@ -171,7 +146,7 @@ mod test { impl JsonToVariantTest<'_> { fn run(self) -> Result<(), ArrowError> { let mut variant_builder = VariantBuilder::new(); - json_to_variant(self.json, &mut variant_builder)?; + variant_builder.append_json(self.json)?; let (metadata, value) = variant_builder.finish(); let variant = Variant::try_new(&metadata, &value)?; assert_eq!(variant, self.expected); @@ -492,7 +467,7 @@ mod test { let mut list_builder = variant_builder.new_list(); let mut object_builder_inner = list_builder.new_object(); object_builder_inner.insert("age", Variant::Int8(32)); - object_builder_inner.finish().unwrap(); + object_builder_inner.finish(); list_builder.append_value(Variant::Int16(128)); list_builder.append_value(Variant::BooleanFalse); list_builder.finish(); @@ -556,7 +531,7 @@ mod test { let mut object_builder = variant_builder.new_object(); object_builder.insert("a", Variant::Int8(3)); object_builder.insert("b", Variant::Int8(2)); - object_builder.finish().unwrap(); + object_builder.finish(); let (metadata, value) = variant_builder.finish(); let variant = Variant::try_new(&metadata, &value)?; JsonToVariantTest { @@ -580,7 +555,7 @@ mod test { inner_list_builder.append_value(Variant::Double(-3e0)); inner_list_builder.append_value(Variant::Double(1001e-3)); inner_list_builder.finish(); - object_builder.finish().unwrap(); + object_builder.finish(); let (metadata, value) = variant_builder.finish(); let variant = Variant::try_new(&metadata, &value)?; JsonToVariantTest { @@ -622,18 +597,18 @@ mod test { ); // Manually verify raw JSON value size let mut variant_builder = VariantBuilder::new(); - json_to_variant(&json, &mut variant_builder)?; + variant_builder.append_json(&json)?; let (metadata, value) = variant_builder.finish(); let v = Variant::try_new(&metadata, &value)?; - let output_string = variant_to_json_string(&v)?; + let output_string = v.to_json_string()?; assert_eq!(output_string, json); // Verify metadata size = 1 + 2 + 2 * 497 + 3 * 496 assert_eq!(metadata.len(), 2485); // Verify value size. - // Size of innermost_list: 1 + 1 + 258 + 256 = 516 - // Size of inner object: 1 + 4 + 256 + 257 * 3 + 256 * 516 = 133128 - // Size of json: 1 + 4 + 512 + 1028 + 256 * 133128 = 34082313 - assert_eq!(value.len(), 34082313); + // Size of innermost_list: 1 + 1 + 2*(128 + 1) + 2*128 = 516 + // Size of inner object: 1 + 4 + 2*256 + 3*(256 + 1) + 256 * 516 = 133384 + // Size of json: 1 + 4 + 2*256 + 4*(256 + 1) + 256 * 133384 = 34147849 + assert_eq!(value.len(), 34147849); let mut variant_builder = VariantBuilder::new(); let mut object_builder = variant_builder.new_object(); @@ -646,9 +621,9 @@ mod test { } list_builder.finish(); }); - inner_object_builder.finish().unwrap(); + inner_object_builder.finish(); }); - object_builder.finish().unwrap(); + object_builder.finish(); let (metadata, value) = variant_builder.finish(); let variant = Variant::try_new(&metadata, &value)?; @@ -663,16 +638,16 @@ mod test { fn test_json_to_variant_unicode() -> Result<(), ArrowError> { let json = "{\"爱\":\"अ\",\"a\":1}"; let mut variant_builder = VariantBuilder::new(); - json_to_variant(json, &mut variant_builder)?; + variant_builder.append_json(json)?; let (metadata, value) = variant_builder.finish(); let v = Variant::try_new(&metadata, &value)?; - let output_string = variant_to_json_string(&v)?; + let output_string = v.to_json_string()?; assert_eq!(output_string, "{\"a\":1,\"爱\":\"अ\"}"); let mut variant_builder = VariantBuilder::new(); let mut object_builder = variant_builder.new_object(); object_builder.insert("a", Variant::Int8(1)); object_builder.insert("爱", Variant::ShortString(ShortString::try_new("अ")?)); - object_builder.finish().unwrap(); + object_builder.finish(); let (metadata, value) = variant_builder.finish(); let variant = Variant::try_new(&metadata, &value)?; diff --git a/parquet-variant-json/src/lib.rs b/parquet-variant-json/src/lib.rs index bb774c05c135..f24c740818be 100644 --- a/parquet-variant-json/src/lib.rs +++ b/parquet-variant-json/src/lib.rs @@ -21,8 +21,8 @@ //! [Variant Binary Encoding]: https://github.com/apache/parquet-format/blob/master/VariantEncoding.md //! [Apache Parquet]: https://parquet.apache.org/ //! -//! * See [`json_to_variant`] for converting a JSON string to a Variant. -//! * See [`variant_to_json`] for converting a Variant to a JSON string. +//! * See [`JsonToVariant`] trait for converting a JSON string to a Variant. +//! * See [`VariantToJson`] trait for converting a Variant to a JSON string. //! //! ## 🚧 Work In Progress //! @@ -34,5 +34,5 @@ mod from_json; mod to_json; -pub use from_json::json_to_variant; -pub use to_json::{variant_to_json, variant_to_json_string, variant_to_json_value}; +pub use from_json::JsonToVariant; +pub use to_json::VariantToJson; diff --git a/parquet-variant-json/src/to_json.rs b/parquet-variant-json/src/to_json.rs index a3ff04bcc99a..b9f5364cf5b6 100644 --- a/parquet-variant-json/src/to_json.rs +++ b/parquet-variant-json/src/to_json.rs @@ -18,128 +18,354 @@ //! Module for converting Variant data to JSON format use arrow_schema::ArrowError; use base64::{engine::general_purpose, Engine as _}; +use chrono::Timelike; +use parquet_variant::{Variant, VariantList, VariantObject}; use serde_json::Value; use std::io::Write; -use parquet_variant::{Variant, VariantList, VariantObject}; +/// Extension trait for converting Variants to JSON +pub trait VariantToJson { + /// + /// This function writes JSON directly to any type that implements [`Write`], + /// making it efficient for streaming or when you want to control the output destination. + /// + /// See [`VariantToJson::to_json_string`] for a convenience function that returns a + /// JSON string. + /// + /// # Arguments + /// + /// * `writer` - Writer to output JSON to + /// * `variant` - The Variant value to convert + /// + /// # Returns + /// + /// * `Ok(())` if successful + /// * `Err` with error details if conversion fails + /// + /// # Examples + /// + /// + /// ```rust + /// # use parquet_variant::{Variant}; + /// # use parquet_variant_json::VariantToJson; + /// # use arrow_schema::ArrowError; + /// let variant = Variant::from("Hello, World!"); + /// let mut buffer = Vec::new(); + /// variant.to_json(&mut buffer)?; + /// assert_eq!(String::from_utf8(buffer).unwrap(), "\"Hello, World!\""); + /// # Ok::<(), ArrowError>(()) + /// ``` + /// + /// # Example: Create a [`Variant::Object`] and convert to JSON + /// ```rust + /// # use parquet_variant::{Variant, VariantBuilder}; + /// # use parquet_variant_json::VariantToJson; + /// # use arrow_schema::ArrowError; + /// let mut builder = VariantBuilder::new(); + /// // Create an object builder that will write fields to the object + /// let mut object_builder = builder.new_object(); + /// object_builder.insert("first_name", "Jiaying"); + /// object_builder.insert("last_name", "Li"); + /// object_builder.finish(); + /// // Finish the builder to get the metadata and value + /// let (metadata, value) = builder.finish(); + /// // Create the Variant and convert to JSON + /// let variant = Variant::try_new(&metadata, &value)?; + /// let mut writer = Vec::new(); + /// variant.to_json(&mut writer)?; + /// assert_eq!(br#"{"first_name":"Jiaying","last_name":"Li"}"#, writer.as_slice()); + /// # Ok::<(), ArrowError>(()) + /// ``` + fn to_json(&self, buffer: &mut impl Write) -> Result<(), ArrowError>; + + /// Convert [`Variant`] to JSON [`String`] + /// + /// This is a convenience function that converts a Variant to a JSON string. + /// This is the same as calling [`VariantToJson::to_json`] with a [`Vec`]. + /// It's the simplest way to get a JSON representation when you just need a String result. + /// + /// # Arguments + /// + /// * `variant` - The Variant value to convert + /// + /// # Returns + /// + /// * `Ok(String)` containing the JSON representation + /// * `Err` with error details if conversion fails + /// + /// # Examples + /// + /// ```rust + /// # use parquet_variant::{Variant}; + /// # use parquet_variant_json::VariantToJson; + /// # use arrow_schema::ArrowError; + /// let variant = Variant::Int32(42); + /// let json = variant.to_json_string()?; + /// assert_eq!(json, "42"); + /// # Ok::<(), ArrowError>(()) + /// ``` + /// + /// # Example: Create a [`Variant::Object`] and convert to JSON + /// + /// This example shows how to create an object with two fields and convert it to JSON: + /// ```json + /// { + /// "first_name": "Jiaying", + /// "last_name": "Li" + /// } + /// ``` + /// + /// ```rust + /// # use parquet_variant::{Variant, VariantBuilder}; + /// # use parquet_variant_json::VariantToJson; + /// # use arrow_schema::ArrowError; + /// let mut builder = VariantBuilder::new(); + /// // Create an object builder that will write fields to the object + /// let mut object_builder = builder.new_object(); + /// object_builder.insert("first_name", "Jiaying"); + /// object_builder.insert("last_name", "Li"); + /// object_builder.finish(); + /// // Finish the builder to get the metadata and value + /// let (metadata, value) = builder.finish(); + /// // Create the Variant and convert to JSON + /// let variant = Variant::try_new(&metadata, &value)?; + /// let json = variant.to_json_string()?; + /// assert_eq!(r#"{"first_name":"Jiaying","last_name":"Li"}"#, json); + /// # Ok::<(), ArrowError>(()) + /// ``` + fn to_json_string(&self) -> Result; + + /// Convert [`Variant`] to [`serde_json::Value`] + /// + /// This function converts a Variant to a [`serde_json::Value`], which is useful + /// when you need to work with the JSON data programmatically or integrate with + /// other serde-based JSON processing. + /// + /// # Arguments + /// + /// * `variant` - The Variant value to convert + /// + /// # Returns + /// + /// * `Ok(Value)` containing the JSON value + /// * `Err` with error details if conversion fails + /// + /// # Examples + /// + /// ```rust + /// # use parquet_variant::{Variant}; + /// # use parquet_variant_json::VariantToJson; + /// # use serde_json::Value; + /// # use arrow_schema::ArrowError; + /// let variant = Variant::from("hello"); + /// let json_value = variant.to_json_value()?; + /// assert_eq!(json_value, Value::String("hello".to_string())); + /// # Ok::<(), ArrowError>(()) + /// ``` + fn to_json_value(&self) -> Result; +} + +impl<'m, 'v> VariantToJson for Variant<'m, 'v> { + fn to_json(&self, buffer: &mut impl Write) -> Result<(), ArrowError> { + match self { + Variant::Null => write!(buffer, "null")?, + Variant::BooleanTrue => write!(buffer, "true")?, + Variant::BooleanFalse => write!(buffer, "false")?, + Variant::Int8(i) => write!(buffer, "{i}")?, + Variant::Int16(i) => write!(buffer, "{i}")?, + Variant::Int32(i) => write!(buffer, "{i}")?, + Variant::Int64(i) => write!(buffer, "{i}")?, + Variant::Float(f) => write!(buffer, "{f}")?, + Variant::Double(f) => write!(buffer, "{f}")?, + Variant::Decimal4(decimal) => write!(buffer, "{decimal}")?, + Variant::Decimal8(decimal) => write!(buffer, "{decimal}")?, + Variant::Decimal16(decimal) => write!(buffer, "{decimal}")?, + Variant::Date(date) => write!(buffer, "\"{}\"", format_date_string(date))?, + Variant::TimestampMicros(ts) | Variant::TimestampNanos(ts) => { + write!(buffer, "\"{}\"", ts.to_rfc3339())? + } + Variant::TimestampNtzMicros(ts) => { + write!(buffer, "\"{}\"", format_timestamp_ntz_string(ts, 6))? + } + Variant::TimestampNtzNanos(ts) => { + write!(buffer, "\"{}\"", format_timestamp_ntz_string(ts, 9))? + } + Variant::Time(time) => write!(buffer, "\"{}\"", format_time_ntz_str(time))?, + Variant::Binary(bytes) => { + // Encode binary as base64 string + let base64_str = format_binary_base64(bytes); + let json_str = serde_json::to_string(&base64_str).map_err(|e| { + ArrowError::InvalidArgumentError(format!("JSON encoding error: {e}")) + })?; + write!(buffer, "{json_str}")? + } + Variant::String(s) => { + // Use serde_json to properly escape the string + let json_str = serde_json::to_string(s).map_err(|e| { + ArrowError::InvalidArgumentError(format!("JSON encoding error: {e}")) + })?; + write!(buffer, "{json_str}")? + } + Variant::ShortString(s) => { + // Use serde_json to properly escape the string + let json_str = serde_json::to_string(s.as_str()).map_err(|e| { + ArrowError::InvalidArgumentError(format!("JSON encoding error: {e}")) + })?; + write!(buffer, "{json_str}")? + } + Variant::Uuid(uuid) => { + write!(buffer, "\"{uuid}\"")?; + } + Variant::Object(obj) => { + convert_object_to_json(buffer, obj)?; + } + Variant::List(arr) => { + convert_array_to_json(buffer, arr)?; + } + } + Ok(()) + } + + fn to_json_string(&self) -> Result { + let mut buffer = Vec::new(); + self.to_json(&mut buffer)?; + String::from_utf8(buffer) + .map_err(|e| ArrowError::InvalidArgumentError(format!("UTF-8 conversion error: {e}"))) + } + + fn to_json_value(&self) -> Result { + match self { + Variant::Null => Ok(Value::Null), + Variant::BooleanTrue => Ok(Value::Bool(true)), + Variant::BooleanFalse => Ok(Value::Bool(false)), + Variant::Int8(i) => Ok(Value::Number((*i).into())), + Variant::Int16(i) => Ok(Value::Number((*i).into())), + Variant::Int32(i) => Ok(Value::Number((*i).into())), + Variant::Int64(i) => Ok(Value::Number((*i).into())), + Variant::Float(f) => serde_json::Number::from_f64((*f).into()) + .map(Value::Number) + .ok_or_else(|| ArrowError::InvalidArgumentError("Invalid float value".to_string())), + Variant::Double(f) => serde_json::Number::from_f64(*f) + .map(Value::Number) + .ok_or_else(|| { + ArrowError::InvalidArgumentError("Invalid double value".to_string()) + }), + Variant::Decimal4(decimal4) => { + let scale = decimal4.scale(); + let integer = decimal4.integer(); + + let integer = if scale == 0 { + integer + } else { + let divisor = 10_i32.pow(scale as u32); + if integer % divisor != 0 { + // fall back to floating point + return Ok(Value::from(integer as f64 / divisor as f64)); + } + integer / divisor + }; + Ok(Value::from(integer)) + } + Variant::Decimal8(decimal8) => { + let scale = decimal8.scale(); + let integer = decimal8.integer(); + + let integer = if scale == 0 { + integer + } else { + let divisor = 10_i64.pow(scale as u32); + if integer % divisor != 0 { + // fall back to floating point + return Ok(Value::from(integer as f64 / divisor as f64)); + } + integer / divisor + }; + Ok(Value::from(integer)) + } + Variant::Decimal16(decimal16) => { + let scale = decimal16.scale(); + let integer = decimal16.integer(); + + let integer = if scale == 0 { + integer + } else { + let divisor = 10_i128.pow(scale as u32); + if integer % divisor != 0 { + // fall back to floating point + return Ok(Value::from(integer as f64 / divisor as f64)); + } + integer / divisor + }; + // i128 has higher precision than any 64-bit type. Try a lossless narrowing cast to + // i64 or u64 first, falling back to a lossy narrowing cast to f64 if necessary. + let value = i64::try_from(integer) + .map(Value::from) + .or_else(|_| u64::try_from(integer).map(Value::from)) + .unwrap_or_else(|_| Value::from(integer as f64)); + Ok(value) + } + Variant::Date(date) => Ok(Value::String(format_date_string(date))), + Variant::TimestampMicros(ts) | Variant::TimestampNanos(ts) => { + Ok(Value::String(ts.to_rfc3339())) + } + Variant::TimestampNtzMicros(ts) => { + Ok(Value::String(format_timestamp_ntz_string(ts, 6))) + } + Variant::TimestampNtzNanos(ts) => Ok(Value::String(format_timestamp_ntz_string(ts, 9))), + Variant::Time(time) => Ok(Value::String(format_time_ntz_str(time))), + Variant::Binary(bytes) => Ok(Value::String(format_binary_base64(bytes))), + Variant::String(s) => Ok(Value::String(s.to_string())), + Variant::ShortString(s) => Ok(Value::String(s.to_string())), + Variant::Uuid(uuid) => Ok(Value::String(uuid.to_string())), + Variant::Object(obj) => { + let map = obj + .iter() + .map(|(k, v)| v.to_json_value().map(|json_val| (k.to_string(), json_val))) + .collect::>()?; + Ok(Value::Object(map)) + } + Variant::List(arr) => { + let vec = arr + .iter() + .map(|element| element.to_json_value()) + .collect::>()?; + Ok(Value::Array(vec)) + } + } + } +} // Format string constants to avoid duplication and reduce errors const DATE_FORMAT: &str = "%Y-%m-%d"; -const TIMESTAMP_NTZ_FORMAT: &str = "%Y-%m-%dT%H:%M:%S%.6f"; // Helper functions for consistent formatting fn format_date_string(date: &chrono::NaiveDate) -> String { date.format(DATE_FORMAT).to_string() } -fn format_timestamp_ntz_string(ts: &chrono::NaiveDateTime) -> String { - ts.format(TIMESTAMP_NTZ_FORMAT).to_string() +fn format_timestamp_ntz_string(ts: &chrono::NaiveDateTime, precision: usize) -> String { + let format_str = format!( + "{}", + ts.format(&format!("%Y-%m-%dT%H:%M:%S%.{}f", precision)) + ); + ts.format(format_str.as_str()).to_string() } fn format_binary_base64(bytes: &[u8]) -> String { general_purpose::STANDARD.encode(bytes) } -/// -/// This function writes JSON directly to any type that implements [`Write`], -/// making it efficient for streaming or when you want to control the output destination. -/// -/// See [`variant_to_json_string`] for a convenience function that returns a -/// JSON string. -/// -/// # Arguments -/// -/// * `writer` - Writer to output JSON to -/// * `variant` - The Variant value to convert -/// -/// # Returns -/// -/// * `Ok(())` if successful -/// * `Err` with error details if conversion fails -/// -/// # Examples -/// -/// -/// ```rust -/// # use parquet_variant::{Variant}; -/// # use parquet_variant_json::variant_to_json; -/// # use arrow_schema::ArrowError; -/// let variant = Variant::from("Hello, World!"); -/// let mut buffer = Vec::new(); -/// variant_to_json(&mut buffer, &variant)?; -/// assert_eq!(String::from_utf8(buffer).unwrap(), "\"Hello, World!\""); -/// # Ok::<(), ArrowError>(()) -/// ``` -/// -/// # Example: Create a [`Variant::Object`] and convert to JSON -/// ```rust -/// # use parquet_variant::{Variant, VariantBuilder}; -/// # use parquet_variant_json::variant_to_json; -/// # use arrow_schema::ArrowError; -/// let mut builder = VariantBuilder::new(); -/// // Create an object builder that will write fields to the object -/// let mut object_builder = builder.new_object(); -/// object_builder.insert("first_name", "Jiaying"); -/// object_builder.insert("last_name", "Li"); -/// object_builder.finish(); -/// // Finish the builder to get the metadata and value -/// let (metadata, value) = builder.finish(); -/// // Create the Variant and convert to JSON -/// let variant = Variant::try_new(&metadata, &value)?; -/// let mut writer = Vec::new(); -/// variant_to_json(&mut writer, &variant,)?; -/// assert_eq!(br#"{"first_name":"Jiaying","last_name":"Li"}"#, writer.as_slice()); -/// # Ok::<(), ArrowError>(()) -/// ``` -pub fn variant_to_json(json_buffer: &mut impl Write, variant: &Variant) -> Result<(), ArrowError> { - match variant { - Variant::Null => write!(json_buffer, "null")?, - Variant::BooleanTrue => write!(json_buffer, "true")?, - Variant::BooleanFalse => write!(json_buffer, "false")?, - Variant::Int8(i) => write!(json_buffer, "{i}")?, - Variant::Int16(i) => write!(json_buffer, "{i}")?, - Variant::Int32(i) => write!(json_buffer, "{i}")?, - Variant::Int64(i) => write!(json_buffer, "{i}")?, - Variant::Float(f) => write!(json_buffer, "{f}")?, - Variant::Double(f) => write!(json_buffer, "{f}")?, - Variant::Decimal4(decimal) => write!(json_buffer, "{decimal}")?, - Variant::Decimal8(decimal) => write!(json_buffer, "{decimal}")?, - Variant::Decimal16(decimal) => write!(json_buffer, "{decimal}")?, - Variant::Date(date) => write!(json_buffer, "\"{}\"", format_date_string(date))?, - Variant::TimestampMicros(ts) => write!(json_buffer, "\"{}\"", ts.to_rfc3339())?, - Variant::TimestampNtzMicros(ts) => { - write!(json_buffer, "\"{}\"", format_timestamp_ntz_string(ts))? - } - Variant::Binary(bytes) => { - // Encode binary as base64 string - let base64_str = format_binary_base64(bytes); - let json_str = serde_json::to_string(&base64_str).map_err(|e| { - ArrowError::InvalidArgumentError(format!("JSON encoding error: {e}")) - })?; - write!(json_buffer, "{json_str}")? - } - Variant::String(s) => { - // Use serde_json to properly escape the string - let json_str = serde_json::to_string(s).map_err(|e| { - ArrowError::InvalidArgumentError(format!("JSON encoding error: {e}")) - })?; - write!(json_buffer, "{json_str}")? - } - Variant::ShortString(s) => { - // Use serde_json to properly escape the string - let json_str = serde_json::to_string(s.as_str()).map_err(|e| { - ArrowError::InvalidArgumentError(format!("JSON encoding error: {e}")) - })?; - write!(json_buffer, "{json_str}")? - } - Variant::Object(obj) => { - convert_object_to_json(json_buffer, obj)?; - } - Variant::List(arr) => { - convert_array_to_json(json_buffer, arr)?; +fn format_time_ntz_str(time: &chrono::NaiveTime) -> String { + let base = time.format("%H:%M:%S").to_string(); + let micros = time.nanosecond() / 1000; + match micros { + 0 => format!("{}.{}", base, 0), + _ => { + let micros_str = format!("{:06}", micros); + let micros_str_trimmed = micros_str.trim_matches('0'); + format!("{}.{}", base, micros_str_trimmed) } } - Ok(()) } /// Convert object fields to JSON @@ -162,7 +388,7 @@ fn convert_object_to_json(buffer: &mut impl Write, obj: &VariantObject) -> Resul write!(buffer, "{json_key}:")?; // Recursively convert the value - variant_to_json(buffer, &value)?; + value.to_json(buffer)?; } write!(buffer, "}}")?; @@ -180,210 +406,29 @@ fn convert_array_to_json(buffer: &mut impl Write, arr: &VariantList) -> Result<( } first = false; - variant_to_json(buffer, &element)?; + element.to_json(buffer)?; } write!(buffer, "]")?; Ok(()) } -/// Convert [`Variant`] to JSON [`String`] -/// -/// This is a convenience function that converts a Variant to a JSON string. -/// This is the same as calling [`variant_to_json`] with a [`Vec`]. -/// It's the simplest way to get a JSON representation when you just need a String result. -/// -/// # Arguments -/// -/// * `variant` - The Variant value to convert -/// -/// # Returns -/// -/// * `Ok(String)` containing the JSON representation -/// * `Err` with error details if conversion fails -/// -/// # Examples -/// -/// ```rust -/// # use parquet_variant::{Variant}; -/// # use parquet_variant_json::variant_to_json_string; -/// # use arrow_schema::ArrowError; -/// let variant = Variant::Int32(42); -/// let json = variant_to_json_string(&variant)?; -/// assert_eq!(json, "42"); -/// # Ok::<(), ArrowError>(()) -/// ``` -/// -/// # Example: Create a [`Variant::Object`] and convert to JSON -/// -/// This example shows how to create an object with two fields and convert it to JSON: -/// ```json -/// { -/// "first_name": "Jiaying", -/// "last_name": "Li" -/// } -/// ``` -/// -/// ```rust -/// # use parquet_variant::{Variant, VariantBuilder}; -/// # use parquet_variant_json::variant_to_json_string; -/// # use arrow_schema::ArrowError; -/// let mut builder = VariantBuilder::new(); -/// // Create an object builder that will write fields to the object -/// let mut object_builder = builder.new_object(); -/// object_builder.insert("first_name", "Jiaying"); -/// object_builder.insert("last_name", "Li"); -/// object_builder.finish(); -/// // Finish the builder to get the metadata and value -/// let (metadata, value) = builder.finish(); -/// // Create the Variant and convert to JSON -/// let variant = Variant::try_new(&metadata, &value)?; -/// let json = variant_to_json_string(&variant)?; -/// assert_eq!(r#"{"first_name":"Jiaying","last_name":"Li"}"#, json); -/// # Ok::<(), ArrowError>(()) -/// ``` -pub fn variant_to_json_string(variant: &Variant) -> Result { - let mut buffer = Vec::new(); - variant_to_json(&mut buffer, variant)?; - String::from_utf8(buffer) - .map_err(|e| ArrowError::InvalidArgumentError(format!("UTF-8 conversion error: {e}"))) -} - -/// Convert [`Variant`] to [`serde_json::Value`] -/// -/// This function converts a Variant to a [`serde_json::Value`], which is useful -/// when you need to work with the JSON data programmatically or integrate with -/// other serde-based JSON processing. -/// -/// # Arguments -/// -/// * `variant` - The Variant value to convert -/// -/// # Returns -/// -/// * `Ok(Value)` containing the JSON value -/// * `Err` with error details if conversion fails -/// -/// # Examples -/// -/// ```rust -/// # use parquet_variant::{Variant}; -/// # use parquet_variant_json::variant_to_json_value; -/// # use serde_json::Value; -/// # use arrow_schema::ArrowError; -/// let variant = Variant::from("hello"); -/// let json_value = variant_to_json_value(&variant)?; -/// assert_eq!(json_value, Value::String("hello".to_string())); -/// # Ok::<(), ArrowError>(()) -/// ``` -pub fn variant_to_json_value(variant: &Variant) -> Result { - match variant { - Variant::Null => Ok(Value::Null), - Variant::BooleanTrue => Ok(Value::Bool(true)), - Variant::BooleanFalse => Ok(Value::Bool(false)), - Variant::Int8(i) => Ok(Value::Number((*i).into())), - Variant::Int16(i) => Ok(Value::Number((*i).into())), - Variant::Int32(i) => Ok(Value::Number((*i).into())), - Variant::Int64(i) => Ok(Value::Number((*i).into())), - Variant::Float(f) => serde_json::Number::from_f64((*f).into()) - .map(Value::Number) - .ok_or_else(|| ArrowError::InvalidArgumentError("Invalid float value".to_string())), - Variant::Double(f) => serde_json::Number::from_f64(*f) - .map(Value::Number) - .ok_or_else(|| ArrowError::InvalidArgumentError("Invalid double value".to_string())), - Variant::Decimal4(decimal4) => { - let scale = decimal4.scale(); - let integer = decimal4.integer(); - - let integer = if scale == 0 { - integer - } else { - let divisor = 10_i32.pow(scale as u32); - if integer % divisor != 0 { - // fall back to floating point - return Ok(Value::from(integer as f64 / divisor as f64)); - } - integer / divisor - }; - Ok(Value::from(integer)) - } - Variant::Decimal8(decimal8) => { - let scale = decimal8.scale(); - let integer = decimal8.integer(); - - let integer = if scale == 0 { - integer - } else { - let divisor = 10_i64.pow(scale as u32); - if integer % divisor != 0 { - // fall back to floating point - return Ok(Value::from(integer as f64 / divisor as f64)); - } - integer / divisor - }; - Ok(Value::from(integer)) - } - Variant::Decimal16(decimal16) => { - let scale = decimal16.scale(); - let integer = decimal16.integer(); - - let integer = if scale == 0 { - integer - } else { - let divisor = 10_i128.pow(scale as u32); - if integer % divisor != 0 { - // fall back to floating point - return Ok(Value::from(integer as f64 / divisor as f64)); - } - integer / divisor - }; - // i128 has higher precision than any 64-bit type. Try a lossless narrowing cast to - // i64 or u64 first, falling back to a lossy narrowing cast to f64 if necessary. - let value = i64::try_from(integer) - .map(Value::from) - .or_else(|_| u64::try_from(integer).map(Value::from)) - .unwrap_or_else(|_| Value::from(integer as f64)); - Ok(value) - } - Variant::Date(date) => Ok(Value::String(format_date_string(date))), - Variant::TimestampMicros(ts) => Ok(Value::String(ts.to_rfc3339())), - Variant::TimestampNtzMicros(ts) => Ok(Value::String(format_timestamp_ntz_string(ts))), - Variant::Binary(bytes) => Ok(Value::String(format_binary_base64(bytes))), - Variant::String(s) => Ok(Value::String(s.to_string())), - Variant::ShortString(s) => Ok(Value::String(s.to_string())), - Variant::Object(obj) => { - let map = obj - .iter() - .map(|(k, v)| variant_to_json_value(&v).map(|json_val| (k.to_string(), json_val))) - .collect::>()?; - Ok(Value::Object(map)) - } - Variant::List(arr) => { - let vec = arr - .iter() - .map(|element| variant_to_json_value(&element)) - .collect::>()?; - Ok(Value::Array(vec)) - } - } -} - #[cfg(test)] mod tests { use super::*; - use chrono::{DateTime, NaiveDate, Utc}; + use chrono::{DateTime, NaiveDate, NaiveTime, Utc}; use parquet_variant::{VariantDecimal16, VariantDecimal4, VariantDecimal8}; #[test] fn test_decimal_edge_cases() -> Result<(), ArrowError> { // Test negative decimal let negative_variant = Variant::from(VariantDecimal4::try_new(-12345, 3)?); - let negative_json = variant_to_json_string(&negative_variant)?; + let negative_json = negative_variant.to_json_string()?; assert_eq!(negative_json, "-12.345"); // Test large scale decimal let large_scale_variant = Variant::from(VariantDecimal8::try_new(123456789, 6)?); - let large_scale_json = variant_to_json_string(&large_scale_variant)?; + let large_scale_json = large_scale_variant.to_json_string()?; assert_eq!(large_scale_json, "123.456789"); Ok(()) @@ -392,15 +437,15 @@ mod tests { #[test] fn test_decimal16_to_json() -> Result<(), ArrowError> { let variant = Variant::from(VariantDecimal16::try_new(123456789012345, 4)?); - let json = variant_to_json_string(&variant)?; + let json = variant.to_json_string()?; assert_eq!(json, "12345678901.2345"); - let json_value = variant_to_json_value(&variant)?; + let json_value = variant.to_json_value()?; assert!(matches!(json_value, Value::Number(_))); // Test very large number let large_variant = Variant::from(VariantDecimal16::try_new(999999999999999999, 2)?); - let large_json = variant_to_json_string(&large_variant)?; + let large_json = large_variant.to_json_string()?; // Due to f64 precision limits, very large numbers may lose precision assert!( large_json.starts_with("9999999999999999") @@ -413,16 +458,16 @@ mod tests { fn test_date_to_json() -> Result<(), ArrowError> { let date = NaiveDate::from_ymd_opt(2023, 12, 25).unwrap(); let variant = Variant::Date(date); - let json = variant_to_json_string(&variant)?; + let json = variant.to_json_string()?; assert_eq!(json, "\"2023-12-25\""); - let json_value = variant_to_json_value(&variant)?; + let json_value = variant.to_json_value()?; assert_eq!(json_value, Value::String("2023-12-25".to_string())); // Test leap year date let leap_date = NaiveDate::from_ymd_opt(2024, 2, 29).unwrap(); let leap_variant = Variant::Date(leap_date); - let leap_json = variant_to_json_string(&leap_variant)?; + let leap_json = leap_variant.to_json_string()?; assert_eq!(leap_json, "\"2024-02-29\""); Ok(()) } @@ -433,11 +478,11 @@ mod tests { .unwrap() .with_timezone(&Utc); let variant = Variant::TimestampMicros(timestamp); - let json = variant_to_json_string(&variant)?; + let json = variant.to_json_string()?; assert!(json.contains("2023-12-25T10:30:45")); assert!(json.starts_with('"') && json.ends_with('"')); - let json_value = variant_to_json_value(&variant)?; + let json_value = variant.to_json_value()?; assert!(matches!(json_value, Value::String(_))); Ok(()) } @@ -448,11 +493,51 @@ mod tests { .unwrap() .naive_utc(); let variant = Variant::TimestampNtzMicros(naive_timestamp); - let json = variant_to_json_string(&variant)?; + let json = variant.to_json_string()?; assert!(json.contains("2023-12-25")); assert!(json.starts_with('"') && json.ends_with('"')); - let json_value = variant_to_json_value(&variant)?; + let json_value = variant.to_json_value()?; + assert!(matches!(json_value, Value::String(_))); + Ok(()) + } + + #[test] + fn test_time_to_json() -> Result<(), ArrowError> { + let naive_time = NaiveTime::from_num_seconds_from_midnight_opt(12345, 123460708).unwrap(); + let variant = Variant::Time(naive_time); + let json = variant.to_json_string()?; + assert_eq!("\"03:25:45.12346\"", json); + + let json_value = variant.to_json_value()?; + assert!(matches!(json_value, Value::String(_))); + Ok(()) + } + + #[test] + fn test_timestamp_nanos_to_json() -> Result<(), ArrowError> { + let timestamp = DateTime::parse_from_rfc3339("2023-12-25T10:30:45.123456789Z") + .unwrap() + .with_timezone(&Utc); + let variant = Variant::TimestampNanos(timestamp); + let json = variant.to_json_string()?; + assert_eq!(json, "\"2023-12-25T10:30:45.123456789+00:00\""); + + let json_value = variant.to_json_value()?; + assert!(matches!(json_value, Value::String(_))); + Ok(()) + } + + #[test] + fn test_timestamp_ntz_nanos_to_json() -> Result<(), ArrowError> { + let naive_timestamp = DateTime::from_timestamp(1703505045, 123456789) + .unwrap() + .naive_utc(); + let variant = Variant::TimestampNtzNanos(naive_timestamp); + let json = variant.to_json_string()?; + assert_eq!(json, "\"2023-12-25T11:50:45.123456789\""); + + let json_value = variant.to_json_value()?; assert!(matches!(json_value, Value::String(_))); Ok(()) } @@ -461,23 +546,23 @@ mod tests { fn test_binary_to_json() -> Result<(), ArrowError> { let binary_data = b"Hello, World!"; let variant = Variant::Binary(binary_data); - let json = variant_to_json_string(&variant)?; + let json = variant.to_json_string()?; // Should be base64 encoded and quoted assert!(json.starts_with('"') && json.ends_with('"')); assert!(json.len() > 2); // Should have content - let json_value = variant_to_json_value(&variant)?; + let json_value = variant.to_json_value()?; assert!(matches!(json_value, Value::String(_))); // Test empty binary let empty_variant = Variant::Binary(b""); - let empty_json = variant_to_json_string(&empty_variant)?; + let empty_json = empty_variant.to_json_string()?; assert_eq!(empty_json, "\"\""); // Test binary with special bytes let special_variant = Variant::Binary(&[0, 255, 128, 64]); - let special_json = variant_to_json_string(&special_variant)?; + let special_json = special_variant.to_json_string()?; assert!(special_json.starts_with('"') && special_json.ends_with('"')); Ok(()) } @@ -485,10 +570,10 @@ mod tests { #[test] fn test_string_to_json() -> Result<(), ArrowError> { let variant = Variant::from("hello world"); - let json = variant_to_json_string(&variant)?; + let json = variant.to_json_string()?; assert_eq!(json, "\"hello world\""); - let json_value = variant_to_json_value(&variant)?; + let json_value = variant.to_json_value()?; assert_eq!(json_value, Value::String("hello world".to_string())); Ok(()) } @@ -498,21 +583,36 @@ mod tests { use parquet_variant::ShortString; let short_string = ShortString::try_new("short")?; let variant = Variant::ShortString(short_string); - let json = variant_to_json_string(&variant)?; + let json = variant.to_json_string()?; assert_eq!(json, "\"short\""); - let json_value = variant_to_json_value(&variant)?; + let json_value = variant.to_json_value()?; assert_eq!(json_value, Value::String("short".to_string())); Ok(()) } + #[test] + fn test_uuid_to_json() -> Result<(), ArrowError> { + let uuid = uuid::Uuid::parse_str("123e4567-e89b-12d3-a456-426614174000").unwrap(); + let variant = Variant::Uuid(uuid); + let json = variant.to_json_string()?; + assert_eq!(json, "\"123e4567-e89b-12d3-a456-426614174000\""); + + let json_value = variant.to_json_value()?; + assert_eq!( + json_value, + Value::String("123e4567-e89b-12d3-a456-426614174000".to_string()) + ); + Ok(()) + } + #[test] fn test_string_escaping() -> Result<(), ArrowError> { let variant = Variant::from("hello\nworld\t\"quoted\""); - let json = variant_to_json_string(&variant)?; + let json = variant.to_json_string()?; assert_eq!(json, "\"hello\\nworld\\t\\\"quoted\\\"\""); - let json_value = variant_to_json_value(&variant)?; + let json_value = variant.to_json_value()?; assert_eq!( json_value, Value::String("hello\nworld\t\"quoted\"".to_string()) @@ -524,7 +624,7 @@ mod tests { fn test_json_buffer_writing() -> Result<(), ArrowError> { let variant = Variant::Int8(123); let mut buffer = Vec::new(); - variant_to_json(&mut buffer, &variant)?; + variant.to_json(&mut buffer)?; let result = String::from_utf8(buffer) .map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?; @@ -541,7 +641,9 @@ mod tests { impl JsonTest { fn run(self) { - let json_string = variant_to_json_string(&self.variant) + let json_string = self + .variant + .to_json_string() .expect("variant_to_json_string should succeed"); assert_eq!( json_string, self.expected_json, @@ -549,8 +651,10 @@ mod tests { self.variant ); - let json_value = - variant_to_json_value(&self.variant).expect("variant_to_json_value should succeed"); + let json_value = self + .variant + .to_json_value() + .expect("variant_to_json_value should succeed"); // For floating point numbers, we need special comparison due to JSON number representation match (&json_value, &self.expected_value) { @@ -830,20 +934,18 @@ mod tests { #[test] fn test_buffer_writing_variants() -> Result<(), ArrowError> { - use crate::variant_to_json; - let variant = Variant::from("test buffer writing"); // Test writing to a Vec let mut buffer = Vec::new(); - variant_to_json(&mut buffer, &variant)?; + variant.to_json(&mut buffer)?; let result = String::from_utf8(buffer) .map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?; assert_eq!(result, "\"test buffer writing\""); // Test writing to vec![] let mut buffer = vec![]; - variant_to_json(&mut buffer, &variant)?; + variant.to_json(&mut buffer)?; let result = String::from_utf8(buffer) .map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?; assert_eq!(result, "\"test buffer writing\""); @@ -864,12 +966,11 @@ mod tests { .with_field("age", 30i32) .with_field("active", true) .with_field("score", 95.5f64) - .finish() - .unwrap(); + .finish(); let (metadata, value) = builder.finish(); let variant = Variant::try_new(&metadata, &value)?; - let json = variant_to_json_string(&variant)?; + let json = variant.to_json_string()?; // Parse the JSON to verify structure - handle JSON parsing errors manually let parsed: Value = serde_json::from_str(&json).unwrap(); @@ -881,7 +982,7 @@ mod tests { assert_eq!(obj.len(), 4); // Test variant_to_json_value as well - let json_value = variant_to_json_value(&variant)?; + let json_value = variant.to_json_value()?; assert!(matches!(json_value, Value::Object(_))); Ok(()) @@ -895,15 +996,15 @@ mod tests { { let obj = builder.new_object(); - obj.finish().unwrap(); + obj.finish(); } let (metadata, value) = builder.finish(); let variant = Variant::try_new(&metadata, &value)?; - let json = variant_to_json_string(&variant)?; + let json = variant.to_json_string()?; assert_eq!(json, "{}"); - let json_value = variant_to_json_value(&variant)?; + let json_value = variant.to_json_value()?; assert_eq!(json_value, Value::Object(serde_json::Map::new())); Ok(()) @@ -920,12 +1021,11 @@ mod tests { .with_field("message", "Hello \"World\"\nWith\tTabs") .with_field("path", "C:\\Users\\Alice\\Documents") .with_field("unicode", "😀 Smiley") - .finish() - .unwrap(); + .finish(); let (metadata, value) = builder.finish(); let variant = Variant::try_new(&metadata, &value)?; - let json = variant_to_json_string(&variant)?; + let json = variant.to_json_string()?; // Verify that special characters are properly escaped assert!(json.contains("Hello \\\"World\\\"\\nWith\\tTabs")); @@ -956,10 +1056,10 @@ mod tests { let (metadata, value) = builder.finish(); let variant = Variant::try_new(&metadata, &value)?; - let json = variant_to_json_string(&variant)?; + let json = variant.to_json_string()?; assert_eq!(json, "[1,2,3,4,5]"); - let json_value = variant_to_json_value(&variant)?; + let json_value = variant.to_json_value()?; let arr = json_value.as_array().expect("expected JSON array"); assert_eq!(arr.len(), 5); assert_eq!(arr[0], Value::Number(1.into())); @@ -981,10 +1081,10 @@ mod tests { let (metadata, value) = builder.finish(); let variant = Variant::try_new(&metadata, &value)?; - let json = variant_to_json_string(&variant)?; + let json = variant.to_json_string()?; assert_eq!(json, "[]"); - let json_value = variant_to_json_value(&variant)?; + let json_value = variant.to_json_value()?; assert_eq!(json_value, Value::Array(vec![])); Ok(()) @@ -1007,7 +1107,7 @@ mod tests { let (metadata, value) = builder.finish(); let variant = Variant::try_new(&metadata, &value)?; - let json = variant_to_json_string(&variant)?; + let json = variant.to_json_string()?; let parsed: Value = serde_json::from_str(&json).unwrap(); let arr = parsed.as_array().expect("expected JSON array"); @@ -1033,12 +1133,12 @@ mod tests { obj.insert("zebra", "last"); obj.insert("alpha", "first"); obj.insert("beta", "second"); - obj.finish().unwrap(); + obj.finish(); } let (metadata, value) = builder.finish(); let variant = Variant::try_new(&metadata, &value)?; - let json = variant_to_json_string(&variant)?; + let json = variant.to_json_string()?; // Parse and verify all fields are present let parsed: Value = serde_json::from_str(&json).unwrap(); @@ -1070,7 +1170,7 @@ mod tests { let (metadata, value) = builder.finish(); let variant = Variant::try_new(&metadata, &value)?; - let json = variant_to_json_string(&variant)?; + let json = variant.to_json_string()?; let parsed: Value = serde_json::from_str(&json).unwrap(); let arr = parsed.as_array().expect("expected JSON array"); @@ -1100,12 +1200,12 @@ mod tests { obj.insert("float_field", 2.71f64); obj.insert("null_field", ()); obj.insert("long_field", 999i64); - obj.finish().unwrap(); + obj.finish(); } let (metadata, value) = builder.finish(); let variant = Variant::try_new(&metadata, &value)?; - let json = variant_to_json_string(&variant)?; + let json = variant.to_json_string()?; let parsed: Value = serde_json::from_str(&json).unwrap(); let obj = parsed.as_object().expect("expected JSON object"); @@ -1132,8 +1232,8 @@ mod tests { 6, )?); - let json_string = variant_to_json_string(&high_precision_decimal8)?; - let json_value = variant_to_json_value(&high_precision_decimal8)?; + let json_string = high_precision_decimal8.to_json_string()?; + let json_value = high_precision_decimal8.to_json_value()?; // Due to f64 precision limits, we expect precision loss for values > 2^53 // Both functions should produce consistent results (even if not exact) @@ -1146,7 +1246,7 @@ mod tests { 6, )?); - let json_string_exact = variant_to_json_string(&exact_decimal)?; + let json_string_exact = exact_decimal.to_json_string()?; assert_eq!(json_string_exact, "1234567.89"); // Test integer case (should be exact) @@ -1155,7 +1255,7 @@ mod tests { 6, )?); - let json_string_integer = variant_to_json_string(&integer_decimal)?; + let json_string_integer = integer_decimal.to_json_string()?; assert_eq!(json_string_integer, "42"); Ok(()) @@ -1165,7 +1265,7 @@ mod tests { fn test_float_nan_inf_handling() -> Result<(), ArrowError> { // Test NaN handling - should return an error since JSON doesn't support NaN let nan_variant = Variant::Float(f32::NAN); - let nan_result = variant_to_json_value(&nan_variant); + let nan_result = nan_variant.to_json_value(); assert!(nan_result.is_err()); assert!(nan_result .unwrap_err() @@ -1174,7 +1274,7 @@ mod tests { // Test positive infinity - should return an error since JSON doesn't support Infinity let pos_inf_variant = Variant::Float(f32::INFINITY); - let pos_inf_result = variant_to_json_value(&pos_inf_variant); + let pos_inf_result = pos_inf_variant.to_json_value(); assert!(pos_inf_result.is_err()); assert!(pos_inf_result .unwrap_err() @@ -1183,7 +1283,7 @@ mod tests { // Test negative infinity - should return an error since JSON doesn't support -Infinity let neg_inf_variant = Variant::Float(f32::NEG_INFINITY); - let neg_inf_result = variant_to_json_value(&neg_inf_variant); + let neg_inf_result = neg_inf_variant.to_json_value(); assert!(neg_inf_result.is_err()); assert!(neg_inf_result .unwrap_err() @@ -1192,7 +1292,7 @@ mod tests { // Test the same for Double variants let nan_double_variant = Variant::Double(f64::NAN); - let nan_double_result = variant_to_json_value(&nan_double_variant); + let nan_double_result = nan_double_variant.to_json_value(); assert!(nan_double_result.is_err()); assert!(nan_double_result .unwrap_err() @@ -1200,7 +1300,7 @@ mod tests { .contains("Invalid double value")); let pos_inf_double_variant = Variant::Double(f64::INFINITY); - let pos_inf_double_result = variant_to_json_value(&pos_inf_double_variant); + let pos_inf_double_result = pos_inf_double_variant.to_json_value(); assert!(pos_inf_double_result.is_err()); assert!(pos_inf_double_result .unwrap_err() @@ -1208,7 +1308,7 @@ mod tests { .contains("Invalid double value")); let neg_inf_double_variant = Variant::Double(f64::NEG_INFINITY); - let neg_inf_double_result = variant_to_json_value(&neg_inf_double_variant); + let neg_inf_double_result = neg_inf_double_variant.to_json_value(); assert!(neg_inf_double_result.is_err()); assert!(neg_inf_double_result .unwrap_err() @@ -1217,11 +1317,11 @@ mod tests { // Test normal float values still work let normal_float = Variant::Float(std::f32::consts::PI); - let normal_result = variant_to_json_value(&normal_float)?; + let normal_result = normal_float.to_json_value()?; assert!(matches!(normal_result, Value::Number(_))); let normal_double = Variant::Double(std::f64::consts::E); - let normal_double_result = variant_to_json_value(&normal_double)?; + let normal_double_result = normal_double.to_json_value()?; assert!(matches!(normal_double_result, Value::Number(_))); Ok(()) diff --git a/parquet-variant/Cargo.toml b/parquet-variant/Cargo.toml index 51fa4cc23311..6e88bff6bd3a 100644 --- a/parquet-variant/Cargo.toml +++ b/parquet-variant/Cargo.toml @@ -33,7 +33,9 @@ rust-version = { workspace = true } [dependencies] arrow-schema = { workspace = true } chrono = { workspace = true } +half = { version = "2.1", default-features = false } indexmap = "2.10.0" +uuid = { version = "1.18.0", features = ["v4"]} simdutf8 = { workspace = true , optional = true } diff --git a/parquet-variant/benches/variant_builder.rs b/parquet-variant/benches/variant_builder.rs index a42327fe1335..5d00cc054e55 100644 --- a/parquet-variant/benches/variant_builder.rs +++ b/parquet-variant/benches/variant_builder.rs @@ -77,7 +77,7 @@ fn bench_object_field_names_reverse_order(c: &mut Criterion) { object_builder.insert(format!("{}", 1000 - i).as_str(), string_table.next()); } - object_builder.finish().unwrap(); + object_builder.finish(); hint::black_box(variant.finish()); }) }); @@ -113,7 +113,7 @@ fn bench_object_same_schema(c: &mut Criterion) { inner_list_builder.append_value(string_table.next()); inner_list_builder.finish(); - object_builder.finish().unwrap(); + object_builder.finish(); hint::black_box(variant.finish()); } @@ -154,7 +154,7 @@ fn bench_object_list_same_schema(c: &mut Criterion) { list_builder.append_value(string_table.next()); list_builder.finish(); - object_builder.finish().unwrap(); + object_builder.finish(); } list_builder.finish(); @@ -189,7 +189,7 @@ fn bench_object_unknown_schema(c: &mut Criterion) { let key = string_table.next(); inner_object_builder.insert(key, key); } - inner_object_builder.finish().unwrap(); + inner_object_builder.finish(); continue; } @@ -202,7 +202,7 @@ fn bench_object_unknown_schema(c: &mut Criterion) { inner_list_builder.finish(); } - object_builder.finish().unwrap(); + object_builder.finish(); hint::black_box(variant.finish()); } }) @@ -241,7 +241,7 @@ fn bench_object_list_unknown_schema(c: &mut Criterion) { let key = string_table.next(); inner_object_builder.insert(key, key); } - inner_object_builder.finish().unwrap(); + inner_object_builder.finish(); continue; } @@ -254,7 +254,7 @@ fn bench_object_list_unknown_schema(c: &mut Criterion) { inner_list_builder.finish(); } - object_builder.finish().unwrap(); + object_builder.finish(); } list_builder.finish(); @@ -314,10 +314,10 @@ fn bench_object_partially_same_schema(c: &mut Criterion) { let key = string_table.next(); inner_object_builder.insert(key, key); } - inner_object_builder.finish().unwrap(); + inner_object_builder.finish(); } - object_builder.finish().unwrap(); + object_builder.finish(); hint::black_box(variant.finish()); } }) @@ -376,10 +376,10 @@ fn bench_object_list_partially_same_schema(c: &mut Criterion) { let key = string_table.next(); inner_object_builder.insert(key, key); } - inner_object_builder.finish().unwrap(); + inner_object_builder.finish(); } - object_builder.finish().unwrap(); + object_builder.finish(); } list_builder.finish(); @@ -408,7 +408,7 @@ fn bench_validation_validated_vs_unvalidated(c: &mut Criterion) { } list.finish(); - obj.finish().unwrap(); + obj.finish(); test_data.push(builder.finish()); } @@ -462,7 +462,7 @@ fn bench_iteration_performance(c: &mut Criterion) { let mut obj = list.new_object(); obj.insert(&format!("field_{i}"), rng.random::()); obj.insert("nested_data", format!("data_{i}").as_str()); - obj.finish().unwrap(); + obj.finish(); } list.finish(); diff --git a/parquet-variant/benches/variant_validation.rs b/parquet-variant/benches/variant_validation.rs index 0ccc10117898..dcf7681a76ed 100644 --- a/parquet-variant/benches/variant_validation.rs +++ b/parquet-variant/benches/variant_validation.rs @@ -40,9 +40,9 @@ fn generate_large_object() -> (Vec, Vec) { } list_builder.finish(); } - inner_object.finish().unwrap(); + inner_object.finish(); } - outer_object.finish().unwrap(); + outer_object.finish(); variant_builder.finish() } @@ -72,9 +72,9 @@ fn generate_complex_object() -> (Vec, Vec) { let key = format!("{}", 1024 - i); inner_object_builder.insert(&key, i); } - inner_object_builder.finish().unwrap(); + inner_object_builder.finish(); - object_builder.finish().unwrap(); + object_builder.finish(); variant_builder.finish() } diff --git a/parquet-variant/src/builder.rs b/parquet-variant/src/builder.rs index b1607f8f306d..93e736285853 100644 --- a/parquet-variant/src/builder.rs +++ b/parquet-variant/src/builder.rs @@ -20,8 +20,11 @@ use crate::{ VariantMetadata, VariantObject, }; use arrow_schema::ArrowError; +use chrono::Timelike; use indexmap::{IndexMap, IndexSet}; -use std::collections::HashSet; +use uuid::Uuid; + +use std::collections::HashMap; const BASIC_TYPE_BITS: u8 = 2; const UNIX_EPOCH_DATE: chrono::NaiveDate = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); @@ -85,28 +88,49 @@ fn append_packed_u32(dest: &mut Vec, value: u32, value_size: usize) { /// /// You can reuse an existing `Vec` by using the `from` impl #[derive(Debug, Default)] -struct ValueBuffer(Vec); +pub struct ValueBuilder(Vec); -impl ValueBuffer { +impl ValueBuilder { /// Construct a ValueBuffer that will write to a new underlying `Vec` - fn new() -> Self { + pub fn new() -> Self { Default::default() } } -impl From> for ValueBuffer { - fn from(value: Vec) -> Self { - Self(value) - } -} - -impl From for Vec { - fn from(value_buffer: ValueBuffer) -> Self { - value_buffer.0 - } +/// Macro to generate the match statement for each append_variant, try_append_variant, and +/// append_variant_bytes -- they each have slightly different handling for object and list handling. +macro_rules! variant_append_value { + ($builder:expr, $value:expr, $object_pat:pat => $object_arm:expr, $list_pat:pat => $list_arm:expr) => { + match $value { + Variant::Null => $builder.append_null(), + Variant::BooleanTrue => $builder.append_bool(true), + Variant::BooleanFalse => $builder.append_bool(false), + Variant::Int8(v) => $builder.append_int8(v), + Variant::Int16(v) => $builder.append_int16(v), + Variant::Int32(v) => $builder.append_int32(v), + Variant::Int64(v) => $builder.append_int64(v), + Variant::Date(v) => $builder.append_date(v), + Variant::Time(v) => $builder.append_time_micros(v), + Variant::TimestampMicros(v) => $builder.append_timestamp_micros(v), + Variant::TimestampNtzMicros(v) => $builder.append_timestamp_ntz_micros(v), + Variant::TimestampNanos(v) => $builder.append_timestamp_nanos(v), + Variant::TimestampNtzNanos(v) => $builder.append_timestamp_ntz_nanos(v), + Variant::Decimal4(decimal4) => $builder.append_decimal4(decimal4), + Variant::Decimal8(decimal8) => $builder.append_decimal8(decimal8), + Variant::Decimal16(decimal16) => $builder.append_decimal16(decimal16), + Variant::Float(v) => $builder.append_float(v), + Variant::Double(v) => $builder.append_double(v), + Variant::Binary(v) => $builder.append_binary(v), + Variant::String(s) => $builder.append_string(s), + Variant::ShortString(s) => $builder.append_short_string(s), + Variant::Uuid(v) => $builder.append_uuid(v), + $object_pat => $object_arm, + $list_pat => $list_arm, + } + }; } -impl ValueBuffer { +impl ValueBuilder { fn append_u8(&mut self, term: u8) { self.0.push(term); } @@ -119,8 +143,9 @@ impl ValueBuffer { self.0.push(primitive_header(primitive_type)); } - fn into_inner(self) -> Vec { - self.into() + /// Returns the underlying buffer, consuming self + pub fn into_inner(self) -> Vec { + self.0 } fn inner_mut(&mut self) -> &mut Vec { @@ -190,6 +215,30 @@ impl ValueBuffer { self.append_slice(µs.to_le_bytes()); } + fn append_time_micros(&mut self, value: chrono::NaiveTime) { + self.append_primitive_header(VariantPrimitiveType::Time); + let micros_from_midnight = value.num_seconds_from_midnight() as u64 * 1_000_000 + + value.nanosecond() as u64 / 1_000; + self.append_slice(µs_from_midnight.to_le_bytes()); + } + + fn append_timestamp_nanos(&mut self, value: chrono::DateTime) { + self.append_primitive_header(VariantPrimitiveType::TimestampNanos); + let nanos = value.timestamp_nanos_opt().unwrap(); + self.append_slice(&nanos.to_le_bytes()); + } + + fn append_timestamp_ntz_nanos(&mut self, value: chrono::NaiveDateTime) { + self.append_primitive_header(VariantPrimitiveType::TimestampNtzNanos); + let nanos = value.and_utc().timestamp_nanos_opt().unwrap(); + self.append_slice(&nanos.to_le_bytes()); + } + + fn append_uuid(&mut self, value: Uuid) { + self.append_primitive_header(VariantPrimitiveType::Uuid); + self.append_slice(&value.into_bytes()); + } + fn append_decimal4(&mut self, decimal4: VariantDecimal4) { self.append_primitive_header(VariantPrimitiveType::Decimal4); self.append_u8(decimal4.scale()); @@ -226,47 +275,44 @@ impl ValueBuffer { self.append_slice(value.as_bytes()); } - fn append_object(&mut self, metadata_builder: &mut MetadataBuilder, obj: VariantObject) { - let mut object_builder = self.new_object(metadata_builder); + fn append_object(state: ParentState<'_, S>, obj: VariantObject) { + let mut object_builder = ObjectBuilder::new(state, false); for (field_name, value) in obj.iter() { object_builder.insert(field_name, value); } - object_builder.finish().unwrap(); + object_builder.finish(); } - fn try_append_object( - &mut self, - metadata_builder: &mut MetadataBuilder, + fn try_append_object( + state: ParentState<'_, S>, obj: VariantObject, ) -> Result<(), ArrowError> { - let mut object_builder = self.new_object(metadata_builder); + let mut object_builder = ObjectBuilder::new(state, false); for res in obj.iter_try() { let (field_name, value) = res?; object_builder.try_insert(field_name, value)?; } - object_builder.finish()?; - + object_builder.finish(); Ok(()) } - fn append_list(&mut self, metadata_builder: &mut MetadataBuilder, list: VariantList) { - let mut list_builder = self.new_list(metadata_builder); + fn append_list(state: ParentState<'_, S>, list: VariantList) { + let mut list_builder = ListBuilder::new(state, false); for value in list.iter() { list_builder.append_value(value); } list_builder.finish(); } - fn try_append_list( - &mut self, - metadata_builder: &mut MetadataBuilder, + fn try_append_list( + state: ParentState<'_, S>, list: VariantList, ) -> Result<(), ArrowError> { - let mut list_builder = self.new_list(metadata_builder); + let mut list_builder = ListBuilder::new(state, false); for res in list.iter_try() { let value = res?; list_builder.try_append_value(value)?; @@ -277,100 +323,72 @@ impl ValueBuffer { Ok(()) } - fn offset(&self) -> usize { + /// Returns the current size of the underlying buffer + pub fn offset(&self) -> usize { self.0.len() } - fn new_object<'a>( - &'a mut self, - metadata_builder: &'a mut MetadataBuilder, - ) -> ObjectBuilder<'a> { - let parent_state = ParentState::Variant { - buffer: self, - metadata_builder, - }; - let validate_unique_fields = false; - ObjectBuilder::new(parent_state, validate_unique_fields) - } - - fn new_list<'a>(&'a mut self, metadata_builder: &'a mut MetadataBuilder) -> ListBuilder<'a> { - let parent_state = ParentState::Variant { - buffer: self, - metadata_builder, - }; - let validate_unique_fields = false; - ListBuilder::new(parent_state, validate_unique_fields) - } - - /// Appends a variant to the buffer. + /// Appends a variant to the builder. /// /// # Panics /// /// This method will panic if the variant contains duplicate field names in objects - /// when validation is enabled. For a fallible version, use [`ValueBuffer::try_append_variant`] - fn append_variant<'m, 'd>( - &mut self, - variant: Variant<'m, 'd>, - metadata_builder: &mut MetadataBuilder, + /// when validation is enabled. For a fallible version, use [`ValueBuilder::try_append_variant`] + pub fn append_variant( + mut state: ParentState<'_, S>, + variant: Variant<'_, '_>, ) { - match variant { - Variant::Null => self.append_null(), - Variant::BooleanTrue => self.append_bool(true), - Variant::BooleanFalse => self.append_bool(false), - Variant::Int8(v) => self.append_int8(v), - Variant::Int16(v) => self.append_int16(v), - Variant::Int32(v) => self.append_int32(v), - Variant::Int64(v) => self.append_int64(v), - Variant::Date(v) => self.append_date(v), - Variant::TimestampMicros(v) => self.append_timestamp_micros(v), - Variant::TimestampNtzMicros(v) => self.append_timestamp_ntz_micros(v), - Variant::Decimal4(decimal4) => self.append_decimal4(decimal4), - Variant::Decimal8(decimal8) => self.append_decimal8(decimal8), - Variant::Decimal16(decimal16) => self.append_decimal16(decimal16), - Variant::Float(v) => self.append_float(v), - Variant::Double(v) => self.append_double(v), - Variant::Binary(v) => self.append_binary(v), - Variant::String(s) => self.append_string(s), - Variant::ShortString(s) => self.append_short_string(s), - Variant::Object(obj) => self.append_object(metadata_builder, obj), - Variant::List(list) => self.append_list(metadata_builder, list), - } + variant_append_value!( + state.value_builder(), + variant, + Variant::Object(obj) => return Self::append_object(state, obj), + Variant::List(list) => return Self::append_list(state, list) + ); + state.finish(); } - /// Appends a variant to the buffer - fn try_append_variant<'m, 'd>( - &mut self, - variant: Variant<'m, 'd>, - metadata_builder: &mut MetadataBuilder, + /// Tries to append a variant to the provided [`ParentState`] instance. + /// + /// The attempt fails if the variant contains duplicate field names in objects when validation + /// is enabled. + pub fn try_append_variant( + mut state: ParentState<'_, S>, + variant: Variant<'_, '_>, ) -> Result<(), ArrowError> { - match variant { - Variant::Null => self.append_null(), - Variant::BooleanTrue => self.append_bool(true), - Variant::BooleanFalse => self.append_bool(false), - Variant::Int8(v) => self.append_int8(v), - Variant::Int16(v) => self.append_int16(v), - Variant::Int32(v) => self.append_int32(v), - Variant::Int64(v) => self.append_int64(v), - Variant::Date(v) => self.append_date(v), - Variant::TimestampMicros(v) => self.append_timestamp_micros(v), - Variant::TimestampNtzMicros(v) => self.append_timestamp_ntz_micros(v), - Variant::Decimal4(decimal4) => self.append_decimal4(decimal4), - Variant::Decimal8(decimal8) => self.append_decimal8(decimal8), - Variant::Decimal16(decimal16) => self.append_decimal16(decimal16), - Variant::Float(v) => self.append_float(v), - Variant::Double(v) => self.append_double(v), - Variant::Binary(v) => self.append_binary(v), - Variant::String(s) => self.append_string(s), - Variant::ShortString(s) => self.append_short_string(s), - Variant::Object(obj) => self.try_append_object(metadata_builder, obj)?, - Variant::List(list) => self.try_append_list(metadata_builder, list)?, - } - + variant_append_value!( + state.value_builder(), + variant, + Variant::Object(obj) => return Self::try_append_object(state, obj), + Variant::List(list) => return Self::try_append_list(state, list) + ); + state.finish(); Ok(()) } + /// Appends a variant to the buffer by copying raw bytes when possible. + /// + /// For objects and lists, this directly copies their underlying byte representation instead of + /// performing a logical copy and without touching the metadata builder. For other variant + /// types, this falls back to the standard append behavior. + /// + /// The caller must ensure that the metadata dictionary is already built and correct for + /// any objects or lists being appended. + pub fn append_variant_bytes( + mut state: ParentState<'_, S>, + variant: Variant<'_, '_>, + ) { + let builder = state.value_builder(); + variant_append_value!( + builder, + variant, + Variant::Object(obj) => builder.append_slice(obj.value), + Variant::List(list) => builder.append_slice(list.value) + ); + state.finish(); + } + /// Writes out the header byte for a variant object or list, from the starting position - /// of the buffer, will return the position after this write + /// of the builder, will return the position after this write fn append_header_start_from_buf_pos( &mut self, start_pos: usize, // the start position where the header will be inserted @@ -427,13 +445,111 @@ impl ValueBuffer { } } +/// A trait for building variant metadata dictionaries, to be used in conjunction with a +/// [`ValueBuilder`]. The trait provides methods for managing field names and their IDs, as well as +/// rolling back a failed builder operation that might have created new field ids. +pub trait MetadataBuilder: std::fmt::Debug { + /// Attempts to register a field name, returning the corresponding (possibly newly-created) + /// field id on success. Attempting to register the same field name twice will _generally_ + /// produce the same field id both times, but the variant spec does not actually require it. + fn try_upsert_field_name(&mut self, field_name: &str) -> Result; + + /// Retrieves the field name for a given field id, which must be less than + /// [`Self::num_field_names`]. Panics if the field id is out of bounds. + fn field_name(&self, field_id: usize) -> &str; + + /// Returns the number of field names stored in this metadata builder. Any number less than this + /// is a valid field id. The builder can be reverted back to this size later on (discarding any + /// newer/higher field ids) by calling [`Self::truncate_field_names`]. + fn num_field_names(&self) -> usize; + + /// Reverts the field names to a previous size, discarding any newly out of bounds field ids. + fn truncate_field_names(&mut self, new_size: usize); + + /// Finishes the current metadata dictionary, returning the new size of the underlying buffer. + fn finish(&mut self) -> usize; +} + +impl MetadataBuilder for WritableMetadataBuilder { + fn try_upsert_field_name(&mut self, field_name: &str) -> Result { + Ok(self.upsert_field_name(field_name)) + } + fn field_name(&self, field_id: usize) -> &str { + self.field_name(field_id) + } + fn num_field_names(&self) -> usize { + self.num_field_names() + } + fn truncate_field_names(&mut self, new_size: usize) { + self.field_names.truncate(new_size) + } + fn finish(&mut self) -> usize { + self.finish() + } +} + +/// A metadata builder that cannot register new field names, and merely returns the field id +/// associated with a known field name. This is useful for variant unshredding operations, where the +/// metadata column is fixed and -- per variant shredding spec -- already contains all field names +/// from the typed_value column. It is also useful when projecting a subset of fields from a variant +/// object value, since the bytes can be copied across directly without re-encoding their field ids. +/// +/// NOTE: [`Self::finish`] is a no-op. If the intent is to make a copy of the underlying bytes each +/// time `finish` is called, a different trait impl will be needed. +#[derive(Debug)] +pub struct ReadOnlyMetadataBuilder<'m> { + metadata: VariantMetadata<'m>, + // A cache that tracks field names this builder has already seen, because finding the field id + // for a given field name is expensive -- O(n) for a large and unsorted metadata dictionary. + known_field_names: HashMap<&'m str, u32>, +} + +impl<'m> ReadOnlyMetadataBuilder<'m> { + /// Creates a new read-only metadata builder from the given metadata dictionary. + pub fn new(metadata: VariantMetadata<'m>) -> Self { + Self { + metadata, + known_field_names: HashMap::new(), + } + } +} + +impl MetadataBuilder for ReadOnlyMetadataBuilder<'_> { + fn try_upsert_field_name(&mut self, field_name: &str) -> Result { + if let Some(field_id) = self.known_field_names.get(field_name) { + return Ok(*field_id); + } + + let Some((field_id, field_name)) = self.metadata.get_entry(field_name) else { + return Err(ArrowError::InvalidArgumentError(format!( + "Field name '{field_name}' not found in metadata dictionary" + ))); + }; + + self.known_field_names.insert(field_name, field_id); + Ok(field_id) + } + fn field_name(&self, field_id: usize) -> &str { + &self.metadata[field_id] + } + fn num_field_names(&self) -> usize { + self.metadata.len() + } + fn truncate_field_names(&mut self, new_size: usize) { + debug_assert_eq!(self.metadata.len(), new_size); + } + fn finish(&mut self) -> usize { + self.metadata.bytes.len() + } +} + /// Builder for constructing metadata for [`Variant`] values. /// /// This is used internally by the [`VariantBuilder`] to construct the metadata /// /// You can use an existing `Vec` as the metadata buffer by using the `from` impl. #[derive(Default, Debug)] -struct MetadataBuilder { +pub struct WritableMetadataBuilder { // Field names -- field_ids are assigned in insert order field_names: IndexSet, @@ -444,17 +560,7 @@ struct MetadataBuilder { metadata_buffer: Vec, } -/// Create a new MetadataBuilder that will write to the specified metadata buffer -impl From> for MetadataBuilder { - fn from(metadata_buffer: Vec) -> Self { - Self { - metadata_buffer, - ..Default::default() - } - } -} - -impl MetadataBuilder { +impl WritableMetadataBuilder { /// Upsert field name to dictionary, return its ID fn upsert_field_name(&mut self, field_name: &str) -> u32 { let (id, new_entry) = self.field_names.insert_full(field_name.to_string()); @@ -473,6 +579,11 @@ impl MetadataBuilder { id as u32 } + /// The current length of the underlying metadata buffer + pub fn offset(&self) -> usize { + self.metadata_buffer.len() + } + /// Returns the number of field names stored in the metadata builder. /// Note: this method should be the only place to call `self.field_names.len()` /// @@ -494,17 +605,18 @@ impl MetadataBuilder { self.field_names.iter().map(|k| k.len()).sum() } - fn finish(self) -> Vec { + /// Finalizes the metadata dictionary and appends its serialized bytes to the underlying buffer, + /// returning the resulting [`Self::offset`]. The builder state is reset and ready to start + /// building a new metadata dictionary. + pub fn finish(&mut self) -> usize { let nkeys = self.num_field_names(); // Calculate metadata size let total_dict_size: usize = self.metadata_size(); - let Self { - field_names, - is_sorted, - mut metadata_buffer, - } = self; + let metadata_buffer = &mut self.metadata_buffer; + let is_sorted = std::mem::take(&mut self.is_sorted); + let field_names = std::mem::take(&mut self.field_names); // Determine appropriate offset size based on the larger of dict size or total string size let max_offset = std::cmp::max(total_dict_size, nkeys); @@ -520,32 +632,32 @@ impl MetadataBuilder { metadata_buffer.push(0x01 | (is_sorted as u8) << 4 | ((offset_size - 1) << 6)); // Write dictionary size - write_offset(&mut metadata_buffer, nkeys, offset_size); + write_offset(metadata_buffer, nkeys, offset_size); // Write offsets let mut cur_offset = 0; for key in field_names.iter() { - write_offset(&mut metadata_buffer, cur_offset, offset_size); + write_offset(metadata_buffer, cur_offset, offset_size); cur_offset += key.len(); } // Write final offset - write_offset(&mut metadata_buffer, cur_offset, offset_size); + write_offset(metadata_buffer, cur_offset, offset_size); // Write string data for key in field_names { metadata_buffer.extend_from_slice(key.as_bytes()); } - metadata_buffer + metadata_buffer.len() } - /// Return the inner buffer, without finalizing any in progress metadata. - pub(crate) fn take_buffer(self) -> Vec { + /// Returns the inner buffer, consuming self without finalizing any in progress metadata. + pub fn into_inner(self) -> Vec { self.metadata_buffer } } -impl> FromIterator for MetadataBuilder { +impl> FromIterator for WritableMetadataBuilder { fn from_iter>(iter: T) -> Self { let mut this = Self::default(); this.extend(iter); @@ -554,7 +666,7 @@ impl> FromIterator for MetadataBuilder { } } -impl> Extend for MetadataBuilder { +impl> Extend for WritableMetadataBuilder { fn extend>(&mut self, iter: T) { let iter = iter.into_iter(); let (min, _) = iter.size_hint(); @@ -567,7 +679,64 @@ impl> Extend for MetadataBuilder { } } -/// Tracks information needed to correctly finalize a nested builder, for each parent builder type. +/// A trait for managing state specific to different builder types. +pub trait BuilderSpecificState: std::fmt::Debug { + /// Called by [`ParentState::finish`] to apply any pending builder-specific changes. + /// + /// The provided implementation does nothing by default. + /// + /// Parameters: + /// - `metadata_builder`: The metadata builder that was used + /// - `value_builder`: The value builder that was used + fn finish( + &mut self, + _metadata_builder: &mut dyn MetadataBuilder, + _value_builder: &mut ValueBuilder, + ) { + } + + /// Called by [`ParentState::drop`] to revert any changes that were eagerly applied, if + /// [`ParentState::finish`] was never invoked. + /// + /// The provided implementation does nothing by default. + /// + /// The base [`ParentState`] will handle rolling back the value and metadata builders, + /// but builder-specific state may need to revert its own changes. + fn rollback(&mut self) {} +} + +/// Empty no-op implementation for top-level variant building +impl BuilderSpecificState for () {} + +/// Internal state for list building +#[derive(Debug)] +pub struct ListState<'a> { + offsets: &'a mut Vec, + saved_offsets_size: usize, +} + +// `ListBuilder::finish()` eagerly updates the list offsets, which we should rollback on failure. +impl BuilderSpecificState for ListState<'_> { + fn rollback(&mut self) { + self.offsets.truncate(self.saved_offsets_size); + } +} + +/// Internal state for object building +#[derive(Debug)] +pub struct ObjectState<'a> { + fields: &'a mut IndexMap, + saved_fields_size: usize, +} + +// `ObjectBuilder::finish()` eagerly updates the field offsets, which we should rollback on failure. +impl BuilderSpecificState for ObjectState<'_> { + fn rollback(&mut self) { + self.fields.truncate(self.saved_fields_size); + } +} + +/// Tracks information needed to correctly finalize a nested builder. /// /// A child builder has no effect on its parent unless/until its `finalize` method is called, at /// which point the child appends the new value to the parent. As a (desirable) side effect, @@ -575,122 +744,162 @@ impl> Extend for MetadataBuilder { /// rendering the parent object completely unusable until the parent state goes out of scope. This /// ensures that at most one child builder can exist at a time. /// -/// The redundancy in buffer and metadata_builder is because all the references come from the -/// parent, and we cannot "split" a mutable reference across two objects (parent state and the child -/// builder that uses it). So everything has to be here. Rust layout optimizations should treat the -/// variants as a union, so that accessing a `buffer` or `metadata_builder` is branch-free. -enum ParentState<'a> { - Variant { - buffer: &'a mut ValueBuffer, - metadata_builder: &'a mut MetadataBuilder, - }, - List { - buffer: &'a mut ValueBuffer, - metadata_builder: &'a mut MetadataBuilder, - parent_value_offset_base: usize, - offsets: &'a mut Vec, - }, - Object { - buffer: &'a mut ValueBuffer, - metadata_builder: &'a mut MetadataBuilder, - fields: &'a mut IndexMap, - field_name: &'a str, - parent_value_offset_base: usize, - }, +/// The redundancy in `value_builder` and `metadata_builder` is because all the references come from +/// the parent, and we cannot "split" a mutable reference across two objects (parent state and the +/// child builder that uses it). So everything has to be here. +#[derive(Debug)] +pub struct ParentState<'a, S: BuilderSpecificState> { + value_builder: &'a mut ValueBuilder, + saved_value_builder_offset: usize, + metadata_builder: &'a mut dyn MetadataBuilder, + saved_metadata_builder_dict_size: usize, + builder_state: S, + finished: bool, } -impl ParentState<'_> { - fn buffer(&mut self) -> &mut ValueBuffer { - match self { - ParentState::Variant { buffer, .. } => buffer, - ParentState::List { buffer, .. } => buffer, - ParentState::Object { buffer, .. } => buffer, +impl<'a, S: BuilderSpecificState> ParentState<'a, S> { + /// Creates a new ParentState instance. The value and metadata builder + /// state is checkpointed and will roll back on drop, unless [`Self::finish`] is called. The + /// builder-specific state is governed by its own `finish` and `rollback` calls. + pub fn new( + value_builder: &'a mut ValueBuilder, + metadata_builder: &'a mut dyn MetadataBuilder, + builder_state: S, + ) -> Self { + Self { + saved_value_builder_offset: value_builder.offset(), + value_builder, + saved_metadata_builder_dict_size: metadata_builder.num_field_names(), + metadata_builder, + builder_state, + finished: false, } } - fn metadata_builder(&mut self) -> &mut MetadataBuilder { - match self { - ParentState::Variant { - metadata_builder, .. - } => metadata_builder, - ParentState::List { - metadata_builder, .. - } => metadata_builder, - ParentState::Object { - metadata_builder, .. - } => metadata_builder, - } + /// Marks the insertion as having succeeded and invokes + /// [`BuilderSpecificState::finish`]. Internal state will no longer roll back on drop. + pub fn finish(&mut self) { + self.builder_state + .finish(self.metadata_builder, self.value_builder); + self.finished = true } - // Performs any parent-specific aspects of finishing, after the child has appended all necessary - // bytes to the parent's value buffer. ListBuilder records the new value's starting offset; - // ObjectBuilder associates the new value's starting offset with its field id; VariantBuilder - // doesn't need anything special. - fn finish(&mut self, starting_offset: usize) { - match self { - ParentState::Variant { .. } => (), - ParentState::List { - offsets, - parent_value_offset_base, - .. - } => offsets.push(starting_offset - *parent_value_offset_base), - ParentState::Object { - metadata_builder, - fields, - field_name, - parent_value_offset_base, - .. - } => { - let field_id = metadata_builder.upsert_field_name(field_name); - let shifted_start_offset = starting_offset - *parent_value_offset_base; - fields.insert(field_id, shifted_start_offset); - } + // Rolls back value and metadata builder changes and invokes [`BuilderSpecificState::rollback`]. + fn rollback(&mut self) { + if self.finished { + return; } + + self.value_builder + .inner_mut() + .truncate(self.saved_value_builder_offset); + self.metadata_builder + .truncate_field_names(self.saved_metadata_builder_dict_size); + self.builder_state.rollback(); } - /// Return mutable references to the buffer and metadata builder that this - /// parent state is using. - fn buffer_and_metadata_builder(&mut self) -> (&mut ValueBuffer, &mut MetadataBuilder) { - match self { - ParentState::Variant { - buffer, - metadata_builder, - } - | ParentState::List { - buffer, - metadata_builder, - .. - } - | ParentState::Object { - buffer, - metadata_builder, - .. - } => (buffer, metadata_builder), - } + // Useful because e.g. `let b = self.value_builder;` fails compilation. + fn value_builder(&mut self) -> &mut ValueBuilder { + self.value_builder + } + + // Useful because e.g. `let b = self.metadata_builder;` fails compilation. + fn metadata_builder(&mut self) -> &mut dyn MetadataBuilder { + self.metadata_builder } +} - // Return the offset of the underlying buffer at the time of calling this method. - fn buffer_current_offset(&self) -> usize { - match self { - ParentState::Variant { buffer, .. } - | ParentState::Object { buffer, .. } - | ParentState::List { buffer, .. } => buffer.offset(), +impl<'a> ParentState<'a, ()> { + /// Creates a new instance suitable for a top-level variant builder + /// (e.g. [`VariantBuilder`]). The value and metadata builder state is checkpointed and will + /// roll back on drop, unless [`Self::finish`] is called. + pub fn variant( + value_builder: &'a mut ValueBuilder, + metadata_builder: &'a mut dyn MetadataBuilder, + ) -> Self { + Self::new(value_builder, metadata_builder, ()) + } +} + +impl<'a> ParentState<'a, ListState<'a>> { + /// Creates a new instance suitable for a [`ListBuilder`]. The value and metadata builder state + /// is checkpointed and will roll back on drop, unless [`Self::finish`] is called. The new + /// element's offset is also captured eagerly and will also roll back if not finished. + pub fn list( + value_builder: &'a mut ValueBuilder, + metadata_builder: &'a mut dyn MetadataBuilder, + offsets: &'a mut Vec, + saved_parent_value_builder_offset: usize, + ) -> Self { + // The saved_parent_buffer_offset is the buffer size as of when the parent builder was + // constructed. The saved_buffer_offset is the buffer size as of now (when a child builder + // is created). The variant field_offset entry for this list element is their difference. + let saved_value_builder_offset = value_builder.offset(); + let saved_offsets_size = offsets.len(); + offsets.push(saved_value_builder_offset - saved_parent_value_builder_offset); + + let builder_state = ListState { + offsets, + saved_offsets_size, + }; + Self { + saved_metadata_builder_dict_size: metadata_builder.num_field_names(), + saved_value_builder_offset, + metadata_builder, + value_builder, + builder_state, + finished: false, } } +} - // Return the current index of the undelying metadata buffer at the time of calling this method. - fn metadata_current_offset(&self) -> usize { - match self { - ParentState::Variant { - metadata_builder, .. - } - | ParentState::Object { - metadata_builder, .. - } - | ParentState::List { - metadata_builder, .. - } => metadata_builder.metadata_buffer.len(), +impl<'a> ParentState<'a, ObjectState<'a>> { + /// Creates a new instance suitable for an [`ObjectBuilder`]. The value and metadata builder state + /// is checkpointed and will roll back on drop, unless [`Self::finish`] is called. The new + /// field's name and offset are also captured eagerly and will also roll back if not finished. + /// + /// The call fails if the field name is invalid (e.g. because it duplicates an existing field). + pub fn try_object( + value_builder: &'a mut ValueBuilder, + metadata_builder: &'a mut dyn MetadataBuilder, + fields: &'a mut IndexMap, + saved_parent_value_builder_offset: usize, + field_name: &str, + validate_unique_fields: bool, + ) -> Result { + // The saved_parent_buffer_offset is the buffer size as of when the parent builder was + // constructed. The saved_buffer_offset is the buffer size as of now (when a child builder + // is created). The variant field_offset entry for this field is their difference. + let saved_value_builder_offset = value_builder.offset(); + let saved_fields_size = fields.len(); + let saved_metadata_builder_dict_size = metadata_builder.num_field_names(); + let field_id = metadata_builder.try_upsert_field_name(field_name)?; + let field_start = saved_value_builder_offset - saved_parent_value_builder_offset; + if fields.insert(field_id, field_start).is_some() && validate_unique_fields { + return Err(ArrowError::InvalidArgumentError(format!( + "Duplicate field name: {field_name}" + ))); } + + let builder_state = ObjectState { + fields, + saved_fields_size, + }; + Ok(Self { + saved_metadata_builder_dict_size, + saved_value_builder_offset, + value_builder, + metadata_builder, + builder_state, + finished: false, + }) + } +} + +/// Automatically rolls back any unfinished `ParentState`. +impl Drop for ParentState<'_, S> { + fn drop(&mut self) { + self.rollback() } } @@ -874,56 +1083,20 @@ impl ParentState<'_> { /// ); /// /// ``` -/// # Example: Reusing Buffers -/// -/// You can use the [`VariantBuilder`] to write into existing buffers (for -/// example to write multiple variants back to back in the same buffer) -/// -/// ``` -/// // we will write two variants back to back -/// use parquet_variant::{Variant, VariantBuilder}; -/// // Append 12345 -/// let mut builder = VariantBuilder::new(); -/// builder.append_value(12345); -/// let (metadata, value) = builder.finish(); -/// // remember where the first variant ends -/// let (first_meta_offset, first_meta_len) = (0, metadata.len()); -/// let (first_value_offset, first_value_len) = (0, value.len()); -/// -/// // now, append a second variant to the same buffers -/// let mut builder = VariantBuilder::new_with_buffers(metadata, value); -/// builder.append_value("Foo"); -/// let (metadata, value) = builder.finish(); -/// -/// // The variants can be referenced in their appropriate location -/// let variant1 = Variant::new( -/// &metadata[first_meta_offset..first_meta_len], -/// &value[first_value_offset..first_value_len] -/// ); -/// assert_eq!(variant1, Variant::Int32(12345)); -/// -/// let variant2 = Variant::new( -/// &metadata[first_meta_len..], -/// &value[first_value_len..] -/// ); -/// assert_eq!(variant2, Variant::from("Foo")); -/// ``` -/// /// # Example: Unique Field Validation /// /// This example shows how enabling unique field validation will cause an error /// if the same field is inserted more than once. /// ``` -/// use parquet_variant::VariantBuilder; -/// +/// # use parquet_variant::VariantBuilder; +/// # /// let mut builder = VariantBuilder::new().with_validate_unique_fields(true); -/// let mut obj = builder.new_object(); -/// -/// obj.insert("a", 1); -/// obj.insert("a", 2); // duplicate field /// -/// // When validation is enabled, finish will return an error -/// let result = obj.finish(); // returns Err +/// // When validation is enabled, try_with_field will return an error +/// let result = builder +/// .new_object() +/// .with_field("a", 1) +/// .try_with_field("a", 2); /// assert!(result.is_err()); /// ``` /// @@ -942,7 +1115,7 @@ impl ParentState<'_> { /// obj.insert("name", "Alice"); /// obj.insert("age", 30); /// obj.insert("score", 95.5); -/// obj.finish().unwrap(); +/// obj.finish(); /// /// let (metadata, value) = builder.finish(); /// let variant = Variant::try_new(&metadata, &value).unwrap(); @@ -960,24 +1133,24 @@ impl ParentState<'_> { /// obj.insert("name", "Bob"); // field id = 3 /// obj.insert("age", 25); /// obj.insert("score", 88.0); -/// obj.finish().unwrap(); +/// obj.finish(); /// /// let (metadata, value) = builder.finish(); /// let variant = Variant::try_new(&metadata, &value).unwrap(); /// ``` #[derive(Default, Debug)] pub struct VariantBuilder { - buffer: ValueBuffer, - metadata_builder: MetadataBuilder, + value_builder: ValueBuilder, + metadata_builder: WritableMetadataBuilder, validate_unique_fields: bool, } impl VariantBuilder { - /// Create a new VariantBuilder with new underlying buffer + /// Create a new VariantBuilder with new underlying buffers pub fn new() -> Self { Self { - buffer: ValueBuffer::new(), - metadata_builder: MetadataBuilder::default(), + value_builder: ValueBuilder::new(), + metadata_builder: WritableMetadataBuilder::default(), validate_unique_fields: false, } } @@ -989,16 +1162,6 @@ impl VariantBuilder { self } - /// Create a new VariantBuilder that will write the metadata and values to - /// the specified buffers. - pub fn new_with_buffers(metadata_buffer: Vec, value_buffer: Vec) -> Self { - Self { - buffer: ValueBuffer::from(value_buffer), - metadata_builder: MetadataBuilder::from(metadata_buffer), - validate_unique_fields: false, - } - } - /// Enables validation of unique field keys in nested objects. /// /// This setting is propagated to all [`ObjectBuilder`]s created through this [`VariantBuilder`] @@ -1015,12 +1178,34 @@ impl VariantBuilder { /// You can use this to pre-populate a [`VariantBuilder`] with a sorted dictionary if you /// know the field names beforehand. Sorted dictionaries can accelerate field access when /// reading [`Variant`]s. - pub fn with_field_names<'a>(mut self, field_names: impl Iterator) -> Self { + pub fn with_field_names<'a>(mut self, field_names: impl IntoIterator) -> Self { self.metadata_builder.extend(field_names); self } + /// Builder-style API for appending a value to the list and returning self to enable method chaining. + /// + /// # Panics + /// + /// This method will panic if the variant contains duplicate field names in objects + /// when validation is enabled. For a fallible version, use [`ListBuilder::try_with_value`]. + pub fn with_value<'m, 'd, T: Into>>(mut self, value: T) -> Self { + self.append_value(value); + self + } + + /// Builder-style API for appending a value to the list and returns self for method chaining. + /// + /// This is the fallible version of [`ListBuilder::with_value`]. + pub fn try_with_value<'m, 'd, T: Into>>( + mut self, + value: T, + ) -> Result { + self.try_append_value(value)?; + Ok(self) + } + /// This method reserves capacity for field names in the Variant metadata, /// which can improve performance when you know the approximate number of unique field /// names that will be used across all objects in the [`Variant`]. @@ -1035,29 +1220,22 @@ impl VariantBuilder { self.metadata_builder.upsert_field_name(field_name); } - // Returns validate_unique_fields because we can no longer reference self once this method returns. - fn parent_state(&mut self) -> (ParentState<'_>, bool) { - let state = ParentState::Variant { - buffer: &mut self.buffer, - metadata_builder: &mut self.metadata_builder, - }; - (state, self.validate_unique_fields) - } - /// Create an [`ListBuilder`] for creating [`Variant::List`] values. /// /// See the examples on [`VariantBuilder`] for usage. - pub fn new_list(&mut self) -> ListBuilder<'_> { - let (parent_state, validate_unique_fields) = self.parent_state(); - ListBuilder::new(parent_state, validate_unique_fields) + pub fn new_list(&mut self) -> ListBuilder<'_, ()> { + let parent_state = + ParentState::variant(&mut self.value_builder, &mut self.metadata_builder); + ListBuilder::new(parent_state, self.validate_unique_fields) } /// Create an [`ObjectBuilder`] for creating [`Variant::Object`] values. /// /// See the examples on [`VariantBuilder`] for usage. - pub fn new_object(&mut self) -> ObjectBuilder<'_> { - let (parent_state, validate_unique_fields) = self.parent_state(); - ObjectBuilder::new(parent_state, validate_unique_fields) + pub fn new_object(&mut self) -> ObjectBuilder<'_, ()> { + let parent_state = + ParentState::variant(&mut self.value_builder, &mut self.metadata_builder); + ObjectBuilder::new(parent_state, self.validate_unique_fields) } /// Append a value to the builder. @@ -1075,9 +1253,8 @@ impl VariantBuilder { /// builder.append_value(42i8); /// ``` pub fn append_value<'m, 'd, T: Into>>(&mut self, value: T) { - let variant = value.into(); - self.buffer - .append_variant(variant, &mut self.metadata_builder); + let state = ParentState::variant(&mut self.value_builder, &mut self.metadata_builder); + ValueBuilder::append_variant(state, value.into()) } /// Append a value to the builder. @@ -1085,27 +1262,29 @@ impl VariantBuilder { &mut self, value: T, ) -> Result<(), ArrowError> { - let variant = value.into(); - self.buffer - .try_append_variant(variant, &mut self.metadata_builder)?; - - Ok(()) + let state = ParentState::variant(&mut self.value_builder, &mut self.metadata_builder); + ValueBuilder::try_append_variant(state, value.into()) } - /// Finish the builder and return the metadata and value buffers. - pub fn finish(self) -> (Vec, Vec) { - (self.metadata_builder.finish(), self.buffer.into_inner()) + /// Appends a variant value to the builder by copying raw bytes when possible. + /// + /// For objects and lists, this directly copies their underlying byte representation instead of + /// performing a logical copy and without touching the metadata builder. For other variant + /// types, this falls back to the standard append behavior. + /// + /// The caller must ensure that the metadata dictionary entries are already built and correct for + /// any objects or lists being appended. + pub fn append_value_bytes<'m, 'd>(&mut self, value: impl Into>) { + let state = ParentState::variant(&mut self.value_builder, &mut self.metadata_builder); + ValueBuilder::append_variant_bytes(state, value.into()); } - /// Return the inner metadata buffers and value buffer. - /// - /// This can be used to get the underlying buffers provided via - /// [`VariantBuilder::new_with_buffers`] without finalizing the metadata or - /// values (for rolling back changes). - pub fn into_buffers(self) -> (Vec, Vec) { + /// Finish the builder and return the metadata and value buffers. + pub fn finish(mut self) -> (Vec, Vec) { + self.metadata_builder.finish(); ( - self.metadata_builder.take_buffer(), - self.buffer.into_inner(), + self.metadata_builder.into_inner(), + self.value_builder.into_inner(), ) } } @@ -1113,30 +1292,19 @@ impl VariantBuilder { /// A builder for creating [`Variant::List`] values. /// /// See the examples on [`VariantBuilder`] for usage. -pub struct ListBuilder<'a> { - parent_state: ParentState<'a>, +#[derive(Debug)] +pub struct ListBuilder<'a, S: BuilderSpecificState> { + parent_state: ParentState<'a, S>, offsets: Vec, - /// The starting offset in the parent's buffer where this list starts - parent_value_offset_base: usize, - /// The starting offset in the parent's metadata buffer where this list starts - /// used to truncate the written fields in `drop` if the current list has not been finished - parent_metadata_offset_base: usize, - /// Whether the list has been finished, the written content of the current list - /// will be truncated in `drop` if `has_been_finished` is false - has_been_finished: bool, validate_unique_fields: bool, } -impl<'a> ListBuilder<'a> { - fn new(parent_state: ParentState<'a>, validate_unique_fields: bool) -> Self { - let parent_value_offset_base = parent_state.buffer_current_offset(); - let parent_metadata_offset_base = parent_state.metadata_current_offset(); +impl<'a, S: BuilderSpecificState> ListBuilder<'a, S> { + /// Creates a new list builder, nested on top of the given parent state. + pub fn new(parent_state: ParentState<'a, S>, validate_unique_fields: bool) -> Self { Self { parent_state, offsets: vec![], - parent_value_offset_base, - has_been_finished: false, - parent_metadata_offset_base, validate_unique_fields, } } @@ -1151,22 +1319,20 @@ impl<'a> ListBuilder<'a> { } // Returns validate_unique_fields because we can no longer reference self once this method returns. - fn parent_state(&mut self) -> (ParentState<'_>, bool) { - let (buffer, metadata_builder) = self.parent_state.buffer_and_metadata_builder(); - - let state = ParentState::List { - buffer, - metadata_builder, - parent_value_offset_base: self.parent_value_offset_base, - offsets: &mut self.offsets, - }; + fn parent_state(&mut self) -> (ParentState<'_, ListState<'_>>, bool) { + let state = ParentState::list( + self.parent_state.value_builder, + self.parent_state.metadata_builder, + &mut self.offsets, + self.parent_state.saved_value_builder_offset, + ); (state, self.validate_unique_fields) } /// Returns an object builder that can be used to append a new (nested) object to this list. /// /// WARNING: The builder will have no effect unless/until [`ObjectBuilder::finish`] is called. - pub fn new_object(&mut self) -> ObjectBuilder<'_> { + pub fn new_object(&mut self) -> ObjectBuilder<'_, ListState<'_>> { let (parent_state, validate_unique_fields) = self.parent_state(); ObjectBuilder::new(parent_state, validate_unique_fields) } @@ -1174,7 +1340,7 @@ impl<'a> ListBuilder<'a> { /// Returns a list builder that can be used to append a new (nested) list to this list. /// /// WARNING: The builder will have no effect unless/until [`ListBuilder::finish`] is called. - pub fn new_list(&mut self) -> ListBuilder<'_> { + pub fn new_list(&mut self) -> ListBuilder<'_, ListState<'_>> { let (parent_state, validate_unique_fields) = self.parent_state(); ListBuilder::new(parent_state, validate_unique_fields) } @@ -1186,7 +1352,8 @@ impl<'a> ListBuilder<'a> { /// This method will panic if the variant contains duplicate field names in objects /// when validation is enabled. For a fallible version, use [`ListBuilder::try_append_value`]. pub fn append_value<'m, 'd, T: Into>>(&mut self, value: T) { - self.try_append_value(value).unwrap(); + let (state, _) = self.parent_state(); + ValueBuilder::append_variant(state, value.into()) } /// Appends a new primitive value to this list @@ -1194,14 +1361,21 @@ impl<'a> ListBuilder<'a> { &mut self, value: T, ) -> Result<(), ArrowError> { - let (buffer, metadata_builder) = self.parent_state.buffer_and_metadata_builder(); - - let offset = buffer.offset() - self.parent_value_offset_base; - self.offsets.push(offset); - - buffer.try_append_variant(value.into(), metadata_builder)?; + let (state, _) = self.parent_state(); + ValueBuilder::try_append_variant(state, value.into()) + } - Ok(()) + /// Appends a variant value to this list by copying raw bytes when possible. + /// + /// For objects and lists, this directly copies their underlying byte representation instead of + /// performing a logical copy. For other variant types, this falls back to the standard append + /// behavior. + /// + /// The caller must ensure that the metadata dictionary is already built and correct for + /// any objects or lists being appended. + pub fn append_value_bytes<'m, 'd>(&mut self, value: impl Into>) { + let (state, _) = self.parent_state(); + ValueBuilder::append_variant_bytes(state, value.into()) } /// Builder-style API for appending a value to the list and returning self to enable method chaining. @@ -1228,19 +1402,18 @@ impl<'a> ListBuilder<'a> { /// Finalizes this list and appends it to its parent, which otherwise remains unmodified. pub fn finish(mut self) { - let buffer = self.parent_state.buffer(); + let starting_offset = self.parent_state.saved_value_builder_offset; + let value_builder = self.parent_state.value_builder(); - let data_size = buffer + let data_size = value_builder .offset() - .checked_sub(self.parent_value_offset_base) + .checked_sub(starting_offset) .expect("Data size overflowed usize"); let num_elements = self.offsets.len(); let is_large = num_elements > u8::MAX as usize; let offset_size = int_size(data_size); - let starting_offset = self.parent_value_offset_base; - let num_elements_size = if is_large { 4 } else { 1 }; // is_large: 4 bytes, else 1 byte. let num_elements = self.offsets.len(); let header_size = 1 + // header (i.e., `array_header`) @@ -1262,65 +1435,31 @@ impl<'a> ListBuilder<'a> { append_packed_u32(&mut bytes_to_splice, data_size as u32, offset_size as usize); - buffer + value_builder .inner_mut() .splice(starting_offset..starting_offset, bytes_to_splice); - self.parent_state.finish(starting_offset); - self.has_been_finished = true; - } -} - -/// Drop implementation for ListBuilder does nothing -/// as the `finish` method must be called to finalize the list. -/// This is to ensure that the list is always finalized before its parent builder -/// is finalized. -impl Drop for ListBuilder<'_> { - fn drop(&mut self) { - if !self.has_been_finished { - self.parent_state - .buffer() - .inner_mut() - .truncate(self.parent_value_offset_base); - self.parent_state - .metadata_builder() - .field_names - .truncate(self.parent_metadata_offset_base); - } + self.parent_state.finish(); } } /// A builder for creating [`Variant::Object`] values. /// /// See the examples on [`VariantBuilder`] for usage. -pub struct ObjectBuilder<'a> { - parent_state: ParentState<'a>, +#[derive(Debug)] +pub struct ObjectBuilder<'a, S: BuilderSpecificState> { + parent_state: ParentState<'a, S>, fields: IndexMap, // (field_id, offset) - /// The starting offset in the parent's buffer where this object starts - parent_value_offset_base: usize, - /// The starting offset in the parent's metadata buffer where this object starts - /// used to truncate the written fields in `drop` if the current object has not been finished - parent_metadata_offset_base: usize, - /// Whether the object has been finished, the written content of the current object - /// will be truncated in `drop` if `has_been_finished` is false - has_been_finished: bool, validate_unique_fields: bool, - /// Set of duplicate fields to report for errors - duplicate_fields: HashSet, } -impl<'a> ObjectBuilder<'a> { - fn new(parent_state: ParentState<'a>, validate_unique_fields: bool) -> Self { - let offset_base = parent_state.buffer_current_offset(); - let meta_offset_base = parent_state.metadata_current_offset(); +impl<'a, S: BuilderSpecificState> ObjectBuilder<'a, S> { + /// Creates a new object builder, nested on top of the given parent state. + pub fn new(parent_state: ParentState<'a, S>, validate_unique_fields: bool) -> Self { Self { parent_state, fields: IndexMap::new(), - parent_value_offset_base: offset_base, - has_been_finished: false, - parent_metadata_offset_base: meta_offset_base, validate_unique_fields, - duplicate_fields: HashSet::new(), } } @@ -1335,33 +1474,65 @@ impl<'a> ObjectBuilder<'a> { /// This method will panic if the variant contains duplicate field names in objects /// when validation is enabled. For a fallible version, use [`ObjectBuilder::try_insert`] pub fn insert<'m, 'd, T: Into>>(&mut self, key: &str, value: T) { - self.try_insert(key, value).unwrap(); + let (state, _) = self.parent_state(key).unwrap(); + ValueBuilder::append_variant(state, value.into()) } /// Add a field with key and value to the object /// /// # See Also - /// - [`ObjectBuilder::insert`] for a infallabel version + /// - [`ObjectBuilder::insert`] for an infallible version that panics /// - [`ObjectBuilder::try_with_field`] for a builder-style API. /// /// # Note - /// When inserting duplicate keys, the new value overwrites the previous mapping, - /// but the old value remains in the buffer, resulting in a larger variant + /// Attempting to insert a duplicate field name produces an error if unique field + /// validation is enabled. Otherwise, the new value overwrites the previous field mapping + /// without erasing the old value, resulting in a larger variant pub fn try_insert<'m, 'd, T: Into>>( &mut self, key: &str, value: T, ) -> Result<(), ArrowError> { - let (buffer, metadata_builder) = self.parent_state.buffer_and_metadata_builder(); - - let field_id = metadata_builder.upsert_field_name(key); - let field_start = buffer.offset() - self.parent_value_offset_base; + let (state, _) = self.parent_state(key)?; + ValueBuilder::try_append_variant(state, value.into()) + } - if self.fields.insert(field_id, field_start).is_some() && self.validate_unique_fields { - self.duplicate_fields.insert(field_id); - } + /// Add a field with key and value to the object by copying raw bytes when possible. + /// + /// For objects and lists, this directly copies their underlying byte representation instead of + /// performing a logical copy, and without touching the metadata builder. For other variant + /// types, this falls back to the standard append behavior. + /// + /// The caller must ensure that the metadata dictionary is already built and correct for + /// any objects or lists being appended, but the value's new field name is handled normally. + /// + /// # Panics + /// + /// This method will panic if the variant contains duplicate field names in objects + /// when validation is enabled. For a fallible version, use [`ObjectBuilder::try_insert_bytes`] + pub fn insert_bytes<'m, 'd>(&mut self, key: &str, value: impl Into>) { + self.try_insert_bytes(key, value).unwrap() + } - buffer.try_append_variant(value.into(), metadata_builder)?; + /// Add a field with key and value to the object by copying raw bytes when possible. + /// + /// For objects and lists, this directly copies their underlying byte representation instead of + /// performing a logical copy, and without touching the metadata builder. For other variant + /// types, this falls back to the standard append behavior. + /// + /// The caller must ensure that the metadata dictionary is already built and correct for + /// any objects or lists being appended, but the value's new field name is handled normally. + /// + /// # Note + /// When inserting duplicate keys, the new value overwrites the previous mapping, + /// but the old value remains in the buffer, resulting in a larger variant + pub fn try_insert_bytes<'m, 'd>( + &mut self, + key: &str, + value: impl Into>, + ) -> Result<(), ArrowError> { + let (state, _) = self.parent_state(key)?; + ValueBuilder::append_variant_bytes(state, value.into()); Ok(()) } @@ -1395,54 +1566,69 @@ impl<'a> ObjectBuilder<'a> { } // Returns validate_unique_fields because we can no longer reference self once this method returns. - fn parent_state<'b>(&'b mut self, key: &'b str) -> (ParentState<'b>, bool) { + fn parent_state<'b>( + &'b mut self, + field_name: &str, + ) -> Result<(ParentState<'b, ObjectState<'b>>, bool), ArrowError> { let validate_unique_fields = self.validate_unique_fields; + let state = ParentState::try_object( + self.parent_state.value_builder, + self.parent_state.metadata_builder, + &mut self.fields, + self.parent_state.saved_value_builder_offset, + field_name, + validate_unique_fields, + )?; + Ok((state, validate_unique_fields)) + } - let (buffer, metadata_builder) = self.parent_state.buffer_and_metadata_builder(); - - let state = ParentState::Object { - buffer, - metadata_builder, - fields: &mut self.fields, - field_name: key, - parent_value_offset_base: self.parent_value_offset_base, - }; - (state, validate_unique_fields) + /// Returns an object builder that can be used to append a new (nested) object to this object. + /// + /// Panics if the proposed key was a duplicate + /// + /// WARNING: The builder will have no effect unless/until [`ObjectBuilder::finish`] is called. + pub fn new_object<'b>(&'b mut self, key: &'b str) -> ObjectBuilder<'b, ObjectState<'b>> { + self.try_new_object(key).unwrap() } /// Returns an object builder that can be used to append a new (nested) object to this object. /// + /// Fails if the proposed key was a duplicate + /// /// WARNING: The builder will have no effect unless/until [`ObjectBuilder::finish`] is called. - pub fn new_object<'b>(&'b mut self, key: &'b str) -> ObjectBuilder<'b> { - let (parent_state, validate_unique_fields) = self.parent_state(key); - ObjectBuilder::new(parent_state, validate_unique_fields) + pub fn try_new_object<'b>( + &'b mut self, + key: &str, + ) -> Result>, ArrowError> { + let (parent_state, validate_unique_fields) = self.parent_state(key)?; + Ok(ObjectBuilder::new(parent_state, validate_unique_fields)) } /// Returns a list builder that can be used to append a new (nested) list to this object. /// + /// Panics if the proposed key was a duplicate + /// /// WARNING: The builder will have no effect unless/until [`ListBuilder::finish`] is called. - pub fn new_list<'b>(&'b mut self, key: &'b str) -> ListBuilder<'b> { - let (parent_state, validate_unique_fields) = self.parent_state(key); - ListBuilder::new(parent_state, validate_unique_fields) + pub fn new_list<'b>(&'b mut self, key: &str) -> ListBuilder<'b, ObjectState<'b>> { + self.try_new_list(key).unwrap() + } + + /// Returns a list builder that can be used to append a new (nested) list to this object. + /// + /// Fails if the proposed key was a duplicate + /// + /// WARNING: The builder will have no effect unless/until [`ListBuilder::finish`] is called. + pub fn try_new_list<'b>( + &'b mut self, + key: &str, + ) -> Result>, ArrowError> { + let (parent_state, validate_unique_fields) = self.parent_state(key)?; + Ok(ListBuilder::new(parent_state, validate_unique_fields)) } /// Finalizes this object and appends it to its parent, which otherwise remains unmodified. - pub fn finish(mut self) -> Result<(), ArrowError> { + pub fn finish(mut self) { let metadata_builder = self.parent_state.metadata_builder(); - if self.validate_unique_fields && !self.duplicate_fields.is_empty() { - let mut names = self - .duplicate_fields - .iter() - .map(|id| metadata_builder.field_name(*id as usize)) - .collect::>(); - - names.sort_unstable(); - - let joined = names.join(", "); - return Err(ArrowError::InvalidArgumentError(format!( - "Duplicate field keys detected: [{joined}]", - ))); - } self.fields.sort_by(|&field_a_id, _, &field_b_id, _| { let field_a_name = metadata_builder.field_name(field_a_id as usize); @@ -1453,10 +1639,11 @@ impl<'a> ObjectBuilder<'a> { let max_id = self.fields.iter().map(|(i, _)| *i).max().unwrap_or(0); let id_size = int_size(max_id as usize); - let parent_buffer = self.parent_state.buffer(); - let current_offset = parent_buffer.offset(); + let starting_offset = self.parent_state.saved_value_builder_offset; + let value_builder = self.parent_state.value_builder(); + let current_offset = value_builder.offset(); // Current object starts from `object_start_offset` - let data_size = current_offset - self.parent_value_offset_base; + let data_size = current_offset - starting_offset; let offset_size = int_size(data_size); let num_fields = self.fields.len(); @@ -1467,11 +1654,8 @@ impl<'a> ObjectBuilder<'a> { (num_fields * id_size as usize) + // field IDs ((num_fields + 1) * offset_size as usize); // field offsets + data_size - let starting_offset = self.parent_value_offset_base; - // Shift existing data to make room for the header - let buffer = parent_buffer.inner_mut(); - buffer.splice( + value_builder.inner_mut().splice( starting_offset..starting_offset, std::iter::repeat_n(0u8, header_size), ); @@ -1484,12 +1668,12 @@ impl<'a> ObjectBuilder<'a> { header_pos = self .parent_state - .buffer() + .value_builder() .append_header_start_from_buf_pos(header_pos, header, is_large, num_fields); header_pos = self .parent_state - .buffer() + .value_builder() .append_offset_array_start_from_buf_pos( header_pos, self.fields.keys().copied().map(|id| id as usize), @@ -1498,40 +1682,14 @@ impl<'a> ObjectBuilder<'a> { ); self.parent_state - .buffer() + .value_builder() .append_offset_array_start_from_buf_pos( header_pos, self.fields.values().copied(), Some(data_size), offset_size, ); - self.parent_state.finish(starting_offset); - - // Mark that this object has been finished - self.has_been_finished = true; - - Ok(()) - } -} - -/// Drop implementation for ObjectBuilder does nothing -/// as the `finish` method must be called to finalize the object. -/// This is to ensure that the object is always finalized before its parent builder -/// is finalized. -impl Drop for ObjectBuilder<'_> { - fn drop(&mut self) { - // Truncate the buffer if the `finish` method has not been called. - if !self.has_been_finished { - self.parent_state - .buffer() - .inner_mut() - .truncate(self.parent_value_offset_base); - - self.parent_state - .metadata_builder() - .field_names - .truncate(self.parent_metadata_offset_base); - } + self.parent_state.finish(); } } @@ -1540,38 +1698,116 @@ impl Drop for ObjectBuilder<'_> { /// Allows users to append values to a [`VariantBuilder`], [`ListBuilder`] or /// [`ObjectBuilder`]. using the same interface. pub trait VariantBuilderExt { + /// The builder specific state used by nested builders + type State<'a>: BuilderSpecificState + 'a + where + Self: 'a; + + /// Appends a NULL value to this builder. The semantics depend on the implementation, but will + /// often translate to appending a [`Variant::Null`] value. + fn append_null(&mut self); + + /// Appends a new variant value to this builder. See e.g. [`VariantBuilder::append_value`]. fn append_value<'m, 'v>(&mut self, value: impl Into>); - fn new_list(&mut self) -> ListBuilder<'_>; + /// Creates a nested list builder. See e.g. [`VariantBuilder::new_list`]. Panics if the nested + /// builder cannot be created, see e.g. [`ObjectBuilder::new_list`]. + fn new_list(&mut self) -> ListBuilder<'_, Self::State<'_>> { + self.try_new_list().unwrap() + } + + /// Creates a nested object builder. See e.g. [`VariantBuilder::new_object`]. Panics if the + /// nested builder cannot be created, see e.g. [`ObjectBuilder::new_object`]. + fn new_object(&mut self) -> ObjectBuilder<'_, Self::State<'_>> { + self.try_new_object().unwrap() + } + + /// Creates a nested list builder. See e.g. [`VariantBuilder::new_list`]. Returns an error if + /// the nested builder cannot be created, see e.g. [`ObjectBuilder::try_new_list`]. + fn try_new_list(&mut self) -> Result>, ArrowError>; - fn new_object(&mut self) -> ObjectBuilder<'_>; + /// Creates a nested object builder. See e.g. [`VariantBuilder::new_object`]. Returns an error + /// if the nested builder cannot be created, see e.g. [`ObjectBuilder::try_new_object`]. + fn try_new_object(&mut self) -> Result>, ArrowError>; } -impl VariantBuilderExt for ListBuilder<'_> { +impl<'a, S: BuilderSpecificState> VariantBuilderExt for ListBuilder<'a, S> { + type State<'s> + = ListState<'s> + where + Self: 's; + + /// Variant arrays cannot encode NULL values, only `Variant::Null`. + fn append_null(&mut self) { + self.append_value(Variant::Null); + } fn append_value<'m, 'v>(&mut self, value: impl Into>) { self.append_value(value); } - fn new_list(&mut self) -> ListBuilder<'_> { - self.new_list() + fn try_new_list(&mut self) -> Result>, ArrowError> { + Ok(self.new_list()) } - fn new_object(&mut self) -> ObjectBuilder<'_> { - self.new_object() + fn try_new_object(&mut self) -> Result>, ArrowError> { + Ok(self.new_object()) } } impl VariantBuilderExt for VariantBuilder { + type State<'a> + = () + where + Self: 'a; + + /// Variant values cannot encode NULL, only [`Variant::Null`]. This is different from the column + /// that holds variant values being NULL at some positions. + fn append_null(&mut self) { + self.append_value(Variant::Null); + } fn append_value<'m, 'v>(&mut self, value: impl Into>) { self.append_value(value); } - fn new_list(&mut self) -> ListBuilder<'_> { - self.new_list() + fn try_new_list(&mut self) -> Result>, ArrowError> { + Ok(self.new_list()) + } + + fn try_new_object(&mut self) -> Result>, ArrowError> { + Ok(self.new_object()) + } +} + +/// A [`VariantBuilderExt`] that inserts a new field into a variant object. +pub struct ObjectFieldBuilder<'o, 'v, 's, S: BuilderSpecificState> { + key: &'s str, + builder: &'o mut ObjectBuilder<'v, S>, +} + +impl<'o, 'v, 's, S: BuilderSpecificState> ObjectFieldBuilder<'o, 'v, 's, S> { + pub fn new(key: &'s str, builder: &'o mut ObjectBuilder<'v, S>) -> Self { + Self { key, builder } + } +} + +impl VariantBuilderExt for ObjectFieldBuilder<'_, '_, '_, S> { + type State<'a> + = ObjectState<'a> + where + Self: 'a; + + /// A NULL object field is interpreted as missing, so nothing gets inserted at all. + fn append_null(&mut self) {} + fn append_value<'m, 'v>(&mut self, value: impl Into>) { + self.builder.insert(self.key, value); + } + + fn try_new_list(&mut self) -> Result>, ArrowError> { + self.builder.try_new_list(self.key) } - fn new_object(&mut self) -> ObjectBuilder<'_> { - self.new_object() + fn try_new_object(&mut self) -> Result>, ArrowError> { + self.builder.try_new_object(self.key) } } @@ -1722,8 +1958,7 @@ mod tests { .new_object() .with_field("name", "John") .with_field("age", 42i8) - .finish() - .unwrap(); + .finish(); let (metadata, value) = builder.finish(); assert!(!metadata.is_empty()); @@ -1739,8 +1974,7 @@ mod tests { .with_field("zebra", "stripes") .with_field("apple", "red") .with_field("banana", "yellow") - .finish() - .unwrap(); + .finish(); let (_, value) = builder.finish(); @@ -1764,8 +1998,7 @@ mod tests { .new_object() .with_field("name", "Ron Artest") .with_field("name", "Metta World Peace") // Duplicate field - .finish() - .unwrap(); + .finish(); let (metadata, value) = builder.finish(); let variant = Variant::try_new(&metadata, &value).unwrap(); @@ -1884,15 +2117,13 @@ mod tests { .new_object() .with_field("id", 1) .with_field("type", "Cauliflower") - .finish() - .unwrap(); + .finish(); list_builder .new_object() .with_field("id", 2) .with_field("type", "Beets") - .finish() - .unwrap(); + .finish(); list_builder.finish(); @@ -1929,17 +2160,9 @@ mod tests { let mut list_builder = builder.new_list(); - list_builder - .new_object() - .with_field("a", 1) - .finish() - .unwrap(); + list_builder.new_object().with_field("a", 1).finish(); - list_builder - .new_object() - .with_field("b", 2) - .finish() - .unwrap(); + list_builder.new_object().with_field("b", 2).finish(); list_builder.finish(); @@ -1985,7 +2208,7 @@ mod tests { { let mut object_builder = list_builder.new_object(); object_builder.insert("a", 1); - let _ = object_builder.finish(); + object_builder.finish(); } list_builder.append_value(2); @@ -1993,7 +2216,7 @@ mod tests { { let mut object_builder = list_builder.new_object(); object_builder.insert("b", 2); - let _ = object_builder.finish(); + object_builder.finish(); } list_builder.append_value(3); @@ -2043,10 +2266,10 @@ mod tests { { let mut inner_object_builder = outer_object_builder.new_object("c"); inner_object_builder.insert("b", "a"); - let _ = inner_object_builder.finish(); + inner_object_builder.finish(); } - let _ = outer_object_builder.finish(); + outer_object_builder.finish(); } let (metadata, value) = builder.finish(); @@ -2085,11 +2308,11 @@ mod tests { inner_object_builder.insert("b", false); inner_object_builder.insert("c", "a"); - let _ = inner_object_builder.finish(); + inner_object_builder.finish(); } outer_object_builder.insert("b", false); - let _ = outer_object_builder.finish(); + outer_object_builder.finish(); } let (metadata, value) = builder.finish(); @@ -2133,10 +2356,10 @@ mod tests { .with_value(false) .finish(); - let _ = inner_object_builder.finish(); + inner_object_builder.finish(); } - let _ = outer_object_builder.finish(); + outer_object_builder.finish(); } let (metadata, value) = builder.finish(); @@ -2196,15 +2419,15 @@ mod tests { { let mut inner_inner_object_builder = inner_object_builder.new_object("c"); inner_inner_object_builder.insert("aa", "bb"); - let _ = inner_inner_object_builder.finish(); + inner_inner_object_builder.finish(); } { let mut inner_inner_object_builder = inner_object_builder.new_object("d"); inner_inner_object_builder.insert("cc", "dd"); - let _ = inner_inner_object_builder.finish(); + inner_inner_object_builder.finish(); } - let _ = inner_object_builder.finish(); + inner_object_builder.finish(); } outer_object_builder.insert("b", true); @@ -2228,10 +2451,10 @@ mod tests { inner_list_builder.finish(); } - let _ = inner_object_builder.finish(); + inner_object_builder.finish(); } - let _ = outer_object_builder.finish(); + outer_object_builder.finish(); } let (metadata, value) = builder.finish(); @@ -2331,7 +2554,7 @@ mod tests { let mut inner_object_builder = inner_list_builder.new_object(); inner_object_builder.insert("a", "b"); inner_object_builder.insert("b", "c"); - let _ = inner_object_builder.finish(); + inner_object_builder.finish(); } { @@ -2340,7 +2563,7 @@ mod tests { let mut inner_object_builder = inner_list_builder.new_object(); inner_object_builder.insert("c", "d"); inner_object_builder.insert("d", "e"); - let _ = inner_object_builder.finish(); + inner_object_builder.finish(); } inner_list_builder.finish(); @@ -2426,15 +2649,29 @@ mod tests { let mut obj = builder.new_object(); obj.insert("a", 1); obj.insert("a", 2); - assert!(obj.finish().is_ok()); + obj.finish(); // Deeply nested list structure with duplicates + let mut builder = VariantBuilder::new(); let mut outer_list = builder.new_list(); let mut inner_list = outer_list.new_list(); let mut nested_obj = inner_list.new_object(); nested_obj.insert("x", 1); nested_obj.insert("x", 2); - assert!(nested_obj.finish().is_ok()); + nested_obj.new_list("x").with_value(3).finish(); + nested_obj.new_object("x").with_field("y", 4).finish(); + nested_obj.finish(); + inner_list.finish(); + outer_list.finish(); + + // Verify the nested object is built correctly -- the nested object "x" should have "won" + let (metadata, value) = builder.finish(); + let variant = Variant::try_new(&metadata, &value).unwrap(); + let outer_element = variant.get_list_element(0).unwrap(); + let inner_element = outer_element.get_list_element(0).unwrap(); + let outer_field = inner_element.get_object_field("x").unwrap(); + let inner_field = outer_field.get_object_field("y").unwrap(); + assert_eq!(inner_field, Variant::from(4)); } #[test] @@ -2442,31 +2679,38 @@ mod tests { let mut builder = VariantBuilder::new().with_validate_unique_fields(true); // Root-level object with duplicates - let mut root_obj = builder.new_object(); - root_obj.insert("a", 1); - root_obj.insert("b", 2); - root_obj.insert("a", 3); - root_obj.insert("b", 4); - - let result = root_obj.finish(); + let result = builder + .new_object() + .with_field("a", 1) + .with_field("b", 2) + .try_with_field("a", 3); assert_eq!( result.unwrap_err().to_string(), - "Invalid argument error: Duplicate field keys detected: [a, b]" + "Invalid argument error: Duplicate field name: a" ); // Deeply nested list -> list -> object with duplicate let mut outer_list = builder.new_list(); let mut inner_list = outer_list.new_list(); - let mut nested_obj = inner_list.new_object(); - nested_obj.insert("x", 1); - nested_obj.insert("x", 2); + let mut object = inner_list.new_object().with_field("x", 1); + let nested_result = object.try_insert("x", 2); + assert_eq!( + nested_result.unwrap_err().to_string(), + "Invalid argument error: Duplicate field name: x" + ); + let nested_result = object.try_new_list("x"); + assert_eq!( + nested_result.unwrap_err().to_string(), + "Invalid argument error: Duplicate field name: x" + ); - let nested_result = nested_obj.finish(); + let nested_result = object.try_new_object("x"); assert_eq!( nested_result.unwrap_err().to_string(), - "Invalid argument error: Duplicate field keys detected: [x]" + "Invalid argument error: Duplicate field name: x" ); + drop(object); inner_list.finish(); outer_list.finish(); @@ -2476,14 +2720,14 @@ mod tests { valid_obj.insert("m", 1); valid_obj.insert("n", 2); - let valid_result = valid_obj.finish(); - assert!(valid_result.is_ok()); + valid_obj.finish(); + list.finish(); } #[test] fn test_sorted_dictionary() { // check if variant metadatabuilders are equivalent from different ways of constructing them - let mut variant1 = VariantBuilder::new().with_field_names(["b", "c", "d"].into_iter()); + let mut variant1 = VariantBuilder::new().with_field_names(["b", "c", "d"]); let mut variant2 = { let mut builder = VariantBuilder::new(); @@ -2533,7 +2777,7 @@ mod tests { #[test] fn test_object_sorted_dictionary() { // predefine the list of field names - let mut variant1 = VariantBuilder::new().with_field_names(["a", "b", "c"].into_iter()); + let mut variant1 = VariantBuilder::new().with_field_names(["a", "b", "c"]); let mut obj = variant1.new_object(); obj.insert("c", true); @@ -2546,7 +2790,7 @@ mod tests { // add a field name that wasn't pre-defined but doesn't break the sort order obj.insert("d", 2); - obj.finish().unwrap(); + obj.finish(); let (metadata, value) = variant1.finish(); let variant = Variant::try_new(&metadata, &value).unwrap(); @@ -2567,7 +2811,7 @@ mod tests { #[test] fn test_object_not_sorted_dictionary() { // predefine the list of field names - let mut variant1 = VariantBuilder::new().with_field_names(["b", "c", "d"].into_iter()); + let mut variant1 = VariantBuilder::new().with_field_names(["b", "c", "d"]); let mut obj = variant1.new_object(); obj.insert("c", true); @@ -2580,7 +2824,7 @@ mod tests { // add a field name that wasn't pre-defined but breaks the sort order obj.insert("a", 2); - obj.finish().unwrap(); + obj.finish(); let (metadata, value) = variant1.finish(); let variant = Variant::try_new(&metadata, &value).unwrap(); @@ -2609,40 +2853,40 @@ mod tests { assert!(builder.metadata_builder.is_sorted); assert_eq!(builder.metadata_builder.num_field_names(), 1); - let builder = builder.with_field_names(["b", "c", "d"].into_iter()); + let builder = builder.with_field_names(["b", "c", "d"]); assert!(builder.metadata_builder.is_sorted); assert_eq!(builder.metadata_builder.num_field_names(), 4); - let builder = builder.with_field_names(["z", "y"].into_iter()); + let builder = builder.with_field_names(["z", "y"]); assert!(!builder.metadata_builder.is_sorted); assert_eq!(builder.metadata_builder.num_field_names(), 6); } #[test] fn test_metadata_builder_from_iter() { - let metadata = MetadataBuilder::from_iter(vec!["apple", "banana", "cherry"]); + let metadata = WritableMetadataBuilder::from_iter(vec!["apple", "banana", "cherry"]); assert_eq!(metadata.num_field_names(), 3); assert_eq!(metadata.field_name(0), "apple"); assert_eq!(metadata.field_name(1), "banana"); assert_eq!(metadata.field_name(2), "cherry"); assert!(metadata.is_sorted); - let metadata = MetadataBuilder::from_iter(["zebra", "apple", "banana"]); + let metadata = WritableMetadataBuilder::from_iter(["zebra", "apple", "banana"]); assert_eq!(metadata.num_field_names(), 3); assert_eq!(metadata.field_name(0), "zebra"); assert_eq!(metadata.field_name(1), "apple"); assert_eq!(metadata.field_name(2), "banana"); assert!(!metadata.is_sorted); - let metadata = MetadataBuilder::from_iter(Vec::<&str>::new()); + let metadata = WritableMetadataBuilder::from_iter(Vec::<&str>::new()); assert_eq!(metadata.num_field_names(), 0); assert!(!metadata.is_sorted); } #[test] fn test_metadata_builder_extend() { - let mut metadata = MetadataBuilder::default(); + let mut metadata = WritableMetadataBuilder::default(); assert_eq!(metadata.num_field_names(), 0); assert!(!metadata.is_sorted); @@ -2667,7 +2911,7 @@ mod tests { #[test] fn test_metadata_builder_extend_sort_order() { - let mut metadata = MetadataBuilder::default(); + let mut metadata = WritableMetadataBuilder::default(); metadata.extend(["middle"]); assert!(metadata.is_sorted); @@ -2683,95 +2927,23 @@ mod tests { #[test] fn test_metadata_builder_from_iter_with_string_types() { // &str - let metadata = MetadataBuilder::from_iter(["a", "b", "c"]); + let metadata = WritableMetadataBuilder::from_iter(["a", "b", "c"]); assert_eq!(metadata.num_field_names(), 3); // string - let metadata = - MetadataBuilder::from_iter(vec!["a".to_string(), "b".to_string(), "c".to_string()]); + let metadata = WritableMetadataBuilder::from_iter(vec![ + "a".to_string(), + "b".to_string(), + "c".to_string(), + ]); assert_eq!(metadata.num_field_names(), 3); // mixed types (anything that implements AsRef) let field_names: Vec> = vec!["a".into(), "b".into(), "c".into()]; - let metadata = MetadataBuilder::from_iter(field_names); + let metadata = WritableMetadataBuilder::from_iter(field_names); assert_eq!(metadata.num_field_names(), 3); } - /// Test reusing buffers with nested objects - #[test] - fn test_with_existing_buffers_nested() { - let mut builder = VariantBuilder::new(); - append_test_list(&mut builder); - let (m1, v1) = builder.finish(); - let variant1 = Variant::new(&m1, &v1); - - let mut builder = VariantBuilder::new(); - append_test_object(&mut builder); - let (m2, v2) = builder.finish(); - let variant2 = Variant::new(&m2, &v2); - - let mut builder = VariantBuilder::new(); - builder.append_value("This is a string"); - let (m3, v3) = builder.finish(); - let variant3 = Variant::new(&m3, &v3); - - // Now, append those three variants to the a new buffer that is reused - let mut builder = VariantBuilder::new(); - append_test_list(&mut builder); - let (metadata, value) = builder.finish(); - let (meta1_offset, meta1_end) = (0, metadata.len()); - let (value1_offset, value1_end) = (0, value.len()); - - // reuse same buffer - let mut builder = VariantBuilder::new_with_buffers(metadata, value); - append_test_object(&mut builder); - let (metadata, value) = builder.finish(); - let (meta2_offset, meta2_end) = (meta1_end, metadata.len()); - let (value2_offset, value2_end) = (value1_end, value.len()); - - // Append a string - let mut builder = VariantBuilder::new_with_buffers(metadata, value); - builder.append_value("This is a string"); - let (metadata, value) = builder.finish(); - let (meta3_offset, meta3_end) = (meta2_end, metadata.len()); - let (value3_offset, value3_end) = (value2_end, value.len()); - - // verify we can read the variants back correctly - let roundtrip1 = Variant::new( - &metadata[meta1_offset..meta1_end], - &value[value1_offset..value1_end], - ); - assert_eq!(roundtrip1, variant1,); - - let roundtrip2 = Variant::new( - &metadata[meta2_offset..meta2_end], - &value[value2_offset..value2_end], - ); - assert_eq!(roundtrip2, variant2,); - - let roundtrip3 = Variant::new( - &metadata[meta3_offset..meta3_end], - &value[value3_offset..value3_end], - ); - assert_eq!(roundtrip3, variant3); - } - - /// append a simple List variant - fn append_test_list(builder: &mut VariantBuilder) { - builder - .new_list() - .with_value(1234) - .with_value("a string value") - .finish(); - } - - /// append an object variant - fn append_test_object(builder: &mut VariantBuilder) { - let mut obj = builder.new_object(); - obj.insert("a", true); - obj.finish().unwrap(); - } - #[test] fn test_variant_builder_to_list_builder_no_finish() { // Create a list builder but never finish it @@ -2896,7 +3068,7 @@ mod tests { // Create a nested object builder and finish it let mut nested_object_builder = list_builder.new_object(); nested_object_builder.insert("name", "unknown"); - nested_object_builder.finish().unwrap(); + nested_object_builder.finish(); // Drop the outer list builder without finishing it drop(list_builder); @@ -2926,15 +3098,18 @@ mod tests { object_builder.insert("second", 2i8); // The parent object should only contain the original fields - object_builder.finish().unwrap(); + object_builder.finish(); let (metadata, value) = builder.finish(); + let metadata = VariantMetadata::try_new(&metadata).unwrap(); - assert_eq!(metadata.len(), 1); - assert_eq!(&metadata[0], "second"); + assert_eq!(metadata.len(), 2); + assert_eq!(&metadata[0], "first"); + assert_eq!(&metadata[1], "second"); let variant = Variant::try_new_with_metadata(metadata, &value).unwrap(); let obj = variant.as_object().unwrap(); - assert_eq!(obj.len(), 1); + assert_eq!(obj.len(), 2); + assert_eq!(obj.get("first"), Some(Variant::Int8(1))); assert_eq!(obj.get("second"), Some(Variant::Int8(2))); } @@ -2977,15 +3152,18 @@ mod tests { object_builder.insert("second", 2i8); // The parent object should only contain the original fields - object_builder.finish().unwrap(); + object_builder.finish(); let (metadata, value) = builder.finish(); + let metadata = VariantMetadata::try_new(&metadata).unwrap(); - assert_eq!(metadata.len(), 1); // the fields of nested_object_builder has been rolled back - assert_eq!(&metadata[0], "second"); + assert_eq!(metadata.len(), 2); // the fields of nested_object_builder has been rolled back + assert_eq!(&metadata[0], "first"); + assert_eq!(&metadata[1], "second"); let variant = Variant::try_new_with_metadata(metadata, &value).unwrap(); let obj = variant.as_object().unwrap(); - assert_eq!(obj.len(), 1); + assert_eq!(obj.len(), 2); + assert_eq!(obj.get("first"), Some(Variant::Int8(1))); assert_eq!(obj.get("second"), Some(Variant::Int8(2))); } @@ -2998,7 +3176,7 @@ mod tests { // Create a nested object builder and finish it let mut nested_object_builder = object_builder.new_object("nested"); nested_object_builder.insert("name", "unknown"); - nested_object_builder.finish().unwrap(); + nested_object_builder.finish(); // Drop the outer object builder without finishing it drop(object_builder); @@ -3036,7 +3214,7 @@ mod tests { obj.insert("b", true); obj.insert("a", false); - obj.finish().unwrap(); + obj.finish(); builder.finish() } @@ -3065,10 +3243,10 @@ mod tests { { let mut inner_obj = outer_obj.new_object("b"); inner_obj.insert("a", "inner_value"); - inner_obj.finish().unwrap(); + inner_obj.finish(); } - outer_obj.finish().unwrap(); + outer_obj.finish(); } builder.finish() @@ -3121,4 +3299,457 @@ mod tests { builder.finish() } + + // Make sure that we can correctly build deeply nested objects even when some of the nested + // builders don't finish. + #[test] + fn test_append_list_object_list_object() { + // An infinite counter + let mut counter = 0..; + let mut take = move |i| (&mut counter).take(i).collect::>(); + let mut builder = VariantBuilder::new(); + let skip = 5; + { + let mut list = builder.new_list(); + for i in take(4) { + let mut object = list.new_object(); + for i in take(4) { + let field_name = format!("field{i}"); + let mut list = object.new_list(&field_name); + for i in take(3) { + let mut object = list.new_object(); + for i in take(3) { + if i % skip != 0 { + object.insert(&format!("field{i}"), i); + } + } + if i % skip != 0 { + object.finish(); + } + } + if i % skip != 0 { + list.finish(); + } + } + if i % skip != 0 { + object.finish(); + } + } + list.finish(); + } + let (metadata, value) = builder.finish(); + let v1 = Variant::try_new(&metadata, &value).unwrap(); + + let (metadata, value) = VariantBuilder::new().with_value(v1.clone()).finish(); + let v2 = Variant::try_new(&metadata, &value).unwrap(); + + assert_eq!(format!("{v1:?}"), format!("{v2:?}")); + } + + #[test] + fn test_read_only_metadata_builder() { + // First create some metadata with a few field names + let mut default_builder = VariantBuilder::new(); + default_builder.add_field_name("name"); + default_builder.add_field_name("age"); + default_builder.add_field_name("active"); + let (metadata_bytes, _) = default_builder.finish(); + + // Use the metadata to build new variant values + let metadata = VariantMetadata::try_new(&metadata_bytes).unwrap(); + let mut metadata_builder = ReadOnlyMetadataBuilder::new(metadata); + let mut value_builder = ValueBuilder::new(); + + { + let state = ParentState::variant(&mut value_builder, &mut metadata_builder); + let mut obj = ObjectBuilder::new(state, false); + + // These should succeed because the fields exist in the metadata + obj.insert("name", "Alice"); + obj.insert("age", 30i8); + obj.insert("active", true); + obj.finish(); + } + + let value = value_builder.into_inner(); + + // Verify the variant was built correctly + let variant = Variant::try_new(&metadata_bytes, &value).unwrap(); + let obj = variant.as_object().unwrap(); + assert_eq!(obj.get("name"), Some(Variant::from("Alice"))); + assert_eq!(obj.get("age"), Some(Variant::Int8(30))); + assert_eq!(obj.get("active"), Some(Variant::from(true))); + } + + #[test] + fn test_read_only_metadata_builder_fails_on_unknown_field() { + // Create metadata with only one field + let mut default_builder = VariantBuilder::new(); + default_builder.add_field_name("known_field"); + let (metadata_bytes, _) = default_builder.finish(); + + // Use the metadata to build new variant values + let metadata = VariantMetadata::try_new(&metadata_bytes).unwrap(); + let mut metadata_builder = ReadOnlyMetadataBuilder::new(metadata); + let mut value_builder = ValueBuilder::new(); + + { + let state = ParentState::variant(&mut value_builder, &mut metadata_builder); + let mut obj = ObjectBuilder::new(state, false); + + // This should succeed + obj.insert("known_field", "value"); + + // This should fail because "unknown_field" is not in the metadata + let result = obj.try_insert("unknown_field", "value"); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Field name 'unknown_field' not found")); + } + } + + #[test] + fn test_append_variant_bytes_round_trip() { + // Create a complex variant with the normal builder + let mut builder = VariantBuilder::new(); + { + let mut obj = builder.new_object(); + obj.insert("name", "Alice"); + obj.insert("age", 30i32); + { + let mut scores_list = obj.new_list("scores"); + scores_list.append_value(95i32); + scores_list.append_value(87i32); + scores_list.append_value(92i32); + scores_list.finish(); + } + { + let mut address = obj.new_object("address"); + address.insert("street", "123 Main St"); + address.insert("city", "Anytown"); + address.finish(); + } + obj.finish(); + } + let (metadata, value1) = builder.finish(); + let variant1 = Variant::try_new(&metadata, &value1).unwrap(); + + // Copy using the new bytes API + let metadata = VariantMetadata::new(&metadata); + let mut metadata = ReadOnlyMetadataBuilder::new(metadata); + let mut builder2 = ValueBuilder::new(); + let state = ParentState::variant(&mut builder2, &mut metadata); + ValueBuilder::append_variant_bytes(state, variant1.clone()); + let value2 = builder2.into_inner(); + + // The bytes should be identical, we merely copied them across. + assert_eq!(value1, value2); + } + + #[test] + fn test_object_insert_bytes_subset() { + // Create an original object, making sure to inject the field names we'll add later. + let mut builder = VariantBuilder::new().with_field_names(["new_field", "another_field"]); + { + let mut obj = builder.new_object(); + obj.insert("field1", "value1"); + obj.insert("field2", 42i32); + obj.insert("field3", true); + obj.insert("field4", "value4"); + obj.finish(); + } + let (metadata1, value1) = builder.finish(); + let original_variant = Variant::try_new(&metadata1, &value1).unwrap(); + let original_obj = original_variant.as_object().unwrap(); + + // Create a new object copying subset of fields interleaved with new ones + let metadata2 = VariantMetadata::new(&metadata1); + let mut metadata2 = ReadOnlyMetadataBuilder::new(metadata2); + let mut builder2 = ValueBuilder::new(); + let state = ParentState::variant(&mut builder2, &mut metadata2); + { + let mut obj = ObjectBuilder::new(state, true); + + // Copy field1 using bytes API + obj.insert_bytes("field1", original_obj.get("field1").unwrap()); + + // Add new field + obj.insert("new_field", "new_value"); + + // Copy field3 using bytes API + obj.insert_bytes("field3", original_obj.get("field3").unwrap()); + + // Add another new field + obj.insert("another_field", 99i32); + + // Copy field2 using bytes API + obj.insert_bytes("field2", original_obj.get("field2").unwrap()); + + obj.finish(); + } + let value2 = builder2.into_inner(); + let result_variant = Variant::try_new(&metadata1, &value2).unwrap(); + let result_obj = result_variant.as_object().unwrap(); + + // Verify the object contains expected fields + assert_eq!(result_obj.len(), 5); + assert_eq!( + result_obj.get("field1").unwrap().as_string().unwrap(), + "value1" + ); + assert_eq!(result_obj.get("field2").unwrap().as_int32().unwrap(), 42); + assert!(result_obj.get("field3").unwrap().as_boolean().unwrap()); + assert_eq!( + result_obj.get("new_field").unwrap().as_string().unwrap(), + "new_value" + ); + assert_eq!( + result_obj.get("another_field").unwrap().as_int32().unwrap(), + 99 + ); + } + + #[test] + fn test_list_append_bytes_subset() { + // Create an original list + let mut builder = VariantBuilder::new(); + { + let mut list = builder.new_list(); + list.append_value("item1"); + list.append_value(42i32); + list.append_value(true); + list.append_value("item4"); + list.append_value(1.234f64); + list.finish(); + } + let (metadata1, value1) = builder.finish(); + let original_variant = Variant::try_new(&metadata1, &value1).unwrap(); + let original_list = original_variant.as_list().unwrap(); + + // Create a new list copying subset of elements interleaved with new ones + let metadata2 = VariantMetadata::new(&metadata1); + let mut metadata2 = ReadOnlyMetadataBuilder::new(metadata2); + let mut builder2 = ValueBuilder::new(); + let state = ParentState::variant(&mut builder2, &mut metadata2); + { + let mut list = ListBuilder::new(state, true); + + // Copy first element using bytes API + list.append_value_bytes(original_list.get(0).unwrap()); + + // Add new element + list.append_value("new_item"); + + // Copy third element using bytes API + list.append_value_bytes(original_list.get(2).unwrap()); + + // Add another new element + list.append_value(99i32); + + // Copy last element using bytes API + list.append_value_bytes(original_list.get(4).unwrap()); + + list.finish(); + } + let value2 = builder2.into_inner(); + let result_variant = Variant::try_new(&metadata1, &value2).unwrap(); + let result_list = result_variant.as_list().unwrap(); + + // Verify the list contains expected elements + assert_eq!(result_list.len(), 5); + assert_eq!(result_list.get(0).unwrap().as_string().unwrap(), "item1"); + assert_eq!(result_list.get(1).unwrap().as_string().unwrap(), "new_item"); + assert!(result_list.get(2).unwrap().as_boolean().unwrap()); + assert_eq!(result_list.get(3).unwrap().as_int32().unwrap(), 99); + assert_eq!(result_list.get(4).unwrap().as_f64().unwrap(), 1.234); + } + + #[test] + fn test_complex_nested_filtering_injection() { + // Create a complex nested structure: object -> list -> objects. Make sure to pre-register + // the extra field names we'll need later while manipulating variant bytes. + let mut builder = VariantBuilder::new().with_field_names([ + "active_count", + "active_users", + "computed_score", + "processed_at", + "status", + ]); + + { + let mut root_obj = builder.new_object(); + root_obj.insert("metadata", "original"); + + { + let mut users_list = root_obj.new_list("users"); + + // User 1 + { + let mut user1 = users_list.new_object(); + user1.insert("id", 1i32); + user1.insert("name", "Alice"); + user1.insert("active", true); + user1.finish(); + } + + // User 2 + { + let mut user2 = users_list.new_object(); + user2.insert("id", 2i32); + user2.insert("name", "Bob"); + user2.insert("active", false); + user2.finish(); + } + + // User 3 + { + let mut user3 = users_list.new_object(); + user3.insert("id", 3i32); + user3.insert("name", "Charlie"); + user3.insert("active", true); + user3.finish(); + } + + users_list.finish(); + } + + root_obj.insert("total_count", 3i32); + root_obj.finish(); + } + let (metadata1, value1) = builder.finish(); + let original_variant = Variant::try_new(&metadata1, &value1).unwrap(); + let original_obj = original_variant.as_object().unwrap(); + let original_users = original_obj.get("users").unwrap(); + let original_users = original_users.as_list().unwrap(); + + // Create filtered/modified version: only copy active users and inject new data + let metadata2 = VariantMetadata::new(&metadata1); + let mut metadata2 = ReadOnlyMetadataBuilder::new(metadata2); + let mut builder2 = ValueBuilder::new(); + let state = ParentState::variant(&mut builder2, &mut metadata2); + { + let mut root_obj = ObjectBuilder::new(state, true); + + // Copy metadata using bytes API + root_obj.insert_bytes("metadata", original_obj.get("metadata").unwrap()); + + // Add processing timestamp + root_obj.insert("processed_at", "2024-01-01T00:00:00Z"); + + { + let mut filtered_users = root_obj.new_list("active_users"); + + // Copy only active users and inject additional data + for i in 0..original_users.len() { + let user = original_users.get(i).unwrap(); + let user = user.as_object().unwrap(); + if user.get("active").unwrap().as_boolean().unwrap() { + { + let mut new_user = filtered_users.new_object(); + + // Copy existing fields using bytes API + new_user.insert_bytes("id", user.get("id").unwrap()); + new_user.insert_bytes("name", user.get("name").unwrap()); + + // Inject new computed field + let user_id = user.get("id").unwrap().as_int32().unwrap(); + new_user.insert("computed_score", user_id * 10); + + // Add status transformation (don't copy the 'active' field) + new_user.insert("status", "verified"); + + new_user.finish(); + } + } + } + + // Inject a completely new user + { + let mut new_user = filtered_users.new_object(); + new_user.insert("id", 999i32); + new_user.insert("name", "System User"); + new_user.insert("computed_score", 0i32); + new_user.insert("status", "system"); + new_user.finish(); + } + + filtered_users.finish(); + } + + // Update count + root_obj.insert("active_count", 3i32); // 2 active + 1 new + + root_obj.finish(); + } + let value2 = builder2.into_inner(); + let result_variant = Variant::try_new(&metadata1, &value2).unwrap(); + let result_obj = result_variant.as_object().unwrap(); + + // Verify the filtered/modified structure + assert_eq!( + result_obj.get("metadata").unwrap().as_string().unwrap(), + "original" + ); + assert_eq!( + result_obj.get("processed_at").unwrap().as_string().unwrap(), + "2024-01-01T00:00:00Z" + ); + assert_eq!( + result_obj.get("active_count").unwrap().as_int32().unwrap(), + 3 + ); + + let active_users = result_obj.get("active_users").unwrap(); + let active_users = active_users.as_list().unwrap(); + assert_eq!(active_users.len(), 3); + + // Verify Alice (id=1, was active) + let alice = active_users.get(0).unwrap(); + let alice = alice.as_object().unwrap(); + assert_eq!(alice.get("id").unwrap().as_int32().unwrap(), 1); + assert_eq!(alice.get("name").unwrap().as_string().unwrap(), "Alice"); + assert_eq!(alice.get("computed_score").unwrap().as_int32().unwrap(), 10); + assert_eq!( + alice.get("status").unwrap().as_string().unwrap(), + "verified" + ); + assert!(alice.get("active").is_none()); // This field was not copied + + // Verify Charlie (id=3, was active) - Bob (id=2) was not active so not included + let charlie = active_users.get(1).unwrap(); + let charlie = charlie.as_object().unwrap(); + assert_eq!(charlie.get("id").unwrap().as_int32().unwrap(), 3); + assert_eq!(charlie.get("name").unwrap().as_string().unwrap(), "Charlie"); + assert_eq!( + charlie.get("computed_score").unwrap().as_int32().unwrap(), + 30 + ); + assert_eq!( + charlie.get("status").unwrap().as_string().unwrap(), + "verified" + ); + + // Verify injected system user + let system_user = active_users.get(2).unwrap(); + let system_user = system_user.as_object().unwrap(); + assert_eq!(system_user.get("id").unwrap().as_int32().unwrap(), 999); + assert_eq!( + system_user.get("name").unwrap().as_string().unwrap(), + "System User" + ); + assert_eq!( + system_user + .get("computed_score") + .unwrap() + .as_int32() + .unwrap(), + 0 + ); + assert_eq!( + system_user.get("status").unwrap().as_string().unwrap(), + "system" + ); + } } diff --git a/parquet-variant/src/decoder.rs b/parquet-variant/src/decoder.rs index 21069cdc02fc..26b4e204fa69 100644 --- a/parquet-variant/src/decoder.rs +++ b/parquet-variant/src/decoder.rs @@ -20,7 +20,8 @@ use crate::utils::{ use crate::ShortString; use arrow_schema::ArrowError; -use chrono::{DateTime, Duration, NaiveDate, NaiveDateTime, Utc}; +use chrono::{DateTime, Duration, NaiveDate, NaiveDateTime, NaiveTime, Utc}; +use uuid::Uuid; /// The basic type of a [`Variant`] value, encoded in the first two bits of the /// header byte. @@ -63,6 +64,10 @@ pub enum VariantPrimitiveType { Float = 14, Binary = 15, String = 16, + Time = 17, + TimestampNanos = 18, + TimestampNtzNanos = 19, + Uuid = 20, } /// Extracts the basic type from a header byte @@ -104,6 +109,10 @@ impl TryFrom for VariantPrimitiveType { 14 => Ok(VariantPrimitiveType::Float), 15 => Ok(VariantPrimitiveType::Binary), 16 => Ok(VariantPrimitiveType::String), + 17 => Ok(VariantPrimitiveType::Time), + 18 => Ok(VariantPrimitiveType::TimestampNanos), + 19 => Ok(VariantPrimitiveType::TimestampNtzNanos), + 20 => Ok(VariantPrimitiveType::Uuid), _ => Err(ArrowError::InvalidArgumentError(format!( "unknown primitive type: {value}", ))), @@ -295,6 +304,44 @@ pub(crate) fn decode_timestampntz_micros(data: &[u8]) -> Result Result { + let micros_since_epoch = u64::from_le_bytes(array_from_slice(data, 0)?); + + let case_error = ArrowError::CastError(format!( + "Could not cast {micros_since_epoch} microseconds into a NaiveTime" + )); + + if micros_since_epoch >= 86_400_000_000 { + return Err(case_error); + } + + let nanos_since_midnight = micros_since_epoch * 1_000; + NaiveTime::from_num_seconds_from_midnight_opt( + (nanos_since_midnight / 1_000_000_000) as u32, + (nanos_since_midnight % 1_000_000_000) as u32, + ) + .ok_or(case_error) +} + +/// Decodes a TimestampNanos from the value section of a variant. +pub(crate) fn decode_timestamp_nanos(data: &[u8]) -> Result, ArrowError> { + let nanos_since_epoch = i64::from_le_bytes(array_from_slice(data, 0)?); + + // DateTime::from_timestamp_nanos would never fail + Ok(DateTime::from_timestamp_nanos(nanos_since_epoch)) +} + +/// Decodes a TimestampNtzNanos from the value section of a variant. +pub(crate) fn decode_timestampntz_nanos(data: &[u8]) -> Result { + decode_timestamp_nanos(data).map(|v| v.naive_utc()) +} + +/// Decodes a UUID from the value section of a variant. +pub(crate) fn decode_uuid(data: &[u8]) -> Result { + Uuid::from_slice(&data[0..16]) + .map_err(|_| ArrowError::CastError(format!("Cant decode uuid from {:?}", &data[0..16]))) +} + /// Decodes a Binary from the value section of a variant. pub(crate) fn decode_binary(data: &[u8]) -> Result<&[u8], ArrowError> { let len = u32::from_le_bytes(array_from_slice(data, 0)?) as usize; @@ -439,6 +486,80 @@ mod tests { .and_hms_milli_opt(16, 34, 56, 780) .unwrap() ); + + test_decoder_bounds!( + test_timestamp_nanos, + [0x15, 0x41, 0xa2, 0x5a, 0x36, 0xa2, 0x5b, 0x18], + decode_timestamp_nanos, + NaiveDate::from_ymd_opt(2025, 8, 14) + .unwrap() + .and_hms_nano_opt(12, 33, 54, 123456789) + .unwrap() + .and_utc() + ); + + test_decoder_bounds!( + test_timestamp_nanos_before_epoch, + [0x15, 0x41, 0x52, 0xd4, 0x94, 0xe5, 0xad, 0xfa], + decode_timestamp_nanos, + NaiveDate::from_ymd_opt(1957, 11, 7) + .unwrap() + .and_hms_nano_opt(12, 33, 54, 123456789) + .unwrap() + .and_utc() + ); + + test_decoder_bounds!( + test_timestampntz_nanos, + [0x15, 0x41, 0xa2, 0x5a, 0x36, 0xa2, 0x5b, 0x18], + decode_timestampntz_nanos, + NaiveDate::from_ymd_opt(2025, 8, 14) + .unwrap() + .and_hms_nano_opt(12, 33, 54, 123456789) + .unwrap() + ); + + test_decoder_bounds!( + test_timestampntz_nanos_before_epoch, + [0x15, 0x41, 0x52, 0xd4, 0x94, 0xe5, 0xad, 0xfa], + decode_timestampntz_nanos, + NaiveDate::from_ymd_opt(1957, 11, 7) + .unwrap() + .and_hms_nano_opt(12, 33, 54, 123456789) + .unwrap() + ); + } + + #[test] + fn test_uuid() { + let data = [ + 0xf2, 0x4f, 0x9b, 0x64, 0x81, 0xfa, 0x49, 0xd1, 0xb7, 0x4e, 0x8c, 0x09, 0xa6, 0xe3, + 0x1c, 0x56, + ]; + let result = decode_uuid(&data).unwrap(); + assert_eq!( + Uuid::parse_str("f24f9b64-81fa-49d1-b74e-8c09a6e31c56").unwrap(), + result + ); + } + + mod time { + use super::*; + + test_decoder_bounds!( + test_timentz, + [0x53, 0x1f, 0x8e, 0xdf, 0x2, 0, 0, 0], + decode_time_ntz, + NaiveTime::from_num_seconds_from_midnight_opt(12340, 567_891_000).unwrap() + ); + + #[test] + fn test_decode_time_ntz_invalid() { + let invalid_second = u64::MAX; + let data = invalid_second.to_le_bytes(); + let result = decode_time_ntz(&data); + assert!(matches!(result, Err(ArrowError::CastError(_)))); + } } #[test] diff --git a/parquet-variant/src/path.rs b/parquet-variant/src/path.rs index 3ba50da3285e..794636ef4092 100644 --- a/parquet-variant/src/path.rs +++ b/parquet-variant/src/path.rs @@ -95,10 +95,10 @@ impl<'a> From>> for VariantPath<'a> { } } -/// Create from &str +/// Create from &str with support for dot notation impl<'a> From<&'a str> for VariantPath<'a> { fn from(path: &'a str) -> Self { - VariantPath::new(vec![path.into()]) + VariantPath::new(path.split('.').map(Into::into).collect()) } } @@ -109,6 +109,12 @@ impl<'a> From for VariantPath<'a> { } } +impl<'a> From<&[VariantPathElement<'a>]> for VariantPath<'a> { + fn from(elements: &[VariantPathElement<'a>]) -> Self { + VariantPath::new(elements.to_vec()) + } +} + /// Create from iter impl<'a> FromIterator> for VariantPath<'a> { fn from_iter>>(iter: T) -> Self { diff --git a/parquet-variant/src/utils.rs b/parquet-variant/src/utils.rs index 8374105e0af8..d28b8685baa2 100644 --- a/parquet-variant/src/utils.rs +++ b/parquet-variant/src/utils.rs @@ -18,6 +18,7 @@ use std::{array::TryFromSliceError, ops::Range, str}; use arrow_schema::ArrowError; +use std::cmp::Ordering; use std::fmt::Debug; use std::slice::SliceIndex; @@ -115,23 +116,20 @@ pub(crate) fn string_from_slice( /// * `Some(Ok(index))` - Element found at the given index /// * `Some(Err(index))` - Element not found, but would be inserted at the given index /// * `None` - Key extraction failed -pub(crate) fn try_binary_search_range_by( +pub(crate) fn try_binary_search_range_by( range: Range, - target: &K, - key_extractor: F, + cmp: F, ) -> Option> where - K: Ord, - F: Fn(usize) -> Option, + F: Fn(usize) -> Option, { let Range { mut start, mut end } = range; while start < end { let mid = start + (end - start) / 2; - let key = key_extractor(mid)?; - match key.cmp(target) { - std::cmp::Ordering::Equal => return Some(Ok(mid)), - std::cmp::Ordering::Greater => end = mid, - std::cmp::Ordering::Less => start = mid + 1, + match cmp(mid)? { + Ordering::Equal => return Some(Ok(mid)), + Ordering::Greater => end = mid, + Ordering::Less => start = mid + 1, } } @@ -146,3 +144,20 @@ pub(crate) const fn expect_size_of(expected: usize) { let _ = [""; 0][size]; } } + +pub(crate) fn fits_precision(n: impl Into) -> bool { + n.into().unsigned_abs().leading_zeros() >= (i64::BITS - N) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_fits_precision() { + assert!(fits_precision::<10>(1023)); + assert!(!fits_precision::<10>(1024)); + assert!(fits_precision::<10>(-1023)); + assert!(!fits_precision::<10>(-1024)); + } +} diff --git a/parquet-variant/src/variant.rs b/parquet-variant/src/variant.rs index 82de637b0697..38ef5ba30a45 100644 --- a/parquet-variant/src/variant.rs +++ b/parquet-variant/src/variant.rs @@ -17,17 +17,22 @@ pub use self::decimal::{VariantDecimal16, VariantDecimal4, VariantDecimal8}; pub use self::list::VariantList; -pub use self::metadata::VariantMetadata; +pub use self::metadata::{VariantMetadata, EMPTY_VARIANT_METADATA, EMPTY_VARIANT_METADATA_BYTES}; pub use self::object::VariantObject; + +// Publically export types used in the API +pub use half::f16; +pub use uuid::Uuid; + use crate::decoder::{ self, get_basic_type, get_primitive_type, VariantBasicType, VariantPrimitiveType, }; use crate::path::{VariantPath, VariantPathElement}; -use crate::utils::{first_byte_from_slice, slice_from_slice}; +use crate::utils::{first_byte_from_slice, fits_precision, slice_from_slice}; use std::ops::Deref; use arrow_schema::ArrowError; -use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc}; +use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc}; mod decimal; mod list; @@ -211,7 +216,7 @@ impl Deref for ShortString<'_> { /// [metadata]: VariantMetadata#Validation /// [object]: VariantObject#Validation /// [array]: VariantList#Validation -#[derive(Debug, Clone, PartialEq)] +#[derive(Clone, PartialEq)] pub enum Variant<'m, 'v> { /// Primitive type: Null Null, @@ -229,6 +234,10 @@ pub enum Variant<'m, 'v> { TimestampMicros(DateTime), /// Primitive (type_id=1): TIMESTAMP(isAdjustedToUTC=false, MICROS) TimestampNtzMicros(NaiveDateTime), + /// Primitive (type_id=1): TIMESTAMP(isAdjustedToUTC=true, NANOS) + TimestampNanos(DateTime), + /// Primitive (type_id=1): TIMESTAMP(isAdjustedToUTC=false, NANOS) + TimestampNtzNanos(NaiveDateTime), /// Primitive (type_id=1): DECIMAL(precision, scale) 32-bits Decimal4(VariantDecimal4), /// Primitive (type_id=1): DECIMAL(precision, scale) 64-bits @@ -248,6 +257,10 @@ pub enum Variant<'m, 'v> { Binary(&'v [u8]), /// Primitive (type_id=1): STRING String(&'v str), + /// Primitive (type_id=1): TIME(isAdjustedToUTC=false, MICROS) + Time(NaiveTime), + /// Primitive (type_id=1): UUID + Uuid(Uuid), /// Short String (type_id=2): STRING ShortString(ShortString<'v>), // need both metadata & value @@ -379,12 +392,20 @@ impl<'m, 'v> Variant<'m, 'v> { VariantPrimitiveType::TimestampNtzMicros => { Variant::TimestampNtzMicros(decoder::decode_timestampntz_micros(value_data)?) } + VariantPrimitiveType::TimestampNanos => { + Variant::TimestampNanos(decoder::decode_timestamp_nanos(value_data)?) + } + VariantPrimitiveType::TimestampNtzNanos => { + Variant::TimestampNtzNanos(decoder::decode_timestampntz_nanos(value_data)?) + } + VariantPrimitiveType::Uuid => Variant::Uuid(decoder::decode_uuid(value_data)?), VariantPrimitiveType::Binary => { Variant::Binary(decoder::decode_binary(value_data)?) } VariantPrimitiveType::String => { Variant::String(decoder::decode_long_string(value_data)?) } + VariantPrimitiveType::Time => Variant::Time(decoder::decode_time_ntz(value_data)?), }, VariantBasicType::ShortString => { Variant::ShortString(decoder::decode_short_string(value_metadata, value_data)?) @@ -525,11 +546,9 @@ impl<'m, 'v> Variant<'m, 'v> { /// let datetime = NaiveDate::from_ymd_opt(2025, 4, 16).unwrap().and_hms_milli_opt(12, 34, 56, 780).unwrap().and_utc(); /// let v1 = Variant::from(datetime); /// assert_eq!(v1.as_datetime_utc(), Some(datetime)); - /// - /// // or a non-UTC-adjusted variant - /// let datetime = NaiveDate::from_ymd_opt(2025, 4, 16).unwrap().and_hms_milli_opt(12, 34, 56, 780).unwrap(); - /// let v2 = Variant::from(datetime); - /// assert_eq!(v2.as_datetime_utc(), Some(datetime.and_utc())); + /// let datetime_nanos = NaiveDate::from_ymd_opt(2025, 8, 14).unwrap().and_hms_nano_opt(12, 33, 54, 123456789).unwrap().and_utc(); + /// let v2 = Variant::from(datetime_nanos); + /// assert_eq!(v2.as_datetime_utc(), Some(datetime_nanos)); /// /// // but not from other variants /// let v3 = Variant::from("hello!"); @@ -537,8 +556,7 @@ impl<'m, 'v> Variant<'m, 'v> { /// ``` pub fn as_datetime_utc(&self) -> Option> { match *self { - Variant::TimestampMicros(d) => Some(d), - Variant::TimestampNtzMicros(d) => Some(d.and_utc()), + Variant::TimestampMicros(d) | Variant::TimestampNanos(d) => Some(d), _ => None, } } @@ -560,9 +578,9 @@ impl<'m, 'v> Variant<'m, 'v> { /// assert_eq!(v1.as_naive_datetime(), Some(datetime)); /// /// // or a UTC-adjusted variant - /// let datetime = NaiveDate::from_ymd_opt(2025, 4, 16).unwrap().and_hms_milli_opt(12, 34, 56, 780).unwrap().and_utc(); + /// let datetime = NaiveDate::from_ymd_opt(2025, 4, 16).unwrap().and_hms_nano_opt(12, 34, 56, 123456789).unwrap(); /// let v2 = Variant::from(datetime); - /// assert_eq!(v2.as_naive_datetime(), Some(datetime.naive_utc())); + /// assert_eq!(v2.as_naive_datetime(), Some(datetime)); /// /// // but not from other variants /// let v3 = Variant::from("hello!"); @@ -570,8 +588,7 @@ impl<'m, 'v> Variant<'m, 'v> { /// ``` pub fn as_naive_datetime(&self) -> Option { match *self { - Variant::TimestampNtzMicros(d) => Some(d), - Variant::TimestampMicros(d) => Some(d.naive_utc()), + Variant::TimestampNtzMicros(d) | Variant::TimestampNtzNanos(d) => Some(d), _ => None, } } @@ -629,6 +646,32 @@ impl<'m, 'v> Variant<'m, 'v> { } } + /// Converts this variant to a `uuid hyphenated string` if possible. + /// + /// Returns `Some(String)` for UUID variants, `None` for non-UUID variants. + /// + /// # Examples + /// + /// ``` + /// use parquet_variant::Variant; + /// + /// // You can extract a UUID from a UUID variant + /// let s = uuid::Uuid::parse_str("67e55044-10b1-426f-9247-bb680e5fe0c8").unwrap(); + /// let v1 = Variant::Uuid(s); + /// assert_eq!(s, v1.as_uuid().unwrap()); + /// assert_eq!("67e55044-10b1-426f-9247-bb680e5fe0c8", v1.as_uuid().unwrap().to_string()); + /// + /// //but not from other variants + /// let v2 = Variant::from(1234); + /// assert_eq!(None, v2.as_uuid()) + /// ``` + pub fn as_uuid(&self) -> Option { + match self { + Variant::Uuid(u) => Some(*u), + _ => None, + } + } + /// Converts this variant to an `i8` if possible. /// /// Returns `Some(i8)` for integer variants that fit in `i8` range, @@ -765,6 +808,166 @@ impl<'m, 'v> Variant<'m, 'v> { } } + fn generic_convert_unsigned_primitive(&self) -> Option + where + T: TryFrom + TryFrom + TryFrom + TryFrom + TryFrom, + { + match *self { + Variant::Int8(i) => i.try_into().ok(), + Variant::Int16(i) => i.try_into().ok(), + Variant::Int32(i) => i.try_into().ok(), + Variant::Int64(i) => i.try_into().ok(), + Variant::Decimal4(d) if d.scale() == 0 => d.integer().try_into().ok(), + Variant::Decimal8(d) if d.scale() == 0 => d.integer().try_into().ok(), + Variant::Decimal16(d) if d.scale() == 0 => d.integer().try_into().ok(), + _ => None, + } + } + + /// Converts this variant to a `u8` if possible. + /// + /// Returns `Some(u8)` for integer variants that fit in `u8` + /// `None` for non-integer variants or values that would overflow. + /// + /// # Examples + /// + /// ``` + /// use parquet_variant::{Variant, VariantDecimal4}; + /// + /// // you can read an int64 variant into an u8 + /// let v1 = Variant::from(123i64); + /// assert_eq!(v1.as_u8(), Some(123u8)); + /// + /// // or a Decimal4 with scale 0 into u8 + /// let d = VariantDecimal4::try_new(26, 0).unwrap(); + /// let v2 = Variant::from(d); + /// assert_eq!(v2.as_u8(), Some(26u8)); + /// + /// // but not a variant that can't fit into the range + /// let v3 = Variant::from(-1); + /// assert_eq!(v3.as_u8(), None); + /// + /// // not a variant that decimal with scale not equal to zero + /// let d = VariantDecimal4::try_new(1, 2).unwrap(); + /// let v4 = Variant::from(d); + /// assert_eq!(v4.as_u8(), None); + /// + /// // or not a variant that cannot be cast into an integer + /// let v5 = Variant::from("hello!"); + /// assert_eq!(v5.as_u8(), None); + /// ``` + pub fn as_u8(&self) -> Option { + self.generic_convert_unsigned_primitive::() + } + + /// Converts this variant to an `u16` if possible. + /// + /// Returns `Some(u16)` for integer variants that fit in `u16` + /// `None` for non-integer variants or values that would overflow. + /// + /// # Examples + /// + /// ``` + /// use parquet_variant::{Variant, VariantDecimal4}; + /// + /// // you can read an int64 variant into an u16 + /// let v1 = Variant::from(123i64); + /// assert_eq!(v1.as_u16(), Some(123u16)); + /// + /// // or a Decimal4 with scale 0 into u8 + /// let d = VariantDecimal4::try_new(u16::MAX as i32, 0).unwrap(); + /// let v2 = Variant::from(d); + /// assert_eq!(v2.as_u16(), Some(u16::MAX)); + /// + /// // but not a variant that can't fit into the range + /// let v3 = Variant::from(-1); + /// assert_eq!(v3.as_u16(), None); + /// + /// // not a variant that decimal with scale not equal to zero + /// let d = VariantDecimal4::try_new(1, 2).unwrap(); + /// let v4 = Variant::from(d); + /// assert_eq!(v4.as_u16(), None); + /// + /// // or not a variant that cannot be cast into an integer + /// let v5 = Variant::from("hello!"); + /// assert_eq!(v5.as_u16(), None); + /// ``` + pub fn as_u16(&self) -> Option { + self.generic_convert_unsigned_primitive::() + } + + /// Converts this variant to an `u32` if possible. + /// + /// Returns `Some(u32)` for integer variants that fit in `u32` + /// `None` for non-integer variants or values that would overflow. + /// + /// # Examples + /// + /// ``` + /// use parquet_variant::{Variant, VariantDecimal8}; + /// + /// // you can read an int64 variant into an u32 + /// let v1 = Variant::from(123i64); + /// assert_eq!(v1.as_u32(), Some(123u32)); + /// + /// // or a Decimal4 with scale 0 into u8 + /// let d = VariantDecimal8::try_new(u32::MAX as i64, 0).unwrap(); + /// let v2 = Variant::from(d); + /// assert_eq!(v2.as_u32(), Some(u32::MAX)); + /// + /// // but not a variant that can't fit into the range + /// let v3 = Variant::from(-1); + /// assert_eq!(v3.as_u32(), None); + /// + /// // not a variant that decimal with scale not equal to zero + /// let d = VariantDecimal8::try_new(1, 2).unwrap(); + /// let v4 = Variant::from(d); + /// assert_eq!(v4.as_u32(), None); + /// + /// // or not a variant that cannot be cast into an integer + /// let v5 = Variant::from("hello!"); + /// assert_eq!(v5.as_u32(), None); + /// ``` + pub fn as_u32(&self) -> Option { + self.generic_convert_unsigned_primitive::() + } + + /// Converts this variant to an `u64` if possible. + /// + /// Returns `Some(u64)` for integer variants that fit in `u64` + /// `None` for non-integer variants or values that would overflow. + /// + /// # Examples + /// + /// ``` + /// use parquet_variant::{Variant, VariantDecimal16}; + /// + /// // you can read an int64 variant into an u64 + /// let v1 = Variant::from(123i64); + /// assert_eq!(v1.as_u64(), Some(123u64)); + /// + /// // or a Decimal16 with scale 0 into u8 + /// let d = VariantDecimal16::try_new(u64::MAX as i128, 0).unwrap(); + /// let v2 = Variant::from(d); + /// assert_eq!(v2.as_u64(), Some(u64::MAX)); + /// + /// // but not a variant that can't fit into the range + /// let v3 = Variant::from(-1); + /// assert_eq!(v3.as_u64(), None); + /// + /// // not a variant that decimal with scale not equal to zero + /// let d = VariantDecimal16::try_new(1, 2).unwrap(); + /// let v4 = Variant::from(d); + /// assert_eq!(v4.as_u64(), None); + /// + /// // or not a variant that cannot be cast into an integer + /// let v5 = Variant::from("hello!"); + /// assert_eq!(v5.as_u64(), None); + /// ``` + pub fn as_u64(&self) -> Option { + self.generic_convert_unsigned_primitive::() + } + /// Converts this variant to tuple with a 4-byte unscaled value if possible. /// /// Returns `Some((i32, u8))` for decimal variants where the unscaled value @@ -876,10 +1079,49 @@ impl<'m, 'v> Variant<'m, 'v> { _ => None, } } + + /// Converts this variant to an `f16` if possible. + /// + /// Returns `Some(f16)` for floating point values, and integers with up to 11 bits of + /// precision. `None` otherwise. + /// + /// # Example + /// + /// ``` + /// use parquet_variant::Variant; + /// use half::f16; + /// + /// // you can extract an f16 from a float variant + /// let v1 = Variant::from(std::f32::consts::PI); + /// assert_eq!(v1.as_f16(), Some(f16::from_f32(std::f32::consts::PI))); + /// + /// // and from a double variant (with loss of precision to nearest f16) + /// let v2 = Variant::from(std::f64::consts::PI); + /// assert_eq!(v2.as_f16(), Some(f16::from_f64(std::f64::consts::PI))); + /// + /// // and from integers with no more than 11 bits of precision + /// let v3 = Variant::from(2047); + /// assert_eq!(v3.as_f16(), Some(f16::from_f32(2047.0))); + /// + /// // but not from other variants + /// let v4 = Variant::from("hello!"); + /// assert_eq!(v4.as_f16(), None); + pub fn as_f16(&self) -> Option { + match *self { + Variant::Float(i) => Some(f16::from_f32(i)), + Variant::Double(i) => Some(f16::from_f64(i)), + Variant::Int8(i) => Some(i.into()), + Variant::Int16(i) if fits_precision::<11>(i) => Some(f16::from_f32(i as _)), + Variant::Int32(i) if fits_precision::<11>(i) => Some(f16::from_f32(i as _)), + Variant::Int64(i) if fits_precision::<11>(i) => Some(f16::from_f32(i as _)), + _ => None, + } + } + /// Converts this variant to an `f32` if possible. /// - /// Returns `Some(f32)` for float and double variants, - /// `None` for non-floating-point variants. + /// Returns `Some(f32)` for floating point values, and integer values with up to 24 bits of + /// precision. `None` otherwise. /// /// # Examples /// @@ -894,23 +1136,31 @@ impl<'m, 'v> Variant<'m, 'v> { /// let v2 = Variant::from(std::f64::consts::PI); /// assert_eq!(v2.as_f32(), Some(std::f32::consts::PI)); /// + /// // and from integers with no more than 24 bits of precision + /// let v3 = Variant::from(16777215i64); + /// assert_eq!(v3.as_f32(), Some(16777215.0)); + /// /// // but not from other variants - /// let v3 = Variant::from("hello!"); - /// assert_eq!(v3.as_f32(), None); + /// let v4 = Variant::from("hello!"); + /// assert_eq!(v4.as_f32(), None); /// ``` #[allow(clippy::cast_possible_truncation)] pub fn as_f32(&self) -> Option { match *self { Variant::Float(i) => Some(i), Variant::Double(i) => Some(i as f32), + Variant::Int8(i) => Some(i.into()), + Variant::Int16(i) => Some(i.into()), + Variant::Int32(i) if fits_precision::<24>(i) => Some(i as _), + Variant::Int64(i) if fits_precision::<24>(i) => Some(i as _), _ => None, } } /// Converts this variant to an `f64` if possible. /// - /// Returns `Some(f64)` for float and double variants, - /// `None` for non-floating-point variants. + /// Returns `Some(f64)` for floating point values, and integer values with up to 53 bits of + /// precision. `None` otherwise. /// /// # Examples /// @@ -925,14 +1175,22 @@ impl<'m, 'v> Variant<'m, 'v> { /// let v2 = Variant::from(std::f64::consts::PI); /// assert_eq!(v2.as_f64(), Some(std::f64::consts::PI)); /// + /// // and from integers with no more than 53 bits of precision + /// let v3 = Variant::from(9007199254740991i64); + /// assert_eq!(v3.as_f64(), Some(9007199254740991.0)); + /// /// // but not from other variants - /// let v3 = Variant::from("hello!"); - /// assert_eq!(v3.as_f64(), None); + /// let v4 = Variant::from("hello!"); + /// assert_eq!(v4.as_f64(), None); /// ``` pub fn as_f64(&self) -> Option { match *self { Variant::Float(i) => Some(i.into()), Variant::Double(i) => Some(i), + Variant::Int8(i) => Some(i.into()), + Variant::Int16(i) => Some(i.into()), + Variant::Int32(i) => Some(i.into()), + Variant::Int64(i) if fits_precision::<53>(i) => Some(i as _), _ => None, } } @@ -1030,6 +1288,34 @@ impl<'m, 'v> Variant<'m, 'v> { } } + /// Converts this variant to a `NaiveTime` if possible. + /// + /// Returns `Some(NaiveTime)` for `Variant::Time`, + /// `None` for non-Time variants. + /// + /// # Example + /// + /// ``` + /// use chrono::NaiveTime; + /// use parquet_variant::Variant; + /// + /// // you can extract a `NaiveTime` from a `Variant::Time` + /// let time = NaiveTime::from_hms_micro_opt(1, 2, 3, 4).unwrap(); + /// let v1 = Variant::from(time); + /// assert_eq!(Some(time), v1.as_time_utc()); + /// + /// // but not from other variants. + /// let v2 = Variant::from("Hello"); + /// assert_eq!(None, v2.as_time_utc()); + /// ``` + pub fn as_time_utc(&'m self) -> Option { + if let Variant::Time(time) = self { + Some(*time) + } else { + None + } + } + /// If this is a list and the requested index is in bounds, retrieves the corresponding /// element. Otherwise, returns None. /// @@ -1082,7 +1368,7 @@ impl<'m, 'v> Variant<'m, 'v> { /// # list.append_value("bar"); /// # list.append_value("baz"); /// # list.finish(); - /// # obj.finish().unwrap(); + /// # obj.finish(); /// # let (metadata, value) = builder.finish(); /// // given a variant like `{"foo": ["bar", "baz"]}` /// let variant = Variant::new(&metadata, &value); @@ -1211,6 +1497,12 @@ impl From for Variant<'_, '_> { } } +impl From for Variant<'_, '_> { + fn from(value: half::f16) -> Self { + Variant::Float(value.into()) + } +} + impl From for Variant<'_, '_> { fn from(value: f32) -> Self { Variant::Float(value) @@ -1231,12 +1523,21 @@ impl From for Variant<'_, '_> { impl From> for Variant<'_, '_> { fn from(value: DateTime) -> Self { - Variant::TimestampMicros(value) + if value.nanosecond() % 1000 > 0 { + Variant::TimestampNanos(value) + } else { + Variant::TimestampMicros(value) + } } } + impl From for Variant<'_, '_> { fn from(value: NaiveDateTime) -> Self { - Variant::TimestampNtzMicros(value) + if value.nanosecond() % 1000 > 0 { + Variant::TimestampNtzNanos(value) + } else { + Variant::TimestampNtzMicros(value) + } } } @@ -1246,6 +1547,18 @@ impl<'v> From<&'v [u8]> for Variant<'_, 'v> { } } +impl From for Variant<'_, '_> { + fn from(value: NaiveTime) -> Self { + Variant::Time(value) + } +} + +impl From for Variant<'_, '_> { + fn from(value: Uuid) -> Self { + Variant::Uuid(value) + } +} + impl<'v> From<&'v str> for Variant<'_, 'v> { fn from(value: &'v str) -> Self { if value.len() > MAX_SHORT_STRING_BYTES { @@ -1286,6 +1599,81 @@ impl TryFrom<(i128, u8)> for Variant<'_, '_> { } } +// helper to print instead of "" in debug mode when a VariantObject or VariantList contains invalid values. +struct InvalidVariant; + +impl std::fmt::Debug for InvalidVariant { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "") + } +} + +// helper to print binary data in hex format in debug mode, as space-separated hex byte values. +struct HexString<'a>(&'a [u8]); + +impl<'a> std::fmt::Debug for HexString<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some((first, rest)) = self.0.split_first() { + write!(f, "{:02x}", first)?; + for b in rest { + write!(f, " {:02x}", b)?; + } + } + Ok(()) + } +} + +impl std::fmt::Debug for Variant<'_, '_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Variant::Null => write!(f, "Null"), + Variant::BooleanTrue => write!(f, "BooleanTrue"), + Variant::BooleanFalse => write!(f, "BooleanFalse"), + Variant::Int8(v) => f.debug_tuple("Int8").field(v).finish(), + Variant::Int16(v) => f.debug_tuple("Int16").field(v).finish(), + Variant::Int32(v) => f.debug_tuple("Int32").field(v).finish(), + Variant::Int64(v) => f.debug_tuple("Int64").field(v).finish(), + Variant::Float(v) => f.debug_tuple("Float").field(v).finish(), + Variant::Double(v) => f.debug_tuple("Double").field(v).finish(), + Variant::Decimal4(d) => f.debug_tuple("Decimal4").field(d).finish(), + Variant::Decimal8(d) => f.debug_tuple("Decimal8").field(d).finish(), + Variant::Decimal16(d) => f.debug_tuple("Decimal16").field(d).finish(), + Variant::Date(d) => f.debug_tuple("Date").field(d).finish(), + Variant::TimestampMicros(ts) => f.debug_tuple("TimestampMicros").field(ts).finish(), + Variant::TimestampNtzMicros(ts) => { + f.debug_tuple("TimestampNtzMicros").field(ts).finish() + } + Variant::TimestampNanos(ts) => f.debug_tuple("TimestampNanos").field(ts).finish(), + Variant::TimestampNtzNanos(ts) => f.debug_tuple("TimestampNtzNanos").field(ts).finish(), + Variant::Binary(bytes) => write!(f, "Binary({:?})", HexString(bytes)), + Variant::String(s) => f.debug_tuple("String").field(s).finish(), + Variant::Time(s) => f.debug_tuple("Time").field(s).finish(), + Variant::ShortString(s) => f.debug_tuple("ShortString").field(s).finish(), + Variant::Uuid(uuid) => f.debug_tuple("Uuid").field(&uuid).finish(), + Variant::Object(obj) => { + let mut map = f.debug_map(); + for res in obj.iter_try() { + match res { + Ok((k, v)) => map.entry(&k, &v), + Err(_) => map.entry(&InvalidVariant, &InvalidVariant), + }; + } + map.finish() + } + Variant::List(arr) => { + let mut list = f.debug_list(); + for res in arr.iter_try() { + match res { + Ok(v) => list.entry(&v), + Err(_) => list.entry(&InvalidVariant), + }; + } + list.finish() + } + } + } +} + #[cfg(test)] mod tests { @@ -1326,4 +1714,258 @@ mod tests { let variant = Variant::from(decimal16); assert_eq!(variant.as_decimal16(), Some(decimal16)); } + + #[test] + fn test_variant_all_subtypes_debug() { + use crate::VariantBuilder; + + let mut builder = VariantBuilder::new(); + + // Create a root object that contains one of every variant subtype + let mut root_obj = builder.new_object(); + + // Add primitive types + root_obj.insert("null", ()); + root_obj.insert("boolean_true", true); + root_obj.insert("boolean_false", false); + root_obj.insert("int8", 42i8); + root_obj.insert("int16", 1234i16); + root_obj.insert("int32", 123456i32); + root_obj.insert("int64", 1234567890123456789i64); + root_obj.insert("float", 1.234f32); + root_obj.insert("double", 1.23456789f64); + + // Add date and timestamp types + let date = chrono::NaiveDate::from_ymd_opt(2024, 12, 25).unwrap(); + root_obj.insert("date", date); + + let timestamp_utc = chrono::NaiveDate::from_ymd_opt(2024, 12, 25) + .unwrap() + .and_hms_milli_opt(15, 30, 45, 123) + .unwrap() + .and_utc(); + root_obj.insert("timestamp_micros", Variant::TimestampMicros(timestamp_utc)); + + let timestamp_ntz = chrono::NaiveDate::from_ymd_opt(2024, 12, 25) + .unwrap() + .and_hms_milli_opt(15, 30, 45, 123) + .unwrap(); + root_obj.insert( + "timestamp_ntz_micros", + Variant::TimestampNtzMicros(timestamp_ntz), + ); + + let timestamp_nanos_utc = chrono::NaiveDate::from_ymd_opt(2025, 8, 15) + .unwrap() + .and_hms_nano_opt(12, 3, 4, 123456789) + .unwrap() + .and_utc(); + root_obj.insert( + "timestamp_nanos", + Variant::TimestampNanos(timestamp_nanos_utc), + ); + + let timestamp_ntz_nanos = chrono::NaiveDate::from_ymd_opt(2025, 8, 15) + .unwrap() + .and_hms_nano_opt(12, 3, 4, 123456789) + .unwrap(); + root_obj.insert( + "timestamp_ntz_nanos", + Variant::TimestampNtzNanos(timestamp_ntz_nanos), + ); + + // Add decimal types + let decimal4 = VariantDecimal4::try_new(1234i32, 2).unwrap(); + root_obj.insert("decimal4", decimal4); + + let decimal8 = VariantDecimal8::try_new(123456789i64, 3).unwrap(); + root_obj.insert("decimal8", decimal8); + + let decimal16 = VariantDecimal16::try_new(123456789012345678901234567890i128, 4).unwrap(); + root_obj.insert("decimal16", decimal16); + + // Add binary and string types + let binary_data = b"\x01\x02\x03\x04\xde\xad\xbe\xef"; + root_obj.insert("binary", binary_data.as_slice()); + + let long_string = + "This is a long string that exceeds the short string limit and contains emoji 🦀"; + root_obj.insert("string", long_string); + root_obj.insert("short_string", "Short string with emoji 🎉"); + let time = NaiveTime::from_hms_micro_opt(1, 2, 3, 4).unwrap(); + root_obj.insert("time", time); + + // Add uuid + let uuid = Uuid::parse_str("67e55044-10b1-426f-9247-bb680e5fe0c8").unwrap(); + root_obj.insert("uuid", Variant::Uuid(uuid)); + + // Add nested object + let mut nested_obj = root_obj.new_object("nested_object"); + nested_obj.insert("inner_key1", "inner_value1"); + nested_obj.insert("inner_key2", 999i32); + nested_obj.finish(); + + // Add list with mixed types + let mut mixed_list = root_obj.new_list("mixed_list"); + mixed_list.append_value(1i32); + mixed_list.append_value("two"); + mixed_list.append_value(true); + mixed_list.append_value(4.0f32); + mixed_list.append_value(()); + + // Add nested list inside the mixed list + let mut nested_list = mixed_list.new_list(); + nested_list.append_value("nested"); + nested_list.append_value(10i8); + nested_list.finish(); + + mixed_list.finish(); + + root_obj.finish(); + + let (metadata, value) = builder.finish(); + let variant = Variant::try_new(&metadata, &value).unwrap(); + + // Test Debug formatter (?) + let debug_output = format!("{:?}", variant); + + // Verify that the debug output contains all the expected types + assert!(debug_output.contains("\"null\": Null")); + assert!(debug_output.contains("\"boolean_true\": BooleanTrue")); + assert!(debug_output.contains("\"boolean_false\": BooleanFalse")); + assert!(debug_output.contains("\"int8\": Int8(42)")); + assert!(debug_output.contains("\"int16\": Int16(1234)")); + assert!(debug_output.contains("\"int32\": Int32(123456)")); + assert!(debug_output.contains("\"int64\": Int64(1234567890123456789)")); + assert!(debug_output.contains("\"float\": Float(1.234)")); + assert!(debug_output.contains("\"double\": Double(1.23456789")); + assert!(debug_output.contains("\"date\": Date(2024-12-25)")); + assert!(debug_output.contains("\"timestamp_micros\": TimestampMicros(")); + assert!(debug_output.contains("\"timestamp_ntz_micros\": TimestampNtzMicros(")); + assert!(debug_output.contains("\"timestamp_nanos\": TimestampNanos(")); + assert!(debug_output.contains("\"timestamp_ntz_nanos\": TimestampNtzNanos(")); + assert!(debug_output.contains("\"decimal4\": Decimal4(")); + assert!(debug_output.contains("\"decimal8\": Decimal8(")); + assert!(debug_output.contains("\"decimal16\": Decimal16(")); + assert!(debug_output.contains("\"binary\": Binary(01 02 03 04 de ad be ef)")); + assert!(debug_output.contains("\"string\": String(")); + assert!(debug_output.contains("\"short_string\": ShortString(")); + assert!(debug_output.contains("\"uuid\": Uuid(67e55044-10b1-426f-9247-bb680e5fe0c8)")); + assert!(debug_output.contains("\"time\": Time(01:02:03.000004)")); + assert!(debug_output.contains("\"nested_object\":")); + assert!(debug_output.contains("\"mixed_list\":")); + + let expected = r#"{"binary": Binary(01 02 03 04 de ad be ef), "boolean_false": BooleanFalse, "boolean_true": BooleanTrue, "date": Date(2024-12-25), "decimal16": Decimal16(VariantDecimal16 { integer: 123456789012345678901234567890, scale: 4 }), "decimal4": Decimal4(VariantDecimal4 { integer: 1234, scale: 2 }), "decimal8": Decimal8(VariantDecimal8 { integer: 123456789, scale: 3 }), "double": Double(1.23456789), "float": Float(1.234), "int16": Int16(1234), "int32": Int32(123456), "int64": Int64(1234567890123456789), "int8": Int8(42), "mixed_list": [Int32(1), ShortString(ShortString("two")), BooleanTrue, Float(4.0), Null, [ShortString(ShortString("nested")), Int8(10)]], "nested_object": {"inner_key1": ShortString(ShortString("inner_value1")), "inner_key2": Int32(999)}, "null": Null, "short_string": ShortString(ShortString("Short string with emoji 🎉")), "string": String("This is a long string that exceeds the short string limit and contains emoji 🦀"), "time": Time(01:02:03.000004), "timestamp_micros": TimestampMicros(2024-12-25T15:30:45.123Z), "timestamp_nanos": TimestampNanos(2025-08-15T12:03:04.123456789Z), "timestamp_ntz_micros": TimestampNtzMicros(2024-12-25T15:30:45.123), "timestamp_ntz_nanos": TimestampNtzNanos(2025-08-15T12:03:04.123456789), "uuid": Uuid(67e55044-10b1-426f-9247-bb680e5fe0c8)}"#; + assert_eq!(debug_output, expected); + + // Test alternate Debug formatter (#?) + let alt_debug_output = format!("{:#?}", variant); + let expected = r#"{ + "binary": Binary(01 02 03 04 de ad be ef), + "boolean_false": BooleanFalse, + "boolean_true": BooleanTrue, + "date": Date( + 2024-12-25, + ), + "decimal16": Decimal16( + VariantDecimal16 { + integer: 123456789012345678901234567890, + scale: 4, + }, + ), + "decimal4": Decimal4( + VariantDecimal4 { + integer: 1234, + scale: 2, + }, + ), + "decimal8": Decimal8( + VariantDecimal8 { + integer: 123456789, + scale: 3, + }, + ), + "double": Double( + 1.23456789, + ), + "float": Float( + 1.234, + ), + "int16": Int16( + 1234, + ), + "int32": Int32( + 123456, + ), + "int64": Int64( + 1234567890123456789, + ), + "int8": Int8( + 42, + ), + "mixed_list": [ + Int32( + 1, + ), + ShortString( + ShortString( + "two", + ), + ), + BooleanTrue, + Float( + 4.0, + ), + Null, + [ + ShortString( + ShortString( + "nested", + ), + ), + Int8( + 10, + ), + ], + ], + "nested_object": { + "inner_key1": ShortString( + ShortString( + "inner_value1", + ), + ), + "inner_key2": Int32( + 999, + ), + }, + "null": Null, + "short_string": ShortString( + ShortString( + "Short string with emoji 🎉", + ), + ), + "string": String( + "This is a long string that exceeds the short string limit and contains emoji 🦀", + ), + "time": Time( + 01:02:03.000004, + ), + "timestamp_micros": TimestampMicros( + 2024-12-25T15:30:45.123Z, + ), + "timestamp_nanos": TimestampNanos( + 2025-08-15T12:03:04.123456789Z, + ), + "timestamp_ntz_micros": TimestampNtzMicros( + 2024-12-25T15:30:45.123, + ), + "timestamp_ntz_nanos": TimestampNtzNanos( + 2025-08-15T12:03:04.123456789, + ), + "uuid": Uuid( + 67e55044-10b1-426f-9247-bb680e5fe0c8, + ), +}"#; + assert_eq!(alt_debug_output, expected); + } } diff --git a/parquet-variant/src/variant/list.rs b/parquet-variant/src/variant/list.rs index e3053ce9100e..438faddffe15 100644 --- a/parquet-variant/src/variant/list.rs +++ b/parquet-variant/src/variant/list.rs @@ -697,7 +697,7 @@ mod tests { // list3 (10..20) let (metadata3, value3) = make_listi32(10i32..20i32); object_builder.insert("list3", Variant::new(&metadata3, &value3)); - object_builder.finish().unwrap(); + object_builder.finish(); builder.finish() }; diff --git a/parquet-variant/src/variant/metadata.rs b/parquet-variant/src/variant/metadata.rs index 0e356e34c41e..604ee0e890e6 100644 --- a/parquet-variant/src/variant/metadata.rs +++ b/parquet-variant/src/variant/metadata.rs @@ -16,7 +16,10 @@ // under the License. use crate::decoder::{map_bytes_to_offsets, OffsetSizeBytes}; -use crate::utils::{first_byte_from_slice, overflow_error, slice_from_slice, string_from_slice}; +use crate::utils::{ + first_byte_from_slice, overflow_error, slice_from_slice, string_from_slice, + try_binary_search_range_by, +}; use arrow_schema::ArrowError; @@ -127,6 +130,7 @@ impl VariantMetadataHeader { /// [Variant Spec]: https://github.com/apache/parquet-format/blob/master/VariantEncoding.md#metadata-encoding #[derive(Debug, Clone, PartialEq)] pub struct VariantMetadata<'m> { + /// (Only) the bytes that make up this metadata instance. pub(crate) bytes: &'m [u8], header: VariantMetadataHeader, dictionary_size: u32, @@ -138,6 +142,39 @@ pub struct VariantMetadata<'m> { // could increase the size of Variant. All those size increases could hurt performance. const _: () = crate::utils::expect_size_of::(32); +/// The canonical byte slice corresponding to an empty metadata dictionary. +/// +/// ``` +/// # use parquet_variant::{EMPTY_VARIANT_METADATA_BYTES, VariantMetadata, WritableMetadataBuilder}; +/// let mut metadata_builder = WritableMetadataBuilder::default(); +/// metadata_builder.finish(); +/// let metadata_bytes = metadata_builder.into_inner(); +/// assert_eq!(&metadata_bytes, EMPTY_VARIANT_METADATA_BYTES); +/// ``` +pub const EMPTY_VARIANT_METADATA_BYTES: &[u8] = &[1, 0, 0]; + +/// The empty metadata dictionary. +/// +/// ``` +/// # use parquet_variant::{EMPTY_VARIANT_METADATA, VariantMetadata, WritableMetadataBuilder}; +/// let mut metadata_builder = WritableMetadataBuilder::default(); +/// metadata_builder.finish(); +/// let metadata_bytes = metadata_builder.into_inner(); +/// let empty_metadata = VariantMetadata::try_new(&metadata_bytes).unwrap(); +/// assert_eq!(empty_metadata, EMPTY_VARIANT_METADATA); +/// ``` +pub const EMPTY_VARIANT_METADATA: VariantMetadata = VariantMetadata { + bytes: EMPTY_VARIANT_METADATA_BYTES, + header: VariantMetadataHeader { + version: CORRECT_VERSION_VALUE, + is_sorted: false, + offset_size: OffsetSizeBytes::One, + }, + dictionary_size: 0, + first_value_byte: 3, + validated: true, +}; + impl<'m> VariantMetadata<'m> { /// Attempts to interpret `bytes` as a variant metadata instance, with full [validation] of all /// dictionary entries. @@ -296,7 +333,7 @@ impl<'m> VariantMetadata<'m> { self.header.version } - /// Gets an offset array entry by index. + /// Gets an offset into the dictionary entry by index. /// /// This offset is an index into the dictionary, at the boundary between string `i-1` and string /// `i`. See [`Self::get`] to retrieve a specific dictionary entry. @@ -306,6 +343,15 @@ impl<'m> VariantMetadata<'m> { self.header.offset_size.unpack_u32(bytes, i) } + /// Returns the total size, in bytes, of the metadata. + /// + /// Note this value may be smaller than what was passed to [`Self::new`] or + /// [`Self::try_new`] if the input was larger than necessary to encode the + /// metadata dictionary. + pub fn size(&self) -> usize { + self.bytes.len() + } + /// Attempts to retrieve a dictionary entry by index, failing if out of bounds or if the /// underlying bytes are [invalid]. /// @@ -315,6 +361,32 @@ impl<'m> VariantMetadata<'m> { string_from_slice(self.bytes, self.first_value_byte as _, byte_range) } + // Helper method used by our `impl Index` and also by `get_entry`. Panics if the underlying + // bytes are invalid. Needed because the `Index` trait forces the returned result to have the + // lifetime of `self` instead of the string's own (longer) lifetime `'m`. + fn get_impl(&self, i: usize) -> &'m str { + self.get(i).expect("Invalid metadata dictionary entry") + } + + /// Attempts to retrieve a dictionary entry and its field id, returning None if the requested field + /// name is not present. The search cost is logarithmic if [`Self::is_sorted`] and linear + /// otherwise. + /// + /// WARNING: This method panics if the underlying bytes are [invalid]. + /// + /// [invalid]: Self#Validation + pub fn get_entry(&self, field_name: &str) -> Option<(u32, &'m str)> { + let field_id = if self.is_sorted() && self.len() > 10 { + // Binary search is faster for a not-tiny sorted metadata dictionary + let cmp = |i| Some(self.get_impl(i).cmp(field_name)); + try_binary_search_range_by(0..self.len(), cmp)?.ok()? + } else { + // Fall back to Linear search for tiny or unsorted dictionary + (0..self.len()).find(|i| self.get_impl(*i) == field_name)? + }; + Some((field_id as u32, self.get_impl(field_id))) + } + /// Returns an iterator that attempts to visit all dictionary entries, producing `Err` if the /// iterator encounters [invalid] data. /// @@ -341,7 +413,7 @@ impl std::ops::Index for VariantMetadata<'_> { type Output = str; fn index(&self, i: usize) -> &str { - self.get(i).expect("Invalid metadata dictionary entry") + self.get_impl(i) } } @@ -544,7 +616,7 @@ mod tests { o.insert("a", false); o.insert("b", false); - o.finish().unwrap(); + o.finish(); let (m, _) = b.finish(); @@ -579,7 +651,7 @@ mod tests { o.insert("a", false); o.insert("b", false); - o.finish().unwrap(); + o.finish(); let (m, _) = b.finish(); diff --git a/parquet-variant/src/variant/object.rs b/parquet-variant/src/variant/object.rs index b809fe278cb4..df1857846302 100644 --- a/parquet-variant/src/variant/object.rs +++ b/parquet-variant/src/variant/object.rs @@ -397,8 +397,8 @@ impl<'m, 'v> VariantObject<'m, 'v> { // NOTE: This does not require a sorted metadata dictionary, because the variant spec // requires object field ids to be lexically sorted by their corresponding string values, // and probing the dictionary for a field id is always O(1) work. - let i = try_binary_search_range_by(0..self.len(), &name, |i| self.field_name(i))?.ok()?; - + let cmp = |i| Some(self.field_name(i)?.cmp(name)); + let i = try_binary_search_range_by(0..self.len(), cmp)?.ok()?; self.field(i) } } @@ -550,7 +550,7 @@ mod tests { #[test] fn test_variant_object_empty_fields() { let mut builder = VariantBuilder::new(); - builder.new_object().with_field("", 42).finish().unwrap(); + builder.new_object().with_field("", 42).finish(); let (metadata, value) = builder.finish(); // Resulting object is valid and has a single empty field @@ -676,7 +676,7 @@ mod tests { obj.insert(&field_names[i as usize], i); } - obj.finish().unwrap(); + obj.finish(); let (metadata, value) = builder.finish(); let variant = Variant::new(&metadata, &value); @@ -737,7 +737,7 @@ mod tests { obj.insert(&key, str_val.as_str()); } - obj.finish().unwrap(); + obj.finish(); let (metadata, value) = builder.finish(); let variant = Variant::new(&metadata, &value); @@ -783,7 +783,7 @@ mod tests { o.insert("c", ()); o.insert("a", ()); - o.finish().unwrap(); + o.finish(); let (m, v) = b.finish(); @@ -801,7 +801,7 @@ mod tests { o.insert("a", ()); o.insert("b", false); - o.finish().unwrap(); + o.finish(); let (m, v) = b.finish(); let v1 = Variant::try_new(&m, &v).unwrap(); @@ -812,7 +812,7 @@ mod tests { o.insert("a", ()); o.insert("b", false); - o.finish().unwrap(); + o.finish(); let (m, v) = b.finish(); let v2 = Variant::try_new(&m, &v).unwrap(); @@ -828,7 +828,7 @@ mod tests { o.insert("a", ()); o.insert("b", 4.3); - o.finish().unwrap(); + o.finish(); let (m, v) = b.finish(); @@ -841,8 +841,8 @@ mod tests { o.insert("a", ()); let mut inner_o = o.new_object("b"); inner_o.insert("a", 3.3); - inner_o.finish().unwrap(); - o.finish().unwrap(); + inner_o.finish(); + o.finish(); let (m, v) = b.finish(); @@ -866,7 +866,7 @@ mod tests { o.insert("a", ()); o.insert("b", 4.3); - o.finish().unwrap(); + o.finish(); let (m, v) = b.finish(); @@ -879,7 +879,7 @@ mod tests { o.insert("aardvark", ()); o.insert("barracuda", 3.3); - o.finish().unwrap(); + o.finish(); let (m, v) = b.finish(); let v2 = Variant::try_new(&m, &v).unwrap(); @@ -895,7 +895,7 @@ mod tests { o.insert("b", false); o.insert("a", ()); - o.finish().unwrap(); + o.finish(); let (m, v) = b.finish(); @@ -904,13 +904,13 @@ mod tests { // create another object pre-filled with field names, b and a // but insert the fields in the order of a, b - let mut b = VariantBuilder::new().with_field_names(["b", "a"].into_iter()); + let mut b = VariantBuilder::new().with_field_names(["b", "a"]); let mut o = b.new_object(); o.insert("a", ()); o.insert("b", false); - o.finish().unwrap(); + o.finish(); let (m, v) = b.finish(); @@ -930,7 +930,7 @@ mod tests { o.insert("a", ()); o.insert("b", 4.3); - o.finish().unwrap(); + o.finish(); let (meta1, value1) = b.finish(); @@ -939,13 +939,13 @@ mod tests { assert!(v1.metadata().unwrap().is_sorted()); // create a second object with different insertion order - let mut b = VariantBuilder::new().with_field_names(["d", "c", "b", "a"].into_iter()); + let mut b = VariantBuilder::new().with_field_names(["d", "c", "b", "a"]); let mut o = b.new_object(); o.insert("b", 4.3); o.insert("a", ()); - o.finish().unwrap(); + o.finish(); let (meta2, value2) = b.finish(); @@ -969,7 +969,7 @@ mod tests { o.insert("a", false); o.insert("b", false); - o.finish().unwrap(); + o.finish(); let (m, v) = b.finish(); diff --git a/parquet-variant/tests/variant_interop.rs b/parquet-variant/tests/variant_interop.rs index e37172a7d568..00c326c06406 100644 --- a/parquet-variant/tests/variant_interop.rs +++ b/parquet-variant/tests/variant_interop.rs @@ -21,13 +21,14 @@ use std::path::{Path, PathBuf}; use std::{env, fs}; -use chrono::NaiveDate; +use chrono::{DateTime, NaiveDate, NaiveTime}; use parquet_variant::{ ShortString, Variant, VariantBuilder, VariantDecimal16, VariantDecimal4, VariantDecimal8, }; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; +use uuid::Uuid; /// Returns a directory path for the parquet variant test data. /// @@ -112,9 +113,9 @@ fn get_primitive_cases() -> Vec<(&'static str, Variant<'static, 'static>)> { ("primitive_boolean_false", Variant::BooleanFalse), ("primitive_boolean_true", Variant::BooleanTrue), ("primitive_date", Variant::Date(NaiveDate::from_ymd_opt(2025, 4 , 16).unwrap())), - ("primitive_decimal4", Variant::from(VariantDecimal4::try_new(1234i32, 2u8).unwrap())), + ("primitive_decimal4", Variant::from(VariantDecimal4::try_new(1234i32, 2u8).unwrap())), // ("primitive_decimal8", Variant::Decimal8{integer: 1234567890, scale: 2}), - ("primitive_decimal8", Variant::Decimal8(VariantDecimal8::try_new(1234567890,2).unwrap())), + ("primitive_decimal8", Variant::Decimal8(VariantDecimal8::try_new(1234567890,2).unwrap())), ("primitive_decimal16", Variant::Decimal16(VariantDecimal16::try_new(1234567891234567890, 2).unwrap())), ("primitive_float", Variant::Float(1234567890.1234)), ("primitive_double", Variant::Double(1234567890.1234)), @@ -126,7 +127,11 @@ fn get_primitive_cases() -> Vec<(&'static str, Variant<'static, 'static>)> { ("primitive_string", Variant::String("This string is longer than 64 bytes and therefore does not fit in a short_string and it also includes several non ascii characters such as 🐢, 💖, ♥\u{fe0f}, 🎣 and 🤦!!")), ("primitive_timestamp", Variant::TimestampMicros(NaiveDate::from_ymd_opt(2025, 4, 16).unwrap().and_hms_milli_opt(16, 34, 56, 780).unwrap().and_utc())), ("primitive_timestampntz", Variant::TimestampNtzMicros(NaiveDate::from_ymd_opt(2025, 4, 16).unwrap().and_hms_milli_opt(12, 34, 56, 780).unwrap())), + ("primitive_timestamp_nanos", Variant::TimestampNanos(NaiveDate::from_ymd_opt(2024, 11, 7).unwrap().and_hms_nano_opt(12, 33, 54, 123456789).unwrap().and_utc())), + ("primitive_timestampntz_nanos", Variant::TimestampNtzNanos(NaiveDate::from_ymd_opt(2024, 11, 7).unwrap().and_hms_nano_opt(12, 33, 54, 123456789).unwrap())), + ("primitive_uuid", Variant::Uuid(Uuid::parse_str("f24f9b64-81fa-49d1-b74e-8c09a6e31c56").unwrap())), ("short_string", Variant::ShortString(ShortString::try_new("Less than 64 bytes (❤\u{fe0f} with utf8)").unwrap())), + ("primitive_time", Variant::Time(NaiveTime::from_hms_micro_opt(12, 33, 54, 123456).unwrap())), ] } #[test] @@ -267,7 +272,7 @@ fn variant_object_builder() { obj.insert("null_field", ()); obj.insert("timestamp_field", "2025-04-16T12:34:56.78"); - obj.finish().unwrap(); + obj.finish(); let (built_metadata, built_value) = builder.finish(); let actual = Variant::try_new(&built_metadata, &built_value).unwrap(); @@ -318,7 +323,7 @@ fn generate_random_value(rng: &mut StdRng, builder: &mut VariantBuilder, max_dep return; } - match rng.random_range(0..15) { + match rng.random_range(0..18) { 0 => builder.append_value(()), 1 => builder.append_value(rng.random::()), 2 => builder.append_value(rng.random::()), @@ -328,11 +333,13 @@ fn generate_random_value(rng: &mut StdRng, builder: &mut VariantBuilder, max_dep 6 => builder.append_value(rng.random::()), 7 => builder.append_value(rng.random::()), 8 => { + // String let len = rng.random_range(0..50); let s: String = (0..len).map(|_| rng.random::()).collect(); builder.append_value(s.as_str()); } 9 => { + // Binary let len = rng.random_range(0..50); let bytes: Vec = (0..len).map(|_| rng.random()).collect(); builder.append_value(bytes.as_slice()); @@ -377,7 +384,35 @@ fn generate_random_value(rng: &mut StdRng, builder: &mut VariantBuilder, max_dep let key = format!("field_{i}"); object_builder.insert(&key, rng.random::()); } - object_builder.finish().unwrap(); + object_builder.finish(); + } + 15 => { + // Time + builder.append_value( + NaiveTime::from_num_seconds_from_midnight_opt( + // make the argument always valid + rng.random_range(0..86_400), + rng.random_range(0..1_000_000_000), + ) + .unwrap(), + ) + } + 16 => { + let data_time = DateTime::from_timestamp( + // make the argument always valid + rng.random_range(0..86_400), + rng.random_range(0..1_000_000_000), + ) + .unwrap(); + + // timestamp w/o timezone + builder.append_value(data_time.naive_local()); + + // timestamp with timezone + builder.append_value(data_time.naive_utc().and_utc()); + } + 17 => { + builder.append_value(Uuid::new_v4()); } _ => unreachable!(), } diff --git a/parquet/Cargo.toml b/parquet/Cargo.toml index 05557069aa7d..5dbd4b5b39dd 100644 --- a/parquet/Cargo.toml +++ b/parquet/Cargo.toml @@ -45,6 +45,10 @@ arrow-data = { workspace = true, optional = true } arrow-schema = { workspace = true, optional = true } arrow-select = { workspace = true, optional = true } arrow-ipc = { workspace = true, optional = true } +parquet-variant = { workspace = true, optional = true } +parquet-variant-json = { workspace = true, optional = true } +parquet-variant-compute = { workspace = true, optional = true } + object_store = { version = "0.12.0", default-features = false, optional = true } bytes = { version = "1.1", default-features = false, features = ["std"] } @@ -65,7 +69,7 @@ serde_json = { version = "1.0", default-features = false, features = ["std"], op seq-macro = { version = "0.3", default-features = false } futures = { version = "0.3", default-features = false, features = ["std"], optional = true } tokio = { version = "1.0", optional = true, default-features = false, features = ["macros", "rt", "io-util"] } -hashbrown = { version = "0.15", default-features = false } +hashbrown = { version = "0.16", default-features = false } twox-hash = { version = "2.0", default-features = false, features = ["xxhash64"] } paste = { version = "1.0" } half = { version = "2.1", default-features = false, features = ["num-traits"] } @@ -78,6 +82,7 @@ base64 = { version = "0.22", default-features = false, features = ["std"] } criterion = { version = "0.5", default-features = false, features = ["async_futures"] } snap = { version = "1.0", default-features = false } tempfile = { version = "3.0", default-features = false } +insta = "1.43.1" brotli = { version = "8.0", default-features = false, features = ["std"] } flate2 = { version = "1.0", default-features = false, features = ["rust_backend"] } lz4_flex = { version = "0.11", default-features = false, features = ["std", "frame"] } @@ -107,7 +112,7 @@ json = ["serde_json", "base64"] # Enable internal testing APIs test_common = ["arrow/test_utils"] # Experimental, unstable functionality primarily used for testing -experimental = [] +experimental = ["variant_experimental"] # Enable async APIs async = ["futures", "tokio"] # Enable object_store integration @@ -123,6 +128,8 @@ encryption = ["dep:ring"] # Explicitely enabling rust_backend and zlib-rs features for flate2 flate2-rust_backened = ["flate2/rust_backend"] flate2-zlib-rs = ["flate2/zlib-rs"] +# Enable parquet variant support +variant_experimental = ["parquet-variant", "parquet-variant-json", "parquet-variant-compute"] [[example]] @@ -164,6 +171,11 @@ name = "encryption" required-features = ["arrow"] path = "./tests/encryption/mod.rs" +[[test]] +name = "variant_integration" +required-features = ["arrow", "variant_experimental", "serde"] +path = "./tests/variant_integration.rs" + [[bin]] name = "parquet-read" required-features = ["cli"] diff --git a/parquet/README.md b/parquet/README.md index 8fc72bfbc32a..5e087ac6a929 100644 --- a/parquet/README.md +++ b/parquet/README.md @@ -64,9 +64,11 @@ The `parquet` crate provides the following features which may be enabled in your - `experimental` - Experimental APIs which may change, even between minor releases - `simdutf8` (default) - Use the [`simdutf8`] crate for SIMD-accelerated UTF-8 validation - `encryption` - support for reading / writing encrypted Parquet files +- `variant_experimental` - ⚠️ Experimental [Parquet Variant] support, which may change, even between minor releases. [`arrow`]: https://crates.io/crates/arrow [`simdutf8`]: https://crates.io/crates/simdutf8 +[parquet variant]: https://github.com/apache/parquet-format/blob/master/VariantEncoding.md ## Parquet Feature Status diff --git a/parquet/benches/arrow_reader_row_filter.rs b/parquet/benches/arrow_reader_row_filter.rs index 33427a37b59a..0ef40ac8237c 100644 --- a/parquet/benches/arrow_reader_row_filter.rs +++ b/parquet/benches/arrow_reader_row_filter.rs @@ -70,7 +70,7 @@ use parquet::arrow::arrow_reader::{ use parquet::arrow::async_reader::AsyncFileReader; use parquet::arrow::{ArrowWriter, ParquetRecordBatchStreamBuilder, ProjectionMask}; use parquet::basic::Compression; -use parquet::file::metadata::{ParquetMetaData, ParquetMetaDataReader}; +use parquet::file::metadata::{PageIndexPolicy, ParquetMetaData, ParquetMetaDataReader}; use parquet::file::properties::WriterProperties; use rand::{rngs::StdRng, Rng, SeedableRng}; use std::ops::Range; @@ -550,7 +550,8 @@ struct InMemoryReader { impl InMemoryReader { fn try_new(inner: &Bytes) -> parquet::errors::Result { - let mut metadata_reader = ParquetMetaDataReader::new().with_page_indexes(true); + let mut metadata_reader = + ParquetMetaDataReader::new().with_page_index_policy(PageIndexPolicy::Required); metadata_reader.try_parse(inner)?; let metadata = metadata_reader.finish().map(Arc::new)?; diff --git a/parquet/benches/metadata.rs b/parquet/benches/metadata.rs index 949e0d98ea39..8c886e4d5eea 100644 --- a/parquet/benches/metadata.rs +++ b/parquet/benches/metadata.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use parquet::file::metadata::ParquetMetaDataReader; use rand::Rng; use thrift::protocol::TCompactOutputProtocol; @@ -25,7 +26,7 @@ use parquet::file::reader::SerializedFileReader; use parquet::file::serialized_reader::ReadOptionsBuilder; use parquet::format::{ ColumnChunk, ColumnMetaData, CompressionCodec, Encoding, FieldRepetitionType, FileMetaData, - RowGroup, SchemaElement, Type, + PageEncodingStats, PageType, RowGroup, SchemaElement, Type, }; use parquet::thrift::TSerializable; @@ -93,7 +94,18 @@ fn encoded_meta() -> Vec { index_page_offset: Some(rng.random()), dictionary_page_offset: Some(rng.random()), statistics: Some(stats.clone()), - encoding_stats: None, + encoding_stats: Some(vec![ + PageEncodingStats { + page_type: PageType::DICTIONARY_PAGE, + encoding: Encoding::PLAIN, + count: 1, + }, + PageEncodingStats { + page_type: PageType::DATA_PAGE, + encoding: Encoding::RLE_DICTIONARY, + count: 10, + }, + ]), bloom_filter_offset: None, bloom_filter_length: None, size_statistics: None, @@ -151,6 +163,36 @@ fn get_footer_bytes(data: Bytes) -> Bytes { data.slice(meta_start..meta_end) } +#[cfg(feature = "arrow")] +fn rewrite_file(bytes: Bytes) -> (Bytes, FileMetaData) { + use arrow::array::RecordBatchReader; + use parquet::arrow::{arrow_reader::ParquetRecordBatchReaderBuilder, ArrowWriter}; + use parquet::file::properties::{EnabledStatistics, WriterProperties}; + + let parquet_reader = ParquetRecordBatchReaderBuilder::try_new(bytes) + .expect("parquet open") + .build() + .expect("parquet open"); + let writer_properties = WriterProperties::builder() + .set_statistics_enabled(EnabledStatistics::Page) + .set_write_page_header_statistics(true) + .build(); + let mut output = Vec::new(); + let mut parquet_writer = ArrowWriter::try_new( + &mut output, + parquet_reader.schema(), + Some(writer_properties), + ) + .expect("create arrow writer"); + + for maybe_batch in parquet_reader { + let batch = maybe_batch.expect("reading batch"); + parquet_writer.write(&batch).expect("writing data"); + } + let file_meta = parquet_writer.close().expect("finalizing file"); + (output.into(), file_meta) +} + fn criterion_benchmark(c: &mut Criterion) { // Read file into memory to isolate filesystem performance let file = "../parquet-testing/data/alltypes_tiny_pages.parquet"; @@ -168,19 +210,54 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - let meta_data = get_footer_bytes(data); - c.bench_function("decode file metadata", |b| { + let meta_data = get_footer_bytes(data.clone()); + c.bench_function("decode parquet metadata", |b| { + b.iter(|| { + ParquetMetaDataReader::decode_metadata(&meta_data).unwrap(); + }) + }); + + c.bench_function("decode thrift file metadata", |b| { b.iter(|| { parquet::thrift::bench_file_metadata(&meta_data); }) }); - let buf = black_box(encoded_meta()).into(); - c.bench_function("decode file metadata (wide)", |b| { + let buf: Bytes = black_box(encoded_meta()).into(); + c.bench_function("decode parquet metadata (wide)", |b| { + b.iter(|| { + ParquetMetaDataReader::decode_metadata(&buf).unwrap(); + }) + }); + + c.bench_function("decode thrift file metadata (wide)", |b| { b.iter(|| { parquet::thrift::bench_file_metadata(&buf); }) }); + + // rewrite file with page statistics. then read page headers. + #[cfg(feature = "arrow")] + let (file_bytes, metadata) = rewrite_file(data.clone()); + #[cfg(feature = "arrow")] + c.bench_function("page headers", |b| { + b.iter(|| { + metadata.row_groups.iter().for_each(|rg| { + rg.columns.iter().for_each(|col| { + if let Some(col_meta) = &col.meta_data { + if let Some(dict_offset) = col_meta.dictionary_page_offset { + parquet::thrift::bench_page_header( + &file_bytes.slice(dict_offset as usize..), + ); + } + parquet::thrift::bench_page_header( + &file_bytes.slice(col_meta.data_page_offset as usize..), + ); + } + }); + }); + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/parquet/examples/external_metadata.rs b/parquet/examples/external_metadata.rs index 2710251e5569..9370016049e1 100644 --- a/parquet/examples/external_metadata.rs +++ b/parquet/examples/external_metadata.rs @@ -20,7 +20,9 @@ use arrow_cast::pretty::pretty_format_batches; use futures::TryStreamExt; use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ArrowReaderOptions}; use parquet::arrow::{ArrowWriter, ParquetRecordBatchStreamBuilder}; -use parquet::file::metadata::{ParquetMetaData, ParquetMetaDataReader, ParquetMetaDataWriter}; +use parquet::file::metadata::{ + PageIndexPolicy, ParquetMetaData, ParquetMetaDataReader, ParquetMetaDataWriter, +}; use parquet::file::properties::{EnabledStatistics, WriterProperties}; use std::fs::File; use std::path::{Path, PathBuf}; @@ -111,7 +113,7 @@ async fn get_metadata_from_remote_parquet_file( // tell the reader to read the page index ParquetMetaDataReader::new() - .with_page_indexes(true) + .with_page_index_policy(PageIndexPolicy::Required) .load_and_finish(remote_file, file_size) .await .unwrap() @@ -160,7 +162,7 @@ fn write_metadata_to_local_file(metadata: ParquetMetaData, file: impl AsRef) -> ParquetMetaData { let file = File::open(file).unwrap(); ParquetMetaDataReader::new() - .with_page_indexes(true) + .with_page_index_policy(PageIndexPolicy::Required) .parse_and_finish(&file) .unwrap() } diff --git a/parquet/src/arrow/array_reader/builder.rs b/parquet/src/arrow/array_reader/builder.rs index d5e36fbcb486..1ee7cc50acc2 100644 --- a/parquet/src/arrow/array_reader/builder.rs +++ b/parquet/src/arrow/array_reader/builder.rs @@ -44,12 +44,12 @@ pub struct CacheOptionsBuilder<'a> { /// Projection mask to apply to the cache pub projection_mask: &'a ProjectionMask, /// Cache to use for storing row groups - pub cache: Arc>, + pub cache: &'a Arc>, } impl<'a> CacheOptionsBuilder<'a> { /// create a new cache options builder - pub fn new(projection_mask: &'a ProjectionMask, cache: Arc>) -> Self { + pub fn new(projection_mask: &'a ProjectionMask, cache: &'a Arc>) -> Self { Self { projection_mask, cache, @@ -79,7 +79,7 @@ impl<'a> CacheOptionsBuilder<'a> { #[derive(Clone)] pub struct CacheOptions<'a> { pub projection_mask: &'a ProjectionMask, - pub cache: Arc>, + pub cache: &'a Arc>, pub role: CacheRole, } @@ -144,7 +144,7 @@ impl<'a> ArrayReaderBuilder<'a> { if cache_options.projection_mask.leaf_included(col_idx) { Ok(Some(Box::new(CachedArrayReader::new( reader, - Arc::clone(&cache_options.cache), + Arc::clone(cache_options.cache), col_idx, cache_options.role, self.metrics.clone(), // cheap clone diff --git a/parquet/src/arrow/arrow_reader/filter.rs b/parquet/src/arrow/arrow_reader/filter.rs index 3a897c05444b..4fbe45748b88 100644 --- a/parquet/src/arrow/arrow_reader/filter.rs +++ b/parquet/src/arrow/arrow_reader/filter.rs @@ -186,4 +186,12 @@ impl RowFilter { pub fn new(predicates: Vec>) -> Self { Self { predicates } } + /// Returns the inner predicates + pub fn predicates(&self) -> &Vec> { + &self.predicates + } + /// Returns the inner predicates, consuming self + pub fn into_predicates(self) -> Vec> { + self.predicates + } } diff --git a/parquet/src/arrow/arrow_reader/mod.rs b/parquet/src/arrow/arrow_reader/mod.rs index 3d20fa0a220c..81765a800edd 100644 --- a/parquet/src/arrow/arrow_reader/mod.rs +++ b/parquet/src/arrow/arrow_reader/mod.rs @@ -37,7 +37,7 @@ use crate::column::page::{PageIterator, PageReader}; #[cfg(feature = "encryption")] use crate::encryption::decrypt::FileDecryptionProperties; use crate::errors::{ParquetError, Result}; -use crate::file::metadata::{ParquetMetaData, ParquetMetaDataReader}; +use crate::file::metadata::{PageIndexPolicy, ParquetMetaData, ParquetMetaDataReader}; use crate::file::reader::{ChunkReader, SerializedPageReader}; use crate::format::{BloomFilterAlgorithm, BloomFilterCompression, BloomFilterHash}; use crate::schema::types::SchemaDescriptor; @@ -383,8 +383,8 @@ pub struct ArrowReaderOptions { /// /// [ARROW_SCHEMA_META_KEY]: crate::arrow::ARROW_SCHEMA_META_KEY supplied_schema: Option, - /// If true, attempt to read `OffsetIndex` and `ColumnIndex` - pub(crate) page_index: bool, + /// Policy for reading offset and column indexes. + pub(crate) page_index_policy: PageIndexPolicy, /// If encryption is enabled, the file decryption properties can be provided #[cfg(feature = "encryption")] pub(crate) file_decryption_properties: Option, @@ -486,7 +486,20 @@ impl ArrowReaderOptions { /// [`ParquetMetaData::column_index`]: crate::file::metadata::ParquetMetaData::column_index /// [`ParquetMetaData::offset_index`]: crate::file::metadata::ParquetMetaData::offset_index pub fn with_page_index(self, page_index: bool) -> Self { - Self { page_index, ..self } + let page_index_policy = PageIndexPolicy::from(page_index); + + Self { + page_index_policy, + ..self + } + } + + /// Set the [`PageIndexPolicy`] to determine how page indexes should be read. + pub fn with_page_index_policy(self, policy: PageIndexPolicy) -> Self { + Self { + page_index_policy: policy, + ..self + } } /// Provide the file decryption properties to use when reading encrypted parquet files. @@ -507,7 +520,7 @@ impl ArrowReaderOptions { /// /// This can be set via [`with_page_index`][Self::with_page_index]. pub fn page_index(&self) -> bool { - self.page_index + self.page_index_policy != PageIndexPolicy::Skip } /// Retrieve the currently set file decryption properties. @@ -556,7 +569,8 @@ impl ArrowReaderMetadata { /// `Self::metadata` is missing the page index, this function will attempt /// to load the page index by making an object store request. pub fn load(reader: &T, options: ArrowReaderOptions) -> Result { - let metadata = ParquetMetaDataReader::new().with_page_indexes(options.page_index); + let metadata = + ParquetMetaDataReader::new().with_page_index_policy(options.page_index_policy); #[cfg(feature = "encryption")] let metadata = metadata.with_decryption_properties(options.file_decryption_properties.as_ref()); diff --git a/parquet/src/arrow/arrow_writer/levels.rs b/parquet/src/arrow/arrow_writer/levels.rs index 1956394ac50e..3c283bcbe3d2 100644 --- a/parquet/src/arrow/arrow_writer/levels.rs +++ b/parquet/src/arrow/arrow_writer/levels.rs @@ -550,13 +550,41 @@ impl LevelInfoBuilder { /// and the other is a native array, the dictionary values must have the same type as the /// native array fn types_compatible(a: &DataType, b: &DataType) -> bool { + // if the Arrow data types are the same, the types are clearly compatible if a == b { return true; } - match (a, b) { - (DataType::Dictionary(_, v), b) => v.as_ref() == b, - (a, DataType::Dictionary(_, v)) => a == v.as_ref(), + // get the values out of the dictionaries + let (a, b) = match (a, b) { + (DataType::Dictionary(_, va), DataType::Dictionary(_, vb)) => { + (va.as_ref(), vb.as_ref()) + } + (DataType::Dictionary(_, v), b) => (v.as_ref(), b), + (a, DataType::Dictionary(_, v)) => (a, v.as_ref()), + _ => (a, b), + }; + + // now that we've got the values from one/both dictionaries, if the values + // have the same Arrow data type, they're compatible + if a == b { + return true; + } + + // here we have different Arrow data types, but if the array contains the same type of data + // then we consider the type compatible + match a { + // String, StringView and LargeString are compatible + DataType::Utf8 => matches!(b, DataType::LargeUtf8 | DataType::Utf8View), + DataType::Utf8View => matches!(b, DataType::LargeUtf8 | DataType::Utf8), + DataType::LargeUtf8 => matches!(b, DataType::Utf8 | DataType::Utf8View), + + // Binary, BinaryView and LargeBinary are compatible + DataType::Binary => matches!(b, DataType::LargeBinary | DataType::BinaryView), + DataType::BinaryView => matches!(b, DataType::LargeBinary | DataType::Binary), + DataType::LargeBinary => matches!(b, DataType::Binary | DataType::BinaryView), + + // otherwise we have incompatible types _ => false, } } diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs index d235f5fcab64..864c1bf2da45 100644 --- a/parquet/src/arrow/arrow_writer/mod.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -134,16 +134,21 @@ mod levels; /// a given column, the writer can accept multiple Arrow [`DataType`]s that contain the same /// value type. /// -/// Currently, only compatibility between Arrow dictionary and native arrays are supported. -/// Additional type compatibility may be added in future (see [issue #8012](https://github.com/apache/arrow-rs/issues/8012)) +/// For example, the following [`DataType`]s are all logically equivalent and can be written +/// to the same column: +/// * String, LargeString, StringView +/// * Binary, LargeBinary, BinaryView +/// +/// The writer can will also accept both native and dictionary encoded arrays if the dictionaries +/// contain compatible values. /// ``` /// # use std::sync::Arc; -/// # use arrow_array::{DictionaryArray, RecordBatch, StringArray, UInt8Array}; +/// # use arrow_array::{DictionaryArray, LargeStringArray, RecordBatch, StringArray, UInt8Array}; /// # use arrow_schema::{DataType, Field, Schema}; /// # use parquet::arrow::arrow_writer::ArrowWriter; /// let record_batch1 = RecordBatch::try_new( -/// Arc::new(Schema::new(vec![Field::new("col", DataType::Utf8, false)])), -/// vec![Arc::new(StringArray::from_iter_values(vec!["a", "b"]))] +/// Arc::new(Schema::new(vec![Field::new("col", DataType::LargeUtf8, false)])), +/// vec![Arc::new(LargeStringArray::from_iter_values(vec!["a", "b"]))] /// ) /// .unwrap(); /// @@ -566,7 +571,7 @@ impl PageWriter for ArrowPageWriter { None => page, }; - let page_header = page.to_thrift_header(); + let page_header = page.to_thrift_header()?; let header = { let mut header = Vec::with_capacity(1024); @@ -3092,106 +3097,188 @@ mod tests { } #[test] - fn arrow_writer_dict_and_native_compatibility() { - let schema = Arc::new(Schema::new(vec![Field::new( - "a", - DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), - false, - )])); + fn arrow_writer_test_type_compatibility() { + fn ensure_compatible_write(array1: T1, array2: T2, expected_result: T1) + where + T1: Array + 'static, + T2: Array + 'static, + { + let schema1 = Arc::new(Schema::new(vec![Field::new( + "a", + array1.data_type().clone(), + false, + )])); + + let file = tempfile().unwrap(); + let mut writer = + ArrowWriter::try_new(file.try_clone().unwrap(), schema1.clone(), None).unwrap(); - let rb1 = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(DictionaryArray::new( - UInt8Array::from_iter_values(vec![0, 1, 0]), + let rb1 = RecordBatch::try_new(schema1.clone(), vec![Arc::new(array1)]).unwrap(); + writer.write(&rb1).unwrap(); + + let schema2 = Arc::new(Schema::new(vec![Field::new( + "a", + array2.data_type().clone(), + false, + )])); + let rb2 = RecordBatch::try_new(schema2, vec![Arc::new(array2)]).unwrap(); + writer.write(&rb2).unwrap(); + + writer.close().unwrap(); + + let mut record_batch_reader = + ParquetRecordBatchReader::try_new(file.try_clone().unwrap(), 1024).unwrap(); + let actual_batch = record_batch_reader.next().unwrap().unwrap(); + + let expected_batch = + RecordBatch::try_new(schema1, vec![Arc::new(expected_result)]).unwrap(); + assert_eq!(actual_batch, expected_batch); + } + + // check compatibility between native and dictionaries + + ensure_compatible_write( + DictionaryArray::new( + UInt8Array::from_iter_values(vec![0]), + Arc::new(StringArray::from_iter_values(vec!["parquet"])), + ), + StringArray::from_iter_values(vec!["barquet"]), + DictionaryArray::new( + UInt8Array::from_iter_values(vec![0, 1]), Arc::new(StringArray::from_iter_values(vec!["parquet", "barquet"])), - ))], - ) - .unwrap(); + ), + ); - let file = tempfile().unwrap(); - let mut writer = - ArrowWriter::try_new(file.try_clone().unwrap(), rb1.schema(), None).unwrap(); - writer.write(&rb1).unwrap(); - - // check can append another record batch where the field has the same type - // as the dictionary values from the first batch - let schema2 = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, false)])); - let rb2 = RecordBatch::try_new( - schema2, - vec![Arc::new(StringArray::from_iter_values(vec![ - "barquet", "curious", - ]))], - ) - .unwrap(); - writer.write(&rb2).unwrap(); + ensure_compatible_write( + StringArray::from_iter_values(vec!["parquet"]), + DictionaryArray::new( + UInt8Array::from_iter_values(vec![0]), + Arc::new(StringArray::from_iter_values(vec!["barquet"])), + ), + StringArray::from_iter_values(vec!["parquet", "barquet"]), + ); - writer.close().unwrap(); + // check compatibility between dictionaries with different key types - let mut record_batch_reader = - ParquetRecordBatchReader::try_new(file.try_clone().unwrap(), 1024).unwrap(); - let actual_batch = record_batch_reader.next().unwrap().unwrap(); + ensure_compatible_write( + DictionaryArray::new( + UInt8Array::from_iter_values(vec![0]), + Arc::new(StringArray::from_iter_values(vec!["parquet"])), + ), + DictionaryArray::new( + UInt16Array::from_iter_values(vec![0]), + Arc::new(StringArray::from_iter_values(vec!["barquet"])), + ), + DictionaryArray::new( + UInt8Array::from_iter_values(vec![0, 1]), + Arc::new(StringArray::from_iter_values(vec!["parquet", "barquet"])), + ), + ); - let expected_batch = RecordBatch::try_new( - schema, - vec![Arc::new(DictionaryArray::new( - UInt8Array::from_iter_values(vec![0, 1, 0, 1, 2]), - Arc::new(StringArray::from_iter_values(vec![ - "parquet", "barquet", "curious", - ])), - ))], - ) - .unwrap(); + // check compatibility between dictionaries with different value types + ensure_compatible_write( + DictionaryArray::new( + UInt8Array::from_iter_values(vec![0]), + Arc::new(StringArray::from_iter_values(vec!["parquet"])), + ), + DictionaryArray::new( + UInt8Array::from_iter_values(vec![0]), + Arc::new(LargeStringArray::from_iter_values(vec!["barquet"])), + ), + DictionaryArray::new( + UInt8Array::from_iter_values(vec![0, 1]), + Arc::new(StringArray::from_iter_values(vec!["parquet", "barquet"])), + ), + ); - assert_eq!(actual_batch, expected_batch) - } + // check compatibility between a dictionary and a native array with a different type + ensure_compatible_write( + DictionaryArray::new( + UInt8Array::from_iter_values(vec![0]), + Arc::new(StringArray::from_iter_values(vec!["parquet"])), + ), + LargeStringArray::from_iter_values(vec!["barquet"]), + DictionaryArray::new( + UInt8Array::from_iter_values(vec![0, 1]), + Arc::new(StringArray::from_iter_values(vec!["parquet", "barquet"])), + ), + ); - #[test] - fn arrow_writer_native_and_dict_compatibility() { - let schema1 = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, false)])); - let rb1 = RecordBatch::try_new( - schema1.clone(), - vec![Arc::new(StringArray::from_iter_values(vec![ - "parquet", "barquet", - ]))], - ) - .unwrap(); + // check compatibility for string types - let file = tempfile().unwrap(); - let mut writer = - ArrowWriter::try_new(file.try_clone().unwrap(), rb1.schema(), None).unwrap(); - writer.write(&rb1).unwrap(); + ensure_compatible_write( + StringArray::from_iter_values(vec!["parquet"]), + LargeStringArray::from_iter_values(vec!["barquet"]), + StringArray::from_iter_values(vec!["parquet", "barquet"]), + ); - let schema2 = Arc::new(Schema::new(vec![Field::new( - "a", - DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), - false, - )])); + ensure_compatible_write( + LargeStringArray::from_iter_values(vec!["parquet"]), + StringArray::from_iter_values(vec!["barquet"]), + LargeStringArray::from_iter_values(vec!["parquet", "barquet"]), + ); - let rb2 = RecordBatch::try_new( - schema2.clone(), - vec![Arc::new(DictionaryArray::new( - UInt8Array::from_iter_values(vec![0, 1, 0]), - Arc::new(StringArray::from_iter_values(vec!["barquet", "curious"])), - ))], - ) - .unwrap(); - writer.write(&rb2).unwrap(); + ensure_compatible_write( + StringArray::from_iter_values(vec!["parquet"]), + StringViewArray::from_iter_values(vec!["barquet"]), + StringArray::from_iter_values(vec!["parquet", "barquet"]), + ); - writer.close().unwrap(); + ensure_compatible_write( + StringViewArray::from_iter_values(vec!["parquet"]), + StringArray::from_iter_values(vec!["barquet"]), + StringViewArray::from_iter_values(vec!["parquet", "barquet"]), + ); - let mut record_batch_reader = - ParquetRecordBatchReader::try_new(file.try_clone().unwrap(), 1024).unwrap(); - let actual_batch = record_batch_reader.next().unwrap().unwrap(); + ensure_compatible_write( + LargeStringArray::from_iter_values(vec!["parquet"]), + StringViewArray::from_iter_values(vec!["barquet"]), + LargeStringArray::from_iter_values(vec!["parquet", "barquet"]), + ); - let expected_batch = RecordBatch::try_new( - schema1, - vec![Arc::new(StringArray::from_iter_values(vec![ - "parquet", "barquet", "barquet", "curious", "barquet", - ]))], - ) - .unwrap(); + ensure_compatible_write( + StringViewArray::from_iter_values(vec!["parquet"]), + LargeStringArray::from_iter_values(vec!["barquet"]), + StringViewArray::from_iter_values(vec!["parquet", "barquet"]), + ); - assert_eq!(actual_batch, expected_batch) + // check compatibility for binary types + + ensure_compatible_write( + BinaryArray::from_iter_values(vec![b"parquet"]), + LargeBinaryArray::from_iter_values(vec![b"barquet"]), + BinaryArray::from_iter_values(vec![b"parquet", b"barquet"]), + ); + + ensure_compatible_write( + LargeBinaryArray::from_iter_values(vec![b"parquet"]), + BinaryArray::from_iter_values(vec![b"barquet"]), + LargeBinaryArray::from_iter_values(vec![b"parquet", b"barquet"]), + ); + + ensure_compatible_write( + BinaryArray::from_iter_values(vec![b"parquet"]), + BinaryViewArray::from_iter_values(vec![b"barquet"]), + BinaryArray::from_iter_values(vec![b"parquet", b"barquet"]), + ); + + ensure_compatible_write( + BinaryViewArray::from_iter_values(vec![b"parquet"]), + BinaryArray::from_iter_values(vec![b"barquet"]), + BinaryViewArray::from_iter_values(vec![b"parquet", b"barquet"]), + ); + + ensure_compatible_write( + BinaryViewArray::from_iter_values(vec![b"parquet"]), + LargeBinaryArray::from_iter_values(vec![b"barquet"]), + BinaryViewArray::from_iter_values(vec![b"parquet", b"barquet"]), + ); + + ensure_compatible_write( + LargeBinaryArray::from_iter_values(vec![b"parquet"]), + BinaryViewArray::from_iter_values(vec![b"barquet"]), + LargeBinaryArray::from_iter_values(vec![b"parquet", b"barquet"]), + ); } #[test] diff --git a/parquet/src/arrow/async_reader/mod.rs b/parquet/src/arrow/async_reader/mod.rs index eea6176b766b..33b03fbbca95 100644 --- a/parquet/src/arrow/async_reader/mod.rs +++ b/parquet/src/arrow/async_reader/mod.rs @@ -52,7 +52,7 @@ use crate::bloom_filter::{ }; use crate::column::page::{PageIterator, PageReader}; use crate::errors::{ParquetError, Result}; -use crate::file::metadata::{ParquetMetaData, ParquetMetaDataReader}; +use crate::file::metadata::{PageIndexPolicy, ParquetMetaData, ParquetMetaDataReader}; use crate::file::page_index::offset_index::OffsetIndexMetaData; use crate::file::reader::{ChunkReader, Length, SerializedPageReader}; use crate::format::{BloomFilterAlgorithm, BloomFilterCompression, BloomFilterHash}; @@ -175,8 +175,9 @@ impl AsyncFileReader for T { options: Option<&'a ArrowReaderOptions>, ) -> BoxFuture<'a, Result>> { async move { - let metadata_reader = ParquetMetaDataReader::new() - .with_page_indexes(options.is_some_and(|o| o.page_index)); + let metadata_reader = ParquetMetaDataReader::new().with_page_index_policy( + PageIndexPolicy::from(options.is_some_and(|o| o.page_index())), + ); #[cfg(feature = "encryption")] let metadata_reader = metadata_reader.with_decryption_properties( @@ -618,8 +619,7 @@ where metadata: self.metadata.as_ref(), }; - let cache_options_builder = - CacheOptionsBuilder::new(&cache_projection, row_group_cache.clone()); + let cache_options_builder = CacheOptionsBuilder::new(&cache_projection, &row_group_cache); let filter = self.filter.as_mut(); let mut plan_builder = ReadPlanBuilder::new(batch_size).with_selection(selection); @@ -1262,8 +1262,9 @@ mod tests { &'a mut self, options: Option<&'a ArrowReaderOptions>, ) -> BoxFuture<'a, Result>> { - let metadata_reader = ParquetMetaDataReader::new() - .with_page_indexes(options.is_some_and(|o| o.page_index)); + let metadata_reader = ParquetMetaDataReader::new().with_page_index_policy( + PageIndexPolicy::from(options.is_some_and(|o| o.page_index())), + ); self.metadata = Some(Arc::new( metadata_reader.parse_and_finish(&self.data).unwrap(), )); @@ -1931,6 +1932,7 @@ mod tests { } #[tokio::test] + #[allow(deprecated)] async fn test_in_memory_row_group_sparse() { let testdata = arrow::util::test_util::parquet_test_data(); let path = format!("{testdata}/alltypes_tiny_pages.parquet"); @@ -2458,6 +2460,7 @@ mod tests { } #[tokio::test] + #[allow(deprecated)] async fn empty_offset_index_doesnt_panic_in_read_row_group() { use tokio::fs::File; let testdata = arrow::util::test_util::parquet_test_data(); @@ -2483,6 +2486,7 @@ mod tests { } #[tokio::test] + #[allow(deprecated)] async fn non_empty_offset_index_doesnt_panic_in_read_row_group() { use tokio::fs::File; let testdata = arrow::util::test_util::parquet_test_data(); @@ -2507,6 +2511,7 @@ mod tests { } #[tokio::test] + #[allow(deprecated)] async fn empty_offset_index_doesnt_panic_in_column_chunks() { use tempfile::TempDir; use tokio::fs::File; diff --git a/parquet/src/arrow/async_reader/store.rs b/parquet/src/arrow/async_reader/store.rs index 51dc368bc9ea..ce1398b56d37 100644 --- a/parquet/src/arrow/async_reader/store.rs +++ b/parquet/src/arrow/async_reader/store.rs @@ -20,7 +20,7 @@ use std::{ops::Range, sync::Arc}; use crate::arrow::arrow_reader::ArrowReaderOptions; use crate::arrow::async_reader::{AsyncFileReader, MetadataSuffixFetch}; use crate::errors::{ParquetError, Result}; -use crate::file::metadata::{ParquetMetaData, ParquetMetaDataReader}; +use crate::file::metadata::{PageIndexPolicy, ParquetMetaData, ParquetMetaDataReader}; use bytes::Bytes; use futures::{future::BoxFuture, FutureExt, TryFutureExt}; use object_store::{path::Path, ObjectStore}; @@ -200,8 +200,8 @@ impl AsyncFileReader for ParquetObjectReader { ) -> BoxFuture<'a, Result>> { Box::pin(async move { let mut metadata = ParquetMetaDataReader::new() - .with_column_indexes(self.preload_column_index) - .with_offset_indexes(self.preload_offset_index) + .with_column_index_policy(PageIndexPolicy::from(self.preload_column_index)) + .with_offset_index_policy(PageIndexPolicy::from(self.preload_offset_index)) .with_prefetch_hint(self.metadata_size_hint); #[cfg(feature = "encryption")] diff --git a/parquet/src/arrow/async_writer/mod.rs b/parquet/src/arrow/async_writer/mod.rs index 3a74aa7c9c20..4547f71274b7 100644 --- a/parquet/src/arrow/async_writer/mod.rs +++ b/parquet/src/arrow/async_writer/mod.rs @@ -61,7 +61,7 @@ mod store; pub use store::*; use crate::{ - arrow::arrow_writer::ArrowWriterOptions, + arrow::arrow_writer::{ArrowColumnChunk, ArrowColumnWriter, ArrowWriterOptions}, arrow::ArrowWriter, errors::{ParquetError, Result}, file::{metadata::RowGroupMetaData, properties::WriterProperties}, @@ -288,6 +288,22 @@ impl AsyncArrowWriter { Ok(()) } + + /// Create a new row group writer and return its column writers. + pub async fn get_column_writers(&mut self) -> Result> { + let before = self.sync_writer.flushed_row_groups().len(); + let writers = self.sync_writer.get_column_writers()?; + if before != self.sync_writer.flushed_row_groups().len() { + self.do_write().await?; + } + Ok(writers) + } + + /// Append the given column chunks to the file as a new row group. + pub async fn append_row_group(&mut self, chunks: Vec) -> Result<()> { + self.sync_writer.append_row_group(chunks)?; + self.do_write().await + } } #[cfg(test)] @@ -298,6 +314,7 @@ mod tests { use std::sync::Arc; use crate::arrow::arrow_reader::{ParquetRecordBatchReader, ParquetRecordBatchReaderBuilder}; + use crate::arrow::arrow_writer::compute_leaves; use super::*; @@ -332,6 +349,51 @@ mod tests { assert_eq!(to_write, read); } + #[tokio::test] + async fn test_async_arrow_group_writer() { + let col = Arc::new(Int64Array::from_iter_values([4, 5, 6])) as ArrayRef; + let to_write_record = RecordBatch::try_from_iter([("col", col)]).unwrap(); + + let mut buffer = Vec::new(); + let mut writer = + AsyncArrowWriter::try_new(&mut buffer, to_write_record.schema(), None).unwrap(); + + // Use classic API + writer.write(&to_write_record).await.unwrap(); + + let mut writers = writer.get_column_writers().await.unwrap(); + let col = Arc::new(Int64Array::from_iter_values([1, 2, 3])) as ArrayRef; + let to_write_arrow_group = RecordBatch::try_from_iter([("col", col)]).unwrap(); + + for (field, column) in to_write_arrow_group + .schema() + .fields() + .iter() + .zip(to_write_arrow_group.columns()) + { + for leaf in compute_leaves(field.as_ref(), column).unwrap() { + writers[0].write(&leaf).unwrap(); + } + } + + let columns: Vec<_> = writers.into_iter().map(|w| w.close().unwrap()).collect(); + // Append the arrow group as a new row group. Flush in progress + writer.append_row_group(columns).await.unwrap(); + writer.close().await.unwrap(); + + let buffer = Bytes::from(buffer); + let mut reader = ParquetRecordBatchReaderBuilder::try_new(buffer) + .unwrap() + .build() + .unwrap(); + + let col = Arc::new(Int64Array::from_iter_values([4, 5, 6, 1, 2, 3])) as ArrayRef; + let expected = RecordBatch::try_from_iter([("col", col)]).unwrap(); + + let read = reader.next().unwrap().unwrap(); + assert_eq!(expected, read); + } + // Read the data from the test file and write it by the async writer and sync writer. // And then compares the results of the two writers. #[tokio::test] diff --git a/parquet/src/arrow/mod.rs b/parquet/src/arrow/mod.rs index 72626d70e0e5..3ec0d0272f94 100644 --- a/parquet/src/arrow/mod.rs +++ b/parquet/src/arrow/mod.rs @@ -467,6 +467,7 @@ mod test { use super::ProjectionMask; #[test] + #[allow(deprecated)] // Reproducer for https://github.com/apache/arrow-rs/issues/6464 fn test_metadata_read_write_partial_offset() { let parquet_bytes = create_parquet_file(); @@ -514,6 +515,7 @@ mod test { } #[test] + #[allow(deprecated)] fn test_metadata_read_write_roundtrip_page_index() { let parquet_bytes = create_parquet_file(); diff --git a/parquet/src/bloom_filter/mod.rs b/parquet/src/bloom_filter/mod.rs index 384a4a10486e..09302bab8fec 100644 --- a/parquet/src/bloom_filter/mod.rs +++ b/parquet/src/bloom_filter/mod.rs @@ -119,6 +119,13 @@ impl Block { Self(result) } + #[inline] + #[cfg(not(target_endian = "little"))] + fn to_ne_bytes(self) -> [u8; 32] { + // SAFETY: [u32; 8] and [u8; 32] have the same size and neither has invalid bit patterns. + unsafe { std::mem::transmute(self.0) } + } + #[inline] #[cfg(not(target_endian = "little"))] fn to_le_bytes(self) -> [u8; 32] { diff --git a/parquet/src/column/page.rs b/parquet/src/column/page.rs index 1dabe6794f07..a2f683d71f4e 100644 --- a/parquet/src/column/page.rs +++ b/parquet/src/column/page.rs @@ -196,9 +196,21 @@ impl CompressedPage { } /// Returns the thrift page header - pub(crate) fn to_thrift_header(&self) -> PageHeader { + pub(crate) fn to_thrift_header(&self) -> Result { let uncompressed_size = self.uncompressed_size(); let compressed_size = self.compressed_size(); + if uncompressed_size > i32::MAX as usize { + return Err(general_err!( + "Page uncompressed size overflow: {}", + uncompressed_size + )); + } + if compressed_size > i32::MAX as usize { + return Err(general_err!( + "Page compressed size overflow: {}", + compressed_size + )); + } let num_values = self.num_values(); let encoding = self.encoding(); let page_type = self.page_type(); @@ -261,7 +273,7 @@ impl CompressedPage { page_header.dictionary_page_header = Some(dictionary_page_header); } } - page_header + Ok(page_header) } /// Update the compressed buffer for a page. @@ -491,4 +503,28 @@ mod tests { assert_eq!(cpage.encoding(), Encoding::PLAIN); assert_eq!(cpage.data(), &[0, 1, 2]); } + + #[test] + fn test_compressed_page_uncompressed_size_overflow() { + // Test that to_thrift_header fails when uncompressed size exceeds i32::MAX + let data_page = Page::DataPage { + buf: Bytes::from(vec![0, 1, 2]), + num_values: 10, + encoding: Encoding::PLAIN, + def_level_encoding: Encoding::RLE, + rep_level_encoding: Encoding::RLE, + statistics: None, + }; + + // Create a CompressedPage with uncompressed size larger than i32::MAX + let uncompressed_size = (i32::MAX as usize) + 1; + let cpage = CompressedPage::new(data_page, uncompressed_size); + + // Verify that to_thrift_header returns an error + let result = cpage.to_thrift_header(); + assert!(result.is_err()); + + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Page uncompressed size overflow")); + } } diff --git a/parquet/src/column/writer/mod.rs b/parquet/src/column/writer/mod.rs index 9374e226b87f..82b8ba166f14 100644 --- a/parquet/src/column/writer/mod.rs +++ b/parquet/src/column/writer/mod.rs @@ -1104,12 +1104,23 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { rep_levels_byte_len + def_levels_byte_len + values_data.buf.len(); // Data Page v2 compresses values only. - match self.compressor { + let is_compressed = match self.compressor { Some(ref mut cmpr) => { + let buffer_len = buffer.len(); cmpr.compress(&values_data.buf, &mut buffer)?; + if uncompressed_size <= buffer.len() - buffer_len { + buffer.truncate(buffer_len); + buffer.extend_from_slice(&values_data.buf); + false + } else { + true + } } - None => buffer.extend_from_slice(&values_data.buf), - } + None => { + buffer.extend_from_slice(&values_data.buf); + false + } + }; let data_page = Page::DataPageV2 { buf: buffer.into(), @@ -1119,7 +1130,7 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { num_rows: self.page_metrics.num_buffered_rows, def_levels_byte_len: def_levels_byte_len as u32, rep_levels_byte_len: rep_levels_byte_len as u32, - is_compressed: self.compressor.is_some(), + is_compressed, statistics: page_statistics, }; @@ -4236,4 +4247,33 @@ mod tests { .unwrap(); ColumnDescriptor::new(Arc::new(tpe), max_def_level, max_rep_level, path) } + + #[test] + fn test_page_v2_snappy_compression_fallback() { + // Test that PageV2 sets is_compressed to false when Snappy compression increases data size + let page_writer = TestPageWriter {}; + + // Create WriterProperties with PageV2 and Snappy compression + let props = WriterProperties::builder() + .set_writer_version(WriterVersion::PARQUET_2_0) + // Disable dictionary to ensure data is written directly + .set_dictionary_enabled(false) + .set_compression(Compression::SNAPPY) + .build(); + + let mut column_writer = + get_test_column_writer::(Box::new(page_writer), 0, 0, Arc::new(props)); + + // Create small, simple data that Snappy compression will likely increase in size + // due to compression overhead for very small data + let values = vec![ByteArray::from("a")]; + + column_writer.write_batch(&values, None, None).unwrap(); + + let result = column_writer.close().unwrap(); + assert_eq!( + result.metadata.uncompressed_size(), + result.metadata.compressed_size() + ); + } } diff --git a/parquet/src/errors.rs b/parquet/src/errors.rs index 93b2c1b7e028..be08245e956c 100644 --- a/parquet/src/errors.rs +++ b/parquet/src/errors.rs @@ -52,6 +52,9 @@ pub enum ParquetError { /// Returned when a function needs more data to complete properly. The `usize` field indicates /// the total number of bytes required, not the number of additional bytes. NeedMoreData(usize), + /// Returned when a function needs more data to complete properly. + /// The `Range` indicates the range of bytes that are needed. + NeedMoreDataRange(std::ops::Range), } impl std::fmt::Display for ParquetError { @@ -69,6 +72,9 @@ impl std::fmt::Display for ParquetError { } ParquetError::External(e) => write!(fmt, "External: {e}"), ParquetError::NeedMoreData(needed) => write!(fmt, "NeedMoreData: {needed}"), + ParquetError::NeedMoreDataRange(range) => { + write!(fmt, "NeedMoreDataRange: {}..{}", range.start, range.end) + } } } } diff --git a/parquet/src/file/metadata/mod.rs b/parquet/src/file/metadata/mod.rs index 04129c6aa482..f90143104ce2 100644 --- a/parquet/src/file/metadata/mod.rs +++ b/parquet/src/file/metadata/mod.rs @@ -40,11 +40,10 @@ //! metadata into parquet files. To work with metadata directly, //! the following APIs are available: //! -//! * [`ParquetMetaDataReader`] for reading +//! * [`ParquetMetaDataReader`] for reading from a reader for I/O +//! * [`ParquetMetaDataPushDecoder`] for decoding from bytes without I/O //! * [`ParquetMetaDataWriter`] for writing. //! -//! [`ParquetMetaDataReader`]: https://docs.rs/parquet/latest/parquet/file/metadata/struct.ParquetMetaDataReader.html -//! [`ParquetMetaDataWriter`]: https://docs.rs/parquet/latest/parquet/file/metadata/struct.ParquetMetaDataWriter.html //! //! # Examples //! @@ -92,6 +91,7 @@ //! * Same name, different struct //! ``` mod memory; +mod push_decoder; pub(crate) mod reader; mod writer; @@ -120,7 +120,8 @@ use crate::schema::types::{ }; #[cfg(feature = "encryption")] use crate::thrift::{TCompactSliceInputProtocol, TSerializable}; -pub use reader::{FooterTail, ParquetMetaDataReader}; +pub use push_decoder::ParquetMetaDataPushDecoder; +pub use reader::{FooterTail, PageIndexPolicy, ParquetMetaDataReader}; use std::ops::Range; use std::sync::Arc; pub use writer::ParquetMetaDataWriter; diff --git a/parquet/src/file/metadata/push_decoder.rs b/parquet/src/file/metadata/push_decoder.rs new file mode 100644 index 000000000000..811caf4fd46c --- /dev/null +++ b/parquet/src/file/metadata/push_decoder.rs @@ -0,0 +1,559 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::errors::ParquetError; +use crate::file::metadata::{PageIndexPolicy, ParquetMetaData, ParquetMetaDataReader}; +use crate::DecodeResult; + +/// A push decoder for [`ParquetMetaData`]. +/// +/// This structure implements a push API based version of the [`ParquetMetaDataReader`], which +/// decouples the IO from the metadata decoding logic. +/// +/// You can use this decoder to customize your IO operations, as shown in the +/// examples below for minimizing bytes read, prefetching data, or +/// using async IO. +/// +/// # Example +/// +/// The most basic usage is to feed the decoder with the necessary byte ranges +/// as requested as shown below. +/// +/// ```rust +/// # use std::ops::Range; +/// # use bytes::Bytes; +/// # use arrow_array::record_batch; +/// # use parquet::DecodeResult; +/// # use parquet::arrow::ArrowWriter; +/// # use parquet::errors::ParquetError; +/// # use parquet::file::metadata::{ParquetMetaData, ParquetMetaDataPushDecoder}; +/// # +/// # fn decode_metadata() -> Result { +/// # let file_bytes = { +/// # let mut buffer = vec![0]; +/// # let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap(); +/// # let mut writer = ArrowWriter::try_new(&mut buffer, batch.schema(), None).unwrap(); +/// # writer.write(&batch).unwrap(); +/// # writer.close().unwrap(); +/// # Bytes::from(buffer) +/// # }; +/// # // mimic IO by returning a function that returns the bytes for a given range +/// # let get_range = |range: &Range| -> Bytes { +/// # let start = range.start as usize; +/// # let end = range.end as usize; +/// # file_bytes.slice(start..end) +/// # }; +/// # +/// # let file_len = file_bytes.len() as u64; +/// // The `ParquetMetaDataPushDecoder` needs to know the file length. +/// let mut decoder = ParquetMetaDataPushDecoder::try_new(file_len).unwrap(); +/// // try to decode the metadata. If more data is needed, the decoder will tell you what ranges +/// loop { +/// match decoder.try_decode() { +/// Ok(DecodeResult::Data(metadata)) => { return Ok(metadata); } // decode successful +/// Ok(DecodeResult::NeedsData(ranges)) => { +/// // The decoder needs more data +/// // +/// // In this example, we call a function that returns the bytes for each given range. +/// // In a real application, you would likely read the data from a file or network. +/// let data = ranges.iter().map(|range| get_range(range)).collect(); +/// // Push the data into the decoder and try to decode again on the next iteration. +/// decoder.push_ranges(ranges, data).unwrap(); +/// } +/// Ok(DecodeResult::Finished) => { unreachable!("returned metadata in previous match arm") } +/// Err(e) => return Err(e), +/// } +/// } +/// # } +/// ``` +/// +/// # Example with "prefetching" +/// +/// By default, the [`ParquetMetaDataPushDecoder`] will request only the exact byte +/// ranges it needs. This minimizes the number of bytes read, however it +/// requires at least two IO operations to read the metadata - one to read the +/// footer and then one to read the metadata. +/// +/// If the file has a "Page Index" (see [Self::with_page_index_policy]), three +/// IO operations are required to read the metadata, as the page index is +/// not part of the normal metadata footer. +/// +/// To reduce the number of IO operations in systems with high per operation +/// overhead (e.g. cloud storage), you can "prefetch" the data and then push +/// the data into the decoder before calling [`Self::try_decode`]. If you do +/// not push enough bytes, the decoder will return the ranges that are still +/// needed. +/// +/// This approach can also be used when you have the entire file already in memory +/// for other reasons. +/// +/// ```rust +/// # use std::ops::Range; +/// # use bytes::Bytes; +/// # use arrow_array::record_batch; +/// # use parquet::DecodeResult; +/// # use parquet::arrow::ArrowWriter; +/// # use parquet::errors::ParquetError; +/// # use parquet::file::metadata::{ParquetMetaData, ParquetMetaDataPushDecoder}; +/// # +/// # fn decode_metadata() -> Result { +/// # let file_bytes = { +/// # let mut buffer = vec![0]; +/// # let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap(); +/// # let mut writer = ArrowWriter::try_new(&mut buffer, batch.schema(), None).unwrap(); +/// # writer.write(&batch).unwrap(); +/// # writer.close().unwrap(); +/// # Bytes::from(buffer) +/// # }; +/// # +/// let file_len = file_bytes.len() as u64; +/// // For this example, we "prefetch" all the bytes which we have in memory, +/// // but in a real application, you would likely read a chunk from the end +/// // for example 1MB. +/// let prefetched_bytes = file_bytes.clone(); +/// let mut decoder = ParquetMetaDataPushDecoder::try_new(file_len).unwrap(); +/// // push the prefetched bytes into the decoder +/// decoder.push_ranges(vec![0..file_len], vec![prefetched_bytes]).unwrap(); +/// // The decoder will now be able to decode the metadata. Note in a real application, +/// // unless you can guarantee that the pushed data is enough to decode the metadata, +/// // you still need to call `try_decode` in a loop until it returns `DecodeResult::Data` +/// // as shown in the previous example +/// match decoder.try_decode() { +/// Ok(DecodeResult::Data(metadata)) => { return Ok(metadata); } // decode successful +/// other => { panic!("expected DecodeResult::Data, got: {other:?}") } +/// } +/// # } +/// ``` +/// +/// # Example using [`AsyncRead`] +/// +/// [`ParquetMetaDataPushDecoder`] is designed to work with any data source that can +/// provide byte ranges, including async IO sources. However, it does not +/// implement async IO itself. To use async IO, you simply write an async +/// wrapper around it that reads the required byte ranges and pushes them into the +/// decoder. +/// +/// ```rust +/// # use std::ops::Range; +/// # use bytes::Bytes; +/// use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt}; +/// # use arrow_array::record_batch; +/// # use parquet::DecodeResult; +/// # use parquet::arrow::ArrowWriter; +/// # use parquet::errors::ParquetError; +/// # use parquet::file::metadata::{ParquetMetaData, ParquetMetaDataPushDecoder}; +/// # +/// // This function decodes Parquet Metadata from anything that implements +/// // [`AsyncRead`] and [`AsyncSeek`] such as a tokio::fs::File +/// async fn decode_metadata( +/// file_len: u64, +/// mut async_source: impl AsyncRead + AsyncSeek + Unpin +/// ) -> Result { +/// // We need a ParquetMetaDataPushDecoder to decode the metadata. +/// let mut decoder = ParquetMetaDataPushDecoder::try_new(file_len).unwrap(); +/// loop { +/// match decoder.try_decode() { +/// Ok(DecodeResult::Data(metadata)) => { return Ok(metadata); } // decode successful +/// Ok(DecodeResult::NeedsData(ranges)) => { +/// // The decoder needs more data +/// // +/// // In this example we use the AsyncRead and AsyncSeek traits to read the +/// // required ranges from the async source. +/// let mut data = Vec::with_capacity(ranges.len()); +/// for range in &ranges { +/// let mut buffer = vec![0; (range.end - range.start) as usize]; +/// async_source.seek(std::io::SeekFrom::Start(range.start)).await?; +/// async_source.read_exact(&mut buffer).await?; +/// data.push(Bytes::from(buffer)); +/// } +/// // Push the data into the decoder and try to decode again on the next iteration. +/// decoder.push_ranges(ranges, data).unwrap(); +/// } +/// Ok(DecodeResult::Finished) => { unreachable!("returned metadata in previous match arm") } +/// Err(e) => return Err(e), +/// } +/// } +/// } +/// ``` +/// [`AsyncRead`]: tokio::io::AsyncRead +#[derive(Debug)] +pub struct ParquetMetaDataPushDecoder { + done: bool, + metadata_reader: ParquetMetaDataReader, + buffers: crate::util::push_buffers::PushBuffers, +} + +impl ParquetMetaDataPushDecoder { + /// Create a new `ParquetMetaDataPushDecoder` with the given file length. + /// + /// By default, this will read page indexes and column indexes. See + /// [`ParquetMetaDataPushDecoder::with_page_index_policy`] for more detail. + /// + /// See examples on [`ParquetMetaDataPushDecoder`]. + pub fn try_new(file_len: u64) -> Result { + if file_len < 8 { + return Err(ParquetError::General(format!( + "Parquet files are at least 8 bytes long, but file length is {file_len}" + ))); + }; + + let metadata_reader = + ParquetMetaDataReader::new().with_page_index_policy(PageIndexPolicy::Optional); + + Ok(Self { + done: false, + metadata_reader, + buffers: crate::util::push_buffers::PushBuffers::new(file_len), + }) + } + + /// Enable or disable reading the page index structures described in + /// "[Parquet page index] Layout to Support Page Skipping". + /// + /// Defaults to [`PageIndexPolicy::Optional`] + /// + /// This requires + /// 1. The Parquet file to have been written with page indexes + /// 2. Additional data to be pushed into the decoder (as the page indexes are not part of the thrift footer) + /// + /// [Parquet page index]: https://github.com/apache/parquet-format/blob/master/PageIndex.md + pub fn with_page_index_policy(mut self, page_index_policy: PageIndexPolicy) -> Self { + self.metadata_reader = self + .metadata_reader + .with_page_index_policy(page_index_policy); + self + } + + /// Push the data into the decoder's buffer. + /// + /// The decoder does not immediately attempt to decode the metadata + /// after pushing data. Instead, it accumulates the pushed data until you + /// call [`Self::try_decode`]. + /// + /// # Determining required data: + /// + /// To determine what ranges are required to decode the metadata, you can + /// either: + /// + /// 1. Call [`Self::try_decode`] first to get the exact ranges required (see + /// example on [`Self`]) + /// + /// 2. Speculatively push any data that you have available, which may + /// include more than the footer data or requested bytes. + /// + /// Speculatively pushing data can be used when "prefetching" data. See + /// example on [`Self`] + pub fn push_ranges( + &mut self, + ranges: Vec>, + buffers: Vec, + ) -> std::result::Result<(), String> { + if self.done { + return Err( + "ParquetMetaDataPushDecoder: cannot push data after decoding is finished" + .to_string(), + ); + } + self.buffers.push_ranges(ranges, buffers); + Ok(()) + } + + /// Try to decode the metadata from the pushed data, returning the + /// decoded metadata or an error if not enough data is available. + pub fn try_decode( + &mut self, + ) -> std::result::Result, ParquetError> { + if self.done { + return Ok(DecodeResult::Finished); + } + + // need to have the last 8 bytes of the file to decode the metadata + let file_len = self.buffers.file_len(); + if !self.buffers.has_range(&(file_len - 8..file_len)) { + #[expect(clippy::single_range_in_vec_init)] + return Ok(DecodeResult::NeedsData(vec![file_len - 8..file_len])); + } + + // Try to parse the metadata from the buffers we have. + // + // If we don't have enough data, returns a `ParquetError::NeedMoreData` + // with the number of bytes needed to complete the metadata parsing. + // + // If we have enough data, returns `Ok(())` and we can complete + // the metadata parsing. + let maybe_metadata = self + .metadata_reader + .try_parse_sized(&self.buffers, self.buffers.file_len()); + + match maybe_metadata { + Ok(()) => { + // Metadata successfully parsed, proceed to decode the row groups + let metadata = self.metadata_reader.finish()?; + self.done = true; + Ok(DecodeResult::Data(metadata)) + } + + Err(ParquetError::NeedMoreData(needed)) => { + let needed = needed as u64; + let Some(start_offset) = file_len.checked_sub(needed) else { + return Err(ParquetError::General(format!( + "Parquet metadata reader needs at least {needed} bytes, but file length is only {file_len}" + ))); + }; + let needed_range = start_offset..start_offset + needed; + // needs `needed_range` bytes at the end of the file + Ok(DecodeResult::NeedsData(vec![needed_range])) + } + Err(ParquetError::NeedMoreDataRange(range)) => Ok(DecodeResult::NeedsData(vec![range])), + + Err(e) => Err(e), // some other error, pass back + } + } +} + +// These tests use the arrow writer to create a parquet file in memory +// so they need the arrow feature and the test feature +#[cfg(all(test, feature = "arrow"))] +mod tests { + use super::*; + use crate::arrow::ArrowWriter; + use crate::file::properties::WriterProperties; + use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringViewArray}; + use bytes::Bytes; + use std::fmt::Debug; + use std::ops::Range; + use std::sync::{Arc, LazyLock}; + + /// It is possible to decode the metadata from the entire file at once before being asked + #[test] + fn test_metadata_decoder_all_data() { + let file_len = test_file_len(); + let mut metadata_decoder = ParquetMetaDataPushDecoder::try_new(file_len).unwrap(); + // Push the entire file data into the metadata decoder + push_ranges_to_metadata_decoder(&mut metadata_decoder, vec![test_file_range()]); + + // should be able to decode the metadata without needing more data + let metadata = expect_data(metadata_decoder.try_decode()); + + assert_eq!(metadata.num_row_groups(), 2); + assert_eq!(metadata.row_group(0).num_rows(), 200); + assert_eq!(metadata.row_group(1).num_rows(), 200); + assert!(metadata.column_index().is_some()); + assert!(metadata.offset_index().is_some()); + } + + /// It is possible to feed some, but not all, of the footer into the metadata decoder + /// before asked. This avoids multiple IO requests + #[test] + fn test_metadata_decoder_prefetch_success() { + let file_len = test_file_len(); + let mut metadata_decoder = ParquetMetaDataPushDecoder::try_new(file_len).unwrap(); + // simulate pre-fetching the last 2k bytes of the file without asking the decoder + let prefetch_range = (file_len - 2 * 1024)..file_len; + push_ranges_to_metadata_decoder(&mut metadata_decoder, vec![prefetch_range]); + + // expect the decoder has enough data to decode the metadata + let metadata = expect_data(metadata_decoder.try_decode()); + expect_finished(metadata_decoder.try_decode()); + assert_eq!(metadata.num_row_groups(), 2); + assert_eq!(metadata.row_group(0).num_rows(), 200); + assert_eq!(metadata.row_group(1).num_rows(), 200); + assert!(metadata.column_index().is_some()); + assert!(metadata.offset_index().is_some()); + } + + /// It is possible to pre-fetch some, but not all, of the necessary data + /// data + #[test] + fn test_metadata_decoder_prefetch_retry() { + let file_len = test_file_len(); + let mut metadata_decoder = ParquetMetaDataPushDecoder::try_new(file_len).unwrap(); + // simulate pre-fetching the last 1500 bytes of the file. + // this is enough to read the footer thrift metadata, but not the offset indexes + let prefetch_range = (file_len - 1500)..file_len; + push_ranges_to_metadata_decoder(&mut metadata_decoder, vec![prefetch_range]); + + // expect another request is needed to read the offset indexes (note + // try_decode only returns NeedsData once, whereas without any prefetching it would + // return NeedsData three times) + let ranges = expect_needs_data(metadata_decoder.try_decode()); + push_ranges_to_metadata_decoder(&mut metadata_decoder, ranges); + + // expect the decoder has enough data to decode the metadata + let metadata = expect_data(metadata_decoder.try_decode()); + expect_finished(metadata_decoder.try_decode()); + + assert_eq!(metadata.num_row_groups(), 2); + assert_eq!(metadata.row_group(0).num_rows(), 200); + assert_eq!(metadata.row_group(1).num_rows(), 200); + assert!(metadata.column_index().is_some()); + assert!(metadata.offset_index().is_some()); + } + + /// Decode the metadata incrementally, simulating a scenario where exactly the data needed + /// is read in each step + #[test] + fn test_metadata_decoder_incremental() { + let file_len = TEST_FILE_DATA.len() as u64; + let mut metadata_decoder = ParquetMetaDataPushDecoder::try_new(file_len).unwrap(); + let ranges = expect_needs_data(metadata_decoder.try_decode()); + assert_eq!(ranges.len(), 1); + assert_eq!(ranges[0], test_file_len() - 8..test_file_len()); + push_ranges_to_metadata_decoder(&mut metadata_decoder, ranges); + + // expect the first request to read the footer + let ranges = expect_needs_data(metadata_decoder.try_decode()); + push_ranges_to_metadata_decoder(&mut metadata_decoder, ranges); + + // expect the second request to read the offset indexes + let ranges = expect_needs_data(metadata_decoder.try_decode()); + push_ranges_to_metadata_decoder(&mut metadata_decoder, ranges); + + // expect the third request to read the actual data + let metadata = expect_data(metadata_decoder.try_decode()); + expect_finished(metadata_decoder.try_decode()); + + assert_eq!(metadata.num_row_groups(), 2); + assert_eq!(metadata.row_group(0).num_rows(), 200); + assert_eq!(metadata.row_group(1).num_rows(), 200); + assert!(metadata.column_index().is_some()); + assert!(metadata.offset_index().is_some()); + } + + /// Decode the metadata incrementally, but without reading the page indexes + /// (so only two requests) + #[test] + fn test_metadata_decoder_incremental_no_page_index() { + let file_len = TEST_FILE_DATA.len() as u64; + let mut metadata_decoder = ParquetMetaDataPushDecoder::try_new(file_len) + .unwrap() + .with_page_index_policy(PageIndexPolicy::Skip); + let ranges = expect_needs_data(metadata_decoder.try_decode()); + assert_eq!(ranges.len(), 1); + assert_eq!(ranges[0], test_file_len() - 8..test_file_len()); + push_ranges_to_metadata_decoder(&mut metadata_decoder, ranges); + + // expect the first request to read the footer + let ranges = expect_needs_data(metadata_decoder.try_decode()); + push_ranges_to_metadata_decoder(&mut metadata_decoder, ranges); + + // expect NO second request to read the offset indexes, should just cough up the metadata + let metadata = expect_data(metadata_decoder.try_decode()); + expect_finished(metadata_decoder.try_decode()); + + assert_eq!(metadata.num_row_groups(), 2); + assert_eq!(metadata.row_group(0).num_rows(), 200); + assert_eq!(metadata.row_group(1).num_rows(), 200); + assert!(metadata.column_index().is_none()); // of course, we did not read the column index + assert!(metadata.offset_index().is_none()); // or the offset index + } + + static TEST_BATCH: LazyLock = LazyLock::new(|| { + // Input batch has 400 rows, with 3 columns: "a", "b", "c" + // Note c is a different types (so the data page sizes will be different) + let a: ArrayRef = Arc::new(Int64Array::from_iter_values(0..400)); + let b: ArrayRef = Arc::new(Int64Array::from_iter_values(400..800)); + let c: ArrayRef = Arc::new(StringViewArray::from_iter_values((0..400).map(|i| { + if i % 2 == 0 { + format!("string_{i}") + } else { + format!("A string larger than 12 bytes and thus not inlined {i}") + } + }))); + + RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap() + }); + + /// Create a parquet file in memory for testing. See [`test_file_range`] for details. + static TEST_FILE_DATA: LazyLock = LazyLock::new(|| { + let input_batch = &TEST_BATCH; + let mut output = Vec::new(); + + let writer_options = WriterProperties::builder() + .set_max_row_group_size(200) + .set_data_page_row_count_limit(100) + .build(); + let mut writer = + ArrowWriter::try_new(&mut output, input_batch.schema(), Some(writer_options)).unwrap(); + + // since the limits are only enforced on batch boundaries, write the input + // batch in chunks of 50 + let mut row_remain = input_batch.num_rows(); + while row_remain > 0 { + let chunk_size = row_remain.min(50); + let chunk = input_batch.slice(input_batch.num_rows() - row_remain, chunk_size); + writer.write(&chunk).unwrap(); + row_remain -= chunk_size; + } + writer.close().unwrap(); + Bytes::from(output) + }); + + /// Return the length of the test file in bytes + fn test_file_len() -> u64 { + TEST_FILE_DATA.len() as u64 + } + + /// Return the range of the entire test file + fn test_file_range() -> Range { + 0..test_file_len() + } + + /// Return a slice of the test file data from the given range + pub fn test_file_slice(range: Range) -> Bytes { + let start: usize = range.start.try_into().unwrap(); + let end: usize = range.end.try_into().unwrap(); + TEST_FILE_DATA.slice(start..end) + } + + /// Push the given ranges to the metadata decoder, simulating reading from a file + fn push_ranges_to_metadata_decoder( + metadata_decoder: &mut ParquetMetaDataPushDecoder, + ranges: Vec>, + ) { + let data = ranges + .iter() + .map(|range| test_file_slice(range.clone())) + .collect::>(); + metadata_decoder.push_ranges(ranges, data).unwrap(); + } + + /// Expect that the [`DecodeResult`] is a [`DecodeResult::Data`] and return the corresponding element + fn expect_data(result: Result, ParquetError>) -> T { + match result.expect("Expected Ok(DecodeResult::Data(T))") { + DecodeResult::Data(data) => data, + result => panic!("Expected DecodeResult::Data, got {result:?}"), + } + } + + /// Expect that the [`DecodeResult`] is a [`DecodeResult::NeedsData`] and return the corresponding ranges + fn expect_needs_data( + result: Result, ParquetError>, + ) -> Vec> { + match result.expect("Expected Ok(DecodeResult::NeedsData{ranges})") { + DecodeResult::NeedsData(ranges) => ranges, + result => panic!("Expected DecodeResult::NeedsData, got {result:?}"), + } + } + + fn expect_finished(result: Result, ParquetError>) { + match result.expect("Expected Ok(DecodeResult::Finished)") { + DecodeResult::Finished => {} + result => panic!("Expected DecodeResult::Finished, got {result:?}"), + } + } +} diff --git a/parquet/src/file/metadata/reader.rs b/parquet/src/file/metadata/reader.rs index 356713837530..8d92d1e0aa8d 100644 --- a/parquet/src/file/metadata/reader.rs +++ b/parquet/src/file/metadata/reader.rs @@ -69,11 +69,11 @@ use crate::file::page_index::offset_index::OffsetIndexMetaData; /// assert!(metadata.column_index().is_some()); /// assert!(metadata.offset_index().is_some()); /// ``` -#[derive(Default)] +#[derive(Default, Debug)] pub struct ParquetMetaDataReader { metadata: Option, - column_index: bool, - offset_index: bool, + column_index: PageIndexPolicy, + offset_index: PageIndexPolicy, prefetch_hint: Option, // Size of the serialized thrift metadata plus the 8 byte footer. Only set if // `self.parse_metadata` is called. @@ -82,6 +82,27 @@ pub struct ParquetMetaDataReader { file_decryption_properties: Option, } +/// Describes the policy for reading page indexes +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum PageIndexPolicy { + /// Do not read the page index. + #[default] + Skip, + /// Read the page index if it exists, otherwise do not error. + Optional, + /// Require the page index to exist, and error if it does not. + Required, +} + +impl From for PageIndexPolicy { + fn from(value: bool) -> Self { + match value { + true => Self::Required, + false => Self::Skip, + } + } +} + /// Describes how the footer metadata is stored /// /// This is parsed from the last 8 bytes of the Parquet file @@ -118,27 +139,49 @@ impl ParquetMetaDataReader { } /// Enable or disable reading the page index structures described in - /// "[Parquet page index]: Layout to Support Page Skipping". Equivalent to: - /// `self.with_column_indexes(val).with_offset_indexes(val)` + /// "[Parquet page index]: Layout to Support Page Skipping". /// /// [Parquet page index]: https://github.com/apache/parquet-format/blob/master/PageIndex.md + #[deprecated(since = "56.1.0", note = "Use `with_page_index_policy` instead")] pub fn with_page_indexes(self, val: bool) -> Self { - self.with_column_indexes(val).with_offset_indexes(val) + let policy = PageIndexPolicy::from(val); + self.with_column_index_policy(policy) + .with_offset_index_policy(policy) } /// Enable or disable reading the Parquet [ColumnIndex] structure. /// /// [ColumnIndex]: https://github.com/apache/parquet-format/blob/master/PageIndex.md - pub fn with_column_indexes(mut self, val: bool) -> Self { - self.column_index = val; - self + #[deprecated(since = "56.1.0", note = "Use `with_column_index_policy` instead")] + pub fn with_column_indexes(self, val: bool) -> Self { + let policy = PageIndexPolicy::from(val); + self.with_column_index_policy(policy) } /// Enable or disable reading the Parquet [OffsetIndex] structure. /// /// [OffsetIndex]: https://github.com/apache/parquet-format/blob/master/PageIndex.md - pub fn with_offset_indexes(mut self, val: bool) -> Self { - self.offset_index = val; + #[deprecated(since = "56.1.0", note = "Use `with_offset_index_policy` instead")] + pub fn with_offset_indexes(self, val: bool) -> Self { + let policy = PageIndexPolicy::from(val); + self.with_offset_index_policy(policy) + } + + /// Sets the [`PageIndexPolicy`] for the column and offset indexes + pub fn with_page_index_policy(self, policy: PageIndexPolicy) -> Self { + self.with_column_index_policy(policy) + .with_offset_index_policy(policy) + } + + /// Sets the [`PageIndexPolicy`] for the column index + pub fn with_column_index_policy(mut self, policy: PageIndexPolicy) -> Self { + self.column_index = policy; + self + } + + /// Sets the [`PageIndexPolicy`] for the offset index + pub fn with_offset_index_policy(mut self, policy: PageIndexPolicy) -> Self { + self.offset_index = policy; self } @@ -277,7 +320,7 @@ impl ParquetMetaDataReader { /// bytes = get_bytes(&file, len - needed as u64..len); /// // If file metadata was read only read page indexes, otherwise continue loop /// if reader.has_metadata() { - /// reader.read_page_indexes_sized(&bytes, len); + /// reader.read_page_indexes_sized(&bytes, len).unwrap(); /// break; /// } /// } @@ -307,7 +350,8 @@ impl ParquetMetaDataReader { }; // we can return if page indexes aren't requested - if !self.column_index && !self.offset_index { + if self.column_index == PageIndexPolicy::Skip && self.offset_index == PageIndexPolicy::Skip + { return Ok(()); } @@ -348,8 +392,7 @@ impl ParquetMetaDataReader { // Requested range starts beyond EOF if range.end > file_size { return Err(eof_err!( - "Parquet file too small. Range {:?} is beyond file bounds {file_size}", - range + "Parquet file too small. Range {range:?} is beyond file bounds {file_size}", )); } else { // Ask for a larger buffer @@ -365,9 +408,7 @@ impl ParquetMetaDataReader { let metadata_range = file_size.saturating_sub(metadata_size as u64)..file_size; if range.end > metadata_range.start { return Err(eof_err!( - "Parquet file too small. Page index range {:?} overlaps with file metadata {:?}", - range, - metadata_range + "Parquet file too small. Page index range {range:?} overlaps with file metadata {metadata_range:?}" , )); } } @@ -424,7 +465,8 @@ impl ParquetMetaDataReader { self.metadata = Some(metadata); // we can return if page indexes aren't requested - if !self.column_index && !self.offset_index { + if self.column_index == PageIndexPolicy::Skip && self.offset_index == PageIndexPolicy::Skip + { return Ok(()); } @@ -446,7 +488,8 @@ impl ParquetMetaDataReader { self.metadata = Some(metadata); // we can return if page indexes aren't requested - if !self.column_index && !self.offset_index { + if self.column_index == PageIndexPolicy::Skip && self.offset_index == PageIndexPolicy::Skip + { return Ok(()); } @@ -500,7 +543,7 @@ impl ParquetMetaDataReader { fn parse_column_index(&mut self, bytes: &Bytes, start_offset: u64) -> Result<()> { let metadata = self.metadata.as_mut().unwrap(); - if self.column_index { + if self.column_index != PageIndexPolicy::Skip { let index = metadata .row_groups() .iter() @@ -526,6 +569,7 @@ impl ParquetMetaDataReader { .collect::>>() }) .collect::>>()?; + metadata.set_column_index(Some(index)); } Ok(()) @@ -572,34 +616,44 @@ impl ParquetMetaDataReader { fn parse_offset_index(&mut self, bytes: &Bytes, start_offset: u64) -> Result<()> { let metadata = self.metadata.as_mut().unwrap(); - if self.offset_index { - let index = metadata - .row_groups() - .iter() - .enumerate() - .map(|(rg_idx, x)| { - x.columns() - .iter() - .enumerate() - .map(|(col_idx, c)| match c.offset_index_range() { - Some(r) => { - let r_start = usize::try_from(r.start - start_offset)?; - let r_end = usize::try_from(r.end - start_offset)?; - Self::parse_single_offset_index( - &bytes[r_start..r_end], - metadata, - c, - rg_idx, - col_idx, - ) + if self.offset_index != PageIndexPolicy::Skip { + let row_groups = metadata.row_groups(); + let mut all_indexes = Vec::with_capacity(row_groups.len()); + for (rg_idx, x) in row_groups.iter().enumerate() { + let mut row_group_indexes = Vec::with_capacity(x.columns().len()); + for (col_idx, c) in x.columns().iter().enumerate() { + let result = match c.offset_index_range() { + Some(r) => { + let r_start = usize::try_from(r.start - start_offset)?; + let r_end = usize::try_from(r.end - start_offset)?; + Self::parse_single_offset_index( + &bytes[r_start..r_end], + metadata, + c, + rg_idx, + col_idx, + ) + } + None => Err(general_err!("missing offset index")), + }; + + match result { + Ok(index) => row_group_indexes.push(index), + Err(e) => { + if self.offset_index == PageIndexPolicy::Required { + return Err(e); + } else { + // Invalidate and return + metadata.set_column_index(None); + metadata.set_offset_index(None); + return Ok(()); } - None => Err(general_err!("missing offset index")), - }) - .collect::>>() - }) - .collect::>>()?; - - metadata.set_offset_index(Some(index)); + } + } + } + all_indexes.push(row_group_indexes); + } + metadata.set_offset_index(Some(all_indexes)); } Ok(()) } @@ -651,10 +705,10 @@ impl ParquetMetaDataReader { let mut range = None; let metadata = self.metadata.as_ref().unwrap(); for c in metadata.row_groups().iter().flat_map(|r| r.columns()) { - if self.column_index { + if self.column_index != PageIndexPolicy::Skip { range = acc_range(range, c.column_index_range()); } - if self.offset_index { + if self.offset_index != PageIndexPolicy::Skip { range = acc_range(range, c.offset_index_range()); } } @@ -1185,6 +1239,7 @@ mod tests { } #[test] + #[allow(deprecated)] fn test_try_parse() { let file = get_test_file("alltypes_tiny_pages.parquet"); let len = file.len(); @@ -1302,6 +1357,10 @@ mod tests { #[cfg(all(feature = "async", feature = "arrow", test))] mod async_tests { use super::*; + + use arrow::{array::Int32Array, datatypes::DataType}; + use arrow_array::RecordBatch; + use arrow_schema::{Field, Schema}; use bytes::Bytes; use futures::future::BoxFuture; use futures::FutureExt; @@ -1310,7 +1369,10 @@ mod async_tests { use std::io::{Read, Seek, SeekFrom}; use std::ops::Range; use std::sync::atomic::{AtomicUsize, Ordering}; + use tempfile::NamedTempFile; + use crate::arrow::ArrowWriter; + use crate::file::properties::WriterProperties; use crate::file::reader::Length; use crate::util::test_common::file_util::get_test_file; @@ -1562,6 +1624,7 @@ mod async_tests { } #[tokio::test] + #[allow(deprecated)] async fn test_page_index() { let mut file = get_test_file("alltypes_tiny_pages.parquet"); let len = file.len(); @@ -1648,4 +1711,50 @@ mod async_tests { assert_eq!(fetch_count.load(Ordering::SeqCst), 1); assert!(metadata.offset_index().is_some() && metadata.column_index().is_some()); } + + fn write_parquet_file(offset_index_disabled: bool) -> Result { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?; + + let file = NamedTempFile::new().unwrap(); + + // Write properties with page index disabled + let props = WriterProperties::builder() + .set_offset_index_disabled(offset_index_disabled) + .build(); + + let mut writer = ArrowWriter::try_new(file.reopen()?, schema, Some(props))?; + writer.write(&batch)?; + writer.close()?; + + Ok(file) + } + + fn read_and_check(file: &File, policy: PageIndexPolicy) -> Result { + let mut reader = ParquetMetaDataReader::new().with_page_index_policy(policy); + reader.try_parse(file)?; + reader.finish() + } + + #[test] + fn test_page_index_policy() { + // With page index + let f = write_parquet_file(false).unwrap(); + read_and_check(f.as_file(), PageIndexPolicy::Required).unwrap(); + read_and_check(f.as_file(), PageIndexPolicy::Optional).unwrap(); + read_and_check(f.as_file(), PageIndexPolicy::Skip).unwrap(); + + // Without page index + let f = write_parquet_file(true).unwrap(); + let res = read_and_check(f.as_file(), PageIndexPolicy::Required); + assert!(matches!( + res, + Err(ParquetError::General(e)) if e == "missing offset index" + )); + read_and_check(f.as_file(), PageIndexPolicy::Optional).unwrap(); + read_and_check(f.as_file(), PageIndexPolicy::Skip).unwrap(); + } } diff --git a/parquet/src/file/properties.rs b/parquet/src/file/properties.rs index 96e3706e27d7..603db6660f45 100644 --- a/parquet/src/file/properties.rs +++ b/parquet/src/file/properties.rs @@ -193,6 +193,12 @@ impl WriterProperties { WriterPropertiesBuilder::default() } + /// Converts this [`WriterProperties`] into a [`WriterPropertiesBuilder`] + /// Used for mutating existing property settings + pub fn into_builder(self) -> WriterPropertiesBuilder { + self.into() + } + /// Returns data page size limit. /// /// Note: this is a best effort limit based on the write batch size @@ -435,6 +441,7 @@ impl WriterProperties { /// Builder for [`WriterProperties`] Parquet writer configuration. /// /// See example on [`WriterProperties`] +#[derive(Debug, Clone)] pub struct WriterPropertiesBuilder { data_page_size_limit: usize, data_page_row_count_limit: usize, @@ -934,6 +941,30 @@ impl WriterPropertiesBuilder { } } +impl From for WriterPropertiesBuilder { + fn from(props: WriterProperties) -> Self { + WriterPropertiesBuilder { + data_page_size_limit: props.data_page_size_limit, + data_page_row_count_limit: props.data_page_row_count_limit, + write_batch_size: props.write_batch_size, + max_row_group_size: props.max_row_group_size, + bloom_filter_position: props.bloom_filter_position, + writer_version: props.writer_version, + created_by: props.created_by, + offset_index_disabled: props.offset_index_disabled, + key_value_metadata: props.key_value_metadata, + default_column_properties: props.default_column_properties, + column_properties: props.column_properties, + sorting_columns: props.sorting_columns, + column_index_truncate_length: props.column_index_truncate_length, + statistics_truncate_length: props.statistics_truncate_length, + coerce_types: props.coerce_types, + #[cfg(feature = "encryption")] + file_encryption_properties: props.file_encryption_properties, + } + } +} + /// Controls the level of statistics to be computed by the writer and stored in /// the parquet file. /// @@ -1377,50 +1408,59 @@ mod tests { .set_column_bloom_filter_fpp(ColumnPath::from("col"), 0.1) .build(); - assert_eq!(props.writer_version(), WriterVersion::PARQUET_2_0); - assert_eq!(props.data_page_size_limit(), 10); - assert_eq!(props.dictionary_page_size_limit(), 20); - assert_eq!(props.write_batch_size(), 30); - assert_eq!(props.max_row_group_size(), 40); - assert_eq!(props.created_by(), "default"); - assert_eq!( - props.key_value_metadata(), - Some(&vec![ - KeyValue::new("key".to_string(), "value".to_string(),) - ]) - ); + fn test_props(props: &WriterProperties) { + assert_eq!(props.writer_version(), WriterVersion::PARQUET_2_0); + assert_eq!(props.data_page_size_limit(), 10); + assert_eq!(props.dictionary_page_size_limit(), 20); + assert_eq!(props.write_batch_size(), 30); + assert_eq!(props.max_row_group_size(), 40); + assert_eq!(props.created_by(), "default"); + assert_eq!( + props.key_value_metadata(), + Some(&vec![ + KeyValue::new("key".to_string(), "value".to_string(),) + ]) + ); - assert_eq!( - props.encoding(&ColumnPath::from("a")), - Some(Encoding::DELTA_BINARY_PACKED) - ); - assert_eq!( - props.compression(&ColumnPath::from("a")), - Compression::GZIP(Default::default()) - ); - assert!(!props.dictionary_enabled(&ColumnPath::from("a"))); - assert_eq!( - props.statistics_enabled(&ColumnPath::from("a")), - EnabledStatistics::None - ); + assert_eq!( + props.encoding(&ColumnPath::from("a")), + Some(Encoding::DELTA_BINARY_PACKED) + ); + assert_eq!( + props.compression(&ColumnPath::from("a")), + Compression::GZIP(Default::default()) + ); + assert!(!props.dictionary_enabled(&ColumnPath::from("a"))); + assert_eq!( + props.statistics_enabled(&ColumnPath::from("a")), + EnabledStatistics::None + ); - assert_eq!( - props.encoding(&ColumnPath::from("col")), - Some(Encoding::RLE) - ); - assert_eq!( - props.compression(&ColumnPath::from("col")), - Compression::SNAPPY - ); - assert!(props.dictionary_enabled(&ColumnPath::from("col"))); - assert_eq!( - props.statistics_enabled(&ColumnPath::from("col")), - EnabledStatistics::Chunk - ); - assert_eq!( - props.bloom_filter_properties(&ColumnPath::from("col")), - Some(&BloomFilterProperties { fpp: 0.1, ndv: 100 }) - ); + assert_eq!( + props.encoding(&ColumnPath::from("col")), + Some(Encoding::RLE) + ); + assert_eq!( + props.compression(&ColumnPath::from("col")), + Compression::SNAPPY + ); + assert!(props.dictionary_enabled(&ColumnPath::from("col"))); + assert_eq!( + props.statistics_enabled(&ColumnPath::from("col")), + EnabledStatistics::Chunk + ); + assert_eq!( + props.bloom_filter_properties(&ColumnPath::from("col")), + Some(&BloomFilterProperties { fpp: 0.1, ndv: 100 }) + ); + } + + // Test direct build of properties + test_props(&props); + + // Test that into_builder() gives the same result + let props_into_builder_and_back = props.into_builder().build(); + test_props(&props_into_builder_and_back); } #[test] diff --git a/parquet/src/file/reader.rs b/parquet/src/file/reader.rs index 7e2b149ad3fb..61af21a68ec1 100644 --- a/parquet/src/file/reader.rs +++ b/parquet/src/file/reader.rs @@ -48,11 +48,12 @@ pub trait Length { /// Generates [`Read`]ers to read chunks of a Parquet data source. /// /// The Parquet reader uses [`ChunkReader`] to access Parquet data, allowing -/// multiple decoders to read concurrently from different locations in the same file. +/// multiple decoders to read concurrently from different locations in the same +/// file. /// -/// The trait provides: -/// * random access (via [`Self::get_bytes`]) -/// * sequential (via [`Self::get_read`]) +/// The trait functions both as a reader and a factory for readers. +/// * random access via [`Self::get_bytes`] +/// * sequential access via the reader returned via factory method [`Self::get_read`] /// /// # Provided Implementations /// * [`File`] for reading from local file system diff --git a/parquet/src/file/serialized_reader.rs b/parquet/src/file/serialized_reader.rs index d198a34227fa..b36a76f472f5 100644 --- a/parquet/src/file/serialized_reader.rs +++ b/parquet/src/file/serialized_reader.rs @@ -191,6 +191,7 @@ impl SerializedFileReader { /// Creates file reader from a Parquet file with read options. /// Returns an error if the Parquet file does not exist or is corrupt. + #[allow(deprecated)] pub fn new_with_options(chunk_reader: R, options: ReadOptions) -> Result { let mut metadata_builder = ParquetMetaDataReader::new() .parse_and_finish(&chunk_reader)? diff --git a/parquet/src/file/writer.rs b/parquet/src/file/writer.rs index 690efb36f281..9adf67e68bee 100644 --- a/parquet/src/file/writer.rs +++ b/parquet/src/file/writer.rs @@ -958,7 +958,7 @@ impl PageWriter for SerializedPageWriter<'_, W> { let page_type = page.page_type(); let start_pos = self.sink.bytes_written() as u64; - let page_header = page.to_thrift_header(); + let page_header = page.to_thrift_header()?; let header_size = self.serialize_page_header(page_header)?; self.sink.write_all(page.data())?; diff --git a/parquet/src/lib.rs b/parquet/src/lib.rs index 07a673c295bc..b1100c4bc440 100644 --- a/parquet/src/lib.rs +++ b/parquet/src/lib.rs @@ -86,6 +86,14 @@ //! [`ParquetRecordBatchStreamBuilder`]: arrow::async_reader::ParquetRecordBatchStreamBuilder //! [`ParquetObjectReader`]: arrow::async_reader::ParquetObjectReader //! +//! ## Variant Logical Type (`variant_experimental` feature) +//! +//! The [`variant`] module supports reading and writing Parquet files +//! with the [Variant Binary Encoding] logical type, which can represent +//! semi-structured data such as JSON efficiently. +//! +//! [Variant Binary Encoding]: https://github.com/apache/parquet-format/blob/master/VariantEncoding.md +//! //! ## Read/Write Parquet Directly //! //! Workloads needing finer-grained control, or to avoid a dependence on arrow, @@ -155,6 +163,8 @@ pub mod format; #[macro_use] pub mod data_type; +use std::fmt::Debug; +use std::ops::Range; // Exported for external use, such as benchmarks #[cfg(feature = "experimental")] #[doc(hidden)] @@ -179,3 +189,21 @@ pub mod record; pub mod schema; pub mod thrift; + +/// What data is needed to read the next item from a decoder. +/// +/// This is used to communicate between the decoder and the caller +/// to indicate what data is needed next, or what the result of decoding is. +#[derive(Debug)] +pub enum DecodeResult { + /// The ranges of data necessary to proceed + // TODO: distinguish between minimim needed to make progress and what could be used? + NeedsData(Vec>), + /// The decoder produced an output item + Data(T), + /// The decoder finished processing + Finished, +} + +#[cfg(feature = "variant_experimental")] +pub mod variant; diff --git a/parquet/src/record/api.rs b/parquet/src/record/api.rs index 04325576a8bc..ebf933f33e60 100644 --- a/parquet/src/record/api.rs +++ b/parquet/src/record/api.rs @@ -928,29 +928,22 @@ fn convert_date_to_string(value: i32) -> String { format!("{}", dt.format("%Y-%m-%d")) } -/// Helper method to convert Parquet timestamp into a string. -/// Input `value` is a number of seconds since the epoch in UTC. -/// Datetime is displayed in local timezone. -#[inline] -fn convert_timestamp_secs_to_string(value: i64) -> String { - let dt = Utc.timestamp_opt(value, 0).unwrap(); - format!("{}", dt.format("%Y-%m-%d %H:%M:%S %:z")) -} - /// Helper method to convert Parquet timestamp into a string. /// Input `value` is a number of milliseconds since the epoch in UTC. -/// Datetime is displayed in local timezone. +/// Datetime is displayed in UTC timezone. #[inline] fn convert_timestamp_millis_to_string(value: i64) -> String { - convert_timestamp_secs_to_string(value / 1000) + let dt = Utc.timestamp_millis_opt(value).unwrap(); + format!("{}", dt.format("%Y-%m-%d %H:%M:%S%.3f %:z")) } /// Helper method to convert Parquet timestamp into a string. /// Input `value` is a number of microseconds since the epoch in UTC. -/// Datetime is displayed in local timezone. +/// Datetime is displayed in UTC timezone. #[inline] fn convert_timestamp_micros_to_string(value: i64) -> String { - convert_timestamp_secs_to_string(value / 1000000) + let dt = Utc.timestamp_micros(value).unwrap(); + format!("{}", dt.format("%Y-%m-%d %H:%M:%S%.6f %:z")) } /// Helper method to convert Parquet time (milliseconds since midnight) into a string. @@ -1278,44 +1271,75 @@ mod tests { #[test] fn test_convert_timestamp_millis_to_string() { - fn check_datetime_conversion(y: u32, m: u32, d: u32, h: u32, mi: u32, s: u32) { + fn check_datetime_conversion( + (y, m, d, h, mi, s, milli): (u32, u32, u32, u32, u32, u32, u32), + exp: &str, + ) { let datetime = chrono::NaiveDate::from_ymd_opt(y as i32, m, d) .unwrap() - .and_hms_opt(h, mi, s) + .and_hms_milli_opt(h, mi, s, milli) .unwrap(); let dt = Utc.from_utc_datetime(&datetime); let res = convert_timestamp_millis_to_string(dt.timestamp_millis()); - let exp = format!("{}", dt.format("%Y-%m-%d %H:%M:%S %:z")); assert_eq!(res, exp); } - check_datetime_conversion(1969, 9, 10, 1, 2, 3); - check_datetime_conversion(2010, 1, 2, 13, 12, 54); - check_datetime_conversion(2011, 1, 3, 8, 23, 1); - check_datetime_conversion(2012, 4, 5, 11, 6, 32); - check_datetime_conversion(2013, 5, 12, 16, 38, 0); - check_datetime_conversion(2014, 11, 28, 21, 15, 12); + check_datetime_conversion((1969, 9, 10, 1, 2, 3, 4), "1969-09-10 01:02:03.004 +00:00"); + check_datetime_conversion( + (2010, 1, 2, 13, 12, 54, 42), + "2010-01-02 13:12:54.042 +00:00", + ); + check_datetime_conversion((2011, 1, 3, 8, 23, 1, 27), "2011-01-03 08:23:01.027 +00:00"); + check_datetime_conversion((2012, 4, 5, 11, 6, 32, 0), "2012-04-05 11:06:32.000 +00:00"); + check_datetime_conversion( + (2013, 5, 12, 16, 38, 0, 15), + "2013-05-12 16:38:00.015 +00:00", + ); + check_datetime_conversion( + (2014, 11, 28, 21, 15, 12, 59), + "2014-11-28 21:15:12.059 +00:00", + ); } #[test] fn test_convert_timestamp_micros_to_string() { - fn check_datetime_conversion(y: u32, m: u32, d: u32, h: u32, mi: u32, s: u32) { + fn check_datetime_conversion( + (y, m, d, h, mi, s, micro): (u32, u32, u32, u32, u32, u32, u32), + exp: &str, + ) { let datetime = chrono::NaiveDate::from_ymd_opt(y as i32, m, d) .unwrap() - .and_hms_opt(h, mi, s) + .and_hms_micro_opt(h, mi, s, micro) .unwrap(); let dt = Utc.from_utc_datetime(&datetime); let res = convert_timestamp_micros_to_string(dt.timestamp_micros()); - let exp = format!("{}", dt.format("%Y-%m-%d %H:%M:%S %:z")); assert_eq!(res, exp); } - check_datetime_conversion(1969, 9, 10, 1, 2, 3); - check_datetime_conversion(2010, 1, 2, 13, 12, 54); - check_datetime_conversion(2011, 1, 3, 8, 23, 1); - check_datetime_conversion(2012, 4, 5, 11, 6, 32); - check_datetime_conversion(2013, 5, 12, 16, 38, 0); - check_datetime_conversion(2014, 11, 28, 21, 15, 12); + check_datetime_conversion( + (1969, 9, 10, 1, 2, 3, 4), + "1969-09-10 01:02:03.000004 +00:00", + ); + check_datetime_conversion( + (2010, 1, 2, 13, 12, 54, 42), + "2010-01-02 13:12:54.000042 +00:00", + ); + check_datetime_conversion( + (2011, 1, 3, 8, 23, 1, 27), + "2011-01-03 08:23:01.000027 +00:00", + ); + check_datetime_conversion( + (2012, 4, 5, 11, 6, 32, 0), + "2012-04-05 11:06:32.000000 +00:00", + ); + check_datetime_conversion( + (2013, 5, 12, 16, 38, 0, 15), + "2013-05-12 16:38:00.000015 +00:00", + ); + check_datetime_conversion( + (2014, 11, 28, 21, 15, 12, 59), + "2014-11-28 21:15:12.000059 +00:00", + ); } #[test] @@ -2000,11 +2024,11 @@ mod tests { ); assert_eq!( Field::TimestampMillis(12345678).to_json_value(), - Value::String("1970-01-01 03:25:45 +00:00".to_string()) + Value::String("1970-01-01 03:25:45.678 +00:00".to_string()) ); assert_eq!( Field::TimestampMicros(12345678901).to_json_value(), - Value::String(convert_timestamp_micros_to_string(12345678901)) + Value::String("1970-01-01 03:25:45.678901 +00:00".to_string()) ); assert_eq!( Field::TimeMillis(47445123).to_json_value(), diff --git a/parquet/src/thrift.rs b/parquet/src/thrift.rs index fc391abe87d7..e16e394be2bb 100644 --- a/parquet/src/thrift.rs +++ b/parquet/src/thrift.rs @@ -33,12 +33,20 @@ pub trait TSerializable: Sized { fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()>; } -/// Public function to aid benchmarking. +// Public function to aid benchmarking. Reads Parquet `FileMetaData` encoded in `bytes`. +#[doc(hidden)] pub fn bench_file_metadata(bytes: &bytes::Bytes) { let mut input = TCompactSliceInputProtocol::new(bytes); crate::format::FileMetaData::read_from_in_protocol(&mut input).unwrap(); } +// Public function to aid benchmarking. Reads Parquet `PageHeader` encoded in `bytes`. +#[doc(hidden)] +pub fn bench_page_header(bytes: &bytes::Bytes) { + let mut prot = TCompactSliceInputProtocol::new(bytes); + crate::format::PageHeader::read_from_in_protocol(&mut prot).unwrap(); +} + /// A more performant implementation of [`TCompactInputProtocol`] that reads a slice /// /// [`TCompactInputProtocol`]: thrift::protocol::TCompactInputProtocol diff --git a/parquet/src/util/mod.rs b/parquet/src/util/mod.rs index 1431132473e9..145cdd693e59 100644 --- a/parquet/src/util/mod.rs +++ b/parquet/src/util/mod.rs @@ -20,6 +20,7 @@ pub mod bit_util; mod bit_pack; pub(crate) mod interner; +pub mod push_buffers; #[cfg(any(test, feature = "test_common"))] pub(crate) mod test_common; pub mod utf8; diff --git a/parquet/src/util/push_buffers.rs b/parquet/src/util/push_buffers.rs new file mode 100644 index 000000000000..b30f91a81b70 --- /dev/null +++ b/parquet/src/util/push_buffers.rs @@ -0,0 +1,197 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::errors::ParquetError; +use crate::file::reader::{ChunkReader, Length}; +use bytes::Bytes; +use std::fmt::Display; +use std::ops::Range; + +/// Holds multiple buffers of data +/// +/// This is the in-memory buffer for the ParquetDecoder and ParquetMetadataDecoders +/// +/// Features: +/// 1. Zero copy +/// 2. non contiguous ranges of bytes +/// +/// # Non Coalescing +/// +/// This buffer does not coalesce (merging adjacent ranges of bytes into a +/// single range). Coalescing at this level would require copying the data but +/// the caller may already have the needed data in a single buffer which would +/// require no copying. +/// +/// Thus, the implementation defers to the caller to coalesce subsequent requests +/// if desired. +#[derive(Debug, Clone)] +pub(crate) struct PushBuffers { + /// the virtual "offset" of this buffers (added to any request) + offset: u64, + /// The total length of the file being decoded + file_len: u64, + /// The ranges of data that are available for decoding (not adjusted for offset) + ranges: Vec>, + /// The buffers of data that can be used to decode the Parquet file + buffers: Vec, +} + +impl Display for PushBuffers { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!( + f, + "Buffers (offset: {}, file_len: {})", + self.offset, self.file_len + )?; + writeln!(f, "Available Ranges (w/ offset):")?; + for range in &self.ranges { + writeln!( + f, + " {}..{} ({}..{}): {} bytes", + range.start, + range.end, + range.start + self.offset, + range.end + self.offset, + range.end - range.start + )?; + } + + Ok(()) + } +} + +impl PushBuffers { + /// Create a new Buffers instance with the given file length + pub fn new(file_len: u64) -> Self { + Self { + offset: 0, + file_len, + ranges: Vec::new(), + buffers: Vec::new(), + } + } + + /// Push all the ranges and buffers + pub fn push_ranges(&mut self, ranges: Vec>, buffers: Vec) { + assert_eq!( + ranges.len(), + buffers.len(), + "Number of ranges must match number of buffers" + ); + for (range, buffer) in ranges.into_iter().zip(buffers.into_iter()) { + self.push_range(range, buffer); + } + } + + /// Push a new range and its associated buffer + pub fn push_range(&mut self, range: Range, buffer: Bytes) { + assert_eq!( + (range.end - range.start) as usize, + buffer.len(), + "Range length must match buffer length" + ); + self.ranges.push(range); + self.buffers.push(buffer); + } + + /// Returns true if the Buffers contains data for the given range + pub fn has_range(&self, range: &Range) -> bool { + self.ranges + .iter() + .any(|r| r.start <= range.start && r.end >= range.end) + } + + fn iter(&self) -> impl Iterator, &Bytes)> { + self.ranges.iter().zip(self.buffers.iter()) + } + + /// return the file length of the Parquet file being read + pub fn file_len(&self) -> u64 { + self.file_len + } + + /// Specify a new offset + pub fn with_offset(mut self, offset: u64) -> Self { + self.offset = offset; + self + } +} + +impl Length for PushBuffers { + fn len(&self) -> u64 { + self.file_len + } +} + +/// less efficient implementation of Read for Buffers +impl std::io::Read for PushBuffers { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + // Find the range that contains the start offset + let mut found = false; + for (range, data) in self.iter() { + if range.start <= self.offset && range.end >= self.offset + buf.len() as u64 { + // Found the range, figure out the starting offset in the buffer + let start_offset = (self.offset - range.start) as usize; + let end_offset = start_offset + buf.len(); + let slice = data.slice(start_offset..end_offset); + buf.copy_from_slice(slice.as_ref()); + found = true; + break; + } + } + if found { + // If we found the range, we can return the number of bytes read + // advance our offset + self.offset += buf.len() as u64; + Ok(buf.len()) + } else { + Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "No data available in Buffers", + )) + } + } +} + +impl ChunkReader for PushBuffers { + type T = Self; + + fn get_read(&self, start: u64) -> Result { + Ok(self.clone().with_offset(self.offset + start)) + } + + fn get_bytes(&self, start: u64, length: usize) -> Result { + if start > self.file_len { + return Err(ParquetError::General(format!( + "Requested start {start} is beyond the end of the file (file length: {})", + self.file_len + ))); + } + + // find the range that contains the start offset + for (range, data) in self.iter() { + if range.start <= start && range.end >= start + length as u64 { + // Found the range, figure out the starting offset in the buffer + let start_offset = (start - range.start) as usize; + return Ok(data.slice(start_offset..start_offset + length)); + } + } + // Signal that we need more data + let requested_end = start + length as u64; + Err(ParquetError::NeedMoreDataRange(start..requested_end)) + } +} diff --git a/parquet/src/variant.rs b/parquet/src/variant.rs new file mode 100644 index 000000000000..b5902c02ed8e --- /dev/null +++ b/parquet/src/variant.rs @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! ⚠️ Experimental Support for reading and writing [`Variant`]s to / from Parquet files ⚠️ +//! +//! This is a 🚧 Work In Progress +//! +//! Note: Requires the `variant_experimental` feature of the `parquet` crate to be enabled. +//! +//! # Features +//! * [`Variant`] represents variant value, which can be an object, list, or primitive. +//! * [`VariantBuilder`] for building `Variant` values. +//! * [`VariantArray`] for representing a column of Variant values. +//! * [`compute`] module with functions for manipulating Variants, such as +//! [`variant_get`] to extracting a value by path and functions to convert +//! between `Variant` and JSON. +//! +//! [Variant Logical Type]: Variant +//! [`VariantArray`]: compute::VariantArray +//! [`variant_get`]: compute::variant_get +//! +//! # Example: Writing a Parquet file with Variant column +//! ```rust +//! # use parquet::variant::compute::{VariantArray, VariantArrayBuilder}; +//! # use parquet::variant::VariantBuilderExt; +//! # use std::sync::Arc; +//! # use arrow_array::{ArrayRef, RecordBatch}; +//! # use parquet::arrow::ArrowWriter; +//! # fn main() -> Result<(), parquet::errors::ParquetError> { +//! // Use the VariantArrayBuilder to build a VariantArray +//! let mut builder = VariantArrayBuilder::new(3); +//! // row 1: {"name": "Alice"} +//! builder.new_object().with_field("name", "Alice").finish(); +//! let array = builder.build(); +//! +//! // TODO support writing VariantArray directly +//! // at the moment it panics when trying to downcast to a struct array +//! // https://github.com/apache/arrow-rs/issues/8296 +//! // let array: ArrayRef = Arc::new(array); +//! let array: ArrayRef = Arc::new(array.into_inner()); +//! +//! // create a RecordBatch with the VariantArray +//! let batch = RecordBatch::try_from_iter(vec![("data", array)])?; +//! +//! // write the RecordBatch to a Parquet file +//! let file = std::fs::File::create("variant.parquet")?; +//! let mut writer = ArrowWriter::try_new(file, batch.schema(), None)?; +//! writer.write(&batch)?; +//! writer.close()?; +//! +//! # std::fs::remove_file("variant.parquet")?; +//! # Ok(()) +//! # } +//! ``` +//! +//! # Example: Writing JSON with a Parquet file with Variant column +//! ```rust +//! # use std::sync::Arc; +//! # use arrow_array::{ArrayRef, RecordBatch, StringArray}; +//! # use parquet::variant::compute::json_to_variant; +//! # use parquet::variant::compute::VariantArray; +//! # use parquet::arrow::ArrowWriter; +//! # fn main() -> Result<(), parquet::errors::ParquetError> { +//! // Create an array of JSON strings, simulating a column of JSON data +//! // TODO use StringViewArray when available +//! let input_array = StringArray::from(vec![ +//! Some(r#"{"name": "Alice", "age": 30}"#), +//! Some(r#"{"name": "Bob", "age": 25, "address": {"city": "New York"}}"#), +//! None, +//! Some("{}"), +//! ]); +//! let input_array: ArrayRef = Arc::new(input_array); +//! +//! // Convert the JSON strings to a VariantArray +//! let array: VariantArray = json_to_variant(&input_array)?; +//! +//! // TODO support writing VariantArray directly +//! // at the moment it panics when trying to downcast to a struct array +//! // https://github.com/apache/arrow-rs/issues/8296 +//! // let array: ArrayRef = Arc::new(array); +//! let array: ArrayRef = Arc::new(array.into_inner()); +//! +//! // create a RecordBatch with the VariantArray +//! let batch = RecordBatch::try_from_iter(vec![("data", array)])?; +//! +//! // write the RecordBatch to a Parquet file +//! let file = std::fs::File::create("variant-json.parquet")?; +//! let mut writer = ArrowWriter::try_new(file, batch.schema(), None)?; +//! writer.write(&batch)?; +//! writer.close()?; +//! # std::fs::remove_file("variant-json.parquet")?; +//! # Ok(()) +//! # } +//! ``` +//! +//! # Example: Reading a Parquet file with Variant column +//! (TODO: add example) +pub use parquet_variant::*; +pub use parquet_variant_compute as compute; diff --git a/parquet/tests/arrow_reader/bad_data.rs b/parquet/tests/arrow_reader/bad_data.rs index ba50e738f6cf..c767115eaa7b 100644 --- a/parquet/tests/arrow_reader/bad_data.rs +++ b/parquet/tests/arrow_reader/bad_data.rs @@ -150,6 +150,7 @@ fn read_file(name: &str) -> Result { #[cfg(feature = "async")] #[tokio::test] +#[allow(deprecated)] async fn bad_metadata_err() { use bytes::Bytes; use parquet::file::metadata::ParquetMetaDataReader; diff --git a/parquet/tests/arrow_reader/io/async_reader.rs b/parquet/tests/arrow_reader/io/async_reader.rs new file mode 100644 index 000000000000..f2d3ce07234b --- /dev/null +++ b/parquet/tests/arrow_reader/io/async_reader.rs @@ -0,0 +1,430 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for the async reader ([`ParquetRecordBatchStreamBuilder`]) + +use crate::io::{ + filter_a_175_b_625, filter_b_575_625, filter_b_false, test_file, test_options, LogEntry, + OperationLog, TestParquetFile, +}; +use bytes::Bytes; +use futures::future::BoxFuture; +use futures::{FutureExt, StreamExt}; +use parquet::arrow::arrow_reader::{ArrowReaderOptions, RowSelection, RowSelector}; +use parquet::arrow::async_reader::AsyncFileReader; +use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask}; +use parquet::errors::Result; +use parquet::file::metadata::ParquetMetaData; +use std::ops::Range; +use std::sync::Arc; + +#[tokio::test] +async fn test_read_entire_file() { + // read entire file without any filtering or projection + let test_file = test_file(); + let builder = async_builder(&test_file, test_options()).await; + insta::assert_debug_snapshot!(run( + &test_file, + builder).await, @r#" + [ + "Get Provided Metadata", + "Event: Builder Configured", + "Event: Reader Built", + "Read Multi:", + " Row Group 0, column 'a': MultiPage(dictionary_page: true, data_pages: [0, 1]) (1856 bytes, 1 requests) [data]", + " Row Group 0, column 'b': MultiPage(dictionary_page: true, data_pages: [0, 1]) (1856 bytes, 1 requests) [data]", + " Row Group 0, column 'c': MultiPage(dictionary_page: true, data_pages: [0, 1]) (7346 bytes, 1 requests) [data]", + "Read Multi:", + " Row Group 1, column 'a': MultiPage(dictionary_page: true, data_pages: [0, 1]) (1856 bytes, 1 requests) [data]", + " Row Group 1, column 'b': MultiPage(dictionary_page: true, data_pages: [0, 1]) (1856 bytes, 1 requests) [data]", + " Row Group 1, column 'c': MultiPage(dictionary_page: true, data_pages: [0, 1]) (7456 bytes, 1 requests) [data]", + ] + "#); +} + +#[tokio::test] +async fn test_read_single_group() { + let test_file = test_file(); + let builder = async_builder(&test_file, test_options()) + .await + // read only second row group + .with_row_groups(vec![1]); + + // Expect to see only IO for Row Group 1. Should see no IO for Row Group 0. + insta::assert_debug_snapshot!(run( + &test_file, + builder).await, @r#" + [ + "Get Provided Metadata", + "Event: Builder Configured", + "Event: Reader Built", + "Read Multi:", + " Row Group 1, column 'a': MultiPage(dictionary_page: true, data_pages: [0, 1]) (1856 bytes, 1 requests) [data]", + " Row Group 1, column 'b': MultiPage(dictionary_page: true, data_pages: [0, 1]) (1856 bytes, 1 requests) [data]", + " Row Group 1, column 'c': MultiPage(dictionary_page: true, data_pages: [0, 1]) (7456 bytes, 1 requests) [data]", + ] + "#); +} + +#[tokio::test] +async fn test_read_single_column() { + let test_file = test_file(); + let builder = async_builder(&test_file, test_options()).await; + let schema_descr = builder.metadata().file_metadata().schema_descr_ptr(); + let builder = builder.with_projection(ProjectionMask::columns(&schema_descr, ["b"])); + // Expect to see only IO for column "b". Should see no IO for columns "a" or "c". + insta::assert_debug_snapshot!(run( + &test_file, + builder).await, @r#" + [ + "Get Provided Metadata", + "Event: Builder Configured", + "Event: Reader Built", + "Read Multi:", + " Row Group 0, column 'b': MultiPage(dictionary_page: true, data_pages: [0, 1]) (1856 bytes, 1 requests) [data]", + "Read Multi:", + " Row Group 1, column 'b': MultiPage(dictionary_page: true, data_pages: [0, 1]) (1856 bytes, 1 requests) [data]", + ] + "#); +} + +#[tokio::test] +async fn test_read_row_selection() { + // There are 400 total rows spread across 4 data pages (100 rows each) + // select rows 175..225 (i.e. DataPage(1) of row group 0 and DataPage(0) of row group 1) + let test_file = test_file(); + let builder = async_builder(&test_file, test_options()).await; + let schema_descr = builder.metadata().file_metadata().schema_descr_ptr(); + let builder = builder + .with_projection(ProjectionMask::columns(&schema_descr, ["a", "b"])) + .with_row_selection(RowSelection::from(vec![ + RowSelector::skip(175), + RowSelector::select(50), + ])); + + // Expect to see only data IO for one page for each column for each row group + insta::assert_debug_snapshot!(run( + &test_file, + builder).await, @r#" + [ + "Get Provided Metadata", + "Event: Builder Configured", + "Event: Reader Built", + "Read Multi:", + " Row Group 0, column 'a': DictionaryPage (1617 bytes, 1 requests) [data]", + " Row Group 0, column 'a': DataPage(1) (126 bytes , 1 requests) [data]", + " Row Group 0, column 'b': DictionaryPage (1617 bytes, 1 requests) [data]", + " Row Group 0, column 'b': DataPage(1) (126 bytes , 1 requests) [data]", + "Read Multi:", + " Row Group 1, column 'a': DictionaryPage (1617 bytes, 1 requests) [data]", + " Row Group 1, column 'a': DataPage(0) (113 bytes , 1 requests) [data]", + " Row Group 1, column 'b': DictionaryPage (1617 bytes, 1 requests) [data]", + " Row Group 1, column 'b': DataPage(0) (113 bytes , 1 requests) [data]", + ] + "#); +} + +#[tokio::test] +async fn test_read_limit() { + // There are 400 total rows spread across 4 data pages (100 rows each) + // a limit of 125 rows should only fetch the first two data pages (DataPage(0) and DataPage(1)) from row group 0 + let test_file = test_file(); + let builder = async_builder(&test_file, test_options()).await; + let schema_descr = builder.metadata().file_metadata().schema_descr_ptr(); + let builder = builder + .with_projection(ProjectionMask::columns(&schema_descr, ["a"])) + .with_limit(125); + + insta::assert_debug_snapshot!(run( + &test_file, + builder).await, @r#" + [ + "Get Provided Metadata", + "Event: Builder Configured", + "Event: Reader Built", + "Read Multi:", + " Row Group 0, column 'a': DictionaryPage (1617 bytes, 1 requests) [data]", + " Row Group 0, column 'a': DataPage(0) (113 bytes , 1 requests) [data]", + " Row Group 0, column 'a': DataPage(1) (126 bytes , 1 requests) [data]", + ] + "#); +} + +#[tokio::test] +async fn test_read_single_row_filter() { + // Values from column "b" range 400..799 + // filter "b" > 575 and < than 625 + // (last data page in Row Group 0 and first DataPage in Row Group 1) + let test_file = test_file(); + let builder = async_builder(&test_file, test_options()).await; + let schema_descr = builder.metadata().file_metadata().schema_descr_ptr(); + + let builder = builder + .with_projection(ProjectionMask::columns(&schema_descr, ["a", "b"])) + .with_row_filter(filter_b_575_625(&schema_descr)); + + // Expect to see I/O for column b in both row groups to evaluate filter, + // then a single pages for the "a" column in each row group + insta::assert_debug_snapshot!(run( + &test_file, + builder).await, @r#" + [ + "Get Provided Metadata", + "Event: Builder Configured", + "Event: Reader Built", + "Read Multi:", + " Row Group 0, column 'b': MultiPage(dictionary_page: true, data_pages: [0, 1]) (1856 bytes, 1 requests) [data]", + "Read Multi:", + " Row Group 0, column 'a': DictionaryPage (1617 bytes, 1 requests) [data]", + " Row Group 0, column 'a': DataPage(1) (126 bytes , 1 requests) [data]", + "Read Multi:", + " Row Group 1, column 'b': MultiPage(dictionary_page: true, data_pages: [0, 1]) (1856 bytes, 1 requests) [data]", + "Read Multi:", + " Row Group 1, column 'a': DictionaryPage (1617 bytes, 1 requests) [data]", + " Row Group 1, column 'a': DataPage(0) (113 bytes , 1 requests) [data]", + ] + "#); +} + +#[tokio::test] +async fn test_read_single_row_filter_no_page_index() { + // Values from column "b" range 400..799 + // Apply a filter "b" > 575 and than 625 + // (last data page in Row Group 0 and first DataPage in Row Group 1) + let test_file = test_file(); + let options = test_options().with_page_index(false); + let builder = async_builder(&test_file, options).await; + let schema_descr = builder.metadata().file_metadata().schema_descr_ptr(); + + let builder = builder + .with_projection(ProjectionMask::columns(&schema_descr, ["a", "b"])) + .with_row_filter(filter_b_575_625(&schema_descr)); + + // Since we don't have the page index, expect to see: + // 1. I/O for all pages of column b to evaluate the filter + // 2. IO for all pages of column a as the reader doesn't know where the page + // boundaries are so needs to scan them. + insta::assert_debug_snapshot!(run( + &test_file, + builder).await, @r#" + [ + "Get Provided Metadata", + "Event: Builder Configured", + "Event: Reader Built", + "Read Multi:", + " Row Group 0, column 'b': MultiPage(dictionary_page: true, data_pages: [0, 1]) (1856 bytes, 1 requests) [data]", + "Read Multi:", + " Row Group 0, column 'a': MultiPage(dictionary_page: true, data_pages: [0, 1]) (1856 bytes, 1 requests) [data]", + "Read Multi:", + " Row Group 1, column 'b': MultiPage(dictionary_page: true, data_pages: [0, 1]) (1856 bytes, 1 requests) [data]", + "Read Multi:", + " Row Group 1, column 'a': MultiPage(dictionary_page: true, data_pages: [0, 1]) (1856 bytes, 1 requests) [data]", + ] + "#); +} + +#[tokio::test] +async fn test_read_multiple_row_filter() { + // Values in column "a" range 0..399 + // Values in column "b" range 400..799 + // First filter: "a" > 175 (last data page in Row Group 0) + // Second filter: "b" < 625 (last data page in Row Group 0 and first DataPage in RowGroup 1) + // Read column "c" + let test_file = test_file(); + let builder = async_builder(&test_file, test_options()).await; + let schema_descr = builder.metadata().file_metadata().schema_descr_ptr(); + + let builder = builder + .with_projection(ProjectionMask::columns(&schema_descr, ["c"])) + .with_row_filter(filter_a_175_b_625(&schema_descr)); + + // Expect that we will see + // 1. IO for all pages of column A (to evaluate the first filter) + // 2. IO for pages of column b that passed the first filter (to evaluate the second filter) + // 3. IO after reader is built only for column c for the rows that passed both filters + insta::assert_debug_snapshot!(run( + &test_file, + builder).await, @r#" + [ + "Get Provided Metadata", + "Event: Builder Configured", + "Event: Reader Built", + "Read Multi:", + " Row Group 0, column 'a': MultiPage(dictionary_page: true, data_pages: [0, 1]) (1856 bytes, 1 requests) [data]", + "Read Multi:", + " Row Group 0, column 'b': DictionaryPage (1617 bytes, 1 requests) [data]", + " Row Group 0, column 'b': DataPage(1) (126 bytes , 1 requests) [data]", + "Read Multi:", + " Row Group 0, column 'c': DictionaryPage (7107 bytes, 1 requests) [data]", + " Row Group 0, column 'c': DataPage(1) (126 bytes , 1 requests) [data]", + "Read Multi:", + " Row Group 1, column 'a': MultiPage(dictionary_page: true, data_pages: [0, 1]) (1856 bytes, 1 requests) [data]", + "Read Multi:", + " Row Group 1, column 'b': DictionaryPage (1617 bytes, 1 requests) [data]", + " Row Group 1, column 'b': DataPage(0) (113 bytes , 1 requests) [data]", + " Row Group 1, column 'b': DataPage(1) (126 bytes , 1 requests) [data]", + "Read Multi:", + " Row Group 1, column 'c': DictionaryPage (7217 bytes, 1 requests) [data]", + " Row Group 1, column 'c': DataPage(0) (113 bytes , 1 requests) [data]", + ] + "#); +} + +#[tokio::test] +async fn test_read_single_row_filter_all() { + // Apply a filter that filters out all rows + + let test_file = test_file(); + let builder = async_builder(&test_file, test_options()).await; + let schema_descr = builder.metadata().file_metadata().schema_descr_ptr(); + + let builder = builder + .with_projection(ProjectionMask::columns(&schema_descr, ["a", "b"])) + .with_row_filter(filter_b_false(&schema_descr)); + + // Expect to see reads for column "b" to evaluate the filter, but no reads + // for column "a" as no rows pass the filter + insta::assert_debug_snapshot!(run( + &test_file, + builder).await, @r#" + [ + "Get Provided Metadata", + "Event: Builder Configured", + "Event: Reader Built", + "Read Multi:", + " Row Group 0, column 'b': MultiPage(dictionary_page: true, data_pages: [0, 1]) (1856 bytes, 1 requests) [data]", + "Read Multi:", + " Row Group 1, column 'b': MultiPage(dictionary_page: true, data_pages: [0, 1]) (1856 bytes, 1 requests) [data]", + ] + "#); +} + +/// Return a [`ParquetRecordBatchStreamBuilder`] for reading this file +async fn async_builder( + test_file: &TestParquetFile, + options: ArrowReaderOptions, +) -> ParquetRecordBatchStreamBuilder { + let parquet_meta_data = if options.page_index() { + Arc::clone(test_file.parquet_metadata()) + } else { + // strip out the page index from the metadata + let metadata = test_file + .parquet_metadata() + .as_ref() + .clone() + .into_builder() + .set_column_index(None) + .set_offset_index(None) + .build(); + Arc::new(metadata) + }; + + let reader = RecordingAsyncFileReader { + bytes: test_file.bytes().clone(), + ops: Arc::clone(test_file.ops()), + parquet_meta_data, + }; + + ParquetRecordBatchStreamBuilder::new_with_options(reader, options) + .await + .unwrap() +} + +/// Build the reader from the specified builder and read all batches from it, +/// and return the operations log. +async fn run( + test_file: &TestParquetFile, + builder: ParquetRecordBatchStreamBuilder, +) -> Vec { + let ops = test_file.ops(); + ops.add_entry(LogEntry::event("Builder Configured")); + let mut stream = builder.build().unwrap(); + ops.add_entry(LogEntry::event("Reader Built")); + while let Some(batch) = stream.next().await { + match batch { + Ok(_) => {} + Err(e) => panic!("Error reading batch: {e}"), + } + } + ops.snapshot() +} + +struct RecordingAsyncFileReader { + bytes: Bytes, + ops: Arc, + parquet_meta_data: Arc, +} + +impl AsyncFileReader for RecordingAsyncFileReader { + fn get_bytes(&mut self, range: Range) -> BoxFuture<'_, parquet::errors::Result> { + let ops = Arc::clone(&self.ops); + let data = self + .bytes + .slice(range.start as usize..range.end as usize) + .clone(); + + // translate to usize from u64 + let logged_range = Range { + start: range.start as usize, + end: range.end as usize, + }; + async move { + ops.add_entry_for_range(&logged_range); + Ok(data) + } + .boxed() + } + + fn get_byte_ranges(&mut self, ranges: Vec>) -> BoxFuture<'_, Result>> { + let ops = Arc::clone(&self.ops); + let datas = ranges + .iter() + .map(|range| { + self.bytes + .slice(range.start as usize..range.end as usize) + .clone() + }) + .collect::>(); + // translate to usize from u64 + let logged_ranges = ranges + .into_iter() + .map(|r| Range { + start: r.start as usize, + end: r.end as usize, + }) + .collect::>(); + + async move { + ops.add_entry_for_ranges(&logged_ranges); + Ok(datas) + } + .boxed() + } + + fn get_metadata<'a>( + &'a mut self, + _options: Option<&'a ArrowReaderOptions>, + ) -> BoxFuture<'a, Result>> { + let ops = Arc::clone(&self.ops); + let parquet_meta_data = Arc::clone(&self.parquet_meta_data); + async move { + ops.add_entry(LogEntry::GetProvidedMetadata); + Ok(parquet_meta_data) + } + .boxed() + } +} diff --git a/parquet/tests/arrow_reader/io/mod.rs b/parquet/tests/arrow_reader/io/mod.rs new file mode 100644 index 000000000000..b31f295755b0 --- /dev/null +++ b/parquet/tests/arrow_reader/io/mod.rs @@ -0,0 +1,703 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for IO read patterns in the Parquet Reader +//! +//! Each test: +//! 1. Creates a temporary Parquet file with a known row group structure +//! 2. Reads data from that file using the Arrow Parquet Reader, recording the IO operations +//! 3. Asserts the expected IO patterns based on the read operations +//! +//! Note this module contains test infrastructure only. The actual tests are in the +//! sub-modules [`sync_reader`] and [`async_reader`]. +//! +//! Key components: +//! - [`TestParquetFile`] - Represents a Parquet file and its layout +//! - [`OperationLog`] - Records IO operations performed on the file +//! - [`LogEntry`] - Represents a single IO operation in the log + +mod sync_reader; + +#[cfg(feature = "async")] +mod async_reader; + +use arrow::compute::and; +use arrow::compute::kernels::cmp::{gt, lt}; +use arrow_array::cast::AsArray; +use arrow_array::types::Int64Type; +use arrow_array::{ArrayRef, BooleanArray, Int64Array, RecordBatch, StringViewArray}; +use bytes::Bytes; +use parquet::arrow::arrow_reader::{ + ArrowPredicateFn, ArrowReaderOptions, ParquetRecordBatchReaderBuilder, RowFilter, +}; +use parquet::arrow::{ArrowWriter, ProjectionMask}; +use parquet::data_type::AsBytes; +use parquet::file::metadata::{ParquetMetaData, ParquetMetaDataReader, ParquetOffsetIndex}; +use parquet::file::properties::WriterProperties; +use parquet::file::FOOTER_SIZE; +use parquet::format::PageLocation; +use parquet::schema::types::SchemaDescriptor; +use std::collections::BTreeMap; +use std::fmt::Display; +use std::ops::Range; +use std::sync::{Arc, LazyLock, Mutex}; + +/// Create a new `TestParquetFile` with: +/// 3 columns: "a", "b", "c" +/// +/// 2 row groups, each with 200 rows +/// each data page has 100 rows +/// +/// Values of column "a" are 0..399 +/// Values of column "b" are 400..799 +/// Values of column "c" are alternating strings of length 12 and longer +fn test_file() -> TestParquetFile { + TestParquetFile::new(TEST_FILE_DATA.clone()) +} + +/// Default options for tests +/// +/// Note these tests use the PageIndex to reduce IO +fn test_options() -> ArrowReaderOptions { + ArrowReaderOptions::default().with_page_index(true) +} + +/// Return a row filter that evaluates "b > 575" AND "b < 625" +/// +/// last data page in Row Group 0 and first DataPage in Row Group 1 +fn filter_b_575_625(schema_descr: &SchemaDescriptor) -> RowFilter { + // "b" > 575 and "b" < 625 + let predicate = ArrowPredicateFn::new( + ProjectionMask::columns(schema_descr, ["b"]), + |batch: RecordBatch| { + let scalar_575 = Int64Array::new_scalar(575); + let scalar_625 = Int64Array::new_scalar(625); + let column = batch.column(0).as_primitive::(); + and(>(column, &scalar_575)?, <(column, &scalar_625)?) + }, + ); + RowFilter::new(vec![Box::new(predicate)]) +} + +/// Filter a > 175 and b < 625 +/// First filter: "a" > 175 (last data page in Row Group 0) +/// Second filter: "b" < 625 (last data page in Row Group 0 and first DataPage in RowGroup 1) +fn filter_a_175_b_625(schema_descr: &SchemaDescriptor) -> RowFilter { + // "a" > 175 and "b" < 625 + let predicate_a = ArrowPredicateFn::new( + ProjectionMask::columns(schema_descr, ["a"]), + |batch: RecordBatch| { + let scalar_175 = Int64Array::new_scalar(175); + let column = batch.column(0).as_primitive::(); + gt(column, &scalar_175) + }, + ); + + let predicate_b = ArrowPredicateFn::new( + ProjectionMask::columns(schema_descr, ["b"]), + |batch: RecordBatch| { + let scalar_625 = Int64Array::new_scalar(625); + let column = batch.column(0).as_primitive::(); + lt(column, &scalar_625) + }, + ); + + RowFilter::new(vec![Box::new(predicate_a), Box::new(predicate_b)]) +} + +/// Filter FALSE (no rows) with b +/// Entirely filters out both row groups +/// Note it selects "b" +fn filter_b_false(schema_descr: &SchemaDescriptor) -> RowFilter { + // "false" + let predicate = ArrowPredicateFn::new( + ProjectionMask::columns(schema_descr, ["b"]), + |batch: RecordBatch| { + let result = + BooleanArray::from_iter(std::iter::repeat_n(Some(false), batch.num_rows())); + Ok(result) + }, + ); + RowFilter::new(vec![Box::new(predicate)]) +} + +/// Create a parquet file in memory for testing. See [`test_file`] for details. +static TEST_FILE_DATA: LazyLock = LazyLock::new(|| { + // Input batch has 400 rows, with 3 columns: "a", "b", "c" + // Note c is a different types (so the data page sizes will be different) + let a: ArrayRef = Arc::new(Int64Array::from_iter_values(0..400)); + let b: ArrayRef = Arc::new(Int64Array::from_iter_values(400..800)); + let c: ArrayRef = Arc::new(StringViewArray::from_iter_values((0..400).map(|i| { + if i % 2 == 0 { + format!("string_{i}") + } else { + format!("A string larger than 12 bytes and thus not inlined {i}") + } + }))); + + let input_batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + + let mut output = Vec::new(); + + let writer_options = WriterProperties::builder() + .set_max_row_group_size(200) + .set_data_page_row_count_limit(100) + .build(); + let mut writer = + ArrowWriter::try_new(&mut output, input_batch.schema(), Some(writer_options)).unwrap(); + + // since the limits are only enforced on batch boundaries, write the input + // batch in chunks of 50 + let mut row_remain = input_batch.num_rows(); + while row_remain > 0 { + let chunk_size = row_remain.min(50); + let chunk = input_batch.slice(input_batch.num_rows() - row_remain, chunk_size); + writer.write(&chunk).unwrap(); + row_remain -= chunk_size; + } + writer.close().unwrap(); + Bytes::from(output) +}); + +/// A test parquet file and its layout. +struct TestParquetFile { + bytes: Bytes, + /// The operation log for IO operations performed on this file + ops: Arc, + /// The (pre-parsed) parquet metadata for this file + parquet_metadata: Arc, +} + +impl TestParquetFile { + /// Create a new `TestParquetFile` with the specified temporary directory and path + /// and determines the row group layout. + fn new(bytes: Bytes) -> Self { + // Read the parquet file to determine its layout + let builder = ParquetRecordBatchReaderBuilder::try_new_with_options( + bytes.clone(), + ArrowReaderOptions::default().with_page_index(true), + ) + .unwrap(); + + let parquet_metadata = Arc::clone(builder.metadata()); + + let offset_index = parquet_metadata + .offset_index() + .expect("Parquet metadata should have a page index"); + + let row_groups = TestRowGroups::new(&parquet_metadata, offset_index); + + // figure out the footer location in the file + let footer_location = bytes.len() - FOOTER_SIZE..bytes.len(); + let footer = bytes.slice(footer_location.clone()); + let footer: &[u8; FOOTER_SIZE] = footer + .as_bytes() + .try_into() // convert to a fixed size array + .unwrap(); + + // figure out the metadata location + let footer = ParquetMetaDataReader::decode_footer_tail(footer).unwrap(); + let metadata_len = footer.metadata_length(); + let metadata_location = footer_location.start - metadata_len..footer_location.start; + + let ops = Arc::new(OperationLog::new( + footer_location, + metadata_location, + row_groups, + )); + + TestParquetFile { + bytes, + ops, + parquet_metadata, + } + } + + /// Return the internal bytes of the parquet file + fn bytes(&self) -> &Bytes { + &self.bytes + } + + /// Return the operation log for this file + fn ops(&self) -> &Arc { + &self.ops + } + + /// Return the parquet metadata for this file + fn parquet_metadata(&self) -> &Arc { + &self.parquet_metadata + } +} + +/// Information about a column chunk +#[derive(Debug)] +struct TestColumnChunk { + /// The name of the column + name: String, + + /// The location of the entire column chunk in the file including dictionary pages + /// and data pages. + location: Range, + + /// The offset of the start of of the dictionary page if any + dictionary_page_location: Option, + + /// The location of the data pages in the file + page_locations: Vec, +} + +/// Information about the pages in a single row group +#[derive(Debug)] +struct TestRowGroup { + /// Maps column_name -> Information about the column chunk + columns: BTreeMap, +} + +/// Information about all the row groups in a Parquet file, extracted from its metadata +#[derive(Debug)] +struct TestRowGroups { + /// List of row groups, each containing information about its columns and page locations + row_groups: Vec, +} + +impl TestRowGroups { + fn new(parquet_metadata: &ParquetMetaData, offset_index: &ParquetOffsetIndex) -> Self { + let row_groups = parquet_metadata + .row_groups() + .iter() + .enumerate() + .map(|(rg_index, rg_meta)| { + let columns = rg_meta + .columns() + .iter() + .enumerate() + .map(|(col_idx, col_meta)| { + let column_name = col_meta.column_descr().name().to_string(); + let page_locations = + offset_index[rg_index][col_idx].page_locations().to_vec(); + let dictionary_page_location = col_meta.dictionary_page_offset(); + + // We can find the byte range of the entire column chunk + let (start_offset, length) = col_meta.byte_range(); + let start_offset = start_offset as usize; + let end_offset = start_offset + length as usize; + + TestColumnChunk { + name: column_name.clone(), + location: start_offset..end_offset, + dictionary_page_location, + page_locations, + } + }) + .map(|test_column_chunk| { + // make key=value pairs to insert into the BTreeMap + (test_column_chunk.name.clone(), test_column_chunk) + }) + .collect::>(); + TestRowGroup { columns } + }) + .collect(); + + Self { row_groups } + } + + fn iter(&self) -> impl Iterator { + self.row_groups.iter() + } +} + +/// Type of data read +#[derive(Debug, PartialEq)] +enum PageType { + /// The data page with the specified index + Data { + data_page_index: usize, + }, + Dictionary, + /// Multiple pages read together + Multi { + /// Was the dictionary page included? + dictionary_page: bool, + /// The data pages included + data_page_indices: Vec, + }, +} + +impl Display for PageType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PageType::Data { data_page_index } => { + write!(f, "DataPage({data_page_index})") + } + PageType::Dictionary => write!(f, "DictionaryPage"), + PageType::Multi { + dictionary_page, + data_page_indices, + } => { + let dictionary_page = if *dictionary_page { + "dictionary_page: true, " + } else { + "" + }; + write!( + f, + "MultiPage({dictionary_page}data_pages: {data_page_indices:?})", + ) + } + } + } +} + +/// Read single logical data object (data page or dictionary page) +/// in one or more requests +#[derive(Debug)] +struct ReadInfo { + row_group_index: usize, + column_name: String, + range: Range, + read_type: PageType, + /// Number of distinct requests (function calls) that were used + num_requests: usize, +} + +impl Display for ReadInfo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let Self { + row_group_index, + column_name, + range, + read_type, + num_requests, + } = self; + + // If the average read size is less than 10 bytes, assume it is the thrift + // decoder reading the page headers and add an annotation + let annotation = if (range.len() / num_requests) < 10 { + " [header]" + } else { + " [data]" + }; + + // align the read type to 20 characters for better readability, not sure why + // this does not work inline with write! macro below + write!( + f, + "Row Group {row_group_index}, column '{column_name}': {:15} ({:10}, {:8}){annotation}", + // convert to strings so alignment works + format!("{read_type}"), + format!("{} bytes", range.len()), + format!("{num_requests} requests"), + ) + } +} + +/// Store structured entries in the log to make it easier to combine multiple entries +#[derive(Debug)] +enum LogEntry { + /// Read the footer (last 8 bytes) of the parquet file + ReadFooter(Range), + /// Read the metadata of the parquet file + ReadMetadata(Range), + /// Access previously parsed metadata + GetProvidedMetadata, + /// Read a single logical data object + ReadData(ReadInfo), + /// Read one or more logical data objects in a single operation + ReadMultipleData(Vec), + /// Not known where the read came from + Unknown(Range), + /// A user defined event + Event(String), +} + +impl LogEntry { + fn event(event: impl Into) -> Self { + LogEntry::Event(event.into()) + } + + /// Appends a string representation of this log entry to the output vector + fn append_string(&self, output: &mut Vec, indent: usize) { + let indent_str = " ".repeat(indent); + match self { + LogEntry::ReadFooter(range) => { + output.push(format!("{indent_str}Footer: {} bytes", range.len())) + } + LogEntry::ReadMetadata(range) => { + output.push(format!("{indent_str}Metadata: {}", range.len())) + } + LogEntry::GetProvidedMetadata => { + output.push(format!("{indent_str}Get Provided Metadata")) + } + LogEntry::ReadData(read_info) => output.push(format!("{indent_str}{read_info}")), + LogEntry::ReadMultipleData(read_infos) => { + output.push(format!("{indent_str}Read Multi:")); + for read_info in read_infos { + let new_indent = indent + 2; + read_info.append_string(output, new_indent); + } + } + LogEntry::Unknown(range) => { + output.push(format!("{indent_str}UNKNOWN: {range:?} (maybe Page Index)")) + } + LogEntry::Event(event) => output.push(format!("Event: {event}")), + } + } +} + +#[derive(Debug)] +struct OperationLog { + /// The operations performed on the file + ops: Mutex>, + + /// Footer location in the parquet file + footer_location: Range, + + /// Metadata location in the parquet file + metadata_location: Range, + + /// Information about the row group layout in the parquet file, used to + /// translate read operations into human understandable IO operations + /// Path to the parquet file + row_groups: TestRowGroups, +} + +impl OperationLog { + fn new( + footer_location: Range, + metadata_location: Range, + row_groups: TestRowGroups, + ) -> Self { + OperationLog { + ops: Mutex::new(Vec::new()), + metadata_location, + footer_location, + row_groups, + } + } + + /// Add an operation to the log + fn add_entry(&self, entry: LogEntry) { + let mut ops = self.ops.lock().unwrap(); + ops.push(entry); + } + + /// Adds an entry to the operation log for the interesting object that is + /// accessed by the specified range + /// + /// This function checks the ranges in order against possible locations + /// and adds the appropriate operation to the log for the first match found. + fn add_entry_for_range(&self, range: &Range) { + self.add_entry(self.entry_for_range(range)); + } + + /// Adds entries to the operation log for each interesting object that is + /// accessed by the specified range + /// + /// It behaves the same as [`add_entry_for_range`] but for multiple ranges. + fn add_entry_for_ranges<'a>(&self, ranges: impl IntoIterator>) { + let entries = ranges + .into_iter() + .map(|range| self.entry_for_range(range)) + .collect::>(); + self.add_entry(LogEntry::ReadMultipleData(entries)); + } + + /// Create an appropriate LogEntry for the specified range + fn entry_for_range(&self, range: &Range) -> LogEntry { + let start = range.start as i64; + let end = range.end as i64; + + // figure out what logical part of the file this range corresponds to + if self.metadata_location.contains(&range.start) + || self.metadata_location.contains(&(range.end - 1)) + { + return LogEntry::ReadMetadata(range.clone()); + } + + if self.footer_location.contains(&range.start) + || self.footer_location.contains(&(range.end - 1)) + { + return LogEntry::ReadFooter(range.clone()); + } + + // Search for the location in each column chunk. + // + // The actual parquet reader must in general decode the page headers + // and determine the byte ranges of the pages. However, for this test + // we assume the following layout: + // + // ```text + // (Dictionary Page) + // (Data Page) + // ... + // (Data Page) + // ``` + // + // We also assume that `self.page_locations` holds the location of all + // data pages, so any read operation that overlaps with a data page + // location is considered a read of that page, and any other read must + // be a dictionary page read. + for (row_group_index, row_group) in self.row_groups.iter().enumerate() { + for (column_name, test_column_chunk) in &row_group.columns { + // Check if the range overlaps with any data page locations + let page_locations = test_column_chunk.page_locations.iter(); + + // What data pages does this range overlap with? + let mut data_page_indices = vec![]; + + for (data_page_index, page_location) in page_locations.enumerate() { + let page_offset = page_location.offset; + let page_end = page_offset + page_location.compressed_page_size as i64; + + // if the range fully contains the page, consider it a read of that page + if start >= page_offset && end <= page_end { + let read_info = ReadInfo { + row_group_index, + column_name: column_name.clone(), + range: range.clone(), + read_type: PageType::Data { data_page_index }, + num_requests: 1, + }; + return LogEntry::ReadData(read_info); + } + + // if the range overlaps with the page, add it to the list of overlapping pages + if start < page_end && end > page_offset { + data_page_indices.push(data_page_index); + } + } + + // was the dictionary page read? + let mut dictionary_page = false; + + // Check if the range overlaps with the dictionary page location + if let Some(dict_page_offset) = test_column_chunk.dictionary_page_location { + let dict_page_end = dict_page_offset + test_column_chunk.location.len() as i64; + if start >= dict_page_offset && end < dict_page_end { + let read_info = ReadInfo { + row_group_index, + column_name: column_name.clone(), + range: range.clone(), + read_type: PageType::Dictionary, + num_requests: 1, + }; + + return LogEntry::ReadData(read_info); + } + + // if the range overlaps with the dictionary page, add it to the list of overlapping pages + if start < dict_page_end && end > dict_page_offset { + dictionary_page = true; + } + } + + // If we can't find a page, but the range overlaps with the + // column chunk location, use the column chunk location + let column_byte_range = &test_column_chunk.location; + if column_byte_range.contains(&range.start) + && column_byte_range.contains(&(range.end - 1)) + { + let read_data_entry = ReadInfo { + row_group_index, + column_name: column_name.clone(), + range: range.clone(), + read_type: PageType::Multi { + data_page_indices, + dictionary_page, + }, + num_requests: 1, + }; + + return LogEntry::ReadData(read_data_entry); + } + } + } + + // If we reach here, the range does not match any known logical part of the file + LogEntry::Unknown(range.clone()) + } + + // Combine entries in the log that are similar to reduce noise in the log. + fn coalesce_entries(&self) { + let mut ops = self.ops.lock().unwrap(); + + // Coalesce entries with the same read type + let prev_ops = std::mem::take(&mut *ops); + for entry in prev_ops { + let Some(last) = ops.last_mut() else { + ops.push(entry); + continue; + }; + + let LogEntry::ReadData(ReadInfo { + row_group_index: last_rg_index, + column_name: last_column_name, + range: last_range, + read_type: last_read_type, + num_requests: last_num_reads, + }) = last + else { + // If the last entry is not a ReadColumnChunk, just push it + ops.push(entry); + continue; + }; + + // If the entry is not a ReadColumnChunk, just push it + let LogEntry::ReadData(ReadInfo { + row_group_index, + column_name, + range, + read_type, + num_requests: num_reads, + }) = &entry + else { + ops.push(entry); + continue; + }; + + // Combine the entries if they are the same and this read is less than 10b. + // + // This heuristic is used to combine small reads (typically 1-2 + // byte) made by the thrift decoder when reading the data/dictionary + // page headers. + if *row_group_index != *last_rg_index + || column_name != last_column_name + || read_type != last_read_type + || (range.start > last_range.end) + || (range.end < last_range.start) + || range.len() > 10 + { + ops.push(entry); + continue; + } + // combine + *last_range = last_range.start.min(range.start)..last_range.end.max(range.end); + *last_num_reads += num_reads; + } + } + + /// return a snapshot of the current operations in the log. + fn snapshot(&self) -> Vec { + self.coalesce_entries(); + let ops = self.ops.lock().unwrap(); + let mut actual = vec![]; + let indent = 0; + ops.iter() + .for_each(|s| s.append_string(&mut actual, indent)); + actual + } +} diff --git a/parquet/tests/arrow_reader/io/sync_reader.rs b/parquet/tests/arrow_reader/io/sync_reader.rs new file mode 100644 index 000000000000..685f251a9e2b --- /dev/null +++ b/parquet/tests/arrow_reader/io/sync_reader.rs @@ -0,0 +1,443 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for the sync reader - [`ParquetRecordBatchReaderBuilder`] + +use crate::io::{ + filter_a_175_b_625, filter_b_575_625, filter_b_false, test_file, test_options, LogEntry, + OperationLog, TestParquetFile, +}; + +use bytes::Bytes; +use parquet::arrow::arrow_reader::{ + ArrowReaderOptions, ParquetRecordBatchReaderBuilder, RowSelection, RowSelector, +}; +use parquet::arrow::ProjectionMask; +use parquet::file::reader::{ChunkReader, Length}; +use std::io::Read; +use std::sync::Arc; + +#[test] +fn test_read_entire_file() { + // read entire file without any filtering or projection + let test_file = test_file(); + // Expect to see IO for all data pages for each row group and column + let builder = sync_builder(&test_file, test_options()); + insta::assert_debug_snapshot!(run(&test_file, builder), + @r#" + [ + "Footer: 8 bytes", + "Metadata: 1162", + "UNKNOWN: 22230..22877 (maybe Page Index)", + "Event: Builder Configured", + "Event: Reader Built", + "Row Group 0, column 'a': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 0, column 'a': DataPage(0) (113 bytes , 1 requests) [data]", + "Row Group 0, column 'a': DataPage(1) (126 bytes , 1 requests) [data]", + "Row Group 1, column 'a': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 1, column 'a': DataPage(0) (113 bytes , 1 requests) [data]", + "Row Group 1, column 'a': DataPage(1) (126 bytes , 1 requests) [data]", + "Row Group 0, column 'b': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 0, column 'b': DataPage(0) (113 bytes , 1 requests) [data]", + "Row Group 0, column 'b': DataPage(1) (126 bytes , 1 requests) [data]", + "Row Group 1, column 'b': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 1, column 'b': DataPage(0) (113 bytes , 1 requests) [data]", + "Row Group 1, column 'b': DataPage(1) (126 bytes , 1 requests) [data]", + "Row Group 0, column 'c': DictionaryPage (7107 bytes, 1 requests) [data]", + "Row Group 0, column 'c': DataPage(0) (113 bytes , 1 requests) [data]", + "Row Group 0, column 'c': DataPage(1) (126 bytes , 1 requests) [data]", + "Row Group 1, column 'c': DictionaryPage (7217 bytes, 1 requests) [data]", + "Row Group 1, column 'c': DataPage(0) (113 bytes , 1 requests) [data]", + "Row Group 1, column 'c': DataPage(1) (126 bytes , 1 requests) [data]", + ] + "#); +} + +#[test] +fn test_read_single_group() { + let test_file = test_file(); + let builder = sync_builder(&test_file, test_options()).with_row_groups(vec![1]); // read only second row group + + // Expect to see only IO for Row Group 1. Should see no IO for Row Group 0. + insta::assert_debug_snapshot!(run(&test_file, builder), + @r#" + [ + "Footer: 8 bytes", + "Metadata: 1162", + "UNKNOWN: 22230..22877 (maybe Page Index)", + "Event: Builder Configured", + "Event: Reader Built", + "Row Group 1, column 'a': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 1, column 'a': DataPage(0) (113 bytes , 1 requests) [data]", + "Row Group 1, column 'a': DataPage(1) (126 bytes , 1 requests) [data]", + "Row Group 1, column 'b': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 1, column 'b': DataPage(0) (113 bytes , 1 requests) [data]", + "Row Group 1, column 'b': DataPage(1) (126 bytes , 1 requests) [data]", + "Row Group 1, column 'c': DictionaryPage (7217 bytes, 1 requests) [data]", + "Row Group 1, column 'c': DataPage(0) (113 bytes , 1 requests) [data]", + "Row Group 1, column 'c': DataPage(1) (126 bytes , 1 requests) [data]", + ] + "#); +} + +#[test] +fn test_read_single_column() { + let test_file = test_file(); + let builder = sync_builder(&test_file, test_options()); + let schema_descr = builder.metadata().file_metadata().schema_descr_ptr(); + let builder = builder.with_projection(ProjectionMask::columns(&schema_descr, ["b"])); + // Expect to see only IO for column "b". Should see no IO for columns "a" or "c". + insta::assert_debug_snapshot!(run(&test_file, builder), + @r#" + [ + "Footer: 8 bytes", + "Metadata: 1162", + "UNKNOWN: 22230..22877 (maybe Page Index)", + "Event: Builder Configured", + "Event: Reader Built", + "Row Group 0, column 'b': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 0, column 'b': DataPage(0) (113 bytes , 1 requests) [data]", + "Row Group 0, column 'b': DataPage(1) (126 bytes , 1 requests) [data]", + "Row Group 1, column 'b': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 1, column 'b': DataPage(0) (113 bytes , 1 requests) [data]", + "Row Group 1, column 'b': DataPage(1) (126 bytes , 1 requests) [data]", + ] + "#); +} + +#[test] +fn test_read_single_column_no_page_index() { + let test_file = test_file(); + let options = test_options().with_page_index(false); + let builder = sync_builder(&test_file, options); + let schema_descr = builder.metadata().file_metadata().schema_descr_ptr(); + let builder = builder.with_projection(ProjectionMask::columns(&schema_descr, ["b"])); + // Expect to see only IO for column "b", should see no IO for columns "a" or "c". + // + // Note that we need to read all data page headers to find the pages for column b + // so there are many more small reads than in the test_read_single_column test above + insta::assert_debug_snapshot!(run(&test_file, builder), + @r#" + [ + "Footer: 8 bytes", + "Metadata: 1162", + "Event: Builder Configured", + "Event: Reader Built", + "Row Group 0, column 'b': DictionaryPage (17 bytes , 17 requests) [header]", + "Row Group 0, column 'b': DictionaryPage (1600 bytes, 1 requests) [data]", + "Row Group 0, column 'b': DataPage(0) (20 bytes , 20 requests) [header]", + "Row Group 0, column 'b': DataPage(0) (93 bytes , 1 requests) [data]", + "Row Group 0, column 'b': DataPage(1) (20 bytes , 20 requests) [header]", + "Row Group 0, column 'b': DataPage(1) (106 bytes , 1 requests) [data]", + "Row Group 1, column 'b': DictionaryPage (17 bytes , 17 requests) [header]", + "Row Group 1, column 'b': DictionaryPage (1600 bytes, 1 requests) [data]", + "Row Group 1, column 'b': DataPage(0) (20 bytes , 20 requests) [header]", + "Row Group 1, column 'b': DataPage(0) (93 bytes , 1 requests) [data]", + "Row Group 1, column 'b': DataPage(1) (20 bytes , 20 requests) [header]", + "Row Group 1, column 'b': DataPage(1) (106 bytes , 1 requests) [data]", + ] + "#); +} + +#[test] +fn test_read_row_selection() { + // There are 400 total rows spread across 4 data pages (100 rows each) + // select rows 175..225 (i.e. DataPage(1) of row group 0 and DataPage(0) of row group 1) + let test_file = test_file(); + let builder = sync_builder(&test_file, test_options()); + let schema_descr = builder.metadata().file_metadata().schema_descr_ptr(); + let builder = builder + .with_projection( + // read both "a" and "b" + ProjectionMask::columns(&schema_descr, ["a", "b"]), + ) + .with_row_selection(RowSelection::from(vec![ + RowSelector::skip(175), + RowSelector::select(50), + ])); + + // Expect to see only data IO for one page for each column for each row group + // Note the data page headers for all pages need to be read to find the correct pages + insta::assert_debug_snapshot!(run(&test_file, builder), + @r#" + [ + "Footer: 8 bytes", + "Metadata: 1162", + "UNKNOWN: 22230..22877 (maybe Page Index)", + "Event: Builder Configured", + "Event: Reader Built", + "Row Group 0, column 'a': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 0, column 'a': DataPage(1) (126 bytes , 1 requests) [data]", + "Row Group 0, column 'b': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 0, column 'b': DataPage(1) (126 bytes , 1 requests) [data]", + "Row Group 1, column 'a': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 1, column 'a': DataPage(0) (113 bytes , 1 requests) [data]", + "Row Group 1, column 'b': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 1, column 'b': DataPage(0) (113 bytes , 1 requests) [data]", + ] + "#); +} + +#[test] +fn test_read_limit() { + // There are 400 total rows spread across 4 data pages (100 rows each) + // a limit of 125 rows should only fetch the first two data pages (DataPage(0) and DataPage(1)) from row group 0 + let test_file = test_file(); + let builder = sync_builder(&test_file, test_options()); + let schema_descr = builder.metadata().file_metadata().schema_descr_ptr(); + let builder = builder + .with_projection(ProjectionMask::columns(&schema_descr, ["a"])) + .with_limit(125); + + insta::assert_debug_snapshot!(run(&test_file, builder), + @r#" + [ + "Footer: 8 bytes", + "Metadata: 1162", + "UNKNOWN: 22230..22877 (maybe Page Index)", + "Event: Builder Configured", + "Event: Reader Built", + "Row Group 0, column 'a': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 0, column 'a': DataPage(0) (113 bytes , 1 requests) [data]", + "Row Group 0, column 'a': DataPage(1) (126 bytes , 1 requests) [data]", + ] + "#); +} + +#[test] +fn test_read_single_row_filter() { + // Values from column "b" range 400..799 + // filter "b" > 575 and < 625 + // (last data page in Row Group 0 and first DataPage in Row Group 1) + let test_file = test_file(); + let builder = sync_builder(&test_file, test_options()); + let schema_descr = builder.metadata().file_metadata().schema_descr_ptr(); + + let builder = builder + .with_projection( + // read both "a" and "b" + ProjectionMask::columns(&schema_descr, ["a", "b"]), + ) + // "b" > 575 and "b" < 625 + .with_row_filter(filter_b_575_625(&schema_descr)); + + // Expect to see I/O for column b in both row groups and then reading just a + // single pages for a in each row group + // + // Note there is significant IO that happens during the construction of the + // reader (between "Builder Configured" and "Reader Built") + insta::assert_debug_snapshot!(run(&test_file, builder), + @r#" + [ + "Footer: 8 bytes", + "Metadata: 1162", + "UNKNOWN: 22230..22877 (maybe Page Index)", + "Event: Builder Configured", + "Row Group 0, column 'b': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 0, column 'b': DataPage(0) (113 bytes , 1 requests) [data]", + "Row Group 0, column 'b': DataPage(1) (126 bytes , 1 requests) [data]", + "Row Group 1, column 'b': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 1, column 'b': DataPage(0) (113 bytes , 1 requests) [data]", + "Row Group 1, column 'b': DataPage(1) (126 bytes , 1 requests) [data]", + "Event: Reader Built", + "Row Group 0, column 'a': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 0, column 'a': DataPage(1) (126 bytes , 1 requests) [data]", + "Row Group 0, column 'b': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 0, column 'b': DataPage(1) (126 bytes , 1 requests) [data]", + "Row Group 1, column 'a': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 1, column 'a': DataPage(0) (113 bytes , 1 requests) [data]", + "Row Group 1, column 'b': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 1, column 'b': DataPage(0) (113 bytes , 1 requests) [data]", + ] + "#); +} + +#[test] +fn test_read_multiple_row_filter() { + // Values in column "a" range 0..399 + // Values in column "b" range 400..799 + // First filter: "a" > 175 (last data page in Row Group 0) + // Second filter: "b" < 625 (last data page in Row Group 0 and first DataPage in RowGroup 1) + // Read column "c" + let test_file = test_file(); + let builder = sync_builder(&test_file, test_options()); + let schema_descr = builder.metadata().file_metadata().schema_descr_ptr(); + + let builder = builder + .with_projection( + ProjectionMask::columns(&schema_descr, ["c"]), // read "c" + ) + // a > 175 and b < 625 + .with_row_filter(filter_a_175_b_625(&schema_descr)); + + // Expect that we will see + // 1. IO for all pages of column A + // 2. IO for pages of column b that passed 1. + // 3. IO after reader is built only for column c + // + // Note there is significant IO that happens during the construction of the + // reader (between "Builder Configured" and "Reader Built") + insta::assert_debug_snapshot!(run(&test_file, builder), + @r#" + [ + "Footer: 8 bytes", + "Metadata: 1162", + "UNKNOWN: 22230..22877 (maybe Page Index)", + "Event: Builder Configured", + "Row Group 0, column 'a': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 0, column 'a': DataPage(0) (113 bytes , 1 requests) [data]", + "Row Group 0, column 'a': DataPage(1) (126 bytes , 1 requests) [data]", + "Row Group 1, column 'a': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 1, column 'a': DataPage(0) (113 bytes , 1 requests) [data]", + "Row Group 1, column 'a': DataPage(1) (126 bytes , 1 requests) [data]", + "Row Group 0, column 'b': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 0, column 'b': DataPage(1) (126 bytes , 1 requests) [data]", + "Row Group 1, column 'b': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 1, column 'b': DataPage(0) (113 bytes , 1 requests) [data]", + "Row Group 1, column 'b': DataPage(1) (126 bytes , 1 requests) [data]", + "Event: Reader Built", + "Row Group 0, column 'c': DictionaryPage (7107 bytes, 1 requests) [data]", + "Row Group 0, column 'c': DataPage(1) (126 bytes , 1 requests) [data]", + "Row Group 1, column 'c': DictionaryPage (7217 bytes, 1 requests) [data]", + "Row Group 1, column 'c': DataPage(0) (113 bytes , 1 requests) [data]", + ] + "#); +} + +#[test] +fn test_read_single_row_filter_all() { + // Apply a filter that entirely filters out rows based on a predicate from one column + // should not read any data pages for any other column + + let test_file = test_file(); + let builder = sync_builder(&test_file, test_options()); + let schema_descr = builder.metadata().file_metadata().schema_descr_ptr(); + + let builder = builder + .with_projection(ProjectionMask::columns(&schema_descr, ["a", "b"])) + .with_row_filter(filter_b_false(&schema_descr)); + + // Expect to see the Footer and Metadata, then I/O for column b + // in both row groups but then nothing for column "a" + // since the row filter entirely filters out all rows. + // + // Note that all IO that happens during the construction of the reader + // (between "Builder Configured" and "Reader Built") + insta::assert_debug_snapshot!(run(&test_file, builder), + @r#" + [ + "Footer: 8 bytes", + "Metadata: 1162", + "UNKNOWN: 22230..22877 (maybe Page Index)", + "Event: Builder Configured", + "Row Group 0, column 'b': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 0, column 'b': DataPage(0) (113 bytes , 1 requests) [data]", + "Row Group 0, column 'b': DataPage(1) (126 bytes , 1 requests) [data]", + "Row Group 1, column 'b': DictionaryPage (1617 bytes, 1 requests) [data]", + "Row Group 1, column 'b': DataPage(0) (113 bytes , 1 requests) [data]", + "Row Group 1, column 'b': DataPage(1) (126 bytes , 1 requests) [data]", + "Event: Reader Built", + ] + "#); +} + +/// Return a [`ParquetRecordBatchReaderBuilder`] for reading this file +fn sync_builder( + test_file: &TestParquetFile, + options: ArrowReaderOptions, +) -> ParquetRecordBatchReaderBuilder { + let reader = RecordingChunkReader { + inner: test_file.bytes().clone(), + ops: Arc::clone(test_file.ops()), + }; + ParquetRecordBatchReaderBuilder::try_new_with_options(reader, options) + .expect("ParquetRecordBatchReaderBuilder") +} + +/// build the reader, and read all batches from it, returning the recorded IO operations +fn run( + test_file: &TestParquetFile, + builder: ParquetRecordBatchReaderBuilder, +) -> Vec { + let ops = test_file.ops(); + ops.add_entry(LogEntry::event("Builder Configured")); + let reader = builder.build().unwrap(); + ops.add_entry(LogEntry::event("Reader Built")); + for batch in reader { + match batch { + Ok(_) => {} + Err(e) => panic!("Error reading batch: {e}"), + } + } + ops.snapshot() +} + +/// Records IO operations on an in-memory chunk reader +struct RecordingChunkReader { + inner: Bytes, + ops: Arc, +} + +impl Length for RecordingChunkReader { + fn len(&self) -> u64 { + self.inner.len() as u64 + } +} + +impl ChunkReader for RecordingChunkReader { + type T = RecordingStdIoReader; + + fn get_read(&self, start: u64) -> parquet::errors::Result { + let reader = RecordingStdIoReader { + start: start as usize, + inner: self.inner.clone(), + ops: Arc::clone(&self.ops), + }; + Ok(reader) + } + + fn get_bytes(&self, start: u64, length: usize) -> parquet::errors::Result { + let start = start as usize; + let range = start..start + length; + self.ops.add_entry_for_range(&range); + Ok(self.inner.slice(start..start + length)) + } +} + +/// Wrapper around a `Bytes` object that implements `Read` +struct RecordingStdIoReader { + /// current offset in the inner `Bytes` that this reader is reading from + start: usize, + inner: Bytes, + ops: Arc, +} + +impl Read for RecordingStdIoReader { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let remain = self.inner.len() - self.start; + let start = self.start; + let read_length = buf.len().min(remain); + let read_range = start..start + read_length; + + self.ops.add_entry_for_range(&read_range); + + buf.copy_from_slice(self.inner.slice(read_range).as_ref()); + // Update the inner position + self.start += read_length; + Ok(read_length) + } +} diff --git a/parquet/tests/arrow_reader/mod.rs b/parquet/tests/arrow_reader/mod.rs index 8d72d1def17a..510d62786077 100644 --- a/parquet/tests/arrow_reader/mod.rs +++ b/parquet/tests/arrow_reader/mod.rs @@ -42,6 +42,7 @@ mod bad_data; #[cfg(feature = "crc")] mod checksum; mod int96_stats_roundtrip; +mod io; #[cfg(feature = "async")] mod predicate_cache; mod statistics; @@ -336,9 +337,9 @@ fn make_uint_batches(start: u8, end: u8) -> RecordBatch { Field::new("u64", DataType::UInt64, true), ])); let v8: Vec = (start..end).collect(); - let v16: Vec = (start as _..end as _).collect(); - let v32: Vec = (start as _..end as _).collect(); - let v64: Vec = (start as _..end as _).collect(); + let v16: Vec = (start as _..end as u16).collect(); + let v32: Vec = (start as _..end as u32).collect(); + let v64: Vec = (start as _..end as u64).collect(); RecordBatch::try_new( schema, vec![ diff --git a/parquet/tests/arrow_reader/predicate_cache.rs b/parquet/tests/arrow_reader/predicate_cache.rs index 44d43113cbf5..15fd7c9e4f2d 100644 --- a/parquet/tests/arrow_reader/predicate_cache.rs +++ b/parquet/tests/arrow_reader/predicate_cache.rs @@ -32,7 +32,7 @@ use parquet::arrow::arrow_reader::{ArrowPredicateFn, ArrowReaderOptions, RowFilt use parquet::arrow::arrow_reader::{ArrowReaderBuilder, ParquetRecordBatchReaderBuilder}; use parquet::arrow::async_reader::AsyncFileReader; use parquet::arrow::{ArrowWriter, ParquetRecordBatchStreamBuilder, ProjectionMask}; -use parquet::file::metadata::{ParquetMetaData, ParquetMetaDataReader}; +use parquet::file::metadata::{PageIndexPolicy, ParquetMetaData, ParquetMetaDataReader}; use parquet::file::properties::WriterProperties; use std::ops::Range; use std::sync::Arc; @@ -269,8 +269,9 @@ impl AsyncFileReader for TestReader { &'a mut self, options: Option<&'a ArrowReaderOptions>, ) -> BoxFuture<'a, parquet::errors::Result>> { - let metadata_reader = - ParquetMetaDataReader::new().with_page_indexes(options.is_some_and(|o| o.page_index())); + let metadata_reader = ParquetMetaDataReader::new().with_page_index_policy( + PageIndexPolicy::from(options.is_some_and(|o| o.page_index())), + ); self.metadata = Some(Arc::new( metadata_reader.parse_and_finish(&self.data).unwrap(), )); diff --git a/parquet/tests/variant_integration.rs b/parquet/tests/variant_integration.rs new file mode 100644 index 000000000000..97fb6b880108 --- /dev/null +++ b/parquet/tests/variant_integration.rs @@ -0,0 +1,504 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Comprehensive integration tests for Parquet files with Variant columns +//! +//! This test harness reads test case definitions from cases.json, loads expected +//! Variant values from .variant.bin files, reads Parquet files, converts StructArray +//! to VariantArray, and verifies that extracted values match expected results. +//! +//! Inspired by the arrow-go implementation: + +use arrow::util::test_util::parquet_test_data; +use arrow_array::{Array, ArrayRef}; +use arrow_cast::cast; +use arrow_schema::{DataType, Fields}; +use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; +use parquet_variant::{Variant, VariantMetadata}; +use parquet_variant_compute::VariantArray; +use serde::Deserialize; +use std::path::Path; +use std::sync::{Arc, LazyLock}; +use std::{fs, path::PathBuf}; + +type Result = std::result::Result; + +/// Creates a test function for a given case number +/// +/// Note the index is zero-based, while the case number is one-based +macro_rules! variant_test_case { + ($case_num:literal) => { + paste::paste! { + #[test] + fn []() { + all_cases()[$case_num - 1].run() + } + } + }; + + // Generates an error test case, where the expected result is an error message + ($case_num:literal, $expected_error:literal) => { + paste::paste! { + #[test] + #[should_panic(expected = $expected_error)] + fn []() { + all_cases()[$case_num - 1].run() + } + } + }; +} + +// Generate test functions for each case +// Notes +// - case 3 is empty in cases.json for some reason +// - cases 40, 42, 87, 127 and 128 are expected to fail always (they include invalid variants) +// - the remaining cases are expected to (eventually) pass + +variant_test_case!(1, "Unsupported typed_value type: List("); +variant_test_case!(2, "Unsupported typed_value type: List("); +// case 3 is empty in cases.json 🤷 +// ```json +// { +// "case_number" : 3 +// }, +// ``` +variant_test_case!(3, "parquet_file must be set"); +// https://github.com/apache/arrow-rs/issues/8329 +variant_test_case!(4); +variant_test_case!(5); +variant_test_case!(6); +variant_test_case!(7); +variant_test_case!(8); +variant_test_case!(9); +variant_test_case!(10); +variant_test_case!(11); +variant_test_case!(12); +variant_test_case!(13); +variant_test_case!(14); +variant_test_case!(15); +variant_test_case!(16); +variant_test_case!(17); +// https://github.com/apache/arrow-rs/issues/8330 +variant_test_case!(18, "Unsupported typed_value type: Date32"); +variant_test_case!(19, "Unsupported typed_value type: Date32"); +// https://github.com/apache/arrow-rs/issues/8331 +variant_test_case!( + 20, + "Unsupported typed_value type: Timestamp(Microsecond, Some(\"UTC\"))" +); +variant_test_case!( + 21, + "Unsupported typed_value type: Timestamp(Microsecond, Some(\"UTC\"))" +); +variant_test_case!( + 22, + "Unsupported typed_value type: Timestamp(Microsecond, None)" +); +variant_test_case!( + 23, + "Unsupported typed_value type: Timestamp(Microsecond, None)" +); +// https://github.com/apache/arrow-rs/issues/8332 +variant_test_case!(24, "Unsupported typed_value type: Decimal128(9, 4)"); +variant_test_case!(25, "Unsupported typed_value type: Decimal128(9, 4)"); +variant_test_case!(26, "Unsupported typed_value type: Decimal128(18, 9)"); +variant_test_case!(27, "Unsupported typed_value type: Decimal128(18, 9)"); +variant_test_case!(28, "Unsupported typed_value type: Decimal128(38, 9)"); +variant_test_case!(29, "Unsupported typed_value type: Decimal128(38, 9)"); +variant_test_case!(30); +variant_test_case!(31); +// https://github.com/apache/arrow-rs/issues/8334 +variant_test_case!(32, "Unsupported typed_value type: Time64(Microsecond)"); +// https://github.com/apache/arrow-rs/issues/8331 +variant_test_case!( + 33, + "Unsupported typed_value type: Timestamp(Nanosecond, Some(\"UTC\"))" +); +variant_test_case!( + 34, + "Unsupported typed_value type: Timestamp(Nanosecond, Some(\"UTC\"))" +); +variant_test_case!( + 35, + "Unsupported typed_value type: Timestamp(Nanosecond, None)" +); +variant_test_case!( + 36, + "Unsupported typed_value type: Timestamp(Nanosecond, None)" +); +variant_test_case!(37); +// https://github.com/apache/arrow-rs/issues/8336 +variant_test_case!(38, "Unsupported typed_value type: Struct("); +variant_test_case!(39); +// Is an error case (should be failing as the expected error message indicates) +variant_test_case!(40, "Unsupported typed_value type: List("); +variant_test_case!(41, "Unsupported typed_value type: List(Field"); +// Is an error case (should be failing as the expected error message indicates) +variant_test_case!( + 42, + "Expected an error 'Invalid variant, conflicting value and typed_value`, but got no error" +); +// https://github.com/apache/arrow-rs/issues/8336 +variant_test_case!(43, "Unsupported typed_value type: Struct([Field"); +variant_test_case!(44, "Unsupported typed_value type: Struct([Field"); +// https://github.com/apache/arrow-rs/issues/8337 +variant_test_case!(45, "Unsupported typed_value type: List(Field"); +variant_test_case!(46, "Unsupported typed_value type: Struct([Field"); +variant_test_case!(47); +variant_test_case!(48); +variant_test_case!(49); +variant_test_case!(50); +variant_test_case!(51); +variant_test_case!(52); +variant_test_case!(53); +variant_test_case!(54); +variant_test_case!(55); +variant_test_case!(56); +variant_test_case!(57); +variant_test_case!(58); +variant_test_case!(59); +variant_test_case!(60); +variant_test_case!(61); +variant_test_case!(62); +variant_test_case!(63); +variant_test_case!(64); +variant_test_case!(65); +variant_test_case!(66); +variant_test_case!(67); +variant_test_case!(68); +variant_test_case!(69); +variant_test_case!(70); +variant_test_case!(71); +variant_test_case!(72); +variant_test_case!(73); +variant_test_case!(74); +variant_test_case!(75); +variant_test_case!(76); +variant_test_case!(77); +variant_test_case!(78); +variant_test_case!(79); +variant_test_case!(80); +variant_test_case!(81); +variant_test_case!(82); +// https://github.com/apache/arrow-rs/issues/8336 +variant_test_case!(83, "Unsupported typed_value type: Struct([Field"); +variant_test_case!(84, "Unsupported typed_value type: Struct([Field"); +// https://github.com/apache/arrow-rs/issues/8337 +variant_test_case!(85, "Unsupported typed_value type: List(Field"); +variant_test_case!(86, "Unsupported typed_value type: List(Field"); +// Is an error case (should be failing as the expected error message indicates) +variant_test_case!(87, "Unsupported typed_value type: Struct([Field"); +variant_test_case!(88, "Unsupported typed_value type: List(Field"); +variant_test_case!(89); +variant_test_case!(90); +variant_test_case!(91); +variant_test_case!(92); +variant_test_case!(93); +variant_test_case!(94); +variant_test_case!(95); +variant_test_case!(96); +variant_test_case!(97); +variant_test_case!(98); +variant_test_case!(99); +variant_test_case!(100); +variant_test_case!(101); +variant_test_case!(102); +variant_test_case!(103); +variant_test_case!(104); +variant_test_case!(105); +variant_test_case!(106); +variant_test_case!(107); +variant_test_case!(108); +variant_test_case!(109); +variant_test_case!(110); +variant_test_case!(111); +variant_test_case!(112); +variant_test_case!(113); +variant_test_case!(114); +variant_test_case!(115); +variant_test_case!(116); +variant_test_case!(117); +variant_test_case!(118); +variant_test_case!(119); +variant_test_case!(120); +variant_test_case!(121); +variant_test_case!(122); +variant_test_case!(123); +variant_test_case!(124); +variant_test_case!(125, "Unsupported typed_value type: Struct"); +variant_test_case!(126, "Unsupported typed_value type: List("); +// Is an error case (should be failing as the expected error message indicates) +variant_test_case!( + 127, + "Invalid variant data: InvalidArgumentError(\"Received empty bytes\")" +); +// Is an error case (should be failing as the expected error message indicates) +variant_test_case!(128, "Unsupported typed_value type: Struct([Field"); +variant_test_case!(129, "Invalid variant data: InvalidArgumentError("); +variant_test_case!(130, "Unsupported typed_value type: Struct([Field"); +variant_test_case!(131); +variant_test_case!(132, "Unsupported typed_value type: Struct([Field"); +variant_test_case!(133, "Unsupported typed_value type: Struct([Field"); +variant_test_case!(134, "Unsupported typed_value type: Struct([Field"); +variant_test_case!(135); +variant_test_case!(136, "Unsupported typed_value type: List(Field "); +variant_test_case!(137, "Invalid variant data: InvalidArgumentError("); +variant_test_case!(138, "Unsupported typed_value type: Struct([Field"); + +/// Test case definition structure matching the format from +/// `parquet-testing/parquet_shredded/cases.json` +/// +/// See [README] for details. +/// +/// [README]: https://github.com/apache/parquet-testing/blob/master/shredded_variant/README.md +/// +/// Example JSON +/// ```json +/// { +/// "case_number" : 5, +/// "test" : "testShreddedVariantPrimitives", +/// "parquet_file" : "case-005.parquet", +/// "variant_file" : "case-005_row-0.variant.bin", +/// "variant" : "Variant(metadata=VariantMetadata(dict={}), value=Variant(type=BOOLEAN_FALSE, value=false))" +/// }, +/// ``` +#[allow(dead_code)] // some fields are not used except when printing the struct +#[derive(Debug, Clone, Deserialize)] +struct VariantTestCase { + /// Case number (e.g., 1, 2, 4, etc. - note: case 3 is missing any data) + pub case_number: u32, + /// Test method name (e.g., "testSimpleArray") + pub test: Option, + /// Name of the parquet file (e.g., "case-001.parquet") + pub parquet_file: Option, + + /// Expected variant binary file (e.g., "case-001_row-0.variant.bin") - None for error cases + pub variant_file: Option, + /// Multiple expected variant binary files, for multi row inputs. If there + /// is no variant, there is no file + pub variant_files: Option>>, + /// Expected error message for negative test cases + /// + /// (this is the message from the cases.json file, which is from the Iceberg + /// implementation, so it is not guaranteed to match the actual Rust error message) + pub error_message: Option, + /// Description of the variant value (for debugging) + pub variant_description: Option, +} + +/// Run a single test case +impl VariantTestCase { + /// Run a test case. Panics on unexpected error + fn run(&self) { + println!("{self:#?}"); + + let variant_data = self.load_variants(); + let variant_array = self.load_parquet(); + + // if this is an error case, the expected error message should be set + if let Some(expected_error) = &self.error_message { + // just accessing the variant_array should trigger the error + for i in 0..variant_array.len() { + let _ = variant_array.value(i); + } + panic!("Expected an error '{expected_error}`, but got no error"); + } + + assert_eq!( + variant_array.len(), + variant_data.len(), + "Number of variants in parquet file does not match expected number" + ); + for (i, expected) in variant_data.iter().enumerate() { + if variant_array.is_null(i) { + assert!( + expected.is_none(), + "Expected null variant at index {i}, but got {:?}", + variant_array.value(i) + ); + continue; + } + let actual = variant_array.value(i); + let expected = variant_data[i] + .as_ref() + .expect("Expected non-null variant data"); + + let expected = expected.as_variant(); + + // compare the variants (is this the right way to compare?) + assert_eq!(actual, expected, "Variant data mismatch at index {}\n\nactual\n{actual:#?}\n\nexpected\n{expected:#?}", i); + } + } + + /// Parses the expected variant files, returning a vector of `ExpectedVariant` or None + /// if the corresponding entry in `variant_files` is null + fn load_variants(&self) -> Vec> { + let variant_files: Box>> = + match (&self.variant_files, &self.variant_file) { + (Some(files), None) => Box::new(files.iter().map(|f| f.as_ref())), + (None, Some(file)) => Box::new(std::iter::once(Some(file))), + // error cases may not have any variant files + _ => Box::new(std::iter::empty()), + }; + + // load each file + variant_files + .map(|f| { + let v = ExpectedVariant::try_load(&TEST_CASE_DIR.join(f?)) + .expect("Failed to load expected variant"); + Some(v) + }) + .collect() + } + + /// Load the parquet file, extract the Variant column, and return as a VariantArray + fn load_parquet(&self) -> VariantArray { + let parquet_file = self + .parquet_file + .as_ref() + .expect("parquet_file must be set"); + let path = TEST_CASE_DIR.join(parquet_file); + let file = fs::File::open(&path) + .unwrap_or_else(|e| panic!("cannot open parquet file {path:?}: {e}")); + + let reader = ParquetRecordBatchReaderBuilder::try_new(file) + .and_then(|b| b.build()) + .unwrap_or_else(|e| panic!("Error reading parquet reader for {path:?}: {e}")); + + let mut batches: Vec<_> = reader + .collect::>() + .unwrap_or_else(|e| panic!("Error reading parquet batches for {path:?}: {e}")); + + if batches.is_empty() { + panic!("No parquet batches were found in file {path:?}"); + } + if batches.len() > 1 { + panic!( + "Multiple parquet batches were found in file {path:?}, only single batch supported" + ); + } + let batch = batches.swap_remove(0); + + // The schema is "id", "var" for the id and variant columns + // TODO: support the actual parquet logical type annotation somehow + let var = batch + .column_by_name("var") + .unwrap_or_else(|| panic!("No 'var' column found in parquet file {path:?}")); + + // the values are read as + // * StructArray + // but VariantArray needs them as + // * StructArray + // + // So cast them to get the right type. Hack Alert: the parquet reader + // should read them directly as BinaryView + let var = cast_to_binary_view_arrays(var); + + VariantArray::try_new(var).unwrap_or_else(|e| { + panic!("Error converting StructArray to VariantArray for {path:?}: {e}") + }) + } +} + +fn cast_to_binary_view_arrays(array: &ArrayRef) -> ArrayRef { + let new_type = map_type(array.data_type()); + cast(array, &new_type).unwrap_or_else(|e| { + panic!( + "Error casting array from {:?} to {:?}: {e}", + array.data_type(), + new_type + ) + }) +} + +/// replaces all instances of Binary with BinaryView in a DataType +fn map_type(data_type: &DataType) -> DataType { + match data_type { + DataType::Binary => DataType::BinaryView, + DataType::List(field) => { + let new_field = field + .as_ref() + .clone() + .with_data_type(map_type(field.data_type())); + DataType::List(Arc::new(new_field)) + } + DataType::Struct(fields) => { + let new_fields: Fields = fields + .iter() + .map(|f| { + let new_field = f.as_ref().clone().with_data_type(map_type(f.data_type())); + Arc::new(new_field) + }) + .collect(); + DataType::Struct(new_fields) + } + _ => data_type.clone(), + } +} + +/// Variant value loaded from .variant.bin file +#[derive(Debug, Clone)] +struct ExpectedVariant { + data: Vec, + data_offset: usize, +} + +impl ExpectedVariant { + fn try_load(path: &Path) -> Result { + // "Each `*.variant.bin` file contains a single variant serialized + // by concatenating the serialized bytes of the variant metadata + // followed by the serialized bytes of the variant value." + let data = fs::read(path).map_err(|e| format!("cannot read variant file {path:?}: {e}"))?; + let metadata = VariantMetadata::try_new(&data) + .map_err(|e| format!("cannot parse variant metadata from {path:?}: {e}"))?; + + let data_offset = metadata.size(); + Ok(Self { data, data_offset }) + } + + fn as_variant(&self) -> Variant<'_, '_> { + let metadata = &self.data[0..self.data_offset]; + let value = &self.data[self.data_offset..]; + Variant::try_new(metadata, value).expect("Invalid variant data") + } +} + +static TEST_CASE_DIR: LazyLock = LazyLock::new(|| { + PathBuf::from(parquet_test_data()) + .join("..") + .join("shredded_variant") +}); + +/// All tests +static ALL_CASES: LazyLock>> = LazyLock::new(|| { + let cases_file = TEST_CASE_DIR.join("cases.json"); + + if !cases_file.exists() { + return Err(format!("cases.json not found at {}", cases_file.display())); + } + + let content = fs::read_to_string(&cases_file) + .map_err(|e| format!("cannot read cases file {cases_file:?}: {e}"))?; + + serde_json::from_str::>(content.as_str()) + .map_err(|e| format!("cannot parse json from {cases_file:?}: {e}")) +}); + +// return a reference to the static ALL_CASES, or panic if loading failed +fn all_cases() -> &'static [VariantTestCase] { + ALL_CASES.as_ref().unwrap() +} diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 000000000000..4ac629d201c5 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[toolchain] +channel = "1.89" +components = ["rustfmt", "clippy"]