diff --git a/.github/workflows/autodocs.yaml b/.github/workflows/autodocs.yaml
index a768f263ca7..4e79993978c 100644
--- a/.github/workflows/autodocs.yaml
+++ b/.github/workflows/autodocs.yaml
@@ -41,5 +41,5 @@ jobs:
- name: Check that documentation is up-to-date
run: |
- npm install -g @redocly/cli
+ npm install -g @redocly/cli@1.34.2
python update_doc.py --check
diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml
index 3e94f730213..5b292890fda 100644
--- a/.github/workflows/build.yaml
+++ b/.github/workflows/build.yaml
@@ -6,10 +6,11 @@ on:
hardware:
type: string
description: Hardware
- # options:
- # - cuda
- # - rocm
- # - intel
+ # options:
+ # - cuda
+ # - cuda-trtllm
+ # - rocm
+ # - intel
required: true
release-tests:
description: "Run release integration tests"
@@ -24,22 +25,34 @@ jobs:
docker_volume: ${{ steps.final.outputs.docker_volume }}
docker_devices: ${{ steps.final.outputs.docker_devices }}
runs_on: ${{ steps.final.outputs.runs_on }}
- label: ${{ steps.final.outputs.label }}
+ label_extension: ${{ steps.final.outputs.label_extension }}
extra_pytest: ${{ steps.final.outputs.extra_pytest }}
concurrency:
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
runs-on:
- group: aws-highmemory-32-plus-priv
+ group: aws-highmemory-64-plus-priv
permissions:
contents: write
packages: write
+ id-token: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1
- - name: Construct harware variables
+ - name: Inject required variables for sccache to interact with Github Actions Cache
+ uses: actions/github-script@v7
+ with:
+ script: |
+ core.exportVariable('ACTIONS_RESULTS_URL', process.env.ACTIONS_RESULTS_URL || '');
+ core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
+
+ - name: Extract TensorRT-LLM version
+ run: |
+ echo "TENSORRT_LLM_VERSION=$(grep -oP '([a-z,0-9]{40})' $GITHUB_WORKSPACE/backends/trtllm/cmake/trtllm.cmake)" >> $GITHUB_ENV
+ echo "TensorRT-LLM version: ${{ env.TENSORRT_LLM_VERSION }}"
+ - name: Construct hardware variables
shell: bash
run: |
case ${{ inputs.hardware }} in
@@ -51,15 +64,34 @@ jobs:
export runs_on="aws-g6-12xl-plus-priv-cache"
export platform=""
export extra_pytest=""
+ export target=""
+ ;;
+ cuda-trtllm)
+ export dockerfile="Dockerfile_trtllm"
+ export label_extension="-trtllm"
+ export docker_volume="/mnt/cache"
+ export docker_devices=""
+ export runs_on="ubuntu-latest"
+ export platform=""
+ export extra_pytest=""
+ if [[ "${GITHUB_REF}" == refs/tags/* ]]; then
+ export build_type="release";
+ export target="";
+ else
+ export build_type="dev";
+ export target="ci-runtime";
+ fi
;;
rocm)
export dockerfile="Dockerfile_amd"
export label_extension="-rocm"
export docker_devices="/dev/kfd,/dev/dri"
export docker_volume="/mnt"
- export runs_on="amd-gpu-runners"
+ # This runner was deactivated.
+ export runs_on="ubuntu-latest"
export platform=""
export extra_pytest="-k test_flash_gemma_gptq_load"
+ export target=""
;;
intel-xpu)
export dockerfile="Dockerfile_intel"
@@ -69,6 +101,7 @@ jobs:
export runs_on="ubuntu-latest"
export platform="xpu"
export extra_pytest=""
+ export target=""
;;
intel-cpu)
export dockerfile="Dockerfile_intel"
@@ -79,7 +112,27 @@ jobs:
export runs_on="aws-highmemory-32-plus-priv"
export platform="cpu"
export extra_pytest="-k test_flash_gemma_simple"
+ export target=""
+ ;;
+ neuron)
+ export dockerfile="Dockerfile.neuron"
+ export label_extension="-neuron"
+ export docker_devices="/dev/neuron0"
+ export docker_volume="/mnt/cache"
+ export runs_on="aws-inf2-8xlarge"
+ export platform="cpu"
+ export extra_pytest="--neuron"
+ export target=""
;;
+ gaudi)
+ export dockerfile="Dockerfile_gaudi"
+ export label_extension="-gaudi"
+ export docker_volume="/mnt/cache"
+ export docker_devices=""
+ export runs_on="itac-bm-emr-gaudi3-dell-2gaudi"
+ export platform=""
+ export extra_pytest="--gaudi"
+ export target=""
esac
echo $dockerfile
echo "Dockerfile=${dockerfile}"
@@ -88,19 +141,22 @@ jobs:
echo $runs_on
echo $platform
echo "DOCKERFILE=${dockerfile}" >> $GITHUB_ENV
- echo "LABEL=${label_extension}" >> $GITHUB_ENV
+ echo "LABEL_EXTENSION=${label_extension}" >> $GITHUB_ENV
echo "PLATFORM=${platform}" >> $GITHUB_ENV
echo "DOCKER_VOLUME=${docker_volume}" >> $GITHUB_ENV
echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV
echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV
echo "EXTRA_PYTEST=${extra_pytest}" >> $GITHUB_ENV
echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV
+ echo "TARGET=${target}" >> $GITHUB_ENV
+ echo "BUILD_TYPE=${build_type}" >> $GITHUB_ENV
- name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v3
with:
install: true
buildkitd-config: /tmp/buildkitd.toml
- name: Login to internal Container Registry
+ if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
username: ${{ secrets.REGISTRY_USERNAME }}
@@ -113,13 +169,12 @@ jobs:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- - name: Login to Azure Container Registry
- if: github.event_name != 'pull_request'
+ - name: Login to Docker Hub Container Registry
uses: docker/login-action@v3
with:
- username: ${{ secrets.AZURE_DOCKER_USERNAME }}
- password: ${{ secrets.AZURE_DOCKER_PASSWORD }}
- registry: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io
+ registry: docker.io
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_PASSWORD }}
# If pull request
- name: Extract metadata (tags, labels) for Docker
if: ${{ github.event_name == 'pull_request' }}
@@ -127,9 +182,9 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
- registry.internal.huggingface.tech/api-inference/community/text-generation-inference
+ docker.io/huggingface/text-generation-inference-ci
tags: |
- type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
+ type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL_EXTENSION }}
# If main, release or tag
- name: Extract metadata (tags, labels) for Docker
if: ${{ github.event_name != 'pull_request' }}
@@ -137,16 +192,15 @@ jobs:
uses: docker/metadata-action@v4.3.0
with:
flavor: |
- latest=auto
+ latest=false
images: |
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
ghcr.io/huggingface/text-generation-inference
- db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
tags: |
- type=semver,pattern={{version}}${{ env.LABEL }}
- type=semver,pattern={{major}}.{{minor}}${{ env.LABEL }}
- type=raw,value=latest${{ env.LABEL }},enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
- type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
+ type=semver,pattern={{version}}${{ env.LABEL_EXTENSION }}
+ type=semver,pattern={{major}}.{{minor}}${{ env.LABEL_EXTENSION }}
+ type=raw,value=latest${{ env.LABEL_EXTENSION }},enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
+ type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL_EXTENSION }}
- name: Build and push Docker image
id: build-and-push
uses: docker/build-push-action@v4
@@ -157,27 +211,66 @@ jobs:
platforms: 'linux/amd64'
build-args: |
GIT_SHA=${{ env.GITHUB_SHA }}
- DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
+ DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL_EXTENSION }}
PLATFORM=${{ env.PLATFORM }}
+ build_type=${{ env.BUILD_TYPE }}
+ sccache_gha_enabled=on
+ actions_results_url=${{ env.ACTIONS_RESULTS_URL }}
+ actions_runtime_token=${{ env.ACTIONS_RUNTIME_TOKEN }}
+ target: ${{ env.TARGET }}
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
- cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
- cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
+ cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL_EXTENSION }},mode=max,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
+ cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL_EXTENSION }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
- name: Final
id: final
run: |
- echo "docker_image=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
+
+ if [ "${{ github.event_name }}" = "pull_request" ]; then
+ echo "docker_image=docker.io/huggingface/text-generation-inference-ci:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
+ else
+ echo "docker_image=ghcr.io/huggingface/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
+ fi
echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT"
echo "docker_volume=${{ env.DOCKER_VOLUME }}" >> "$GITHUB_OUTPUT"
echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT"
- echo "label=${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
+ echo "label_extension=${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
echo "extra_pytest=${{ env.EXTRA_PYTEST }}" >> "$GITHUB_OUTPUT"
- integration_tests:
+ precompile_neuron_models:
concurrency:
- group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label }}-${{ github.head_ref || github.run_id }}
+ group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label_extension }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
needs: build-and-push
- if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
+ if: needs.build-and-push.outputs.label_extension == '-neuron'
+ runs-on:
+ group: ${{ needs.build-and-push.outputs.runs_on }}
+ env:
+ PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '--release' }}
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+ - name: Inject slug/short variables
+ uses: rlespinasse/github-slug-action@v4.4.1
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.11"
+ - name: Install
+ run: |
+ make install-integration-tests
+ - name: Export neuron models
+ run: |
+ export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}
+ echo $DOCKER_IMAGE
+ docker pull $DOCKER_IMAGE
+ export HF_TOKEN=${{ secrets.HF_TOKEN_NEURON }}
+ python integration-tests/fixtures/neuron/export_models.py
+ integration_tests:
+ concurrency:
+ group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label_extension }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+ needs: [precompile_neuron_models, build-and-push]
+ if: ${{ always() && !contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') && needs.build-and-push.outputs.runs_on != 'ubuntu-latest' }}
runs-on:
group: ${{ needs.build-and-push.outputs.runs_on }}
env:
@@ -204,3 +297,23 @@ jobs:
echo $DOCKER_IMAGE
docker pull $DOCKER_IMAGE
pytest -s -vv integration-tests ${PYTEST_FLAGS} ${EXTRA_PYTEST}
+
+ backend_trtllm_cxx_tests:
+ needs: build-and-push
+ if: needs.build-and-push.outputs.label_extension == '-trtllm'
+ concurrency:
+ group: ${{ github.workflow }}-${{ github.job }}-trtllm-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+ runs-on:
+ group: aws-g6-12xl-plus-priv-cache
+ container:
+ image: ${{ needs.build-and-push.outputs.docker_image }}
+ credentials:
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_PASSWORD }}
+ options: --gpus all --shm-size=8g
+
+ steps:
+ - name: Run C++/CUDA tests
+ if: ${{ env.LABEL_EXTENSION == 'ci-runtime' }}
+ run: /usr/local/tgi/bin/tgi_trtllm_backend_tests
diff --git a/.github/workflows/ci_build.yaml b/.github/workflows/ci_build.yaml
index 5190f321754..f0d39399b0c 100644
--- a/.github/workflows/ci_build.yaml
+++ b/.github/workflows/ci_build.yaml
@@ -20,6 +20,8 @@ on:
- "Dockerfile"
- "Dockerfile_amd"
- "Dockerfile_intel"
+ - "Dockerfile.neuron"
+ - "Dockerfile_gaudi"
branches:
- "main"
workflow_dispatch:
@@ -37,11 +39,12 @@ jobs:
# fail-fast is true by default
fail-fast: false
matrix:
- hardware: ["cuda", "rocm", "intel-xpu", "intel-cpu"]
+ hardware: ["cuda", "cuda-trtllm", "rocm", "intel-xpu", "intel-cpu", "neuron", "gaudi"]
uses: ./.github/workflows/build.yaml # calls the one above ^
permissions:
contents: write
packages: write
+ id-token: write
with:
hardware: ${{ matrix.hardware }}
# https://github.com/actions/runner/issues/2206
diff --git a/.github/workflows/nix_build.yaml b/.github/workflows/nix_build.yaml
new file mode 100644
index 00000000000..b8a10f65d78
--- /dev/null
+++ b/.github/workflows/nix_build.yaml
@@ -0,0 +1,53 @@
+name: "Nix Build Docker image"
+on:
+ pull_request:
+ push:
+ branches:
+ - 'main'
+ tags:
+ - 'v*'
+concurrency:
+ group: nix-image-${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+jobs:
+ build_nix_image:
+ runs-on:
+ group: aws-highmemory-32-plus-priv
+ steps:
+ - uses: actions/checkout@v4
+ - uses: cachix/install-nix-action@v27
+ with:
+ nix_path: nixpkgs=channel:nixos-unstable
+ - uses: cachix/cachix-action@v14
+ with:
+ name: huggingface
+ # If you chose signing key for write access
+ # authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
+ env:
+ USER: github_runner
+ - name: Build
+ run: nix build .#dockerImage
+ - name: Initialize Docker Buildx
+ uses: docker/setup-buildx-action@v3
+ with:
+ install: true
+ buildkitd-config: /tmp/buildkitd.toml
+ - name: Inject slug/short variables
+ uses: rlespinasse/github-slug-action@v4.4.1
+ - name: Login to internal Container Registry
+ # if: github.event_name != 'pull_request'
+ uses: docker/login-action@v3
+ with:
+ username: ${{ secrets.REGISTRY_USERNAME }}
+ password: ${{ secrets.REGISTRY_PASSWORD }}
+ registry: registry.internal.huggingface.tech
+ - name: Push to docker
+ run: |
+ if [ "${{ github.event_name }}" = "pull_request" ]; then
+ export TAG=nix-sha-${{ env.GITHUB_SHA_SHORT }}
+ else
+ export TAG=${{ github.ref_name }}-nix
+ fi
+ export IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:$TAG
+ nix-shell -p skopeo --command "skopeo --insecure-policy copy docker-archive:$(readlink -f ./result) docker://$IMAGE --dest-compress-format zstd"
diff --git a/.github/workflows/nix_cache.yaml b/.github/workflows/nix_cache.yaml
index 967a5982e05..9a76e7c18e9 100644
--- a/.github/workflows/nix_cache.yaml
+++ b/.github/workflows/nix_cache.yaml
@@ -20,9 +20,9 @@ jobs:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
with:
- name: text-generation-inference
+ name: huggingface
# If you chose signing key for write access
- authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
+ #authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
env:
USER: github_runner
- name: Build impure devshell
diff --git a/.github/workflows/nix_tests.yaml b/.github/workflows/nix_tests.yaml
index f2209f8a453..72d75f53bee 100644
--- a/.github/workflows/nix_tests.yaml
+++ b/.github/workflows/nix_tests.yaml
@@ -7,6 +7,7 @@ on:
- "proto/**"
- "router/**"
- "launcher/**"
+ - "backends/**"
- "Cargo.lock"
- "rust-toolchain.toml"
concurrency:
@@ -24,11 +25,13 @@ jobs:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
with:
- name: text-generation-inference
+ name: huggingface
# If you chose signing key for write access
- authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
+ #authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
env:
USER: github_runner
+ - name: Nix info
+ run: nix-shell -p nix-info --run "nix-info -m"
- name: Build
run: nix develop .#test --command echo "Ok"
- name: Pre-commit tests.
diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml
index 4eeca3348ae..3e431c86182 100644
--- a/.github/workflows/tests.yaml
+++ b/.github/workflows/tests.yaml
@@ -8,6 +8,7 @@ on:
- "proto/**"
- "router/**"
- "launcher/**"
+ - "backends/**"
- "Cargo.lock"
- "rust-toolchain.toml"
@@ -20,19 +21,14 @@ jobs:
runs-on:
group: aws-highmemory-32-plus-priv
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
id: python
with:
python-version: 3.11
- - name: Install Rust
- uses: actions-rs/toolchain@v1
+ - uses: dtolnay/rust-toolchain@1.85.0
with:
- # Released on: 02 May, 2024
- # https://releases.rs/docs/1.78.0/
- toolchain: 1.80.0
- override: true
components: rustfmt, clippy
- name: Install Protoc
uses: arduino/setup-protoc@v1
@@ -44,10 +40,18 @@ jobs:
run: |
sudo apt update
sudo apt install python3.11-dev -y
+ pip install -U pip uv
+ uv venv
+ source ./.venv/bin/activate
make install-cpu
+ - name: Download locked kernels
+ run: |
+ source ./.venv/bin/activate
+ kernels download server
- name: Run server tests
run: |
- pip install pytest
+ source ./.venv/bin/activate
+ uv pip install pytest
export HF_TOKEN=${{ secrets.HF_TOKEN }}
pytest -s -vv server/tests
- name: Pre-commit checks
diff --git a/.github/workflows/trufflehog.yaml b/.github/workflows/trufflehog.yaml
index b406d43b8f0..9f1c5f36dad 100644
--- a/.github/workflows/trufflehog.yaml
+++ b/.github/workflows/trufflehog.yaml
@@ -10,9 +10,12 @@ jobs:
trufflehog:
runs-on: ubuntu-latest
steps:
- - name: Checkout code
- uses: actions/checkout@v4
- with:
- fetch-depth: 0
- - name: Secret Scanning
- uses: trufflesecurity/trufflehog@main
+ - name: Checkout code
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+ - name: Secret Scanning
+ uses: trufflesecurity/trufflehog@853e1e8d249fd1e29d0fcc7280d29b03df3d643d
+ with:
+ # exclude buggy postgres detector that is causing false positives and not relevant to our codebase
+ extra_args: --results=verified,unknown --exclude-detectors=postgres
diff --git a/.gitignore b/.gitignore
index 9434d75ca17..8a6bda722d1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -23,3 +23,9 @@ server/fbgemmm
.direnv/
.venv/
+
+# Gaudi auto-generated files
+hl-smi_log*.txt
+.graph_dumps
+out
+hqt_output
diff --git a/Cargo.lock b/Cargo.lock
index 9551ae2d9b2..cfe19dcdef3 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
-version = 3
+version = 4
[[package]]
name = "addr2line"
@@ -24,11 +24,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011"
dependencies = [
"cfg-if",
- "getrandom",
+ "getrandom 0.2.15",
"once_cell",
"serde",
"version_check",
- "zerocopy",
+ "zerocopy 0.7.35",
]
[[package]]
@@ -48,9 +48,24 @@ checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1"
[[package]]
name = "allocator-api2"
-version = "0.2.20"
+version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9"
+checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
+
+[[package]]
+name = "android-tzdata"
+version = "0.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
+
+[[package]]
+name = "android_system_properties"
+version = "0.1.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
+dependencies = [
+ "libc",
+]
[[package]]
name = "anstream"
@@ -93,19 +108,20 @@ dependencies = [
[[package]]
name = "anstyle-wincon"
-version = "3.0.6"
+version = "3.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125"
+checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e"
dependencies = [
"anstyle",
+ "once_cell",
"windows-sys 0.59.0",
]
[[package]]
name = "anyhow"
-version = "1.0.93"
+version = "1.0.97"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775"
+checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f"
[[package]]
name = "arbitrary"
@@ -127,7 +143,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -166,18 +182,18 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
name = "async-trait"
-version = "0.1.83"
+version = "0.1.88"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd"
+checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -230,9 +246,9 @@ dependencies = [
[[package]]
name = "avif-serialize"
-version = "0.8.2"
+version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e335041290c43101ca215eed6f43ec437eb5a42125573f600fc3fa42b9bddd62"
+checksum = "98922d6a4cfbcb08820c69d8eeccc05bb1f29bfa06b4f5b1dbfe9a868bd7608e"
dependencies = [
"arrayvec",
]
@@ -251,29 +267,25 @@ dependencies = [
[[package]]
name = "aws-lc-rs"
-version = "1.11.0"
+version = "1.12.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "fe7c2840b66236045acd2607d5866e274380afd87ef99d6226e961e2cb47df45"
+checksum = "dabb68eb3a7aa08b46fddfd59a3d55c978243557a90ab804769f7e20e67d2b01"
dependencies = [
"aws-lc-sys",
- "mirai-annotations",
- "paste",
"zeroize",
]
[[package]]
name = "aws-lc-sys"
-version = "0.23.0"
+version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ad3a619a9de81e1d7de1f1186dcba4506ed661a0e483d84410fdef0ee87b2f96"
+checksum = "77926887776171ced7d662120a75998e444d3750c951abfe07f90da130514b1f"
dependencies = [
- "bindgen",
+ "bindgen 0.69.5",
"cc",
"cmake",
"dunce",
"fs_extra",
- "libc",
- "paste",
]
[[package]]
@@ -289,7 +301,7 @@ dependencies = [
"futures-util",
"http 0.2.12",
"http-body 0.4.6",
- "hyper 0.14.31",
+ "hyper 0.14.32",
"itoa",
"matchit",
"memchr",
@@ -318,10 +330,10 @@ dependencies = [
"axum-core 0.4.5",
"bytes",
"futures-util",
- "http 1.1.0",
+ "http 1.3.1",
"http-body 1.0.1",
"http-body-util",
- "hyper 1.5.1",
+ "hyper 1.6.0",
"hyper-util",
"itoa",
"matchit",
@@ -336,7 +348,7 @@ dependencies = [
"serde_urlencoded",
"sync_wrapper 1.0.2",
"tokio",
- "tower 0.5.1",
+ "tower 0.5.2",
"tower-layer",
"tower-service",
"tracing",
@@ -368,7 +380,7 @@ dependencies = [
"async-trait",
"bytes",
"futures-util",
- "http 1.1.0",
+ "http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"mime",
@@ -389,7 +401,7 @@ dependencies = [
"axum 0.7.9",
"futures-core",
"futures-util",
- "http 1.1.0",
+ "http 1.3.1",
"opentelemetry 0.21.0",
"pin-project-lite",
"tower 0.4.13",
@@ -437,7 +449,7 @@ version = "0.69.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
"cexpr",
"clang-sys",
"itertools 0.12.1",
@@ -448,26 +460,46 @@ dependencies = [
"proc-macro2",
"quote",
"regex",
- "rustc-hash",
+ "rustc-hash 1.1.0",
"shlex",
- "syn 2.0.89",
+ "syn 2.0.100",
"which",
]
+[[package]]
+name = "bindgen"
+version = "0.71.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3"
+dependencies = [
+ "bitflags 2.9.0",
+ "cexpr",
+ "clang-sys",
+ "itertools 0.13.0",
+ "log",
+ "prettyplease",
+ "proc-macro2",
+ "quote",
+ "regex",
+ "rustc-hash 2.1.1",
+ "shlex",
+ "syn 2.0.100",
+]
+
[[package]]
name = "bit-set"
-version = "0.5.3"
+version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
+checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3"
dependencies = [
"bit-vec",
]
[[package]]
name = "bit-vec"
-version = "0.6.3"
+version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
+checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7"
[[package]]
name = "bit_field"
@@ -483,9 +515,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bitflags"
-version = "2.6.0"
+version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
+checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd"
[[package]]
name = "bitstream-io"
@@ -502,17 +534,23 @@ dependencies = [
"generic-array",
]
+[[package]]
+name = "borrow-or-share"
+version = "0.2.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3eeab4423108c5d7c744f4d234de88d18d636100093ae04caf4825134b9c3a32"
+
[[package]]
name = "built"
-version = "0.7.5"
+version = "0.7.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c360505aed52b7ec96a3636c3f039d99103c37d1d9b4f7a8c743d3ea9ffcd03b"
+checksum = "56ed6191a7e78c36abdb16ab65341eefd73d64d303fffccdbb00d51e4205967b"
[[package]]
name = "bumpalo"
-version = "3.16.0"
+version = "3.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
+checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf"
[[package]]
name = "bytecount"
@@ -522,9 +560,9 @@ checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce"
[[package]]
name = "bytemuck"
-version = "1.20.0"
+version = "1.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8b37c88a63ffd85d15b406896cc343916d7cf57838a847b3a6f2ca5d39a5695a"
+checksum = "b6b1fc10dbac614ebc03540c9dbd60e83887fda27794998c6528f1782047d540"
[[package]]
name = "byteorder"
@@ -540,9 +578,9 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495"
[[package]]
name = "bytes"
-version = "1.8.0"
+version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da"
+checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a"
[[package]]
name = "camino"
@@ -555,9 +593,9 @@ dependencies = [
[[package]]
name = "cargo-platform"
-version = "0.1.8"
+version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "24b1f0365a6c6bb4020cd05806fd0d33c44d38046b8bd7f0e40814b9763cabfc"
+checksum = "e35af189006b9c0f00a064685c727031e3ed2d8020f7ba284d78cc2671bd36ea"
dependencies = [
"serde",
]
@@ -573,7 +611,7 @@ dependencies = [
"semver",
"serde",
"serde_json",
- "thiserror",
+ "thiserror 1.0.69",
]
[[package]]
@@ -599,9 +637,9 @@ dependencies = [
[[package]]
name = "cc"
-version = "1.2.1"
+version = "1.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47"
+checksum = "1fcb57c740ae1daf453ae85f16e37396f672b039e00d9d866e07ddb24e328e3a"
dependencies = [
"jobserver",
"libc",
@@ -645,6 +683,20 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
+[[package]]
+name = "chrono"
+version = "0.4.40"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1a7964611d71df112cb1730f2ee67324fcf4d0fc6606acbbe9bfe06df124637c"
+dependencies = [
+ "android-tzdata",
+ "iana-time-zone",
+ "js-sys",
+ "num-traits",
+ "wasm-bindgen",
+ "windows-link",
+]
+
[[package]]
name = "clang-sys"
version = "1.8.1"
@@ -669,9 +721,9 @@ dependencies = [
[[package]]
name = "clap"
-version = "4.5.21"
+version = "4.5.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f"
+checksum = "6088f3ae8c3608d19260cd7445411865a485688711b78b5be70d78cd96136f83"
dependencies = [
"clap_builder",
"clap_derive",
@@ -679,9 +731,9 @@ dependencies = [
[[package]]
name = "clap_builder"
-version = "4.5.21"
+version = "4.5.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec"
+checksum = "22a7ef7f676155edfb82daa97f99441f3ebf4a58d5e32f295a56259f1b6facc8"
dependencies = [
"anstream",
"anstyle",
@@ -691,39 +743,40 @@ dependencies = [
[[package]]
name = "clap_derive"
-version = "4.5.18"
+version = "4.5.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab"
+checksum = "09176aae279615badda0765c0c0b3f6ed53f4709118af73cf4655d85d1530cd7"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
name = "clap_lex"
-version = "0.7.3"
+version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7"
+checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6"
[[package]]
name = "cmake"
-version = "0.1.51"
+version = "0.1.54"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "fb1e43aa7fd152b1f968787f7dbcdeb306d1867ff373c69955211876c053f91a"
+checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
dependencies = [
"cc",
]
[[package]]
name = "codespan-reporting"
-version = "0.11.1"
+version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e"
+checksum = "fe6d2e5af09e8c8ad56c969f2157a3d4238cebc7c55f0a517728c38f7b200f81"
dependencies = [
+ "serde",
"termcolor",
- "unicode-width 0.1.14",
+ "unicode-width 0.2.0",
]
[[package]]
@@ -740,9 +793,9 @@ checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990"
[[package]]
name = "compact_str"
-version = "0.8.0"
+version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6050c3a16ddab2e412160b31f2c871015704239bca62f72f6e5f0be631d3f644"
+checksum = "3b79c4069c6cad78e2e0cdfcbd26275770669fb39fd308a752dc110e83b9af32"
dependencies = [
"castaway",
"cfg-if",
@@ -754,15 +807,15 @@ dependencies = [
[[package]]
name = "console"
-version = "0.15.8"
+version = "0.15.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb"
+checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8"
dependencies = [
"encode_unicode",
- "lazy_static",
"libc",
- "unicode-width 0.1.14",
- "windows-sys 0.52.0",
+ "once_cell",
+ "unicode-width 0.2.0",
+ "windows-sys 0.59.0",
]
[[package]]
@@ -793,9 +846,9 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
[[package]]
name = "cpufeatures"
-version = "0.2.16"
+version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3"
+checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280"
dependencies = [
"libc",
]
@@ -847,18 +900,18 @@ dependencies = [
[[package]]
name = "crossbeam-channel"
-version = "0.5.13"
+version = "0.5.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "33480d6946193aa8033910124896ca395333cae7e2d1113d1fef6c3272217df2"
+checksum = "06ba6d68e24814cb8de6bb986db8222d3a027d15872cabc0d18817bc3c0e4471"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-deque"
-version = "0.8.5"
+version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d"
+checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
@@ -875,9 +928,9 @@ dependencies = [
[[package]]
name = "crossbeam-utils"
-version = "0.8.20"
+version = "0.8.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80"
+checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
[[package]]
name = "crossterm"
@@ -885,11 +938,11 @@ version = "0.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "829d955a0bb380ef178a640b91779e3987da38c9aea133b20614cfed8cdea9c6"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
"crossterm_winapi",
"mio",
"parking_lot",
- "rustix",
+ "rustix 0.38.44",
"signal-hook",
"signal-hook-mio",
"winapi",
@@ -906,9 +959,9 @@ dependencies = [
[[package]]
name = "crunchy"
-version = "0.2.2"
+version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
+checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929"
[[package]]
name = "crypto-common"
@@ -934,9 +987,9 @@ dependencies = [
[[package]]
name = "csv-core"
-version = "0.1.11"
+version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70"
+checksum = "7d02f3b0da4c6504f86e9cd789d8dbafab48c2321be74e9987593de5a894d93d"
dependencies = [
"memchr",
]
@@ -953,46 +1006,61 @@ dependencies = [
[[package]]
name = "cxx"
-version = "1.0.130"
+version = "1.0.150"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "23c042a0ba58aaff55299632834d1ea53ceff73d62373f62c9ae60890ad1b942"
+checksum = "6d1cf22155cf6a8e0b0536efc30c775eadd7a481c376d2d7e30daf0825a42ef9"
dependencies = [
"cc",
+ "cxxbridge-cmd",
"cxxbridge-flags",
"cxxbridge-macro",
+ "foldhash",
"link-cplusplus",
]
[[package]]
name = "cxx-build"
-version = "1.0.130"
+version = "1.0.150"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "45dc1c88d0fdac57518a9b1f6c4f4fb2aca8f3c30c0d03d7d8518b47ca0bcea6"
+checksum = "db4e07e3a69db032f03450594e53785a5d6b1d787c2ad5b901d9347f0064af94"
dependencies = [
"cc",
"codespan-reporting",
"proc-macro2",
"quote",
"scratch",
- "syn 2.0.89",
+ "syn 2.0.100",
+]
+
+[[package]]
+name = "cxxbridge-cmd"
+version = "1.0.150"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "48e9ff9c627d3abe06190462f7db81fb6cc12f3424ea081c2a8c9ed7a8cc167a"
+dependencies = [
+ "clap 4.5.32",
+ "codespan-reporting",
+ "proc-macro2",
+ "quote",
+ "syn 2.0.100",
]
[[package]]
name = "cxxbridge-flags"
-version = "1.0.130"
+version = "1.0.150"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "aa7ed7d30b289e2592cc55bc2ccd89803a63c913e008e6eb59f06cddf45bb52f"
+checksum = "2e6417f4e1518ded330e088d5a66f50fbae9bbc96840e147058ae44970a2b51a"
[[package]]
name = "cxxbridge-macro"
-version = "1.0.130"
+version = "1.0.150"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0b8c465d22de46b851c04630a5fc749a26005b263632ed2e0d9cc81518ead78d"
+checksum = "856ff0dba6e023dd78189c8f4667126842dfe88392b5d4e94118bd18b8f2afbf"
dependencies = [
"proc-macro2",
"quote",
"rustversion",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -1016,7 +1084,7 @@ dependencies = [
"proc-macro2",
"quote",
"strsim",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -1027,14 +1095,14 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806"
dependencies = [
"darling_core",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
name = "deranged"
-version = "0.3.11"
+version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4"
+checksum = "28cfac68e08048ae1883171632c2aef3ebc555621ae56fbccce1cbf22dd7f058"
dependencies = [
"powerfmt",
]
@@ -1057,7 +1125,7 @@ dependencies = [
"darling",
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -1067,15 +1135,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c"
dependencies = [
"derive_builder_core",
- "syn 2.0.89",
+ "syn 2.0.100",
]
-[[package]]
-name = "diff"
-version = "0.1.13"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8"
-
[[package]]
name = "digest"
version = "0.10.7"
@@ -1115,7 +1177,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -1126,24 +1188,33 @@ checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
[[package]]
name = "easy-cast"
-version = "0.5.2"
+version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "10936778145f3bea71fd9bf61332cce28c28e96a380714f7ab34838b80733fd6"
+checksum = "72852736692ec862655eca398c9bb1b476161b563c9f80f45f4808b9629750d6"
dependencies = [
"libm",
]
[[package]]
name = "either"
-version = "1.13.0"
+version = "1.15.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
+
+[[package]]
+name = "email_address"
+version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
+checksum = "e079f19b08ca6239f47f8ba8509c11cf3ea30095831f7fed61441475edd8c449"
+dependencies = [
+ "serde",
+]
[[package]]
name = "encode_unicode"
-version = "0.3.6"
+version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f"
+checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0"
[[package]]
name = "encoding_rs"
@@ -1156,18 +1227,18 @@ dependencies = [
[[package]]
name = "equivalent"
-version = "1.0.1"
+version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
+checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]]
name = "errno"
-version = "0.3.9"
+version = "0.3.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba"
+checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d"
dependencies = [
"libc",
- "windows-sys 0.52.0",
+ "windows-sys 0.59.0",
]
[[package]]
@@ -1186,7 +1257,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83197f59927b46c04a183a619b7c29df34e63e63c7869320862268c0ef687e0"
dependencies = [
"bit_field",
- "half 2.4.1",
+ "half 2.5.0",
"lebe",
"miniz_oxide",
"rayon-core",
@@ -1196,25 +1267,26 @@ dependencies = [
[[package]]
name = "fancy-regex"
-version = "0.11.0"
+version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2"
+checksum = "6e24cb5a94bcae1e5408b0effca5cd7172ea3c5755049c5f3af4cd283a165298"
dependencies = [
"bit-set",
- "regex",
+ "regex-automata 0.4.9",
+ "regex-syntax 0.8.5",
]
[[package]]
name = "fastrand"
-version = "2.2.0"
+version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4"
+checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "fdeflate"
-version = "0.3.6"
+version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "07c6f4c64c1d33a3111c4466f7365ebdcc37c5bd1ea0d62aae2e3d722aacbedb"
+checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c"
dependencies = [
"simd-adler32",
]
@@ -1227,9 +1299,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
[[package]]
name = "flate2"
-version = "1.0.35"
+version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c"
+checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc"
dependencies = [
"crc32fast",
"miniz_oxide",
@@ -1247,6 +1319,17 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853"
+[[package]]
+name = "fluent-uri"
+version = "0.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1918b65d96df47d3591bed19c5cca17e3fa5d0707318e4b5ef2eae01764df7e5"
+dependencies = [
+ "borrow-or-share",
+ "ref-cast",
+ "serde",
+]
+
[[package]]
name = "fnv"
version = "1.0.7"
@@ -1255,9 +1338,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foldhash"
-version = "0.1.3"
+version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2"
+checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "foreign-types"
@@ -1285,9 +1368,9 @@ dependencies = [
[[package]]
name = "fraction"
-version = "0.13.1"
+version = "0.15.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3027ae1df8d41b4bed2241c8fdad4acc1e7af60c8e17743534b545e77182d678"
+checksum = "0f158e3ff0a1b334408dc9fb811cd99b446986f4d8b741bb08f9df1604085ae7"
dependencies = [
"lazy_static",
"num",
@@ -1355,7 +1438,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -1414,10 +1497,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if",
- "js-sys",
"libc",
- "wasi",
- "wasm-bindgen",
+ "wasi 0.11.0+wasi-snapshot-preview1",
+]
+
+[[package]]
+name = "getrandom"
+version = "0.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "73fea8450eea4bac3940448fb7ae50d91f034f941199fcd9d909a5a07aa455f0"
+dependencies = [
+ "cfg-if",
+ "libc",
+ "r-efi",
+ "wasi 0.14.2+wasi-0.2.4",
]
[[package]]
@@ -1438,9 +1531,9 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
[[package]]
name = "glob"
-version = "0.3.1"
+version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
+checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
[[package]]
name = "grpc-metadata"
@@ -1464,7 +1557,7 @@ dependencies = [
"futures-sink",
"futures-util",
"http 0.2.12",
- "indexmap 2.6.0",
+ "indexmap 2.8.0",
"slab",
"tokio",
"tokio-util",
@@ -1473,17 +1566,17 @@ dependencies = [
[[package]]
name = "h2"
-version = "0.4.7"
+version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ccae279728d634d083c00f6099cb58f01cc99c145b84b8be2f6c74618d79922e"
+checksum = "5017294ff4bb30944501348f6f8e42e6ad28f42c8bbef7a74029aff064a4e3c2"
dependencies = [
"atomic-waker",
"bytes",
"fnv",
"futures-core",
"futures-sink",
- "http 1.1.0",
- "indexmap 2.6.0",
+ "http 1.3.1",
+ "indexmap 2.8.0",
"slab",
"tokio",
"tokio-util",
@@ -1498,9 +1591,9 @@ checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403"
[[package]]
name = "half"
-version = "2.4.1"
+version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
+checksum = "7db2ff139bba50379da6aa0766b52fdcb62cb5b263009b09ed58ba604e14bbd1"
dependencies = [
"cfg-if",
"crunchy",
@@ -1519,14 +1612,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
dependencies = [
"ahash",
- "allocator-api2",
]
[[package]]
name = "hashbrown"
-version = "0.15.1"
+version = "0.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3"
+checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289"
dependencies = [
"allocator-api2",
"equivalent",
@@ -1565,29 +1657,49 @@ name = "hf-hub"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732"
+dependencies = [
+ "dirs",
+ "indicatif",
+ "log",
+ "native-tls",
+ "rand 0.8.5",
+ "serde",
+ "serde_json",
+ "thiserror 1.0.69",
+ "ureq",
+]
+
+[[package]]
+name = "hf-hub"
+version = "0.4.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cc03dcb0b0a83ae3f3363ec811014ae669f083e4e499c66602f447c4828737a1"
dependencies = [
"dirs",
"futures",
+ "http 1.3.1",
"indicatif",
+ "libc",
"log",
"native-tls",
"num_cpus",
- "rand",
- "reqwest",
+ "rand 0.8.5",
+ "reqwest 0.12.15",
"serde",
"serde_json",
- "thiserror",
+ "thiserror 2.0.12",
"tokio",
"ureq",
+ "windows-sys 0.59.0",
]
[[package]]
name = "home"
-version = "0.5.9"
+version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5"
+checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
dependencies = [
- "windows-sys 0.52.0",
+ "windows-sys 0.59.0",
]
[[package]]
@@ -1614,9 +1726,9 @@ dependencies = [
[[package]]
name = "http"
-version = "1.1.0"
+version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258"
+checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565"
dependencies = [
"bytes",
"fnv",
@@ -1641,27 +1753,27 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
dependencies = [
"bytes",
- "http 1.1.0",
+ "http 1.3.1",
]
[[package]]
name = "http-body-util"
-version = "0.1.2"
+version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f"
+checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
dependencies = [
"bytes",
- "futures-util",
- "http 1.1.0",
+ "futures-core",
+ "http 1.3.1",
"http-body 1.0.1",
"pin-project-lite",
]
[[package]]
name = "httparse"
-version = "1.9.5"
+version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946"
+checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
[[package]]
name = "httpdate"
@@ -1671,9 +1783,9 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "hyper"
-version = "0.14.31"
+version = "0.14.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8c08302e8fa335b151b788c775ff56e7a03ae64ff85c548ee820fecb70356e85"
+checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7"
dependencies = [
"bytes",
"futures-channel",
@@ -1695,15 +1807,15 @@ dependencies = [
[[package]]
name = "hyper"
-version = "1.5.1"
+version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "97818827ef4f364230e16705d4706e2897df2bb60617d6ca15d598025a3c481f"
+checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80"
dependencies = [
"bytes",
"futures-channel",
"futures-util",
- "h2 0.4.7",
- "http 1.1.0",
+ "h2 0.4.8",
+ "http 1.3.1",
"http-body 1.0.1",
"httparse",
"httpdate",
@@ -1716,16 +1828,16 @@ dependencies = [
[[package]]
name = "hyper-rustls"
-version = "0.27.3"
+version = "0.27.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333"
+checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2"
dependencies = [
"futures-util",
- "http 1.1.0",
- "hyper 1.5.1",
+ "http 1.3.1",
+ "hyper 1.6.0",
"hyper-util",
"log",
- "rustls 0.23.17",
+ "rustls 0.23.25",
"rustls-native-certs",
"rustls-pki-types",
"tokio",
@@ -1739,7 +1851,7 @@ version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1"
dependencies = [
- "hyper 0.14.31",
+ "hyper 0.14.32",
"pin-project-lite",
"tokio",
"tokio-io-timeout",
@@ -1752,10 +1864,26 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905"
dependencies = [
"bytes",
- "hyper 0.14.31",
+ "hyper 0.14.32",
+ "native-tls",
+ "tokio",
+ "tokio-native-tls",
+]
+
+[[package]]
+name = "hyper-tls"
+version = "0.6.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
+dependencies = [
+ "bytes",
+ "http-body-util",
+ "hyper 1.6.0",
+ "hyper-util",
"native-tls",
"tokio",
"tokio-native-tls",
+ "tower-service",
]
[[package]]
@@ -1767,9 +1895,9 @@ dependencies = [
"bytes",
"futures-channel",
"futures-util",
- "http 1.1.0",
+ "http 1.3.1",
"http-body 1.0.1",
- "hyper 1.5.1",
+ "hyper 1.6.0",
"pin-project-lite",
"socket2",
"tokio",
@@ -1777,6 +1905,29 @@ dependencies = [
"tracing",
]
+[[package]]
+name = "iana-time-zone"
+version = "0.1.61"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220"
+dependencies = [
+ "android_system_properties",
+ "core-foundation-sys",
+ "iana-time-zone-haiku",
+ "js-sys",
+ "wasm-bindgen",
+ "windows-core",
+]
+
+[[package]]
+name = "iana-time-zone-haiku"
+version = "0.1.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
+dependencies = [
+ "cc",
+]
+
[[package]]
name = "icu_collections"
version = "1.5.0"
@@ -1892,7 +2043,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -1947,9 +2098,9 @@ dependencies = [
[[package]]
name = "image-webp"
-version = "0.2.0"
+version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e031e8e3d94711a9ccb5d6ea357439ef3dcbed361798bd4071dc4d9793fbe22f"
+checksum = "b77d01e822461baa8409e156015a1d91735549f0f2c17691bd2d996bef238f7f"
dependencies = [
"byteorder-lite",
"quick-error",
@@ -1973,20 +2124,20 @@ dependencies = [
[[package]]
name = "indexmap"
-version = "2.6.0"
+version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da"
+checksum = "3954d50fe15b02142bf25d3b8bdadb634ec3948f103d04ffe3031bc8fe9d7058"
dependencies = [
"equivalent",
- "hashbrown 0.15.1",
+ "hashbrown 0.15.2",
"serde",
]
[[package]]
name = "indicatif"
-version = "0.17.9"
+version = "0.17.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cbf675b85ed934d3c67b5c5469701eec7db22689d0a2139d856e0925fa28b281"
+checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235"
dependencies = [
"console",
"number_prefix",
@@ -1997,9 +2148,9 @@ dependencies = [
[[package]]
name = "indoc"
-version = "2.0.5"
+version = "2.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5"
+checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd"
[[package]]
name = "init-tracing-opentelemetry"
@@ -2009,23 +2160,22 @@ checksum = "94bd26b1b737bc11f183620072e188d1c6ede67e0e78682228d66b49ec510e17"
dependencies = [
"opentelemetry 0.20.0",
"opentelemetry-otlp",
- "thiserror",
+ "thiserror 1.0.69",
"tracing",
"tracing-opentelemetry 0.21.0",
]
[[package]]
name = "instability"
-version = "0.3.3"
+version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b829f37dead9dc39df40c2d3376c179fdfd2ac771f53f55d3c30dc096a3c0c6e"
+checksum = "0bf9fed6d91cfb734e7476a06bde8300a1b94e217e1b523b6f0cd1a01998c71d"
dependencies = [
"darling",
"indoc",
- "pretty_assertions",
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -2036,14 +2186,14 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
name = "ipnet"
-version = "2.10.1"
+version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708"
+checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
[[package]]
name = "is_terminal_polyfill"
@@ -2051,15 +2201,6 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
-[[package]]
-name = "iso8601"
-version = "0.6.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "924e5d73ea28f59011fec52a0d12185d496a9b075d360657aed2a5707f701153"
-dependencies = [
- "nom",
-]
-
[[package]]
name = "itertools"
version = "0.10.5"
@@ -2098,9 +2239,9 @@ dependencies = [
[[package]]
name = "itoa"
-version = "1.0.13"
+version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "540654e97a3f4470a492cd30ff187bc95d89557a903a2bbf112e2fae98104ef2"
+checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
[[package]]
name = "jobserver"
@@ -2119,41 +2260,37 @@ checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0"
[[package]]
name = "js-sys"
-version = "0.3.72"
+version = "0.3.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9"
+checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f"
dependencies = [
+ "once_cell",
"wasm-bindgen",
]
[[package]]
name = "jsonschema"
-version = "0.17.1"
+version = "0.28.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2a071f4f7efc9a9118dfb627a0a94ef247986e1ab8606a4c806ae2b3aa3b6978"
+checksum = "4b8f66fe41fa46a5c83ed1c717b7e0b4635988f427083108c8cf0a882cc13441"
dependencies = [
"ahash",
- "anyhow",
- "base64 0.21.7",
+ "base64 0.22.1",
"bytecount",
- "clap 4.5.21",
+ "email_address",
"fancy-regex",
"fraction",
- "getrandom",
- "iso8601",
+ "idna",
"itoa",
- "memchr",
"num-cmp",
"once_cell",
- "parking_lot",
"percent-encoding",
- "regex",
- "reqwest",
+ "referencing",
+ "regex-syntax 0.8.5",
+ "reqwest 0.12.15",
"serde",
"serde_json",
- "time",
- "url",
- "uuid",
+ "uuid-simd",
]
[[package]]
@@ -2176,15 +2313,15 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
[[package]]
name = "libc"
-version = "0.2.164"
+version = "0.2.171"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f"
+checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6"
[[package]]
name = "libfuzzer-sys"
-version = "0.4.8"
+version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9b9569d2f74e257076d8c6bfa73fb505b46b851e51ddaecc825944aa3bed17fa"
+checksum = "cf78f52d400cf2d84a3a973a78a592b4adc535739e0a5597a0da6f0c357adc75"
dependencies = [
"arbitrary",
"cc",
@@ -2192,9 +2329,9 @@ dependencies = [
[[package]]
name = "libloading"
-version = "0.8.5"
+version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4"
+checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [
"cfg-if",
"windows-targets 0.52.6",
@@ -2212,30 +2349,36 @@ version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
"libc",
]
[[package]]
name = "link-cplusplus"
-version = "1.0.9"
+version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9d240c6f7e1ba3a28b0249f774e6a9dd0175054b52dfbb61b16eb8505c3785c9"
+checksum = "4a6f6da007f968f9def0d65a05b187e2960183de70c160204ecfccf0ee330212"
dependencies = [
"cc",
]
[[package]]
name = "linux-raw-sys"
-version = "0.4.14"
+version = "0.4.15"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
+
+[[package]]
+name = "linux-raw-sys"
+version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89"
+checksum = "fe7db12097d22ec582439daf8618b8fdd1a7bef6270e9af3b1ebcd30893cf413"
[[package]]
name = "litemap"
-version = "0.7.3"
+version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "643cb0b8d4fcc284004d5fd0d67ccf61dfffadb7f75e1e71bc420f4688a3a704"
+checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856"
[[package]]
name = "lock_api"
@@ -2249,9 +2392,9 @@ dependencies = [
[[package]]
name = "log"
-version = "0.4.22"
+version = "0.4.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
+checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e"
[[package]]
name = "loop9"
@@ -2268,7 +2411,7 @@ version = "0.12.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38"
dependencies = [
- "hashbrown 0.15.1",
+ "hashbrown 0.15.2",
]
[[package]]
@@ -2351,15 +2494,15 @@ checksum = "b4f0c8427b39666bf970460908b213ec09b3b350f20c0c2eabcbba51704a08e6"
dependencies = [
"base64 0.22.1",
"http-body-util",
- "hyper 1.5.1",
+ "hyper 1.6.0",
"hyper-rustls",
"hyper-util",
- "indexmap 2.6.0",
+ "indexmap 2.8.0",
"ipnet",
"metrics",
"metrics-util",
"quanta",
- "thiserror",
+ "thiserror 1.0.69",
"tokio",
"tracing",
]
@@ -2397,9 +2540,9 @@ dependencies = [
[[package]]
name = "minijinja"
-version = "2.5.0"
+version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2c37e1b517d1dcd0e51dc36c4567b9d5a29262b3ec8da6cb5d35e27a8fb529b5"
+checksum = "6e36f1329330bb1614c94b78632b9ce45dd7d761f3304a1bed07b2990a7c5097"
dependencies = [
"serde",
"serde_json",
@@ -2407,9 +2550,9 @@ dependencies = [
[[package]]
name = "minijinja-contrib"
-version = "2.5.0"
+version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7fe51f1a6a8285f03fcd1544d834234fe8db285f29e1c2253600c93b3ae19242"
+checksum = "8e807b6b15e36a4c808e92f78c2ac1f6776519a50d9cf6649819c759a8e7133c"
dependencies = [
"minijinja",
"serde",
@@ -2423,9 +2566,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
-version = "0.8.0"
+version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1"
+checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5"
dependencies = [
"adler2",
"simd-adler32",
@@ -2433,28 +2576,21 @@ dependencies = [
[[package]]
name = "mio"
-version = "1.0.2"
+version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec"
+checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd"
dependencies = [
- "hermit-abi 0.3.9",
"libc",
"log",
- "wasi",
+ "wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys 0.52.0",
]
-[[package]]
-name = "mirai-annotations"
-version = "1.12.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c9be0862c1b3f26a88803c4a49de6889c10e608b3ee9344e6ef5b45fb37ad3d1"
-
[[package]]
name = "monostate"
-version = "0.1.13"
+version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0d208407d7552cd041d8cdb69a1bc3303e029c598738177a3d87082004dc0e1e"
+checksum = "aafe1be9d0c75642e3e50fedc7ecadf1ef1cbce6eb66462153fc44245343fbee"
dependencies = [
"monostate-impl",
"serde",
@@ -2462,13 +2598,13 @@ dependencies = [
[[package]]
name = "monostate-impl"
-version = "0.1.13"
+version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0"
+checksum = "c402a4092d5e204f32c9e155431046831fa712637043c58cb73bc6bc6c9663b5"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -2489,8 +2625,8 @@ dependencies = [
"bytes",
"futures",
"pin-project",
- "rand",
- "thiserror",
+ "rand 0.8.5",
+ "thiserror 1.0.69",
"tokio",
"tokio-util",
"tracing",
@@ -2498,9 +2634,9 @@ dependencies = [
[[package]]
name = "native-tls"
-version = "0.2.12"
+version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466"
+checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e"
dependencies = [
"libc",
"log",
@@ -2534,15 +2670,15 @@ dependencies = [
"bytes",
"futures",
"hostname",
- "hyper 0.14.31",
+ "hyper 0.14.32",
"muxado",
"once_cell",
"parking_lot",
"regex",
- "rustls-pemfile",
+ "rustls-pemfile 1.0.4",
"serde",
"serde_json",
- "thiserror",
+ "thiserror 1.0.69",
"tokio",
"tokio-retry",
"tokio-util",
@@ -2556,7 +2692,7 @@ version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
"cfg-if",
"cfg_aliases 0.1.1",
"libc",
@@ -2568,7 +2704,7 @@ version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
"cfg-if",
"cfg_aliases 0.2.1",
"libc",
@@ -2668,7 +2804,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -2739,18 +2875,18 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
[[package]]
name = "object"
-version = "0.36.5"
+version = "0.36.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e"
+checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87"
dependencies = [
"memchr",
]
[[package]]
name = "once_cell"
-version = "1.20.2"
+version = "1.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
+checksum = "d75b0bedcc4fe52caa0e03d9f1151a323e4aa5e2d78ba3580400cd3c9e2bc4bc"
[[package]]
name = "onig"
@@ -2776,17 +2912,17 @@ dependencies = [
[[package]]
name = "oorandom"
-version = "11.1.4"
+version = "11.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9"
+checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
[[package]]
name = "openssl"
-version = "0.10.68"
+version = "0.10.71"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5"
+checksum = "5e14130c6a98cd258fdcb0fb6d744152343ff729cbfcb28c656a9d12b999fbcd"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
"cfg-if",
"foreign-types",
"libc",
@@ -2803,20 +2939,20 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
name = "openssl-probe"
-version = "0.1.5"
+version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
+checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]]
name = "openssl-sys"
-version = "0.9.104"
+version = "0.9.106"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741"
+checksum = "8bb61ea9811cc39e3c2069f40b8b8e2e70d8569b361f879786cc7ed48b777cdd"
dependencies = [
"cc",
"libc",
@@ -2842,35 +2978,21 @@ checksum = "1e32339a5dc40459130b3bd269e9892439f55b33e772d2a9d402a789baaf4e8a"
dependencies = [
"futures-core",
"futures-sink",
- "indexmap 2.6.0",
+ "indexmap 2.8.0",
"js-sys",
"once_cell",
"pin-project-lite",
- "thiserror",
+ "thiserror 1.0.69",
"urlencoding",
]
[[package]]
-name = "opentelemetry"
-version = "0.24.0"
+name = "opentelemetry-otlp"
+version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4c365a63eec4f55b7efeceb724f1336f26a9cf3427b70e59e2cd2a5b947fba96"
+checksum = "7e5e5a5c4135864099f3faafbe939eb4d7f9b80ebf68a8448da961b32a7c1275"
dependencies = [
- "futures-core",
- "futures-sink",
- "js-sys",
- "once_cell",
- "pin-project-lite",
- "thiserror",
-]
-
-[[package]]
-name = "opentelemetry-otlp"
-version = "0.13.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7e5e5a5c4135864099f3faafbe939eb4d7f9b80ebf68a8448da961b32a7c1275"
-dependencies = [
- "async-trait",
+ "async-trait",
"futures-core",
"http 0.2.12",
"opentelemetry-proto",
@@ -2878,7 +3000,7 @@ dependencies = [
"opentelemetry_api",
"opentelemetry_sdk 0.20.0",
"prost 0.11.9",
- "thiserror",
+ "thiserror 1.0.69",
"tokio",
"tonic 0.9.2",
]
@@ -2916,7 +3038,7 @@ dependencies = [
"js-sys",
"once_cell",
"pin-project-lite",
- "thiserror",
+ "thiserror 1.0.69",
"urlencoding",
]
@@ -2935,10 +3057,10 @@ dependencies = [
"opentelemetry_api",
"ordered-float 3.9.2",
"percent-encoding",
- "rand",
+ "rand 0.8.5",
"regex",
"serde_json",
- "thiserror",
+ "thiserror 1.0.69",
"tokio",
"tokio-stream",
]
@@ -2957,28 +3079,10 @@ dependencies = [
"glob",
"once_cell",
"opentelemetry 0.21.0",
- "ordered-float 4.5.0",
- "percent-encoding",
- "rand",
- "thiserror",
-]
-
-[[package]]
-name = "opentelemetry_sdk"
-version = "0.24.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "692eac490ec80f24a17828d49b40b60f5aeaccdfe6a503f939713afd22bc28df"
-dependencies = [
- "async-trait",
- "futures-channel",
- "futures-executor",
- "futures-util",
- "glob",
- "once_cell",
- "opentelemetry 0.24.0",
+ "ordered-float 4.6.0",
"percent-encoding",
- "rand",
- "thiserror",
+ "rand 0.8.5",
+ "thiserror 1.0.69",
]
[[package]]
@@ -2998,9 +3102,9 @@ dependencies = [
[[package]]
name = "ordered-float"
-version = "4.5.0"
+version = "4.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c65ee1f9701bf938026630b455d5315f490640234259037edb259798b3bcf85e"
+checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951"
dependencies = [
"num-traits",
]
@@ -3016,6 +3120,12 @@ dependencies = [
"serde_json",
]
+[[package]]
+name = "outref"
+version = "0.5.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e"
+
[[package]]
name = "overload"
version = "0.1.1"
@@ -3075,34 +3185,34 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db"
dependencies = [
"fixedbitset",
- "indexmap 2.6.0",
+ "indexmap 2.8.0",
]
[[package]]
name = "pin-project"
-version = "1.1.7"
+version = "1.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "be57f64e946e500c8ee36ef6331845d40a93055567ec57e8fae13efd33759b95"
+checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
-version = "1.1.7"
+version = "1.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3c0f5fad0874fc7abcd4d750e76917eaebbecaa2c20bde22e1dbeeba8beb758c"
+checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
name = "pin-project-lite"
-version = "0.2.15"
+version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff"
+checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b"
[[package]]
name = "pin-utils"
@@ -3112,9 +3222,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "pkg-config"
-version = "0.3.31"
+version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2"
+checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
[[package]]
name = "plotters"
@@ -3146,9 +3256,9 @@ dependencies = [
[[package]]
name = "png"
-version = "0.17.14"
+version = "0.17.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "52f9d46a34a05a6a57566bc2bfae066ef07585a6e3fa30fbbdff5936380623f0"
+checksum = "82151a2fc869e011c153adc57cf2789ccb8d9906ce52c0b39a6b5697749d7526"
dependencies = [
"bitflags 1.3.2",
"crc32fast",
@@ -3159,9 +3269,9 @@ dependencies = [
[[package]]
name = "portable-atomic"
-version = "1.9.0"
+version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2"
+checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e"
[[package]]
name = "powerfmt"
@@ -3171,31 +3281,21 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391"
[[package]]
name = "ppv-lite86"
-version = "0.2.20"
+version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04"
+checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9"
dependencies = [
- "zerocopy",
-]
-
-[[package]]
-name = "pretty_assertions"
-version = "1.4.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d"
-dependencies = [
- "diff",
- "yansi",
+ "zerocopy 0.8.24",
]
[[package]]
name = "prettyplease"
-version = "0.2.25"
+version = "0.2.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "64d1ec885c64d0457d564db4ec299b2dae3f9c02808b8ad9c3a089c591b18033"
+checksum = "5316f57387668042f561aae71480de936257848f9c43ce528e311d89a07cadeb"
dependencies = [
"proc-macro2",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -3224,9 +3324,9 @@ dependencies = [
[[package]]
name = "proc-macro2"
-version = "1.0.92"
+version = "1.0.94"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0"
+checksum = "a31971752e70b8b2686d7e46ec17fb38dad4051d94024c88df49b667caea9c84"
dependencies = [
"unicode-ident",
]
@@ -3247,7 +3347,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a65f2e60fbf1063868558d69c6beacf412dc755f9fc020f514b7955fc914fe30"
dependencies = [
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -3287,7 +3387,7 @@ dependencies = [
"prost 0.12.6",
"prost-types",
"regex",
- "syn 2.0.89",
+ "syn 2.0.100",
"tempfile",
]
@@ -3314,7 +3414,7 @@ dependencies = [
"itertools 0.12.1",
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -3373,7 +3473,7 @@ dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -3386,7 +3486,7 @@ dependencies = [
"proc-macro2",
"pyo3-build-config",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -3400,15 +3500,15 @@ dependencies = [
[[package]]
name = "quanta"
-version = "0.12.3"
+version = "0.12.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5"
+checksum = "3bd1fe6824cea6538803de3ff1bc0cf3949024db3d43c9643024bfb33a807c0e"
dependencies = [
"crossbeam-utils",
"libc",
"once_cell",
"raw-cpuid",
- "wasi",
+ "wasi 0.11.0+wasi-snapshot-preview1",
"web-sys",
"winapi",
]
@@ -3421,13 +3521,19 @@ checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3"
[[package]]
name = "quote"
-version = "1.0.37"
+version = "1.0.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af"
+checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d"
dependencies = [
"proc-macro2",
]
+[[package]]
+name = "r-efi"
+version = "5.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5"
+
[[package]]
name = "rand"
version = "0.8.5"
@@ -3435,8 +3541,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
- "rand_chacha",
- "rand_core",
+ "rand_chacha 0.3.1",
+ "rand_core 0.6.4",
+]
+
+[[package]]
+name = "rand"
+version = "0.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94"
+dependencies = [
+ "rand_chacha 0.9.0",
+ "rand_core 0.9.3",
+ "zerocopy 0.8.24",
]
[[package]]
@@ -3446,7 +3563,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
- "rand_core",
+ "rand_core 0.6.4",
+]
+
+[[package]]
+name = "rand_chacha"
+version = "0.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
+dependencies = [
+ "ppv-lite86",
+ "rand_core 0.9.3",
]
[[package]]
@@ -3455,7 +3582,16 @@ version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
- "getrandom",
+ "getrandom 0.2.15",
+]
+
+[[package]]
+name = "rand_core"
+version = "0.9.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
+dependencies = [
+ "getrandom 0.3.2",
]
[[package]]
@@ -3464,7 +3600,7 @@ version = "0.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fdef7f9be5c0122f890d58bdf4d964349ba6a6161f705907526d891efabba57d"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
"cassowary",
"compact_str",
"crossterm",
@@ -3505,11 +3641,11 @@ dependencies = [
"once_cell",
"paste",
"profiling",
- "rand",
- "rand_chacha",
+ "rand 0.8.5",
+ "rand_chacha 0.3.1",
"simd_helpers",
"system-deps",
- "thiserror",
+ "thiserror 1.0.69",
"v_frame",
"wasm-bindgen",
]
@@ -3531,11 +3667,11 @@ dependencies = [
[[package]]
name = "raw-cpuid"
-version = "11.2.0"
+version = "11.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1ab240315c661615f2ee9f0f2cd32d5a7343a84d5ebcccb99d46e6637565e7b0"
+checksum = "c6df7ab838ed27997ba19a4664507e6f82b41fe6e20be42929332156e5e85146"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
]
[[package]]
@@ -3571,11 +3707,11 @@ dependencies = [
[[package]]
name = "redox_syscall"
-version = "0.5.7"
+version = "0.5.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f"
+checksum = "0b8c0c260b63a8219631167be35e6a988e9554dbd323f8bd08439c8ed1302bd1"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
]
[[package]]
@@ -3584,9 +3720,42 @@ version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43"
dependencies = [
- "getrandom",
+ "getrandom 0.2.15",
"libredox",
- "thiserror",
+ "thiserror 1.0.69",
+]
+
+[[package]]
+name = "ref-cast"
+version = "1.0.24"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4a0ae411dbe946a674d89546582cea4ba2bb8defac896622d6496f14c23ba5cf"
+dependencies = [
+ "ref-cast-impl",
+]
+
+[[package]]
+name = "ref-cast-impl"
+version = "1.0.24"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.100",
+]
+
+[[package]]
+name = "referencing"
+version = "0.28.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d0dcb5ab28989ad7c91eb1b9531a37a1a137cc69a0499aee4117cae4a107c464"
+dependencies = [
+ "ahash",
+ "fluent-uri",
+ "once_cell",
+ "percent-encoding",
+ "serde_json",
]
[[package]]
@@ -3647,8 +3816,8 @@ dependencies = [
"h2 0.3.26",
"http 0.2.12",
"http-body 0.4.6",
- "hyper 0.14.31",
- "hyper-tls",
+ "hyper 0.14.32",
+ "hyper-tls 0.5.0",
"ipnet",
"js-sys",
"log",
@@ -3657,12 +3826,12 @@ dependencies = [
"once_cell",
"percent-encoding",
"pin-project-lite",
- "rustls-pemfile",
+ "rustls-pemfile 1.0.4",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper 0.1.2",
- "system-configuration",
+ "system-configuration 0.5.1",
"tokio",
"tokio-native-tls",
"tower-service",
@@ -3673,6 +3842,53 @@ dependencies = [
"winreg",
]
+[[package]]
+name = "reqwest"
+version = "0.12.15"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d19c46a6fdd48bc4dab94b6103fccc55d34c67cc0ad04653aad4ea2a07cd7bbb"
+dependencies = [
+ "base64 0.22.1",
+ "bytes",
+ "encoding_rs",
+ "futures-channel",
+ "futures-core",
+ "futures-util",
+ "h2 0.4.8",
+ "http 1.3.1",
+ "http-body 1.0.1",
+ "http-body-util",
+ "hyper 1.6.0",
+ "hyper-rustls",
+ "hyper-tls 0.6.0",
+ "hyper-util",
+ "ipnet",
+ "js-sys",
+ "log",
+ "mime",
+ "native-tls",
+ "once_cell",
+ "percent-encoding",
+ "pin-project-lite",
+ "rustls-pemfile 2.2.0",
+ "serde",
+ "serde_json",
+ "serde_urlencoded",
+ "sync_wrapper 1.0.2",
+ "system-configuration 0.6.1",
+ "tokio",
+ "tokio-native-tls",
+ "tokio-util",
+ "tower 0.5.2",
+ "tower-service",
+ "url",
+ "wasm-bindgen",
+ "wasm-bindgen-futures",
+ "wasm-streams",
+ "web-sys",
+ "windows-registry",
+]
+
[[package]]
name = "rgb"
version = "0.8.50"
@@ -3688,7 +3904,7 @@ dependencies = [
"cc",
"libc",
"once_cell",
- "spin 0.5.2",
+ "spin",
"untrusted 0.7.1",
"web-sys",
"winapi",
@@ -3696,24 +3912,23 @@ dependencies = [
[[package]]
name = "ring"
-version = "0.17.8"
+version = "0.17.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d"
+checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7"
dependencies = [
"cc",
"cfg-if",
- "getrandom",
+ "getrandom 0.2.15",
"libc",
- "spin 0.9.8",
"untrusted 0.9.0",
"windows-sys 0.52.0",
]
[[package]]
name = "rust-embed"
-version = "8.5.0"
+version = "8.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "fa66af4a4fdd5e7ebc276f115e895611a34739a9c1c01028383d612d550953c0"
+checksum = "0b3aba5104622db5c9fc61098de54708feb732e7763d7faa2fa625899f00bf6f"
dependencies = [
"rust-embed-impl",
"rust-embed-utils",
@@ -3722,22 +3937,22 @@ dependencies = [
[[package]]
name = "rust-embed-impl"
-version = "8.5.0"
+version = "8.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6125dbc8867951125eec87294137f4e9c2c96566e61bf72c45095a7c77761478"
+checksum = "1f198c73be048d2c5aa8e12f7960ad08443e56fd39cc26336719fdb4ea0ebaae"
dependencies = [
"proc-macro2",
"quote",
"rust-embed-utils",
- "syn 2.0.89",
+ "syn 2.0.100",
"walkdir",
]
[[package]]
name = "rust-embed-utils"
-version = "8.5.0"
+version = "8.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2e5347777e9aacb56039b0e1f28785929a8a3b709e87482e7442c72e7c12529d"
+checksum = "5a2fcdc9f40c8dc2922842ca9add611ad19f332227fc651d015881ad1552bd9a"
dependencies = [
"sha2",
"walkdir",
@@ -3755,6 +3970,12 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
+[[package]]
+name = "rustc-hash"
+version = "2.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
+
[[package]]
name = "rustc_version"
version = "0.4.1"
@@ -3766,15 +3987,28 @@ dependencies = [
[[package]]
name = "rustix"
-version = "0.38.41"
+version = "0.38.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6"
+checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
"errno",
"libc",
- "linux-raw-sys",
- "windows-sys 0.52.0",
+ "linux-raw-sys 0.4.15",
+ "windows-sys 0.59.0",
+]
+
+[[package]]
+name = "rustix"
+version = "1.0.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e56a18552996ac8d29ecc3b190b4fdbb2d91ca4ec396de7bbffaf43f3d637e96"
+dependencies = [
+ "bitflags 2.9.0",
+ "errno",
+ "libc",
+ "linux-raw-sys 0.9.3",
+ "windows-sys 0.59.0",
]
[[package]]
@@ -3796,24 +4030,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432"
dependencies = [
"log",
- "ring 0.17.8",
+ "ring 0.17.14",
"rustls-pki-types",
- "rustls-webpki",
+ "rustls-webpki 0.102.8",
"subtle",
"zeroize",
]
[[package]]
name = "rustls"
-version = "0.23.17"
+version = "0.23.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7f1a745511c54ba6d4465e8d5dfbd81b45791756de28d4981af70d6dca128f1e"
+checksum = "822ee9188ac4ec04a2f0531e55d035fb2de73f18b41a63c70c2712503b6fb13c"
dependencies = [
"aws-lc-rs",
"log",
"once_cell",
"rustls-pki-types",
- "rustls-webpki",
+ "rustls-webpki 0.103.0",
"subtle",
"zeroize",
]
@@ -3827,7 +4061,7 @@ dependencies = [
"openssl-probe",
"rustls-pki-types",
"schannel",
- "security-framework 3.0.1",
+ "security-framework 3.2.0",
]
[[package]]
@@ -3839,35 +4073,55 @@ dependencies = [
"base64 0.21.7",
]
+[[package]]
+name = "rustls-pemfile"
+version = "2.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50"
+dependencies = [
+ "rustls-pki-types",
+]
+
[[package]]
name = "rustls-pki-types"
-version = "1.10.0"
+version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b"
+checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c"
[[package]]
name = "rustls-webpki"
version = "0.102.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9"
+dependencies = [
+ "ring 0.17.14",
+ "rustls-pki-types",
+ "untrusted 0.9.0",
+]
+
+[[package]]
+name = "rustls-webpki"
+version = "0.103.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0aa4eeac2588ffff23e9d7a7e9b3f971c5fb5b7ebc9452745e0c232c64f83b2f"
dependencies = [
"aws-lc-rs",
- "ring 0.17.8",
+ "ring 0.17.14",
"rustls-pki-types",
"untrusted 0.9.0",
]
[[package]]
name = "rustversion"
-version = "1.0.18"
+version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248"
+checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2"
[[package]]
name = "ryu"
-version = "1.0.18"
+version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
+checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
[[package]]
name = "same-file"
@@ -3895,9 +4149,9 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "scratch"
-version = "1.0.7"
+version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152"
+checksum = "9f6280af86e5f559536da57a45ebc84948833b3bee313a7dd25232e09c878a52"
[[package]]
name = "sct"
@@ -3905,7 +4159,7 @@ version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414"
dependencies = [
- "ring 0.17.8",
+ "ring 0.17.14",
"untrusted 0.9.0",
]
@@ -3915,7 +4169,7 @@ version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
"core-foundation 0.9.4",
"core-foundation-sys",
"libc",
@@ -3924,11 +4178,11 @@ dependencies = [
[[package]]
name = "security-framework"
-version = "3.0.1"
+version = "3.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e1415a607e92bec364ea2cf9264646dcce0f91e6d65281bd6f2819cca3bf39c8"
+checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
"core-foundation 0.10.0",
"core-foundation-sys",
"libc",
@@ -3937,9 +4191,9 @@ dependencies = [
[[package]]
name = "security-framework-sys"
-version = "2.12.1"
+version = "2.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "fa39c7303dc58b5543c94d22c1766b0d31f2ee58306363ea622b10bbc075eaa2"
+checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32"
dependencies = [
"core-foundation-sys",
"libc",
@@ -3947,18 +4201,18 @@ dependencies = [
[[package]]
name = "semver"
-version = "1.0.23"
+version = "1.0.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b"
+checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0"
dependencies = [
"serde",
]
[[package]]
name = "serde"
-version = "1.0.215"
+version = "1.0.219"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f"
+checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6"
dependencies = [
"serde_derive",
]
@@ -3985,22 +4239,22 @@ dependencies = [
[[package]]
name = "serde_derive"
-version = "1.0.215"
+version = "1.0.219"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0"
+checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
name = "serde_json"
-version = "1.0.133"
+version = "1.0.140"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377"
+checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373"
dependencies = [
- "indexmap 2.6.0",
+ "indexmap 2.8.0",
"itoa",
"memchr",
"ryu",
@@ -4009,9 +4263,9 @@ dependencies = [
[[package]]
name = "serde_path_to_error"
-version = "0.1.16"
+version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "af99884400da37c88f5e9146b7f1fd0fbcae8f6eec4e9da38b67d05486f814a6"
+checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a"
dependencies = [
"itoa",
"serde",
@@ -4135,31 +4389,36 @@ dependencies = [
[[package]]
name = "smallvec"
-version = "1.13.2"
+version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
+checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd"
[[package]]
name = "socket2"
-version = "0.5.7"
+version = "0.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c"
+checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8"
dependencies = [
"libc",
"windows-sys 0.52.0",
]
[[package]]
-name = "spin"
-version = "0.5.2"
+name = "socks"
+version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
+checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b"
+dependencies = [
+ "byteorder",
+ "libc",
+ "winapi",
+]
[[package]]
name = "spin"
-version = "0.9.8"
+version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
+checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
[[package]]
name = "spm_precompiled"
@@ -4210,7 +4469,7 @@ dependencies = [
"proc-macro2",
"quote",
"rustversion",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -4232,9 +4491,9 @@ dependencies = [
[[package]]
name = "syn"
-version = "2.0.89"
+version = "2.0.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "44d46482f1c1c87acd84dea20c1bf5ebff4c757009ed6bf19cfd36fb10e92c4e"
+checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0"
dependencies = [
"proc-macro2",
"quote",
@@ -4252,6 +4511,9 @@ name = "sync_wrapper"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
+dependencies = [
+ "futures-core",
+]
[[package]]
name = "synstructure"
@@ -4261,7 +4523,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -4287,7 +4549,18 @@ checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"
dependencies = [
"bitflags 1.3.2",
"core-foundation 0.9.4",
- "system-configuration-sys",
+ "system-configuration-sys 0.5.0",
+]
+
+[[package]]
+name = "system-configuration"
+version = "0.6.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b"
+dependencies = [
+ "bitflags 2.9.0",
+ "core-foundation 0.9.4",
+ "system-configuration-sys 0.6.0",
]
[[package]]
@@ -4300,6 +4573,16 @@ dependencies = [
"libc",
]
+[[package]]
+name = "system-configuration-sys"
+version = "0.6.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4"
+dependencies = [
+ "core-foundation-sys",
+ "libc",
+]
+
[[package]]
name = "system-deps"
version = "6.2.2"
@@ -4345,14 +4628,14 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
[[package]]
name = "tempfile"
-version = "3.14.0"
+version = "3.19.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c"
+checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf"
dependencies = [
- "cfg-if",
"fastrand",
+ "getrandom 0.3.2",
"once_cell",
- "rustix",
+ "rustix 1.0.3",
"windows-sys 0.59.0",
]
@@ -4367,42 +4650,39 @@ dependencies = [
[[package]]
name = "text-generation-backends-trtllm"
-version = "3.0.1-dev0"
+version = "3.3.6-dev0"
dependencies = [
- "async-stream",
"async-trait",
- "clap 4.5.21",
+ "clap 4.5.32",
"cmake",
"cxx",
"cxx-build",
- "hashbrown 0.14.5",
- "hf-hub",
- "log",
+ "hashbrown 0.15.2",
+ "hf-hub 0.4.2",
"pkg-config",
+ "pyo3",
"text-generation-router",
- "thiserror",
+ "thiserror 1.0.69",
"tokenizers",
"tokio",
"tokio-stream",
"tracing",
- "tracing-opentelemetry 0.25.0",
- "tracing-subscriber",
]
[[package]]
name = "text-generation-benchmark"
-version = "3.0.1-dev0"
+version = "3.3.6-dev0"
dependencies = [
"average",
- "clap 4.5.21",
+ "clap 4.5.32",
"float-ord",
- "hf-hub",
+ "hf-hub 0.4.2",
"ratatui",
"serde",
"serde_json",
"tabled",
"text-generation-client",
- "thiserror",
+ "thiserror 1.0.69",
"tokenizers",
"tokio",
"tracing",
@@ -4411,7 +4691,7 @@ dependencies = [
[[package]]
name = "text-generation-client"
-version = "3.0.1-dev0"
+version = "3.3.6-dev0"
dependencies = [
"async-trait",
"base64 0.22.1",
@@ -4419,7 +4699,7 @@ dependencies = [
"grpc-metadata",
"prost 0.12.6",
"prost-build",
- "thiserror",
+ "thiserror 1.0.69",
"tokio",
"tonic 0.10.2",
"tonic-build",
@@ -4429,20 +4709,20 @@ dependencies = [
[[package]]
name = "text-generation-launcher"
-version = "3.0.1-dev0"
+version = "3.3.6-dev0"
dependencies = [
- "clap 4.5.21",
+ "clap 4.5.32",
"ctrlc",
"float_eq",
- "hf-hub",
+ "hf-hub 0.4.2",
"nix 0.28.0",
"once_cell",
"pyo3",
"regex",
- "reqwest",
+ "reqwest 0.11.27",
"serde",
"serde_json",
- "thiserror",
+ "thiserror 1.0.69",
"tracing",
"tracing-subscriber",
"vergen",
@@ -4450,7 +4730,7 @@ dependencies = [
[[package]]
name = "text-generation-router"
-version = "3.0.1-dev0"
+version = "3.3.6-dev0"
dependencies = [
"anyhow",
"async-stream",
@@ -4458,11 +4738,12 @@ dependencies = [
"axum 0.7.9",
"axum-tracing-opentelemetry",
"base64 0.22.1",
- "clap 4.5.21",
+ "chrono",
+ "clap 4.5.32",
"csv",
"futures",
"futures-util",
- "hf-hub",
+ "hf-hub 0.4.2",
"image",
"init-tracing-opentelemetry",
"itertools 0.10.5",
@@ -4478,13 +4759,13 @@ dependencies = [
"opentelemetry-otlp",
"outlines-core",
"pyo3",
- "rand",
+ "rand 0.8.5",
"regex",
- "reqwest",
+ "reqwest 0.11.27",
"serde",
"serde_json",
"sysinfo",
- "thiserror",
+ "thiserror 1.0.69",
"tokenizers",
"tokio",
"tokio-stream",
@@ -4499,20 +4780,38 @@ dependencies = [
"vergen",
]
+[[package]]
+name = "text-generation-router-llamacpp"
+version = "3.3.6-dev0"
+dependencies = [
+ "async-trait",
+ "bindgen 0.71.1",
+ "clap 4.5.32",
+ "hf-hub 0.4.2",
+ "num_cpus",
+ "pkg-config",
+ "text-generation-router",
+ "thiserror 2.0.12",
+ "tokenizers",
+ "tokio",
+ "tokio-stream",
+ "tracing",
+]
+
[[package]]
name = "text-generation-router-v2"
-version = "3.0.1-dev0"
+version = "3.3.6-dev0"
dependencies = [
"async-stream",
"async-trait",
"axum 0.7.9",
"axum-tracing-opentelemetry",
"base64 0.22.1",
- "clap 4.5.21",
+ "clap 4.5.32",
"futures",
"futures-util",
"grpc-metadata",
- "hf-hub",
+ "hf-hub 0.4.2",
"image",
"init-tracing-opentelemetry",
"jsonschema",
@@ -4526,14 +4825,14 @@ dependencies = [
"opentelemetry-otlp",
"prost 0.12.6",
"prost-build",
- "rand",
+ "rand 0.8.5",
"regex",
- "reqwest",
+ "reqwest 0.11.27",
"serde",
"serde_json",
"slotmap",
"text-generation-router",
- "thiserror",
+ "thiserror 1.0.69",
"tokenizers",
"tokio",
"tokio-stream",
@@ -4550,19 +4849,19 @@ dependencies = [
[[package]]
name = "text-generation-router-v3"
-version = "3.0.1-dev0"
+version = "3.3.6-dev0"
dependencies = [
"async-stream",
"async-trait",
"axum 0.7.9",
"axum-tracing-opentelemetry",
"base64 0.22.1",
- "clap 4.5.21",
+ "clap 4.5.32",
"criterion",
"futures",
"futures-util",
"grpc-metadata",
- "hf-hub",
+ "hf-hub 0.4.2",
"image",
"init-tracing-opentelemetry",
"itertools 0.13.0",
@@ -4577,14 +4876,15 @@ dependencies = [
"opentelemetry-otlp",
"prost 0.12.6",
"prost-build",
- "rand",
+ "rand 0.8.5",
"regex",
- "reqwest",
+ "reqwest 0.11.27",
+ "rustc-hash 2.1.1",
"serde",
"serde_json",
"slotmap",
"text-generation-router",
- "thiserror",
+ "thiserror 1.0.69",
"tokenizers",
"tokio",
"tokio-stream",
@@ -4614,7 +4914,16 @@ version = "1.0.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52"
dependencies = [
- "thiserror-impl",
+ "thiserror-impl 1.0.69",
+]
+
+[[package]]
+name = "thiserror"
+version = "2.0.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708"
+dependencies = [
+ "thiserror-impl 2.0.12",
]
[[package]]
@@ -4625,7 +4934,18 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
+]
+
+[[package]]
+name = "thiserror-impl"
+version = "2.0.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.100",
]
[[package]]
@@ -4651,9 +4971,9 @@ dependencies = [
[[package]]
name = "time"
-version = "0.3.36"
+version = "0.3.41"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885"
+checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40"
dependencies = [
"deranged",
"itoa",
@@ -4668,15 +4988,15 @@ dependencies = [
[[package]]
name = "time-core"
-version = "0.1.2"
+version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3"
+checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c"
[[package]]
name = "time-macros"
-version = "0.2.18"
+version = "0.2.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf"
+checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49"
dependencies = [
"num-conv",
"time-core",
@@ -4704,15 +5024,15 @@ dependencies = [
[[package]]
name = "tokenizers"
-version = "0.20.3"
+version = "0.20.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "67b67c92f6d705e2a1d106fb0b28c696f9074901a9c656ee5d9f5de204c39bf7"
+checksum = "3b08cc37428a476fc9e20ac850132a513a2e1ce32b6a31addf2b74fa7033b905"
dependencies = [
"aho-corasick",
"derive_builder",
"esaxx-rs",
- "getrandom",
- "hf-hub",
+ "getrandom 0.2.15",
+ "hf-hub 0.3.2",
"indicatif",
"itertools 0.12.1",
"lazy_static",
@@ -4721,7 +5041,7 @@ dependencies = [
"monostate",
"onig",
"paste",
- "rand",
+ "rand 0.8.5",
"rayon",
"rayon-cond",
"regex",
@@ -4729,7 +5049,7 @@ dependencies = [
"serde",
"serde_json",
"spm_precompiled",
- "thiserror",
+ "thiserror 1.0.69",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
@@ -4737,9 +5057,9 @@ dependencies = [
[[package]]
name = "tokio"
-version = "1.41.1"
+version = "1.44.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33"
+checksum = "f382da615b842244d4b8738c82ed1275e6c5dd90c459a30941cd07080b06c91a"
dependencies = [
"backtrace",
"bytes",
@@ -4765,13 +5085,13 @@ dependencies = [
[[package]]
name = "tokio-macros"
-version = "2.4.0"
+version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752"
+checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -4791,26 +5111,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f57eb36ecbe0fc510036adff84824dd3c24bb781e21bfa67b69d556aa85214f"
dependencies = [
"pin-project",
- "rand",
+ "rand 0.8.5",
"tokio",
]
[[package]]
name = "tokio-rustls"
-version = "0.26.0"
+version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4"
+checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b"
dependencies = [
- "rustls 0.23.17",
- "rustls-pki-types",
+ "rustls 0.23.25",
"tokio",
]
[[package]]
name = "tokio-stream"
-version = "0.1.16"
+version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1"
+checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047"
dependencies = [
"futures-core",
"pin-project-lite",
@@ -4819,9 +5138,9 @@ dependencies = [
[[package]]
name = "tokio-util"
-version = "0.7.12"
+version = "0.7.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a"
+checksum = "6b9590b93e6fcc1739458317cccd391ad3955e2bde8913edf6f95f9e65a8f034"
dependencies = [
"bytes",
"futures-core",
@@ -4833,9 +5152,9 @@ dependencies = [
[[package]]
name = "toml"
-version = "0.8.19"
+version = "0.8.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e"
+checksum = "cd87a5cdd6ffab733b2f74bc4fd7ee5fff6634124999ac278c35fc78c6120148"
dependencies = [
"serde",
"serde_spanned",
@@ -4854,11 +5173,11 @@ dependencies = [
[[package]]
name = "toml_edit"
-version = "0.22.22"
+version = "0.22.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5"
+checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474"
dependencies = [
- "indexmap 2.6.0",
+ "indexmap 2.8.0",
"serde",
"serde_spanned",
"toml_datetime",
@@ -4880,7 +5199,7 @@ dependencies = [
"h2 0.3.26",
"http 0.2.12",
"http-body 0.4.6",
- "hyper 0.14.31",
+ "hyper 0.14.32",
"hyper-timeout",
"percent-encoding",
"pin-project",
@@ -4907,7 +5226,7 @@ dependencies = [
"h2 0.3.26",
"http 0.2.12",
"http-body 0.4.6",
- "hyper 0.14.31",
+ "hyper 0.14.32",
"hyper-timeout",
"percent-encoding",
"pin-project",
@@ -4930,7 +5249,7 @@ dependencies = [
"proc-macro2",
"prost-build",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -4944,7 +5263,7 @@ dependencies = [
"indexmap 1.9.3",
"pin-project",
"pin-project-lite",
- "rand",
+ "rand 0.8.5",
"slab",
"tokio",
"tokio-util",
@@ -4955,14 +5274,14 @@ dependencies = [
[[package]]
name = "tower"
-version = "0.5.1"
+version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f"
+checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9"
dependencies = [
"futures-core",
"futures-util",
"pin-project-lite",
- "sync_wrapper 0.1.2",
+ "sync_wrapper 1.0.2",
"tokio",
"tower-layer",
"tower-service",
@@ -4975,9 +5294,9 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
"bytes",
- "http 1.1.0",
+ "http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"pin-project-lite",
@@ -4999,9 +5318,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
[[package]]
name = "tracing"
-version = "0.1.40"
+version = "0.1.41"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
+checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0"
dependencies = [
"log",
"pin-project-lite",
@@ -5011,20 +5330,20 @@ dependencies = [
[[package]]
name = "tracing-attributes"
-version = "0.1.27"
+version = "0.1.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
+checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
name = "tracing-core"
-version = "0.1.32"
+version = "0.1.33"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54"
+checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c"
dependencies = [
"once_cell",
"valuable",
@@ -5086,31 +5405,13 @@ dependencies = [
"web-time 0.2.4",
]
-[[package]]
-name = "tracing-opentelemetry"
-version = "0.25.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a9784ed4da7d921bc8df6963f8c80a0e4ce34ba6ba76668acadd3edbd985ff3b"
-dependencies = [
- "js-sys",
- "once_cell",
- "opentelemetry 0.24.0",
- "opentelemetry_sdk 0.24.1",
- "smallvec",
- "tracing",
- "tracing-core",
- "tracing-log 0.2.0",
- "tracing-subscriber",
- "web-time 1.1.0",
-]
-
[[package]]
name = "tracing-opentelemetry-instrumentation-sdk"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9920abb6a3ee3a2af7d30c9ff02900f8481935d36723c3da95cf807468218e8c"
dependencies = [
- "http 1.1.0",
+ "http 1.3.1",
"opentelemetry 0.21.0",
"tracing",
"tracing-opentelemetry 0.22.0",
@@ -5118,9 +5419,9 @@ dependencies = [
[[package]]
name = "tracing-serde"
-version = "0.1.3"
+version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "bc6b213177105856957181934e4920de57730fc69bf42c37ee5bb664d406d9e1"
+checksum = "704b1aeb7be0d0a84fc9828cae51dab5970fee5088f83d1dd7ee6f6246fc6ff1"
dependencies = [
"serde",
"tracing-core",
@@ -5128,9 +5429,9 @@ dependencies = [
[[package]]
name = "tracing-subscriber"
-version = "0.3.18"
+version = "0.3.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b"
+checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008"
dependencies = [
"matchers",
"nu-ansi-term",
@@ -5155,21 +5456,21 @@ checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "typenum"
-version = "1.17.0"
+version = "1.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
+checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f"
[[package]]
name = "unicase"
-version = "2.8.0"
+version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7e51b68083f157f853b6379db119d1c1be0e6e4dec98101079dec41f6f5cf6df"
+checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-ident"
-version = "1.0.14"
+version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83"
+checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
[[package]]
name = "unicode-normalization-alignments"
@@ -5217,9 +5518,9 @@ checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "unindent"
-version = "0.2.3"
+version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
+checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3"
[[package]]
name = "untrusted"
@@ -5246,18 +5547,19 @@ dependencies = [
"once_cell",
"rustls 0.22.4",
"rustls-pki-types",
- "rustls-webpki",
+ "rustls-webpki 0.102.8",
"serde",
"serde_json",
+ "socks",
"url",
"webpki-roots",
]
[[package]]
name = "url"
-version = "2.5.3"
+version = "2.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8d157f1b96d14500ffdc1f10ba712e780825526c03d9a49b4d0324b0d9113ada"
+checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60"
dependencies = [
"form_urlencoded",
"idna",
@@ -5294,7 +5596,7 @@ version = "4.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23"
dependencies = [
- "indexmap 2.6.0",
+ "indexmap 2.8.0",
"serde",
"serde_json",
"utoipa-gen",
@@ -5310,7 +5612,7 @@ dependencies = [
"proc-macro2",
"quote",
"regex",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -5331,24 +5633,35 @@ dependencies = [
[[package]]
name = "uuid"
-version = "1.11.0"
+version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a"
+checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9"
dependencies = [
- "getrandom",
- "rand",
+ "getrandom 0.3.2",
+ "rand 0.9.0",
"uuid-macro-internal",
]
[[package]]
name = "uuid-macro-internal"
-version = "1.11.0"
+version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6b91f57fe13a38d0ce9e28a03463d8d3c2468ed03d75375110ec71d93b449a08"
+checksum = "72dcd78c4f979627a754f5522cea6e6a25e55139056535fe6e69c506cd64a862"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
+]
+
+[[package]]
+name = "uuid-simd"
+version = "0.8.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "23b082222b4f6619906941c17eb2297fff4c2fb96cb60164170522942a200bd8"
+dependencies = [
+ "outref",
+ "uuid",
+ "vsimd",
]
[[package]]
@@ -5364,9 +5677,9 @@ dependencies = [
[[package]]
name = "valuable"
-version = "0.1.0"
+version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
+checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
[[package]]
name = "vcpkg"
@@ -5402,6 +5715,12 @@ version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
+[[package]]
+name = "vsimd"
+version = "0.8.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64"
+
[[package]]
name = "walkdir"
version = "2.5.0"
@@ -5427,49 +5746,59 @@ version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
+[[package]]
+name = "wasi"
+version = "0.14.2+wasi-0.2.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3"
+dependencies = [
+ "wit-bindgen-rt",
+]
+
[[package]]
name = "wasm-bindgen"
-version = "0.2.95"
+version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e"
+checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5"
dependencies = [
"cfg-if",
"once_cell",
+ "rustversion",
"wasm-bindgen-macro",
]
[[package]]
name = "wasm-bindgen-backend"
-version = "0.2.95"
+version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358"
+checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6"
dependencies = [
"bumpalo",
"log",
- "once_cell",
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-futures"
-version = "0.4.45"
+version = "0.4.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b"
+checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61"
dependencies = [
"cfg-if",
"js-sys",
+ "once_cell",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "wasm-bindgen-macro"
-version = "0.2.95"
+version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56"
+checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407"
dependencies = [
"quote",
"wasm-bindgen-macro-support",
@@ -5477,28 +5806,44 @@ dependencies = [
[[package]]
name = "wasm-bindgen-macro-support"
-version = "0.2.95"
+version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68"
+checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-shared"
-version = "0.2.95"
+version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d"
+checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d"
+dependencies = [
+ "unicode-ident",
+]
+
+[[package]]
+name = "wasm-streams"
+version = "0.4.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
+dependencies = [
+ "futures-util",
+ "js-sys",
+ "wasm-bindgen",
+ "wasm-bindgen-futures",
+ "web-sys",
+]
[[package]]
name = "web-sys"
-version = "0.3.72"
+version = "0.3.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112"
+checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2"
dependencies = [
"js-sys",
"wasm-bindgen",
@@ -5530,15 +5875,15 @@ version = "0.22.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53"
dependencies = [
- "ring 0.17.8",
+ "ring 0.17.14",
"untrusted 0.9.0",
]
[[package]]
name = "webpki-roots"
-version = "0.26.7"
+version = "0.26.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5d642ff16b7e79272ae451b7322067cdc17cadf68c23264be9d94a32319efe7e"
+checksum = "2210b291f7ea53617fbafcc4939f10914214ec15aace5ba62293a668f322c5c9"
dependencies = [
"rustls-pki-types",
]
@@ -5558,7 +5903,7 @@ dependencies = [
"either",
"home",
"once_cell",
- "rustix",
+ "rustix 0.38.44",
]
[[package]]
@@ -5611,6 +5956,41 @@ dependencies = [
"windows-targets 0.52.6",
]
+[[package]]
+name = "windows-link"
+version = "0.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38"
+
+[[package]]
+name = "windows-registry"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3"
+dependencies = [
+ "windows-result",
+ "windows-strings",
+ "windows-targets 0.53.0",
+]
+
+[[package]]
+name = "windows-result"
+version = "0.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c64fd11a4fd95df68efcfee5f44a294fe71b8bc6a91993e2791938abcc712252"
+dependencies = [
+ "windows-link",
+]
+
+[[package]]
+name = "windows-strings"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "87fa48cc5d406560701792be122a10132491cff9d0aeb23583cc2dcafc847319"
+dependencies = [
+ "windows-link",
+]
+
[[package]]
name = "windows-sys"
version = "0.45.0"
@@ -5686,13 +6066,29 @@ dependencies = [
"windows_aarch64_gnullvm 0.52.6",
"windows_aarch64_msvc 0.52.6",
"windows_i686_gnu 0.52.6",
- "windows_i686_gnullvm",
+ "windows_i686_gnullvm 0.52.6",
"windows_i686_msvc 0.52.6",
"windows_x86_64_gnu 0.52.6",
"windows_x86_64_gnullvm 0.52.6",
"windows_x86_64_msvc 0.52.6",
]
+[[package]]
+name = "windows-targets"
+version = "0.53.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b1e4c7e8ceaaf9cb7d7507c974735728ab453b67ef8f18febdd7c11fe59dca8b"
+dependencies = [
+ "windows_aarch64_gnullvm 0.53.0",
+ "windows_aarch64_msvc 0.53.0",
+ "windows_i686_gnu 0.53.0",
+ "windows_i686_gnullvm 0.53.0",
+ "windows_i686_msvc 0.53.0",
+ "windows_x86_64_gnu 0.53.0",
+ "windows_x86_64_gnullvm 0.53.0",
+ "windows_x86_64_msvc 0.53.0",
+]
+
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.42.2"
@@ -5711,6 +6107,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
+[[package]]
+name = "windows_aarch64_gnullvm"
+version = "0.53.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764"
+
[[package]]
name = "windows_aarch64_msvc"
version = "0.42.2"
@@ -5729,6 +6131,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
+[[package]]
+name = "windows_aarch64_msvc"
+version = "0.53.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c"
+
[[package]]
name = "windows_i686_gnu"
version = "0.42.2"
@@ -5747,12 +6155,24 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
+[[package]]
+name = "windows_i686_gnu"
+version = "0.53.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3"
+
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
+[[package]]
+name = "windows_i686_gnullvm"
+version = "0.53.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11"
+
[[package]]
name = "windows_i686_msvc"
version = "0.42.2"
@@ -5771,6 +6191,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
+[[package]]
+name = "windows_i686_msvc"
+version = "0.53.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d"
+
[[package]]
name = "windows_x86_64_gnu"
version = "0.42.2"
@@ -5789,6 +6215,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
+[[package]]
+name = "windows_x86_64_gnu"
+version = "0.53.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba"
+
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.42.2"
@@ -5807,6 +6239,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
+[[package]]
+name = "windows_x86_64_gnullvm"
+version = "0.53.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57"
+
[[package]]
name = "windows_x86_64_msvc"
version = "0.42.2"
@@ -5825,11 +6263,17 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
+[[package]]
+name = "windows_x86_64_msvc"
+version = "0.53.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486"
+
[[package]]
name = "winnow"
-version = "0.6.20"
+version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b"
+checksum = "0e97b544156e9bebe1a0ffbc03484fc1ffe3100cbce3ffb17eac35f7cdd7ab36"
dependencies = [
"memchr",
]
@@ -5844,6 +6288,15 @@ dependencies = [
"windows-sys 0.48.0",
]
+[[package]]
+name = "wit-bindgen-rt"
+version = "0.39.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1"
+dependencies = [
+ "bitflags 2.9.0",
+]
+
[[package]]
name = "write16"
version = "1.0.0"
@@ -5856,17 +6309,11 @@ version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51"
-[[package]]
-name = "yansi"
-version = "1.0.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049"
-
[[package]]
name = "yoke"
-version = "0.7.4"
+version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6c5b1314b079b0930c31e3af543d8ee1757b1951ae1e1565ec704403a7240ca5"
+checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40"
dependencies = [
"serde",
"stable_deref_trait",
@@ -5876,13 +6323,13 @@ dependencies = [
[[package]]
name = "yoke-derive"
-version = "0.7.4"
+version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95"
+checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
"synstructure",
]
@@ -5892,8 +6339,16 @@ version = "0.7.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0"
dependencies = [
- "byteorder",
- "zerocopy-derive",
+ "zerocopy-derive 0.7.35",
+]
+
+[[package]]
+name = "zerocopy"
+version = "0.8.24"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2586fea28e186957ef732a5f8b3be2da217d65c5969d4b1e17f973ebbe876879"
+dependencies = [
+ "zerocopy-derive 0.8.24",
]
[[package]]
@@ -5904,27 +6359,38 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
+]
+
+[[package]]
+name = "zerocopy-derive"
+version = "0.8.24"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a996a8f63c5c4448cd959ac1bab0aaa3306ccfd060472f85943ee0750f0169be"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.100",
]
[[package]]
name = "zerofrom"
-version = "0.1.4"
+version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "91ec111ce797d0e0784a1116d0ddcdbea84322cd79e5d5ad173daeba4f93ab55"
+checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5"
dependencies = [
"zerofrom-derive",
]
[[package]]
name = "zerofrom-derive"
-version = "0.1.4"
+version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5"
+checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
"synstructure",
]
@@ -5953,7 +6419,7 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -5985,9 +6451,9 @@ dependencies = [
[[package]]
name = "zune-jpeg"
-version = "0.4.13"
+version = "0.4.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "16099418600b4d8f028622f73ff6e3deaabdff330fb9a2a131dea781ee8b0768"
+checksum = "99a5bab8d7dedf81405c4bb1f2b83ea057643d9cb28778cea9eecddeedd2e028"
dependencies = [
"zune-core",
]
diff --git a/Cargo.toml b/Cargo.toml
index 806c94a0485..f985d0a1356 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,26 +1,27 @@
[workspace]
members = [
- "benchmark",
- "backends/v2",
- "backends/v3",
- "backends/grpc-metadata",
- "backends/trtllm",
- "launcher",
- "router"
+ "benchmark",
+ "backends/v2",
+ "backends/v3",
+ "backends/grpc-metadata",
+ "backends/trtllm",
+ "backends/llamacpp",
+ "launcher",
+ "router"
]
default-members = [
- "benchmark",
- "backends/v2",
- "backends/v3",
- "backends/grpc-metadata",
- # "backends/trtllm",
- "launcher",
- "router"
+ "benchmark",
+ "backends/v2",
+ "backends/v3",
+ "backends/grpc-metadata",
+ # "backends/trtllm",
+ "launcher",
+ "router"
]
resolver = "2"
[workspace.package]
-version = "3.0.1-dev0"
+version = "3.3.6-dev0"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"
@@ -28,7 +29,7 @@ homepage = "https://github.com/huggingface/text-generation-inference"
[workspace.dependencies]
base64 = "0.22.0"
tokenizers = { version = "0.20.0", features = ["http"] }
-hf-hub = { version = "0.3.1", features = ["tokio"] }
+hf-hub = { version = "0.4.2", features = ["tokio"] }
metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
minijinja = { version = "2.2.0", features = ["json"] }
diff --git a/Dockerfile b/Dockerfile
index 0c08d48f6e4..869596d0f21 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,5 +1,5 @@
# Rust builder
-FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
+FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@@ -45,21 +45,16 @@ RUN cargo build --profile release-opt --frozen
# Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
+WORKDIR /usr/src/
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
-ARG PYTORCH_VERSION=2.4.0
-
+ARG PYTORCH_VERSION=2.7
ARG PYTHON_VERSION=3.11
+
# Keep in sync with `server/pyproject.toml
-ARG CUDA_VERSION=12.4
-ARG MAMBA_VERSION=24.3.0-0
-ARG CUDA_CHANNEL=nvidia
-ARG INSTALL_CHANNEL=pytorch
# Automatically set by buildx
ARG TARGETPLATFORM
-ENV PATH /opt/conda/bin:$PATH
-
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
ca-certificates \
@@ -67,26 +62,12 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
curl \
git && \
rm -rf /var/lib/apt/lists/*
-
-# Install conda
-# translating Docker's TARGETPLATFORM into mamba arches
-RUN case ${TARGETPLATFORM} in \
- "linux/arm64") MAMBA_ARCH=aarch64 ;; \
- *) MAMBA_ARCH=x86_64 ;; \
- esac && \
- curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/$%7BMAMBA_VERSION%7D/Mambaforge-$%7BMAMBA_VERSION%7D-Linux-$%7BMAMBA_ARCH%7D.sh"
-RUN chmod +x ~/mambaforge.sh && \
- bash ~/mambaforge.sh -b -p /opt/conda && \
- rm ~/mambaforge.sh
-
-# Install pytorch
-# On arm64 we exit with an error code
-RUN case ${TARGETPLATFORM} in \
- "linux/arm64") exit 1 ;; \
- *) /opt/conda/bin/conda update -y conda && \
- /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \
- esac && \
- /opt/conda/bin/conda clean -ya
+COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
+ENV PATH="$PATH:/root/.local/bin"
+RUN uv python install ${PYTHON_VERSION}
+RUN uv venv --python ${PYTHON_VERSION} && uv pip install torch==${PYTORCH_VERSION} torchvision pip setuptools packaging
+ENV VIRTUAL_ENV=/usr/src/.venv/
+ENV PATH="$PATH:/usr/src/.venv/bin/"
# CUDA kernels builder image
FROM pytorch-install AS kernel-builder
@@ -106,7 +87,7 @@ WORKDIR /usr/src
COPY server/Makefile-flash-att Makefile
# Build specific version of flash attention
-RUN make build-flash-attention
+RUN . .venv/bin/activate && make build-flash-attention
# Build Flash Attention v2 CUDA kernels
FROM kernel-builder AS flash-att-v2-builder
@@ -116,14 +97,14 @@ WORKDIR /usr/src
COPY server/Makefile-flash-att-v2 Makefile
# Build specific version of flash attention v2
-RUN make build-flash-attention-v2-cuda
+RUN . .venv/bin/activate && make build-flash-attention-v2-cuda
# Build Transformers exllama kernels
FROM kernel-builder AS exllama-kernels-builder
WORKDIR /usr/src
COPY server/exllama_kernels/ .
-RUN python setup.py build
+RUN . .venv/bin/activate && python setup.py build
# Build Transformers exllama kernels
FROM kernel-builder AS exllamav2-kernels-builder
@@ -131,54 +112,36 @@ WORKDIR /usr/src
COPY server/Makefile-exllamav2/ Makefile
# Build specific version of transformers
-RUN make build-exllamav2
+RUN . .venv/bin/activate && make build-exllamav2
# Build Transformers awq kernels
FROM kernel-builder AS awq-kernels-builder
WORKDIR /usr/src
COPY server/Makefile-awq Makefile
# Build specific version of transformers
-RUN make build-awq
-
-# Build eetq kernels
-FROM kernel-builder AS eetq-kernels-builder
-WORKDIR /usr/src
-COPY server/Makefile-eetq Makefile
-# Build specific version of transformers
-RUN make build-eetq
-
-# Build Lorax Punica kernels
-FROM kernel-builder AS lorax-punica-builder
-WORKDIR /usr/src
-COPY server/Makefile-lorax-punica Makefile
-# Build specific version of transformers
-RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
+RUN . .venv/bin/activate && make build-awq
# Build Transformers CUDA kernels
FROM kernel-builder AS custom-kernels-builder
WORKDIR /usr/src
COPY server/custom_kernels/ .
# Build specific version of transformers
-RUN python setup.py build
+RUN . .venv/bin/activate && python setup.py build
# Build mamba kernels
FROM kernel-builder AS mamba-builder
WORKDIR /usr/src
COPY server/Makefile-selective-scan Makefile
-RUN make build-all
+RUN . .venv/bin/activate && make build-all
# Build flashinfer
FROM kernel-builder AS flashinfer-builder
WORKDIR /usr/src
COPY server/Makefile-flashinfer Makefile
-RUN make install-flashinfer
+RUN . .venv/bin/activate && make install-flashinfer
# Text Generation Inference base image
-FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base
-
-# Conda env
-ENV PATH=/opt/conda/bin:$PATH \
- CONDA_PREFIX=/opt/conda
+FROM nvidia/cuda:12.4.0-base-ubuntu22.04 AS base
# Text Generation Inference base env
ENV HF_HOME=/data \
@@ -195,50 +158,59 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
git \
&& rm -rf /var/lib/apt/lists/*
-# Copy conda with PyTorch installed
-COPY --from=pytorch-install /opt/conda /opt/conda
+# RUN curl -LsSf https://astral.sh/uv/install.sh | sh
+# ENV PATH="$PATH:/root/.local/bin"
+COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
+# Install flash-attention dependencies
+# RUN pip install einops --no-cache-dir
+
+# Copy env with PyTorch installed
+COPY --from=pytorch-install /usr/src/.venv /usr/src/.venv
+ENV PYTHON_VERSION=3.11
+RUN uv python install ${PYTHON_VERSION}
+ENV VIRTUAL_ENV=/usr/src/.venv/
+ENV PATH="$PATH:/usr/src/.venv/bin/"
+
+# Install server
+COPY proto proto
+COPY server server
+COPY server/Makefile server/Makefile
+ENV HF_KERNELS_CACHE=/kernels
+RUN cd server && \
+ uv sync --frozen --extra gen --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --extra torch --no-install-project --active && \
+ make gen-server-raw && \
+ kernels download .
+
+RUN cd server && \
+ uv sync --frozen --extra gen --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --extra torch --active --python=${PYTHON_VERSION} && \
+ uv pip install nvidia-nccl-cu12==2.25.1 && \
+ pwd && \
+ text-generation-server --help
# Copy build artifacts from flash attention builder
-COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
-COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
-COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
+COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
+COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
+COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from flash attention v2 builder
-COPY --from=flash-att-v2-builder /opt/conda/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so /opt/conda/lib/python3.11/site-packages
+COPY --from=flash-att-v2-builder /usr/src/.venv/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from custom kernels builder
-COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
+COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from exllama kernels builder
-COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
+COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from exllamav2 kernels builder
-COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
+COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from awq kernels builder
-COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
-# Copy build artifacts from eetq kernels builder
-COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
-# Copy build artifacts from lorax punica kernels builder
-COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
+COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from mamba builder
-COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages
-COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages
-COPY --from=flashinfer-builder /opt/conda/lib/python3.11/site-packages/flashinfer/ /opt/conda/lib/python3.11/site-packages/flashinfer/
+COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
+COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
+COPY --from=flashinfer-builder /usr/src/.venv/lib/python3.11/site-packages/flashinfer/ /usr/src/.venv/lib/python3.11/site-packages/flashinfer/
-# Install flash-attention dependencies
-RUN pip install einops --no-cache-dir
-
-# Install server
-COPY proto proto
-COPY server server
-COPY server/Makefile server/Makefile
-RUN cd server && \
- make gen-server && \
- pip install -r requirements_cuda.txt && \
- pip install ".[attention, bnb, accelerate, compressed-tensors, marlin, moe, quantize, peft, outlines]" --no-cache-dir && \
- pip install nvidia-nccl-cu12==2.22.3
-ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
+# ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
# Required to find libpython within the rust binaries
-ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
# This is needed because exl2 tries to load flash-attn
# And fails with our builds.
ENV EXLLAMA_NO_FLASH_ATTN=1
@@ -273,5 +245,6 @@ FROM base
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh
+ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/root/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/"
ENTRYPOINT ["/tgi-entrypoint.sh"]
# CMD ["--json-output"]
diff --git a/Dockerfile.neuron b/Dockerfile.neuron
new file mode 100644
index 00000000000..9ca7aad9a0c
--- /dev/null
+++ b/Dockerfile.neuron
@@ -0,0 +1,166 @@
+# Fetch and extract the TGI sources
+FROM alpine AS tgi
+RUN mkdir -p /tgi
+
+# Fetch the optimum-neuron sources directly to avoid relying on pypi deployments
+FROM alpine AS optimum-neuron
+RUN mkdir -p /optimum-neuron
+ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.3.0.tar.gz /optimum-neuron/sources.tar.gz
+RUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1
+
+# Build cargo components (adapted from TGI original Dockerfile)
+# Note: we cannot use the cargo-chef base image as it uses python 3.11
+FROM ubuntu:22.04 AS chef
+
+RUN apt-get update -y \
+ && apt-get install -y --no-install-recommends \
+ curl ca-certificates build-essential \
+ && rm -rf /var/lib/apt/lists/* \
+ && apt-get clean
+
+RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.1 --profile minimal -y
+ENV PATH="/root/.cargo/bin:${PATH}"
+RUN cargo install cargo-chef --locked
+
+WORKDIR /usr/src
+
+FROM chef AS planner
+COPY backends/neuron/Cargo.toml Cargo.toml
+COPY Cargo.lock Cargo.lock
+COPY rust-toolchain.toml rust-toolchain.toml
+COPY proto proto
+COPY router router
+COPY backends backends
+COPY launcher launcher
+RUN cargo chef prepare --recipe-path recipe.json
+
+FROM chef AS builder
+
+RUN apt-get update -y \
+ && apt-get install -y --no-install-recommends \
+ unzip python3-dev libssl-dev pkg-config \
+ && rm -rf /var/lib/apt/lists/* \
+ && apt-get clean
+
+RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
+ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
+ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
+ unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
+ rm -f $PROTOC_ZIP
+
+COPY backends/neuron/Cargo.toml Cargo.toml
+COPY --from=planner /usr/src/recipe.json recipe.json
+RUN cargo chef cook --release --recipe-path recipe.json
+
+COPY Cargo.lock Cargo.lock
+COPY rust-toolchain.toml rust-toolchain.toml
+COPY proto proto
+COPY router router
+COPY backends backends
+COPY launcher launcher
+RUN cargo build --release
+
+# Python base image
+FROM ubuntu:22.04 AS base
+
+RUN apt-get update -y \
+ && apt-get install -y --no-install-recommends \
+ python3-pip \
+ python3-setuptools \
+ python-is-python3 \
+ && rm -rf /var/lib/apt/lists/* \
+ && apt-get clean
+RUN pip3 --no-cache-dir install --upgrade pip
+
+# Python server build image
+FROM base AS pyserver
+
+RUN apt-get update -y \
+ && apt-get install -y --no-install-recommends \
+ make \
+ python3-venv \
+ && rm -rf /var/lib/apt/lists/* \
+ && apt-get clean
+
+RUN install -d /pyserver
+WORKDIR /pyserver
+COPY backends/neuron/server server
+COPY proto proto
+RUN pip3 install -r server/build-requirements.txt
+RUN VERBOSE=1 BUILDDIR=/pyserver/build PROTODIR=/pyserver/proto make -C server package
+
+# Neuron base image (used for deployment)
+FROM base AS neuron
+
+# Install system prerequisites
+RUN apt-get update -y \
+ && apt-get install -y --no-install-recommends \
+ gnupg2 \
+ wget \
+ python3-dev \
+ libexpat1 \
+ && rm -rf /var/lib/apt/lists/* \
+ && apt-get clean
+
+RUN echo "deb https://apt.repos.neuron.amazonaws.com jammy main" > /etc/apt/sources.list.d/neuron.list
+RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | apt-key add -
+
+# Install neuronx packages
+RUN apt-get update -y \
+ && apt-get install -y --no-install-recommends \
+ aws-neuronx-dkms=2.22.2.0 \
+ aws-neuronx-collectives=2.26.43.0-47cc904ea \
+ aws-neuronx-runtime-lib=2.26.42.0-2ff3b5c7d \
+ aws-neuronx-tools=2.24.54.0 \
+ libxml2 \
+ && rm -rf /var/lib/apt/lists/* \
+ && apt-get clean
+
+ENV PATH="/opt/bin/:/opt/aws/neuron/bin:${PATH}"
+
+# Install manually torch CPU version to avoid pulling CUDA
+RUN pip3 install \
+ torch==2.7.0 \
+ torchvision==0.22.0 \
+ --index-url https://download.pytorch.org/whl/cpu
+
+RUN pip3 install \
+ neuronx-cc==2.19.8089.0+8ab9f450 \
+ torch-neuronx==2.7.0.2.8.6734+ac864f72 \
+ neuronx-distributed==0.13.14393+b8569585 \
+ libneuronxla==2.2.4410.0+835a67fb \
+ --extra-index-url=https://pip.repos.neuron.amazonaws.com
+
+# Install HuggingFace packages
+RUN pip3 install \
+ hf_transfer huggingface_hub
+
+# Install optimum-neuron
+COPY --from=optimum-neuron /optimum-neuron optimum-neuron
+RUN pip3 install ./optimum-neuron
+
+# TGI base env
+ENV HUGGINGFACE_HUB_CACHE=/tmp \
+ HF_HUB_ENABLE_HF_TRANSFER=1 \
+ PORT=80
+
+# Disable color logs as they are not supported by CloudWatch
+ENV LOGURU_COLORIZE=NO
+ENV LOG_COLORIZE=0
+
+# Install router
+COPY --from=builder /usr/src/target/release/text-generation-router-v2 /usr/local/bin/text-generation-router
+# Install launcher
+COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
+# Install python server
+COPY --from=pyserver /pyserver/build/dist dist
+RUN pip install dist/text_generation_server*.tar.gz
+
+# Final image
+FROM neuron
+
+COPY backends/neuron/tgi_entry_point.py /tgi_entry_point.py
+COPY backends/neuron/tgi-entrypoint.sh /tgi-entrypoint.sh
+RUN chmod +x /tgi-entrypoint.sh
+
+ENTRYPOINT ["/tgi-entrypoint.sh"]
diff --git a/Dockerfile.nix b/Dockerfile.nix
index f1e7e0f553e..90390de6917 100644
--- a/Dockerfile.nix
+++ b/Dockerfile.nix
@@ -6,7 +6,7 @@
FROM nixos/nix:2.18.8 AS builder
RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf
RUN nix profile install nixpkgs#cachix
-RUN cachix use text-generation-inference
+RUN cachix use huggingface
WORKDIR /root
ADD . .
RUN nix build .
diff --git a/Dockerfile_amd b/Dockerfile_amd
index 7638947a5c7..e3e9efda8a2 100644
--- a/Dockerfile_amd
+++ b/Dockerfile_amd
@@ -1,5 +1,5 @@
# Rust builder
-FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
+FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@@ -41,262 +41,237 @@ COPY backends backends
COPY launcher launcher
RUN cargo build --profile release-opt --frozen
-# Text Generation Inference base image for RoCm
-FROM rocm/dev-ubuntu-22.04:6.2 AS base
+FROM rocm/dev-ubuntu-22.04:6.3.1-complete AS base
+ARG HIPBLASLT_BRANCH="4d40e36"
+ARG HIPBLAS_COMMON_BRANCH="7c1566b"
+ARG LEGACY_HIPBLASLT_OPTION=
+ARG RCCL_BRANCH="648a58d"
+ARG RCCL_REPO="https://github.com/ROCm/rccl"
+ARG TRITON_BRANCH="e5be006"
+ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
+ARG PYTORCH_BRANCH="3a585126"
+ARG PYTORCH_VISION_BRANCH="v0.19.1"
+ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
+ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
+ARG FA_BRANCH="b7d29fb"
+ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
+ARG AITER_BRANCH="21d47a9"
+ARG AITER_REPO="https://github.com/ROCm/aiter.git"
+
+ENV PATH=/opt/rocm/llvm/bin:$PATH
+ENV ROCM_PATH=/opt/rocm
+ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
+ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942
+ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
+
+ARG PYTHON_VERSION=3.11
+
+RUN mkdir -p /app
+WORKDIR /app
+ENV DEBIAN_FRONTEND=noninteractive
+
+# Install Python and other dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
- build-essential \
- ca-certificates \
- ccache \
- curl \
- git \
- make \
- libmsgpack-dev \
- libssl-dev \
- llvm-dev \
- g++ \
- # Needed to build VLLM & flash.
- rocthrust-dev \
- hipsparse-dev \
- hipblas-dev \
- hipcub-dev \
- rocblas-dev \
- hiprand-dev \
- hipfft-dev \
- rocrand-dev \
- miopen-hip-dev \
- hipsolver-dev \
- rccl-dev \
- cmake \
- python3.11-venv && \
- rm -rf /var/lib/apt/lists/*
-
-# Keep in sync with `server/pyproject.toml
-ARG MAMBA_VERSION=23.1.0-1
-ARG PYTHON_VERSION='3.11.10'
-# Automatically set by buildx
-ARG TARGETPLATFORM
-ENV PATH=/opt/conda/bin:$PATH
-
-ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
-
-# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
-# Install mamba
-# translating Docker's TARGETPLATFORM into mamba arches
-RUN case ${TARGETPLATFORM} in \
- "linux/arm64") MAMBA_ARCH=aarch64 ;; \
- *) MAMBA_ARCH=x86_64 ;; \
- esac && \
- curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/$%7BMAMBA_VERSION%7D/Mambaforge-$%7BMAMBA_VERSION%7D-Linux-$%7BMAMBA_ARCH%7D.sh"
-RUN chmod +x ~/mambaforge.sh && \
- bash ~/mambaforge.sh -b -p /opt/conda && \
- mamba init && \
- rm ~/mambaforge.sh
-
-# RUN conda install intel::mkl-static intel::mkl-include
-# Install pytorch
-# On arm64 we exit with an error code
-RUN case ${TARGETPLATFORM} in \
- "linux/arm64") exit 1 ;; \
- *) /opt/conda/bin/conda update -y conda && \
- /opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
- esac && \
- /opt/conda/bin/conda clean -ya
-
-# Install flash-attention, torch dependencies
-RUN python3 -m pip install --upgrade pip && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/*
-
-RUN conda install mkl=2021
-ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/
-
-
-ARG COMMON_WORKDIR=/
-WORKDIR ${COMMON_WORKDIR}
-
-
-# Install HIPBLASLt
+ build-essential \
+ ca-certificates \
+ ccache \
+ curl \
+ git \
+ ninja-build \
+ cmake \
+ software-properties-common \
+ python3.11-dev \
+ python3.11-venv && \
+ rm -rf /var/lib/apt/lists/*
+
+COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
+ENV PATH="$PATH:/root/.local/bin"
+RUN uv python install ${PYTHON_VERSION}
+RUN uv venv --python ${PYTHON_VERSION} && uv pip install pip setuptools packaging
+ENV VIRTUAL_ENV=/usr/src/.venv/
+ENV PATH="$PATH:/usr/src/.venv/bin/"
+
+RUN . .venv/bin/activate && pip install -U packaging cmake ninja wheel setuptools pybind11 Cython
+
FROM base AS build_hipblaslt
-ARG HIPBLASLT_BRANCH="e6da924"
-RUN git clone https://github.com/ROCm/hipBLASLt.git \
- && cd hipBLASLt \
+ARG HIPBLASLT_BRANCH
+ARG HIPBLAS_COMMON_BRANCH
+# Set to "--legacy_hipblas_direct" for ROCm<=6.2
+ARG LEGACY_HIPBLASLT_OPTION
+RUN git clone https://github.com/ROCm/hipBLAS-common.git
+RUN . .venv/bin/activate && cd hipBLAS-common \
+ && git checkout ${HIPBLAS_COMMON_BRANCH} \
+ && mkdir build \
+ && cd build \
+ && cmake .. \
+ && make package \
+ && dpkg -i ./*.deb
+RUN git clone https://github.com/ROCm/hipBLASLt
+RUN . .venv/bin/activate && cd hipBLASLt \
&& git checkout ${HIPBLASLT_BRANCH} \
- && SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \
+ && ./install.sh -d --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \
&& cd build/release \
&& make package
+RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install
-FROM scratch AS export_hipblaslt
-ARG COMMON_WORKDIR
-COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb /
-
-# RCCL build stages
FROM base AS build_rccl
-ARG RCCL_BRANCH="rocm-6.2.0"
-RUN git clone https://github.com/ROCm/rccl \
- && cd rccl \
+ARG RCCL_BRANCH
+ARG RCCL_REPO
+RUN git clone ${RCCL_REPO}
+RUN . .venv/bin/activate && cd rccl \
&& git checkout ${RCCL_BRANCH} \
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
-FROM scratch AS export_rccl
-ARG COMMON_WORKDIR
-COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb /
+RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install
-# Triton build stages
FROM base AS build_triton
-ARG TRITON_BRANCH="e192dba"
-ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
-RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \
- && cd triton \
+ARG TRITON_BRANCH
+ARG TRITON_REPO
+RUN git clone ${TRITON_REPO}
+RUN . .venv/bin/activate && cd triton \
&& git checkout ${TRITON_BRANCH} \
&& cd python \
&& python3 setup.py bdist_wheel --dist-dir=dist
-FROM scratch AS export_triton
-ARG COMMON_WORKDIR
-COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl /
+RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install
-# # AMD-SMI build stages
FROM base AS build_amdsmi
-RUN cd /opt/rocm/share/amd_smi \
+RUN . .venv/bin/activate && cd /opt/rocm/share/amd_smi \
&& pip wheel . --wheel-dir=dist
-FROM scratch AS export_amdsmi
-COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl /
-
-
-FROM base as build_pytorch
-
-RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
- if ls /install/*.deb; then \
- dpkg -i /install/*.deb \
- && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
- && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
- fi
-
-ARG BUILD_ENVIRONMENT=pytorch-linux-jammy-rocm6.2-py3.11
-ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
-
-# A commit to fix the output scaling factor issue in _scaled_mm
-# Not yet in 2.5.0-rc1
-ARG PYTORCH_BRANCH="cedc116"
-ARG PYTORCH_VISION_BRANCH="v0.19.1"
-ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
-
-RUN git clone ${PYTORCH_REPO} pytorch \
- && cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \
- && pip install -r requirements.txt --no-cache-dir \
- && python tools/amd_build/build_amd.py \
- && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist
-FROM scratch as export_pytorch
-ARG COMMON_WORKDIR
-COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl /
-
-FROM base AS install_deps
-
-ARG COMMON_WORKDIR
-
-# Install hipblaslt
-RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
- if ls /install/*.deb; then \
- dpkg -i /install/*.deb \
- && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
- && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
- fi
-
-RUN --mount=type=bind,from=export_rccl,src=/,target=/install \
- if ls /install/*.deb; then \
- dpkg -i /install/*.deb \
- # RCCL needs to be installed twice
- && dpkg -i /install/*.deb \
- && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
- && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \
- fi
-
-RUN --mount=type=bind,from=export_triton,src=/,target=/install \
- if ls /install/*.whl; then \
- # Preemptively uninstall to prevent pip same-version no-installs
- pip uninstall -y triton \
- && pip install /install/*.whl; \
- fi
-
-RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \
- # Preemptively uninstall to prevent pip same-version no-installs
- pip uninstall -y amdsmi \
- && pip install /install/*.whl;
-
-RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
- if ls /install/*.whl; then \
- # Preemptively uninstall to prevent pip same-version no-installs
- pip uninstall -y torch torchvision \
- && pip install /install/*.whl; \
- fi
-
-FROM install_deps AS kernel-builder
-
+RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install
+
+FROM base AS build_pytorch
+ARG PYTORCH_BRANCH
+ARG PYTORCH_VISION_BRANCH
+ARG PYTORCH_REPO
+ARG PYTORCH_VISION_REPO
+ARG FA_BRANCH
+ARG FA_REPO
+RUN git clone ${PYTORCH_REPO} pytorch
+RUN . .venv/bin/activate && cd pytorch && git checkout ${PYTORCH_BRANCH} && \
+ pip install -r requirements.txt && git submodule update --init --recursive \
+ && python3 tools/amd_build/build_amd.py \
+ && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
+ && pip install dist/*.whl
+RUN git clone ${PYTORCH_VISION_REPO} vision
+RUN . .venv/bin/activate && cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
+ && python3 setup.py bdist_wheel --dist-dir=dist \
+ && pip install dist/*.whl
+RUN git clone ${FA_REPO}
+RUN . .venv/bin/activate && cd flash-attention \
+ && git checkout ${FA_BRANCH} \
+ && git submodule update --init \
+ && MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
+RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
+ && cp /app/vision/dist/*.whl /app/install \
+ && cp /app/flash-attention/dist/*.whl /app/install
+
+FROM base AS final
+RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
+ dpkg -i /install/*deb \
+ && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
+ && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status
+RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \
+ dpkg -i /install/*deb \
+ && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
+ && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status
+RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
+ . .venv/bin/activate && \
+ pip install /install/*.whl
+RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
+ . .venv/bin/activate && \
+ pip install /install/*.whl
+RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
+ . .venv/bin/activate && \
+ pip install /install/*.whl
+
+ARG AITER_REPO
+ARG AITER_BRANCH
+RUN git clone --recursive ${AITER_REPO}
+RUN . .venv/bin/activate && cd aiter \
+ && git checkout ${AITER_BRANCH} \
+ && git submodule update --init --recursive \
+ && pip install -r requirements.txt \
+ && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter
+
+RUN rm -rf /var/lib/apt/lists/*
+
+FROM final AS kernel-builder
# # Build vllm kernels
FROM kernel-builder AS vllm-builder
-WORKDIR /usr/src
COPY server/Makefile-vllm Makefile
+RUN . .venv/bin/activate && pip install setuptools_scm
# Build specific version of vllm
-RUN make build-vllm-rocm
-
-# Build Flash Attention v2 kernels
-FROM kernel-builder AS flash-att-v2-builder
-WORKDIR /usr/src
-
-COPY server/Makefile-flash-att-v2 Makefile
-
-# Build specific version of flash attention v2
-RUN make build-flash-attention-v2-rocm
+RUN . .venv/bin/activate && make build-vllm-rocm
# Build Transformers CUDA kernels (gpt-neox and bloom)
FROM kernel-builder AS custom-kernels-builder
-WORKDIR /usr/src
COPY server/custom_kernels/ .
-RUN python setup.py build
+RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
# Build exllama kernels
FROM kernel-builder AS exllama-kernels-builder
-WORKDIR /usr/src
COPY server/exllama_kernels/ .
-
-RUN python setup.py build
+RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
# Build exllama v2 kernels
FROM kernel-builder AS exllamav2-kernels-builder
-WORKDIR /usr/src
COPY server/exllamav2_kernels/ .
-
-RUN python setup.py build
-
-FROM install_deps AS base-copy
+RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
+
+FROM kernel-builder AS marlin-kernels
+ENV MARLIN_KERNELS_BRANCH=v0.3.6
+ENV VLLM_TARGET_DEVICE=rocm
+RUN . .venv/bin/activate && git clone https://github.com/danieldk/marlin-kernels.git && \
+ cd marlin-kernels && \
+ git checkout ${MARLIN_KERNELS_BRANCH} && \
+ python3 setup.py bdist_wheel --dist-dir=dist
+
+FROM kernel-builder AS moe-kernels
+ENV MOE_KERNELS_BRANCH=v0.8.2
+ENV VLLM_TARGET_DEVICE=rocm
+RUN . .venv/bin/activate && git clone https://github.com/danieldk/moe-kernels.git && \
+ cd moe-kernels && \
+ git checkout ${MOE_KERNELS_BRANCH} && \
+ python3 setup.py bdist_wheel --dist-dir=dist
+
+FROM final AS base-copy
# Text Generation Inference base env
ENV HF_HOME=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80
-# Copy builds artifacts from vllm builder
-COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
-
-# Copy build artifacts from flash attention v2 builder
-COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
-
-# Copy build artifacts from custom kernels builder
-COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
-
-# Copy build artifacts from exllama kernels builder
-COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
-
-# Copy build artifacts from exllamav2 kernels builder
-COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
+ENV VIRTUAL_ENV=/app/.venv/
+ENV PATH="$PATH:/app/.venv/bin/"
# Install server
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
RUN cd server && \
- make gen-server && \
- pip install -r requirements_rocm.txt && \
- pip install ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
+ uv pip install grpcio-tools mypy-protobuf && \
+ uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir && \
+ make gen-server-raw
+RUN cd server && \
+ pwd && \
+ text-generation-server --help
+
+RUN --mount=type=bind,from=vllm-builder,src=/app/vllm/dist,target=/install \
+ uv pip install /install/*.whl
+RUN --mount=type=bind,from=custom-kernels-builder,src=/app/dist,target=/install \
+ uv pip install /install/*.whl
+RUN --mount=type=bind,from=custom-kernels-builder,src=/app/dist,target=/install \
+ uv pip install /install/*.whl
+RUN --mount=type=bind,from=exllama-kernels-builder,src=/app/dist,target=/install \
+ uv pip install /install/*.whl
+RUN --mount=type=bind,from=exllamav2-kernels-builder,src=/app/dist,target=/install \
+ uv pip install /install/*.whl
+RUN --mount=type=bind,from=marlin-kernels,src=/app/marlin-kernels/dist,target=/install \
+ uv pip install /install/*.whl
+RUN --mount=type=bind,from=moe-kernels,src=/app/moe-kernels/dist,target=/install \
+ uv pip install /install/*.whl
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
@@ -304,7 +279,6 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/l
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
-ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
# AWS Sagemaker compatible image
FROM base AS sagemaker
@@ -335,4 +309,6 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh
ENTRYPOINT ["/tgi-entrypoint.sh"]
-CMD ["--json-output"]
+ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/root/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib"
+ENV PYTHONPATH=/app/.venv/lib/python3.11/site-packages
+# CMD ["--json-output"]
diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi
new file mode 100644
index 00000000000..bd09c91e0e4
--- /dev/null
+++ b/Dockerfile_gaudi
@@ -0,0 +1,129 @@
+# Those arguments are required to build the image
+ARG HABANA_VERSION=1.21.0
+ARG PYTORCH_VERSION=2.6.0
+
+# Rust builder
+FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
+WORKDIR /usr/src
+
+ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
+
+FROM chef AS planner
+COPY Cargo.lock Cargo.lock
+COPY Cargo.toml Cargo.toml
+COPY rust-toolchain.toml rust-toolchain.toml
+COPY proto proto
+COPY benchmark benchmark
+COPY router router
+COPY backends backends
+COPY launcher launcher
+RUN cargo chef prepare --recipe-path recipe.json
+
+FROM chef AS builder
+
+ENV PYO3_PYTHON="/root/.local/bin/python" \
+ PYTHON_SYS_EXECUTABLE="/root/.local/bin/python" \
+ PYO3_PYTHON_VERSION="3.10"
+
+RUN curl -LsSf https://astral.sh/uv/install.sh | sh \
+ && . $HOME/.local/bin/env \
+ && uv python install 3.10 --default --preview \
+ && test -f /root/.local/bin/python || (echo "Python 3.10 not found at /root/.local/bin/python" && exit 1)
+
+RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
+ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
+ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
+ unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
+ rm -f $PROTOC_ZIP
+
+COPY --from=planner /usr/src/recipe.json recipe.json
+RUN cargo chef cook --profile release-opt --recipe-path recipe.json
+
+ARG GIT_SHA
+ARG DOCKER_LABEL
+
+COPY Cargo.toml Cargo.toml
+COPY rust-toolchain.toml rust-toolchain.toml
+COPY proto proto
+COPY benchmark benchmark
+COPY router router
+COPY backends backends
+COPY launcher launcher
+RUN cargo build --profile release-opt
+
+# Text Generation Inference base image
+ARG HABANA_VERSION
+ARG PYTORCH_VERSION
+
+FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base
+
+ENV ATTENTION=paged
+ENV PREFIX_CACHING=0
+ENV PREFILL_CHUNKING=0
+ENV PT_HPU_LAZY_MODE=1
+ENV PT_HPU_WEIGHT_SHARING=0
+ENV VLLM_EXPONENTIAL_BUCKETING=true
+
+# Text Generation Inference base env
+ENV HF_HOME=/data \
+ HF_HUB_ENABLE_HF_TRANSFER=1 \
+ PORT=80
+
+# Assert that Python 3.10 is installed as the launcher is compiled with Python 3.10
+RUN python3.10 --version || (echo "Python 3.10 is not installed" && exit 1)
+
+# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
+RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
+ dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
+
+WORKDIR /usr/src
+
+RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
+ libssl-dev \
+ ca-certificates \
+ make \
+ curl \
+ git \
+ && rm -rf /var/lib/apt/lists/*
+
+# Install server
+COPY proto proto
+COPY backends/gaudi/server server
+COPY backends/gaudi/server/Makefile server/Makefile
+ARG HABANA_VERSION
+RUN cd server && \
+ make gen-server && \
+ pip install --no-deps -r requirements.txt && \
+ bash ./dill-0.3.8-patch.sh && \
+ pip install . --no-cache-dir
+RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix
+RUN pip install compressed-tensors==0.9.1
+
+# Install benchmarker
+COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
+# Install router
+COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
+# Install launcher
+COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
+
+
+# AWS Sagemaker compatible image
+FROM base AS sagemaker
+
+COPY sagemaker-entrypoint.sh entrypoint.sh
+RUN chmod +x entrypoint.sh
+
+ENTRYPOINT ["./entrypoint.sh"]
+
+# Final image
+FROM base
+
+ENV HF_HUB_ENABLE_HF_TRANSFER=1
+ENV HABANA_VISIBLE_DEVICES=all
+ENV OMPI_MCA_btl_vader_single_copy_mechanism=NONE
+
+COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh
+RUN chmod +x /tgi-entrypoint.sh
+
+ENTRYPOINT ["/tgi-entrypoint.sh"]
+CMD ["--json-output"]
diff --git a/Dockerfile_intel b/Dockerfile_intel
index e024f31a563..9eb746256a3 100644
--- a/Dockerfile_intel
+++ b/Dockerfile_intel
@@ -1,6 +1,6 @@
ARG PLATFORM=xpu
-FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
+FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@@ -45,7 +45,7 @@ RUN cargo build --profile release-opt --frozen
# Text Generation Inference base image for Intel
-FROM intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04 AS xpu
+FROM intel/oneapi-basekit:2025.1.3-0-devel-ubuntu22.04 AS xpu
USER root
@@ -87,7 +87,7 @@ RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https:/
RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d
-RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-pti-dev-0.9
+RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc libnl-genl-3-200
# Text Generation Inference base env
ENV HF_HOME=/data \
@@ -96,29 +96,28 @@ ENV HF_HOME=/data \
+
WORKDIR /usr/src
-RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp311-cp311-linux_x86_64.whl --no-cache-dir
-RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torchaudio-2.5.0a0%2B56bc006-cp311-cp311-linux_x86_64.whl --no-cache-dir
-RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torchvision-0.20.0a0%2B8e8a208-cp311-cp311-linux_x86_64.whl --no-cache-dir
-RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.5.10%2Bgit9d489a8-cp311-cp311-linux_x86_64.whl --no-cache-dir
-RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp311-cp311-linux_x86_64.whl --no-cache-dir
-RUN pip install triton-xpu==3.0.0b2 --no-cache-dir
+RUN pip install torch==2.8.0 torchvision==0.23.0 --index-url https://download.pytorch.org/whl/xpu
# Install server
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
+ENV UV_SYSTEM_PYTHON=1
RUN cd server && \
make gen-server && \
- pip install -r requirements_intel.txt && \
- pip install ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
+ pip install -U pip uv && \
+ uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
-ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/pti/0.9/lib:/opt/conda/lib
+ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib
ENV CCL_ZE_IPC_EXCHANGE=sockets
-#ENV TORCH_LLM_ALLREDUCE=1
-#ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
+ENV TORCH_LLM_ALLREDUCE=1
+ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
+ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0
+RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.8.10%2Bxpu-cp311-cp311-linux_x86_64.whl
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
@@ -158,7 +157,7 @@ ARG MAMBA_VERSION=23.1.0-1
ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx
ARG TARGETPLATFORM
-ENV PATH /opt/conda/bin:$PATH
+ENV PATH=/opt/conda/bin:$PATH
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
@@ -181,22 +180,14 @@ RUN case ${TARGETPLATFORM} in \
RUN conda install -c conda-forge gperftools mkl
-
-RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.5.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
-RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.20.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
-RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
-
-RUN pip install triton py-libnuma
+RUN pip install torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cpu
+RUN pip install triton==3.2.0 py-libnuma
WORKDIR /usr/src
-RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout b7b552baf64283b594665b8687430fe92990e497
-RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0
-
-RUN sed -i 's/VERSION_MINOR 6/VERSION_MINOR 5/' intel-extension-for-pytorch/version.txt
-RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
+RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/intel_extension_for_pytorch-2.7.0%2Bcpu-cp311-cp311-linux_x86_64.whl
+RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/oneccl_bind_pt-2.7.0%2Bcpu-cp311-cp311-linux_x86_64.whl
-RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
ENV CCL_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch
@@ -209,10 +200,11 @@ ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
+ENV UV_SYSTEM_PYTHON=1
RUN cd server && \
make gen-server && \
- pip install -r requirements_intel.txt && \
- pip install ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
+ pip install -U pip uv && \
+ uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
@@ -222,9 +214,9 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
FROM ${PLATFORM} AS final
-ENV ATTENTION=paged
-ENV PREFIX_CACHING=0
-ENV PREFILL_CHUNKING=0
+ENV ATTENTION=flashdecoding-ipex
+ENV PREFIX_CACHING=1
+ENV PREFILL_CHUNKING=1
ENV CUDA_GRAPHS=0
ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]
diff --git a/Dockerfile_llamacpp b/Dockerfile_llamacpp
new file mode 100644
index 00000000000..291ae88cb60
--- /dev/null
+++ b/Dockerfile_llamacpp
@@ -0,0 +1,88 @@
+FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu24.04 AS deps
+
+ARG llamacpp_version=b4827
+ARG llamacpp_cuda=OFF
+ARG llamacpp_native=ON
+ARG llamacpp_cpu_arm_arch=native
+ARG cuda_arch=75-real;80-real;86-real;89-real;90-real
+
+WORKDIR /opt/src
+
+ENV DEBIAN_FRONTEND=noninteractive
+RUN apt update && apt upgrade -y && apt install -y \
+ clang \
+ cmake \
+ curl \
+ git \
+ python3-dev \
+ libssl-dev \
+ pkg-config \
+ tar
+
+ADD https://github.com/ggml-org/llama.cpp/archive/refs/tags/${llamacpp_version}.tar.gz /opt/src/
+RUN mkdir -p llama.cpp \
+ && tar -xzf ${llamacpp_version}.tar.gz -C llama.cpp --strip-components=1 \
+ && cd llama.cpp \
+ && cmake -B build \
+ -DCMAKE_INSTALL_PREFIX=/usr \
+ -DCMAKE_INSTALL_LIBDIR=/usr/lib \
+ -DCMAKE_C_COMPILER=clang \
+ -DCMAKE_CXX_COMPILER=clang++ \
+ -DCMAKE_CUDA_ARCHITECTURES=${cuda_arch} \
+ -DGGML_CUDA=${llamacpp_cuda} \
+ -DGGML_NATIVE=${llamacpp_native} \
+ -DGGML_CPU_ARM_ARCH=${llamacpp_cpu_arm_arch} \
+ -DLLAMA_BUILD_COMMON=OFF \
+ -DLLAMA_BUILD_TESTS=OFF \
+ -DLLAMA_BUILD_EXAMPLES=OFF \
+ -DLLAMA_BUILD_SERVER=OFF \
+ && cmake --build build --parallel --config Release \
+ && cmake --install build
+
+WORKDIR /app
+COPY rust-toolchain.toml rust-toolchain.toml
+RUN curl -sSf https://sh.rustup.rs | sh -s -- --no-modify-path --default-toolchain 1.85.1 --profile minimal -y
+ENV PATH="/root/.cargo/bin:$PATH"
+RUN cargo install cargo-chef --locked
+
+FROM deps AS planner
+COPY . .
+RUN cargo chef prepare --recipe-path recipe.json
+
+FROM deps AS builder
+COPY --from=planner /app/recipe.json recipe.json
+RUN cargo chef cook \
+ --recipe-path recipe.json \
+ --profile release \
+ --package text-generation-router-llamacpp
+COPY . .
+RUN cargo build \
+ --profile release \
+ --package text-generation-router-llamacpp --frozen
+
+FROM nvidia/cuda:12.8.0-cudnn-runtime-ubuntu24.04
+WORKDIR /app
+
+ENV DEBIAN_FRONTEND=noninteractive
+RUN apt update && apt upgrade -y && apt install -y \
+ python3-venv \
+ python3-pip
+
+RUN python3 -m venv /venv
+ENV PATH="/venv/bin:$PATH"
+
+COPY backends/llamacpp/requirements.txt requirements.txt
+COPY --from=builder /opt/src/llama.cpp/gguf-py gguf-py
+COPY --from=builder /opt/src/llama.cpp/convert_hf_to_gguf.py /bin/
+
+RUN pip3 install --no-cache-dir \
+ -r requirements.txt \
+ -e gguf-py
+
+COPY --from=builder /usr/lib/libllama.so /usr/lib/
+COPY --from=builder /usr/lib/libggml*.so /usr/lib/
+COPY --from=builder /app/target/release/text-generation-router-llamacpp /usr/bin/
+
+ENV HF_HUB_ENABLE_HF_TRANSFER=1
+
+ENTRYPOINT ["text-generation-router-llamacpp"]
diff --git a/Dockerfile_trtllm b/Dockerfile_trtllm
index 3ccb0310bea..c0cf90335cd 100644
--- a/Dockerfile_trtllm
+++ b/Dockerfile_trtllm
@@ -1,52 +1,55 @@
-ARG CUDA_ARCH_LIST="75-real;80-real;86-real;89-real;90-real"
-ARG OMPI_VERSION="4.1.6"
-
-# Build dependencies resolver stage
-FROM lukemathwalker/cargo-chef:latest AS chef
-WORKDIR /usr/src/text-generation-inference/backends/trtllm
-
-FROM chef AS planner
-COPY . .
-RUN cargo chef prepare --recipe-path recipe.json
+ARG cuda_arch_list="75-real;80-real;86-real;89-real;90-real;100-real;120-real"
+ARG cuda_base=12.8.0
+ARG build_type=release
+ARG ompi_version=4.1.7
+ARG sccache_gha_enabled=off
+ARG actions_results_url=""
+ARG actions_runtime_token=""
# CUDA dependent dependencies resolver stage
-FROM nvidia/cuda:12.6.1-cudnn-devel-ubuntu22.04 AS cuda-builder
+FROM nvidia/cuda:${cuda_base}-cudnn-devel-ubuntu24.04 AS cuda-builder
-RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
- --mount=type=cache,target=/var/lib/apt,sharing=locked \
- apt update && apt install -y \
+RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
build-essential \
cmake \
curl \
- gcc \
- g++ \
+ gcc-14 \
+ g++-14 \
git \
git-lfs \
+ lld \
libssl-dev \
+ libucx-dev \
+ libasan8 \
+ libubsan1 \
ninja-build \
pkg-config \
+ pipx \
python3 \
python3-dev \
python3-setuptools \
tar \
- wget
+ wget --no-install-recommends && \
+ pipx ensurepath
ENV TGI_INSTALL_PREFIX=/usr/local/tgi
ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
# Install OpenMPI
FROM cuda-builder AS mpi-builder
-ARG OMPI_VERSION
+WORKDIR /opt/src/mpi
+
+ARG ompi_version
+ENV OMPI_VERSION=${ompi_version}
+ENV OMPI_TARBALL_FILENAME=openmpi-${OMPI_VERSION}.tar.bz2
+ADD --checksum=sha256:54a33cb7ad81ff0976f15a6cc8003c3922f0f3d8ceed14e1813ef3603f22cd34 \
+ https://download.open-mpi.org/release/open-mpi/v4.1/${OMPI_TARBALL_FILENAME} .
-ENV OMPI_TARBALL_FILENAME="openmpi-$OMPI_VERSION.tar.bz2"
-RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILENAME" -P /opt/src && \
- mkdir /usr/src/mpi && \
- tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \
- cd /usr/src/mpi && \
+RUN tar --strip-components=1 -xf ${OMPI_TARBALL_FILENAME} &&\
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \
make -j all && \
make install && \
- rm -rf "/opt/src/$OMPI_TARBALL_FILENAME"
+ rm -rf ${OMPI_TARBALL_FILENAME}/..
# Install TensorRT
FROM cuda-builder AS trt-builder
@@ -58,38 +61,62 @@ RUN chmod +x /opt/install_tensorrt.sh && \
FROM cuda-builder AS tgi-builder
WORKDIR /usr/src/text-generation-inference
-# Install Rust
-RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \
- chmod -R a+w /root/.rustup && \
- chmod -R a+w /root/.cargo
+# Scoped global args reuse
+ARG cuda_arch_list
+ARG build_type
+ARG sccache_gha_enabled
+ARG actions_results_url
+ARG actions_runtime_token
+# Install Rust
ENV PATH="/root/.cargo/bin:$PATH"
-RUN cargo install cargo-chef
-
-# Cache dependencies
-COPY --from=planner /usr/src/text-generation-inference/backends/trtllm/recipe.json .
-RUN cargo chef cook --release --recipe-path recipe.json
+RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.1 --profile minimal -y && \
+ chmod -R a+w /root/.rustup && \
+ chmod -R a+w /root/.cargo && \
+ cargo install sccache --version ">=0.10.0" --locked
-# Build actual TGI
-ARG CUDA_ARCH_LIST
-ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt:$CMAKE_PREFIX_PATH"
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
-ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig:$PKG_CONFIG_PATH"
-
-COPY . .
+ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig"
+ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt"
+
+ENV USE_LLD_LINKER=ON
+ENV CUDA_ARCH_LIST=${cuda_arch_list}
+
+# SCCACHE Specifics args - before finding a better, more generic, way...
+ENV SCCACHE_GHA_ENABLED=${sccache_gha_enabled}
+ENV ACTIONS_RESULTS_URL=${actions_results_url}
+ENV ACTIONS_RUNTIME_TOKEN=${actions_runtime_token}
+
+COPY Cargo.lock Cargo.lock
+COPY Cargo.toml Cargo.toml
+COPY rust-toolchain.toml rust-toolchain.toml
+COPY router router
+COPY backends backends
+COPY benchmark benchmark
+COPY launcher launcher
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
-RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
- cd backends/trtllm && \
- CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release
-FROM nvidia/cuda:12.6.1-cudnn-runtime-ubuntu22.04 AS runtime
-RUN apt update && apt install -y python3-minimal python3-dev python3-pip && \
+ENV RUSTC_WRAPPER=sccache
+ENV CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX
+RUN export CC=gcc-14 \
+ export CXX=g++-14 \
+ export CMAKE_C_COMPILER_LAUNCHER=sccache && \
+ export CMAKE_CXX_COMPILER_LAUNCHER=sccache && \
+ export CMAKE_CUDA_COMPILER_LAUNCHER=sccache && \
+ mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
+ cargo build --profile ${build_type} --package text-generation-backends-trtllm --bin text-generation-backends-trtllm && \
+ sccache --show-stats
+
+FROM nvidia/cuda:${cuda_base}-cudnn-runtime-ubuntu24.04 AS runtime
+RUN apt update && apt install -y libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \
rm -rf /var/lib/{apt,dpkg,cache,log}/ && \
- python3 -m pip install transformers tokenizers
+ pipx ensurepath && \
+ pipx install --include-deps transformers tokenizers
WORKDIR /usr/local/tgi/bin
+ENV PATH=/root/.local/share/pipx/venvs/transformers/bin/:$PATH
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/mpi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
ENV TOKENIZERS_PARALLELISM=false
ENV OMPI_MCA_plm_rsh_agent=""
@@ -99,10 +126,33 @@ COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
+# This is used only for the CI/CD
+FROM nvidia/cuda:${cuda_base}-cudnn-runtime-ubuntu24.04 AS ci-runtime
+RUN apt update && apt install -y libasan8 libubsan1 libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \
+ rm -rf /var/lib/{apt,dpkg,cache,log}/ && \
+ pipx ensurepath && \
+ pipx install --include-deps transformers tokenizers
+
+WORKDIR /usr/local/tgi/bin
+
+ENV PATH=/root/.local/share/pipx/venvs/transformers/bin/:$PATH
+ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/mpi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
+ENV TOKENIZERS_PARALLELISM=false
+ENV OMPI_MCA_plm_rsh_agent=""
+
+COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
+COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
+COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
+
+# Basically we copy from target/debug instead of target/release
+COPY --from=tgi-builder /usr/src/text-generation-inference/target/debug/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
+
+# This is the final image
FROM runtime
LABEL co.huggingface.vendor="Hugging Face Inc."
LABEL org.opencontainers.image.authors="hardware@hf.co"
+LABEL org.opencontainers.title="Text-Generation-Inference TensorRT-LLM Backend"
ENTRYPOINT ["./text-generation-launcher"]
CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"]
diff --git a/Makefile b/Makefile
index 3068a06f41b..2ecdd45ca93 100644
--- a/Makefile
+++ b/Makefile
@@ -53,3 +53,6 @@ run-falcon-7b-instruct-quantize:
clean:
rm -rf target aml
+
+preview_doc:
+ doc-builder preview text-generation-inference docs/source --not_python_module
diff --git a/README.md b/README.md
index 631a97a2ddc..0890d9c6cde 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,7 @@
-
+
# Text Generation Inference
@@ -14,7 +14,7 @@
A Rust, Python and gRPC server for text generation inference. Used in production at [Hugging Face](https://huggingface.co)
-to power Hugging Chat, the Inference API and Inference Endpoint.
+to power Hugging Chat, the Inference API and Inference Endpoints.
@@ -84,7 +84,7 @@ model=HuggingFaceH4/zephyr-7b-beta
volume=$PWD/data
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
-3.0.0 ghcr.io/huggingface/text-generation-inference:3.0.0 --model-id $model
+ ghcr.io/huggingface/text-generation-inference:3.3.5 --model-id $model
```
And then you can make requests like
@@ -121,7 +121,7 @@ curl localhost:8080/v1/chat/completions \
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
-**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.0-rocm --model-id $model` instead of the command above.
+**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.5-rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
```
@@ -141,8 +141,8 @@ You have the option to utilize the `HF_TOKEN` environment variable for configuri
For example, if you want to serve the gated Llama V2 model variants:
1. Go to https://huggingface.co/settings/tokens
-2. Copy your cli READ token
-3. Export `HF_TOKEN=`
+2. Copy your CLI READ token
+3. Export `HF_TOKEN=`
or with Docker:
@@ -151,13 +151,14 @@ model=meta-llama/Meta-Llama-3.1-8B-Instruct
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
token=
-docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.0 --model-id $model
+docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \
+ ghcr.io/huggingface/text-generation-inference:3.3.5 --model-id $model
```
### A note on Shared Memory (shm)
[`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by
-`PyTorch` to do distributed training/inference. `text-generation-inference` make
+`PyTorch` to do distributed training/inference. `text-generation-inference` makes
use of `NCCL` to enable Tensor Parallelism to dramatically speed up inference for large language models.
In order to share data between the different devices of a `NCCL` group, `NCCL` might fall back to using the host memory if
@@ -196,14 +197,26 @@ Detailed blogpost by Adyen on TGI inner workings: [LLM inference at scale with T
You can also opt to install `text-generation-inference` locally.
-First [install Rust](https://rustup.rs/) and create a Python virtual environment with at least
-Python 3.9, e.g. using `conda`:
+First clone the repository and change directory into it:
+
+```shell
+git clone https://github.com/huggingface/text-generation-inference
+cd text-generation-inference
+```
+
+Then [install Rust](https://rustup.rs/) and create a Python virtual environment with at least
+Python 3.9, e.g. using `conda` or `python venv`:
```shell
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
+#using conda
conda create -n text-generation-inference python=3.11
conda activate text-generation-inference
+
+#using python venv
+python3 -m venv .venv
+source .venv/bin/activate
```
You may also need to install Protoc.
@@ -243,14 +256,15 @@ Another option is to install `text-generation-inference` locally using [Nix](htt
we only support Nix on x86_64 Linux with CUDA GPUs. When using Nix, all dependencies can
be pulled from a binary cache, removing the need to build them locally.
-First follow the instructions to [install Cachix and enable the TGI cache](https://app.cachix.org/cache/text-generation-inference).
+First follow the instructions to [install Cachix and enable the Hugging Face cache](https://app.cachix.org/cache/huggingface).
Setting up the cache is important, otherwise Nix will build many of the dependencies
locally, which can take hours.
After that you can run TGI with `nix run`:
```shell
-nix run . -- --model-id meta-llama/Llama-3.1-8B-Instruct
+cd text-generation-inference
+nix run --extra-experimental-features nix-command --extra-experimental-features flakes . -- --model-id meta-llama/Llama-3.1-8B-Instruct
```
**Note:** when you are using Nix on a non-NixOS system, you have to [make some symlinks](https://danieldk.eu/Nix-CUDA-on-non-NixOS-systems#make-runopengl-driverlib-and-symlink-the-driver-library)
diff --git a/backends/gaudi/Makefile b/backends/gaudi/Makefile
new file mode 100644
index 00000000000..40d17f61253
--- /dev/null
+++ b/backends/gaudi/Makefile
@@ -0,0 +1,67 @@
+mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
+mkfile_dir := $(dir $(mkfile_path))
+root_dir := ${mkfile_dir}/../..
+
+HABANA_VERSION := 1.21.0
+PYTORCH_VERSION := 2.6.0
+
+.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
+
+image:
+ docker build --ulimit nofile=4096 -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION)
+
+run-local-dev-container:
+ docker run -it \
+ --runtime=habana \
+ --ipc=host \
+ --cap-add=sys_nice \
+ --net=host \
+ -e HABANA_VISIBLE_DEVICES=all \
+ -e OMPI_MCA_btl_vader_single_copy_mechanism=none \
+ -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
+ -e HF_TOKEN=`cat /home/ubuntu/.cache/huggingface/token` \
+ -e LOG_LEVEL=debug \
+ -e PORT=8080 \
+ -v /home/ubuntu/.cache/huggingface:/data \
+ -v $(PWD):/text-generation-inference \
+ -w /text-generation-inference \
+ vault.habana.ai/gaudi-docker/$(HABANA_VERSION)/ubuntu22.04/habanalabs/pytorch-installer-$(PYTORCH_VERSION):latest
+
+install-dependencies:
+ pip install git+https://github.com/HabanaAI/DeepSpeed.git@$(HABANA_VERSION)
+ pip install outlines~=0.0.34
+ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
+
+install-server:
+ make -C ${root_dir}/backends/gaudi/server install PROTO_PATH=../../../proto/v3
+
+install-router:
+ make -C ${root_dir} install-router
+
+install-launcher:
+ make -C ${root_dir} install-launcher
+
+# use source to load the rust in path
+local-dev-install: install-dependencies
+ bash -c 'source "$$HOME/.cargo/env" && \
+ make install-server && \
+ make install-router && \
+ make install-launcher'
+
+# In order to run the integration tests, you need to first build the image (make -C backends/gaudi image)
+run-integration-tests:
+ DOCKER_VOLUME=${root_dir}/data \
+ HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
+ pytest --durations=0 -s -vv ${root_dir}/integration-tests --gaudi
+
+run-integration-tests-with-all-models:
+ DOCKER_VOLUME=${root_dir}/data \
+ HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
+ pytest --durations=0 -s -vv ${root_dir}/integration-tests --gaudi --gaudi-all-models
+
+# This is used to capture the expected outputs for the integration tests offering an easy way to add more models to the integration tests
+capture-expected-outputs-for-integration-tests:
+ pip install -U pip uv
+ DOCKER_VOLUME=${root_dir}/data \
+ HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
+ uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests/capture_expected_outputs.py
diff --git a/backends/gaudi/README.md b/backends/gaudi/README.md
new file mode 100644
index 00000000000..7713040f76c
--- /dev/null
+++ b/backends/gaudi/README.md
@@ -0,0 +1,152 @@
+# Text-generation-inference - Gaudi backend
+
+## Description
+
+This is the TGI backend for Intel Gaudi. This backend is composed of the tgi server optimized for Gaudi hardware.
+
+## Build your own image
+
+The simplest way to build TGI with the Gaudi backend is to use the provided `Makefile`:
+
+Option 1: From the project root directory:
+```bash
+make -C backends/gaudi image
+```
+
+Option 2: From the Gaudi backend directory:
+```bash
+cd backends/gaudi
+make image
+```
+
+You can now run the server with the following command:
+
+Option 1: Sharded:
+```bash
+model=meta-llama/Llama-3.1-8B-Instruct
+hf_token=$(cat ${HOME}/.cache/huggingface/token)
+volume=${HOME}/.cache/huggingface
+
+docker run --runtime=habana --ipc=host --cap-add=sys_nice \
+ -p 8080:80 -v $volume:/data \
+ -e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
+ tgi-gaudi --model-id $model \
+ --sharded true --num-shard 8 \
+ --max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 8 --max-batch-prefill-tokens 2048
+```
+
+Option 2: Non-sharded:
+```bash
+model=meta-llama/Llama-3.1-8B-Instruct
+hf_token=$(cat ${HOME}/.cache/huggingface/token)
+volume=${HOME}/.cache/huggingface
+
+docker run --runtime=habana --ipc=host --cap-add=sys_nice \
+ -p 8080:80 -v $volume:/data \
+ -e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
+ tgi-gaudi --model-id $model \
+ --max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 4 --max-batch-prefill-tokens 2048
+```
+
+## Contributing
+
+### Local Development
+
+This is useful if you want to run the server locally for better debugging.
+```bash
+make -C backends/gaudi run-local-dev-container
+```
+
+Then run the following command inside the container to install tgi for gaudi:
+```bash
+make -C backends/gaudi local-dev-install
+```
+
+Add rust to path:
+```bash
+. "$HOME/.cargo/env"
+```
+
+Option 1: Run the server (sharded model):
+```bash
+LOG_LEVEL=debug text-generation-launcher \
+ --model-id meta-llama/Llama-3.1-8B-Instruct \
+ --sharded true \
+ --num-shard 8 \
+ --max-input-tokens 512 \
+ --max-total-tokens 1024 \
+ --max-batch-size 8 \
+ --max-batch-prefill-tokens 2048
+```
+
+Option 2: Run the server (non-sharded model):
+```bash
+LOG_LEVEL=debug text-generation-launcher \
+ --model-id meta-llama/Llama-3.1-8B-Instruct \
+ --max-input-tokens 512 \
+ --max-total-tokens 1024 \
+ --max-batch-size 4 \
+ --max-batch-prefill-tokens 2048
+```
+
+You can then test the server with the following curl command from another terminal (can be outside the container):
+```bash
+curl 127.0.0.1:8080/generate \
+ -X POST \
+ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
+ -H 'Content-Type: application/json'
+```
+
+### Integration tests
+
+Install the dependencies:
+```bash
+pip install -r integration-tests/requirements.txt
+```
+
+To run the integration tests, you need to first build the image:
+```bash
+make -C backends/gaudi image
+```
+
+Then run the following command to run the integration tests (CI tests):
+```bash
+make -C backends/gaudi run-integration-tests
+```
+
+To run the integration tests with all models, you can run the following command:
+```bash
+make -C backends/gaudi run-integration-tests-with-all-models
+```
+
+To capture the expected outputs for the integration tests, you can run the following command:
+```bash
+make -C backends/gaudi capture-expected-outputs-for-integration-tests
+```
+
+#### How the integration tests works
+The integration tests works as follows:
+
+1. Start a tgi server in a container, similar to the command:
+```bash
+docker run --runtime=habana --ipc=host --cap-add=sys_nice \
+ -p 8080:80 -v $volume:/data \
+ -e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
+ tgi-gaudi --model-id $model \
+ --max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 4 --max-batch-prefill-tokens 2048
+```
+
+2. Do a /generate request to the server, similar to the command:
+```bash
+curl 127.0.0.1:8080/generate \
+ -X POST \
+ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
+ -H 'Content-Type: application/json'
+```
+
+3. Check the output of the server against the expected output:
+```python
+assert curl_output == expected_output
+```
+
+This is the repeated for a set of models and configurations.
diff --git a/backends/gaudi/examples/docker_commands/docker_commands.md b/backends/gaudi/examples/docker_commands/docker_commands.md
new file mode 100644
index 00000000000..ccacfbdb574
--- /dev/null
+++ b/backends/gaudi/examples/docker_commands/docker_commands.md
@@ -0,0 +1,112 @@
+# Examples of Docker Commands for Gaudi Backend
+
+This page gives a list of examples of docker run commands for some of the most popular models.
+
+> **Note:** The parameters are chosen for Gaudi2 hardware to maximize performance on this given hardware, please adjust the parameters based on your hardware. For example, if you are using Gaudi3, you may want to increase the batch size.
+
+## Default Precision (BF16)
+
+### Llama3.1-8B on 1 card (BF16)
+
+```bash
+model=meta-llama/Meta-Llama-3.1-8B-Instruct
+hf_token=YOUR_ACCESS_TOKEN
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+
+docker run -p 8080:80 \
+ --runtime=habana \
+ --cap-add=sys_nice \
+ --ipc=host \
+ -v $volume:/data \
+ -e HF_TOKEN=$hf_token \
+ ghcr.io/huggingface/text-generation-inference:3.3.5-gaudi \
+ --model-id $model \
+ --max-input-tokens 1024 --max-total-tokens 2048 \
+ --max-batch-prefill-tokens 2048 --max-batch-size 32 \
+ --max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
+```
+
+### Llama3.1-70B 8 cards (BF16)
+
+```bash
+model=meta-llama/Meta-Llama-3.1-70B-Instruct
+hf_token=YOUR_ACCESS_TOKEN
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+
+docker run -p 8080:80 \
+ --runtime=habana \
+ --cap-add=sys_nice \
+ --ipc=host \
+ -v $volume:/data \
+ -e HF_TOKEN=$hf_token \
+ ghcr.io/huggingface/text-generation-inference:3.3.5-gaudi \
+ --model-id $model \
+ --sharded true --num-shard 8 \
+ --max-input-tokens 1024 --max-total-tokens 2048 \
+ --max-batch-prefill-tokens 4096 --max-batch-size 256 \
+ --max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
+```
+
+### Llava-v1.6-Mistral-7B on 1 card (BF16)
+
+```bash
+model=llava-hf/llava-v1.6-mistral-7b-hf
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+
+docker run -p 8080:80 \
+ --runtime=habana \
+ --cap-add=sys_nice \
+ --ipc=host \
+ -v $volume:/data \
+ ghcr.io/huggingface/text-generation-inference:3.3.5-gaudi \
+ --model-id $model \
+ --max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
+ --max-total-tokens 8192 --max-batch-size 4
+```
+
+## FP8 Precision
+
+You could also set kv cache dtype to FP8 when launching the server, fp8_e4m3fn is supported in Gaudi
+
+## Llama3-8B on 1 Card (FP8)
+
+```bash
+model=RedHatAI/Meta-Llama-3-8B-Instruct-FP8-KV
+hf_token=YOUR_ACCESS_TOKEN
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+
+docker run -p 8080:80 \
+ --runtime=habana \
+ --cap-add=sys_nice \
+ --ipc=host \
+ -v $volume:/data \
+ -e HF_TOKEN=$hf_token \
+ ghcr.io/huggingface/text-generation-inference:3.3.5-gaudi \
+ --model-id $model \
+ --kv-cache-dtype fp8_e4m3fn \
+ --max-input-tokens 1024 --max-total-tokens 2048 \
+ --max-batch-prefill-tokens 2048 --max-batch-size 32 \
+ --max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
+```
+
+## Llama3-70B on 8 cards (FP8)
+
+```bash
+model=RedHatAI/Meta-Llama-3-70B-Instruct-FP8
+hf_token=YOUR_ACCESS_TOKEN
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+
+docker run -p 8080:80 \
+ --runtime=habana \
+ --cap-add=sys_nice \
+ --ipc=host \
+ -v $volume:/data \
+ -e HF_TOKEN=$hf_token \
+ ghcr.io/huggingface/text-generation-inference:3.3.5-gaudi \
+ --model-id $model \
+ --kv-cache-dtype fp8_e4m3fn \
+ --sharded true --num-shard 8 \
+ --max-input-tokens 1024 --max-total-tokens 2048 \
+ --max-batch-prefill-tokens 4096 --max-batch-size 256 \
+ --max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
+```
diff --git a/backends/gaudi/server/.gitignore b/backends/gaudi/server/.gitignore
new file mode 100644
index 00000000000..576746eec14
--- /dev/null
+++ b/backends/gaudi/server/.gitignore
@@ -0,0 +1,164 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+text_generation_server/__pycache__/
+text_generation_server/pb/__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+transformers
+safetensors
+flash-attention/
+flash-attention-v2/
+vllm/
+llm-awq/
+eetq/
+mamba/
diff --git a/backends/gaudi/server/Makefile b/backends/gaudi/server/Makefile
new file mode 100644
index 00000000000..b5b843387cd
--- /dev/null
+++ b/backends/gaudi/server/Makefile
@@ -0,0 +1,38 @@
+include Makefile-flash-att
+include Makefile-flash-att-v2
+include Makefile-vllm
+include Makefile-awq
+include Makefile-eetq
+include Makefile-selective-scan
+
+PROTO_PATH ?= ../proto/v3
+
+unit-tests:
+ pytest -s -vv -m "not private" tests
+
+gen-server:
+ # Compile protos
+ pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir
+ mkdir text_generation_server/pb || true
+ python -m grpc_tools.protoc -I$(PROTO_PATH) --python_out=text_generation_server/pb \
+ --grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb $(PROTO_PATH)/generate.proto
+ find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
+ touch text_generation_server/pb/__init__.py
+
+install: gen-server
+ pip install pip --upgrade
+ pip install --no-deps -r requirements.txt
+ pip install -e "."
+
+run-dev:
+ SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
+
+install-poetry:
+ curl -sSL https://install.python-poetry.org | python3 -
+
+update-lock:
+ rm poetry.lock
+ poetry lock --no-update
+
+export-requirements:
+ poetry export -o requirements.txt --without-hashes
diff --git a/backends/gaudi/server/Makefile-awq b/backends/gaudi/server/Makefile-awq
new file mode 100644
index 00000000000..4e074a13324
--- /dev/null
+++ b/backends/gaudi/server/Makefile-awq
@@ -0,0 +1,15 @@
+# Fork that adds only the correct stream to this kernel in order
+# to make cuda graphs work.
+awq_commit := bd1dc2d5254345cc76ab71894651fb821275bdd4
+
+awq:
+ rm -rf llm-awq
+ git clone https://github.com/huggingface/llm-awq
+
+build-awq: awq
+ cd llm-awq/ && git fetch && git checkout $(awq_commit)
+ cd llm-awq/awq/kernels && python setup.py build
+
+install-awq: build-awq
+ pip uninstall awq_inference_engine -y || true
+ cd llm-awq/awq/kernels && python setup.py install
diff --git a/backends/gaudi/server/Makefile-eetq b/backends/gaudi/server/Makefile-eetq
new file mode 100644
index 00000000000..726e47b5729
--- /dev/null
+++ b/backends/gaudi/server/Makefile-eetq
@@ -0,0 +1,13 @@
+eetq_commit := 1657b1504faa359e2ce0ac02999439d7ac8c74c0
+
+eetq:
+ # Clone eetq
+ pip install packaging
+ git clone https://github.com/NetEase-FuXi/EETQ.git eetq
+
+build-eetq: eetq
+ cd eetq && git fetch && git checkout $(eetq_commit) && git submodule update --init --recursive
+ cd eetq && python setup.py build
+
+install-eetq: build-eetq
+ cd eetq && python setup.py install
diff --git a/backends/gaudi/server/Makefile-fbgemm b/backends/gaudi/server/Makefile-fbgemm
new file mode 100644
index 00000000000..3b8061a1fc4
--- /dev/null
+++ b/backends/gaudi/server/Makefile-fbgemm
@@ -0,0 +1,15 @@
+fbgemm_commit := v0.8.0
+
+build-fbgemm:
+ @if [ ! -d "fbgemm" ]; then \
+ git clone https://github.com/pytorch/FBGEMM.git fbgemm; \
+ fi
+ cd fbgemm && git fetch && git checkout $(fbgemm_commit) && \
+ git submodule update --init --recursive && \
+ cd fbgemm_gpu && \
+ pip install -r requirements.txt && \
+ CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build
+
+install-fbgemm: build-fbgemm
+ cd fbgemm/fbgemm_gpu && \
+ CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install
diff --git a/backends/gaudi/server/Makefile-flash-att b/backends/gaudi/server/Makefile-flash-att
new file mode 100644
index 00000000000..29e75bc4810
--- /dev/null
+++ b/backends/gaudi/server/Makefile-flash-att
@@ -0,0 +1,12 @@
+flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec
+
+build-flash-attention:
+ if [ ! -d 'flash-attention' ]; then \
+ pip install -U packaging ninja --no-cache-dir && \
+ git clone https://github.com/HazyResearch/flash-attention.git; \
+ fi
+ cd flash-attention && git fetch && git checkout $(flash_att_commit) && \
+ MAX_JOBS=8 python setup.py build && cd csrc/layer_norm && python setup.py build && cd ../rotary && python setup.py build
+
+install-flash-attention: build-flash-attention
+ cd flash-attention && git checkout $(flash_att_commit) && MAX_JOBS=8 python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install
diff --git a/backends/gaudi/server/Makefile-flash-att-v2 b/backends/gaudi/server/Makefile-flash-att-v2
new file mode 100644
index 00000000000..a9cdf782270
--- /dev/null
+++ b/backends/gaudi/server/Makefile-flash-att-v2
@@ -0,0 +1,21 @@
+flash_att_v2_commit_cuda := v2.6.1
+flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4
+
+build-flash-attention-v2-cuda:
+ pip install -U packaging wheel
+ pip install flash-attn==$(flash_att_v2_commit_cuda)
+
+install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
+ echo "Flash v2 installed"
+
+build-flash-attention-v2-rocm:
+ if [ ! -d 'flash-attention-v2' ]; then \
+ pip install -U packaging ninja --no-cache-dir && \
+ git clone https://github.com/mht-sharma/flash-attention.git flash-attention-v2 && \
+ cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \
+ git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \
+ fi
+
+install-flash-attention-v2-rocm: build-flash-attention-v2-rocm
+ cd flash-attention-v2 && \
+ GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install
diff --git a/backends/gaudi/server/Makefile-selective-scan b/backends/gaudi/server/Makefile-selective-scan
new file mode 100644
index 00000000000..b93b517d607
--- /dev/null
+++ b/backends/gaudi/server/Makefile-selective-scan
@@ -0,0 +1,28 @@
+selective_scan_commit := 2a3704fd47ba817b415627b06fd796b971fdc137
+
+causal-conv1d:
+ rm -rf causal-conv1d
+ git clone https://github.com/Dao-AILab/causal-conv1d.git
+
+build-causal-conv1d: causal-conv1d
+ cd causal-conv1d/ && git checkout v1.1.1 # known latest working version tag
+ cd causal-conv1d/ && CAUSAL_CONV1D_FORCE_BUILD=TRUE python setup.py build
+
+install-causal-conv1d: build-causal-conv1d
+ pip uninstall causal-conv1d -y || true
+ cd causal-conv1d/ && pip install .
+
+# selective-scan dependends on causal-conv1d
+selective-scan:
+ rm -rf mamba
+ git clone https://github.com/state-spaces/mamba.git mamba
+
+build-selective-scan: selective-scan
+ cd mamba/ && git fetch && git checkout $(selective_scan_commit)
+ cd mamba && python setup.py build
+
+install-selective-scan: install-causal-conv1d build-selective-scan
+ pip uninstall selective-scan-cuda -y || true
+ cd mamba && pip install .
+
+build-all: build-causal-conv1d build-selective-scan
diff --git a/backends/gaudi/server/Makefile-vllm b/backends/gaudi/server/Makefile-vllm
new file mode 100644
index 00000000000..18dcc4a0c53
--- /dev/null
+++ b/backends/gaudi/server/Makefile-vllm
@@ -0,0 +1,23 @@
+commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
+commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
+build-vllm-cuda:
+ if [ ! -d 'vllm' ]; then \
+ pip install -U ninja packaging --no-cache-dir && \
+ git clone https://github.com/Narsil/vllm.git vllm; \
+ fi
+ cd vllm && git fetch origin && git checkout $(commit_cuda) && python setup.py build
+
+install-vllm-cuda: build-vllm-cuda
+ cd vllm && git fetch origin && git checkout $(commit_cuda) && pip install -e .
+
+build-vllm-rocm:
+ if [ ! -d 'vllm' ]; then \
+ pip install -U ninja packaging --no-cache-dir && \
+ git clone https://github.com/mht-sharma/vllm.git vllm; \
+ fi
+ cd vllm && git fetch && git checkout $(commit_rocm) && \
+ PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
+
+install-vllm-rocm: build-vllm-rocm
+ cd vllm && git fetch && git checkout $(commit_rocm) && \
+ PYTORCH_ROCM_ARCH="gfx90a;gfx942" pip install -e .
diff --git a/backends/gaudi/server/README.md b/backends/gaudi/server/README.md
new file mode 100644
index 00000000000..b8208f9eac4
--- /dev/null
+++ b/backends/gaudi/server/README.md
@@ -0,0 +1,15 @@
+# Text Generation Inference Python gRPC Server
+
+A Python gRPC server for Text Generation Inference
+
+## Install
+
+```shell
+make install
+```
+
+## Run
+
+```shell
+make run-dev
+```
diff --git a/backends/gaudi/server/dill-0.3.7-patch.sh b/backends/gaudi/server/dill-0.3.7-patch.sh
new file mode 100644
index 00000000000..5efd6c54b8a
--- /dev/null
+++ b/backends/gaudi/server/dill-0.3.7-patch.sh
@@ -0,0 +1,91 @@
+#!/bin/bash
+git clone -b dill-0.3.7 https://github.com/uqfoundation/dill.git
+pushd dill
+cat < dill-0.3.7.patch
+diff --git a/dill/_dill.py b/dill/_dill.py
+index d0cf543..f6eb662 100644
+--- a/dill/_dill.py
++++ b/dill/_dill.py
+@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered
+ XRangeType = range
+ from types import MappingProxyType as DictProxyType, new_class
+ from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError
+-import __main__ as _main_module
++class _LazyMainModule(object):
++ _module = None
++ @property
++ def module(self):
++ if self._module is None:
++ import __main__ as _m_module
++ self._module = _m_module
++ return self._module
++_main_module = _LazyMainModule()
+ import marshal
+ import gc
+ # import zlib
+@@ -353,7 +361,7 @@ class Pickler(StockPickler):
+ _fmode = kwds.pop('fmode', None)
+ _recurse = kwds.pop('recurse', None)
+ StockPickler.__init__(self, file, *args, **kwds)
+- self._main = _main_module
++ self._main = _main_module.module
+ self._diff_cache = {}
+ self._byref = settings['byref'] if _byref is None else _byref
+ self._strictio = False #_strictio
+@@ -435,12 +443,12 @@ class Unpickler(StockUnpickler):
+ settings = Pickler.settings
+ _ignore = kwds.pop('ignore', None)
+ StockUnpickler.__init__(self, *args, **kwds)
+- self._main = _main_module
++ self._main = _main_module.module
+ self._ignore = settings['ignore'] if _ignore is None else _ignore
+
+ def load(self): #NOTE: if settings change, need to update attributes
+ obj = StockUnpickler.load(self)
+- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'):
++ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'):
+ if not self._ignore:
+ # point obj class to main
+ try: obj.__class__ = getattr(self._main, type(obj).__name__)
+@@ -1194,11 +1202,11 @@ def save_module_dict(pickler, obj):
+ logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
+ pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
+ logger.trace(pickler, "# D1")
+- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
++ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__):
+ logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
+ pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
+ logger.trace(pickler, "# D3")
+- elif '__name__' in obj and obj != _main_module.__dict__ \\
++ elif '__name__' in obj and obj != _main_module.module.__dict__ \\
+ and type(obj['__name__']) is str \\
+ and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None):
+ logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
+diff --git a/dill/session.py b/dill/session.py
+index 74234ab..1be8d89 100644
+--- a/dill/session.py
++++ b/dill/session.py
+@@ -233,7 +233,7 @@ def dump_module(
+ protocol = settings['protocol']
+ main = module
+ if main is None:
+- main = _main_module
++ main = _main_module.module
+ elif isinstance(main, str):
+ main = _import_module(main)
+ if not isinstance(main, ModuleType):
+@@ -501,7 +501,7 @@ def load_module(
+ pass
+ assert loaded is main
+ _restore_modules(unpickler, main)
+- if main is _main_module or main is module:
++ if main is _main_module.module or main is module:
+ return None
+ else:
+ return main
+
+EOF
+git apply dill-0.3.7.patch
+python -m pip install .
+popd
+rm -fr dill
diff --git a/backends/gaudi/server/dill-0.3.8-patch.sh b/backends/gaudi/server/dill-0.3.8-patch.sh
new file mode 100644
index 00000000000..414790e7bae
--- /dev/null
+++ b/backends/gaudi/server/dill-0.3.8-patch.sh
@@ -0,0 +1,91 @@
+#!/bin/bash
+git clone -b 0.3.8 https://github.com/uqfoundation/dill.git
+pushd dill
+cat < dill-0.3.8.patch
+diff --git a/dill/_dill.py b/dill/_dill.py
+index d42432f..1d251e6 100644
+--- a/dill/_dill.py
++++ b/dill/_dill.py
+@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered
+ XRangeType = range
+ from types import MappingProxyType as DictProxyType, new_class
+ from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError
+-import __main__ as _main_module
++class _LazyMainModule(object):
++ _module = None
++ @property
++ def module(self):
++ if self._module is None:
++ import __main__ as _m_module
++ self._module = _m_module
++ return self._module
++_main_module = _LazyMainModule()
+ import marshal
+ import gc
+ # import zlib
+@@ -355,7 +363,7 @@ class Pickler(StockPickler):
+ _fmode = kwds.pop('fmode', None)
+ _recurse = kwds.pop('recurse', None)
+ StockPickler.__init__(self, file, *args, **kwds)
+- self._main = _main_module
++ self._main = _main_module.module
+ self._diff_cache = {}
+ self._byref = settings['byref'] if _byref is None else _byref
+ self._strictio = False #_strictio
+@@ -437,12 +445,12 @@ class Unpickler(StockUnpickler):
+ settings = Pickler.settings
+ _ignore = kwds.pop('ignore', None)
+ StockUnpickler.__init__(self, *args, **kwds)
+- self._main = _main_module
++ self._main = _main_module.module
+ self._ignore = settings['ignore'] if _ignore is None else _ignore
+
+ def load(self): #NOTE: if settings change, need to update attributes
+ obj = StockUnpickler.load(self)
+- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'):
++ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'):
+ if not self._ignore:
+ # point obj class to main
+ try: obj.__class__ = getattr(self._main, type(obj).__name__)
+@@ -1199,11 +1207,11 @@ def save_module_dict(pickler, obj):
+ logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
+ pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
+ logger.trace(pickler, "# D1")
+- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
++ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__):
+ logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
+ pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
+ logger.trace(pickler, "# D3")
+- elif '__name__' in obj and obj != _main_module.__dict__ \\
++ elif '__name__' in obj and obj != _main_module.module.__dict__ \\
+ and type(obj['__name__']) is str \\
+ and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None):
+ logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
+diff --git a/dill/session.py b/dill/session.py
+index e91068a..a921b43 100644
+--- a/dill/session.py
++++ b/dill/session.py
+@@ -233,7 +233,7 @@ def dump_module(
+ protocol = settings['protocol']
+ main = module
+ if main is None:
+- main = _main_module
++ main = _main_module.module
+ elif isinstance(main, str):
+ main = _import_module(main)
+ if not isinstance(main, ModuleType):
+@@ -501,7 +501,7 @@ def load_module(
+ pass
+ assert loaded is main
+ _restore_modules(unpickler, main)
+- if main is _main_module or main is module:
++ if main is _main_module.module or main is module:
+ return None
+ else:
+ return main
+
+EOF
+git apply dill-0.3.8.patch
+python -m pip install .
+popd
+rm -fr dill
diff --git a/backends/gaudi/server/poetry.lock b/backends/gaudi/server/poetry.lock
new file mode 100644
index 00000000000..c6cace66c4f
--- /dev/null
+++ b/backends/gaudi/server/poetry.lock
@@ -0,0 +1,2764 @@
+# This file is automatically @generated by Poetry 2.0.0 and should not be changed by hand.
+
+[[package]]
+name = "accelerate"
+version = "0.33.0"
+description = "Accelerate"
+optional = false
+python-versions = ">=3.8.0"
+groups = ["main"]
+files = [
+ {file = "accelerate-0.33.0-py3-none-any.whl", hash = "sha256:0a7f33d60ba09afabd028d4f0856dd19c5a734b7a596d637d9dd6e3d0eadbaf3"},
+ {file = "accelerate-0.33.0.tar.gz", hash = "sha256:11ba481ed6ea09191775df55ce464aeeba67a024bd0261a44b77b30fb439e26a"},
+]
+
+[package.dependencies]
+huggingface-hub = ">=0.21.0"
+numpy = ">=1.17,<2.0.0"
+packaging = ">=20.0"
+psutil = "*"
+pyyaml = "*"
+safetensors = ">=0.3.1"
+torch = ">=1.10.0"
+
+[package.extras]
+deepspeed = ["deepspeed (<=0.14.0)"]
+dev = ["bitsandbytes", "black (>=23.1,<24.0)", "datasets", "diffusers", "evaluate", "hf-doc-builder (>=0.3.0)", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "rich", "ruff (>=0.2.1,<0.3.0)", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"]
+quality = ["black (>=23.1,<24.0)", "hf-doc-builder (>=0.3.0)", "ruff (>=0.2.1,<0.3.0)"]
+rich = ["rich"]
+sagemaker = ["sagemaker"]
+test-dev = ["bitsandbytes", "datasets", "diffusers", "evaluate", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"]
+test-prod = ["parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist"]
+test-trackers = ["comet-ml", "dvclive", "tensorboard", "wandb"]
+testing = ["bitsandbytes", "datasets", "diffusers", "evaluate", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"]
+
+[[package]]
+name = "annotated-types"
+version = "0.7.0"
+description = "Reusable constraint types to use with typing.Annotated"
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"},
+ {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
+]
+
+[[package]]
+name = "attrs"
+version = "25.3.0"
+description = "Classes Without Boilerplate"
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3"},
+ {file = "attrs-25.3.0.tar.gz", hash = "sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b"},
+]
+
+[package.extras]
+benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
+cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
+dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
+docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier"]
+tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
+tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"]
+
+[[package]]
+name = "certifi"
+version = "2025.1.31"
+description = "Python package for providing Mozilla's CA Bundle."
+optional = false
+python-versions = ">=3.6"
+groups = ["main"]
+files = [
+ {file = "certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe"},
+ {file = "certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651"},
+]
+
+[[package]]
+name = "charset-normalizer"
+version = "3.4.1"
+description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "charset_normalizer-3.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e218488cd232553829be0664c2292d3af2eeeb94b32bea483cf79ac6a694e037"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80ed5e856eb7f30115aaf94e4a08114ccc8813e6ed1b5efa74f9f82e8509858f"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b010a7a4fd316c3c484d482922d13044979e78d1861f0e0650423144c616a46a"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4532bff1b8421fd0a320463030c7520f56a79c9024a4e88f01c537316019005a"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d973f03c0cb71c5ed99037b870f2be986c3c05e63622c017ea9816881d2dd247"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3a3bd0dcd373514dcec91c411ddb9632c0d7d92aed7093b8c3bbb6d69ca74408"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:d9c3cdf5390dcd29aa8056d13e8e99526cda0305acc038b96b30352aff5ff2bb"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:2bdfe3ac2e1bbe5b59a1a63721eb3b95fc9b6817ae4a46debbb4e11f6232428d"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:eab677309cdb30d047996b36d34caeda1dc91149e4fdca0b1a039b3f79d9a807"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-win32.whl", hash = "sha256:c0429126cf75e16c4f0ad00ee0eae4242dc652290f940152ca8c75c3a4b6ee8f"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:9f0b8b1c6d84c8034a44893aba5e767bf9c7a211e313a9605d9c617d7083829f"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8bfa33f4f2672964266e940dd22a195989ba31669bd84629f05fab3ef4e2d125"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28bf57629c75e810b6ae989f03c0828d64d6b26a5e205535585f96093e405ed1"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f08ff5e948271dc7e18a35641d2f11a4cd8dfd5634f55228b691e62b37125eb3"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:234ac59ea147c59ee4da87a0c0f098e9c8d169f4dc2a159ef720f1a61bbe27cd"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd4ec41f914fa74ad1b8304bbc634b3de73d2a0889bd32076342a573e0779e00"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eea6ee1db730b3483adf394ea72f808b6e18cf3cb6454b4d86e04fa8c4327a12"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c96836c97b1238e9c9e3fe90844c947d5afbf4f4c92762679acfe19927d81d77"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4d86f7aff21ee58f26dcf5ae81a9addbd914115cdebcbb2217e4f0ed8982e146"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:09b5e6733cbd160dcc09589227187e242a30a49ca5cefa5a7edd3f9d19ed53fd"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:5777ee0881f9499ed0f71cc82cf873d9a0ca8af166dfa0af8ec4e675b7df48e6"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:237bdbe6159cff53b4f24f397d43c6336c6b0b42affbe857970cefbb620911c8"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-win32.whl", hash = "sha256:8417cb1f36cc0bc7eaba8ccb0e04d55f0ee52df06df3ad55259b9a323555fc8b"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:d7f50a1f8c450f3925cb367d011448c39239bb3eb4117c36a6d354794de4ce76"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:73d94b58ec7fecbc7366247d3b0b10a21681004153238750bb67bd9012414545"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dad3e487649f498dd991eeb901125411559b22e8d7ab25d3aeb1af367df5efd7"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c30197aa96e8eed02200a83fba2657b4c3acd0f0aa4bdc9f6c1af8e8962e0757"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2369eea1ee4a7610a860d88f268eb39b95cb588acd7235e02fd5a5601773d4fa"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc2722592d8998c870fa4e290c2eec2c1569b87fe58618e67d38b4665dfa680d"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffc9202a29ab3920fa812879e95a9e78b2465fd10be7fcbd042899695d75e616"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:804a4d582ba6e5b747c625bf1255e6b1507465494a40a2130978bda7b932c90b"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0f55e69f030f7163dffe9fd0752b32f070566451afe180f99dbeeb81f511ad8d"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c4c3e6da02df6fa1410a7680bd3f63d4f710232d3139089536310d027950696a"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:5df196eb874dae23dcfb968c83d4f8fdccb333330fe1fc278ac5ceeb101003a9"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e358e64305fe12299a08e08978f51fc21fac060dcfcddd95453eabe5b93ed0e1"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-win32.whl", hash = "sha256:9b23ca7ef998bc739bf6ffc077c2116917eabcc901f88da1b9856b210ef63f35"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:6ff8a4a60c227ad87030d76e99cd1698345d4491638dfa6673027c48b3cd395f"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:aabfa34badd18f1da5ec1bc2715cadc8dca465868a4e73a0173466b688f29dda"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22e14b5d70560b8dd51ec22863f370d1e595ac3d024cb8ad7d308b4cd95f8313"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8436c508b408b82d87dc5f62496973a1805cd46727c34440b0d29d8a2f50a6c9"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2d074908e1aecee37a7635990b2c6d504cd4766c7bc9fc86d63f9c09af3fa11b"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:955f8851919303c92343d2f66165294848d57e9bba6cf6e3625485a70a038d11"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:44ecbf16649486d4aebafeaa7ec4c9fed8b88101f4dd612dcaf65d5e815f837f"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0924e81d3d5e70f8126529951dac65c1010cdf117bb75eb02dd12339b57749dd"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2967f74ad52c3b98de4c3b32e1a44e32975e008a9cd2a8cc8966d6a5218c5cb2"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:c75cb2a3e389853835e84a2d8fb2b81a10645b503eca9bcb98df6b5a43eb8886"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:09b26ae6b1abf0d27570633b2b078a2a20419c99d66fb2823173d73f188ce601"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fa88b843d6e211393a37219e6a1c1df99d35e8fd90446f1118f4216e307e48cd"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-win32.whl", hash = "sha256:eb8178fe3dba6450a3e024e95ac49ed3400e506fd4e9e5c32d30adda88cbd407"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-win_amd64.whl", hash = "sha256:b1ac5992a838106edb89654e0aebfc24f5848ae2547d22c2c3f66454daa11971"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f30bf9fd9be89ecb2360c7d94a711f00c09b976258846efe40db3d05828e8089"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:97f68b8d6831127e4787ad15e6757232e14e12060bec17091b85eb1486b91d8d"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7974a0b5ecd505609e3b19742b60cee7aa2aa2fb3151bc917e6e2646d7667dcf"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc54db6c8593ef7d4b2a331b58653356cf04f67c960f584edb7c3d8c97e8f39e"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:311f30128d7d333eebd7896965bfcfbd0065f1716ec92bd5638d7748eb6f936a"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:7d053096f67cd1241601111b698f5cad775f97ab25d81567d3f59219b5f1adbd"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:807f52c1f798eef6cf26beb819eeb8819b1622ddfeef9d0977a8502d4db6d534"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:dccbe65bd2f7f7ec22c4ff99ed56faa1e9f785482b9bbd7c717e26fd723a1d1e"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_s390x.whl", hash = "sha256:2fb9bd477fdea8684f78791a6de97a953c51831ee2981f8e4f583ff3b9d9687e"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:01732659ba9b5b873fc117534143e4feefecf3b2078b0a6a2e925271bb6f4cfa"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-win32.whl", hash = "sha256:7a4f97a081603d2050bfaffdefa5b02a9ec823f8348a572e39032caa8404a487"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:7b1bef6280950ee6c177b326508f86cad7ad4dff12454483b51d8b7d673a2c5d"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ecddf25bee22fe4fe3737a399d0d177d72bc22be6913acfab364b40bce1ba83c"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c60ca7339acd497a55b0ea5d506b2a2612afb2826560416f6894e8b5770d4a9"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b7b2d86dd06bfc2ade3312a83a5c364c7ec2e3498f8734282c6c3d4b07b346b8"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dd78cfcda14a1ef52584dbb008f7ac81c1328c0f58184bf9a84c49c605002da6"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e27f48bcd0957c6d4cb9d6fa6b61d192d0b13d5ef563e5f2ae35feafc0d179c"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:01ad647cdd609225c5350561d084b42ddf732f4eeefe6e678765636791e78b9a"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:619a609aa74ae43d90ed2e89bdd784765de0a25ca761b93e196d938b8fd1dbbd"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:89149166622f4db9b4b6a449256291dc87a99ee53151c74cbd82a53c8c2f6ccd"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:7709f51f5f7c853f0fb938bcd3bc59cdfdc5203635ffd18bf354f6967ea0f824"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:345b0426edd4e18138d6528aed636de7a9ed169b4aaf9d61a8c19e39d26838ca"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0907f11d019260cdc3f94fbdb23ff9125f6b5d1039b76003b5b0ac9d6a6c9d5b"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-win32.whl", hash = "sha256:ea0d8d539afa5eb2728aa1932a988a9a7af94f18582ffae4bc10b3fbdad0626e"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:329ce159e82018d646c7ac45b01a430369d526569ec08516081727a20e9e4af4"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b97e690a2118911e39b4042088092771b4ae3fc3aa86518f84b8cf6888dbdb41"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78baa6d91634dfb69ec52a463534bc0df05dbd546209b79a3880a34487f4b84f"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1a2bc9f351a75ef49d664206d51f8e5ede9da246602dc2d2726837620ea034b2"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:75832c08354f595c760a804588b9357d34ec00ba1c940c15e31e96d902093770"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0af291f4fe114be0280cdd29d533696a77b5b49cfde5467176ecab32353395c4"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0167ddc8ab6508fe81860a57dd472b2ef4060e8d378f0cc555707126830f2537"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2a75d49014d118e4198bcee5ee0a6f25856b29b12dbf7cd012791f8a6cc5c496"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:363e2f92b0f0174b2f8238240a1a30142e3db7b957a5dd5689b0e75fb717cc78"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ab36c8eb7e454e34e60eb55ca5d241a5d18b2c6244f6827a30e451c42410b5f7"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:4c0907b1928a36d5a998d72d64d8eaa7244989f7aaaf947500d3a800c83a3fd6"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:04432ad9479fa40ec0f387795ddad4437a2b50417c69fa275e212933519ff294"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-win32.whl", hash = "sha256:3bed14e9c89dcb10e8f3a29f9ccac4955aebe93c71ae803af79265c9ca5644c5"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:49402233c892a461407c512a19435d1ce275543138294f7ef013f0b63d5d3765"},
+ {file = "charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85"},
+ {file = "charset_normalizer-3.4.1.tar.gz", hash = "sha256:44251f18cd68a75b56585dd00dae26183e102cd5e0f9f1466e6df5da2ed64ea3"},
+]
+
+[[package]]
+name = "click"
+version = "8.1.8"
+description = "Composable command line interface toolkit"
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2"},
+ {file = "click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a"},
+]
+
+[package.dependencies]
+colorama = {version = "*", markers = "platform_system == \"Windows\""}
+
+[[package]]
+name = "cloudpickle"
+version = "3.1.1"
+description = "Pickler class to extend the standard pickle.Pickler functionality"
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "cloudpickle-3.1.1-py3-none-any.whl", hash = "sha256:c8c5a44295039331ee9dad40ba100a9c7297b6f988e50e87ccdf3765a668350e"},
+ {file = "cloudpickle-3.1.1.tar.gz", hash = "sha256:b216fa8ae4019d5482a8ac3c95d8f6346115d8835911fd4aefd1a445e4242c64"},
+]
+
+[[package]]
+name = "colorama"
+version = "0.4.6"
+description = "Cross-platform colored terminal text."
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
+groups = ["main", "dev"]
+files = [
+ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
+ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
+]
+markers = {main = "platform_system == \"Windows\" or sys_platform == \"win32\"", dev = "sys_platform == \"win32\""}
+
+[[package]]
+name = "deprecated"
+version = "1.2.18"
+description = "Python @deprecated decorator to deprecate old python classes, functions or methods."
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7"
+groups = ["main"]
+files = [
+ {file = "Deprecated-1.2.18-py2.py3-none-any.whl", hash = "sha256:bd5011788200372a32418f888e326a09ff80d0214bd961147cfed01b5c018eec"},
+ {file = "deprecated-1.2.18.tar.gz", hash = "sha256:422b6f6d859da6f2ef57857761bfb392480502a64c3028ca9bbe86085d72115d"},
+]
+
+[package.dependencies]
+wrapt = ">=1.10,<2"
+
+[package.extras]
+dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "setuptools", "tox"]
+
+[[package]]
+name = "diffusers"
+version = "0.31.0"
+description = "State-of-the-art diffusion in PyTorch and JAX."
+optional = false
+python-versions = ">=3.8.0"
+groups = ["main"]
+files = [
+ {file = "diffusers-0.31.0-py3-none-any.whl", hash = "sha256:cbc498ae63f4abfc7c3a07649cdcbee229ef2f9a9a1f0d19c9bbaf22f8d30c1f"},
+ {file = "diffusers-0.31.0.tar.gz", hash = "sha256:b1d01a73e45d43a0630c299173915dddd69fc50f2ae8f2ab5de4fd245eaed72f"},
+]
+
+[package.dependencies]
+filelock = "*"
+huggingface-hub = ">=0.23.2"
+importlib-metadata = "*"
+numpy = "*"
+Pillow = "*"
+regex = "!=2019.12.17"
+requests = "*"
+safetensors = ">=0.3.1"
+
+[package.extras]
+dev = ["GitPython (<3.1.19)", "Jinja2", "accelerate (>=0.31.0)", "compel (==0.1.8)", "datasets", "flax (>=0.4.1)", "hf-doc-builder (>=0.3.0)", "invisible-watermark (>=0.2.0)", "isort (>=5.5.4)", "jax (>=0.4.1)", "jaxlib (>=0.4.1)", "k-diffusion (>=0.0.12)", "librosa", "parameterized", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "pytest", "pytest-timeout", "pytest-xdist", "requests-mock (==1.10.0)", "ruff (==0.1.5)", "safetensors (>=0.3.1)", "scipy", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "torch (>=1.4,<2.5.0)", "torchvision", "transformers (>=4.41.2)", "urllib3 (<=2.0.0)"]
+docs = ["hf-doc-builder (>=0.3.0)"]
+flax = ["flax (>=0.4.1)", "jax (>=0.4.1)", "jaxlib (>=0.4.1)"]
+quality = ["hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<=2.0.0)"]
+test = ["GitPython (<3.1.19)", "Jinja2", "compel (==0.1.8)", "datasets", "invisible-watermark (>=0.2.0)", "k-diffusion (>=0.0.12)", "librosa", "parameterized", "pytest", "pytest-timeout", "pytest-xdist", "requests-mock (==1.10.0)", "safetensors (>=0.3.1)", "scipy", "sentencepiece (>=0.1.91,!=0.1.92)", "torchvision", "transformers (>=4.41.2)"]
+torch = ["accelerate (>=0.31.0)", "torch (>=1.4,<2.5.0)"]
+training = ["Jinja2", "accelerate (>=0.31.0)", "datasets", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "tensorboard"]
+
+[[package]]
+name = "diskcache"
+version = "5.6.3"
+description = "Disk Cache -- Disk and file backed persistent cache."
+optional = true
+python-versions = ">=3"
+groups = ["main"]
+files = [
+ {file = "diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19"},
+ {file = "diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc"},
+]
+
+[[package]]
+name = "exceptiongroup"
+version = "1.2.2"
+description = "Backport of PEP 654 (exception groups)"
+optional = false
+python-versions = ">=3.7"
+groups = ["dev"]
+markers = "python_version < \"3.11\""
+files = [
+ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"},
+ {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
+]
+
+[package.extras]
+test = ["pytest (>=6)"]
+
+[[package]]
+name = "filelock"
+version = "3.18.0"
+description = "A platform independent file lock."
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "filelock-3.18.0-py3-none-any.whl", hash = "sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de"},
+ {file = "filelock-3.18.0.tar.gz", hash = "sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2"},
+]
+
+[package.extras]
+docs = ["furo (>=2024.8.6)", "sphinx (>=8.1.3)", "sphinx-autodoc-typehints (>=3)"]
+testing = ["covdefaults (>=2.3)", "coverage (>=7.6.10)", "diff-cover (>=9.2.1)", "pytest (>=8.3.4)", "pytest-asyncio (>=0.25.2)", "pytest-cov (>=6)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.28.1)"]
+typing = ["typing-extensions (>=4.12.2)"]
+
+[[package]]
+name = "fsspec"
+version = "2025.3.2"
+description = "File-system specification"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "fsspec-2025.3.2-py3-none-any.whl", hash = "sha256:2daf8dc3d1dfa65b6aa37748d112773a7a08416f6c70d96b264c96476ecaf711"},
+ {file = "fsspec-2025.3.2.tar.gz", hash = "sha256:e52c77ef398680bbd6a98c0e628fbc469491282981209907bbc8aea76a04fdc6"},
+]
+
+[package.extras]
+abfs = ["adlfs"]
+adl = ["adlfs"]
+arrow = ["pyarrow (>=1)"]
+dask = ["dask", "distributed"]
+dev = ["pre-commit", "ruff"]
+doc = ["numpydoc", "sphinx", "sphinx-design", "sphinx-rtd-theme", "yarl"]
+dropbox = ["dropbox", "dropboxdrivefs", "requests"]
+full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"]
+fuse = ["fusepy"]
+gcs = ["gcsfs"]
+git = ["pygit2"]
+github = ["requests"]
+gs = ["gcsfs"]
+gui = ["panel"]
+hdfs = ["pyarrow (>=1)"]
+http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)"]
+libarchive = ["libarchive-c"]
+oci = ["ocifs"]
+s3 = ["s3fs"]
+sftp = ["paramiko"]
+smb = ["smbprotocol"]
+ssh = ["paramiko"]
+test = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "numpy", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "requests"]
+test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask[dataframe,test]", "moto[server] (>4,<5)", "pytest-timeout", "xarray"]
+test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"]
+tqdm = ["tqdm"]
+
+[[package]]
+name = "googleapis-common-protos"
+version = "1.70.0"
+description = "Common protobufs used in Google APIs"
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "googleapis_common_protos-1.70.0-py3-none-any.whl", hash = "sha256:b8bfcca8c25a2bb253e0e0b0adaf8c00773e5e6af6fd92397576680b807e0fd8"},
+ {file = "googleapis_common_protos-1.70.0.tar.gz", hash = "sha256:0e1b44e0ea153e6594f9f394fef15193a68aaaea2d843f83e2742717ca753257"},
+]
+
+[package.dependencies]
+protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
+
+[package.extras]
+grpc = ["grpcio (>=1.44.0,<2.0.0)"]
+
+[[package]]
+name = "grpc-interceptor"
+version = "0.15.4"
+description = "Simplifies gRPC interceptors"
+optional = false
+python-versions = ">=3.7,<4.0"
+groups = ["main"]
+files = [
+ {file = "grpc-interceptor-0.15.4.tar.gz", hash = "sha256:1f45c0bcb58b6f332f37c637632247c9b02bc6af0fdceb7ba7ce8d2ebbfb0926"},
+ {file = "grpc_interceptor-0.15.4-py3-none-any.whl", hash = "sha256:0035f33228693ed3767ee49d937bac424318db173fef4d2d0170b3215f254d9d"},
+]
+
+[package.dependencies]
+grpcio = ">=1.49.1,<2.0.0"
+
+[package.extras]
+testing = ["protobuf (>=4.21.9)"]
+
+[[package]]
+name = "grpcio"
+version = "1.72.0rc1"
+description = "HTTP/2-based RPC framework"
+optional = false
+python-versions = ">=3.9"
+groups = ["main", "dev"]
+files = [
+ {file = "grpcio-1.72.0rc1-cp310-cp310-linux_armv7l.whl", hash = "sha256:db7db4b246a7fb21aeb70e7220be480948aa9c535eaa777ea0c840416ed8cac9"},
+ {file = "grpcio-1.72.0rc1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:baf028e61662fd320c18fb50070b6e330fa24b2b3a4d113f4d57b41e0f5b5873"},
+ {file = "grpcio-1.72.0rc1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:bf84cf17dfbf49ebe11b081b7a3c83b23625a80c979741e2e98b0ddb41080397"},
+ {file = "grpcio-1.72.0rc1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3fd6f8700d34754b32d13af234da2e413f408c8b741c8039f11beb06d53c3f6a"},
+ {file = "grpcio-1.72.0rc1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f05d243b8d814dd1c6fca19e4e0c5986fc70e2c3aa29e2c7c67e877e4c03ede6"},
+ {file = "grpcio-1.72.0rc1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:390a70394e2c315d7c480496db259ec16c00baeebf759c8967247269f0fee981"},
+ {file = "grpcio-1.72.0rc1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:b08973c62eda11343e7131d78635d50ae0c138a8f39eb817ca83cca842527d04"},
+ {file = "grpcio-1.72.0rc1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ce539397a258af1dee26118c40327004d023617bc99493baaf8e7938491f7361"},
+ {file = "grpcio-1.72.0rc1-cp310-cp310-win32.whl", hash = "sha256:4f97f628095bbdf6d4c2c15c1bc18f0514f90781528bc6082bb697ccc71d4f42"},
+ {file = "grpcio-1.72.0rc1-cp310-cp310-win_amd64.whl", hash = "sha256:dbcdf7a5463b61fca1586b54f7ea3c9dfd159f535224f457ae307f52d8d4a839"},
+ {file = "grpcio-1.72.0rc1-cp311-cp311-linux_armv7l.whl", hash = "sha256:23ebb3947783f10fec3e1d0b29b94db8e72f721900d1dd9c1d6db5876da69066"},
+ {file = "grpcio-1.72.0rc1-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:fd96b20846907ed4cd95bf1d628f16732f450114bde897eedb323fc3bc1eddb3"},
+ {file = "grpcio-1.72.0rc1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:6df1ba4a5f5793ae210699e1b1745f77a4ac17f73510fc36ee12c215f02523b4"},
+ {file = "grpcio-1.72.0rc1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a3398957c611f0af7cee4fdd34268b6664be8689eae0327440efb794e544908b"},
+ {file = "grpcio-1.72.0rc1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ef66029da9cbe94ba3047c1b04653e1d5096ca8d036eb6e24092f0e847d2c4f"},
+ {file = "grpcio-1.72.0rc1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6566e3e3458805381f8714492e8f559f082f8955ccd1c98d71f8afc0612dc841"},
+ {file = "grpcio-1.72.0rc1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:3c799bfa92450e95d3f1f9cc4b7d8cbefc8bd4356d3f6573d2fb5e698353192a"},
+ {file = "grpcio-1.72.0rc1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a251992531f3b16be3c013ec45a9caa69ecfe9b45335652d5681659f6d117233"},
+ {file = "grpcio-1.72.0rc1-cp311-cp311-win32.whl", hash = "sha256:c9e5f2c628dedf0886b774eee17e003a043941024e68ee2ebe76be6981a7baab"},
+ {file = "grpcio-1.72.0rc1-cp311-cp311-win_amd64.whl", hash = "sha256:8b9c0a84ff584da3f5c0cb04ee3d87c0bc70d41ab5a21d3b943963a94c622892"},
+ {file = "grpcio-1.72.0rc1-cp312-cp312-linux_armv7l.whl", hash = "sha256:188ac9d8cb05c250e212ba946a65a8541419bdfd803373d6b7fb8b10fe5ff991"},
+ {file = "grpcio-1.72.0rc1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:8bd956711dc21235bc78a70bf04a28b3f747c6576b9bb79362803707fec9f705"},
+ {file = "grpcio-1.72.0rc1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:b032b9cbb325e28ff847b6aae1df5a090aa49b682dc80c926b24a96de43c01aa"},
+ {file = "grpcio-1.72.0rc1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1ca12a4388a40eb0411264af291184e2cca38176996b591ac047844abd81d40b"},
+ {file = "grpcio-1.72.0rc1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7cefd52f392f4d6747b401f825901c48176737f7b03b17be0a0a638da194749"},
+ {file = "grpcio-1.72.0rc1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:1a24408fb051b70efa440b95f7e1acbb1c3067609934aa53a953d8d2cfc4d824"},
+ {file = "grpcio-1.72.0rc1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:c7b37608d14792d3dacb9aba55b96a17a074e139c4567b0ac5c1926302add910"},
+ {file = "grpcio-1.72.0rc1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:81ca42a96299ca617f3bc7b60660f15cabb98de6fce440ecd4d0640a5554345f"},
+ {file = "grpcio-1.72.0rc1-cp312-cp312-win32.whl", hash = "sha256:9ff2ef2a553d4edc8c620df3735b15a1e7dc05a60262e8c28445f2676fb09189"},
+ {file = "grpcio-1.72.0rc1-cp312-cp312-win_amd64.whl", hash = "sha256:3c9a6613662591c198d9e4e499f3336bc5c1c0e3fe3f0922cf48e74b37b3dcd1"},
+ {file = "grpcio-1.72.0rc1-cp313-cp313-linux_armv7l.whl", hash = "sha256:995e3e5c43cab6d0f1922b43b3c01a2624a4497ce91c3124e807497654301c59"},
+ {file = "grpcio-1.72.0rc1-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:8dfb0ff2ddd708dbecdffa37245b79aef707e789ffb0fc6a8be01608d982afcd"},
+ {file = "grpcio-1.72.0rc1-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:7e08eb53d6123995da63df90ce50e5b834de0a8ebfb1a3ac0890a2e246d2771c"},
+ {file = "grpcio-1.72.0rc1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:71cb52c0956fe7868692b490fda341a52d8187fab94e1136f5bd253c8e3560ac"},
+ {file = "grpcio-1.72.0rc1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dcf76ce8d4a6829f112ad88c4e6d528dbef922e01834d4a5cc3718bf599f7e84"},
+ {file = "grpcio-1.72.0rc1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:8852b6234a52b6b694a5f9a5a687d59127b3e71c8e345eebd6d483abbc412217"},
+ {file = "grpcio-1.72.0rc1-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:d1a0fee8420d9e453dc8cba1c7c067ca2d3054487cb6616ab8dad41f15e57465"},
+ {file = "grpcio-1.72.0rc1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:a13149f4fd3904093fa2dba484744dd7205f536650a533ab24dd95cca393c14c"},
+ {file = "grpcio-1.72.0rc1-cp313-cp313-win32.whl", hash = "sha256:cebe148511a1965363fc6aafd60a488fe9dc5d74dd92a59a8ecba66ddd53c573"},
+ {file = "grpcio-1.72.0rc1-cp313-cp313-win_amd64.whl", hash = "sha256:843352c352970a1df5bbf7da68d2770781f4bff2c85a4a0d20cc6eaaadf26e59"},
+ {file = "grpcio-1.72.0rc1-cp39-cp39-linux_armv7l.whl", hash = "sha256:2083c0cdff47ff7d4b093d05d703baeeef8db3b2c1f43c9f9d4288a99e444cdd"},
+ {file = "grpcio-1.72.0rc1-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:42df7e0f9d66f5c9b246d8e1da74605bce27b10dec20b6fc204edd6e7178da2d"},
+ {file = "grpcio-1.72.0rc1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:1190c2e4f221b5bd0e6eba3e44d6758ef48eeb2216dcb9734c158e8a5d8ce6a3"},
+ {file = "grpcio-1.72.0rc1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6d6c8d2ea63e1cdaaa81271e5c867fcd9732050324df372ff9d3163968be68c8"},
+ {file = "grpcio-1.72.0rc1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f6ee161b9d112232e5d6be437bf56383dca2334bd17e8b7a4a3f97f33722bdd"},
+ {file = "grpcio-1.72.0rc1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9abbdf945e3b151603d642f2bc7a637b87af2e3480ed047689bad9eb4fa9c712"},
+ {file = "grpcio-1.72.0rc1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:2edab5d26319a1fed695ec658efe3846b75e0c7f3a6202b042099c9b11dc10fd"},
+ {file = "grpcio-1.72.0rc1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:03b46e0041bee18a786ccef978bc29a26e4bd1b73a6ca0b21252387167843ff1"},
+ {file = "grpcio-1.72.0rc1-cp39-cp39-win32.whl", hash = "sha256:9b861cbfb63433e02b52f9971644095bec4a5fcd1e4d3f94e18cfad38f649d53"},
+ {file = "grpcio-1.72.0rc1-cp39-cp39-win_amd64.whl", hash = "sha256:2416792a567cba9f92bffc1a55ce0f2c8106956a2e32bfe8a22a8094a56b7108"},
+ {file = "grpcio-1.72.0rc1.tar.gz", hash = "sha256:221793dccd3332060f426975a041d319d6d57323d857d4afc25257ec4a5a67f3"},
+]
+
+[package.extras]
+protobuf = ["grpcio-tools (>=1.72.0rc1)"]
+
+[[package]]
+name = "grpcio-reflection"
+version = "1.71.0"
+description = "Standard Protobuf Reflection Service for gRPC"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "grpcio_reflection-1.71.0-py3-none-any.whl", hash = "sha256:8c88bdd9c92fcdd4d5df119997be05ecd0d7e10d377ec4a5072db507d2894612"},
+ {file = "grpcio_reflection-1.71.0.tar.gz", hash = "sha256:51504e977057ffabe66d1ed55557b15e969c42bb3a1f28ee45d730dd5f983bb5"},
+]
+
+[package.dependencies]
+grpcio = ">=1.71.0"
+protobuf = ">=5.26.1,<6.0dev"
+
+[[package]]
+name = "grpcio-status"
+version = "1.71.0"
+description = "Status proto mapping for gRPC"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "grpcio_status-1.71.0-py3-none-any.whl", hash = "sha256:843934ef8c09e3e858952887467f8256aac3910c55f077a359a65b2b3cde3e68"},
+ {file = "grpcio_status-1.71.0.tar.gz", hash = "sha256:11405fed67b68f406b3f3c7c5ae5104a79d2d309666d10d61b152e91d28fb968"},
+]
+
+[package.dependencies]
+googleapis-common-protos = ">=1.5.5"
+grpcio = ">=1.71.0"
+protobuf = ">=5.26.1,<6.0dev"
+
+[[package]]
+name = "grpcio-tools"
+version = "1.71.0"
+description = "Protobuf code generator for gRPC"
+optional = false
+python-versions = ">=3.9"
+groups = ["dev"]
+files = [
+ {file = "grpcio_tools-1.71.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:f4ad7f0d756546902597053d70b3af2606fbd70d7972876cd75c1e241d22ae00"},
+ {file = "grpcio_tools-1.71.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:64bdb291df61cf570b5256777ad5fe2b1db6d67bc46e55dc56a0a862722ae329"},
+ {file = "grpcio_tools-1.71.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:8dd9795e982d77a4b496f7278b943c2563d9afde2069cdee78c111a40cc4d675"},
+ {file = "grpcio_tools-1.71.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c1b5860c41a36b26fec4f52998f1a451d0525a5c9a4fb06b6ea3e9211abdb925"},
+ {file = "grpcio_tools-1.71.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3059c14035e5dc03d462f261e5900b9a077fd1a36976c3865b8507474520bad4"},
+ {file = "grpcio_tools-1.71.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:f360981b215b1d5aff9235b37e7e1826246e35bbac32a53e41d4e990a37b8f4c"},
+ {file = "grpcio_tools-1.71.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:bfe3888c3bbe16a5aa39409bc38744a31c0c3d2daa2b0095978c56e106c85b42"},
+ {file = "grpcio_tools-1.71.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:145985c0bf12131f0a1503e65763e0f060473f7f3928ed1ff3fb0e8aad5bc8ac"},
+ {file = "grpcio_tools-1.71.0-cp310-cp310-win32.whl", hash = "sha256:82c430edd939bb863550ee0fecf067d78feff828908a1b529bbe33cc57f2419c"},
+ {file = "grpcio_tools-1.71.0-cp310-cp310-win_amd64.whl", hash = "sha256:83e90724e3f02415c628e4ead1d6ffe063820aaaa078d9a39176793df958cd5a"},
+ {file = "grpcio_tools-1.71.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:1f19b16b49afa5d21473f49c0966dd430c88d089cd52ac02404d8cef67134efb"},
+ {file = "grpcio_tools-1.71.0-cp311-cp311-macosx_10_14_universal2.whl", hash = "sha256:459c8f5e00e390aecd5b89de67deb3ec7188a274bc6cb50e43cef35ab3a3f45d"},
+ {file = "grpcio_tools-1.71.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:edab7e6518de01196be37f96cb1e138c3819986bf5e2a6c9e1519b4d716b2f5a"},
+ {file = "grpcio_tools-1.71.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8b93b9f6adc7491d4c10144c0643409db298e5e63c997106a804f6f0248dbaf4"},
+ {file = "grpcio_tools-1.71.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ae5f2efa9e644c10bf1021600bfc099dfbd8e02b184d2d25dc31fcd6c2bc59e"},
+ {file = "grpcio_tools-1.71.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:65aa082f4435571d65d5ce07fc444f23c3eff4f3e34abef599ef8c9e1f6f360f"},
+ {file = "grpcio_tools-1.71.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:1331e726e08b7bdcbf2075fcf4b47dff07842b04845e6e220a08a4663e232d7f"},
+ {file = "grpcio_tools-1.71.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6693a7d3ba138b0e693b3d1f687cdd9db9e68976c3fa2b951c17a072fea8b583"},
+ {file = "grpcio_tools-1.71.0-cp311-cp311-win32.whl", hash = "sha256:6d11ed3ff7b6023b5c72a8654975324bb98c1092426ba5b481af406ff559df00"},
+ {file = "grpcio_tools-1.71.0-cp311-cp311-win_amd64.whl", hash = "sha256:072b2a5805ac97e4623b3aa8f7818275f3fb087f4aa131b0fce00471065f6eaa"},
+ {file = "grpcio_tools-1.71.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:61c0409d5bdac57a7bd0ce0ab01c1c916728fe4c8a03d77a25135ad481eb505c"},
+ {file = "grpcio_tools-1.71.0-cp312-cp312-macosx_10_14_universal2.whl", hash = "sha256:28784f39921d061d2164a9dcda5164a69d07bf29f91f0ea50b505958292312c9"},
+ {file = "grpcio_tools-1.71.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:192808cf553cedca73f0479cc61d5684ad61f24db7a5f3c4dfe1500342425866"},
+ {file = "grpcio_tools-1.71.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:989ee9da61098230d3d4c8f8f8e27c2de796f1ff21b1c90110e636d9acd9432b"},
+ {file = "grpcio_tools-1.71.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:541a756276c8a55dec991f6c0106ae20c8c8f5ce8d0bdbfcb01e2338d1a8192b"},
+ {file = "grpcio_tools-1.71.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:870c0097700d13c403e5517cb7750ab5b4a791ce3e71791c411a38c5468b64bd"},
+ {file = "grpcio_tools-1.71.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:abd57f615e88bf93c3c6fd31f923106e3beb12f8cd2df95b0d256fa07a7a0a57"},
+ {file = "grpcio_tools-1.71.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:753270e2d06d37e6d7af8967d1d059ec635ad215882041a36294f4e2fd502b2e"},
+ {file = "grpcio_tools-1.71.0-cp312-cp312-win32.whl", hash = "sha256:0e647794bd7138b8c215e86277a9711a95cf6a03ff6f9e555d54fdf7378b9f9d"},
+ {file = "grpcio_tools-1.71.0-cp312-cp312-win_amd64.whl", hash = "sha256:48debc879570972d28bfe98e4970eff25bb26da3f383e0e49829b2d2cd35ad87"},
+ {file = "grpcio_tools-1.71.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:9a78d07d6c301a25ef5ede962920a522556a1dfee1ccc05795994ceb867f766c"},
+ {file = "grpcio_tools-1.71.0-cp313-cp313-macosx_10_14_universal2.whl", hash = "sha256:580ac88141c9815557e63c9c04f5b1cdb19b4db8d0cb792b573354bde1ee8b12"},
+ {file = "grpcio_tools-1.71.0-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:f7c678e68ece0ae908ecae1c4314a0c2c7f83e26e281738b9609860cc2c82d96"},
+ {file = "grpcio_tools-1.71.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:56ecd6cc89b5e5eed1de5eb9cafce86c9c9043ee3840888cc464d16200290b53"},
+ {file = "grpcio_tools-1.71.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e52a041afc20ab2431d756b6295d727bd7adee813b21b06a3483f4a7a15ea15f"},
+ {file = "grpcio_tools-1.71.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:2a1712f12102b60c8d92779b89d0504e0d6f3a59f2b933e5622b8583f5c02992"},
+ {file = "grpcio_tools-1.71.0-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:41878cb7a75477e62fdd45e7e9155b3af1b7a5332844021e2511deaf99ac9e6c"},
+ {file = "grpcio_tools-1.71.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:682e958b476049ccc14c71bedf3f979bced01f6e0c04852efc5887841a32ad6b"},
+ {file = "grpcio_tools-1.71.0-cp313-cp313-win32.whl", hash = "sha256:0ccfb837152b7b858b9f26bb110b3ae8c46675d56130f6c2f03605c4f129be13"},
+ {file = "grpcio_tools-1.71.0-cp313-cp313-win_amd64.whl", hash = "sha256:ffff9bc5eacb34dd26b487194f7d44a3e64e752fc2cf049d798021bf25053b87"},
+ {file = "grpcio_tools-1.71.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:834959b6eceb85de5217a411aba1643b5f782798680c122202d6a06177226644"},
+ {file = "grpcio_tools-1.71.0-cp39-cp39-macosx_10_14_universal2.whl", hash = "sha256:e3ae9556e2a1cd70e7d7b0e0459c35af71d51a7dae4cf36075068011a69f13ec"},
+ {file = "grpcio_tools-1.71.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:77fe6db1334e0ce318b2cb4e70afa94e0c173ed1a533d37aea69ad9f61ae8ea9"},
+ {file = "grpcio_tools-1.71.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57e3e2544c306b60ef2d76570bac4e977be1ad548641c9eec130c3bc47e80141"},
+ {file = "grpcio_tools-1.71.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af39e245fa56f7f5c2fe86b7d6c1b78f395c07e54d5613cbdbb3c24769a92b6e"},
+ {file = "grpcio_tools-1.71.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:8f987d0053351217954543b174b0bddbf51d45b3cfcf8d6de97b0a43d264d753"},
+ {file = "grpcio_tools-1.71.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:8e6cdbba4dae7b37b0d25d074614be9936fb720144420f03d9f142a80be69ba2"},
+ {file = "grpcio_tools-1.71.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d3adc8b229e60c77bab5a5d62b415667133bd5ced7d59b5f71d6317c9143631e"},
+ {file = "grpcio_tools-1.71.0-cp39-cp39-win32.whl", hash = "sha256:f68334d28a267fabec6e70cb5986e9999cfbfd14db654094ddf9aedd804a293a"},
+ {file = "grpcio_tools-1.71.0-cp39-cp39-win_amd64.whl", hash = "sha256:1291a6136c07a86c3bb09f6c33f5cf227cc14956edd1b85cb572327a36e0aef8"},
+ {file = "grpcio_tools-1.71.0.tar.gz", hash = "sha256:38dba8e0d5e0fb23a034e09644fdc6ed862be2371887eee54901999e8f6792a8"},
+]
+
+[package.dependencies]
+grpcio = ">=1.71.0"
+protobuf = ">=5.26.1,<6.0dev"
+setuptools = "*"
+
+[[package]]
+name = "hf-transfer"
+version = "0.1.9"
+description = "Speed up file transfers with the Hugging Face Hub."
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "hf_transfer-0.1.9-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:6e94e8822da79573c9b6ae4d6b2f847c59a7a06c5327d7db20751b68538dc4f6"},
+ {file = "hf_transfer-0.1.9-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3ebc4ab9023414880c8b1d3c38174d1c9989eb5022d37e814fa91a3060123eb0"},
+ {file = "hf_transfer-0.1.9-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8674026f21ed369aa2a0a4b46000aca850fc44cd2b54af33a172ce5325b4fc82"},
+ {file = "hf_transfer-0.1.9-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3a736dfbb2c84f5a2c975478ad200c0c8bfcb58a25a35db402678fb87ce17fa4"},
+ {file = "hf_transfer-0.1.9-cp313-cp313t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:504b8427fd785dd8546d53b9fafe6e436bd7a3adf76b9dce556507650a7b4567"},
+ {file = "hf_transfer-0.1.9-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2c7fc1b85f4d0f76e452765d7648c9f4bfd0aedb9ced2ae1ebfece2d8cfaf8e2"},
+ {file = "hf_transfer-0.1.9-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d991376f0eac70a60f0cbc95602aa708a6f7c8617f28b4945c1431d67b8e3c8"},
+ {file = "hf_transfer-0.1.9-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:e6ac4eddcd99575ed3735ed911ddf9d1697e2bd13aa3f0ad7e3904dd4863842e"},
+ {file = "hf_transfer-0.1.9-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:57fd9880da1ee0f47250f735f791fab788f0aa1ee36afc49f761349869c8b4d9"},
+ {file = "hf_transfer-0.1.9-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:5d561f0520f493c66b016d99ceabe69c23289aa90be38dd802d2aef279f15751"},
+ {file = "hf_transfer-0.1.9-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:a5b366d34cd449fe9b20ef25941e6eef0460a2f74e7389f02e673e1f88ebd538"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:e66acf91df4a8b72f60223059df3003062a5ae111757187ed1a06750a30e911b"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:8669dbcc7a3e2e8d61d42cd24da9c50d57770bd74b445c65123291ca842a7e7a"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8fd0167c4407a3bc4cdd0307e65ada2294ec04f1813d8a69a5243e379b22e9d8"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ee8b10afedcb75f71091bcc197c526a6ebf5c58bbbadb34fdeee6160f55f619f"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5828057e313de59300dd1abb489444bc452efe3f479d3c55b31a8f680936ba42"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fc6bd19e1cc177c66bdef15ef8636ad3bde79d5a4f608c158021153b4573509d"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cdca9bfb89e6f8f281890cc61a8aff2d3cecaff7e1a4d275574d96ca70098557"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:89a23f58b7b7effbc047b8ca286f131b17728c99a9f972723323003ffd1bb916"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:dc7fff1345980d6c0ebb92c811d24afa4b98b3e07ed070c8e38cc91fd80478c5"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:1a6bd16c667ebe89a069ca163060127a794fa3a3525292c900b8c8cc47985b0d"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:d2fde99d502093ade3ab1b53f80da18480e9902aa960dab7f74fb1b9e5bc5746"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-win32.whl", hash = "sha256:435cc3cdc8524ce57b074032b8fd76eed70a4224d2091232fa6a8cef8fd6803e"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-win_amd64.whl", hash = "sha256:16f208fc678911c37e11aa7b586bc66a37d02e636208f18b6bc53d29b5df40ad"},
+ {file = "hf_transfer-0.1.9.tar.gz", hash = "sha256:035572865dab29d17e783fbf1e84cf1cb24f3fcf8f1b17db1cfc7fdf139f02bf"},
+]
+
+[[package]]
+name = "huggingface-hub"
+version = "0.30.2"
+description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
+optional = false
+python-versions = ">=3.8.0"
+groups = ["main"]
+files = [
+ {file = "huggingface_hub-0.30.2-py3-none-any.whl", hash = "sha256:68ff05969927058cfa41df4f2155d4bb48f5f54f719dd0390103eefa9b191e28"},
+ {file = "huggingface_hub-0.30.2.tar.gz", hash = "sha256:9a7897c5b6fd9dad3168a794a8998d6378210f5b9688d0dfc180b1a228dc2466"},
+]
+
+[package.dependencies]
+filelock = "*"
+fsspec = ">=2023.5.0"
+packaging = ">=20.9"
+pyyaml = ">=5.1"
+requests = "*"
+tqdm = ">=4.42.1"
+typing-extensions = ">=3.7.4.3"
+
+[package.extras]
+all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
+cli = ["InquirerPy (==0.3.4)"]
+dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
+fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
+hf-transfer = ["hf-transfer (>=0.1.4)"]
+hf-xet = ["hf-xet (>=0.1.4)"]
+inference = ["aiohttp"]
+quality = ["libcst (==1.4.0)", "mypy (==1.5.1)", "ruff (>=0.9.0)"]
+tensorflow = ["graphviz", "pydot", "tensorflow"]
+tensorflow-testing = ["keras (<3.0)", "tensorflow"]
+testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
+torch = ["safetensors[torch]", "torch"]
+typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"]
+
+[[package]]
+name = "idna"
+version = "3.10"
+description = "Internationalized Domain Names in Applications (IDNA)"
+optional = false
+python-versions = ">=3.6"
+groups = ["main"]
+files = [
+ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"},
+ {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"},
+]
+
+[package.extras]
+all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"]
+
+[[package]]
+name = "importlib-metadata"
+version = "8.6.1"
+description = "Read metadata from Python packages"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "importlib_metadata-8.6.1-py3-none-any.whl", hash = "sha256:02a89390c1e15fdfdc0d7c6b25cb3e62650d0494005c97d6f148bf5b9787525e"},
+ {file = "importlib_metadata-8.6.1.tar.gz", hash = "sha256:310b41d755445d74569f993ccfc22838295d9fe005425094fad953d7f15c8580"},
+]
+
+[package.dependencies]
+zipp = ">=3.20"
+
+[package.extras]
+check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"]
+cover = ["pytest-cov"]
+doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
+enabler = ["pytest-enabler (>=2.2)"]
+perf = ["ipython"]
+test = ["flufl.flake8", "importlib_resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-perf (>=0.9.2)"]
+type = ["pytest-mypy"]
+
+[[package]]
+name = "iniconfig"
+version = "2.1.0"
+description = "brain-dead simple config-ini parsing"
+optional = false
+python-versions = ">=3.8"
+groups = ["dev"]
+files = [
+ {file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"},
+ {file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"},
+]
+
+[[package]]
+name = "interegular"
+version = "0.3.3"
+description = "a regex intersection checker"
+optional = true
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "interegular-0.3.3-py37-none-any.whl", hash = "sha256:b0c07007d48c89d6d19f7204972d369b2a77222722e126b6aa63aa721dc3b19c"},
+ {file = "interegular-0.3.3.tar.gz", hash = "sha256:d9b697b21b34884711399ba0f0376914b81899ce670032486d0d048344a76600"},
+]
+
+[[package]]
+name = "jinja2"
+version = "3.1.6"
+description = "A very fast and expressive template engine."
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67"},
+ {file = "jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d"},
+]
+
+[package.dependencies]
+MarkupSafe = ">=2.0"
+
+[package.extras]
+i18n = ["Babel (>=2.7)"]
+
+[[package]]
+name = "joblib"
+version = "1.4.2"
+description = "Lightweight pipelining with Python functions"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"},
+ {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"},
+]
+
+[[package]]
+name = "jsonschema"
+version = "4.23.0"
+description = "An implementation of JSON Schema validation for Python"
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "jsonschema-4.23.0-py3-none-any.whl", hash = "sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566"},
+ {file = "jsonschema-4.23.0.tar.gz", hash = "sha256:d71497fef26351a33265337fa77ffeb82423f3ea21283cd9467bb03999266bc4"},
+]
+
+[package.dependencies]
+attrs = ">=22.2.0"
+jsonschema-specifications = ">=2023.03.6"
+referencing = ">=0.28.4"
+rpds-py = ">=0.7.1"
+
+[package.extras]
+format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"]
+format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=24.6.0)"]
+
+[[package]]
+name = "jsonschema-specifications"
+version = "2024.10.1"
+description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "jsonschema_specifications-2024.10.1-py3-none-any.whl", hash = "sha256:a09a0680616357d9a0ecf05c12ad234479f549239d0f5b55f3deea67475da9bf"},
+ {file = "jsonschema_specifications-2024.10.1.tar.gz", hash = "sha256:0f38b83639958ce1152d02a7f062902c41c8fd20d558b0c34344292d417ae272"},
+]
+
+[package.dependencies]
+referencing = ">=0.31.0"
+
+[[package]]
+name = "lark"
+version = "1.2.2"
+description = "a modern parsing library"
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "lark-1.2.2-py3-none-any.whl", hash = "sha256:c2276486b02f0f1b90be155f2c8ba4a8e194d42775786db622faccd652d8e80c"},
+ {file = "lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80"},
+]
+
+[package.extras]
+atomic-cache = ["atomicwrites"]
+interegular = ["interegular (>=0.3.1,<0.4.0)"]
+nearley = ["js2py"]
+regex = ["regex"]
+
+[[package]]
+name = "llvmlite"
+version = "0.43.0"
+description = "lightweight wrapper around basic LLVM functionality"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "llvmlite-0.43.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a289af9a1687c6cf463478f0fa8e8aa3b6fb813317b0d70bf1ed0759eab6f761"},
+ {file = "llvmlite-0.43.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6d4fd101f571a31acb1559ae1af30f30b1dc4b3186669f92ad780e17c81e91bc"},
+ {file = "llvmlite-0.43.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d434ec7e2ce3cc8f452d1cd9a28591745de022f931d67be688a737320dfcead"},
+ {file = "llvmlite-0.43.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6912a87782acdff6eb8bf01675ed01d60ca1f2551f8176a300a886f09e836a6a"},
+ {file = "llvmlite-0.43.0-cp310-cp310-win_amd64.whl", hash = "sha256:14f0e4bf2fd2d9a75a3534111e8ebeb08eda2f33e9bdd6dfa13282afacdde0ed"},
+ {file = "llvmlite-0.43.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3e8d0618cb9bfe40ac38a9633f2493d4d4e9fcc2f438d39a4e854f39cc0f5f98"},
+ {file = "llvmlite-0.43.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e0a9a1a39d4bf3517f2af9d23d479b4175ead205c592ceeb8b89af48a327ea57"},
+ {file = "llvmlite-0.43.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1da416ab53e4f7f3bc8d4eeba36d801cc1894b9fbfbf2022b29b6bad34a7df2"},
+ {file = "llvmlite-0.43.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:977525a1e5f4059316b183fb4fd34fa858c9eade31f165427a3977c95e3ee749"},
+ {file = "llvmlite-0.43.0-cp311-cp311-win_amd64.whl", hash = "sha256:d5bd550001d26450bd90777736c69d68c487d17bf371438f975229b2b8241a91"},
+ {file = "llvmlite-0.43.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f99b600aa7f65235a5a05d0b9a9f31150c390f31261f2a0ba678e26823ec38f7"},
+ {file = "llvmlite-0.43.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:35d80d61d0cda2d767f72de99450766250560399edc309da16937b93d3b676e7"},
+ {file = "llvmlite-0.43.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eccce86bba940bae0d8d48ed925f21dbb813519169246e2ab292b5092aba121f"},
+ {file = "llvmlite-0.43.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df6509e1507ca0760787a199d19439cc887bfd82226f5af746d6977bd9f66844"},
+ {file = "llvmlite-0.43.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a2872ee80dcf6b5dbdc838763d26554c2a18aa833d31a2635bff16aafefb9c9"},
+ {file = "llvmlite-0.43.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9cd2a7376f7b3367019b664c21f0c61766219faa3b03731113ead75107f3b66c"},
+ {file = "llvmlite-0.43.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:18e9953c748b105668487b7c81a3e97b046d8abf95c4ddc0cd3c94f4e4651ae8"},
+ {file = "llvmlite-0.43.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74937acd22dc11b33946b67dca7680e6d103d6e90eeaaaf932603bec6fe7b03a"},
+ {file = "llvmlite-0.43.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc9efc739cc6ed760f795806f67889923f7274276f0eb45092a1473e40d9b867"},
+ {file = "llvmlite-0.43.0-cp39-cp39-win_amd64.whl", hash = "sha256:47e147cdda9037f94b399bf03bfd8a6b6b1f2f90be94a454e3386f006455a9b4"},
+ {file = "llvmlite-0.43.0.tar.gz", hash = "sha256:ae2b5b5c3ef67354824fb75517c8db5fbe93bc02cd9671f3c62271626bc041d5"},
+]
+
+[[package]]
+name = "loguru"
+version = "0.7.3"
+description = "Python logging made (stupidly) simple"
+optional = false
+python-versions = "<4.0,>=3.5"
+groups = ["main"]
+files = [
+ {file = "loguru-0.7.3-py3-none-any.whl", hash = "sha256:31a33c10c8e1e10422bfd431aeb5d351c7cf7fa671e3c4df004162264b28220c"},
+ {file = "loguru-0.7.3.tar.gz", hash = "sha256:19480589e77d47b8d85b2c827ad95d49bf31b0dcde16593892eb51dd18706eb6"},
+]
+
+[package.dependencies]
+colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""}
+win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
+
+[package.extras]
+dev = ["Sphinx (==8.1.3)", "build (==1.2.2)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.5.0)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.13.0)", "mypy (==v1.4.1)", "myst-parser (==4.0.0)", "pre-commit (==4.0.1)", "pytest (==6.1.2)", "pytest (==8.3.2)", "pytest-cov (==2.12.1)", "pytest-cov (==5.0.0)", "pytest-cov (==6.0.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.1.0)", "sphinx-rtd-theme (==3.0.2)", "tox (==3.27.1)", "tox (==4.23.2)", "twine (==6.0.1)"]
+
+[[package]]
+name = "markdown-it-py"
+version = "3.0.0"
+description = "Python port of markdown-it. Markdown parsing, done right!"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"},
+ {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"},
+]
+
+[package.dependencies]
+mdurl = ">=0.1,<1.0"
+
+[package.extras]
+benchmarking = ["psutil", "pytest", "pytest-benchmark"]
+code-style = ["pre-commit (>=3.0,<4.0)"]
+compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
+linkify = ["linkify-it-py (>=1,<3)"]
+plugins = ["mdit-py-plugins"]
+profiling = ["gprof2dot"]
+rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
+testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
+
+[[package]]
+name = "markupsafe"
+version = "3.0.2"
+description = "Safely add untrusted strings to HTML/XML markup."
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "MarkupSafe-3.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7e94c425039cde14257288fd61dcfb01963e658efbc0ff54f5306b06054700f8"},
+ {file = "MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9e2d922824181480953426608b81967de705c3cef4d1af983af849d7bd619158"},
+ {file = "MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:38a9ef736c01fccdd6600705b09dc574584b89bea478200c5fbf112a6b0d5579"},
+ {file = "MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bbcb445fa71794da8f178f0f6d66789a28d7319071af7a496d4d507ed566270d"},
+ {file = "MarkupSafe-3.0.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57cb5a3cf367aeb1d316576250f65edec5bb3be939e9247ae594b4bcbc317dfb"},
+ {file = "MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:3809ede931876f5b2ec92eef964286840ed3540dadf803dd570c3b7e13141a3b"},
+ {file = "MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e07c3764494e3776c602c1e78e298937c3315ccc9043ead7e685b7f2b8d47b3c"},
+ {file = "MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b424c77b206d63d500bcb69fa55ed8d0e6a3774056bdc4839fc9298a7edca171"},
+ {file = "MarkupSafe-3.0.2-cp310-cp310-win32.whl", hash = "sha256:fcabf5ff6eea076f859677f5f0b6b5c1a51e70a376b0579e0eadef8db48c6b50"},
+ {file = "MarkupSafe-3.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:6af100e168aa82a50e186c82875a5893c5597a0c1ccdb0d8b40240b1f28b969a"},
+ {file = "MarkupSafe-3.0.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:9025b4018f3a1314059769c7bf15441064b2207cb3f065e6ea1e7359cb46db9d"},
+ {file = "MarkupSafe-3.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:93335ca3812df2f366e80509ae119189886b0f3c2b81325d39efdb84a1e2ae93"},
+ {file = "MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2cb8438c3cbb25e220c2ab33bb226559e7afb3baec11c4f218ffa7308603c832"},
+ {file = "MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a123e330ef0853c6e822384873bef7507557d8e4a082961e1defa947aa59ba84"},
+ {file = "MarkupSafe-3.0.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e084f686b92e5b83186b07e8a17fc09e38fff551f3602b249881fec658d3eca"},
+ {file = "MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d8213e09c917a951de9d09ecee036d5c7d36cb6cb7dbaece4c71a60d79fb9798"},
+ {file = "MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:5b02fb34468b6aaa40dfc198d813a641e3a63b98c2b05a16b9f80b7ec314185e"},
+ {file = "MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0bff5e0ae4ef2e1ae4fdf2dfd5b76c75e5c2fa4132d05fc1b0dabcd20c7e28c4"},
+ {file = "MarkupSafe-3.0.2-cp311-cp311-win32.whl", hash = "sha256:6c89876f41da747c8d3677a2b540fb32ef5715f97b66eeb0c6b66f5e3ef6f59d"},
+ {file = "MarkupSafe-3.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:70a87b411535ccad5ef2f1df5136506a10775d267e197e4cf531ced10537bd6b"},
+ {file = "MarkupSafe-3.0.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:9778bd8ab0a994ebf6f84c2b949e65736d5575320a17ae8984a77fab08db94cf"},
+ {file = "MarkupSafe-3.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:846ade7b71e3536c4e56b386c2a47adf5741d2d8b94ec9dc3e92e5e1ee1e2225"},
+ {file = "MarkupSafe-3.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c99d261bd2d5f6b59325c92c73df481e05e57f19837bdca8413b9eac4bd8028"},
+ {file = "MarkupSafe-3.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e17c96c14e19278594aa4841ec148115f9c7615a47382ecb6b82bd8fea3ab0c8"},
+ {file = "MarkupSafe-3.0.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:88416bd1e65dcea10bc7569faacb2c20ce071dd1f87539ca2ab364bf6231393c"},
+ {file = "MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2181e67807fc2fa785d0592dc2d6206c019b9502410671cc905d132a92866557"},
+ {file = "MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:52305740fe773d09cffb16f8ed0427942901f00adedac82ec8b67752f58a1b22"},
+ {file = "MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ad10d3ded218f1039f11a75f8091880239651b52e9bb592ca27de44eed242a48"},
+ {file = "MarkupSafe-3.0.2-cp312-cp312-win32.whl", hash = "sha256:0f4ca02bea9a23221c0182836703cbf8930c5e9454bacce27e767509fa286a30"},
+ {file = "MarkupSafe-3.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:8e06879fc22a25ca47312fbe7c8264eb0b662f6db27cb2d3bbbc74b1df4b9b87"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ba9527cdd4c926ed0760bc301f6728ef34d841f405abf9d4f959c478421e4efd"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f8b3d067f2e40fe93e1ccdd6b2e1d16c43140e76f02fb1319a05cf2b79d99430"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:569511d3b58c8791ab4c2e1285575265991e6d8f8700c7be0e88f86cb0672094"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15ab75ef81add55874e7ab7055e9c397312385bd9ced94920f2802310c930396"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f3818cb119498c0678015754eba762e0d61e5b52d34c8b13d770f0719f7b1d79"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:cdb82a876c47801bb54a690c5ae105a46b392ac6099881cdfb9f6e95e4014c6a"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:cabc348d87e913db6ab4aa100f01b08f481097838bdddf7c7a84b7575b7309ca"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:444dcda765c8a838eaae23112db52f1efaf750daddb2d9ca300bcae1039adc5c"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313-win32.whl", hash = "sha256:bcf3e58998965654fdaff38e58584d8937aa3096ab5354d493c77d1fdd66d7a1"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:e6a2a455bd412959b57a172ce6328d2dd1f01cb2135efda2e4576e8a23fa3b0f"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:b5a6b3ada725cea8a5e634536b1b01c30bcdcd7f9c6fff4151548d5bf6b3a36c"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:a904af0a6162c73e3edcb969eeeb53a63ceeb5d8cf642fade7d39e7963a22ddb"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4aa4e5faecf353ed117801a068ebab7b7e09ffb6e1d5e412dc852e0da018126c"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0ef13eaeee5b615fb07c9a7dadb38eac06a0608b41570d8ade51c56539e509d"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d16a81a06776313e817c951135cf7340a3e91e8c1ff2fac444cfd75fffa04afe"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:6381026f158fdb7c72a168278597a5e3a5222e83ea18f543112b2662a9b699c5"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:3d79d162e7be8f996986c064d1c7c817f6df3a77fe3d6859f6f9e7be4b8c213a"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:131a3c7689c85f5ad20f9f6fb1b866f402c445b220c19fe4308c0b147ccd2ad9"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313t-win32.whl", hash = "sha256:ba8062ed2cf21c07a9e295d5b8a2a5ce678b913b45fdf68c32d95d6c1291e0b6"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313t-win_amd64.whl", hash = "sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f"},
+ {file = "MarkupSafe-3.0.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:eaa0a10b7f72326f1372a713e73c3f739b524b3af41feb43e4921cb529f5929a"},
+ {file = "MarkupSafe-3.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:48032821bbdf20f5799ff537c7ac3d1fba0ba032cfc06194faffa8cda8b560ff"},
+ {file = "MarkupSafe-3.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a9d3f5f0901fdec14d8d2f66ef7d035f2157240a433441719ac9a3fba440b13"},
+ {file = "MarkupSafe-3.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88b49a3b9ff31e19998750c38e030fc7bb937398b1f78cfa599aaef92d693144"},
+ {file = "MarkupSafe-3.0.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cfad01eed2c2e0c01fd0ecd2ef42c492f7f93902e39a42fc9ee1692961443a29"},
+ {file = "MarkupSafe-3.0.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:1225beacc926f536dc82e45f8a4d68502949dc67eea90eab715dea3a21c1b5f0"},
+ {file = "MarkupSafe-3.0.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:3169b1eefae027567d1ce6ee7cae382c57fe26e82775f460f0b2778beaad66c0"},
+ {file = "MarkupSafe-3.0.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:eb7972a85c54febfb25b5c4b4f3af4dcc731994c7da0d8a0b4a6eb0640e1d178"},
+ {file = "MarkupSafe-3.0.2-cp39-cp39-win32.whl", hash = "sha256:8c4e8c3ce11e1f92f6536ff07154f9d49677ebaaafc32db9db4620bc11ed480f"},
+ {file = "MarkupSafe-3.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6e296a513ca3d94054c2c881cc913116e90fd030ad1c656b3869762b754f5f8a"},
+ {file = "markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0"},
+]
+
+[[package]]
+name = "mdurl"
+version = "0.1.2"
+description = "Markdown URL utilities"
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"},
+ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
+]
+
+[[package]]
+name = "mpmath"
+version = "1.3.0"
+description = "Python library for arbitrary-precision floating-point arithmetic"
+optional = false
+python-versions = "*"
+groups = ["main"]
+files = [
+ {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"},
+ {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"},
+]
+
+[package.extras]
+develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"]
+docs = ["sphinx"]
+gmpy = ["gmpy2 (>=2.1.0a4)"]
+tests = ["pytest (>=4.6)"]
+
+[[package]]
+name = "nest-asyncio"
+version = "1.6.0"
+description = "Patch asyncio to allow nested event loops"
+optional = true
+python-versions = ">=3.5"
+groups = ["main"]
+files = [
+ {file = "nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c"},
+ {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"},
+]
+
+[[package]]
+name = "networkx"
+version = "3.2.1"
+description = "Python package for creating and manipulating graphs and networks"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "networkx-3.2.1-py3-none-any.whl", hash = "sha256:f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2"},
+ {file = "networkx-3.2.1.tar.gz", hash = "sha256:9f1bb5cf3409bf324e0a722c20bdb4c20ee39bf1c30ce8ae499c8502b0b5e0c6"},
+]
+
+[package.extras]
+default = ["matplotlib (>=3.5)", "numpy (>=1.22)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"]
+developer = ["changelist (==0.4)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"]
+doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"]
+extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"]
+test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"]
+
+[[package]]
+name = "numba"
+version = "0.60.0"
+description = "compiling Python code using LLVM"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "numba-0.60.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d761de835cd38fb400d2c26bb103a2726f548dc30368853121d66201672e651"},
+ {file = "numba-0.60.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:159e618ef213fba758837f9837fb402bbe65326e60ba0633dbe6c7f274d42c1b"},
+ {file = "numba-0.60.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1527dc578b95c7c4ff248792ec33d097ba6bef9eda466c948b68dfc995c25781"},
+ {file = "numba-0.60.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fe0b28abb8d70f8160798f4de9d486143200f34458d34c4a214114e445d7124e"},
+ {file = "numba-0.60.0-cp310-cp310-win_amd64.whl", hash = "sha256:19407ced081d7e2e4b8d8c36aa57b7452e0283871c296e12d798852bc7d7f198"},
+ {file = "numba-0.60.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a17b70fc9e380ee29c42717e8cc0bfaa5556c416d94f9aa96ba13acb41bdece8"},
+ {file = "numba-0.60.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3fb02b344a2a80efa6f677aa5c40cd5dd452e1b35f8d1c2af0dfd9ada9978e4b"},
+ {file = "numba-0.60.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5f4fde652ea604ea3c86508a3fb31556a6157b2c76c8b51b1d45eb40c8598703"},
+ {file = "numba-0.60.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4142d7ac0210cc86432b818338a2bc368dc773a2f5cf1e32ff7c5b378bd63ee8"},
+ {file = "numba-0.60.0-cp311-cp311-win_amd64.whl", hash = "sha256:cac02c041e9b5bc8cf8f2034ff6f0dbafccd1ae9590dc146b3a02a45e53af4e2"},
+ {file = "numba-0.60.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d7da4098db31182fc5ffe4bc42c6f24cd7d1cb8a14b59fd755bfee32e34b8404"},
+ {file = "numba-0.60.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:38d6ea4c1f56417076ecf8fc327c831ae793282e0ff51080c5094cb726507b1c"},
+ {file = "numba-0.60.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:62908d29fb6a3229c242e981ca27e32a6e606cc253fc9e8faeb0e48760de241e"},
+ {file = "numba-0.60.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0ebaa91538e996f708f1ab30ef4d3ddc344b64b5227b67a57aa74f401bb68b9d"},
+ {file = "numba-0.60.0-cp312-cp312-win_amd64.whl", hash = "sha256:f75262e8fe7fa96db1dca93d53a194a38c46da28b112b8a4aca168f0df860347"},
+ {file = "numba-0.60.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:01ef4cd7d83abe087d644eaa3d95831b777aa21d441a23703d649e06b8e06b74"},
+ {file = "numba-0.60.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:819a3dfd4630d95fd574036f99e47212a1af41cbcb019bf8afac63ff56834449"},
+ {file = "numba-0.60.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0b983bd6ad82fe868493012487f34eae8bf7dd94654951404114f23c3466d34b"},
+ {file = "numba-0.60.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c151748cd269ddeab66334bd754817ffc0cabd9433acb0f551697e5151917d25"},
+ {file = "numba-0.60.0-cp39-cp39-win_amd64.whl", hash = "sha256:3031547a015710140e8c87226b4cfe927cac199835e5bf7d4fe5cb64e814e3ab"},
+ {file = "numba-0.60.0.tar.gz", hash = "sha256:5df6158e5584eece5fc83294b949fd30b9f1125df7708862205217e068aabf16"},
+]
+
+[package.dependencies]
+llvmlite = "==0.43.*"
+numpy = ">=1.22,<2.1"
+
+[[package]]
+name = "numpy"
+version = "1.26.4"
+description = "Fundamental package for array computing in Python"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"},
+ {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"},
+ {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"},
+ {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"},
+ {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"},
+ {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"},
+ {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"},
+ {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"},
+ {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"},
+ {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"},
+ {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"},
+ {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"},
+ {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"},
+ {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"},
+ {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"},
+ {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"},
+ {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"},
+ {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"},
+ {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"},
+ {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"},
+ {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"},
+ {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"},
+ {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"},
+ {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"},
+ {file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"},
+ {file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"},
+ {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"},
+ {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"},
+ {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"},
+ {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"},
+ {file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"},
+ {file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"},
+ {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"},
+ {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"},
+ {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"},
+ {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"},
+]
+
+[[package]]
+name = "opentelemetry-api"
+version = "1.32.0"
+description = "OpenTelemetry Python API"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_api-1.32.0-py3-none-any.whl", hash = "sha256:15df743c765078611f376037b0d9111ec5c1febf2ec9440cdd919370faa1ce55"},
+ {file = "opentelemetry_api-1.32.0.tar.gz", hash = "sha256:2623280c916f9b19cad0aa4280cb171265f19fd2909b0d47e4f06f7c83b02cb5"},
+]
+
+[package.dependencies]
+deprecated = ">=1.2.6"
+importlib-metadata = ">=6.0,<8.7.0"
+
+[[package]]
+name = "opentelemetry-exporter-otlp"
+version = "1.32.0"
+description = "OpenTelemetry Collector Exporters"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_exporter_otlp-1.32.0-py3-none-any.whl", hash = "sha256:8b563bee30f05415fb51e075eb6461cdaa7bcef1cc79917cfd79caf12e5bb548"},
+ {file = "opentelemetry_exporter_otlp-1.32.0.tar.gz", hash = "sha256:4c66681f8acd95dce44966842182e3690e77256e5791ceb34b76ea1c34b20463"},
+]
+
+[package.dependencies]
+opentelemetry-exporter-otlp-proto-grpc = "1.32.0"
+opentelemetry-exporter-otlp-proto-http = "1.32.0"
+
+[[package]]
+name = "opentelemetry-exporter-otlp-proto-common"
+version = "1.32.0"
+description = "OpenTelemetry Protobuf encoding"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_exporter_otlp_proto_common-1.32.0-py3-none-any.whl", hash = "sha256:277a63a18768b3b460d082a489f6f80d4ae2c1e6b185bb701c6bd4e91405e4bd"},
+ {file = "opentelemetry_exporter_otlp_proto_common-1.32.0.tar.gz", hash = "sha256:2bca672f2a279c4f517115e635c0cc1269d07b2982a36681c521f7e56179a222"},
+]
+
+[package.dependencies]
+opentelemetry-proto = "1.32.0"
+
+[[package]]
+name = "opentelemetry-exporter-otlp-proto-grpc"
+version = "1.32.0"
+description = "OpenTelemetry Collector Protobuf over gRPC Exporter"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_exporter_otlp_proto_grpc-1.32.0-py3-none-any.whl", hash = "sha256:85b7c42bebe48ef55866793a3123ebf357dcaf629d961b27067025fd60104dbe"},
+ {file = "opentelemetry_exporter_otlp_proto_grpc-1.32.0.tar.gz", hash = "sha256:c069c5d5f429a46fb1001f38191730939f593789c847648e4cea26dc8b6018a8"},
+]
+
+[package.dependencies]
+deprecated = ">=1.2.6"
+googleapis-common-protos = ">=1.52,<2.0"
+grpcio = {version = ">=1.63.2,<2.0.0", markers = "python_version < \"3.13\""}
+opentelemetry-api = ">=1.15,<2.0"
+opentelemetry-exporter-otlp-proto-common = "1.32.0"
+opentelemetry-proto = "1.32.0"
+opentelemetry-sdk = ">=1.32.0,<1.33.0"
+
+[[package]]
+name = "opentelemetry-exporter-otlp-proto-http"
+version = "1.32.0"
+description = "OpenTelemetry Collector Protobuf over HTTP Exporter"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_exporter_otlp_proto_http-1.32.0-py3-none-any.whl", hash = "sha256:e2ffecd6d2220eaf1291a46339f109bc0a57ee7c4c6abb8174df418bf00ce01f"},
+ {file = "opentelemetry_exporter_otlp_proto_http-1.32.0.tar.gz", hash = "sha256:a5dfd94603da86e313e4f4fb8d181fd3b64a7c2a9c7b408c3653d2b1bc68d14f"},
+]
+
+[package.dependencies]
+deprecated = ">=1.2.6"
+googleapis-common-protos = ">=1.52,<2.0"
+opentelemetry-api = ">=1.15,<2.0"
+opentelemetry-exporter-otlp-proto-common = "1.32.0"
+opentelemetry-proto = "1.32.0"
+opentelemetry-sdk = ">=1.32.0,<1.33.0"
+requests = ">=2.7,<3.0"
+
+[[package]]
+name = "opentelemetry-instrumentation"
+version = "0.53b0"
+description = "Instrumentation Tools & Auto Instrumentation for OpenTelemetry Python"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_instrumentation-0.53b0-py3-none-any.whl", hash = "sha256:70600778fd567c9c5fbfca181378ae179c0dec3ff613171707d3d77c360ff105"},
+ {file = "opentelemetry_instrumentation-0.53b0.tar.gz", hash = "sha256:f2c21d71a3cdf28c656e3d90d247ee7558fb9b0239b3d9e9190266499dbed9d2"},
+]
+
+[package.dependencies]
+opentelemetry-api = ">=1.4,<2.0"
+opentelemetry-semantic-conventions = "0.53b0"
+packaging = ">=18.0"
+wrapt = ">=1.0.0,<2.0.0"
+
+[[package]]
+name = "opentelemetry-instrumentation-grpc"
+version = "0.53b0"
+description = "OpenTelemetry gRPC instrumentation"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_instrumentation_grpc-0.53b0-py3-none-any.whl", hash = "sha256:bd44f113c58fd66614b07bd9b8115ec311389ec58ef7e48a06581e302971c3f4"},
+ {file = "opentelemetry_instrumentation_grpc-0.53b0.tar.gz", hash = "sha256:a95b752e0782e7b503379de1c64a5afa2c7c1cd8196fa5f2b5c090d01c15e517"},
+]
+
+[package.dependencies]
+opentelemetry-api = ">=1.12,<2.0"
+opentelemetry-instrumentation = "0.53b0"
+opentelemetry-semantic-conventions = "0.53b0"
+wrapt = ">=1.0.0,<2.0.0"
+
+[package.extras]
+instruments = ["grpcio (>=1.42.0)"]
+
+[[package]]
+name = "opentelemetry-proto"
+version = "1.32.0"
+description = "OpenTelemetry Python Proto"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_proto-1.32.0-py3-none-any.whl", hash = "sha256:f699269dc037e18fba05442580a8682c9fbd0f4c7f5addfed82c44be0c53c5ff"},
+ {file = "opentelemetry_proto-1.32.0.tar.gz", hash = "sha256:f8b70ae52f4ef8a4e4c0760e87c9071e07ece2618c080d4839bef44c0156cd44"},
+]
+
+[package.dependencies]
+protobuf = ">=5.0,<6.0"
+
+[[package]]
+name = "opentelemetry-sdk"
+version = "1.32.0"
+description = "OpenTelemetry Python SDK"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_sdk-1.32.0-py3-none-any.whl", hash = "sha256:ed252d035c22a15536c1f603ca089298daab60850fc2f5ddfa95d95cc1c043ea"},
+ {file = "opentelemetry_sdk-1.32.0.tar.gz", hash = "sha256:5ff07fb371d1ab1189fa7047702e2e888b5403c5efcbb18083cae0d5aa5f58d2"},
+]
+
+[package.dependencies]
+opentelemetry-api = "1.32.0"
+opentelemetry-semantic-conventions = "0.53b0"
+typing-extensions = ">=3.7.4"
+
+[[package]]
+name = "opentelemetry-semantic-conventions"
+version = "0.53b0"
+description = "OpenTelemetry Semantic Conventions"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_semantic_conventions-0.53b0-py3-none-any.whl", hash = "sha256:561da89f766ab51615c0e72b12329e0a1bc16945dbd62c8646ffc74e36a1edff"},
+ {file = "opentelemetry_semantic_conventions-0.53b0.tar.gz", hash = "sha256:05b7908e1da62d72f9bf717ed25c72f566fe005a2dd260c61b11e025f2552cf6"},
+]
+
+[package.dependencies]
+deprecated = ">=1.2.6"
+opentelemetry-api = "1.32.0"
+
+[[package]]
+name = "optimum"
+version = "1.24.0"
+description = "Optimum Library is an extension of the Hugging Face Transformers library, providing a framework to integrate third-party libraries from Hardware Partners and interface with their specific functionality."
+optional = false
+python-versions = ">=3.9.0"
+groups = ["main"]
+files = [
+ {file = "optimum-1.24.0-py3-none-any.whl", hash = "sha256:196776949183cd3a56a15097a02be41e6f37aa92d824bd053de89c39ee6b0087"},
+ {file = "optimum-1.24.0.tar.gz", hash = "sha256:b502a2afbf78bb73370ebb1eff07b93108a1b386116e87eb17e882d210150551"},
+]
+
+[package.dependencies]
+huggingface-hub = ">=0.8.0"
+numpy = "*"
+packaging = "*"
+torch = ">=1.11"
+transformers = ">=4.29"
+
+[package.extras]
+amd = ["optimum-amd"]
+benchmark = ["evaluate (>=0.2.0)", "optuna", "scikit-learn", "seqeval", "torchvision", "tqdm"]
+dev = ["Pillow", "accelerate", "black (>=23.1,<24.0)", "einops", "parameterized", "pytest (<=8.0.0)", "pytest-xdist", "requests", "rjieba", "ruff (==0.1.5)", "sacremoses", "scikit-learn", "sentencepiece", "timm", "torchaudio", "torchvision"]
+diffusers = ["diffusers"]
+doc-build = ["accelerate"]
+exporters = ["onnx", "onnxruntime", "timm", "transformers (>=4.36,<4.49.0)"]
+exporters-gpu = ["onnx", "onnxruntime-gpu", "timm", "transformers (>=4.36,<4.49.0)"]
+exporters-tf = ["datasets (<=2.16)", "h5py", "numpy (<1.24.0)", "onnx", "onnxruntime", "tensorflow (>=2.4,<=2.12.1)", "tf2onnx", "timm", "transformers (>=4.36,<4.38)"]
+furiosa = ["optimum-furiosa"]
+graphcore = ["optimum-graphcore"]
+habana = ["optimum-habana", "transformers (>=4.45.0,<4.46.0)"]
+intel = ["optimum-intel (>=1.18.0)"]
+ipex = ["optimum-intel[ipex] (>=1.18.0)"]
+neural-compressor = ["optimum-intel[neural-compressor] (>=1.18.0)"]
+neuron = ["optimum-neuron[neuron] (>=0.0.20)", "transformers (>=4.36.2,<4.42.0)"]
+neuronx = ["optimum-neuron[neuronx] (>=0.0.20)", "transformers (>=4.36.2,<4.42.0)"]
+nncf = ["optimum-intel[nncf] (>=1.18.0)"]
+onnxruntime = ["datasets (>=1.2.1)", "evaluate", "onnx", "onnxruntime (>=1.11.0)", "protobuf (>=3.20.1)", "transformers (>=4.36,<4.49.0)"]
+onnxruntime-gpu = ["datasets (>=1.2.1)", "evaluate", "onnx", "onnxruntime-gpu (>=1.11.0)", "protobuf (>=3.20.1)", "transformers (>=4.36,<4.49.0)"]
+onnxruntime-training = ["accelerate", "datasets (>=1.2.1)", "evaluate", "onnxruntime-training (>=1.11.0)", "protobuf (>=3.20.1)", "torch-ort", "transformers (>=4.36,<4.49.0)"]
+openvino = ["optimum-intel[openvino] (>=1.18.0)"]
+quality = ["black (>=23.1,<24.0)", "ruff (==0.1.5)"]
+quanto = ["optimum-quanto (>=0.2.4)"]
+tests = ["Pillow", "accelerate", "einops", "parameterized", "pytest (<=8.0.0)", "pytest-xdist", "requests", "rjieba", "sacremoses", "scikit-learn", "sentencepiece", "timm", "torchaudio", "torchvision"]
+
+[[package]]
+name = "optimum-habana"
+version = "1.17.0"
+description = "Optimum Habana is the interface between the Hugging Face Transformers and Diffusers libraries and Habana's Gaudi processor (HPU). It provides a set of tools enabling easy model loading, training and inference on single- and multi-HPU settings for different downstream tasks."
+optional = false
+python-versions = "*"
+groups = ["main"]
+files = [
+ {file = "optimum_habana-1.17.0-py3-none-any.whl", hash = "sha256:4f1008c7e84248b62778c8d5f79443237026a5a281b50ee67c7db211c5ca7d2a"},
+ {file = "optimum_habana-1.17.0.tar.gz", hash = "sha256:634adaa775c5c1694a164bdec46133e9712676237e996aadf3392c820573ce92"},
+]
+
+[package.dependencies]
+accelerate = ">=0.33.0,<0.34.0"
+diffusers = ">=0.31.0,<0.32.0"
+huggingface_hub = ">=0.24.7"
+optimum = "*"
+sentence-transformers = "3.3.1"
+torch = "*"
+transformers = ">=4.49.0,<4.50.0"
+
+[package.extras]
+quality = ["hf_doc_builder", "ruff"]
+tests = ["GitPython", "datasets", "optuna", "parameterized", "peft", "psutil", "pytest (<8.0.0)", "safetensors", "scipy", "sentencepiece", "timm", "timm", "torchsde"]
+
+[[package]]
+name = "outlines"
+version = "0.0.36"
+description = "Probabilistic Generative Model Programming"
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "outlines-0.0.36-py3-none-any.whl", hash = "sha256:afa02ca5c449c47731fa06af66d13c2f5ee8b30f8b82b4db90e08215d6f111d1"},
+ {file = "outlines-0.0.36.tar.gz", hash = "sha256:3cffb43143548cd78c6061990feb461cffd5479999391b8390471ea839c2d46e"},
+]
+
+[package.dependencies]
+cloudpickle = "*"
+diskcache = "*"
+interegular = "*"
+jinja2 = "*"
+joblib = "*"
+jsonschema = "*"
+lark = "*"
+nest-asyncio = "*"
+numba = "*"
+numpy = "*"
+pydantic = ">=2.0"
+referencing = "*"
+requests = "*"
+scipy = "*"
+torch = ">=2.1.0"
+transformers = "*"
+
+[package.extras]
+serve = ["fastapi", "pydantic (>=2.0)", "ray (==2.9.0)", "uvicorn", "vllm (>=0.3.0)"]
+test = ["accelerate", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "datasets", "diff-cover", "huggingface-hub", "llama-cpp-python", "openai (>=1.0.0)", "pre-commit", "pytest", "pytest-benchmark", "pytest-cov", "pytest-mock", "responses", "transformers"]
+
+[[package]]
+name = "packaging"
+version = "24.2"
+description = "Core utilities for Python packages"
+optional = false
+python-versions = ">=3.8"
+groups = ["main", "dev"]
+files = [
+ {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"},
+ {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"},
+]
+
+[[package]]
+name = "peft"
+version = "0.15.1"
+description = "Parameter-Efficient Fine-Tuning (PEFT)"
+optional = false
+python-versions = ">=3.9.0"
+groups = ["main"]
+files = [
+ {file = "peft-0.15.1-py3-none-any.whl", hash = "sha256:5fb3960beb518f00668f2cdc53424a5cc495c78281697821ce24609c90ca0a10"},
+ {file = "peft-0.15.1.tar.gz", hash = "sha256:e4c65af70683a9ef3baf1ab450710f1eb7181f369ef6172ca8bf15bf4ae6ff71"},
+]
+
+[package.dependencies]
+accelerate = ">=0.21.0"
+huggingface_hub = ">=0.25.0"
+numpy = ">=1.17"
+packaging = ">=20.0"
+psutil = "*"
+pyyaml = "*"
+safetensors = "*"
+torch = ">=1.13.0"
+tqdm = "*"
+transformers = "*"
+
+[package.extras]
+dev = ["black", "black", "hf-doc-builder", "hf-doc-builder", "ruff (>=0.9.2,<0.10.0)"]
+docs-specific = ["black", "hf-doc-builder"]
+quality = ["black", "hf-doc-builder", "ruff (>=0.9.2,<0.10.0)"]
+test = ["black", "black", "datasets", "diffusers", "hf-doc-builder", "hf-doc-builder", "parameterized", "protobuf", "pytest", "pytest-cov", "pytest-xdist", "ruff (>=0.9.2,<0.10.0)", "scipy", "sentencepiece"]
+
+[[package]]
+name = "pillow"
+version = "11.2.1"
+description = "Python Imaging Library (Fork)"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "pillow-11.2.1-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:d57a75d53922fc20c165016a20d9c44f73305e67c351bbc60d1adaf662e74047"},
+ {file = "pillow-11.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:127bf6ac4a5b58b3d32fc8289656f77f80567d65660bc46f72c0d77e6600cc95"},
+ {file = "pillow-11.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b4ba4be812c7a40280629e55ae0b14a0aafa150dd6451297562e1764808bbe61"},
+ {file = "pillow-11.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8bd62331e5032bc396a93609982a9ab6b411c05078a52f5fe3cc59234a3abd1"},
+ {file = "pillow-11.2.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:562d11134c97a62fe3af29581f083033179f7ff435f78392565a1ad2d1c2c45c"},
+ {file = "pillow-11.2.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:c97209e85b5be259994eb5b69ff50c5d20cca0f458ef9abd835e262d9d88b39d"},
+ {file = "pillow-11.2.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:0c3e6d0f59171dfa2e25d7116217543310908dfa2770aa64b8f87605f8cacc97"},
+ {file = "pillow-11.2.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:cc1c3bc53befb6096b84165956e886b1729634a799e9d6329a0c512ab651e579"},
+ {file = "pillow-11.2.1-cp310-cp310-win32.whl", hash = "sha256:312c77b7f07ab2139924d2639860e084ec2a13e72af54d4f08ac843a5fc9c79d"},
+ {file = "pillow-11.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:9bc7ae48b8057a611e5fe9f853baa88093b9a76303937449397899385da06fad"},
+ {file = "pillow-11.2.1-cp310-cp310-win_arm64.whl", hash = "sha256:2728567e249cdd939f6cc3d1f049595c66e4187f3c34078cbc0a7d21c47482d2"},
+ {file = "pillow-11.2.1-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:35ca289f712ccfc699508c4658a1d14652e8033e9b69839edf83cbdd0ba39e70"},
+ {file = "pillow-11.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e0409af9f829f87a2dfb7e259f78f317a5351f2045158be321fd135973fff7bf"},
+ {file = "pillow-11.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4e5c5edee874dce4f653dbe59db7c73a600119fbea8d31f53423586ee2aafd7"},
+ {file = "pillow-11.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b93a07e76d13bff9444f1a029e0af2964e654bfc2e2c2d46bfd080df5ad5f3d8"},
+ {file = "pillow-11.2.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:e6def7eed9e7fa90fde255afaf08060dc4b343bbe524a8f69bdd2a2f0018f600"},
+ {file = "pillow-11.2.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:8f4f3724c068be008c08257207210c138d5f3731af6c155a81c2b09a9eb3a788"},
+ {file = "pillow-11.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:a0a6709b47019dff32e678bc12c63008311b82b9327613f534e496dacaefb71e"},
+ {file = "pillow-11.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f6b0c664ccb879109ee3ca702a9272d877f4fcd21e5eb63c26422fd6e415365e"},
+ {file = "pillow-11.2.1-cp311-cp311-win32.whl", hash = "sha256:cc5d875d56e49f112b6def6813c4e3d3036d269c008bf8aef72cd08d20ca6df6"},
+ {file = "pillow-11.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:0f5c7eda47bf8e3c8a283762cab94e496ba977a420868cb819159980b6709193"},
+ {file = "pillow-11.2.1-cp311-cp311-win_arm64.whl", hash = "sha256:4d375eb838755f2528ac8cbc926c3e31cc49ca4ad0cf79cff48b20e30634a4a7"},
+ {file = "pillow-11.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:78afba22027b4accef10dbd5eed84425930ba41b3ea0a86fa8d20baaf19d807f"},
+ {file = "pillow-11.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:78092232a4ab376a35d68c4e6d5e00dfd73454bd12b230420025fbe178ee3b0b"},
+ {file = "pillow-11.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25a5f306095c6780c52e6bbb6109624b95c5b18e40aab1c3041da3e9e0cd3e2d"},
+ {file = "pillow-11.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c7b29dbd4281923a2bfe562acb734cee96bbb129e96e6972d315ed9f232bef4"},
+ {file = "pillow-11.2.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3e645b020f3209a0181a418bffe7b4a93171eef6c4ef6cc20980b30bebf17b7d"},
+ {file = "pillow-11.2.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b2dbea1012ccb784a65349f57bbc93730b96e85b42e9bf7b01ef40443db720b4"},
+ {file = "pillow-11.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:da3104c57bbd72948d75f6a9389e6727d2ab6333c3617f0a89d72d4940aa0443"},
+ {file = "pillow-11.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:598174aef4589af795f66f9caab87ba4ff860ce08cd5bb447c6fc553ffee603c"},
+ {file = "pillow-11.2.1-cp312-cp312-win32.whl", hash = "sha256:1d535df14716e7f8776b9e7fee118576d65572b4aad3ed639be9e4fa88a1cad3"},
+ {file = "pillow-11.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:14e33b28bf17c7a38eede290f77db7c664e4eb01f7869e37fa98a5aa95978941"},
+ {file = "pillow-11.2.1-cp312-cp312-win_arm64.whl", hash = "sha256:21e1470ac9e5739ff880c211fc3af01e3ae505859392bf65458c224d0bf283eb"},
+ {file = "pillow-11.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:fdec757fea0b793056419bca3e9932eb2b0ceec90ef4813ea4c1e072c389eb28"},
+ {file = "pillow-11.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:b0e130705d568e2f43a17bcbe74d90958e8a16263868a12c3e0d9c8162690830"},
+ {file = "pillow-11.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7bdb5e09068332578214cadd9c05e3d64d99e0e87591be22a324bdbc18925be0"},
+ {file = "pillow-11.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d189ba1bebfbc0c0e529159631ec72bb9e9bc041f01ec6d3233d6d82eb823bc1"},
+ {file = "pillow-11.2.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:191955c55d8a712fab8934a42bfefbf99dd0b5875078240943f913bb66d46d9f"},
+ {file = "pillow-11.2.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:ad275964d52e2243430472fc5d2c2334b4fc3ff9c16cb0a19254e25efa03a155"},
+ {file = "pillow-11.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:750f96efe0597382660d8b53e90dd1dd44568a8edb51cb7f9d5d918b80d4de14"},
+ {file = "pillow-11.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fe15238d3798788d00716637b3d4e7bb6bde18b26e5d08335a96e88564a36b6b"},
+ {file = "pillow-11.2.1-cp313-cp313-win32.whl", hash = "sha256:3fe735ced9a607fee4f481423a9c36701a39719252a9bb251679635f99d0f7d2"},
+ {file = "pillow-11.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:74ee3d7ecb3f3c05459ba95eed5efa28d6092d751ce9bf20e3e253a4e497e691"},
+ {file = "pillow-11.2.1-cp313-cp313-win_arm64.whl", hash = "sha256:5119225c622403afb4b44bad4c1ca6c1f98eed79db8d3bc6e4e160fc6339d66c"},
+ {file = "pillow-11.2.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:8ce2e8411c7aaef53e6bb29fe98f28cd4fbd9a1d9be2eeea434331aac0536b22"},
+ {file = "pillow-11.2.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:9ee66787e095127116d91dea2143db65c7bb1e232f617aa5957c0d9d2a3f23a7"},
+ {file = "pillow-11.2.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9622e3b6c1d8b551b6e6f21873bdcc55762b4b2126633014cea1803368a9aa16"},
+ {file = "pillow-11.2.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63b5dff3a68f371ea06025a1a6966c9a1e1ee452fc8020c2cd0ea41b83e9037b"},
+ {file = "pillow-11.2.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:31df6e2d3d8fc99f993fd253e97fae451a8db2e7207acf97859732273e108406"},
+ {file = "pillow-11.2.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:062b7a42d672c45a70fa1f8b43d1d38ff76b63421cbbe7f88146b39e8a558d91"},
+ {file = "pillow-11.2.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:4eb92eca2711ef8be42fd3f67533765d9fd043b8c80db204f16c8ea62ee1a751"},
+ {file = "pillow-11.2.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f91ebf30830a48c825590aede79376cb40f110b387c17ee9bd59932c961044f9"},
+ {file = "pillow-11.2.1-cp313-cp313t-win32.whl", hash = "sha256:e0b55f27f584ed623221cfe995c912c61606be8513bfa0e07d2c674b4516d9dd"},
+ {file = "pillow-11.2.1-cp313-cp313t-win_amd64.whl", hash = "sha256:36d6b82164c39ce5482f649b437382c0fb2395eabc1e2b1702a6deb8ad647d6e"},
+ {file = "pillow-11.2.1-cp313-cp313t-win_arm64.whl", hash = "sha256:225c832a13326e34f212d2072982bb1adb210e0cc0b153e688743018c94a2681"},
+ {file = "pillow-11.2.1-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:7491cf8a79b8eb867d419648fff2f83cb0b3891c8b36da92cc7f1931d46108c8"},
+ {file = "pillow-11.2.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8b02d8f9cb83c52578a0b4beadba92e37d83a4ef11570a8688bbf43f4ca50909"},
+ {file = "pillow-11.2.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:014ca0050c85003620526b0ac1ac53f56fc93af128f7546623cc8e31875ab928"},
+ {file = "pillow-11.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3692b68c87096ac6308296d96354eddd25f98740c9d2ab54e1549d6c8aea9d79"},
+ {file = "pillow-11.2.1-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:f781dcb0bc9929adc77bad571b8621ecb1e4cdef86e940fe2e5b5ee24fd33b35"},
+ {file = "pillow-11.2.1-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:2b490402c96f907a166615e9a5afacf2519e28295f157ec3a2bb9bd57de638cb"},
+ {file = "pillow-11.2.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:dd6b20b93b3ccc9c1b597999209e4bc5cf2853f9ee66e3fc9a400a78733ffc9a"},
+ {file = "pillow-11.2.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:4b835d89c08a6c2ee7781b8dd0a30209a8012b5f09c0a665b65b0eb3560b6f36"},
+ {file = "pillow-11.2.1-cp39-cp39-win32.whl", hash = "sha256:b10428b3416d4f9c61f94b494681280be7686bda15898a3a9e08eb66a6d92d67"},
+ {file = "pillow-11.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:6ebce70c3f486acf7591a3d73431fa504a4e18a9b97ff27f5f47b7368e4b9dd1"},
+ {file = "pillow-11.2.1-cp39-cp39-win_arm64.whl", hash = "sha256:c27476257b2fdcd7872d54cfd119b3a9ce4610fb85c8e32b70b42e3680a29a1e"},
+ {file = "pillow-11.2.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:9b7b0d4fd2635f54ad82785d56bc0d94f147096493a79985d0ab57aedd563156"},
+ {file = "pillow-11.2.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:aa442755e31c64037aa7c1cb186e0b369f8416c567381852c63444dd666fb772"},
+ {file = "pillow-11.2.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f0d3348c95b766f54b76116d53d4cb171b52992a1027e7ca50c81b43b9d9e363"},
+ {file = "pillow-11.2.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85d27ea4c889342f7e35f6d56e7e1cb345632ad592e8c51b693d7b7556043ce0"},
+ {file = "pillow-11.2.1-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:bf2c33d6791c598142f00c9c4c7d47f6476731c31081331664eb26d6ab583e01"},
+ {file = "pillow-11.2.1-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e616e7154c37669fc1dfc14584f11e284e05d1c650e1c0f972f281c4ccc53193"},
+ {file = "pillow-11.2.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:39ad2e0f424394e3aebc40168845fee52df1394a4673a6ee512d840d14ab3013"},
+ {file = "pillow-11.2.1-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:80f1df8dbe9572b4b7abdfa17eb5d78dd620b1d55d9e25f834efdbee872d3aed"},
+ {file = "pillow-11.2.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:ea926cfbc3957090becbcbbb65ad177161a2ff2ad578b5a6ec9bb1e1cd78753c"},
+ {file = "pillow-11.2.1-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:738db0e0941ca0376804d4de6a782c005245264edaa253ffce24e5a15cbdc7bd"},
+ {file = "pillow-11.2.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9db98ab6565c69082ec9b0d4e40dd9f6181dab0dd236d26f7a50b8b9bfbd5076"},
+ {file = "pillow-11.2.1-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:036e53f4170e270ddb8797d4c590e6dd14d28e15c7da375c18978045f7e6c37b"},
+ {file = "pillow-11.2.1-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:14f73f7c291279bd65fda51ee87affd7c1e097709f7fdd0188957a16c264601f"},
+ {file = "pillow-11.2.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:208653868d5c9ecc2b327f9b9ef34e0e42a4cdd172c2988fd81d62d2bc9bc044"},
+ {file = "pillow-11.2.1.tar.gz", hash = "sha256:a64dd61998416367b7ef979b73d3a85853ba9bec4c2925f74e588879a58716b6"},
+]
+
+[package.extras]
+docs = ["furo", "olefile", "sphinx (>=8.2)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinxext-opengraph"]
+fpx = ["olefile"]
+mic = ["olefile"]
+test-arrow = ["pyarrow"]
+tests = ["check-manifest", "coverage (>=7.4.2)", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout", "trove-classifiers (>=2024.10.12)"]
+typing = ["typing-extensions"]
+xmp = ["defusedxml"]
+
+[[package]]
+name = "pluggy"
+version = "1.5.0"
+description = "plugin and hook calling mechanisms for python"
+optional = false
+python-versions = ">=3.8"
+groups = ["dev"]
+files = [
+ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"},
+ {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"},
+]
+
+[package.extras]
+dev = ["pre-commit", "tox"]
+testing = ["pytest", "pytest-benchmark"]
+
+[[package]]
+name = "prometheus-client"
+version = "0.21.1"
+description = "Python client for the Prometheus monitoring system."
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "prometheus_client-0.21.1-py3-none-any.whl", hash = "sha256:594b45c410d6f4f8888940fe80b5cc2521b305a1fafe1c58609ef715a001f301"},
+ {file = "prometheus_client-0.21.1.tar.gz", hash = "sha256:252505a722ac04b0456be05c05f75f45d760c2911ffc45f2a06bcaed9f3ae3fb"},
+]
+
+[package.extras]
+twisted = ["twisted"]
+
+[[package]]
+name = "protobuf"
+version = "5.29.4"
+description = ""
+optional = false
+python-versions = ">=3.8"
+groups = ["main", "dev"]
+files = [
+ {file = "protobuf-5.29.4-cp310-abi3-win32.whl", hash = "sha256:13eb236f8eb9ec34e63fc8b1d6efd2777d062fa6aaa68268fb67cf77f6839ad7"},
+ {file = "protobuf-5.29.4-cp310-abi3-win_amd64.whl", hash = "sha256:bcefcdf3976233f8a502d265eb65ea740c989bacc6c30a58290ed0e519eb4b8d"},
+ {file = "protobuf-5.29.4-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:307ecba1d852ec237e9ba668e087326a67564ef83e45a0189a772ede9e854dd0"},
+ {file = "protobuf-5.29.4-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:aec4962f9ea93c431d5714ed1be1c93f13e1a8618e70035ba2b0564d9e633f2e"},
+ {file = "protobuf-5.29.4-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:d7d3f7d1d5a66ed4942d4fefb12ac4b14a29028b209d4bfb25c68ae172059922"},
+ {file = "protobuf-5.29.4-cp38-cp38-win32.whl", hash = "sha256:1832f0515b62d12d8e6ffc078d7e9eb06969aa6dc13c13e1036e39d73bebc2de"},
+ {file = "protobuf-5.29.4-cp38-cp38-win_amd64.whl", hash = "sha256:476cb7b14914c780605a8cf62e38c2a85f8caff2e28a6a0bad827ec7d6c85d68"},
+ {file = "protobuf-5.29.4-cp39-cp39-win32.whl", hash = "sha256:fd32223020cb25a2cc100366f1dedc904e2d71d9322403224cdde5fdced0dabe"},
+ {file = "protobuf-5.29.4-cp39-cp39-win_amd64.whl", hash = "sha256:678974e1e3a9b975b8bc2447fca458db5f93a2fb6b0c8db46b6675b5b5346812"},
+ {file = "protobuf-5.29.4-py3-none-any.whl", hash = "sha256:3fde11b505e1597f71b875ef2fc52062b6a9740e5f7c8997ce878b6009145862"},
+ {file = "protobuf-5.29.4.tar.gz", hash = "sha256:4f1dfcd7997b31ef8f53ec82781ff434a28bf71d9102ddde14d076adcfc78c99"},
+]
+
+[[package]]
+name = "psutil"
+version = "7.0.0"
+description = "Cross-platform lib for process and system monitoring in Python. NOTE: the syntax of this script MUST be kept compatible with Python 2.7."
+optional = false
+python-versions = ">=3.6"
+groups = ["main"]
+files = [
+ {file = "psutil-7.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:101d71dc322e3cffd7cea0650b09b3d08b8e7c4109dd6809fe452dfd00e58b25"},
+ {file = "psutil-7.0.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:39db632f6bb862eeccf56660871433e111b6ea58f2caea825571951d4b6aa3da"},
+ {file = "psutil-7.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fcee592b4c6f146991ca55919ea3d1f8926497a713ed7faaf8225e174581e91"},
+ {file = "psutil-7.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b1388a4f6875d7e2aff5c4ca1cc16c545ed41dd8bb596cefea80111db353a34"},
+ {file = "psutil-7.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5f098451abc2828f7dc6b58d44b532b22f2088f4999a937557b603ce72b1993"},
+ {file = "psutil-7.0.0-cp36-cp36m-win32.whl", hash = "sha256:84df4eb63e16849689f76b1ffcb36db7b8de703d1bc1fe41773db487621b6c17"},
+ {file = "psutil-7.0.0-cp36-cp36m-win_amd64.whl", hash = "sha256:1e744154a6580bc968a0195fd25e80432d3afec619daf145b9e5ba16cc1d688e"},
+ {file = "psutil-7.0.0-cp37-abi3-win32.whl", hash = "sha256:ba3fcef7523064a6c9da440fc4d6bd07da93ac726b5733c29027d7dc95b39d99"},
+ {file = "psutil-7.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553"},
+ {file = "psutil-7.0.0.tar.gz", hash = "sha256:7be9c3eba38beccb6495ea33afd982a44074b78f28c434a1f51cc07fd315c456"},
+]
+
+[package.extras]
+dev = ["abi3audit", "black (==24.10.0)", "check-manifest", "coverage", "packaging", "pylint", "pyperf", "pypinfo", "pytest", "pytest-cov", "pytest-xdist", "requests", "rstcheck", "ruff", "setuptools", "sphinx", "sphinx_rtd_theme", "toml-sort", "twine", "virtualenv", "vulture", "wheel"]
+test = ["pytest", "pytest-xdist", "setuptools"]
+
+[[package]]
+name = "py-cpuinfo"
+version = "9.0.0"
+description = "Get CPU info with pure Python"
+optional = false
+python-versions = "*"
+groups = ["main"]
+files = [
+ {file = "py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690"},
+ {file = "py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5"},
+]
+
+[[package]]
+name = "pydantic"
+version = "2.11.3"
+description = "Data validation using Python type hints"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "pydantic-2.11.3-py3-none-any.whl", hash = "sha256:a082753436a07f9ba1289c6ffa01cd93db3548776088aa917cc43b63f68fa60f"},
+ {file = "pydantic-2.11.3.tar.gz", hash = "sha256:7471657138c16adad9322fe3070c0116dd6c3ad8d649300e3cbdfe91f4db4ec3"},
+]
+
+[package.dependencies]
+annotated-types = ">=0.6.0"
+pydantic-core = "2.33.1"
+typing-extensions = ">=4.12.2"
+typing-inspection = ">=0.4.0"
+
+[package.extras]
+email = ["email-validator (>=2.0.0)"]
+timezone = ["tzdata"]
+
+[[package]]
+name = "pydantic-core"
+version = "2.33.1"
+description = "Core functionality for Pydantic validation and serialization"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "pydantic_core-2.33.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3077cfdb6125cc8dab61b155fdd714663e401f0e6883f9632118ec12cf42df26"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8ffab8b2908d152e74862d276cf5017c81a2f3719f14e8e3e8d6b83fda863927"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5183e4f6a2d468787243ebcd70cf4098c247e60d73fb7d68d5bc1e1beaa0c4db"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:398a38d323f37714023be1e0285765f0a27243a8b1506b7b7de87b647b517e48"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:87d3776f0001b43acebfa86f8c64019c043b55cc5a6a2e313d728b5c95b46969"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c566dd9c5f63d22226409553531f89de0cac55397f2ab8d97d6f06cfce6d947e"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0d5f3acc81452c56895e90643a625302bd6be351e7010664151cc55b7b97f89"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d3a07fadec2a13274a8d861d3d37c61e97a816beae717efccaa4b36dfcaadcde"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:f99aeda58dce827f76963ee87a0ebe75e648c72ff9ba1174a253f6744f518f65"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:902dbc832141aa0ec374f4310f1e4e7febeebc3256f00dc359a9ac3f264a45dc"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fe44d56aa0b00d66640aa84a3cbe80b7a3ccdc6f0b1ca71090696a6d4777c091"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-win32.whl", hash = "sha256:ed3eb16d51257c763539bde21e011092f127a2202692afaeaccb50db55a31383"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-win_amd64.whl", hash = "sha256:694ad99a7f6718c1a498dc170ca430687a39894a60327f548e02a9c7ee4b6504"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6e966fc3caaf9f1d96b349b0341c70c8d6573bf1bac7261f7b0ba88f96c56c24"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bfd0adeee563d59c598ceabddf2c92eec77abcb3f4a391b19aa7366170bd9e30"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91815221101ad3c6b507804178a7bb5cb7b2ead9ecd600041669c8d805ebd595"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9fea9c1869bb4742d174a57b4700c6dadea951df8b06de40c2fedb4f02931c2e"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d20eb4861329bb2484c021b9d9a977566ab16d84000a57e28061151c62b349a"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fb935c5591573ae3201640579f30128ccc10739b45663f93c06796854405505"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c964fd24e6166420d18fb53996d8c9fd6eac9bf5ae3ec3d03015be4414ce497f"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:681d65e9011f7392db5aa002b7423cc442d6a673c635668c227c6c8d0e5a4f77"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e100c52f7355a48413e2999bfb4e139d2977a904495441b374f3d4fb4a170961"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:048831bd363490be79acdd3232f74a0e9951b11b2b4cc058aeb72b22fdc3abe1"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:bdc84017d28459c00db6f918a7272a5190bec3090058334e43a76afb279eac7c"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-win32.whl", hash = "sha256:32cd11c5914d1179df70406427097c7dcde19fddf1418c787540f4b730289896"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-win_amd64.whl", hash = "sha256:2ea62419ba8c397e7da28a9170a16219d310d2cf4970dbc65c32faf20d828c83"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-win_arm64.whl", hash = "sha256:fc903512177361e868bc1f5b80ac8c8a6e05fcdd574a5fb5ffeac5a9982b9e89"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:1293d7febb995e9d3ec3ea09caf1a26214eec45b0f29f6074abb004723fc1de8"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:99b56acd433386c8f20be5c4000786d1e7ca0523c8eefc995d14d79c7a081498"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35a5ec3fa8c2fe6c53e1b2ccc2454398f95d5393ab398478f53e1afbbeb4d939"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b172f7b9d2f3abc0efd12e3386f7e48b576ef309544ac3a63e5e9cdd2e24585d"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9097b9f17f91eea659b9ec58148c0747ec354a42f7389b9d50701610d86f812e"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cc77ec5b7e2118b152b0d886c7514a4653bcb58c6b1d760134a9fab915f777b3"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5e3d15245b08fa4a84cefc6c9222e6f37c98111c8679fbd94aa145f9a0ae23d"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ef99779001d7ac2e2461d8ab55d3373fe7315caefdbecd8ced75304ae5a6fc6b"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:fc6bf8869e193855e8d91d91f6bf59699a5cdfaa47a404e278e776dd7f168b39"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:b1caa0bc2741b043db7823843e1bde8aaa58a55a58fda06083b0569f8b45693a"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ec259f62538e8bf364903a7d0d0239447059f9434b284f5536e8402b7dd198db"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-win32.whl", hash = "sha256:e14f369c98a7c15772b9da98987f58e2b509a93235582838bd0d1d8c08b68fda"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-win_amd64.whl", hash = "sha256:1c607801d85e2e123357b3893f82c97a42856192997b95b4d8325deb1cd0c5f4"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-win_arm64.whl", hash = "sha256:8d13f0276806ee722e70a1c93da19748594f19ac4299c7e41237fc791d1861ea"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:70af6a21237b53d1fe7b9325b20e65cbf2f0a848cf77bed492b029139701e66a"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:282b3fe1bbbe5ae35224a0dbd05aed9ccabccd241e8e6b60370484234b456266"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b315e596282bbb5822d0c7ee9d255595bd7506d1cb20c2911a4da0b970187d3"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1dfae24cf9921875ca0ca6a8ecb4bb2f13c855794ed0d468d6abbec6e6dcd44a"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6dd8ecfde08d8bfadaea669e83c63939af76f4cf5538a72597016edfa3fad516"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2f593494876eae852dc98c43c6f260f45abdbfeec9e4324e31a481d948214764"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:948b73114f47fd7016088e5186d13faf5e1b2fe83f5e320e371f035557fd264d"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e11f3864eb516af21b01e25fac915a82e9ddad3bb0fb9e95a246067398b435a4"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:549150be302428b56fdad0c23c2741dcdb5572413776826c965619a25d9c6bde"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:495bc156026efafd9ef2d82372bd38afce78ddd82bf28ef5276c469e57c0c83e"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ec79de2a8680b1a67a07490bddf9636d5c2fab609ba8c57597e855fa5fa4dacd"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-win32.whl", hash = "sha256:ee12a7be1742f81b8a65b36c6921022301d466b82d80315d215c4c691724986f"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-win_amd64.whl", hash = "sha256:ede9b407e39949d2afc46385ce6bd6e11588660c26f80576c11c958e6647bc40"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-win_arm64.whl", hash = "sha256:aa687a23d4b7871a00e03ca96a09cad0f28f443690d300500603bd0adba4b523"},
+ {file = "pydantic_core-2.33.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:401d7b76e1000d0dd5538e6381d28febdcacb097c8d340dde7d7fc6e13e9f95d"},
+ {file = "pydantic_core-2.33.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7aeb055a42d734c0255c9e489ac67e75397d59c6fbe60d155851e9782f276a9c"},
+ {file = "pydantic_core-2.33.1-cp313-cp313t-win_amd64.whl", hash = "sha256:338ea9b73e6e109f15ab439e62cb3b78aa752c7fd9536794112e14bee02c8d18"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:5ab77f45d33d264de66e1884fca158bc920cb5e27fd0764a72f72f5756ae8bdb"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e7aaba1b4b03aaea7bb59e1b5856d734be011d3e6d98f5bcaa98cb30f375f2ad"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7fb66263e9ba8fea2aa85e1e5578980d127fb37d7f2e292773e7bc3a38fb0c7b"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3f2648b9262607a7fb41d782cc263b48032ff7a03a835581abbf7a3bec62bcf5"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:723c5630c4259400818b4ad096735a829074601805d07f8cafc366d95786d331"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d100e3ae783d2167782391e0c1c7a20a31f55f8015f3293647544df3f9c67824"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:177d50460bc976a0369920b6c744d927b0ecb8606fb56858ff542560251b19e5"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a3edde68d1a1f9af1273b2fe798997b33f90308fb6d44d8550c89fc6a3647cf6"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a62c3c3ef6a7e2c45f7853b10b5bc4ddefd6ee3cd31024754a1a5842da7d598d"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-musllinux_1_1_armv7l.whl", hash = "sha256:c91dbb0ab683fa0cd64a6e81907c8ff41d6497c346890e26b23de7ee55353f96"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9f466e8bf0a62dc43e068c12166281c2eca72121dd2adc1040f3aa1e21ef8599"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-win32.whl", hash = "sha256:ab0277cedb698749caada82e5d099dc9fed3f906a30d4c382d1a21725777a1e5"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-win_amd64.whl", hash = "sha256:5773da0ee2d17136b1f1c6fbde543398d452a6ad2a7b54ea1033e2daa739b8d2"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5c834f54f8f4640fd7e4b193f80eb25a0602bba9e19b3cd2fc7ffe8199f5ae02"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:049e0de24cf23766f12cc5cc71d8abc07d4a9deb9061b334b62093dedc7cb068"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a28239037b3d6f16916a4c831a5a0eadf856bdd6d2e92c10a0da3a59eadcf3e"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d3da303ab5f378a268fa7d45f37d7d85c3ec19769f28d2cc0c61826a8de21fe"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:25626fb37b3c543818c14821afe0fd3830bc327a43953bc88db924b68c5723f1"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:3ab2d36e20fbfcce8f02d73c33a8a7362980cff717926bbae030b93ae46b56c7"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:2f9284e11c751b003fd4215ad92d325d92c9cb19ee6729ebd87e3250072cdcde"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:048c01eee07d37cbd066fc512b9d8b5ea88ceeb4e629ab94b3e56965ad655add"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:5ccd429694cf26af7997595d627dd2637e7932214486f55b8a357edaac9dae8c"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3a371dc00282c4b84246509a5ddc808e61b9864aa1eae9ecc92bb1268b82db4a"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:f59295ecc75a1788af8ba92f2e8c6eeaa5a94c22fc4d151e8d9638814f85c8fc"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08530b8ac922003033f399128505f513e30ca770527cc8bbacf75a84fcc2c74b"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bae370459da6a5466978c0eacf90690cb57ec9d533f8e63e564ef3822bfa04fe"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e3de2777e3b9f4d603112f78006f4ae0acb936e95f06da6cb1a45fbad6bdb4b5"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:3a64e81e8cba118e108d7126362ea30e021291b7805d47e4896e52c791be2761"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:52928d8c1b6bda03cc6d811e8923dffc87a2d3c8b3bfd2ce16471c7147a24850"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:1b30d92c9412beb5ac6b10a3eb7ef92ccb14e3f2a8d7732e2d739f58b3aa7544"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:f995719707e0e29f0f41a8aa3bcea6e761a36c9136104d3189eafb83f5cec5e5"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:7edbc454a29fc6aeae1e1eecba4f07b63b8d76e76a748532233c4c167b4cb9ea"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:ad05b683963f69a1d5d2c2bdab1274a31221ca737dbbceaa32bcb67359453cdd"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df6a94bf9452c6da9b5d76ed229a5683d0306ccb91cca8e1eea883189780d568"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7965c13b3967909a09ecc91f21d09cfc4576bf78140b988904e94f130f188396"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3f1fdb790440a34f6ecf7679e1863b825cb5ffde858a9197f851168ed08371e5"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:5277aec8d879f8d05168fdd17ae811dd313b8ff894aeeaf7cd34ad28b4d77e33"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:8ab581d3530611897d863d1a649fb0644b860286b4718db919bfd51ece41f10b"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:0483847fa9ad5e3412265c1bd72aad35235512d9ce9d27d81a56d935ef489672"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:de9e06abe3cc5ec6a2d5f75bc99b0bdca4f5c719a5b34026f8c57efbdecd2ee3"},
+ {file = "pydantic_core-2.33.1.tar.gz", hash = "sha256:bcc9c6fdb0ced789245b02b7d6603e17d1563064ddcfc36f046b61c0c05dd9df"},
+]
+
+[package.dependencies]
+typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
+
+[[package]]
+name = "pygments"
+version = "2.19.1"
+description = "Pygments is a syntax highlighting package written in Python."
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c"},
+ {file = "pygments-2.19.1.tar.gz", hash = "sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f"},
+]
+
+[package.extras]
+windows-terminal = ["colorama (>=0.4.6)"]
+
+[[package]]
+name = "pytest"
+version = "8.3.5"
+description = "pytest: simple powerful testing with Python"
+optional = false
+python-versions = ">=3.8"
+groups = ["dev"]
+files = [
+ {file = "pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820"},
+ {file = "pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845"},
+]
+
+[package.dependencies]
+colorama = {version = "*", markers = "sys_platform == \"win32\""}
+exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
+iniconfig = "*"
+packaging = "*"
+pluggy = ">=1.5,<2"
+tomli = {version = ">=1", markers = "python_version < \"3.11\""}
+
+[package.extras]
+dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
+
+[[package]]
+name = "pyyaml"
+version = "6.0.2"
+description = "YAML parser and emitter for Python"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"},
+ {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"},
+ {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237"},
+ {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b"},
+ {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed"},
+ {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180"},
+ {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68"},
+ {file = "PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99"},
+ {file = "PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e"},
+ {file = "PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774"},
+ {file = "PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee"},
+ {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c"},
+ {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317"},
+ {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85"},
+ {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4"},
+ {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e"},
+ {file = "PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5"},
+ {file = "PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44"},
+ {file = "PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab"},
+ {file = "PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725"},
+ {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5"},
+ {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425"},
+ {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476"},
+ {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48"},
+ {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b"},
+ {file = "PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4"},
+ {file = "PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8"},
+ {file = "PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba"},
+ {file = "PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1"},
+ {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133"},
+ {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484"},
+ {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5"},
+ {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc"},
+ {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652"},
+ {file = "PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183"},
+ {file = "PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563"},
+ {file = "PyYAML-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a"},
+ {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5"},
+ {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d"},
+ {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083"},
+ {file = "PyYAML-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706"},
+ {file = "PyYAML-6.0.2-cp38-cp38-win32.whl", hash = "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a"},
+ {file = "PyYAML-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff"},
+ {file = "PyYAML-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d"},
+ {file = "PyYAML-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f"},
+ {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290"},
+ {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12"},
+ {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19"},
+ {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e"},
+ {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725"},
+ {file = "PyYAML-6.0.2-cp39-cp39-win32.whl", hash = "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631"},
+ {file = "PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8"},
+ {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"},
+]
+
+[[package]]
+name = "referencing"
+version = "0.36.2"
+description = "JSON Referencing + Python"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "referencing-0.36.2-py3-none-any.whl", hash = "sha256:e8699adbbf8b5c7de96d8ffa0eb5c158b3beafce084968e2ea8bb08c6794dcd0"},
+ {file = "referencing-0.36.2.tar.gz", hash = "sha256:df2e89862cd09deabbdba16944cc3f10feb6b3e6f18e902f7cc25609a34775aa"},
+]
+
+[package.dependencies]
+attrs = ">=22.2.0"
+rpds-py = ">=0.7.0"
+typing-extensions = {version = ">=4.4.0", markers = "python_version < \"3.13\""}
+
+[[package]]
+name = "regex"
+version = "2024.11.6"
+description = "Alternative regular expression module, to replace re."
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "regex-2024.11.6-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ff590880083d60acc0433f9c3f713c51f7ac6ebb9adf889c79a261ecf541aa91"},
+ {file = "regex-2024.11.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:658f90550f38270639e83ce492f27d2c8d2cd63805c65a13a14d36ca126753f0"},
+ {file = "regex-2024.11.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:164d8b7b3b4bcb2068b97428060b2a53be050085ef94eca7f240e7947f1b080e"},
+ {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3660c82f209655a06b587d55e723f0b813d3a7db2e32e5e7dc64ac2a9e86fde"},
+ {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d22326fcdef5e08c154280b71163ced384b428343ae16a5ab2b3354aed12436e"},
+ {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f1ac758ef6aebfc8943560194e9fd0fa18bcb34d89fd8bd2af18183afd8da3a2"},
+ {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:997d6a487ff00807ba810e0f8332c18b4eb8d29463cfb7c820dc4b6e7562d0cf"},
+ {file = "regex-2024.11.6-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02a02d2bb04fec86ad61f3ea7f49c015a0681bf76abb9857f945d26159d2968c"},
+ {file = "regex-2024.11.6-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f02f93b92358ee3f78660e43b4b0091229260c5d5c408d17d60bf26b6c900e86"},
+ {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:06eb1be98df10e81ebaded73fcd51989dcf534e3c753466e4b60c4697a003b67"},
+ {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:040df6fe1a5504eb0f04f048e6d09cd7c7110fef851d7c567a6b6e09942feb7d"},
+ {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:fdabbfc59f2c6edba2a6622c647b716e34e8e3867e0ab975412c5c2f79b82da2"},
+ {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8447d2d39b5abe381419319f942de20b7ecd60ce86f16a23b0698f22e1b70008"},
+ {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:da8f5fc57d1933de22a9e23eec290a0d8a5927a5370d24bda9a6abe50683fe62"},
+ {file = "regex-2024.11.6-cp310-cp310-win32.whl", hash = "sha256:b489578720afb782f6ccf2840920f3a32e31ba28a4b162e13900c3e6bd3f930e"},
+ {file = "regex-2024.11.6-cp310-cp310-win_amd64.whl", hash = "sha256:5071b2093e793357c9d8b2929dfc13ac5f0a6c650559503bb81189d0a3814519"},
+ {file = "regex-2024.11.6-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5478c6962ad548b54a591778e93cd7c456a7a29f8eca9c49e4f9a806dcc5d638"},
+ {file = "regex-2024.11.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2c89a8cc122b25ce6945f0423dc1352cb9593c68abd19223eebbd4e56612c5b7"},
+ {file = "regex-2024.11.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:94d87b689cdd831934fa3ce16cc15cd65748e6d689f5d2b8f4f4df2065c9fa20"},
+ {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1062b39a0a2b75a9c694f7a08e7183a80c63c0d62b301418ffd9c35f55aaa114"},
+ {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:167ed4852351d8a750da48712c3930b031f6efdaa0f22fa1933716bfcd6bf4a3"},
+ {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2d548dafee61f06ebdb584080621f3e0c23fff312f0de1afc776e2a2ba99a74f"},
+ {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2a19f302cd1ce5dd01a9099aaa19cae6173306d1302a43b627f62e21cf18ac0"},
+ {file = "regex-2024.11.6-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bec9931dfb61ddd8ef2ebc05646293812cb6b16b60cf7c9511a832b6f1854b55"},
+ {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9714398225f299aa85267fd222f7142fcb5c769e73d7733344efc46f2ef5cf89"},
+ {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:202eb32e89f60fc147a41e55cb086db2a3f8cb82f9a9a88440dcfc5d37faae8d"},
+ {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:4181b814e56078e9b00427ca358ec44333765f5ca1b45597ec7446d3a1ef6e34"},
+ {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:068376da5a7e4da51968ce4c122a7cd31afaaec4fccc7856c92f63876e57b51d"},
+ {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ac10f2c4184420d881a3475fb2c6f4d95d53a8d50209a2500723d831036f7c45"},
+ {file = "regex-2024.11.6-cp311-cp311-win32.whl", hash = "sha256:c36f9b6f5f8649bb251a5f3f66564438977b7ef8386a52460ae77e6070d309d9"},
+ {file = "regex-2024.11.6-cp311-cp311-win_amd64.whl", hash = "sha256:02e28184be537f0e75c1f9b2f8847dc51e08e6e171c6bde130b2687e0c33cf60"},
+ {file = "regex-2024.11.6-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:52fb28f528778f184f870b7cf8f225f5eef0a8f6e3778529bdd40c7b3920796a"},
+ {file = "regex-2024.11.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fdd6028445d2460f33136c55eeb1f601ab06d74cb3347132e1c24250187500d9"},
+ {file = "regex-2024.11.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:805e6b60c54bf766b251e94526ebad60b7de0c70f70a4e6210ee2891acb70bf2"},
+ {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b85c2530be953a890eaffde05485238f07029600e8f098cdf1848d414a8b45e4"},
+ {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bb26437975da7dc36b7efad18aa9dd4ea569d2357ae6b783bf1118dabd9ea577"},
+ {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:abfa5080c374a76a251ba60683242bc17eeb2c9818d0d30117b4486be10c59d3"},
+ {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b7fa6606c2881c1db9479b0eaa11ed5dfa11c8d60a474ff0e095099f39d98e"},
+ {file = "regex-2024.11.6-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0c32f75920cf99fe6b6c539c399a4a128452eaf1af27f39bce8909c9a3fd8cbe"},
+ {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:982e6d21414e78e1f51cf595d7f321dcd14de1f2881c5dc6a6e23bbbbd68435e"},
+ {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:a7c2155f790e2fb448faed6dd241386719802296ec588a8b9051c1f5c481bc29"},
+ {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:149f5008d286636e48cd0b1dd65018548944e495b0265b45e1bffecce1ef7f39"},
+ {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:e5364a4502efca094731680e80009632ad6624084aff9a23ce8c8c6820de3e51"},
+ {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:0a86e7eeca091c09e021db8eb72d54751e527fa47b8d5787caf96d9831bd02ad"},
+ {file = "regex-2024.11.6-cp312-cp312-win32.whl", hash = "sha256:32f9a4c643baad4efa81d549c2aadefaeba12249b2adc5af541759237eee1c54"},
+ {file = "regex-2024.11.6-cp312-cp312-win_amd64.whl", hash = "sha256:a93c194e2df18f7d264092dc8539b8ffb86b45b899ab976aa15d48214138e81b"},
+ {file = "regex-2024.11.6-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a6ba92c0bcdf96cbf43a12c717eae4bc98325ca3730f6b130ffa2e3c3c723d84"},
+ {file = "regex-2024.11.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:525eab0b789891ac3be914d36893bdf972d483fe66551f79d3e27146191a37d4"},
+ {file = "regex-2024.11.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:086a27a0b4ca227941700e0b31425e7a28ef1ae8e5e05a33826e17e47fbfdba0"},
+ {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bde01f35767c4a7899b7eb6e823b125a64de314a8ee9791367c9a34d56af18d0"},
+ {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b583904576650166b3d920d2bcce13971f6f9e9a396c673187f49811b2769dc7"},
+ {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c4de13f06a0d54fa0d5ab1b7138bfa0d883220965a29616e3ea61b35d5f5fc7"},
+ {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3cde6e9f2580eb1665965ce9bf17ff4952f34f5b126beb509fee8f4e994f143c"},
+ {file = "regex-2024.11.6-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d7f453dca13f40a02b79636a339c5b62b670141e63efd511d3f8f73fba162b3"},
+ {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:59dfe1ed21aea057a65c6b586afd2a945de04fc7db3de0a6e3ed5397ad491b07"},
+ {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b97c1e0bd37c5cd7902e65f410779d39eeda155800b65fc4d04cc432efa9bc6e"},
+ {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f9d1e379028e0fc2ae3654bac3cbbef81bf3fd571272a42d56c24007979bafb6"},
+ {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:13291b39131e2d002a7940fb176e120bec5145f3aeb7621be6534e46251912c4"},
+ {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4f51f88c126370dcec4908576c5a627220da6c09d0bff31cfa89f2523843316d"},
+ {file = "regex-2024.11.6-cp313-cp313-win32.whl", hash = "sha256:63b13cfd72e9601125027202cad74995ab26921d8cd935c25f09c630436348ff"},
+ {file = "regex-2024.11.6-cp313-cp313-win_amd64.whl", hash = "sha256:2b3361af3198667e99927da8b84c1b010752fa4b1115ee30beaa332cabc3ef1a"},
+ {file = "regex-2024.11.6-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3a51ccc315653ba012774efca4f23d1d2a8a8f278a6072e29c7147eee7da446b"},
+ {file = "regex-2024.11.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ad182d02e40de7459b73155deb8996bbd8e96852267879396fb274e8700190e3"},
+ {file = "regex-2024.11.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ba9b72e5643641b7d41fa1f6d5abda2c9a263ae835b917348fc3c928182ad467"},
+ {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40291b1b89ca6ad8d3f2b82782cc33807f1406cf68c8d440861da6304d8ffbbd"},
+ {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cdf58d0e516ee426a48f7b2c03a332a4114420716d55769ff7108c37a09951bf"},
+ {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a36fdf2af13c2b14738f6e973aba563623cb77d753bbbd8d414d18bfaa3105dd"},
+ {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1cee317bfc014c2419a76bcc87f071405e3966da434e03e13beb45f8aced1a6"},
+ {file = "regex-2024.11.6-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50153825ee016b91549962f970d6a4442fa106832e14c918acd1c8e479916c4f"},
+ {file = "regex-2024.11.6-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ea1bfda2f7162605f6e8178223576856b3d791109f15ea99a9f95c16a7636fb5"},
+ {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:df951c5f4a1b1910f1a99ff42c473ff60f8225baa1cdd3539fe2819d9543e9df"},
+ {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:072623554418a9911446278f16ecb398fb3b540147a7828c06e2011fa531e773"},
+ {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:f654882311409afb1d780b940234208a252322c24a93b442ca714d119e68086c"},
+ {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:89d75e7293d2b3e674db7d4d9b1bee7f8f3d1609428e293771d1a962617150cc"},
+ {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:f65557897fc977a44ab205ea871b690adaef6b9da6afda4790a2484b04293a5f"},
+ {file = "regex-2024.11.6-cp38-cp38-win32.whl", hash = "sha256:6f44ec28b1f858c98d3036ad5d7d0bfc568bdd7a74f9c24e25f41ef1ebfd81a4"},
+ {file = "regex-2024.11.6-cp38-cp38-win_amd64.whl", hash = "sha256:bb8f74f2f10dbf13a0be8de623ba4f9491faf58c24064f32b65679b021ed0001"},
+ {file = "regex-2024.11.6-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5704e174f8ccab2026bd2f1ab6c510345ae8eac818b613d7d73e785f1310f839"},
+ {file = "regex-2024.11.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:220902c3c5cc6af55d4fe19ead504de80eb91f786dc102fbd74894b1551f095e"},
+ {file = "regex-2024.11.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5e7e351589da0850c125f1600a4c4ba3c722efefe16b297de54300f08d734fbf"},
+ {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5056b185ca113c88e18223183aa1a50e66507769c9640a6ff75859619d73957b"},
+ {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e34b51b650b23ed3354b5a07aab37034d9f923db2a40519139af34f485f77d0"},
+ {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5670bce7b200273eee1840ef307bfa07cda90b38ae56e9a6ebcc9f50da9c469b"},
+ {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:08986dce1339bc932923e7d1232ce9881499a0e02925f7402fb7c982515419ef"},
+ {file = "regex-2024.11.6-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:93c0b12d3d3bc25af4ebbf38f9ee780a487e8bf6954c115b9f015822d3bb8e48"},
+ {file = "regex-2024.11.6-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:764e71f22ab3b305e7f4c21f1a97e1526a25ebdd22513e251cf376760213da13"},
+ {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:f056bf21105c2515c32372bbc057f43eb02aae2fda61052e2f7622c801f0b4e2"},
+ {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:69ab78f848845569401469da20df3e081e6b5a11cb086de3eed1d48f5ed57c95"},
+ {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:86fddba590aad9208e2fa8b43b4c098bb0ec74f15718bb6a704e3c63e2cef3e9"},
+ {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:684d7a212682996d21ca12ef3c17353c021fe9de6049e19ac8481ec35574a70f"},
+ {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:a03e02f48cd1abbd9f3b7e3586d97c8f7a9721c436f51a5245b3b9483044480b"},
+ {file = "regex-2024.11.6-cp39-cp39-win32.whl", hash = "sha256:41758407fc32d5c3c5de163888068cfee69cb4c2be844e7ac517a52770f9af57"},
+ {file = "regex-2024.11.6-cp39-cp39-win_amd64.whl", hash = "sha256:b2837718570f95dd41675328e111345f9b7095d821bac435aac173ac80b19983"},
+ {file = "regex-2024.11.6.tar.gz", hash = "sha256:7ab159b063c52a0333c884e4679f8d7a85112ee3078fe3d9004b2dd875585519"},
+]
+
+[[package]]
+name = "requests"
+version = "2.32.3"
+description = "Python HTTP for Humans."
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"},
+ {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"},
+]
+
+[package.dependencies]
+certifi = ">=2017.4.17"
+charset-normalizer = ">=2,<4"
+idna = ">=2.5,<4"
+urllib3 = ">=1.21.1,<3"
+
+[package.extras]
+socks = ["PySocks (>=1.5.6,!=1.5.7)"]
+use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
+
+[[package]]
+name = "rich"
+version = "14.0.0"
+description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
+optional = false
+python-versions = ">=3.8.0"
+groups = ["main"]
+files = [
+ {file = "rich-14.0.0-py3-none-any.whl", hash = "sha256:1c9491e1951aac09caffd42f448ee3d04e58923ffe14993f6e83068dc395d7e0"},
+ {file = "rich-14.0.0.tar.gz", hash = "sha256:82f1bc23a6a21ebca4ae0c45af9bdbc492ed20231dcb63f297d6d1021a9d5725"},
+]
+
+[package.dependencies]
+markdown-it-py = ">=2.2.0"
+pygments = ">=2.13.0,<3.0.0"
+typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.11\""}
+
+[package.extras]
+jupyter = ["ipywidgets (>=7.5.1,<9)"]
+
+[[package]]
+name = "rpds-py"
+version = "0.24.0"
+description = "Python bindings to Rust's persistent data structures (rpds)"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "rpds_py-0.24.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:006f4342fe729a368c6df36578d7a348c7c716be1da0a1a0f86e3021f8e98724"},
+ {file = "rpds_py-0.24.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2d53747da70a4e4b17f559569d5f9506420966083a31c5fbd84e764461c4444b"},
+ {file = "rpds_py-0.24.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8acd55bd5b071156bae57b555f5d33697998752673b9de554dd82f5b5352727"},
+ {file = "rpds_py-0.24.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7e80d375134ddb04231a53800503752093dbb65dad8dabacce2c84cccc78e964"},
+ {file = "rpds_py-0.24.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60748789e028d2a46fc1c70750454f83c6bdd0d05db50f5ae83e2db500b34da5"},
+ {file = "rpds_py-0.24.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6e1daf5bf6c2be39654beae83ee6b9a12347cb5aced9a29eecf12a2d25fff664"},
+ {file = "rpds_py-0.24.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1b221c2457d92a1fb3c97bee9095c874144d196f47c038462ae6e4a14436f7bc"},
+ {file = "rpds_py-0.24.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:66420986c9afff67ef0c5d1e4cdc2d0e5262f53ad11e4f90e5e22448df485bf0"},
+ {file = "rpds_py-0.24.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:43dba99f00f1d37b2a0265a259592d05fcc8e7c19d140fe51c6e6f16faabeb1f"},
+ {file = "rpds_py-0.24.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:a88c0d17d039333a41d9bf4616bd062f0bd7aa0edeb6cafe00a2fc2a804e944f"},
+ {file = "rpds_py-0.24.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:cc31e13ce212e14a539d430428cd365e74f8b2d534f8bc22dd4c9c55b277b875"},
+ {file = "rpds_py-0.24.0-cp310-cp310-win32.whl", hash = "sha256:fc2c1e1b00f88317d9de6b2c2b39b012ebbfe35fe5e7bef980fd2a91f6100a07"},
+ {file = "rpds_py-0.24.0-cp310-cp310-win_amd64.whl", hash = "sha256:c0145295ca415668420ad142ee42189f78d27af806fcf1f32a18e51d47dd2052"},
+ {file = "rpds_py-0.24.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:2d3ee4615df36ab8eb16c2507b11e764dcc11fd350bbf4da16d09cda11fcedef"},
+ {file = "rpds_py-0.24.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e13ae74a8a3a0c2f22f450f773e35f893484fcfacb00bb4344a7e0f4f48e1f97"},
+ {file = "rpds_py-0.24.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf86f72d705fc2ef776bb7dd9e5fbba79d7e1f3e258bf9377f8204ad0fc1c51e"},
+ {file = "rpds_py-0.24.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c43583ea8517ed2e780a345dd9960896afc1327e8cf3ac8239c167530397440d"},
+ {file = "rpds_py-0.24.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4cd031e63bc5f05bdcda120646a0d32f6d729486d0067f09d79c8db5368f4586"},
+ {file = "rpds_py-0.24.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:34d90ad8c045df9a4259c47d2e16a3f21fdb396665c94520dbfe8766e62187a4"},
+ {file = "rpds_py-0.24.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e838bf2bb0b91ee67bf2b889a1a841e5ecac06dd7a2b1ef4e6151e2ce155c7ae"},
+ {file = "rpds_py-0.24.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:04ecf5c1ff4d589987b4d9882872f80ba13da7d42427234fce8f22efb43133bc"},
+ {file = "rpds_py-0.24.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:630d3d8ea77eabd6cbcd2ea712e1c5cecb5b558d39547ac988351195db433f6c"},
+ {file = "rpds_py-0.24.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:ebcb786b9ff30b994d5969213a8430cbb984cdd7ea9fd6df06663194bd3c450c"},
+ {file = "rpds_py-0.24.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:174e46569968ddbbeb8a806d9922f17cd2b524aa753b468f35b97ff9c19cb718"},
+ {file = "rpds_py-0.24.0-cp311-cp311-win32.whl", hash = "sha256:5ef877fa3bbfb40b388a5ae1cb00636a624690dcb9a29a65267054c9ea86d88a"},
+ {file = "rpds_py-0.24.0-cp311-cp311-win_amd64.whl", hash = "sha256:e274f62cbd274359eff63e5c7e7274c913e8e09620f6a57aae66744b3df046d6"},
+ {file = "rpds_py-0.24.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:d8551e733626afec514b5d15befabea0dd70a343a9f23322860c4f16a9430205"},
+ {file = "rpds_py-0.24.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0e374c0ce0ca82e5b67cd61fb964077d40ec177dd2c4eda67dba130de09085c7"},
+ {file = "rpds_py-0.24.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d69d003296df4840bd445a5d15fa5b6ff6ac40496f956a221c4d1f6f7b4bc4d9"},
+ {file = "rpds_py-0.24.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8212ff58ac6dfde49946bea57474a386cca3f7706fc72c25b772b9ca4af6b79e"},
+ {file = "rpds_py-0.24.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:528927e63a70b4d5f3f5ccc1fa988a35456eb5d15f804d276709c33fc2f19bda"},
+ {file = "rpds_py-0.24.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a824d2c7a703ba6daaca848f9c3d5cb93af0505be505de70e7e66829affd676e"},
+ {file = "rpds_py-0.24.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44d51febb7a114293ffd56c6cf4736cb31cd68c0fddd6aa303ed09ea5a48e029"},
+ {file = "rpds_py-0.24.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3fab5f4a2c64a8fb64fc13b3d139848817a64d467dd6ed60dcdd6b479e7febc9"},
+ {file = "rpds_py-0.24.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:9be4f99bee42ac107870c61dfdb294d912bf81c3c6d45538aad7aecab468b6b7"},
+ {file = "rpds_py-0.24.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:564c96b6076a98215af52f55efa90d8419cc2ef45d99e314fddefe816bc24f91"},
+ {file = "rpds_py-0.24.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:75a810b7664c17f24bf2ffd7f92416c00ec84b49bb68e6a0d93e542406336b56"},
+ {file = "rpds_py-0.24.0-cp312-cp312-win32.whl", hash = "sha256:f6016bd950be4dcd047b7475fdf55fb1e1f59fc7403f387be0e8123e4a576d30"},
+ {file = "rpds_py-0.24.0-cp312-cp312-win_amd64.whl", hash = "sha256:998c01b8e71cf051c28f5d6f1187abbdf5cf45fc0efce5da6c06447cba997034"},
+ {file = "rpds_py-0.24.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:3d2d8e4508e15fc05b31285c4b00ddf2e0eb94259c2dc896771966a163122a0c"},
+ {file = "rpds_py-0.24.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0f00c16e089282ad68a3820fd0c831c35d3194b7cdc31d6e469511d9bffc535c"},
+ {file = "rpds_py-0.24.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:951cc481c0c395c4a08639a469d53b7d4afa252529a085418b82a6b43c45c240"},
+ {file = "rpds_py-0.24.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c9ca89938dff18828a328af41ffdf3902405a19f4131c88e22e776a8e228c5a8"},
+ {file = "rpds_py-0.24.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ed0ef550042a8dbcd657dfb284a8ee00f0ba269d3f2286b0493b15a5694f9fe8"},
+ {file = "rpds_py-0.24.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2b2356688e5d958c4d5cb964af865bea84db29971d3e563fb78e46e20fe1848b"},
+ {file = "rpds_py-0.24.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78884d155fd15d9f64f5d6124b486f3d3f7fd7cd71a78e9670a0f6f6ca06fb2d"},
+ {file = "rpds_py-0.24.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6a4a535013aeeef13c5532f802708cecae8d66c282babb5cd916379b72110cf7"},
+ {file = "rpds_py-0.24.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:84e0566f15cf4d769dade9b366b7b87c959be472c92dffb70462dd0844d7cbad"},
+ {file = "rpds_py-0.24.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:823e74ab6fbaa028ec89615ff6acb409e90ff45580c45920d4dfdddb069f2120"},
+ {file = "rpds_py-0.24.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:c61a2cb0085c8783906b2f8b1f16a7e65777823c7f4d0a6aaffe26dc0d358dd9"},
+ {file = "rpds_py-0.24.0-cp313-cp313-win32.whl", hash = "sha256:60d9b630c8025b9458a9d114e3af579a2c54bd32df601c4581bd054e85258143"},
+ {file = "rpds_py-0.24.0-cp313-cp313-win_amd64.whl", hash = "sha256:6eea559077d29486c68218178ea946263b87f1c41ae7f996b1f30a983c476a5a"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:d09dc82af2d3c17e7dd17120b202a79b578d79f2b5424bda209d9966efeed114"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5fc13b44de6419d1e7a7e592a4885b323fbc2f46e1f22151e3a8ed3b8b920405"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c347a20d79cedc0a7bd51c4d4b7dbc613ca4e65a756b5c3e57ec84bd43505b47"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:20f2712bd1cc26a3cc16c5a1bfee9ed1abc33d4cdf1aabd297fe0eb724df4272"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aad911555286884be1e427ef0dc0ba3929e6821cbeca2194b13dc415a462c7fd"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0aeb3329c1721c43c58cae274d7d2ca85c1690d89485d9c63a006cb79a85771a"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a0f156e9509cee987283abd2296ec816225145a13ed0391df8f71bf1d789e2d"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:aa6800adc8204ce898c8a424303969b7aa6a5e4ad2789c13f8648739830323b7"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:a18fc371e900a21d7392517c6f60fe859e802547309e94313cd8181ad9db004d"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:9168764133fd919f8dcca2ead66de0105f4ef5659cbb4fa044f7014bed9a1797"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:5f6e3cec44ba05ee5cbdebe92d052f69b63ae792e7d05f1020ac5e964394080c"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-win32.whl", hash = "sha256:8ebc7e65ca4b111d928b669713865f021b7773350eeac4a31d3e70144297baba"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-win_amd64.whl", hash = "sha256:675269d407a257b8c00a6b58205b72eec8231656506c56fd429d924ca00bb350"},
+ {file = "rpds_py-0.24.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:a36b452abbf29f68527cf52e181fced56685731c86b52e852053e38d8b60bc8d"},
+ {file = "rpds_py-0.24.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8b3b397eefecec8e8e39fa65c630ef70a24b09141a6f9fc17b3c3a50bed6b50e"},
+ {file = "rpds_py-0.24.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cdabcd3beb2a6dca7027007473d8ef1c3b053347c76f685f5f060a00327b8b65"},
+ {file = "rpds_py-0.24.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5db385bacd0c43f24be92b60c857cf760b7f10d8234f4bd4be67b5b20a7c0b6b"},
+ {file = "rpds_py-0.24.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8097b3422d020ff1c44effc40ae58e67d93e60d540a65649d2cdaf9466030791"},
+ {file = "rpds_py-0.24.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:493fe54318bed7d124ce272fc36adbf59d46729659b2c792e87c3b95649cdee9"},
+ {file = "rpds_py-0.24.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8aa362811ccdc1f8dadcc916c6d47e554169ab79559319ae9fae7d7752d0d60c"},
+ {file = "rpds_py-0.24.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d8f9a6e7fd5434817526815f09ea27f2746c4a51ee11bb3439065f5fc754db58"},
+ {file = "rpds_py-0.24.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:8205ee14463248d3349131bb8099efe15cd3ce83b8ef3ace63c7e976998e7124"},
+ {file = "rpds_py-0.24.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:921ae54f9ecba3b6325df425cf72c074cd469dea843fb5743a26ca7fb2ccb149"},
+ {file = "rpds_py-0.24.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:32bab0a56eac685828e00cc2f5d1200c548f8bc11f2e44abf311d6b548ce2e45"},
+ {file = "rpds_py-0.24.0-cp39-cp39-win32.whl", hash = "sha256:f5c0ed12926dec1dfe7d645333ea59cf93f4d07750986a586f511c0bc61fe103"},
+ {file = "rpds_py-0.24.0-cp39-cp39-win_amd64.whl", hash = "sha256:afc6e35f344490faa8276b5f2f7cbf71f88bc2cda4328e00553bd451728c571f"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:619ca56a5468f933d940e1bf431c6f4e13bef8e688698b067ae68eb4f9b30e3a"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:4b28e5122829181de1898c2c97f81c0b3246d49f585f22743a1246420bb8d399"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8e5ab32cf9eb3647450bc74eb201b27c185d3857276162c101c0f8c6374e098"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:208b3a70a98cf3710e97cabdc308a51cd4f28aa6e7bb11de3d56cd8b74bab98d"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bbc4362e06f950c62cad3d4abf1191021b2ffaf0b31ac230fbf0526453eee75e"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ebea2821cdb5f9fef44933617be76185b80150632736f3d76e54829ab4a3b4d1"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b9a4df06c35465ef4d81799999bba810c68d29972bf1c31db61bfdb81dd9d5bb"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d3aa13bdf38630da298f2e0d77aca967b200b8cc1473ea05248f6c5e9c9bdb44"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:041f00419e1da7a03c46042453598479f45be3d787eb837af382bfc169c0db33"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:d8754d872a5dfc3c5bf9c0e059e8107451364a30d9fd50f1f1a85c4fb9481164"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:896c41007931217a343eff197c34513c154267636c8056fb409eafd494c3dcdc"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:92558d37d872e808944c3c96d0423b8604879a3d1c86fdad508d7ed91ea547d5"},
+ {file = "rpds_py-0.24.0-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f9e0057a509e096e47c87f753136c9b10d7a91842d8042c2ee6866899a717c0d"},
+ {file = "rpds_py-0.24.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:d6e109a454412ab82979c5b1b3aee0604eca4bbf9a02693bb9df027af2bfa91a"},
+ {file = "rpds_py-0.24.0-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc1c892b1ec1f8cbd5da8de287577b455e388d9c328ad592eabbdcb6fc93bee5"},
+ {file = "rpds_py-0.24.0-pp311-pypy311_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9c39438c55983d48f4bb3487734d040e22dad200dab22c41e331cee145e7a50d"},
+ {file = "rpds_py-0.24.0-pp311-pypy311_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9d7e8ce990ae17dda686f7e82fd41a055c668e13ddcf058e7fb5e9da20b57793"},
+ {file = "rpds_py-0.24.0-pp311-pypy311_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9ea7f4174d2e4194289cb0c4e172d83e79a6404297ff95f2875cf9ac9bced8ba"},
+ {file = "rpds_py-0.24.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb2954155bb8f63bb19d56d80e5e5320b61d71084617ed89efedb861a684baea"},
+ {file = "rpds_py-0.24.0-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:04f2b712a2206e13800a8136b07aaedc23af3facab84918e7aa89e4be0260032"},
+ {file = "rpds_py-0.24.0-pp311-pypy311_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:eda5c1e2a715a4cbbca2d6d304988460942551e4e5e3b7457b50943cd741626d"},
+ {file = "rpds_py-0.24.0-pp311-pypy311_pp73-musllinux_1_2_i686.whl", hash = "sha256:9abc80fe8c1f87218db116016de575a7998ab1629078c90840e8d11ab423ee25"},
+ {file = "rpds_py-0.24.0-pp311-pypy311_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:6a727fd083009bc83eb83d6950f0c32b3c94c8b80a9b667c87f4bd1274ca30ba"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:e0f3ef95795efcd3b2ec3fe0a5bcfb5dadf5e3996ea2117427e524d4fbf309c6"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:2c13777ecdbbba2077670285dd1fe50828c8742f6a4119dbef6f83ea13ad10fb"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79e8d804c2ccd618417e96720ad5cd076a86fa3f8cb310ea386a3e6229bae7d1"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fd822f019ccccd75c832deb7aa040bb02d70a92eb15a2f16c7987b7ad4ee8d83"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0047638c3aa0dbcd0ab99ed1e549bbf0e142c9ecc173b6492868432d8989a046"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a5b66d1b201cc71bc3081bc2f1fc36b0c1f268b773e03bbc39066651b9e18391"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dbcbb6db5582ea33ce46a5d20a5793134b5365110d84df4e30b9d37c6fd40ad3"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:63981feca3f110ed132fd217bf7768ee8ed738a55549883628ee3da75bb9cb78"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:3a55fc10fdcbf1a4bd3c018eea422c52cf08700cf99c28b5cb10fe97ab77a0d3"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:c30ff468163a48535ee7e9bf21bd14c7a81147c0e58a36c1078289a8ca7af0bd"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:369d9c6d4c714e36d4a03957b4783217a3ccd1e222cdd67d464a3a479fc17796"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:24795c099453e3721fda5d8ddd45f5dfcc8e5a547ce7b8e9da06fecc3832e26f"},
+ {file = "rpds_py-0.24.0.tar.gz", hash = "sha256:772cc1b2cd963e7e17e6cc55fe0371fb9c704d63e44cacec7b9b7f523b78919e"},
+]
+
+[[package]]
+name = "safetensors"
+version = "0.5.3"
+description = ""
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "safetensors-0.5.3-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:bd20eb133db8ed15b40110b7c00c6df51655a2998132193de2f75f72d99c7073"},
+ {file = "safetensors-0.5.3-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:21d01c14ff6c415c485616b8b0bf961c46b3b343ca59110d38d744e577f9cce7"},
+ {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11bce6164887cd491ca75c2326a113ba934be596e22b28b1742ce27b1d076467"},
+ {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4a243be3590bc3301c821da7a18d87224ef35cbd3e5f5727e4e0728b8172411e"},
+ {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8bd84b12b1670a6f8e50f01e28156422a2bc07fb16fc4e98bded13039d688a0d"},
+ {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:391ac8cab7c829452175f871fcaf414aa1e292b5448bd02620f675a7f3e7abb9"},
+ {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cead1fa41fc54b1e61089fa57452e8834f798cb1dc7a09ba3524f1eb08e0317a"},
+ {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1077f3e94182d72618357b04b5ced540ceb71c8a813d3319f1aba448e68a770d"},
+ {file = "safetensors-0.5.3-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:799021e78287bac619c7b3f3606730a22da4cda27759ddf55d37c8db7511c74b"},
+ {file = "safetensors-0.5.3-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:df26da01aaac504334644e1b7642fa000bfec820e7cef83aeac4e355e03195ff"},
+ {file = "safetensors-0.5.3-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:32c3ef2d7af8b9f52ff685ed0bc43913cdcde135089ae322ee576de93eae5135"},
+ {file = "safetensors-0.5.3-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:37f1521be045e56fc2b54c606d4455573e717b2d887c579ee1dbba5f868ece04"},
+ {file = "safetensors-0.5.3-cp38-abi3-win32.whl", hash = "sha256:cfc0ec0846dcf6763b0ed3d1846ff36008c6e7290683b61616c4b040f6a54ace"},
+ {file = "safetensors-0.5.3-cp38-abi3-win_amd64.whl", hash = "sha256:836cbbc320b47e80acd40e44c8682db0e8ad7123209f69b093def21ec7cafd11"},
+ {file = "safetensors-0.5.3.tar.gz", hash = "sha256:b6b0d6ecacec39a4fdd99cc19f4576f5219ce858e6fd8dbe7609df0b8dc56965"},
+]
+
+[package.extras]
+all = ["safetensors[jax]", "safetensors[numpy]", "safetensors[paddlepaddle]", "safetensors[pinned-tf]", "safetensors[quality]", "safetensors[testing]", "safetensors[torch]"]
+dev = ["safetensors[all]"]
+jax = ["flax (>=0.6.3)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "safetensors[numpy]"]
+mlx = ["mlx (>=0.0.9)"]
+numpy = ["numpy (>=1.21.6)"]
+paddlepaddle = ["paddlepaddle (>=2.4.1)", "safetensors[numpy]"]
+pinned-tf = ["safetensors[numpy]", "tensorflow (==2.18.0)"]
+quality = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "isort (>=5.5.4)"]
+tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"]
+testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"]
+torch = ["safetensors[numpy]", "torch (>=1.10)"]
+
+[[package]]
+name = "scikit-learn"
+version = "1.6.1"
+description = "A set of python modules for machine learning and data mining"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "scikit_learn-1.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d056391530ccd1e501056160e3c9673b4da4805eb67eb2bdf4e983e1f9c9204e"},
+ {file = "scikit_learn-1.6.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:0c8d036eb937dbb568c6242fa598d551d88fb4399c0344d95c001980ec1c7d36"},
+ {file = "scikit_learn-1.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8634c4bd21a2a813e0a7e3900464e6d593162a29dd35d25bdf0103b3fce60ed5"},
+ {file = "scikit_learn-1.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:775da975a471c4f6f467725dff0ced5c7ac7bda5e9316b260225b48475279a1b"},
+ {file = "scikit_learn-1.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:8a600c31592bd7dab31e1c61b9bbd6dea1b3433e67d264d17ce1017dbdce8002"},
+ {file = "scikit_learn-1.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:72abc587c75234935e97d09aa4913a82f7b03ee0b74111dcc2881cba3c5a7b33"},
+ {file = "scikit_learn-1.6.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b3b00cdc8f1317b5f33191df1386c0befd16625f49d979fe77a8d44cae82410d"},
+ {file = "scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dc4765af3386811c3ca21638f63b9cf5ecf66261cc4815c1db3f1e7dc7b79db2"},
+ {file = "scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:25fc636bdaf1cc2f4a124a116312d837148b5e10872147bdaf4887926b8c03d8"},
+ {file = "scikit_learn-1.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:fa909b1a36e000a03c382aade0bd2063fd5680ff8b8e501660c0f59f021a6415"},
+ {file = "scikit_learn-1.6.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:926f207c804104677af4857b2c609940b743d04c4c35ce0ddc8ff4f053cddc1b"},
+ {file = "scikit_learn-1.6.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:2c2cae262064e6a9b77eee1c8e768fc46aa0b8338c6a8297b9b6759720ec0ff2"},
+ {file = "scikit_learn-1.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1061b7c028a8663fb9a1a1baf9317b64a257fcb036dae5c8752b2abef31d136f"},
+ {file = "scikit_learn-1.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e69fab4ebfc9c9b580a7a80111b43d214ab06250f8a7ef590a4edf72464dd86"},
+ {file = "scikit_learn-1.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:70b1d7e85b1c96383f872a519b3375f92f14731e279a7b4c6cfd650cf5dffc52"},
+ {file = "scikit_learn-1.6.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:2ffa1e9e25b3d93990e74a4be2c2fc61ee5af85811562f1288d5d055880c4322"},
+ {file = "scikit_learn-1.6.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:dc5cf3d68c5a20ad6d571584c0750ec641cc46aeef1c1507be51300e6003a7e1"},
+ {file = "scikit_learn-1.6.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c06beb2e839ecc641366000ca84f3cf6fa9faa1777e29cf0c04be6e4d096a348"},
+ {file = "scikit_learn-1.6.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8ca8cb270fee8f1f76fa9bfd5c3507d60c6438bbee5687f81042e2bb98e5a97"},
+ {file = "scikit_learn-1.6.1-cp313-cp313-win_amd64.whl", hash = "sha256:7a1c43c8ec9fde528d664d947dc4c0789be4077a3647f232869f41d9bf50e0fb"},
+ {file = "scikit_learn-1.6.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:a17c1dea1d56dcda2fac315712f3651a1fea86565b64b48fa1bc090249cbf236"},
+ {file = "scikit_learn-1.6.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:6a7aa5f9908f0f28f4edaa6963c0a6183f1911e63a69aa03782f0d924c830a35"},
+ {file = "scikit_learn-1.6.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0650e730afb87402baa88afbf31c07b84c98272622aaba002559b614600ca691"},
+ {file = "scikit_learn-1.6.1-cp313-cp313t-win_amd64.whl", hash = "sha256:3f59fe08dc03ea158605170eb52b22a105f238a5d512c4470ddeca71feae8e5f"},
+ {file = "scikit_learn-1.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6849dd3234e87f55dce1db34c89a810b489ead832aaf4d4550b7ea85628be6c1"},
+ {file = "scikit_learn-1.6.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:e7be3fa5d2eb9be7d77c3734ff1d599151bb523674be9b834e8da6abe132f44e"},
+ {file = "scikit_learn-1.6.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:44a17798172df1d3c1065e8fcf9019183f06c87609b49a124ebdf57ae6cb0107"},
+ {file = "scikit_learn-1.6.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8b7a3b86e411e4bce21186e1c180d792f3d99223dcfa3b4f597ecc92fa1a422"},
+ {file = "scikit_learn-1.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:7a73d457070e3318e32bdb3aa79a8d990474f19035464dfd8bede2883ab5dc3b"},
+ {file = "scikit_learn-1.6.1.tar.gz", hash = "sha256:b4fc2525eca2c69a59260f583c56a7557c6ccdf8deafdba6e060f94c1c59738e"},
+]
+
+[package.dependencies]
+joblib = ">=1.2.0"
+numpy = ">=1.19.5"
+scipy = ">=1.6.0"
+threadpoolctl = ">=3.1.0"
+
+[package.extras]
+benchmark = ["matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "pandas (>=1.1.5)"]
+build = ["cython (>=3.0.10)", "meson-python (>=0.16.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)"]
+docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-design (>=0.6.0)", "sphinx-gallery (>=0.17.1)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)", "towncrier (>=24.8.0)"]
+examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"]
+install = ["joblib (>=1.2.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)", "threadpoolctl (>=3.1.0)"]
+maintenance = ["conda-lock (==2.5.6)"]
+tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.5.1)", "scikit-image (>=0.17.2)"]
+
+[[package]]
+name = "scipy"
+version = "1.13.1"
+description = "Fundamental algorithms for scientific computing in Python"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "scipy-1.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca"},
+ {file = "scipy-1.13.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f"},
+ {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989"},
+ {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f"},
+ {file = "scipy-1.13.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94"},
+ {file = "scipy-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54"},
+ {file = "scipy-1.13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9"},
+ {file = "scipy-1.13.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326"},
+ {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299"},
+ {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa"},
+ {file = "scipy-1.13.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59"},
+ {file = "scipy-1.13.1-cp311-cp311-win_amd64.whl", hash = "sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b"},
+ {file = "scipy-1.13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1"},
+ {file = "scipy-1.13.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d"},
+ {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627"},
+ {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884"},
+ {file = "scipy-1.13.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16"},
+ {file = "scipy-1.13.1-cp312-cp312-win_amd64.whl", hash = "sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949"},
+ {file = "scipy-1.13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5"},
+ {file = "scipy-1.13.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24"},
+ {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004"},
+ {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d"},
+ {file = "scipy-1.13.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c"},
+ {file = "scipy-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2"},
+ {file = "scipy-1.13.1.tar.gz", hash = "sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c"},
+]
+
+[package.dependencies]
+numpy = ">=1.22.4,<2.3"
+
+[package.extras]
+dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"]
+doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.12.0)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"]
+test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
+
+[[package]]
+name = "sentence-transformers"
+version = "3.3.1"
+description = "State-of-the-Art Text Embeddings"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "sentence_transformers-3.3.1-py3-none-any.whl", hash = "sha256:abffcc79dab37b7d18d21a26d5914223dd42239cfe18cb5e111c66c54b658ae7"},
+ {file = "sentence_transformers-3.3.1.tar.gz", hash = "sha256:9635dbfb11c6b01d036b9cfcee29f7716ab64cf2407ad9f403a2e607da2ac48b"},
+]
+
+[package.dependencies]
+huggingface-hub = ">=0.20.0"
+Pillow = "*"
+scikit-learn = "*"
+scipy = "*"
+torch = ">=1.11.0"
+tqdm = "*"
+transformers = ">=4.41.0,<5.0.0"
+
+[package.extras]
+dev = ["accelerate (>=0.20.3)", "datasets", "peft", "pre-commit", "pytest", "pytest-cov"]
+onnx = ["optimum[onnxruntime] (>=1.23.1)"]
+onnx-gpu = ["optimum[onnxruntime-gpu] (>=1.23.1)"]
+openvino = ["optimum-intel[openvino] (>=1.20.0)"]
+train = ["accelerate (>=0.20.3)", "datasets"]
+
+[[package]]
+name = "sentencepiece"
+version = "0.2.0"
+description = "SentencePiece python wrapper"
+optional = false
+python-versions = "*"
+groups = ["main"]
+files = [
+ {file = "sentencepiece-0.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:188779e1298a1c8b8253c7d3ad729cb0a9891e5cef5e5d07ce4592c54869e227"},
+ {file = "sentencepiece-0.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bed9cf85b296fa2b76fc2547b9cbb691a523864cebaee86304c43a7b4cb1b452"},
+ {file = "sentencepiece-0.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d7b67e724bead13f18db6e1d10b6bbdc454af574d70efbb36f27d90387be1ca3"},
+ {file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2fde4b08cfe237be4484c6c7c2e2c75fb862cfeab6bd5449ce4caeafd97b767a"},
+ {file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c378492056202d1c48a4979650981635fd97875a00eabb1f00c6a236b013b5e"},
+ {file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1380ce6540a368de2ef6d7e6ba14ba8f3258df650d39ba7d833b79ee68a52040"},
+ {file = "sentencepiece-0.2.0-cp310-cp310-win32.whl", hash = "sha256:a1151d6a6dd4b43e552394aed0edfe9292820272f0194bd56c7c1660a0c06c3d"},
+ {file = "sentencepiece-0.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:d490142b0521ef22bc1085f061d922a2a6666175bb6b42e588ff95c0db6819b2"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:17982700c4f6dbb55fa3594f3d7e5dd1c8659a274af3738e33c987d2a27c9d5c"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7c867012c0e8bcd5bdad0f791609101cb5c66acb303ab3270218d6debc68a65e"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7fd6071249c74f779c5b27183295b9202f8dedb68034e716784364443879eaa6"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27f90c55a65013cbb8f4d7aab0599bf925cde4adc67ae43a0d323677b5a1c6cb"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b293734059ef656dcd65be62ff771507bea8fed0a711b6733976e1ed3add4553"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e58b47f933aca74c6a60a79dcb21d5b9e47416256c795c2d58d55cec27f9551d"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-win32.whl", hash = "sha256:c581258cf346b327c62c4f1cebd32691826306f6a41d8c4bec43b010dee08e75"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:0993dbc665f4113017892f1b87c3904a44d0640eda510abcacdfb07f74286d36"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ea5f536e32ea8ec96086ee00d7a4a131ce583a1b18d130711707c10e69601cb2"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0cb51f53b6aae3c36bafe41e86167c71af8370a039f542c43b0cce5ef24a68c"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3212121805afc58d8b00ab4e7dd1f8f76c203ddb9dc94aa4079618a31cf5da0f"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a3149e3066c2a75e0d68a43eb632d7ae728c7925b517f4c05c40f6f7280ce08"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:632f3594d3e7ac8b367bca204cb3fd05a01d5b21455acd097ea4c0e30e2f63d7"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f295105c6bdbb05bd5e1b0cafbd78ff95036f5d3641e7949455a3f4e5e7c3109"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-win32.whl", hash = "sha256:fb89f811e5efd18bab141afc3fea3de141c3f69f3fe9e898f710ae7fe3aab251"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a673a72aab81fef5ebe755c6e0cc60087d1f3a4700835d40537183c1703a45f"},
+ {file = "sentencepiece-0.2.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4547683f330289ec4f093027bfeb87f9ef023b2eb6f879fdc4a8187c7e0ffb90"},
+ {file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7cd6175f7eaec7142d2bf6f6597ce7db4c9ac89acf93fcdb17410c3a8b781eeb"},
+ {file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:859ba1acde782609a0910a26a60e16c191a82bf39b5621107552c0cd79fad00f"},
+ {file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcbbef6cc277f8f18f36959e305f10b1c620442d75addc79c21d7073ae581b50"},
+ {file = "sentencepiece-0.2.0-cp36-cp36m-win32.whl", hash = "sha256:536b934e244829e3fe6c4f198652cd82da48adb9aa145c9f00889542726dee3d"},
+ {file = "sentencepiece-0.2.0-cp36-cp36m-win_amd64.whl", hash = "sha256:0a91aaa3c769b52440df56fafda683b3aa48e3f2169cf7ee5b8c8454a7f3ae9b"},
+ {file = "sentencepiece-0.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:787e480ca4c1d08c9985a7eb1eae4345c107729c99e9b5a9a00f2575fc7d4b4b"},
+ {file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4d158189eb2ecffea3a51edf6d25e110b3678ec47f1a40f2d541eafbd8f6250"},
+ {file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1e5ca43013e8935f25457a4fca47e315780172c3e821b4b13a890668911c792"},
+ {file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7140d9e5a74a0908493bb4a13f1f16a401297bd755ada4c707e842fbf6f0f5bf"},
+ {file = "sentencepiece-0.2.0-cp37-cp37m-win32.whl", hash = "sha256:6cf333625234f247ab357b0bd9836638405ea9082e1543d5b8408f014979dcbf"},
+ {file = "sentencepiece-0.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:ff88712338b01031910e8e61e7239aff3ce8869ee31a47df63cb38aadd591bea"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:20813a68d4c221b1849c62c30e1281ea81687894d894b8d4a0f4677d9311e0f5"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:926ef920ae2e8182db31d3f5d081ada57804e3e1d3a8c4ef8b117f9d9fb5a945"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:89f65f69636b7e9c015b79dff9c9985a9bc7d19ded6f79ef9f1ec920fdd73ecf"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f67eae0dbe6f2d7d6ba50a354623d787c99965f068b81e145d53240198021b0"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:98501e075f35dd1a1d5a20f65be26839fcb1938752ec61539af008a5aa6f510b"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3d1d2cc4882e8d6a1adf9d5927d7716f80617fc693385661caff21888972269"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-win32.whl", hash = "sha256:b99a308a2e5e569031ab164b74e6fab0b6f37dfb493c32f7816225f4d411a6dd"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:cdb701eec783d3ec86b7cd4c763adad8eaf6b46db37ee1c36e5e6c44b3fe1b5f"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1e0f9c4d0a6b0af59b613175f019916e28ade076e21242fd5be24340d8a2f64a"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:298f21cc1366eb60311aedba3169d30f885c363ddbf44214b0a587d2908141ad"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3f1ec95aa1e5dab11f37ac7eff190493fd87770f7a8b81ebc9dd768d1a3c8704"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b06b70af54daa4b4904cbb90b4eb6d35c9f3252fdc86c9c32d5afd4d30118d8"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:22e37bac44dd6603388cb598c64ff7a76e41ca774646f21c23aadfbf5a2228ab"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0461324897735512a32d222e3d886e24ad6a499761952b6bda2a9ee6e4313ea5"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-win32.whl", hash = "sha256:38aed822fb76435fa1f12185f10465a94ab9e51d5e8a9159e9a540ce926f0ffd"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:d8cf876516548b5a1d6ac4745d8b554f5c07891d55da557925e5c13ff0b4e6ad"},
+ {file = "sentencepiece-0.2.0.tar.gz", hash = "sha256:a52c19171daaf2e697dc6cbe67684e0fa341b1248966f6aebb541de654d15843"},
+]
+
+[[package]]
+name = "setuptools"
+version = "78.1.0"
+description = "Easily download, build, install, upgrade, and uninstall Python packages"
+optional = false
+python-versions = ">=3.9"
+groups = ["main", "dev"]
+files = [
+ {file = "setuptools-78.1.0-py3-none-any.whl", hash = "sha256:3e386e96793c8702ae83d17b853fb93d3e09ef82ec62722e61da5cd22376dcd8"},
+ {file = "setuptools-78.1.0.tar.gz", hash = "sha256:18fd474d4a82a5f83dac888df697af65afa82dec7323d09c3e37d1f14288da54"},
+]
+markers = {main = "python_version >= \"3.12\""}
+
+[package.extras]
+check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.8.0)"]
+core = ["importlib_metadata (>=6)", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"]
+cover = ["pytest-cov"]
+doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
+enabler = ["pytest-enabler (>=2.2)"]
+test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
+type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.14.*)", "pytest-mypy"]
+
+[[package]]
+name = "shellingham"
+version = "1.5.4"
+description = "Tool to Detect Surrounding Shell"
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686"},
+ {file = "shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de"},
+]
+
+[[package]]
+name = "sympy"
+version = "1.13.1"
+description = "Computer algebra system (CAS) in Python"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8"},
+ {file = "sympy-1.13.1.tar.gz", hash = "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f"},
+]
+
+[package.dependencies]
+mpmath = ">=1.1.0,<1.4"
+
+[package.extras]
+dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"]
+
+[[package]]
+name = "threadpoolctl"
+version = "3.6.0"
+description = "threadpoolctl"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb"},
+ {file = "threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e"},
+]
+
+[[package]]
+name = "tokenizers"
+version = "0.21.1"
+description = ""
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "tokenizers-0.21.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:e78e413e9e668ad790a29456e677d9d3aa50a9ad311a40905d6861ba7692cf41"},
+ {file = "tokenizers-0.21.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:cd51cd0a91ecc801633829fcd1fda9cf8682ed3477c6243b9a095539de4aecf3"},
+ {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28da6b72d4fb14ee200a1bd386ff74ade8992d7f725f2bde2c495a9a98cf4d9f"},
+ {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:34d8cfde551c9916cb92014e040806122295a6800914bab5865deb85623931cf"},
+ {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aaa852d23e125b73d283c98f007e06d4595732104b65402f46e8ef24b588d9f8"},
+ {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a21a15d5c8e603331b8a59548bbe113564136dc0f5ad8306dd5033459a226da0"},
+ {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2fdbd4c067c60a0ac7eca14b6bd18a5bebace54eb757c706b47ea93204f7a37c"},
+ {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2dd9a0061e403546f7377df940e866c3e678d7d4e9643d0461ea442b4f89e61a"},
+ {file = "tokenizers-0.21.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:db9484aeb2e200c43b915a1a0150ea885e35f357a5a8fabf7373af333dcc8dbf"},
+ {file = "tokenizers-0.21.1-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:ed248ab5279e601a30a4d67bdb897ecbe955a50f1e7bb62bd99f07dd11c2f5b6"},
+ {file = "tokenizers-0.21.1-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:9ac78b12e541d4ce67b4dfd970e44c060a2147b9b2a21f509566d556a509c67d"},
+ {file = "tokenizers-0.21.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:e5a69c1a4496b81a5ee5d2c1f3f7fbdf95e90a0196101b0ee89ed9956b8a168f"},
+ {file = "tokenizers-0.21.1-cp39-abi3-win32.whl", hash = "sha256:1039a3a5734944e09de1d48761ade94e00d0fa760c0e0551151d4dd851ba63e3"},
+ {file = "tokenizers-0.21.1-cp39-abi3-win_amd64.whl", hash = "sha256:0f0dcbcc9f6e13e675a66d7a5f2f225a736745ce484c1a4e07476a89ccdad382"},
+ {file = "tokenizers-0.21.1.tar.gz", hash = "sha256:a1bb04dc5b448985f86ecd4b05407f5a8d97cb2c0532199b2a302a604a0165ab"},
+]
+
+[package.dependencies]
+huggingface-hub = ">=0.16.4,<1.0"
+
+[package.extras]
+dev = ["tokenizers[testing]"]
+docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"]
+testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"]
+
+[[package]]
+name = "tomli"
+version = "2.2.1"
+description = "A lil' TOML parser"
+optional = false
+python-versions = ">=3.8"
+groups = ["dev"]
+markers = "python_version < \"3.11\""
+files = [
+ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
+ {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
+ {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a"},
+ {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee"},
+ {file = "tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e"},
+ {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4"},
+ {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106"},
+ {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8"},
+ {file = "tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff"},
+ {file = "tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b"},
+ {file = "tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea"},
+ {file = "tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8"},
+ {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192"},
+ {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222"},
+ {file = "tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77"},
+ {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6"},
+ {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd"},
+ {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e"},
+ {file = "tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98"},
+ {file = "tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4"},
+ {file = "tomli-2.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f4039b9cbc3048b2416cc57ab3bda989a6fcf9b36cf8937f01a6e731b64f80d7"},
+ {file = "tomli-2.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:286f0ca2ffeeb5b9bd4fcc8d6c330534323ec51b2f52da063b11c502da16f30c"},
+ {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a92ef1a44547e894e2a17d24e7557a5e85a9e1d0048b0b5e7541f76c5032cb13"},
+ {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9316dc65bed1684c9a98ee68759ceaed29d229e985297003e494aa825ebb0281"},
+ {file = "tomli-2.2.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e85e99945e688e32d5a35c1ff38ed0b3f41f43fad8df0bdf79f72b2ba7bc5272"},
+ {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ac065718db92ca818f8d6141b5f66369833d4a80a9d74435a268c52bdfa73140"},
+ {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d920f33822747519673ee656a4b6ac33e382eca9d331c87770faa3eef562aeb2"},
+ {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a198f10c4d1b1375d7687bc25294306e551bf1abfa4eace6650070a5c1ae2744"},
+ {file = "tomli-2.2.1-cp313-cp313-win32.whl", hash = "sha256:d3f5614314d758649ab2ab3a62d4f2004c825922f9e370b29416484086b264ec"},
+ {file = "tomli-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69"},
+ {file = "tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc"},
+ {file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"},
+]
+
+[[package]]
+name = "tqdm"
+version = "4.67.1"
+description = "Fast, Extensible Progress Meter"
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2"},
+ {file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"},
+]
+
+[package.dependencies]
+colorama = {version = "*", markers = "platform_system == \"Windows\""}
+
+[package.extras]
+dev = ["nbval", "pytest (>=6)", "pytest-asyncio (>=0.24)", "pytest-cov", "pytest-timeout"]
+discord = ["requests"]
+notebook = ["ipywidgets (>=6)"]
+slack = ["slack-sdk"]
+telegram = ["requests"]
+
+[[package]]
+name = "transformers"
+version = "4.49.0"
+description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
+optional = false
+python-versions = ">=3.9.0"
+groups = ["main"]
+files = [
+ {file = "transformers-4.49.0-py3-none-any.whl", hash = "sha256:6b4fded1c5fee04d384b1014495b4235a2b53c87503d7d592423c06128cbbe03"},
+ {file = "transformers-4.49.0.tar.gz", hash = "sha256:7e40e640b5b8dc3f48743f5f5adbdce3660c82baafbd3afdfc04143cdbd2089e"},
+]
+
+[package.dependencies]
+filelock = "*"
+huggingface-hub = ">=0.26.0,<1.0"
+numpy = ">=1.17"
+packaging = ">=20.0"
+pyyaml = ">=5.1"
+regex = "!=2019.12.17"
+requests = "*"
+safetensors = ">=0.4.1"
+tokenizers = ">=0.21,<0.22"
+tqdm = ">=4.27"
+
+[package.extras]
+accelerate = ["accelerate (>=0.26.0)"]
+agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=2.0)"]
+all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "codecarbon (>=2.8.1)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision"]
+audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
+benchmark = ["optimum-benchmark (>=0.3.0)"]
+codecarbon = ["codecarbon (>=2.8.1)"]
+deepspeed = ["accelerate (>=0.26.0)", "deepspeed (>=0.9.3)"]
+deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
+dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
+dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"]
+dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "libcst", "librosa", "nltk (<=3.8.1)", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
+flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"]
+flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
+ftfy = ["ftfy"]
+integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"]
+ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"]
+modelcreation = ["cookiecutter (==1.7.3)"]
+natten = ["natten (>=0.14.6,<0.15.0)"]
+onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"]
+onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
+optuna = ["optuna"]
+quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.5.1)", "urllib3 (<2.0.0)"]
+ray = ["ray[tune] (>=2.7.0)"]
+retrieval = ["datasets (!=2.5.0)", "faiss-cpu"]
+ruff = ["ruff (==0.5.1)"]
+sagemaker = ["sagemaker (>=2.31.0)"]
+sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"]
+serving = ["fastapi", "pydantic", "starlette", "uvicorn"]
+sigopt = ["sigopt"]
+sklearn = ["scikit-learn"]
+speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
+testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
+tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"]
+tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"]
+tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
+tiktoken = ["blobfile", "tiktoken"]
+timm = ["timm (<=1.0.11)"]
+tokenizers = ["tokenizers (>=0.21,<0.22)"]
+torch = ["accelerate (>=0.26.0)", "torch (>=2.0)"]
+torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
+torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"]
+torchhub = ["filelock", "huggingface-hub (>=0.26.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "tqdm (>=4.27)"]
+video = ["av"]
+vision = ["Pillow (>=10.0.1,<=15.0)"]
+
+[[package]]
+name = "triton"
+version = "3.2.0"
+description = "A language and compiler for custom Deep Learning operations"
+optional = false
+python-versions = "*"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3e54983cd51875855da7c68ec05c05cf8bb08df361b1d5b69e05e40b0c9bd62"},
+ {file = "triton-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8009a1fb093ee8546495e96731336a33fb8856a38e45bb4ab6affd6dbc3ba220"},
+ {file = "triton-3.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d9b215efc1c26fa7eefb9a157915c92d52e000d2bf83e5f69704047e63f125c"},
+ {file = "triton-3.2.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5dfa23ba84541d7c0a531dfce76d8bcd19159d50a4a8b14ad01e91734a5c1b0"},
+ {file = "triton-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:30ceed0eff2c4a73b14eb63e052992f44bbdf175f3fad21e1ac8097a772de7ee"},
+]
+
+[package.extras]
+build = ["cmake (>=3.20)", "lit"]
+tests = ["autopep8", "flake8", "isort", "llnl-hatchet", "numpy", "pytest", "scipy (>=1.7.1)"]
+tutorials = ["matplotlib", "pandas", "tabulate"]
+
+[[package]]
+name = "typer"
+version = "0.15.2"
+description = "Typer, build great CLIs. Easy to code. Based on Python type hints."
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "typer-0.15.2-py3-none-any.whl", hash = "sha256:46a499c6107d645a9c13f7ee46c5d5096cae6f5fc57dd11eccbbb9ae3e44ddfc"},
+ {file = "typer-0.15.2.tar.gz", hash = "sha256:ab2fab47533a813c49fe1f16b1a370fd5819099c00b119e0633df65f22144ba5"},
+]
+
+[package.dependencies]
+click = ">=8.0.0"
+rich = ">=10.11.0"
+shellingham = ">=1.3.0"
+typing-extensions = ">=3.7.4.3"
+
+[[package]]
+name = "typing-extensions"
+version = "4.13.2"
+description = "Backported and Experimental Type Hints for Python 3.8+"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "typing_extensions-4.13.2-py3-none-any.whl", hash = "sha256:a439e7c04b49fec3e5d3e2beaa21755cadbbdc391694e28ccdd36ca4a1408f8c"},
+ {file = "typing_extensions-4.13.2.tar.gz", hash = "sha256:e6c81219bd689f51865d9e372991c540bda33a0379d5573cddb9a3a23f7caaef"},
+]
+
+[[package]]
+name = "typing-inspection"
+version = "0.4.0"
+description = "Runtime typing introspection tools"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "typing_inspection-0.4.0-py3-none-any.whl", hash = "sha256:50e72559fcd2a6367a19f7a7e610e6afcb9fac940c650290eed893d61386832f"},
+ {file = "typing_inspection-0.4.0.tar.gz", hash = "sha256:9765c87de36671694a67904bf2c96e395be9c6439bb6c87b5142569dcdd65122"},
+]
+
+[package.dependencies]
+typing-extensions = ">=4.12.0"
+
+[[package]]
+name = "urllib3"
+version = "2.4.0"
+description = "HTTP library with thread-safe connection pooling, file post, and more."
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "urllib3-2.4.0-py3-none-any.whl", hash = "sha256:4e16665048960a0900c702d4a66415956a584919c03361cac9f1df5c5dd7e813"},
+ {file = "urllib3-2.4.0.tar.gz", hash = "sha256:414bc6535b787febd7567804cc015fee39daab8ad86268f1310a9250697de466"},
+]
+
+[package.extras]
+brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
+h2 = ["h2 (>=4,<5)"]
+socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
+zstd = ["zstandard (>=0.18.0)"]
+
+[[package]]
+name = "win32-setctime"
+version = "1.2.0"
+description = "A small Python utility to set file creation time on Windows"
+optional = false
+python-versions = ">=3.5"
+groups = ["main"]
+markers = "sys_platform == \"win32\""
+files = [
+ {file = "win32_setctime-1.2.0-py3-none-any.whl", hash = "sha256:95d644c4e708aba81dc3704a116d8cbc974d70b3bdb8be1d150e36be6e9d1390"},
+ {file = "win32_setctime-1.2.0.tar.gz", hash = "sha256:ae1fdf948f5640aae05c511ade119313fb6a30d7eabe25fef9764dca5873c4c0"},
+]
+
+[package.extras]
+dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"]
+
+[[package]]
+name = "wrapt"
+version = "1.17.2"
+description = "Module for decorators, wrappers and monkey patching."
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "wrapt-1.17.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3d57c572081fed831ad2d26fd430d565b76aa277ed1d30ff4d40670b1c0dd984"},
+ {file = "wrapt-1.17.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5e251054542ae57ac7f3fba5d10bfff615b6c2fb09abeb37d2f1463f841ae22"},
+ {file = "wrapt-1.17.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:80dd7db6a7cb57ffbc279c4394246414ec99537ae81ffd702443335a61dbf3a7"},
+ {file = "wrapt-1.17.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a6e821770cf99cc586d33833b2ff32faebdbe886bd6322395606cf55153246c"},
+ {file = "wrapt-1.17.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b60fb58b90c6d63779cb0c0c54eeb38941bae3ecf7a73c764c52c88c2dcb9d72"},
+ {file = "wrapt-1.17.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b870b5df5b71d8c3359d21be8f0d6c485fa0ebdb6477dda51a1ea54a9b558061"},
+ {file = "wrapt-1.17.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4011d137b9955791f9084749cba9a367c68d50ab8d11d64c50ba1688c9b457f2"},
+ {file = "wrapt-1.17.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:1473400e5b2733e58b396a04eb7f35f541e1fb976d0c0724d0223dd607e0f74c"},
+ {file = "wrapt-1.17.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3cedbfa9c940fdad3e6e941db7138e26ce8aad38ab5fe9dcfadfed9db7a54e62"},
+ {file = "wrapt-1.17.2-cp310-cp310-win32.whl", hash = "sha256:582530701bff1dec6779efa00c516496968edd851fba224fbd86e46cc6b73563"},
+ {file = "wrapt-1.17.2-cp310-cp310-win_amd64.whl", hash = "sha256:58705da316756681ad3c9c73fd15499aa4d8c69f9fd38dc8a35e06c12468582f"},
+ {file = "wrapt-1.17.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ff04ef6eec3eee8a5efef2401495967a916feaa353643defcc03fc74fe213b58"},
+ {file = "wrapt-1.17.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4db983e7bca53819efdbd64590ee96c9213894272c776966ca6306b73e4affda"},
+ {file = "wrapt-1.17.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9abc77a4ce4c6f2a3168ff34b1da9b0f311a8f1cfd694ec96b0603dff1c79438"},
+ {file = "wrapt-1.17.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b929ac182f5ace000d459c59c2c9c33047e20e935f8e39371fa6e3b85d56f4a"},
+ {file = "wrapt-1.17.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f09b286faeff3c750a879d336fb6d8713206fc97af3adc14def0cdd349df6000"},
+ {file = "wrapt-1.17.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a7ed2d9d039bd41e889f6fb9364554052ca21ce823580f6a07c4ec245c1f5d6"},
+ {file = "wrapt-1.17.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:129a150f5c445165ff941fc02ee27df65940fcb8a22a61828b1853c98763a64b"},
+ {file = "wrapt-1.17.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1fb5699e4464afe5c7e65fa51d4f99e0b2eadcc176e4aa33600a3df7801d6662"},
+ {file = "wrapt-1.17.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9a2bce789a5ea90e51a02dfcc39e31b7f1e662bc3317979aa7e5538e3a034f72"},
+ {file = "wrapt-1.17.2-cp311-cp311-win32.whl", hash = "sha256:4afd5814270fdf6380616b321fd31435a462019d834f83c8611a0ce7484c7317"},
+ {file = "wrapt-1.17.2-cp311-cp311-win_amd64.whl", hash = "sha256:acc130bc0375999da18e3d19e5a86403667ac0c4042a094fefb7eec8ebac7cf3"},
+ {file = "wrapt-1.17.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:d5e2439eecc762cd85e7bd37161d4714aa03a33c5ba884e26c81559817ca0925"},
+ {file = "wrapt-1.17.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3fc7cb4c1c744f8c05cd5f9438a3caa6ab94ce8344e952d7c45a8ed59dd88392"},
+ {file = "wrapt-1.17.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8fdbdb757d5390f7c675e558fd3186d590973244fab0c5fe63d373ade3e99d40"},
+ {file = "wrapt-1.17.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5bb1d0dbf99411f3d871deb6faa9aabb9d4e744d67dcaaa05399af89d847a91d"},
+ {file = "wrapt-1.17.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d18a4865f46b8579d44e4fe1e2bcbc6472ad83d98e22a26c963d46e4c125ef0b"},
+ {file = "wrapt-1.17.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc570b5f14a79734437cb7b0500376b6b791153314986074486e0b0fa8d71d98"},
+ {file = "wrapt-1.17.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6d9187b01bebc3875bac9b087948a2bccefe464a7d8f627cf6e48b1bbae30f82"},
+ {file = "wrapt-1.17.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:9e8659775f1adf02eb1e6f109751268e493c73716ca5761f8acb695e52a756ae"},
+ {file = "wrapt-1.17.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e8b2816ebef96d83657b56306152a93909a83f23994f4b30ad4573b00bd11bb9"},
+ {file = "wrapt-1.17.2-cp312-cp312-win32.whl", hash = "sha256:468090021f391fe0056ad3e807e3d9034e0fd01adcd3bdfba977b6fdf4213ea9"},
+ {file = "wrapt-1.17.2-cp312-cp312-win_amd64.whl", hash = "sha256:ec89ed91f2fa8e3f52ae53cd3cf640d6feff92ba90d62236a81e4e563ac0e991"},
+ {file = "wrapt-1.17.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:6ed6ffac43aecfe6d86ec5b74b06a5be33d5bb9243d055141e8cabb12aa08125"},
+ {file = "wrapt-1.17.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:35621ae4c00e056adb0009f8e86e28eb4a41a4bfa8f9bfa9fca7d343fe94f998"},
+ {file = "wrapt-1.17.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a604bf7a053f8362d27eb9fefd2097f82600b856d5abe996d623babd067b1ab5"},
+ {file = "wrapt-1.17.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5cbabee4f083b6b4cd282f5b817a867cf0b1028c54d445b7ec7cfe6505057cf8"},
+ {file = "wrapt-1.17.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:49703ce2ddc220df165bd2962f8e03b84c89fee2d65e1c24a7defff6f988f4d6"},
+ {file = "wrapt-1.17.2-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8112e52c5822fc4253f3901b676c55ddf288614dc7011634e2719718eaa187dc"},
+ {file = "wrapt-1.17.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9fee687dce376205d9a494e9c121e27183b2a3df18037f89d69bd7b35bcf59e2"},
+ {file = "wrapt-1.17.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:18983c537e04d11cf027fbb60a1e8dfd5190e2b60cc27bc0808e653e7b218d1b"},
+ {file = "wrapt-1.17.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:703919b1633412ab54bcf920ab388735832fdcb9f9a00ae49387f0fe67dad504"},
+ {file = "wrapt-1.17.2-cp313-cp313-win32.whl", hash = "sha256:abbb9e76177c35d4e8568e58650aa6926040d6a9f6f03435b7a522bf1c487f9a"},
+ {file = "wrapt-1.17.2-cp313-cp313-win_amd64.whl", hash = "sha256:69606d7bb691b50a4240ce6b22ebb319c1cfb164e5f6569835058196e0f3a845"},
+ {file = "wrapt-1.17.2-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:4a721d3c943dae44f8e243b380cb645a709ba5bd35d3ad27bc2ed947e9c68192"},
+ {file = "wrapt-1.17.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:766d8bbefcb9e00c3ac3b000d9acc51f1b399513f44d77dfe0eb026ad7c9a19b"},
+ {file = "wrapt-1.17.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e496a8ce2c256da1eb98bd15803a79bee00fc351f5dfb9ea82594a3f058309e0"},
+ {file = "wrapt-1.17.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40d615e4fe22f4ad3528448c193b218e077656ca9ccb22ce2cb20db730f8d306"},
+ {file = "wrapt-1.17.2-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a5aaeff38654462bc4b09023918b7f21790efb807f54c000a39d41d69cf552cb"},
+ {file = "wrapt-1.17.2-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a7d15bbd2bc99e92e39f49a04653062ee6085c0e18b3b7512a4f2fe91f2d681"},
+ {file = "wrapt-1.17.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:e3890b508a23299083e065f435a492b5435eba6e304a7114d2f919d400888cc6"},
+ {file = "wrapt-1.17.2-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:8c8b293cd65ad716d13d8dd3624e42e5a19cc2a2f1acc74b30c2c13f15cb61a6"},
+ {file = "wrapt-1.17.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:4c82b8785d98cdd9fed4cac84d765d234ed3251bd6afe34cb7ac523cb93e8b4f"},
+ {file = "wrapt-1.17.2-cp313-cp313t-win32.whl", hash = "sha256:13e6afb7fe71fe7485a4550a8844cc9ffbe263c0f1a1eea569bc7091d4898555"},
+ {file = "wrapt-1.17.2-cp313-cp313t-win_amd64.whl", hash = "sha256:eaf675418ed6b3b31c7a989fd007fa7c3be66ce14e5c3b27336383604c9da85c"},
+ {file = "wrapt-1.17.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5c803c401ea1c1c18de70a06a6f79fcc9c5acfc79133e9869e730ad7f8ad8ef9"},
+ {file = "wrapt-1.17.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f917c1180fdb8623c2b75a99192f4025e412597c50b2ac870f156de8fb101119"},
+ {file = "wrapt-1.17.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ecc840861360ba9d176d413a5489b9a0aff6d6303d7e733e2c4623cfa26904a6"},
+ {file = "wrapt-1.17.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb87745b2e6dc56361bfde481d5a378dc314b252a98d7dd19a651a3fa58f24a9"},
+ {file = "wrapt-1.17.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:58455b79ec2661c3600e65c0a716955adc2410f7383755d537584b0de41b1d8a"},
+ {file = "wrapt-1.17.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4e42a40a5e164cbfdb7b386c966a588b1047558a990981ace551ed7e12ca9c2"},
+ {file = "wrapt-1.17.2-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:91bd7d1773e64019f9288b7a5101f3ae50d3d8e6b1de7edee9c2ccc1d32f0c0a"},
+ {file = "wrapt-1.17.2-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:bb90fb8bda722a1b9d48ac1e6c38f923ea757b3baf8ebd0c82e09c5c1a0e7a04"},
+ {file = "wrapt-1.17.2-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:08e7ce672e35efa54c5024936e559469436f8b8096253404faeb54d2a878416f"},
+ {file = "wrapt-1.17.2-cp38-cp38-win32.whl", hash = "sha256:410a92fefd2e0e10d26210e1dfb4a876ddaf8439ef60d6434f21ef8d87efc5b7"},
+ {file = "wrapt-1.17.2-cp38-cp38-win_amd64.whl", hash = "sha256:95c658736ec15602da0ed73f312d410117723914a5c91a14ee4cdd72f1d790b3"},
+ {file = "wrapt-1.17.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:99039fa9e6306880572915728d7f6c24a86ec57b0a83f6b2491e1d8ab0235b9a"},
+ {file = "wrapt-1.17.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2696993ee1eebd20b8e4ee4356483c4cb696066ddc24bd70bcbb80fa56ff9061"},
+ {file = "wrapt-1.17.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:612dff5db80beef9e649c6d803a8d50c409082f1fedc9dbcdfde2983b2025b82"},
+ {file = "wrapt-1.17.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:62c2caa1585c82b3f7a7ab56afef7b3602021d6da34fbc1cf234ff139fed3cd9"},
+ {file = "wrapt-1.17.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c958bcfd59bacc2d0249dcfe575e71da54f9dcf4a8bdf89c4cb9a68a1170d73f"},
+ {file = "wrapt-1.17.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc78a84e2dfbc27afe4b2bd7c80c8db9bca75cc5b85df52bfe634596a1da846b"},
+ {file = "wrapt-1.17.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:ba0f0eb61ef00ea10e00eb53a9129501f52385c44853dbd6c4ad3f403603083f"},
+ {file = "wrapt-1.17.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:1e1fe0e6ab7775fd842bc39e86f6dcfc4507ab0ffe206093e76d61cde37225c8"},
+ {file = "wrapt-1.17.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:c86563182421896d73858e08e1db93afdd2b947a70064b813d515d66549e15f9"},
+ {file = "wrapt-1.17.2-cp39-cp39-win32.whl", hash = "sha256:f393cda562f79828f38a819f4788641ac7c4085f30f1ce1a68672baa686482bb"},
+ {file = "wrapt-1.17.2-cp39-cp39-win_amd64.whl", hash = "sha256:36ccae62f64235cf8ddb682073a60519426fdd4725524ae38874adf72b5f2aeb"},
+ {file = "wrapt-1.17.2-py3-none-any.whl", hash = "sha256:b18f2d1533a71f069c7f82d524a52599053d4c7166e9dd374ae2136b7f40f7c8"},
+ {file = "wrapt-1.17.2.tar.gz", hash = "sha256:41388e9d4d1522446fe79d3213196bd9e3b301a336965b9e27ca2788ebd122f3"},
+]
+
+[[package]]
+name = "zipp"
+version = "3.21.0"
+description = "Backport of pathlib-compatible object wrapper for zip files"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "zipp-3.21.0-py3-none-any.whl", hash = "sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931"},
+ {file = "zipp-3.21.0.tar.gz", hash = "sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4"},
+]
+
+[package.extras]
+check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"]
+cover = ["pytest-cov"]
+doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
+enabler = ["pytest-enabler (>=2.2)"]
+test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"]
+type = ["pytest-mypy"]
+
+[metadata]
+lock-version = "2.1"
+python-versions = ">=3.9,<3.13"
+content-hash = "cb3921d3df77dd5a7c9c7f09fcdf4f4f61b307b04e5bd58c52cfb299ae053da3"
diff --git a/backends/gaudi/server/pyproject.toml b/backends/gaudi/server/pyproject.toml
new file mode 100644
index 00000000000..fa2c2697328
--- /dev/null
+++ b/backends/gaudi/server/pyproject.toml
@@ -0,0 +1,44 @@
+[tool.poetry]
+name = "text-generation-server"
+version = "2.0.4"
+description = "Text Generation Inference Python gRPC Server"
+authors = ["Olivier Dehaene "]
+
+[tool.poetry.scripts]
+text-generation-server = 'text_generation_server.cli:app'
+
+[tool.poetry.dependencies]
+python = ">=3.9,<3.13"
+protobuf = "^5.0"
+grpcio = "^1.71.1"
+grpcio-status = "*"
+grpcio-reflection = "*"
+grpc-interceptor = "^0.15.0"
+typer = "^0.15.0"
+loguru = "^0.7.3"
+opentelemetry-api = "^1.32.0"
+opentelemetry-exporter-otlp = "^1.32.0"
+opentelemetry-instrumentation-grpc = "^0.53b0"
+hf-transfer = "^0.1.9"
+sentencepiece = "^0.2.0"
+peft = "^0.15"
+transformers = "^4.52.4"
+numpy = "^1.26"
+accelerate = "^1.7.0"
+outlines= { version = "^0.0.36", optional = true }
+prometheus-client = "^0.21.1"
+py-cpuinfo = "^9.0.0"
+
+[tool.poetry.group.dev.dependencies]
+grpcio-tools = "*"
+pytest = "^8.3.5"
+
+[tool.pytest.ini_options]
+markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
+
+[build-system]
+requires = ["poetry-core>=1.0.0"]
+build-backend = "poetry.core.masonry.api"
+
+[tool.poetry.requires-plugins]
+poetry-plugin-export = ">=1.8"
diff --git a/backends/gaudi/server/requirements.txt b/backends/gaudi/server/requirements.txt
new file mode 100644
index 00000000000..e6c9abf2a0b
--- /dev/null
+++ b/backends/gaudi/server/requirements.txt
@@ -0,0 +1,86 @@
+accelerate==1.7.0 ; python_version >= "3.9" and python_version < "3.13"
+annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
+attrs==25.3.0 ; python_version >= "3.9" and python_version < "3.13"
+certifi==2025.1.31 ; python_version >= "3.9" and python_version < "3.13"
+charset-normalizer==3.4.1 ; python_version >= "3.9" and python_version < "3.13"
+click==8.1.8 ; python_version >= "3.9" and python_version < "3.13"
+cloudpickle==3.1.1 ; python_version >= "3.9" and python_version < "3.13"
+colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Windows" or python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
+deprecated==1.2.18 ; python_version >= "3.9" and python_version < "3.13"
+diffusers==0.31.0 ; python_version >= "3.9" and python_version < "3.13"
+diskcache==5.6.3 ; python_version >= "3.9" and python_version < "3.13"
+filelock==3.18.0 ; python_version >= "3.9" and python_version < "3.13"
+fsspec==2025.3.2 ; python_version >= "3.9" and python_version < "3.13"
+googleapis-common-protos==1.70.0 ; python_version >= "3.9" and python_version < "3.13"
+grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
+grpcio-reflection==1.71.0 ; python_version >= "3.9" and python_version < "3.13"
+grpcio-status==1.71.0 ; python_version >= "3.9" and python_version < "3.13"
+grpcio==1.72.0rc1 ; python_version >= "3.9" and python_version < "3.13"
+hf-transfer==0.1.9 ; python_version >= "3.9" and python_version < "3.13"
+huggingface-hub==0.30.2 ; python_version >= "3.9" and python_version < "3.13"
+idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
+importlib-metadata==8.6.1 ; python_version >= "3.9" and python_version < "3.13"
+interegular==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
+jinja2==3.1.6 ; python_version >= "3.9" and python_version < "3.13"
+joblib==1.4.2 ; python_version >= "3.9" and python_version < "3.13"
+jsonschema-specifications==2024.10.1 ; python_version >= "3.9" and python_version < "3.13"
+jsonschema==4.23.0 ; python_version >= "3.9" and python_version < "3.13"
+lark==1.2.2 ; python_version >= "3.9" and python_version < "3.13"
+llvmlite==0.43.0 ; python_version >= "3.9" and python_version < "3.13"
+loguru==0.7.3 ; python_version >= "3.9" and python_version < "3.13"
+markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
+markupsafe==3.0.2 ; python_version >= "3.9" and python_version < "3.13"
+mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13"
+mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
+nest-asyncio==1.6.0 ; python_version >= "3.9" and python_version < "3.13"
+networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
+numba==0.60.0 ; python_version >= "3.9" and python_version < "3.13"
+numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
+opentelemetry-api==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
+opentelemetry-exporter-otlp-proto-common==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
+opentelemetry-exporter-otlp-proto-grpc==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
+opentelemetry-exporter-otlp-proto-http==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
+opentelemetry-exporter-otlp==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
+opentelemetry-instrumentation-grpc==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
+opentelemetry-instrumentation==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
+opentelemetry-proto==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
+opentelemetry-sdk==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
+opentelemetry-semantic-conventions==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
+optimum==1.24.0 ; python_version >= "3.9" and python_version < "3.13"
+outlines==0.0.36 ; python_version >= "3.9" and python_version < "3.13"
+packaging==24.2 ; python_version >= "3.9" and python_version < "3.13"
+peft==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
+pillow==11.2.1 ; python_version >= "3.9" and python_version < "3.13"
+prometheus-client==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
+protobuf==5.29.4 ; python_version >= "3.9" and python_version < "3.13"
+psutil==7.0.0 ; python_version >= "3.9" and python_version < "3.13"
+py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
+pydantic-core==2.33.1 ; python_version >= "3.9" and python_version < "3.13"
+pydantic==2.11.3 ; python_version >= "3.9" and python_version < "3.13"
+pygments==2.19.1 ; python_version >= "3.9" and python_version < "3.13"
+pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
+referencing==0.36.2 ; python_version >= "3.9" and python_version < "3.13"
+regex==2024.11.6 ; python_version >= "3.9" and python_version < "3.13"
+requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
+rich==14.0.0 ; python_version >= "3.9" and python_version < "3.13"
+rpds-py==0.24.0 ; python_version >= "3.9" and python_version < "3.13"
+safetensors==0.5.3 ; python_version >= "3.9" and python_version < "3.13"
+scikit-learn==1.6.1 ; python_version >= "3.9" and python_version < "3.13"
+scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
+sentence-transformers==3.3.1 ; python_version >= "3.9" and python_version < "3.13"
+sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
+setuptools==78.1.0 ; python_version >= "3.12" and python_version < "3.13"
+shellingham==1.5.4 ; python_version >= "3.9" and python_version < "3.13"
+sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
+threadpoolctl==3.6.0 ; python_version >= "3.9" and python_version < "3.13"
+tokenizers==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
+tqdm==4.67.1 ; python_version >= "3.9" and python_version < "3.13"
+transformers==4.52.4 ; python_version >= "3.9" and python_version < "3.13"
+triton==3.2.0 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
+typer==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
+typing-extensions==4.13.2 ; python_version >= "3.9" and python_version < "3.13"
+typing-inspection==0.4.0 ; python_version >= "3.9" and python_version < "3.13"
+urllib3==2.4.0 ; python_version >= "3.9" and python_version < "3.13"
+win32-setctime==1.2.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
+wrapt==1.17.2 ; python_version >= "3.9" and python_version < "3.13"
+zipp==3.21.0 ; python_version >= "3.9" and python_version < "3.13"
diff --git a/.devcontainer/Dockerfile.trtllm b/backends/gaudi/server/text_generation_server/__init__.py
similarity index 100%
rename from .devcontainer/Dockerfile.trtllm
rename to backends/gaudi/server/text_generation_server/__init__.py
diff --git a/backends/gaudi/server/text_generation_server/adapters/__init__.py b/backends/gaudi/server/text_generation_server/adapters/__init__.py
new file mode 100644
index 00000000000..8697cb9ee98
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/adapters/__init__.py
@@ -0,0 +1,13 @@
+# Origin: https://github.com/predibase/lorax
+# Path: lorax/server/lorax_server/adapters/__init__.py
+# License: Apache License Version 2.0, January 2004
+
+from text_generation_server.adapters.weights import (
+ AdapterBatchData,
+ AdapterBatchMetadata,
+)
+
+__all__ = [
+ "AdapterBatchData",
+ "AdapterBatchMetadata",
+]
diff --git a/backends/gaudi/server/text_generation_server/adapters/config.py b/backends/gaudi/server/text_generation_server/adapters/config.py
new file mode 100644
index 00000000000..b7e27090018
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/adapters/config.py
@@ -0,0 +1,30 @@
+# Origin: https://github.com/predibase/lorax
+# Path: lorax/server/lorax_server/adapters/config.py
+# License: Apache License Version 2.0, January 2004
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import Dict, Set, Tuple
+
+import torch
+
+from text_generation_server.adapters.weights import AdapterWeights
+
+
+@dataclass
+class ModuleMap:
+ module_name: str
+ module_weights: Dict[str, Tuple[torch.Tensor, str]]
+
+
+@dataclass
+class AdapterConfig(ABC):
+ base_model_name_or_path: str
+
+ @abstractmethod
+ def map_weights_for_model(
+ self,
+ adapter_weights: Dict[int, AdapterWeights],
+ weight_names: Tuple[str],
+ ) -> Tuple[ModuleMap, Set[str]]:
+ pass
diff --git a/backends/gaudi/server/text_generation_server/adapters/lora.py b/backends/gaudi/server/text_generation_server/adapters/lora.py
new file mode 100644
index 00000000000..a00338e7ca5
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/adapters/lora.py
@@ -0,0 +1,471 @@
+# Origin: https://github.com/predibase/lorax
+# Path: lorax/server/lorax_server/adapters/lora.py
+# License: Apache License Version 2.0, January 2004
+
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Set, Tuple, Type, Union
+
+import torch
+from peft import LoraConfig as _LoraConfig
+from torch.distributed import ProcessGroup
+
+from text_generation_server.adapters.config import AdapterConfig, ModuleMap
+
+from text_generation_server.adapters.weights import (
+ AdapterBatchMetadata,
+ AdapterWeights,
+ BatchAdapterWeights,
+)
+from text_generation_server.utils.sgmv import (
+ BGMV_MAX_RANK,
+ MAX_RANK_CUSTOM,
+ get_tmp_tensors,
+ orient_for_rank,
+ pad_rank,
+ use_cutlass_shrink,
+)
+
+
+def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
+ block_size = size // world_size
+ start = offset + rank * block_size
+ stop = offset + (rank + 1) * block_size
+ return start, stop
+
+
+def shard_on_dim(
+ t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
+):
+ world_size = process_group.size()
+ rank = process_group.rank()
+
+ size = t.shape[dim]
+ start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
+
+ if dim == 0:
+ tensor = t[start:stop]
+ elif dim == 1:
+ tensor = t[:, start:stop]
+ else:
+ raise NotImplementedError("Let's make that generic when needed")
+
+ return tensor
+
+
+def shard_lora_weights(
+ weights_a: List[torch.Tensor],
+ weights_b: List[torch.Tensor],
+ split_dim: int,
+ process_group: ProcessGroup,
+) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
+ # [hidden_size, r]
+ weights_a = [
+ shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a
+ ]
+
+ # [r, hidden_size]
+ weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]
+
+ return weights_a, weights_b
+
+
+@dataclass
+class LoraConfig(AdapterConfig):
+ r: int
+ target_modules: Optional[Union[List[str], str]]
+ fan_in_fan_out: bool
+ lora_alpha: int
+ use_rslora: bool
+
+ def map_weights_for_model(
+ self,
+ adapter_weights: Dict[int, AdapterWeights],
+ weight_names: Tuple[str],
+ ) -> Tuple[ModuleMap, Set[str]]:
+ adapter_weight_names = set()
+ module_map = {}
+ for weight_name in weight_names:
+ lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
+ lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
+ if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:
+ continue
+
+ module_map[weight_name] = {
+ "lora_A": (adapter_weights[lora_a_name], lora_a_name),
+ "lora_B": (adapter_weights[lora_b_name], lora_b_name),
+ }
+ adapter_weight_names.add(lora_a_name)
+ adapter_weight_names.add(lora_b_name)
+ return module_map, adapter_weight_names
+
+ @classmethod
+ def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
+ hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
+ return cls(
+ base_model_name_or_path=hf_config.base_model_name_or_path,
+ r=hf_config.r,
+ target_modules=hf_config.target_modules,
+ fan_in_fan_out=hf_config.fan_in_fan_out,
+ lora_alpha=hf_config.lora_alpha,
+ use_rslora=(
+ hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False
+ ),
+ )
+
+
+class LoraWeights(AdapterWeights):
+ """LoRA weights for a single adapter merged across all layers."""
+
+ def __init__(
+ self,
+ weights_a: List[torch.Tensor],
+ weights_b: List[torch.Tensor],
+ adapter_config: LoraConfig,
+ ):
+ self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
+ self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
+
+ self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
+ self._is_transposed = False
+
+ # [num_layers, hidden_size, r]
+ weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
+ self._weights_a = torch.stack(weights_a)
+
+ # [num_layers, r, hidden_size]
+ self._weights_b = torch.stack(weights_b)
+
+ self.adapter_config = adapter_config
+
+ @property
+ def weights_a(self) -> torch.Tensor:
+ if self._is_transposed:
+ self._transpose_weights()
+ return self._weights_a
+
+ @property
+ def weights_b(self) -> torch.Tensor:
+ if self._is_transposed:
+ self._transpose_weights()
+ return self._weights_b
+
+ @property
+ def weights_a_t(self) -> torch.Tensor:
+ if not self._is_transposed:
+ self._transpose_weights()
+ return self._weights_a
+
+ @property
+ def weights_b_t(self) -> torch.Tensor:
+ if not self._is_transposed:
+ self._transpose_weights()
+ return self._weights_b
+
+ def _transpose_weights(self):
+ if self._use_cutlass_shrink:
+ # If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation
+ self._weights_a = self._weights_a.transpose(1, 2).contiguous()
+ self._weights_b = self._weights_b.transpose(1, 2).contiguous()
+ self._is_transposed = not self._is_transposed
+
+ @classmethod
+ def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
+ return [BatchLoraWeights]
+
+ # prepare pre-loaded lora weights for use in the model.
+ #
+ # this method processes and organizes lora weights for a specific layer type across all layers:
+ # - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor.
+ # - retrieves weights from `module_map` based on the `layer_type`.
+ # - processes `nlayers` number of layers.
+ # - converts weights to the specified `dtype`.
+ # - shards weights across `world_size` number of processes using the `process_group`.
+ # - maps weights to specific layers using `target_to_layer`.
+ # - tracks `unused_weight_names` to identify any unused weights.
+ #
+ # the method handles weight transposition, scaling, and padding to ensure compatibility
+ # with SGMV or BGMV operations.
+ @classmethod
+ def prepare_weights(
+ cls,
+ config: LoraConfig,
+ module_map: Dict[str, Dict],
+ layer_type: str,
+ unused_weight_names: Set[str],
+ nlayers: int,
+ dtype: torch.dtype,
+ world_size: int,
+ process_group: ProcessGroup,
+ target_to_layer: Dict[str, Tuple[str, torch.Tensor]],
+ ) -> Optional[AdapterWeights]:
+ lora_a_list = [None] * nlayers
+ lora_b_list = [None] * nlayers
+
+ for layer_id in range(nlayers):
+ key = (layer_id, layer_type)
+ weight_name, layer = target_to_layer[key]
+ base_weight = layer.base_layer.linear.weight
+ base_device = base_weight.device
+
+ if weight_name not in module_map:
+ # There is no LoRA weight for this layer type in the adapter
+ return None
+
+ lora_a, lora_a_name = module_map[weight_name]["lora_A"]
+ lora_a = lora_a.to(base_device, dtype)
+
+ lora_b, lora_b_name = module_map[weight_name]["lora_B"]
+ lora_b = lora_b.to(base_device, dtype)
+
+ scale = get_scaling_factor(
+ config.lora_alpha,
+ config.r,
+ uses_rslora=config.use_rslora,
+ )
+
+ unused_weight_names.discard(lora_a_name)
+ unused_weight_names.discard(lora_b_name)
+
+ # Merge scaling factor into lora_b due to associativity of matrix multiplication:
+ # (A * B) * C = A * (B * C)
+ lora_a_list[layer_id] = lora_a.transpose(0, 1)
+ lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
+
+ # pad lora ranks to be compatible with sgmv
+ lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]
+ lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]
+
+ if lora_a_list:
+ # update rank if it was padded
+ padded_rank = lora_a_list[0].size(1)
+ config.r = padded_rank
+
+ return LoraWeights(
+ *shard_lora_weights(
+ weights_a=lora_a_list,
+ weights_b=lora_b_list,
+ split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1,
+ process_group=process_group,
+ ),
+ config,
+ )
+
+
+@dataclass
+class RankSegments:
+ rank: int
+
+ lora_a_ptr: torch.Tensor
+ lora_b_ptr: torch.Tensor
+
+ # prefill (sgmv)
+ tmp_shrink: torch.Tensor
+ tmp_expand: torch.Tensor
+ segment_starts: torch.Tensor
+ segment_ends: torch.Tensor
+
+ # decode (bgmv)
+ indices: torch.Tensor
+
+
+@dataclass
+class BatchLoraWeights(BatchAdapterWeights):
+ lora_a: Dict[int, torch.Tensor]
+ lora_b: Dict[int, torch.Tensor]
+ adapter_index_configs: Dict[int, LoraConfig]
+ rank_data: Dict[int, RankSegments]
+ use_sgmv: bool
+
+ def has_adapter(self, adapter_index: int) -> bool:
+ return adapter_index in self.adapter_index_configs
+
+ def can_vectorize(self, pg: ProcessGroup) -> bool:
+ return all(
+ rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
+ for rank_data in self.rank_data.values()
+ )
+
+ @classmethod
+ def load(
+ self,
+ adapter_weights: Dict[int, AdapterWeights],
+ meta: AdapterBatchMetadata,
+ prefill: bool,
+ prefill_head_indices: Optional[torch.Tensor],
+ ) -> Optional["BatchLoraWeights"]:
+ adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
+ adapter_weights = {
+ k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
+ }
+ if not adapter_weights:
+ return None
+
+ first_weights = next(iter(adapter_weights.values()))
+ device = first_weights.weights_a.device
+ segment_indices = meta.segment_indices
+
+ lora_a = {
+ idx: adapter_weights[idx].weights_a
+ for idx in segment_indices
+ if idx in adapter_weights
+ }
+ lora_b = {
+ idx: adapter_weights[idx].weights_b
+ for idx in segment_indices
+ if idx in adapter_weights
+ }
+
+ max_rank = max(
+ (
+ adapter_weights[idx].lora_a_r
+ for idx in segment_indices
+ if idx in adapter_weights
+ ),
+ default=0,
+ )
+
+ if prefill or max_rank > BGMV_MAX_RANK:
+ use_sgmv = True
+ lora_a_ptr = torch.tensor(
+ [
+ (
+ adapter_weights[idx].weights_a.data_ptr()
+ if idx in adapter_weights
+ else 0
+ )
+ for idx in segment_indices
+ ],
+ dtype=torch.int64,
+ device=device,
+ )
+ lora_b_ptr = torch.tensor(
+ [
+ (
+ adapter_weights[idx].weights_b.data_ptr()
+ if idx in adapter_weights
+ else 0
+ )
+ for idx in segment_indices
+ ],
+ dtype=torch.int64,
+ device=device,
+ )
+ else:
+ use_sgmv = False
+ lora_a_ptr = torch.tensor(
+ [
+ (
+ adapter_weights[idx].weights_a_t.data_ptr()
+ if idx in adapter_weights
+ else 0
+ )
+ for idx in segment_indices
+ ],
+ dtype=torch.int64,
+ device=device,
+ )
+ lora_b_ptr = torch.tensor(
+ [
+ (
+ adapter_weights[idx].weights_b_t.data_ptr()
+ if idx in adapter_weights
+ else 0
+ )
+ for idx in segment_indices
+ ],
+ dtype=torch.int64,
+ device=device,
+ )
+
+ adapter_index_configs = {
+ idx: adapter_weights[idx].adapter_config
+ for idx in segment_indices
+ if idx in adapter_weights
+ }
+
+ adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
+
+ rank_indices = defaultdict(list)
+ for segment_idx, adapter_idx in enumerate(segment_indices):
+ if adapter_idx not in adapter_weights:
+ continue
+ rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
+
+ if prefill_head_indices is not None:
+ j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
+ for head_index in prefill_head_indices:
+ # j cannot go out of bounds as that would mean there are tokens without corresponding adapters
+ if head_index < meta.adapter_segments[j]:
+ prefill_head_segment_ends[-1] += 1
+ else:
+ prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
+ prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
+ j += 1
+
+ rank_data = {}
+ for rank, indices in rank_indices.items():
+ tmp_shrink = None
+ tmp_expand = None
+ segment_starts = None
+ segment_ends = None
+ batch_indices = None
+
+ if use_sgmv:
+ lora_a_ptr_indices = lora_a_ptr[indices]
+ tmp_shrink, tmp_expand = get_tmp_tensors(
+ lora_a_ptr_indices.size(0), rank, device
+ )
+ segment_starts = meta.adapter_segments[indices]
+ segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
+ if prefill_head_indices is not None:
+ for i, segment_index in enumerate(indices):
+ segment_starts[i] = prefill_head_segment_starts[segment_index]
+ segment_ends[i] = prefill_head_segment_ends[segment_index]
+ else:
+ rank_indices = set(indices)
+ batch_indices = [
+ adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
+ ]
+ batch_indices = [
+ idx if idx in rank_indices else -1 for idx in batch_indices
+ ]
+ batch_indices = torch.tensor(
+ batch_indices, dtype=torch.int64, device=device
+ )
+
+ rank_data[rank] = RankSegments(
+ rank=rank,
+ tmp_shrink=tmp_shrink,
+ tmp_expand=tmp_expand,
+ lora_a_ptr=lora_a_ptr[indices],
+ lora_b_ptr=lora_b_ptr[indices],
+ segment_starts=segment_starts,
+ segment_ends=segment_ends,
+ indices=batch_indices,
+ )
+
+ return BatchLoraWeights(
+ lora_a=lora_a,
+ lora_b=lora_b,
+ adapter_index_configs=adapter_index_configs,
+ rank_data=rank_data,
+ use_sgmv=use_sgmv,
+ )
+
+
+def get_scaling_factor(
+ lora_alpha: int,
+ r: int,
+ uses_rslora: bool = False,
+) -> float:
+ """Computes the scaling factor for the lora weights."""
+ if uses_rslora:
+ return lora_alpha / (r**0.5)
+ return lora_alpha / r
+
+
+def _convert_lora(v: AdapterWeights) -> AdapterWeights:
+ if hasattr(v, "lora_weights"):
+ return v.lora_weights
+ return v
diff --git a/backends/gaudi/server/text_generation_server/adapters/weights.py b/backends/gaudi/server/text_generation_server/adapters/weights.py
new file mode 100644
index 00000000000..da75dbcdf98
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/adapters/weights.py
@@ -0,0 +1,146 @@
+# Origin: https://github.com/predibase/lorax
+# Path: lorax/server/lorax_server/adapters/weights.py
+# License: Apache License Version 2.0, January 2004
+
+from abc import ABC, abstractclassmethod
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Set, Type
+
+import torch
+
+
+@dataclass
+class AdapterBatchMetadata:
+ # [batch_size]
+ adapter_indices: torch.Tensor
+
+ # [num_adapters]
+ adapter_set: Set[int]
+
+ # [num_segments + 1]
+ adapter_segments: torch.Tensor
+
+ # [num_segments]
+ # maps from segment index to adapter index, i.e.:
+ # segment_indices[s] == adapter_indices[i]
+ segment_indices: List[int]
+
+
+class AdapterWeights(ABC):
+ @abstractclassmethod
+ def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]:
+ pass
+
+ @property
+ def speculative_tokens(self) -> int:
+ return 0
+
+
+class BatchAdapterWeights(ABC):
+ @abstractclassmethod
+ def has_adapter(self, adapter_index: int) -> bool:
+ pass
+
+ @abstractclassmethod
+ def load(
+ cls,
+ adapter_weights: Dict[int, AdapterWeights],
+ meta: "AdapterBatchMetadata",
+ prefill: bool,
+ prefill_head_indices: torch.Tensor,
+ ) -> Optional["BatchAdapterWeights"]:
+ pass
+
+
+class LayerAdapterWeights:
+ """Adapter weights that apply to a particular layer."""
+
+ def __init__(self):
+ self.adapter_weights: Dict[int, AdapterWeights] = {}
+
+ def add_adapter(self, adapter_idx: int, weights: AdapterWeights):
+ self.adapter_weights[adapter_idx] = weights
+
+ def remove_adapter(self, adapter_idx: int):
+ if adapter_idx not in self.adapter_weights:
+ return
+ del self.adapter_weights[adapter_idx]
+
+ def is_empty(self) -> bool:
+ return len(self.adapter_weights) == 0
+
+ def get_data(
+ self,
+ meta: AdapterBatchMetadata,
+ prefill: bool,
+ prefill_head_indices: Optional[torch.Tensor],
+ ) -> Dict[str, BatchAdapterWeights]:
+ # bucket adapters by batch class
+ adapter_batch_types: Dict[
+ Type[BatchAdapterWeights], Dict[int, AdapterWeights]
+ ] = defaultdict(dict)
+ for adapter_index, adapter_weights in self.adapter_weights.items():
+ for batch_type in adapter_weights.get_batch_types():
+ adapter_batch_types[batch_type][adapter_index] = adapter_weights
+
+ batch_data = {}
+ for batch_type, adapter_weights in adapter_batch_types.items():
+ batched_weights = batch_type.load(
+ adapter_weights, meta, prefill, prefill_head_indices
+ )
+ if batched_weights is not None:
+ batch_data = batched_weights
+ return batch_data
+
+
+@dataclass
+class AdapterBatchData:
+ meta: AdapterBatchMetadata
+
+ # layer type -> adapter type -> batch weight data
+ data: Dict[str, Dict[str, BatchAdapterWeights]]
+
+ prefill: bool
+
+ @staticmethod
+ def from_meta(
+ meta: AdapterBatchMetadata,
+ weights: Dict[str, LayerAdapterWeights],
+ prefill: bool,
+ prefill_head_indices: Optional[torch.Tensor],
+ ) -> "AdapterBatchData":
+ data = {}
+ for k, v in weights.items():
+ if v.is_empty():
+ continue
+ data[k] = v.get_data(
+ meta, prefill, prefill_head_indices if k == "lm_head" else None
+ )
+ return AdapterBatchData(meta=meta, data=data, prefill=prefill)
+
+ def ranks(self) -> Set[int]:
+ # TODO(travis): refactor to be less coupled to lora implementation
+ ranks = set()
+ for lora_data in self.data.values():
+ if lora_data is None:
+ continue
+
+ for rank_data in lora_data.rank_data.values():
+ ranks.add(rank_data.rank)
+
+ return ranks
+
+ def layer_names(self) -> Set[str]:
+ return set(self.data.keys())
+
+ def adapter_keys(self) -> Set[str]:
+ adapter_keys = set()
+ for layer_data in self.data.values():
+ adapter_keys.update(layer_data.keys())
+ return adapter_keys
+
+ @property
+ def max_rank(self) -> int:
+ ranks = self.ranks()
+ return max(ranks) if len(ranks) > 0 else 0
diff --git a/backends/gaudi/server/text_generation_server/cache.py b/backends/gaudi/server/text_generation_server/cache.py
new file mode 100644
index 00000000000..4504733e51e
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/cache.py
@@ -0,0 +1,34 @@
+import torch
+
+from typing import Dict, Optional, TypeVar
+
+from text_generation_server.models.types import Batch
+
+B = TypeVar("B", bound=Batch)
+
+
+class Cache:
+ def __init__(self):
+ self.cache: Dict[int, B] = {}
+
+ def pop(self, batch_id: int) -> Optional[B]:
+ return self.cache.pop(batch_id, None)
+
+ def set(self, entry: B):
+ if entry is not None:
+ self.cache[entry.batch_id] = entry
+
+ def delete(self, batch_id: int):
+ batch = self.pop(batch_id)
+ if batch is not None:
+ del batch
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ def clear(self):
+ keys = list(self.cache.keys())
+ for k in keys:
+ self.delete(k)
+
+ def __len__(self):
+ return len(self.cache.keys())
diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py
new file mode 100644
index 00000000000..dc31ab2fd6d
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/cli.py
@@ -0,0 +1,373 @@
+import os
+import sys
+import typer
+
+from pathlib import Path
+from loguru import logger
+from typing import Optional
+from enum import Enum
+from huggingface_hub import hf_hub_download
+from text_generation_server.utils.adapter import parse_lora_adapters
+
+
+app = typer.Typer()
+
+
+class Quantization(str, Enum):
+ gptq = "gptq"
+ awq = "awq"
+ fp8 = "fp8"
+ compressed_tensors = "compressed-tensors"
+
+
+class Dtype(str, Enum):
+ float16 = "float16"
+ bloat16 = "bfloat16"
+
+
+class KVCacheDtype(str, Enum):
+ fp8_e4m3fn = "fp8_e4m3fn"
+ fp8_e5m2 = "fp8_e5m2"
+
+
+@app.command()
+def serve(
+ model_id: str,
+ revision: Optional[str] = None,
+ sharded: bool = False,
+ quantize: Optional[Quantization] = None,
+ speculate: Optional[int] = None,
+ dtype: Optional[Dtype] = None,
+ kv_cache_dtype: Optional[KVCacheDtype] = None,
+ trust_remote_code: bool = False,
+ uds_path: Path = "/tmp/text-generation-server",
+ logger_level: str = "INFO",
+ json_output: bool = False,
+ otlp_endpoint: Optional[str] = None,
+ otlp_service_name: str = "text-generation-inference.server",
+ max_input_tokens: Optional[int] = None,
+):
+ if sharded:
+ # assert (
+ # os.getenv("RANK", None) is not None
+ # ), "RANK must be set when sharded is True"
+ assert (
+ os.getenv("WORLD_SIZE", None) is not None
+ ), "WORLD_SIZE must be set when sharded is True"
+ assert (
+ os.getenv("MASTER_ADDR", None) is not None
+ ), "MASTER_ADDR must be set when sharded is True"
+ assert (
+ os.getenv("MASTER_PORT", None) is not None
+ ), "MASTER_PORT must be set when sharded is True"
+
+ # Remove default handler
+ logger.remove()
+ logger.add(
+ sys.stdout,
+ format="{message}",
+ filter="text_generation_server",
+ level=logger_level,
+ serialize=json_output,
+ backtrace=True,
+ diagnose=False,
+ )
+
+ # Import here after the logger is added to log potential import exceptions
+ from text_generation_server import server
+ from text_generation_server.tracing import setup_tracing
+
+ # Setup OpenTelemetry distributed tracing
+ if otlp_endpoint is not None:
+ setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
+
+ lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS"))
+
+ # TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
+ # and warn the user
+ if lora_adapters:
+ logger.warning("LoRA adapters enabled (experimental feature).")
+
+ if "CUDA_GRAPHS" in os.environ:
+ logger.warning(
+ "LoRA adapters incompatible with CUDA Graphs. Disabling CUDA Graphs."
+ )
+ global CUDA_GRAPHS
+ CUDA_GRAPHS = None
+
+ # Downgrade enum into str for easier management later on
+ quantize = None if quantize is None else quantize.value
+ dtype = "bfloat16" if dtype is None else dtype.value
+ kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
+ logger.info(f"quantize={quantize} kv_cache_dtype={kv_cache_dtype}")
+ if dtype is not None and quantize not in {
+ None,
+ "bitsandbytes",
+ "bitsandbytes-nf4",
+ "bitsandbytes-fp4",
+ "gptq",
+ "awq",
+ "fp8",
+ "compressed-tensors",
+ }:
+ raise RuntimeError(
+ "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
+ )
+ server.serve(
+ model_id,
+ lora_adapters,
+ revision,
+ sharded,
+ quantize,
+ speculate,
+ dtype,
+ kv_cache_dtype,
+ trust_remote_code,
+ uds_path,
+ max_input_tokens,
+ )
+
+
+@app.command()
+def download_weights(
+ model_id: str,
+ revision: Optional[str] = None,
+ extension: str = ".safetensors",
+ auto_convert: bool = True,
+ logger_level: str = "INFO",
+ json_output: bool = False,
+ trust_remote_code: bool = False,
+ merge_lora: bool = False,
+):
+ # Remove default handler
+ logger.remove()
+ logger.add(
+ sys.stdout,
+ format="{message}",
+ filter="text_generation_server",
+ level=logger_level,
+ serialize=json_output,
+ backtrace=True,
+ diagnose=False,
+ )
+
+ # Import here after the logger is added to log potential import exceptions
+ from text_generation_server import utils
+
+ # Test if files were already download
+ try:
+ utils.weight_files(model_id, revision, extension)
+ logger.info("Files are already present on the host. " "Skipping download.")
+ return
+ # Local files not found
+ except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
+ pass
+
+ is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
+ "WEIGHTS_CACHE_OVERRIDE", None
+ ) is not None
+
+ if not is_local_model:
+ # TODO: maybe reverse the default value of merge_lora?
+ # currently by default we don't merge the weights with the base model
+ if merge_lora:
+ try:
+ hf_hub_download(
+ model_id, revision=revision, filename="adapter_config.json"
+ )
+ utils.download_and_unload_peft(
+ model_id, revision, trust_remote_code=trust_remote_code
+ )
+ is_local_model = True
+ utils.weight_files(model_id, revision, extension)
+ return
+ except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
+ pass
+ else:
+ try:
+ utils.peft.download_peft(
+ model_id, revision, trust_remote_code=trust_remote_code
+ )
+ except Exception:
+ pass
+
+ try:
+ import json
+
+ config = hf_hub_download(
+ model_id, revision=revision, filename="config.json"
+ )
+ with open(config, "r") as f:
+ config = json.load(f)
+
+ base_model_id = config.get("base_model_name_or_path", None)
+ if base_model_id and base_model_id != model_id:
+ try:
+ logger.info(f"Downloading parent model {base_model_id}")
+ download_weights(
+ model_id=base_model_id,
+ revision="main",
+ extension=extension,
+ auto_convert=auto_convert,
+ logger_level=logger_level,
+ json_output=json_output,
+ trust_remote_code=trust_remote_code,
+ )
+ except Exception:
+ pass
+ except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
+ pass
+
+ # Try to download weights from the hub
+ try:
+ filenames = utils.weight_hub_files(model_id, revision, extension)
+ utils.download_weights(filenames, model_id, revision)
+ # Successfully downloaded weights
+ return
+
+ # No weights found on the hub with this extension
+ except utils.EntryNotFoundError as e:
+ # Check if we want to automatically convert to safetensors or if we can use .bin weights instead
+ if not extension == ".safetensors" or not auto_convert:
+ raise e
+
+ elif (Path(model_id) / "adapter_config.json").exists():
+ # Try to load as a local PEFT model
+ try:
+ utils.download_and_unload_peft(
+ model_id, revision, trust_remote_code=trust_remote_code
+ )
+ utils.weight_files(model_id, revision, extension)
+ return
+ except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
+ pass
+ elif (Path(model_id) / "config.json").exists():
+ # Try to load as a local Medusa model
+ try:
+ import json
+
+ config = Path(model_id) / "config.json"
+ with open(config, "r") as f:
+ config = json.load(f)
+
+ base_model_id = config.get("base_model_name_or_path", None)
+ if base_model_id:
+ try:
+ logger.info(f"Downloading parent model {base_model_id}")
+ download_weights(
+ model_id=base_model_id,
+ revision="main",
+ extension=extension,
+ auto_convert=auto_convert,
+ logger_level=logger_level,
+ json_output=json_output,
+ trust_remote_code=trust_remote_code,
+ )
+ except Exception:
+ pass
+ except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
+ pass
+
+ # Try to see if there are local pytorch weights
+ try:
+ # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
+ try:
+ local_pt_files = utils.weight_files(model_id, revision, ".bin")
+ except Exception:
+ local_pt_files = utils.weight_files(model_id, revision, ".pt")
+
+ # No local pytorch weights
+ except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
+ if extension == ".safetensors":
+ logger.warning(
+ f"No safetensors weights found for model {model_id} at revision {revision}. "
+ f"Downloading PyTorch weights."
+ )
+
+ # Try to see if there are pytorch weights on the hub
+ pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
+ # Download pytorch weights
+ local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
+
+ if auto_convert:
+ if not trust_remote_code:
+ logger.warning(
+ "🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because "
+ "Pickle files are unsafe and can essentially contain remote code execution!"
+ "Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety",
+ )
+
+ logger.warning(
+ f"No safetensors weights found for model {model_id} at revision {revision}. "
+ f"Converting PyTorch weights to safetensors."
+ )
+
+ # Safetensors final filenames
+ local_st_files = [
+ p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
+ for p in local_pt_files
+ ]
+ try:
+ import transformers
+ import json
+
+ if is_local_model:
+ config_filename = os.path.join(model_id, "config.json")
+ else:
+ config_filename = hf_hub_download(
+ model_id, revision=revision, filename="config.json"
+ )
+ with open(config_filename, "r") as f:
+ config = json.load(f)
+ architecture = config["architectures"][0]
+
+ class_ = getattr(transformers, architecture)
+
+ # Name for this varible depends on transformers version.
+ discard_names = getattr(class_, "_tied_weights_keys", [])
+
+ except Exception:
+ discard_names = []
+ # Convert pytorch weights to safetensors
+ utils.convert_files(local_pt_files, local_st_files, discard_names)
+
+
+@app.command()
+def quantize(
+ model_id: str,
+ output_dir: str,
+ revision: Optional[str] = None,
+ logger_level: str = "INFO",
+ json_output: bool = False,
+ trust_remote_code: bool = False,
+ upload_to_model_id: Optional[str] = None,
+ percdamp: float = 0.01,
+ act_order: bool = False,
+ groupsize: int = 128,
+):
+ if revision is None:
+ revision = "main"
+ download_weights(
+ model_id=model_id,
+ revision=revision,
+ logger_level=logger_level,
+ json_output=json_output,
+ )
+ from text_generation_server.layers.gptq.quantize import quantize
+
+ quantize(
+ model_id=model_id,
+ bits=4,
+ groupsize=groupsize,
+ output_dir=output_dir,
+ revision=revision,
+ trust_remote_code=trust_remote_code,
+ upload_to_model_id=upload_to_model_id,
+ percdamp=percdamp,
+ act_order=act_order,
+ sym=True,
+ )
+
+
+if __name__ == "__main__":
+ app()
diff --git a/backends/gaudi/server/text_generation_server/interceptor.py b/backends/gaudi/server/text_generation_server/interceptor.py
new file mode 100644
index 00000000000..47f33cd0b9a
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/interceptor.py
@@ -0,0 +1,45 @@
+# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
+
+import torch
+import grpc
+
+from google.rpc import status_pb2, code_pb2
+from grpc_status import rpc_status
+from grpc_interceptor.server import AsyncServerInterceptor
+from loguru import logger
+from typing import Callable, Any
+import traceback
+import os
+
+
+class ExceptionInterceptor(AsyncServerInterceptor):
+ async def intercept(
+ self,
+ method: Callable,
+ request_or_iterator: Any,
+ context: grpc.ServicerContext,
+ method_name: str,
+ ) -> Any:
+ try:
+ response = method(request_or_iterator, context)
+ return await response
+ except Exception as err:
+ trace = " " + traceback.format_exc() if os.environ.get("DUMP_STACK") else ""
+ method_name = method_name.split("/")[-1]
+ logger.exception(f"Method {method_name} encountered an error.")
+
+ # Runtime Error cannot be recovered from
+ if isinstance(err, RuntimeError):
+ exit(1)
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ from .utils.debug import dbg_trace
+
+ dbg_trace("EXCEPTION", traceback.format_exc())
+ await context.abort_with_status(
+ rpc_status.to_status(
+ status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace)
+ )
+ )
diff --git a/backends/gaudi/server/text_generation_server/layers/__init__.py b/backends/gaudi/server/text_generation_server/layers/__init__.py
new file mode 100644
index 00000000000..fd146728ada
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/__init__.py
@@ -0,0 +1,36 @@
+from text_generation_server.layers.tensor_parallel import (
+ TensorParallelColumnLinear,
+ TensorParallelRowLinear,
+ TensorParallelEmbedding,
+)
+from text_generation_server.layers.linear import (
+ get_linear,
+ FastLinear,
+)
+from text_generation_server.layers.speculative import SpeculativeHead
+
+# Just to add the `load` methods.
+from text_generation_server.layers.layernorm import load_layer_norm
+from text_generation_server.layers.conv import load_conv2d
+from text_generation_server.layers.fp8 import Fp8Linear
+
+from text_generation_server.layers.lora import (
+ LoraLinear,
+ TensorParallelMultiAdapterLinear,
+ TensorParallelAdapterRowLinear,
+)
+
+__all__ = [
+ "get_linear",
+ "FastLinear",
+ "TensorParallelColumnLinear",
+ "TensorParallelRowLinear",
+ "TensorParallelEmbedding",
+ "SpeculativeHead",
+ "LoraLinear",
+ "Fp8Linear",
+ "TensorParallelMultiAdapterLinear",
+ "TensorParallelAdapterRowLinear",
+ "load_layer_norm",
+ "load_conv2d",
+]
diff --git a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py
new file mode 100644
index 00000000000..aa639832135
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py
@@ -0,0 +1,35 @@
+from .common import (
+ Seqlen,
+ HPUPagedAttentionMetadata,
+ trim_attn_metadata,
+ trim_seqlen_metadata,
+ _async_h2d_tensor_copy,
+)
+
+from .hpu import (
+ SUPPORTS_WINDOWING,
+ attention,
+ paged_attention,
+ paged_attention_mla,
+ set_block_mapping,
+)
+
+
+# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
+from .kv_cache import KVCache, get_kv_scales, KVCompressCache
+
+__all__ = [
+ "attention",
+ "get_kv_scales",
+ "paged_attention",
+ "paged_attention_mla",
+ "set_block_mapping",
+ "SUPPORTS_WINDOWING",
+ "KVCache",
+ "KVCompressCache",
+ "Seqlen",
+ "HPUPagedAttentionMetadata",
+ "trim_seqlen_metadata",
+ "trim_attn_metadata",
+ "_async_h2d_tensor_copy",
+]
diff --git a/backends/gaudi/server/text_generation_server/layers/attention/common.py b/backends/gaudi/server/text_generation_server/layers/attention/common.py
new file mode 100644
index 00000000000..1086c411411
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/attention/common.py
@@ -0,0 +1,186 @@
+from dataclasses import dataclass
+import torch
+from typing import Optional, List, Dict
+import collections
+import torch.nn.functional as F
+
+_TYPE_CACHE = {}
+
+
+@dataclass
+class HPUPagedAttentionMetadata:
+ """Metadata for PagedAttention."""
+
+ block_list: Optional[torch.Tensor]
+ block_mapping: Optional[torch.Tensor]
+ block_usage: Optional[torch.Tensor]
+ block_groups: Optional[torch.Tensor]
+ attn_bias: Optional[torch.Tensor]
+ slots_in_window_mask: Optional[torch.Tensor] = None
+ block_list_in_window: Optional[torch.Tensor] = None
+ block_mapping_in_window: Optional[torch.Tensor] = None
+ block_usage_in_window: Optional[torch.Tensor] = None
+ block_groups_in_window: Optional[torch.Tensor] = None
+ attn_bias_in_window: Optional[torch.Tensor] = None
+
+
+def subtuple(
+ obj: object,
+ typename: str,
+ to_copy: List[str],
+ to_override: Optional[Dict[str, object]] = None,
+):
+ if obj is None:
+ return None
+ if to_override is None:
+ to_override = {}
+ fields = set(to_copy) | set(to_override.keys())
+ if isinstance(obj, dict):
+ values = {key: obj[key] for key in fields if key in obj}
+ else:
+ values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
+ if typename not in _TYPE_CACHE:
+ _TYPE_CACHE[typename] = collections.namedtuple(typename, " ".join(fields))
+ return _TYPE_CACHE[typename](**values)
+
+
+def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
+ # NOTE(kzawora): To anyone working on this in the future:
+ # Trimming metadata is required when using HPUGraphs.
+ # Attention metadata is going to be hashed by PT bridge, and
+ # appropriate HPUGraphs will be matched based on all inputs' hash.
+
+ # Before you put more keys in here, make sure you know their
+ # value type and make sure you know how it's going to be hashed.
+ # You can find that information in input_hash function
+ # in habana_frameworks/torch/hpu/graphs.py. You can also hash
+ # it manually with torch.hpu.graphs.input_hash(attention_metadata)
+
+ # If you use primitive types here - they will get hashed based
+ # on their value. You *will* get lots of excessive graph captures
+ # (and an OOM eventually) if you decide to put something like
+ # seq_len int here.
+ # If you absolutely need a scalar, put it in a tensor. Tensors
+ # get hashed using their metadata, not their values:
+ # input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
+ # input_hash(123) != input_hash(321)
+ # input_hash("abc") != input_hash("cba")
+ attention_metadata = subtuple(
+ metadata,
+ "TrimmedAttentionMetadata",
+ [
+ "block_list",
+ "block_mapping",
+ "block_usage",
+ "block_groups",
+ "attn_bias",
+ "slots_in_window_mask",
+ "block_list_in_window",
+ "block_mapping_in_window",
+ "block_usage_in_window",
+ "block_groups_in_window",
+ "attn_bias_in_window",
+ ],
+ )
+ return attention_metadata
+
+
+@dataclass
+class Seqlen:
+ input_lengths: torch.Tensor
+ attn_mask: Optional[torch.Tensor] = None
+
+ def __init__(
+ self,
+ input_lengths,
+ ):
+ self.input_lengths = input_lengths
+
+ def clamp(self, max):
+ # Flash decoding doesn't need to clamp
+ return self
+
+ def make_sliding_window_bias(
+ self,
+ seq_lens: List[int],
+ window_size: Optional[int],
+ dtype: torch.dtype,
+ padded_input_len: Optional[int],
+ padded_bs: Optional[int],
+ ) -> List[torch.Tensor]:
+ attn_biases = []
+ for seq_len in seq_lens:
+ if seq_len != 0:
+ tensor = torch.full(
+ (1, seq_len, seq_len),
+ dtype=dtype,
+ fill_value=1,
+ )
+ shift = 0
+ mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
+ if window_size is not None:
+ mask = torch.triu(mask, diagonal=shift - window_size + 1)
+ mask = F.pad(
+ mask,
+ (
+ padded_input_len - seq_len,
+ 0,
+ padded_input_len - seq_len,
+ 0,
+ 0,
+ 0,
+ ),
+ value=0,
+ )
+ else:
+ mask = torch.full(
+ (1, padded_input_len, padded_input_len),
+ dtype=dtype,
+ fill_value=0,
+ )
+ attn_biases.append(mask)
+ attn_biases = torch.stack(attn_biases, dim=0)
+ return attn_biases.to(torch.bool)
+
+
+def _async_h2d_tensor_copy(source, device="hpu"):
+ if source is None:
+ return None
+ if source.device.type == "hpu":
+ return source
+ assert source.device.type == "cpu", "Source tensor is not present in host memory!"
+ target = torch.empty(source.shape, dtype=source.dtype, device=device)
+ target.copy_(source, non_blocking=True)
+ return target
+
+
+def trim_seqlen_metadata(metadata: Seqlen) -> object:
+ # NOTE(kzawora): To anyone working on this in the future:
+ # Trimming metadata is required when using HPUGraphs.
+ # Attention metadata is going to be hashed by PT bridge, and
+ # appropriate HPUGraphs will be matched based on all inputs' hash.
+
+ # Before you put more keys in here, make sure you know their
+ # value type and make sure you know how it's going to be hashed.
+ # You can find that information in input_hash function
+ # in habana_frameworks/torch/hpu/graphs.py. You can also hash
+ # it manually with torch.hpu.graphs.input_hash(attention_metadata)
+
+ # If you use primitive types here - they will get hashed based
+ # on their value. You *will* get lots of excessive graph captures
+ # (and an OOM eventually) if you decide to put something like
+ # seq_len int here.
+ # If you absolutely need a scalar, put it in a tensor. Tensors
+ # get hashed using their metadata, not their values:
+ # input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
+ # input_hash(123) != input_hash(321)
+ # input_hash("abc") != input_hash("cba")
+ attention_metadata = subtuple(
+ metadata,
+ "TrimmedSeqlen",
+ [
+ "input_lengths",
+ "attn_mask",
+ ],
+ )
+ return attention_metadata
diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py
new file mode 100644
index 00000000000..d3588e253e1
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py
@@ -0,0 +1,227 @@
+import torch
+from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
+from typing import Optional
+from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
+from vllm_hpu_extension import ops
+from vllm_hpu_extension.utils import Matmul
+from habana_frameworks.torch.hpex.kernels import FusedSDPA
+from vllm_hpu_extension.utils import ModuleFusedSDPA
+import os
+from text_generation_server.models.globals import BLOCK_SIZE
+import math
+
+SUPPORTS_WINDOWING = False
+
+
+class FP8Matmul(torch.nn.Module):
+
+ def __init__(self, scale_other):
+ super().__init__()
+ self.scale_input = torch.tensor(1.0, dtype=torch.bfloat16, device="hpu")
+ self.scale_other = scale_other
+
+ def quant_input(self, x, scale):
+ return torch.ops.hpu.cast_to_fp8_v2(
+ x, scale, False, False, torch.float8_e4m3fn
+ )[0]
+
+ def matmul_fp8(
+ self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None
+ ):
+ return torch.ops.hpu.fp8_gemm_v2(
+ A=x,
+ trans_A=False,
+ B=other,
+ trans_B=False,
+ D=None,
+ out_dtype=out_dtype,
+ A_scale_inv=scale_input_inv,
+ B_scale_inv=scale_other_inv,
+ bias=None,
+ accumulate=False,
+ )
+
+ def forward(self, input, other):
+ qinput = self.quant_input(input, self.scale_input)
+ qother = self.quant_input(other, self.scale_other)
+ output = self.matmul_fp8(
+ qinput,
+ qother,
+ out_dtype=torch.bfloat16,
+ scale_input_inv=1.0 / self.scale_input,
+ scale_other_inv=1.0 / self.scale_other,
+ )
+ return output
+
+
+class FetchFromCache(torch.nn.Module):
+
+ def __init__(self, scale_inv):
+ super().__init__()
+ self.scale_inv = scale_inv
+
+ def forward(self, cache, blocks):
+ if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
+ out = cache[: blocks.size(0)]
+ else:
+ out = cache.index_select(0, blocks)
+ if out.dtype == torch.float8_e4m3fn:
+ out = torch.ops.hpu.cast_from_fp8(out, self.scale_inv, torch.bfloat16)
+ return out
+
+
+def attention(
+ *,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ kv_cache: KVCache,
+ kv_scales: KVScales,
+ seqlen: Seqlen,
+ softmax_scale: float,
+ window_size_left: int = -1,
+ causal: bool = True,
+ softcap: Optional[float] = None,
+):
+ fsdpa_op = ModuleFusedSDPA(FusedSDPA)
+ bs = seqlen.input_lengths.shape[0]
+ _, head_num, head_size = query.shape
+ _, kv_head_num, head_size = key.shape
+ query = query.view(bs, -1, head_num, head_size).transpose(1, 2)
+ key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
+ value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
+ attn_output = fsdpa_op(
+ query,
+ key,
+ value,
+ attn_mask=seqlen.attn_mask if window_size_left != -1 else None,
+ dropout_p=0.0,
+ is_causal=causal if window_size_left == -1 else False,
+ scale=softmax_scale,
+ softmax_mode="None",
+ recompute_mode=None,
+ valid_sequence_lengths=seqlen.input_lengths if window_size_left == -1 else None,
+ padding_side="left",
+ )
+ attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
+ return attn_output
+
+
+def set_block_mapping(hpu_attention_meta: HPUPagedAttentionMetadata, batch_size):
+ block_mapping = torch.nn.functional.one_hot(
+ hpu_attention_meta.block_groups, num_classes=batch_size
+ )
+ dtype = hpu_attention_meta.block_usage.dtype
+ device = hpu_attention_meta.block_usage.device
+ mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
+ mask = mask >= hpu_attention_meta.block_usage.unsqueeze(-1)
+ attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
+ hpu_attention_meta = hpu_attention_meta._replace(
+ attn_bias=attn_bias, block_mapping=block_mapping.to(dtype)
+ )
+ if hpu_attention_meta.block_groups_in_window is not None:
+ block_mapping = torch.nn.functional.one_hot(
+ hpu_attention_meta.block_groups_in_window, num_classes=batch_size
+ )
+ attn_bias = torch.log(hpu_attention_meta.slots_in_window_mask.float())
+ hpu_attention_meta = hpu_attention_meta._replace(
+ attn_bias_in_window=attn_bias,
+ block_mapping_in_window=block_mapping.to(dtype),
+ )
+ return hpu_attention_meta
+
+
+def paged_attention(
+ query: torch.Tensor,
+ kv_cache: KVCache,
+ kv_head_mapping: torch.Tensor,
+ softmax_scale: float,
+ seqlen: Seqlen,
+ *,
+ kv_scales: KVScales,
+ softcap: Optional[float] = None,
+ hpu_attention_meta: HPUPagedAttentionMetadata,
+ window_size_left: int = -1,
+):
+ batch_size, head_num, head_size = query.shape
+ fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
+ output = ops.flat_pa(
+ query=query.view(batch_size, 1, head_num * head_size),
+ key_cache=kv_cache.key,
+ value_cache=kv_cache.value,
+ block_list=(
+ hpu_attention_meta.block_list
+ if window_size_left == -1
+ else hpu_attention_meta.block_list_in_window
+ ),
+ block_mapping=(
+ hpu_attention_meta.block_mapping
+ if window_size_left == -1
+ else hpu_attention_meta.block_mapping_in_window
+ ),
+ block_bias=(
+ hpu_attention_meta.attn_bias
+ if window_size_left == -1
+ else hpu_attention_meta.attn_bias_in_window
+ ),
+ block_groups=(
+ hpu_attention_meta.block_groups
+ if window_size_left == -1
+ else hpu_attention_meta.block_groups_in_window
+ ),
+ block_size=BLOCK_SIZE,
+ scale=softmax_scale,
+ matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
+ matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
+ batch2block_matmul_op=Matmul(),
+ block2batch_matmul_op=Matmul(),
+ keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
+ values_fetch_func=FetchFromCache(1.0 / kv_scales.value_scale_cpu),
+ )
+ # Reshape the output tensor.
+ return output.view(batch_size, head_num, head_size)
+
+
+def paged_attention_mla(
+ query: torch.Tensor,
+ kv_cache: KVCache,
+ kv_head_mapping: torch.Tensor,
+ softmax_scale: float,
+ seqlen: Seqlen,
+ *,
+ kv_scales: KVScales,
+ softcap: Optional[float] = None,
+ hpu_attention_meta: HPUPagedAttentionMetadata,
+ kv_lora_rank: int = 0,
+):
+ batch_size, head_num, head_size = query.shape
+ fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
+ output = ops.flat_pa_mla(
+ query=query,
+ key_cache=kv_cache.key,
+ value_cache=None,
+ block_list=hpu_attention_meta.block_list,
+ block_mapping=hpu_attention_meta.block_mapping,
+ block_bias=hpu_attention_meta.attn_bias,
+ block_groups=hpu_attention_meta.block_groups,
+ block_size=BLOCK_SIZE,
+ scale=softmax_scale,
+ matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
+ matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
+ batch2block_matmul_op=Matmul(),
+ block2batch_matmul_op=Matmul(),
+ keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
+ values_fetch_func=None,
+ kv_lora_rank=kv_lora_rank,
+ )
+ # Reshape the output tensor.
+ return output.view(batch_size, head_num, -1)
+
+
+__all__ = [
+ "SUPPORTS_WINDOWING",
+ "attention",
+ "paged_attention",
+ "paged_attention_mla",
+ "set_block_mapping",
+]
diff --git a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py
new file mode 100644
index 00000000000..723c1ec02f4
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py
@@ -0,0 +1,205 @@
+from typing import Tuple
+from dataclasses import dataclass, field
+
+import torch
+
+from text_generation_server.models.globals import BLOCK_SIZE
+from text_generation_server.utils.weights import Weights
+
+
+@dataclass
+class KVScales:
+ """
+ Key-value scales for FP8 KV cache.
+
+ This data class stores key and value scales both as a GPU tensor and
+ as a GPU float. This inconvenience is necessary because some functions
+ (e.g. scaling kernels) take scales as a GPU tensor, whereas others
+ (e.g. flashinfer) take scales as a CPU scalar.
+ """
+
+ key_scale: torch.Tensor
+ value_scale: torch.Tensor
+ key_scale_cpu: float = field(init=False)
+ value_scale_cpu: float = field(init=False)
+
+ def __post_init__(self):
+ if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:
+ raise ValueError("Key and value scales must be scalar tensors.")
+
+ self.key_scale_cpu = self.key_scale.item()
+ self.value_scale_cpu = self.value_scale.item()
+
+
+class KVCache:
+ """
+ Key-value cache for attention layers.
+ """
+
+ kv_cache: Tuple[torch.Tensor, torch.Tensor]
+
+ def __init__(
+ self,
+ *,
+ num_blocks: int,
+ num_heads: int,
+ head_size: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ ):
+ """Construct the key-value cache for a layer."""
+ ## TODO FP8 kv cache support
+ if dtype is torch.float8_e5m2:
+ raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
+
+ self.kv_cache = (
+ torch.zeros(
+ (num_blocks * BLOCK_SIZE, num_heads, head_size),
+ dtype=dtype,
+ device=device,
+ ),
+ torch.zeros(
+ (num_blocks * BLOCK_SIZE, num_heads, head_size),
+ dtype=dtype,
+ device=device,
+ ),
+ )
+
+ @property
+ def dtype(self):
+ """Get the data type of the cache."""
+ return self.kv_cache[0].dtype
+
+ @property
+ def key(self):
+ """Get the key cache."""
+
+ return self.kv_cache[0]
+
+ @property
+ def value(self):
+ """Get the value cache."""
+
+ return self.kv_cache[1]
+
+ def store(
+ self,
+ *,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ slots: torch.Tensor,
+ kv_scales: KVScales,
+ ):
+ """Store the key and value at the given slots."""
+ ## TODO FP8 kv cache support
+
+ key_cache = self.kv_cache[0]
+ value_cache = self.kv_cache[1]
+
+ paged_reshape_and_cache(
+ key,
+ value,
+ key_cache,
+ value_cache,
+ slots,
+ kv_scales.key_scale,
+ kv_scales.value_scale,
+ )
+
+
+class KVCompressCache(KVCache):
+ """
+ Key-value cache for attention layers.
+ """
+
+ kv_cache: torch.Tensor
+
+ def __init__(
+ self,
+ *,
+ num_blocks: int,
+ head_size: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ ):
+ """Construct the key-value cache for a layer."""
+ ## TODO FP8 kv cache support
+ if dtype is torch.float8_e5m2:
+ raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
+
+ self.kv_cache = torch.zeros(
+ (num_blocks * BLOCK_SIZE, 1, head_size),
+ dtype=dtype,
+ device=device,
+ )
+
+ @property
+ def dtype(self):
+ """Get the data type of the cache."""
+ return self.kv_cache.dtype
+
+ @property
+ def key(self):
+ """Get the key cache."""
+
+ return self.kv_cache
+
+ @property
+ def value(self):
+ """Get the value cache."""
+
+ return self.kv_cache
+
+ def store(
+ self,
+ *,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ slots: torch.Tensor,
+ kv_scales: KVScales,
+ ):
+ """Store the key and value at the given slots."""
+ ## TODO FP8 kv cache support
+ if self.kv_cache.dtype == torch.float8_e4m3fn:
+ key = torch.ops.hpu.cast_to_fp8_v2(
+ key, kv_scales.key_scale, False, False, torch.float8_e4m3fn
+ )[0]
+ self.kv_cache.index_copy_(0, slots, key)
+
+
+def paged_reshape_and_cache(
+ key: torch.Tensor,
+ value: torch.Tensor,
+ key_cache: torch.Tensor,
+ value_cache: torch.Tensor,
+ slots: torch.Tensor,
+ k_scale: torch.Tensor,
+ v_scale: torch.Tensor,
+):
+ if key_cache.dtype == torch.float8_e4m3fn:
+ key = torch.ops.hpu.cast_to_fp8_v2(
+ key, k_scale, False, False, torch.float8_e4m3fn
+ )[0]
+ value = torch.ops.hpu.cast_to_fp8_v2(
+ value, v_scale, False, False, torch.float8_e4m3fn
+ )[0]
+ key_cache.index_copy_(0, slots, key)
+ value_cache.index_copy_(0, slots, value)
+
+
+def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
+ """Load KV cache scales."""
+
+ key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)
+ value_scale = key_scale
+ if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor(
+ f"{prefix}.v_scale"
+ ):
+ key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float()
+ value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float()
+ elif weights.has_tensor(f"{prefix}.kv_scale"):
+ # Fall back to older more coarse-grained scale when available.
+ key_scale = weights.get_tensor(f"{prefix}.kv_scale").float()
+ value_scale = key_scale
+
+ return KVScales(key_scale=key_scale, value_scale=value_scale)
diff --git a/backends/gaudi/server/text_generation_server/layers/awq/conversion_utils.py b/backends/gaudi/server/text_generation_server/layers/awq/conversion_utils.py
new file mode 100644
index 00000000000..b19eafbbe2f
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/awq/conversion_utils.py
@@ -0,0 +1,97 @@
+import torch
+from typing import List
+
+
+AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
+REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
+
+
+def pack(imatrix: torch.Tensor, direction: str = "column"):
+ """
+ Packs a 4-bit integer matrix into a packed 32-bit integer matrix.
+ Args:
+ imatrix (torch.Tensor): matrix of integers
+ direction (str): direction of packing, either "column" or "row"
+ Returns:
+ qmatrix (torch.Tensor): packed matrix of integers
+ """
+ shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=imatrix.device)
+
+ imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
+
+ if direction == "column":
+ imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4))
+ qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)
+
+ elif direction == "row":
+ imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1)
+ qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1)
+
+ qmatrix = qmatrix.to(torch.int32)
+
+ return qmatrix
+
+
+def unpack(qmatrix: torch.Tensor, direction: str = "column"):
+ """
+ Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix.
+ Args:
+ qmatrix (torch.Tensor): matrix of packed integers
+ direction (str): direction of unpacking, either "column" or "row"
+ Returns:
+ imatrix (torch.Tensor): matrix of integers
+ """
+ shifts = torch.arange(0, 32, 4, device=qmatrix.device)
+
+ if direction == "column":
+ imatrix = torch.bitwise_right_shift(
+ qmatrix[:, :, None], shifts[None, None, :]
+ ).view(qmatrix.shape[0], -1)
+
+ elif direction == "row":
+ imatrix = torch.bitwise_right_shift(
+ qmatrix[:, None, :], shifts[None, :, None]
+ ).view(-1, qmatrix.shape[-1])
+
+ imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
+
+ return imatrix
+
+
+def apply_order(
+ imatrix: torch.Tensor,
+ direction: str = "column",
+ order: List[int] = AWQ_PACK_ORDER,
+):
+ """
+ Applies the order to a 4-bit integer matrix.
+ Args:
+ imatrix (torch.Tensor): matrix of integers
+ direction (str): direction of applying order, either "column" or "row"
+ order (List[int]): order to apply, default is AWQ_PACK_ORDER
+ Returns:
+ imatrix (torch.Tensor): matrix of integers
+ """
+ if direction == "column":
+ imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape)
+ elif direction == "row":
+ imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape)
+
+ return imatrix
+
+
+def fast_awq_to_gptq(qweight, qzeros):
+ # awq uses column packing for both weights and zeros
+ izeros = unpack(qzeros, direction="column")
+ iweights = unpack(qweight, direction="column")
+
+ # Reverse the order of the iweight and izeros tensors
+ izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER)
+ iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER)
+ # Subtract 1 from the izeros tensor (gptq adds 1 to the zeros)
+ izeros = izeros - 1
+ # exllama uses row packing for weights and column packing for zeros
+ qzeros = pack(izeros, direction="column")
+ qweight = pack(iweights, direction="row")
+
+ return qweight, qzeros
diff --git a/backends/gaudi/server/text_generation_server/layers/awq/quantize/__init__.py b/backends/gaudi/server/text_generation_server/layers/awq/quantize/__init__.py
new file mode 100644
index 00000000000..856d7c28154
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/awq/quantize/__init__.py
@@ -0,0 +1,3 @@
+from .hpu import WQLinear
+
+__all__ = ["WQLinear"]
diff --git a/backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py b/backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py
new file mode 100644
index 00000000000..3af0131b3a3
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py
@@ -0,0 +1,134 @@
+from typing import Optional
+import torch
+import torch.nn as nn
+
+try:
+ import habana_frameworks.torch.hpu # noqa: F401
+
+ convert_from_uint4 = torch.ops.hpu.convert_from_uint4
+except Exception as e:
+ hpu_import_exception = e
+
+ def error_raiser_hpu(*args, **kwargs):
+ raise ValueError(
+ f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
+ )
+
+ convert_from_uint4 = error_raiser_hpu
+
+AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
+
+
+def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
+ shifts = torch.arange(0, 32, bits, device=qzeros.device)
+
+ # unpacking columnwise
+ iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
+ torch.int8 # smallest dtype available
+ )
+ iweights = iweights.view(iweights.shape[0], -1)
+
+ # unpacking columnwise
+ if qzeros is not None:
+ izeros = torch.bitwise_right_shift(
+ qzeros[:, :, None], shifts[None, None, :]
+ ).to(
+ torch.int8 # smallest dtype available
+ )
+ izeros = izeros.view(izeros.shape[0], -1)
+ else:
+ izeros = qzeros
+
+ return iweights, izeros
+
+
+def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
+ reverse_order_tensor = torch.arange(
+ iweights.shape[-1],
+ dtype=torch.int32,
+ device=izeros.device,
+ )
+ reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
+ reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
+ reverse_order_tensor = reverse_order_tensor.view(-1)
+
+ if izeros is not None:
+ izeros = izeros[:, reverse_order_tensor]
+ iweights = iweights[:, reverse_order_tensor]
+
+ return iweights, izeros
+
+
+def unpack_weight_and_zeros(qweight, qzeros, bits):
+ # Unpack the qweight and qzeros tensors
+ iweight, izeros = unpack_awq(qweight, qzeros, bits)
+ # Reverse the order of the iweight and izeros tensors
+ iweight, izeros = reverse_awq_order(iweight, izeros, bits)
+
+ # overflow checks
+ iweight = torch.bitwise_and(iweight, (2**bits) - 1)
+ izeros = torch.bitwise_and(izeros, (2**bits) - 1)
+
+ return iweight, izeros
+
+
+def pack_tensor(input, bits=4):
+ normal = input.to(torch.int32)
+ q = torch.zeros(
+ (normal.shape[0], normal.shape[1] // 32 * bits),
+ dtype=torch.int32,
+ device=input.device,
+ )
+ i = 0
+ col = 0
+ while col < q.shape[1]:
+ for j in range(i, i + (32 // bits)):
+ q[:, col] |= normal[:, j] << (bits * (j - i))
+ i += 32 // bits
+ col += 1
+ q = q.to(torch.int32)
+ return q
+
+
+class WQLinear(nn.Module):
+ def __init__(
+ self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
+ ):
+ super().__init__()
+
+ if w_bit not in [4]:
+ raise NotImplementedError("Only 4-bit are supported for now.")
+
+ self.in_features = qweight.shape[0]
+ self.out_features = qweight.shape[1] * 32 // w_bit
+
+ self.w_bit = w_bit
+ self.group_size = group_size if group_size != -1 else self.in_features
+ # quick sanity check (make sure aligment)
+ assert self.in_features % self.group_size == 0
+ assert self.out_features % (32 // self.w_bit) == 0
+
+ self.qweight = qweight
+ self.qzeros = qzeros
+ self.scales = scales
+ self.bias = bias
+ self._preprocessing()
+
+ def _preprocessing(self):
+ device = self.qweight.device
+ weight, zeros = unpack_weight_and_zeros(
+ self.qweight.cpu(), self.qzeros.cpu(), self.w_bit
+ )
+ self.qweight = pack_tensor(weight).to(device)
+ self.qzeros = pack_tensor(zeros).to(device)
+
+ @torch.no_grad()
+ def forward(self, x):
+ out_shape = x.shape[:-1] + (self.out_features,)
+ x = x.reshape(-1, x.shape[-1])
+ weights = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
+ outputs = torch.matmul(x, weights)
+
+ outputs = outputs + self.bias if self.bias is not None else outputs
+ outputs = outputs.reshape(out_shape)
+ return outputs
diff --git a/backends/gaudi/server/text_generation_server/layers/bnb.py b/backends/gaudi/server/text_generation_server/layers/bnb.py
new file mode 100644
index 00000000000..791d9b6d8c6
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/bnb.py
@@ -0,0 +1,124 @@
+from dataclasses import dataclass
+
+import bitsandbytes as bnb
+import torch
+from bitsandbytes.nn import Int8Params, Params4bit
+from text_generation_server.utils.weights import UnquantizedWeight
+
+
+@dataclass
+class BNBWeight(UnquantizedWeight):
+ weight: torch.Tensor
+
+ def get_linear(self, bias: torch.Tensor):
+ return Linear8bitLt(self.weight, bias, has_fp16_weights=False, threshold=6.0)
+
+
+class Linear8bitLt(torch.nn.Module):
+ def __init__(
+ self,
+ weight,
+ bias,
+ has_fp16_weights=True,
+ memory_efficient_backward=False,
+ threshold=0.0,
+ index=None,
+ ):
+ super().__init__()
+ assert (
+ not memory_efficient_backward
+ ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
+ self.state = bnb.MatmulLtState()
+ self.index = index
+
+ # Necessary for stacked layers
+ self.state.threshold = threshold
+ self.state.has_fp16_weights = has_fp16_weights
+ self.state.memory_efficient_backward = memory_efficient_backward
+ if threshold > 0.0 and not has_fp16_weights:
+ self.state.use_pool = True
+
+ self.weight = Int8Params(
+ weight.data,
+ has_fp16_weights=has_fp16_weights,
+ requires_grad=has_fp16_weights,
+ )
+ self.weight.cuda(weight.device)
+ self.bias = bias
+
+ def init_8bit_state(self):
+ self.state.CB = self.weight.CB
+ self.state.SCB = self.weight.SCB
+ self.weight.CB = None
+ self.weight.SCB = None
+
+ def forward(self, x: torch.Tensor):
+ self.state.is_training = self.training
+ if self.weight.CB is not None:
+ self.init_8bit_state()
+
+ # weights are cast automatically as Int8Params, but the bias has to be cast manually
+ if self.bias is not None and self.bias.dtype != x.dtype:
+ self.bias.data = self.bias.data.to(x.dtype)
+
+ out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
+
+ if not self.state.has_fp16_weights:
+ if self.state.CB is not None and self.state.CxB is not None:
+ # we converted 8-bit row major to turing/ampere format in the first inference pass
+ # we no longer need the row-major weight
+ del self.state.CB
+ self.weight.data = self.state.CxB
+ return out
+
+
+@dataclass
+class BNBFP4Weight(UnquantizedWeight):
+ weight: torch.Tensor
+
+ def get_linear(self, bias: torch.Tensor):
+ return Linear4bit(self.weight, bias, quant_type="fp4")
+
+
+@dataclass
+class BNBNF4Weight(UnquantizedWeight):
+ weight: torch.Tensor
+
+ def get_linear(self, bias: torch.Tensor):
+ return Linear4bit(self.weight, bias, quant_type="nf4")
+
+
+class Linear4bit(torch.nn.Module):
+ def __init__(self, weight, bias, quant_type):
+ super().__init__()
+ self.weight = Params4bit(
+ weight.data,
+ requires_grad=False,
+ compress_statistics=True,
+ quant_type=quant_type,
+ )
+ self.compute_dtype = None
+ self.weight.cuda(weight.device)
+ self.bias = bias
+
+ def forward(self, x: torch.Tensor):
+ # weights are cast automatically as Int8Params, but the bias has to be cast manually
+ if self.bias is not None and self.bias.dtype != x.dtype:
+ self.bias.data = self.bias.data.to(x.dtype)
+
+ if getattr(self.weight, "quant_state", None) is None:
+ print(
+ "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
+ )
+ inp_dtype = x.dtype
+ if self.compute_dtype is not None:
+ x = x.to(self.compute_dtype)
+
+ bias = None if self.bias is None else self.bias.to(self.compute_dtype)
+ out = bnb.matmul_4bit(
+ x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
+ )
+
+ out = out.to(inp_dtype)
+
+ return out
diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py
new file mode 100644
index 00000000000..507af706b9e
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py
@@ -0,0 +1,3 @@
+from .loader import CompressedTensorsLoader
+
+__all__ = ["CompressedTensorsLoader"]
diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py
new file mode 100644
index 00000000000..0dccf34a5c3
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py
@@ -0,0 +1,169 @@
+from typing import Any, Dict, List, Union
+
+from compressed_tensors import QuantizationConfig, QuantizationStatus
+from compressed_tensors.config import CompressionFormat
+from compressed_tensors.quantization import (
+ QuantizationScheme,
+ QuantizationType,
+ find_name_or_class_matches,
+)
+from loguru import logger
+from pydantic import ValidationError
+from torch import nn
+
+from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader
+from text_generation_server.utils.log import log_once
+from text_generation_server.utils.weights import (
+ DefaultWeightsLoader,
+ UnquantizedWeight,
+ Weights,
+ WeightsLoader,
+)
+
+# compressed-tensors can match modules as quantization targets. However,
+# they need to be objects rather than classes or class names. Since we
+# need to match `Linear` targets, make an instance that can be re-used.
+_EMPTY_LINEAR: nn.Module = nn.Linear(0, 0)
+
+
+class CompressedTensorsLoader(WeightsLoader):
+ """Loader for checkpoints stored in the compressed-tensors format."""
+
+ def __init__(self, config: Dict[str, Any]):
+ quantization_config_raw = config.get("quantization_config")
+ if quantization_config_raw is None:
+ # `compression_config` was renamed to `quantization_config`; support
+ # retained for backward compatibility.
+ quantization_config_raw = config.get("compression_config")
+ if quantization_config_raw is None:
+ raise ValueError(
+ "Checkpoint does not have compressed-tensors configuration"
+ )
+
+ try:
+ quantization_config = QuantizationConfig.model_validate(
+ quantization_config_raw
+ )
+ except ValidationError as e:
+ raise ValueError("Cannot parse compressed-tensors configuration") from e
+
+ if quantization_config.quantization_status not in (
+ QuantizationStatus.COMPRESSED,
+ QuantizationStatus.FROZEN,
+ ):
+ raise ValueError(
+ f"Model quantization was not finished, status was: {quantization_config.quantization_status}"
+ )
+
+ self.ignore = (
+ quantization_config.ignore if quantization_config.ignore is not None else []
+ )
+ self.loaders = self._get_target_loaders(quantization_config)
+
+ for target, loader in self.loaders.items():
+ log_once(
+ logger.info,
+ f"Using {loader} for compressed-tensors target '{target}'",
+ )
+
+ def get_weights(self, weights: Weights, prefix: str):
+ loader = self._lookup_loader(prefix)
+ return loader.get_weights(weights, prefix)
+
+ def get_weights_col_packed(
+ self,
+ weights: "Weights",
+ prefix: str,
+ block_sizes: Union[int, List[int]],
+ ):
+ loader = self._lookup_loader(prefix)
+ return loader.get_weights_col_packed(weights, prefix, block_sizes)
+
+ def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
+ loader = self._lookup_loader(prefixes[0])
+ return loader.get_multi_weights_col(weights, prefixes, dim)
+
+ def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int):
+ loader = self._lookup_loader(prefixes[0])
+ return loader.get_multi_weights(weights, prefixes, dim)
+
+ def get_weights_row(self, weights: Weights, prefix: str):
+ loader = self._lookup_loader(prefix)
+ return loader.get_weights_row(weights, prefix)
+
+ def _get_target_loaders(
+ self, quantization_config: QuantizationConfig
+ ) -> Dict[str, WeightsLoader]:
+ """
+ A compressed-tensors checkpoint can use different quantizations
+ for different targets. This method returns a dictionary with a
+ loader per target.
+ """
+
+ loaders: Dict[str, WeightsLoader] = {}
+
+ format = quantization_config.format
+
+ for group_name, group in quantization_config.config_groups.items():
+ # The group configuration can be a string, but does that ever
+ # happen in a serialized quantization config?
+ assert isinstance(group, QuantizationScheme)
+
+ loader = self._create_loader_for_group(format, group_name, group)
+
+ # A quantized parameter group can have multiple targets, add the
+ # loader for all the targets.
+ for target in group.targets:
+ if target in loaders:
+ raise ValueError(
+ f"Target '{target} has multiple configured loaders'"
+ )
+ loaders[target] = loader
+
+ return loaders
+
+ def _create_loader_for_group(
+ self, format: str, group_name: str, group: QuantizationScheme
+ ) -> WeightsLoader:
+ """
+ Find and create a loader for the group with the given quantization
+ scheme.
+ """
+ # NOTE: we ignore group.output_activations because we don't support
+ # output quantization yet.
+
+ input_activations = group.input_activations
+ weights = group.weights
+ if (
+ format
+ in {
+ CompressionFormat.float_quantized.value,
+ CompressionFormat.naive_quantized.value,
+ }
+ and weights is not None
+ and weights.type == QuantizationType.FLOAT
+ and weights.num_bits == 8
+ ):
+ # FP W8A8 or W8A16.
+ return W8ANFpLoader(input_activations=input_activations, weights=weights)
+ else:
+ raise ValueError(
+ f"Group '{group_name}' has unsupported compressed-tensors configurtion"
+ )
+
+ def _lookup_loader(self, prefix: str) -> WeightsLoader:
+ """
+ Look up the loader to use for a given parameter name (prefix).
+ """
+
+ if len(find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.ignore)) > 0:
+ return DefaultWeightsLoader(UnquantizedWeight)
+
+ # We currently only handle linear layers, so unconditionally pass
+ # a `Linear` instance.
+ targets = find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.loaders.keys())
+ if len(targets) == 0:
+ raise ValueError(
+ f"Cannot find compressed-tensors target for prefix: {prefix}"
+ )
+ return self.loaders[targets[0]]
diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py
new file mode 100644
index 00000000000..6eb003874d6
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py
@@ -0,0 +1,253 @@
+from typing import List, Optional, Union
+
+import torch
+from compressed_tensors.quantization import QuantizationArgs, QuantizationType
+
+from text_generation_server.layers.fp8 import (
+ Fp8Weight,
+ _load_scalar_or_matrix_scale,
+ requantize_with_max_scale,
+)
+from text_generation_server.utils.weights import Weights, WeightsLoader
+
+
+class W8ANFpLoader(WeightsLoader):
+ """
+ Loader for W8A8/W8A16 FP compressed-tensors parameters.
+ """
+
+ def __init__(
+ self,
+ *,
+ input_activations: Optional[QuantizationArgs],
+ weights: QuantizationArgs,
+ ):
+ assert weights.type == QuantizationType.FLOAT and weights.num_bits == 8
+
+ # We ignore the `strategy` option which sets the scales to be
+ # per-tensor, per-channel or per-token. What scales are supported
+ # is dependent on the kernels used (e.g. cutlass can do tokenwise,
+ # Torch cannot, and FP8-Marlin does not quantize inputs at all).
+ # So, instead we try to use the best-possible configuration.
+
+ self.load_weight_scale = not weights.dynamic
+ self.load_input_scale = (
+ input_activations is not None and not input_activations.dynamic
+ )
+ self.force_w8a16 = (
+ input_activations is not None and input_activations.num_bits == 16
+ )
+
+ def __str__(self) -> str:
+ def scale_to_str(scale):
+ return "static" if scale else "dynamic"
+
+ quantization_type = f"W8A{16 if self.force_w8a16 else 8}"
+
+ return f"{self.__class__.__name__} ({quantization_type}, weight: {scale_to_str(self.load_weight_scale)}, input: {scale_to_str(self.load_input_scale)})"
+
+ def get_weights(self, weights: "Weights", prefix: str):
+ w = weights.get_tensor(f"{prefix}.weight")
+
+ weight_scale = None
+ if self.load_weight_scale:
+ weight_scale = (
+ weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
+ .reshape(-1)
+ .expand(w.shape[0])
+ )
+ logical_widths = [w.shape[0]]
+ w, weight_scale = requantize_with_max_scale(
+ w,
+ weight_scale.unsqueeze(-1).to(weights.device),
+ logical_widths,
+ weights.dtype,
+ )
+
+ input_scale = None
+ if self.load_input_scale:
+ input_scale = weights.get_tensor(
+ f"{prefix}.input_scale", to_dtype=False
+ ).reshape(-1)
+
+ return Fp8Weight(
+ weight=w,
+ weight_scale=weight_scale,
+ input_scale=input_scale,
+ dtype=weights.dtype,
+ force_w8a16=self.force_w8a16,
+ )
+
+ def get_weights_col_packed(
+ self,
+ weights: Weights,
+ prefix: str,
+ block_sizes: Union[int, List[int]],
+ ):
+ w = weights.get_packed_sharded(
+ f"{prefix}.weight", dim=0, block_sizes=block_sizes
+ )
+
+ weight_scale = None
+ if self.load_weight_scale:
+ weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
+ if weight_scale.numel() > 1:
+ weight_scale = weights.get_packed_sharded(
+ f"{prefix}.weight_scale",
+ dim=0,
+ block_sizes=block_sizes,
+ to_dtype=False,
+ )
+ weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
+ logical_widths = [w.shape[0]]
+ w, weight_scale = requantize_with_max_scale(
+ w,
+ weight_scale.unsqueeze(-1).to(weights.device),
+ logical_widths,
+ weights.dtype,
+ )
+
+ input_scale = None
+ if self.load_input_scale:
+ input_scale = weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
+ if input_scale.numel() > 1:
+ input_scale = weights.get_packed_sharded(
+ f"{prefix}.input_scale",
+ dim=0,
+ block_sizes=block_sizes,
+ to_dtype=False,
+ )
+ input_scale = input_scale.reshape(-1).max()
+
+ return Fp8Weight(
+ weight=w,
+ weight_scale=weight_scale,
+ input_scale=input_scale,
+ dtype=weights.dtype,
+ force_w8a16=self.force_w8a16,
+ )
+
+ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
+ # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
+ w = [
+ weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
+ ]
+ shapes = [x.shape for x in w]
+
+ # Concat then send to the device
+ w = torch.cat(w, dim=dim).to(weights.device)
+
+ weight_scale = None
+ if self.load_weight_scale:
+ weight_scale = [
+ _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
+ for p, shape in zip(prefixes, shapes)
+ ]
+ weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
+ logical_widths = [x[0] for x in shapes]
+ w, weight_scale = requantize_with_max_scale(
+ w,
+ weight_scale.unsqueeze(-1).to(weights.device),
+ logical_widths,
+ weights.dtype,
+ )
+
+ input_scale = None
+ if self.load_input_scale:
+ input_scale = [
+ _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
+ for p, shape in zip(prefixes, shapes)
+ if weights.has_tensor(f"{p}.input_scale")
+ ]
+ assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
+ input_scale = (
+ torch.cat(input_scale, dim=0).reshape(-1).max()
+ if len(input_scale) != 0
+ else None
+ )
+
+ return Fp8Weight(
+ weight=w,
+ weight_scale=weight_scale,
+ input_scale=input_scale,
+ dtype=weights.dtype,
+ force_w8a16=self.force_w8a16,
+ )
+
+ def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
+ # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
+ w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes]
+ shapes = [x.shape for x in w]
+
+ # Concat then send to the device
+ w = torch.cat(w, dim=dim).to(weights.device)
+
+ weight_scale = None
+
+ if self.load_weight_scale:
+ weight_scale = [
+ weights.get_tensor(f"{p}.weight_scale", to_dtype=False)
+ .reshape(-1)
+ .expand(shape[0])
+ for p, shape in zip(prefixes, shapes)
+ ]
+ weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
+ logical_widths = [x[0] for x in shapes]
+ w, weight_scale = requantize_with_max_scale(
+ w,
+ weight_scale.unsqueeze(-1).to(weights.device),
+ logical_widths,
+ weights.dtype,
+ )
+
+ input_scale = None
+ if self.load_input_scale:
+ input_scale = [
+ weights.get_tensor(f"{p}.input_scale", to_dtype=False)
+ .reshape(-1)
+ .expand(shape[0])
+ for p, shape in zip(prefixes, shapes)
+ if weights.has_tensor(f"{p}.input_scale")
+ ]
+ assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
+ input_scale = (
+ torch.cat(input_scale, dim=0).reshape(-1).max()
+ if len(input_scale) != 0
+ else None
+ )
+
+ return Fp8Weight(
+ weight=w,
+ weight_scale=weight_scale,
+ input_scale=input_scale,
+ dtype=weights.dtype,
+ force_w8a16=self.force_w8a16,
+ )
+
+ def get_weights_row(self, weights: "Weights", prefix: str):
+ w = weights.get_sharded(f"{prefix}.weight", dim=1)
+ weight_scale = None
+ if self.load_weight_scale:
+ weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
+ weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
+ logical_widths = [w.shape[0]]
+ w, weight_scale = requantize_with_max_scale(
+ w,
+ weight_scale.unsqueeze(-1).to(weights.device),
+ logical_widths,
+ weights.dtype,
+ )
+
+ input_scale = None
+ if self.load_input_scale:
+ input_scale = weights.get_tensor(
+ f"{prefix}.input_scale", to_dtype=False
+ ).reshape(-1)
+
+ return Fp8Weight(
+ weight=w,
+ weight_scale=weight_scale,
+ input_scale=input_scale,
+ dtype=weights.dtype,
+ force_w8a16=self.force_w8a16,
+ )
diff --git a/backends/gaudi/server/text_generation_server/layers/conv.py b/backends/gaudi/server/text_generation_server/layers/conv.py
new file mode 100644
index 00000000000..7fb18ab3f07
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/conv.py
@@ -0,0 +1,41 @@
+from accelerate import init_empty_weights
+import torch
+
+
+@classmethod
+def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
+ weight = weights.get_tensor(f"{prefix}.weight")
+ bias = weights.get_tensor(f"{prefix}.bias")
+ with init_empty_weights():
+ conv2d = cls(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ )
+
+ conv2d.weight = torch.nn.Parameter(weight)
+ conv2d.bias = torch.nn.Parameter(bias)
+ return conv2d
+
+
+@classmethod
+def load_conv2d_no_bias(
+ cls, prefix, weights, in_channels, out_channels, kernel_size, stride
+):
+ weight = weights.get_tensor(f"{prefix}.weight")
+ with init_empty_weights():
+ conv2d = cls(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ )
+
+ conv2d.weight = torch.nn.Parameter(weight)
+ conv2d.bias = None
+ return conv2d
+
+
+torch.nn.Conv2d.load = load_conv2d
+torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias
diff --git a/backends/gaudi/server/text_generation_server/layers/exl2.py b/backends/gaudi/server/text_generation_server/layers/exl2.py
new file mode 100644
index 00000000000..a6e07f45343
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/exl2.py
@@ -0,0 +1,78 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import torch
+from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
+
+
+@dataclass
+class Exl2Weight(Weight):
+ """
+ Exllama2 exl2 quantized weights.
+ """
+
+ q_weight: torch.Tensor
+ q_scale: torch.Tensor
+ q_invperm: torch.Tensor
+ q_scale_max: torch.Tensor
+ q_groups: torch.Tensor
+
+ def __post_init__(self):
+ self.q_scale_max /= 256
+ self.q_invperm = self.q_invperm.short()
+
+ @property
+ def device(self) -> torch.device:
+ return self.q_weight.device
+
+ def get_linear(self, bias: torch.Tensor):
+ from text_generation_server.layers.gptq import ExllamaQuantLinear
+
+ return ExllamaQuantLinear(self, bias)
+
+
+class Exl2WeightsLoader(WeightsLoader):
+ """Loader for exl2-quantized weights."""
+
+ def get_weights(self, weights: "Weights", prefix: str):
+ """
+ Get weights at the given prefix and apply without tensor paralllism.
+ """
+ try:
+ q_weight = weights.get_tensor(f"{prefix}.q_weight")
+ except RuntimeError:
+ raise RuntimeError(
+ "Cannot load `exl2`-quantized weight, make sure the model is already quantized."
+ )
+
+ q_scale = weights.get_tensor(f"{prefix}.q_scale")
+ q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
+ q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
+ q_groups = weights.get_tensor(f"{prefix}.q_groups")
+
+ return Exl2Weight(
+ q_weight=q_weight,
+ q_scale=q_scale,
+ q_invperm=q_invperm,
+ q_scale_max=q_scale_max,
+ q_groups=q_groups,
+ )
+
+ def get_weights_col_packed(
+ self,
+ weights: Weights,
+ prefix: str,
+ block_sizes: Union[int, List[int]],
+ ):
+ raise RuntimeError("Column-packed weights are not supported for exl")
+
+ def get_weights_col(self, weights: Weights, prefix: str):
+ # Sharding is not yet supported, so we return the weights as-is.
+ return self.get_weights(weights, prefix)
+
+ def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
+ raise ValueError("get_multi_weights_col is not supported for exl2")
+
+ def get_weights_row(self, weights: Weights, prefix: str):
+ # Sharding is not yet supported, so we return the weights as-is.
+ return self.get_weights(weights, prefix)
diff --git a/backends/gaudi/server/text_generation_server/layers/fp8.py b/backends/gaudi/server/text_generation_server/layers/fp8.py
new file mode 100644
index 00000000000..8de335ac5e2
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/fp8.py
@@ -0,0 +1,655 @@
+from dataclasses import dataclass
+from typing import Optional, Tuple, Type, Union, List
+
+import torch
+
+from text_generation_server.utils.weights import (
+ Weight,
+ WeightsLoader,
+ UnquantizedWeight,
+ Weights,
+)
+
+from vllm_hpu_extension.ops import scaled_fp8_quant
+from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
+
+quant_dtype: torch.dtype = torch.float8_e4m3fn
+FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
+if is_hpu_gaudi2():
+ FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max
+
+
+def pad_weight(weight, block_size):
+ """Pads a matrix to make its dimensions multiples of block_size."""
+ M, N = weight.shape[-2:]
+ block_size_m, block_size_n = block_size
+ pad_M = (block_size_m - M % block_size_m) % block_size_m
+ pad_N = (block_size_n - N % block_size_n) % block_size_n
+
+ if pad_M == 0 and pad_N == 0:
+ return weight, M, N # No padding needed
+ padded_weight = torch.nn.functional.pad(
+ weight, (0, pad_N, 0, pad_M), mode="constant", value=0
+ )
+ return padded_weight, M, N # Return original dimensions for unpadding
+
+
+def unpad_weight(weight, original_M, original_N, keep_first_dim=False):
+ """Removes padding from the matrix to restore its original shape."""
+ if (weight.shape[-2] == original_M) and (weight.shape[-1] == original_N):
+ return weight
+ if keep_first_dim:
+ return weight[:, :original_M, :original_N]
+ else:
+ return weight[:original_M, :original_N]
+
+
+def pad_block_fp8_weight_naive(weight, weight_scale, block_size):
+
+ assert len(block_size) == 2
+
+ block_size_m, block_size_n = block_size
+ weight_scale_m, weight_scale_n = weight_scale.shape[-2:]
+
+ weight, orig_M, orig_N = pad_weight(weight, block_size)
+ M, N = weight.shape[-2:]
+
+ assert weight_scale_m == M // block_size_m
+ assert weight_scale_n == N // block_size_n
+
+ return weight, orig_M, orig_N
+
+
+def dynamic_quant(data, single_scale=False):
+ if single_scale:
+ scale = ((torch.abs(data)).max() + 1e-8) / FP8_MAX
+ else:
+ scale = ((torch.abs(data)).max(dim=-1).values + 1e-8) / FP8_MAX
+ scale = scale.unsqueeze(-1)
+ data_fp8 = torch.ops.hpu.cast_to_fp8_v2(
+ data, 1.0 / scale, False, False, torch.float8_e4m3fn
+ )[0]
+ return data_fp8, scale.float()
+
+
+def dequant_block_fp8_weight_naive(
+ weight,
+ weight_scale,
+ block_size,
+ dtype=torch.bfloat16,
+ original_M=None,
+ original_N=None,
+ do_unpad=False,
+):
+ if weight_scale is None:
+ return weight
+ assert len(block_size) == 2
+
+ weight_shape_len = len(weight.shape)
+
+ block_size_m, block_size_n = block_size
+
+ # mul scale
+ if weight_shape_len == 2:
+ weight_scale_m, weight_scale_n = weight_scale.shape
+ weight_scale = weight_scale.view(weight_scale_m, 1, weight_scale_n, 1)
+ weight = weight.view(weight_scale_m, block_size_m, weight_scale_n, block_size_n)
+ if is_hpu_gaudi2():
+ fake_weight = weight.cpu().to(dtype).to(weight.device)
+ dequant_weight = fake_weight * weight_scale.to(dtype)
+ else:
+ dequant_weight = weight.to(dtype) * weight_scale.to(dtype)
+ dequant_weight = dequant_weight.view(
+ weight_scale_m * block_size_m, weight_scale_n * block_size_n
+ )
+ keep_first_dim = False
+ elif weight_shape_len == 3:
+ fd, weight_scale_m, weight_scale_n = weight_scale.shape
+ weight_scale = weight_scale.view(fd, weight_scale_m, 1, weight_scale_n, 1)
+ weight = weight.view(
+ fd, weight_scale_m, block_size_m, weight_scale_n, block_size_n
+ )
+ if is_hpu_gaudi2():
+ fake_weight = weight.cpu().to(dtype).to(weight.device)
+ dequant_weight = fake_weight * weight_scale.to(dtype)
+ else:
+ dequant_weight = weight.to(dtype) * weight_scale.to(dtype)
+ dequant_weight = dequant_weight.view(
+ fd, weight_scale_m * block_size_m, weight_scale_n * block_size_n
+ )
+ keep_first_dim = True
+ else:
+ raise ValueError("Only support original weight shape is either 2 or 3")
+
+ if do_unpad:
+ dequant_weight = unpad_weight(
+ dequant_weight, original_M, original_N, keep_first_dim=keep_first_dim
+ )
+
+ return dequant_weight
+
+
+def apply_block_fp8_linear_hpu_dynamic(
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ weight_scale: torch.Tensor,
+ input_scale: Optional[torch.Tensor] = None,
+ bias: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ # View input as 2D matrix for fp8 methods
+ input_2d = input.view(-1, input.shape[-1])
+ output_shape = [*input.shape[:-1], weight.shape[0]]
+
+ x_fp8, x_scale = dynamic_quant(input_2d)
+
+ output = torch.ops.hpu.fp8_gemm_v2(
+ x_fp8,
+ False,
+ weight,
+ True,
+ None,
+ torch.bfloat16,
+ x_scale,
+ weight_scale,
+ None,
+ False,
+ )
+ if bias is not None:
+ output = output + bias
+ return output.to(dtype=input.dtype).view(*output_shape)
+
+
+def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
+ """
+ Return an FP8 linear `Module` that is compatible with the current system.
+ """
+ # On other systems let Torch decide if the hardware supports FP8.
+ return Fp8Linear
+
+
+def normalize_e4m3fn_to_native_float8(
+ weight: torch.Tensor,
+ weight_scale: torch.Tensor,
+ input_scale: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ return weight, weight_scale, input_scale
+
+
+def per_tensor_dequantize(
+ tensor: torch.Tensor,
+ inv_scale: Union[float, torch.Tensor],
+ dtype: torch.dtype = torch.float16,
+) -> torch.Tensor:
+ device = tensor.device
+ dtype = torch.bfloat16
+ if is_hpu_gaudi2():
+ # dequant on cpu to avoid nan on gaudi2
+ tensor = tensor.to("cpu")
+
+ fake_qweight = tensor.to(dtype).to(device)
+ dq_weight = fake_qweight * inv_scale
+ return dq_weight
+
+
+def requantize_with_max_scale(
+ weight: torch.Tensor,
+ weight_scale: torch.Tensor,
+ logical_widths: int,
+ dtype: torch.dtype,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Max scale to be used for requanitzation.
+ max_w_scale = weight_scale.max()
+
+ if is_hpu_gaudi2():
+ max_w_scale = max_w_scale * get_hpu_gaudi2_scale_factor()
+
+ start = 0
+ for idx, logical_width in enumerate(logical_widths):
+ end = start + logical_width
+ weight_dq = per_tensor_dequantize(
+ weight[start:end, :], weight_scale[start:end, :], dtype
+ )
+ weight[start:end, :], max_w_scale_normalized = fp8_quantize(
+ weight_dq, max_w_scale
+ )
+ start = end
+
+ return weight, max_w_scale_normalized
+
+
+def fp8_quantize(
+ weight: torch.Tensor,
+ scale: Optional[torch.Tensor] = None,
+ scale_upper_bound: Optional[torch.Tensor] = None,
+ qdtype: torch.dtype = torch.float8_e4m3fn,
+ scalar: bool = False,
+):
+ """
+ This function returns a reciprocal of the scale, so that a tensor can be unscaled
+ by multiplying it with the returned scale. If a scale is given through the `scale`
+ argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
+ be used without modification).
+ """
+ shape = weight.shape
+ qweight, scale = scaled_fp8_quant(
+ weight.reshape(-1, shape[-1]),
+ scale=scale,
+ scale_ub=scale_upper_bound,
+ # TODO: don't do this when we have to use the Torch kernel.
+ use_per_token_if_dynamic=not scalar,
+ )
+
+ return qweight.reshape(shape), scale
+
+
+class HybridFP8UnquantLoader(WeightsLoader):
+ """Weight loader that loads FP8 and unquantized Torch tensors."""
+
+ def __init__(
+ self,
+ activation_scale_ub: Optional[float],
+ to_fp8: bool,
+ weight_block_size: Optional[List[int]] = None,
+ ):
+ self.activation_scale_ub = activation_scale_ub
+ self.to_fp8 = to_fp8
+ self.weight_block_size = weight_block_size
+
+ def get_weights(self, weights: "Weights", prefix: str):
+ w = weights.get_tensor(f"{prefix}.weight")
+
+ if w.dtype == torch.float8_e4m3fn:
+ if self.weight_block_size is not None:
+ scale = weights.get_tensor(f"{prefix}.weight_scale_inv")
+ return Fp8Weight(
+ weight=w,
+ weight_scale=scale,
+ activation_scale_ub=self.activation_scale_ub,
+ dtype=weights.dtype,
+ weight_block_size=self.weight_block_size,
+ )
+ # FP8 branch
+ scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
+ scale = scale.reshape(-1).expand(w.shape[0])
+ logical_widths = [w.shape[0]]
+ w, scale = requantize_with_max_scale(
+ w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
+ )
+
+ input_scale = None
+ if weights.has_tensor(f"{prefix}.input_scale"):
+ input_scale = (
+ weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
+ .reshape(-1)
+ .max()
+ )
+
+ return Fp8Weight(
+ weight=w,
+ weight_scale=scale,
+ input_scale=input_scale,
+ activation_scale_ub=self.activation_scale_ub,
+ dtype=weights.dtype,
+ )
+ if self.to_fp8:
+ return Fp8Weight(weight=w, dtype=weights.dtype)
+
+ return UnquantizedWeight(w)
+
+ def get_weights_col_packed(
+ self,
+ weights: Weights,
+ prefix: str,
+ block_sizes: Union[int, List[int]],
+ ):
+ w = weights.get_packed_sharded(
+ f"{prefix}.weight", dim=0, block_sizes=block_sizes
+ )
+
+ if w.dtype == torch.float8_e4m3fn:
+ # FP8 branch
+ scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
+
+ if scale.numel() > 1:
+ scale = weights.get_packed_sharded(
+ f"{prefix}.weight_scale",
+ dim=0,
+ block_sizes=block_sizes,
+ to_dtype=False,
+ )
+ scale = scale.reshape(-1).expand(w.shape[0])
+ logical_widths = [w.shape[0]]
+ w, scale = requantize_with_max_scale(
+ w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
+ )
+
+ input_scale = None
+ if weights.has_tensor(f"{prefix}.input_scale"):
+ input_scale = weights.get_tensor(
+ f"{prefix}.input_scale", to_dtype=False
+ )
+ if input_scale.numel() > 1:
+ input_scale = weights.get_packed_sharded(
+ f"{prefix}.input_scale",
+ dim=0,
+ block_sizes=block_sizes,
+ to_dtype=False,
+ )
+ input_scale = input_scale.reshape(-1).max()
+
+ return Fp8Weight(
+ weight=w,
+ weight_scale=scale,
+ input_scale=input_scale,
+ activation_scale_ub=self.activation_scale_ub,
+ dtype=weights.dtype,
+ )
+ if self.to_fp8:
+ return Fp8Weight(weight=w, dtype=weights.dtype)
+
+ return UnquantizedWeight(w)
+
+ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
+ # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
+ w = [
+ weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
+ ]
+ shapes = [x.shape for x in w]
+
+ # Concat then send to the device
+ w = torch.cat(w, dim=dim).to(weights.device)
+
+ # FP8 branch
+ if w.dtype == torch.float8_e4m3fn:
+ if self.weight_block_size is not None:
+ scale = [
+ weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
+ for p in prefixes
+ ]
+ scale = torch.cat(scale, dim=dim)
+ scale = scale.to(weights.device)
+ return Fp8Weight(
+ weight=w,
+ weight_scale=scale,
+ activation_scale_ub=self.activation_scale_ub,
+ dtype=weights.dtype,
+ weight_block_size=self.weight_block_size,
+ )
+
+ scale = [
+ _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
+ for p, shape in zip(prefixes, shapes)
+ ]
+ scale = torch.cat(scale, dim=0).reshape(-1)
+
+ logical_widths = [x[0] for x in shapes]
+ w, scale = requantize_with_max_scale(
+ w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
+ )
+
+ input_scale = [
+ _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
+ for p, shape in zip(prefixes, shapes)
+ if weights.has_tensor(f"{p}.input_scale")
+ ]
+ assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
+ input_scale = (
+ torch.cat(input_scale, dim=0).reshape(-1).max()
+ if len(input_scale) != 0
+ else None
+ )
+
+ return Fp8Weight(
+ weight=w,
+ weight_scale=scale,
+ input_scale=input_scale,
+ activation_scale_ub=self.activation_scale_ub,
+ dtype=weights.dtype,
+ )
+ if self.to_fp8:
+ return Fp8Weight(weight=w, dtype=weights.dtype)
+
+ return UnquantizedWeight(w)
+
+ def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
+ # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
+ w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes]
+ shapes = [x.shape for x in w]
+
+ # Concat then send to the device
+ w = torch.cat(w, dim=dim).to(weights.device)
+
+ # FP8 branch
+ if w.dtype == torch.float8_e4m3fn:
+ if self.weight_block_size is not None:
+ scale = [
+ weights.get_tensor(f"{p}.weight_scale_inv", to_device=False)
+ for p in prefixes
+ ]
+ scale = torch.cat(scale, dim=dim)
+ scale = scale.to(weights.device)
+ return Fp8Weight(
+ weight=w,
+ weight_scale=scale,
+ activation_scale_ub=self.activation_scale_ub,
+ dtype=weights.dtype,
+ weight_block_size=self.weight_block_size,
+ )
+
+ scale = [
+ weights.get_tensor(f"{p}.weight_scale", to_dtype=False)
+ .reshape(-1)
+ .expand(shape[0])
+ for p, shape in zip(prefixes, shapes)
+ ]
+ scale = torch.cat(scale, dim=0).reshape(-1)
+
+ logical_widths = [x[0] for x in shapes]
+ w, scale = requantize_with_max_scale(
+ w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
+ )
+
+ input_scale = [
+ weights.get_tensor(f"{p}.input_scale", to_dtype=False).reshape(-1)
+ for p in prefixes
+ if weights.has_tensor(f"{p}.input_scale")
+ ]
+ assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
+ input_scale = (
+ torch.cat(input_scale, dim=0).reshape(-1).max()
+ if len(input_scale) != 0
+ else None
+ )
+
+ return Fp8Weight(
+ weight=w,
+ weight_scale=scale,
+ input_scale=input_scale,
+ activation_scale_ub=self.activation_scale_ub,
+ dtype=weights.dtype,
+ )
+ if self.to_fp8:
+ return Fp8Weight(weight=w, dtype=weights.dtype)
+
+ return UnquantizedWeight(w)
+
+ def get_weights_row(self, weights: "Weights", prefix: str):
+ w = weights.get_sharded(f"{prefix}.weight", dim=1)
+ # FP8 branch
+ if w.dtype == torch.float8_e4m3fn:
+ if self.weight_block_size is not None:
+ # XXX: Yes the weights is named scale_inv, but corresponds to scale it seems.
+ scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
+
+ return Fp8Weight(
+ weight=w,
+ weight_scale=scale,
+ activation_scale_ub=self.activation_scale_ub,
+ dtype=weights.dtype,
+ weight_block_size=self.weight_block_size,
+ )
+
+ scale = (
+ weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
+ .reshape(-1)
+ .expand(w.shape[0])
+ )
+ logical_widths = [w.shape[0]]
+ w, scale = requantize_with_max_scale(
+ w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
+ )
+
+ input_scale = None
+ if weights.has_tensor(f"{prefix}.input_scale"):
+ input_scale = (
+ weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
+ .reshape(-1)
+ .max()
+ )
+
+ return Fp8Weight(
+ weight=w,
+ weight_scale=scale,
+ input_scale=input_scale,
+ activation_scale_ub=self.activation_scale_ub,
+ dtype=weights.dtype,
+ )
+ if self.to_fp8:
+ return Fp8Weight(weight=w, dtype=weights.dtype)
+
+ return UnquantizedWeight(w)
+
+
+@dataclass
+class Fp8Weight(Weight):
+ weight: torch.Tensor
+ dtype: torch.dtype
+ weight_scale: Optional[torch.Tensor] = None
+ input_scale: Optional[torch.Tensor] = None
+ activation_scale_ub: Optional[float] = None
+ force_w8a16: bool = False
+ weight_block_size: Optional[List[int]] = None
+
+ def get_linear(self, bias: torch.Tensor):
+ if self.weight_scale is None:
+ return get_fp8_linear(force_w8a16=self.force_w8a16).from_unquant(
+ self.weight, bias, self.dtype
+ )
+ # This is not checked by the fbgemm kernels, but they require contiguous
+ # memory. Can be non-contiguous when we e.g. expand from scalars.
+ self.weight_scale = self.weight_scale.contiguous()
+ return get_fp8_linear(force_w8a16=self.force_w8a16).from_fp8(
+ weight=self.weight,
+ scale=self.weight_scale,
+ dtype=self.dtype,
+ bias=bias,
+ input_scale=self.input_scale,
+ scale_upper_bound=self.activation_scale_ub,
+ weight_block_size=self.weight_block_size,
+ )
+
+
+class Fp8Linear(torch.nn.Module):
+ _device_identity_cache = {}
+
+ def __init__(
+ self,
+ qweight: torch.Tensor,
+ scale: torch.Tensor,
+ dtype: torch.dtype,
+ bias: Optional[torch.Tensor] = None,
+ input_scale: Optional[torch.Tensor] = None,
+ scale_upper_bound: Optional[float] = None,
+ weight_block_size: Optional[List[int]] = None,
+ ) -> None:
+ super().__init__()
+
+ self.dtype = dtype
+ self.qweight = qweight
+ self.scale = scale.float()
+ self.input_scale = input_scale.float() if input_scale is not None else None
+ self.weight_block_size = weight_block_size
+ self.scale_upper_bound = scale_upper_bound
+
+ self.bias = bias if bias is not None else None
+
+ @classmethod
+ def from_unquant(cls, weight, bias, dtype):
+ qweight, scale = fp8_quantize(weight, scalar=True)
+ return cls(
+ qweight=qweight,
+ scale=scale,
+ dtype=dtype,
+ bias=bias,
+ input_scale=None,
+ scale_upper_bound=None,
+ )
+
+ @classmethod
+ def from_fp8(
+ cls,
+ weight: torch.Tensor,
+ scale: torch.Tensor,
+ dtype: torch.dtype,
+ bias: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> "Fp8Linear":
+ input_scale = kwargs.get("input_scale", None)
+ scale_upper_bound = kwargs.get("scale_upper_bound", None)
+ weight_block_size = kwargs.get("weight_block_size", None)
+
+ if weight_block_size is not None:
+ weight, orig_M, orig_N = pad_block_fp8_weight_naive(
+ weight, scale, weight_block_size
+ )
+ weight, scale = dynamic_quant(
+ dequant_block_fp8_weight_naive(
+ weight,
+ scale,
+ weight_block_size,
+ original_M=orig_M,
+ original_N=orig_N,
+ do_unpad=True,
+ )
+ )
+ scale = scale.squeeze(-1)
+
+ return cls(
+ qweight=weight,
+ scale=scale,
+ input_scale=input_scale,
+ scale_upper_bound=scale_upper_bound,
+ bias=bias,
+ dtype=dtype,
+ weight_block_size=weight_block_size,
+ )
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ if self.weight_block_size is not None or self.input_scale is None:
+ return apply_block_fp8_linear_hpu_dynamic(
+ input, self.qweight, self.scale, self.input_scale, self.bias
+ )
+
+ x_fp8 = torch.ops.hpu.cast_to_fp8_v2(
+ input, 1.0 / self.input_scale, False, False, torch.float8_e4m3fn
+ )[0]
+ return torch.ops.hpu.fp8_gemm_v2(
+ A=x_fp8,
+ trans_A=False,
+ B=self.qweight,
+ trans_B=True,
+ D=None,
+ out_dtype=input.dtype,
+ A_scale_inv=self.input_scale,
+ B_scale_inv=self.scale,
+ bias=self.bias,
+ accumulate=False,
+ )
+
+
+def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
+ scale = weights.get_tensor(prefix, to_dtype=False)
+
+ if scale.numel() > 1:
+ scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
+ return scale.reshape(-1).expand(shape[0])
diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py
new file mode 100644
index 00000000000..96b120b2417
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py
@@ -0,0 +1,438 @@
+from dataclasses import dataclass
+from typing import List, Optional, Union
+
+import torch
+from loguru import logger
+from text_generation_server.utils.log import log_once
+from text_generation_server.utils.weights import (
+ Weight,
+ Weights,
+ WeightsLoader,
+ DefaultWeightsLoader,
+)
+
+
+from .hpu import QuantLinear
+
+
+@dataclass
+class GPTQWeight(Weight):
+ qweight: torch.Tensor
+ qzeros: torch.Tensor
+ scales: torch.Tensor
+ g_idx: Optional[torch.Tensor]
+ bits: int
+ groupsize: int
+ use_awq_kernel: bool
+ use_exllama: bool
+
+ def __post_init__(self):
+ if self.scales.dtype == torch.float:
+ self.scales = self.scales.half()
+
+ @property
+ def device(self) -> torch.device:
+ return self.qweight.device
+
+ def get_linear(self, bias: torch.Tensor):
+ if self.use_awq_kernel:
+ try:
+ from text_generation_server.layers.awq.quantize import WQLinear
+
+ return WQLinear(
+ w_bit=self.bits,
+ group_size=self.groupsize,
+ qweight=self.qweight,
+ qzeros=self.qzeros,
+ scales=self.scales,
+ bias=bias,
+ )
+ except ImportError:
+ raise NotImplementedError(
+ "You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
+ )
+ else:
+ return QuantLinear(
+ self.qweight,
+ self.qzeros,
+ self.scales,
+ self.g_idx,
+ bias,
+ self.bits,
+ self.groupsize,
+ )
+
+
+class GPTQWeightsLoader(WeightsLoader):
+ """
+ Loader for GPTQ- and AWQ-quantized weights.
+ """
+
+ def __init__(
+ self,
+ *,
+ bits: int,
+ desc_act: bool,
+ groupsize: int,
+ quant_method: str,
+ quantize: str,
+ sym: bool,
+ modules_to_not_convert: List[str],
+ ):
+ self.bits = bits
+ self.desc_act = desc_act
+ self.groupsize = groupsize
+ self.quant_method = quant_method
+ self.quantize = quantize
+ self.sym = sym
+ self.modules_to_not_convert = modules_to_not_convert
+
+ def is_layer_skipped_quantization(
+ self, prefix: str, modules_to_not_convert: List[str]
+ ):
+ return any(module_name in prefix for module_name in modules_to_not_convert)
+
+ def get_weights(self, weights: Weights, prefix: str):
+ self._get_gptq_params(weights)
+
+ use_exllama = True
+ if self.bits != 4:
+ use_exllama = False
+
+ if self.desc_act:
+ log_once(logger.warning, "Disabling exllama because desc_act=True")
+ use_exllama = False
+
+ if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
+ return DefaultWeightsLoader.get_weights(weights, prefix)
+
+ try:
+ qweight = weights.get_tensor(f"{prefix}.qweight")
+ except RuntimeError:
+ raise RuntimeError(
+ "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
+ )
+
+ if self.quantize == "gptq" and self.quant_method == "gptq":
+ g_idx = weights.get_tensor(f"{prefix}.g_idx")
+ else:
+ g_idx = None
+
+ qzeros = weights.get_tensor(f"{prefix}.qzeros")
+ scales = weights.get_tensor(f"{prefix}.scales")
+
+ if use_exllama and g_idx is not None:
+ g_idx = g_idx - g_idx[0]
+
+ if self.quantize == "gptq" and self.quant_method == "awq":
+ log_once(
+ logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
+ )
+ from text_generation_server.layers.awq.conversion_utils import (
+ fast_awq_to_gptq,
+ )
+
+ qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
+ if use_exllama:
+ g_idx = None
+ else:
+ g_idx = (
+ torch.arange(
+ qweight.shape[0] * (32 // self.bits),
+ device=qweight.device,
+ )
+ // self.groupsize
+ ).to(dtype=torch.int32)
+
+ return GPTQWeight(
+ qweight=qweight,
+ qzeros=qzeros,
+ scales=scales,
+ g_idx=g_idx,
+ bits=self.bits,
+ groupsize=self.groupsize,
+ use_exllama=use_exllama,
+ )
+
+ def get_weights_col_packed(
+ self,
+ weights: Weights,
+ prefix: str,
+ block_sizes: Union[int, List[int]],
+ ):
+ if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
+ return DefaultWeightsLoader.get_weights_col_packed(
+ weights, prefix, block_sizes
+ )
+ try:
+ qweight = weights.get_packed_sharded(
+ f"{prefix}.qweight", dim=1, block_sizes=block_sizes
+ )
+ except RuntimeError:
+ raise RuntimeError(
+ f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
+ )
+ scales = weights.get_packed_sharded(
+ f"{prefix}.scales", dim=1, block_sizes=block_sizes
+ )
+ scales = scales.to(dtype=weights.dtype)
+
+ self._get_gptq_params(weights)
+
+ qzeros = weights.get_packed_sharded(
+ f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
+ )
+ if self.quantize == "gptq" and self.quant_method == "gptq":
+ g_idx = weights.get_tensor(f"{prefix}.g_idx")
+ elif self.quantize == "gptq" and self.quant_method == "awq":
+ log_once(
+ logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
+ )
+ from text_generation_server.layers.awq.conversion_utils import (
+ fast_awq_to_gptq,
+ )
+
+ qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
+ g_idx = (
+ torch.arange(
+ qweight.shape[0] * (32 // self.bits),
+ device=qweight.device,
+ )
+ // self.groupsize
+ ).to(dtype=torch.int32)
+ else:
+ g_idx = None
+
+ return GPTQWeight(
+ qweight=qweight,
+ qzeros=qzeros,
+ scales=scales,
+ g_idx=g_idx,
+ bits=self.bits,
+ groupsize=self.groupsize,
+ use_awq_kernel=self.quantize == "awq",
+ use_exllama=False,
+ )
+
+ def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
+ if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
+ return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim)
+ try:
+ qweight = torch.cat(
+ [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
+ )
+ except RuntimeError:
+ raise RuntimeError(
+ f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
+ )
+
+ scales = torch.cat(
+ [weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
+ )
+
+ self._get_gptq_params(weights)
+
+ qzeros = torch.cat(
+ [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
+ )
+
+ use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act
+
+ if self.quantize == "gptq" and self.quant_method == "gptq":
+ w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
+ for w2 in w[1:]:
+ torch.testing.assert_close(w2, w[0])
+ g_idx = w[0]
+ elif self.quantize == "gptq" and self.quant_method == "awq":
+ log_once(
+ logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
+ )
+ from text_generation_server.layers.awq.conversion_utils import (
+ fast_awq_to_gptq,
+ )
+
+ qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
+ if use_exllama:
+ g_idx = None
+ else:
+ g_idx = (
+ torch.arange(
+ qweight.shape[0] * (32 // self.bits),
+ device=qweight.device,
+ )
+ // self.groupsize
+ ).to(dtype=torch.int32)
+ else:
+ g_idx = None
+
+ return GPTQWeight(
+ qweight=qweight,
+ qzeros=qzeros,
+ scales=scales,
+ g_idx=g_idx,
+ bits=self.bits,
+ groupsize=self.groupsize,
+ use_awq_kernel=self.quantize == "awq",
+ use_exllama=use_exllama,
+ )
+
+ def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int):
+ if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
+ return DefaultWeightsLoader.get_multi_weights(weights, prefixes, dim)
+ try:
+ qweight = torch.cat(
+ [weights.get_tensor(f"{p}.qweight") for p in prefixes], dim=1
+ )
+ except RuntimeError:
+ raise RuntimeError(
+ f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
+ )
+
+ scales = torch.cat([weights.get_tensor(f"{p}.scales") for p in prefixes], dim=1)
+
+ self._get_gptq_params(weights)
+
+ qzeros = torch.cat([weights.get_tensor(f"{p}.qzeros") for p in prefixes], dim=1)
+
+ use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act
+
+ if self.quantize == "gptq" and self.quant_method == "gptq":
+ w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
+ for w2 in w[1:]:
+ torch.testing.assert_close(w2, w[0])
+ g_idx = w[0]
+ elif self.quantize == "gptq" and self.quant_method == "awq":
+ log_once(
+ logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
+ )
+ from text_generation_server.layers.awq.conversion_utils import (
+ fast_awq_to_gptq,
+ )
+
+ qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
+ if use_exllama:
+ g_idx = None
+ else:
+ g_idx = (
+ torch.arange(
+ qweight.shape[0] * (32 // self.bits),
+ device=qweight.device,
+ )
+ ).to(dtype=torch.int32)
+ else:
+ g_idx = None
+
+ return GPTQWeight(
+ qweight=qweight,
+ qzeros=qzeros,
+ scales=scales,
+ g_idx=g_idx,
+ bits=self.bits,
+ groupsize=self.groupsize,
+ use_awq_kernel=self.quantize == "awq",
+ use_exllama=use_exllama,
+ )
+
+ def get_weights_row(self, weights: Weights, prefix: str):
+ self._get_gptq_params(weights)
+
+ use_exllama = True
+ desc_act = self.desc_act
+ if self.bits != 4:
+ use_exllama = False
+
+ if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
+ return DefaultWeightsLoader.get_weights_row(weights, prefix)
+
+ if self.desc_act:
+ log_once(logger.warning, "Disabling exllama because desc_act=True")
+ use_exllama = False
+
+ try:
+ qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
+ except RuntimeError:
+ raise RuntimeError(
+ "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
+ )
+
+ if self.quantize == "gptq" and self.quant_method == "gptq":
+ g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
+ else:
+ g_idx = None
+
+ if weights.process_group.size() > 1:
+ if g_idx is not None:
+ if (
+ not torch.equal(
+ # Remove g_idx[0] to adapt the check with TP>1.
+ (g_idx - g_idx[0]).cpu(),
+ torch.tensor(
+ [i // self.groupsize for i in range(g_idx.shape[0])],
+ dtype=torch.int32,
+ ),
+ )
+ and not (g_idx == 0).all()
+ ):
+ # Exllama implementation does not support row tensor parallelism with act-order, as
+ # it would require to reorder input activations that are split unto several GPUs
+ use_exllama = False
+ desc_act = True
+
+ from text_generation_server.layers.gptq import (
+ GPTQWeight,
+ )
+
+ if not desc_act and self.groupsize != -1:
+ qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
+ scales = weights.get_sharded(f"{prefix}.scales", dim=0)
+ if g_idx is not None:
+ # qzeros, scales sharded, and g_idx must be adjusted accordingly
+ g_idx = g_idx - g_idx[0]
+ else:
+ qzeros = weights.get_tensor(f"{prefix}.qzeros")
+ scales = weights.get_tensor(f"{prefix}.scales")
+
+ if self.quantize == "gptq" and self.quant_method == "awq":
+ log_once(
+ logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
+ )
+ from text_generation_server.layers.awq.conversion_utils import (
+ fast_awq_to_gptq,
+ )
+
+ qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
+ if use_exllama:
+ g_idx = None
+ else:
+ g_idx = (
+ torch.arange(
+ qweight.shape[0] * (32 // self.bits),
+ device=qweight.device,
+ )
+ // self.groupsize
+ ).to(dtype=torch.int32)
+
+ return GPTQWeight(
+ qweight=qweight,
+ qzeros=qzeros,
+ scales=scales,
+ g_idx=g_idx,
+ bits=self.bits,
+ groupsize=self.groupsize,
+ use_awq_kernel=self.quantize == "awq",
+ use_exllama=use_exllama,
+ )
+
+ def _get_gptq_params(self, weights: Weights):
+ if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
+ self.bits = weights.get_tensor("gptq_bits").item()
+ self.groupsize = weights.get_tensor("gptq_groupsize").item()
+ self.desc_act = False
+ # `server quantize` used asymmetric quantization unconditionally
+ # before the `gptq_sym` setting tensor was added.
+ self.sym = (
+ weights.get_tensor("gptq_sym").item()
+ if weights.has_tensor("gptq_sym")
+ else False
+ )
+ self.quant_method = "gptq"
diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py b/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py
new file mode 100644
index 00000000000..fa1d8a2e5b6
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py
@@ -0,0 +1,204 @@
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+
+try:
+
+ convert_from_uint4 = torch.ops.hpu.convert_from_uint4
+except Exception as e:
+ hpu_import_exception = e
+
+ def error_raiser_hpu(*args, **kwargs):
+ raise ValueError(
+ f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
+ )
+
+ convert_from_uint4 = error_raiser_hpu
+
+
+def pack_tensor(input, bits=4):
+ normal = input.to(torch.int32)
+ q = torch.zeros((normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32)
+ i = 0
+ col = 0
+ while col < q.shape[1]:
+ for j in range(i, i + (32 // bits)):
+ q[:, col] |= normal[:, j] << (bits * (j - i))
+ i += 32 // bits
+ col += 1
+ q = q.to(torch.int32)
+ return q
+
+
+class QuantLinear(nn.Module):
+ def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
+ super().__init__()
+ self.register_buffer("qweight", qweight)
+ self.register_buffer("qzeros", qzeros)
+ self.register_buffer("scales", scales)
+ self.register_buffer("g_idx", g_idx)
+ if bias is not None:
+ self.register_buffer("bias", bias)
+ else:
+ self.bias = None
+ if bits not in [4]:
+ raise NotImplementedError("Only 4 bits are supported.")
+ self.bits = bits
+ self.maxq = 2**self.bits - 1
+ self.groupsize = groupsize
+
+ self.outfeatures = qweight.shape[1]
+ self.infeatures = qweight.shape[0] * 32 // bits
+ self.wf = torch.tensor(
+ list(range(0, 32, self.bits)), dtype=torch.int32
+ ).unsqueeze(0)
+ self._preprocessing()
+
+ def unpack_zeros_from_cuda_old_format(self):
+ zeros = torch.bitwise_right_shift(
+ torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
+ self.wf.unsqueeze(0),
+ ).to(torch.int16 if self.bits == 8 else torch.int8)
+
+ zeros = zeros + 1
+ zeros = torch.bitwise_and(zeros, (2**self.bits) - 1).to(
+ self.scales.dtype
+ ) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important.
+ zeros = zeros.reshape(-1, zeros.shape[1] * zeros.shape[2])
+ return zeros
+
+ def unpack_weight_from_cuda_old_format(self):
+ weight = torch.bitwise_right_shift(
+ torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
+ self.wf.unsqueeze(-1),
+ ).to(torch.int16 if self.bits == 8 else torch.int8)
+ weight = torch.bitwise_and(weight, (2**self.bits) - 1)
+ weight = weight.reshape((weight.shape[0] * weight.shape[1], weight.shape[2]))
+ return weight
+
+ def _preprocessing(self):
+ orig_device = self.qweight.device
+ self.qweight = self.qweight.cpu()
+ weight = self.unpack_weight_from_cuda_old_format()
+ new_qweight = pack_tensor(weight)
+ self.qweight = new_qweight.to(orig_device)
+ # TODO: Support group indexing and remove the check
+ columns = self.qweight.shape[0]
+ g_idx_trivial = [i // self.groupsize for i in range(columns)]
+ g_idx_trivial = torch.tensor(
+ g_idx_trivial, dtype=torch.int32, device=self.g_idx.device
+ )
+ sort_zeros = not (torch.equal(self.g_idx, g_idx_trivial))
+ self.qzeros = self.qzeros.cpu()
+ zeros = self.unpack_zeros_from_cuda_old_format()
+ if sort_zeros:
+ zeros_group_1 = torch.zeros(
+ (self.infeatures, self.outfeatures),
+ dtype=zeros.dtype,
+ device=zeros.device,
+ )
+ scales = self.scales.cpu()
+ scale_group_1 = torch.zeros(
+ (self.infeatures, self.outfeatures),
+ dtype=scales.dtype,
+ device=scales.device,
+ )
+ for i in range(self.infeatures):
+ zeros_group_1[i] = zeros[self.g_idx[i]]
+ scale_group_1[i] = self.scales[self.g_idx[i]]
+ self.qzeros = pack_tensor(zeros_group_1).to(orig_device)
+ self.scales = scale_group_1.to(orig_device)
+ self.groupsize = 1
+ self.g_idx = None
+ else:
+ new_qzeros = pack_tensor(zeros)
+ self.qzeros = new_qzeros.to(orig_device)
+
+ @classmethod
+ def new(cls, bits, groupsize, infeatures, outfeatures, bias):
+ if bits not in [4]:
+ raise NotImplementedError("Only 4 bits are supported.")
+
+ qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
+ qzeros = torch.zeros(
+ (math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
+ dtype=torch.int32,
+ )
+ scales = torch.zeros(
+ (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
+ )
+ g_idx = torch.tensor(
+ [i // groupsize for i in range(infeatures)], dtype=torch.int32
+ )
+ if bias:
+ bias = torch.zeros((outfeatures), dtype=torch.float16)
+ else:
+ bias = None
+ return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
+
+ def pack(self, linear, scales, zeros, g_idx=None):
+ self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
+
+ scales = scales.t().contiguous()
+ zeros = zeros.t().contiguous()
+ scale_zeros = zeros * scales
+ self.scales = scales.clone().half()
+ if linear.bias is not None:
+ self.bias = linear.bias.clone().half()
+
+ intweight = []
+ for idx in range(self.infeatures):
+ intweight.append(
+ torch.round(
+ (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
+ / self.scales[self.g_idx[idx]]
+ ).to(torch.int)[:, None]
+ )
+ intweight = torch.cat(intweight, dim=1)
+ intweight = intweight.t().contiguous()
+ intweight = intweight.numpy().astype(np.uint32)
+ qweight = np.zeros(
+ (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
+ )
+ i = 0
+ row = 0
+ while row < qweight.shape[0]:
+ if self.bits in [4]:
+ for j in range(i, i + (32 // self.bits)):
+ qweight[row] |= intweight[j] << (self.bits * (j - i))
+ i += 32 // self.bits
+ row += 1
+ else:
+ raise NotImplementedError("Only 4 bits are supported.")
+
+ qweight = qweight.astype(np.int32)
+ self.qweight = torch.from_numpy(qweight)
+
+ zeros -= 1
+ zeros = zeros.numpy().astype(np.uint32)
+ qzeros = np.zeros(
+ (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
+ )
+ i = 0
+ col = 0
+ while col < qzeros.shape[1]:
+ if self.bits in [4]:
+ for j in range(i, i + (32 // self.bits)):
+ qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
+ i += 32 // self.bits
+ col += 1
+ else:
+ raise NotImplementedError("Only 4 bits are supported.")
+
+ qzeros = qzeros.astype(np.int32)
+ self.qzeros = torch.from_numpy(qzeros)
+
+ def forward(self, x):
+ out_shape = x.shape[:-1] + (self.outfeatures,)
+ x = x.reshape(-1, x.shape[-1])
+ weight = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
+ out = torch.matmul(x, weight)
+ out = out.reshape(out_shape)
+ out = out + self.bias if self.bias is not None else out
+ return out
diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/quantize.py b/backends/gaudi/server/text_generation_server/layers/gptq/quantize.py
new file mode 100644
index 00000000000..aa664ea607a
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/gptq/quantize.py
@@ -0,0 +1,1026 @@
+import time
+import torch.nn as nn
+import math
+import json
+import os
+import torch
+import transformers
+
+from texttable import Texttable
+from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
+from huggingface_hub import HfApi
+from accelerate import init_empty_weights
+from text_generation_server.utils import initialize_torch_distributed, Weights
+from text_generation_server.utils.hub import weight_files
+from text_generation_server.layers.gptq import QuantLinear
+from loguru import logger
+from typing import Optional
+from text_generation_server.layers.gptq.utils import torch_snr_error
+
+from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
+
+DEV = torch.device("cuda:0")
+
+
+class Quantizer(nn.Module):
+ def __init__(self, shape=1):
+ super(Quantizer, self).__init__()
+ self.register_buffer("maxq", torch.tensor(0))
+ self.register_buffer("scale", torch.zeros(shape))
+ self.register_buffer("zero", torch.zeros(shape))
+
+ def configure(
+ self,
+ bits,
+ perchannel=False,
+ sym=True,
+ mse=False,
+ norm=2.4,
+ grid=100,
+ maxshrink=0.8,
+ trits=False,
+ ):
+ self.maxq = torch.tensor(2**bits - 1)
+ self.perchannel = perchannel
+ self.sym = sym
+ self.mse = mse
+ self.norm = norm
+ self.grid = grid
+ self.maxshrink = maxshrink
+ if trits:
+ self.maxq = torch.tensor(-1)
+ self.scale = torch.zeros_like(self.scale)
+
+ def _quantize(self, x, scale, zero, maxq):
+ if maxq < 0:
+ return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
+ q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
+ return scale * (q - zero)
+
+ def find_params(self, x, weight=False):
+ dev = x.device
+ self.maxq = self.maxq.to(dev)
+
+ shape = x.shape
+ if self.perchannel:
+ if weight:
+ x = x.flatten(1)
+ else:
+ if len(shape) == 4:
+ x = x.permute([1, 0, 2, 3])
+ x = x.flatten(1)
+ if len(shape) == 3:
+ x = x.reshape((-1, shape[-1])).t()
+ if len(shape) == 2:
+ x = x.t()
+ else:
+ x = x.flatten().unsqueeze(0)
+
+ tmp = torch.zeros(x.shape[0], device=dev)
+ xmin = torch.minimum(x.min(1)[0], tmp)
+ xmax = torch.maximum(x.max(1)[0], tmp)
+
+ if self.sym:
+ xmax = torch.maximum(torch.abs(xmin), xmax)
+ tmp = xmin < 0
+ if torch.any(tmp):
+ xmin[tmp] = -xmax[tmp]
+ tmp = (xmin == 0) & (xmax == 0)
+ xmin[tmp] = -1
+ xmax[tmp] = +1
+
+ if self.maxq < 0:
+ self.scale = xmax
+ self.zero = xmin
+ else:
+ self.scale = (xmax - xmin) / self.maxq
+ if self.sym:
+ self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
+ else:
+ self.zero = torch.round(-xmin / self.scale)
+
+ if self.mse:
+ best = torch.full([x.shape[0]], float("inf"), device=dev)
+ for i in range(int(self.maxshrink * self.grid)):
+ p = 1 - i / self.grid
+ xmin1 = p * xmin
+ xmax1 = p * xmax
+ scale1 = (xmax1 - xmin1) / self.maxq
+ zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
+ q = self._quantize(
+ x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq
+ )
+ q -= x
+ q.abs_()
+ q.pow_(self.norm)
+ err = torch.sum(q, 1)
+ tmp = err < best
+ if torch.any(tmp):
+ best[tmp] = err[tmp]
+ self.scale[tmp] = scale1[tmp]
+ self.zero[tmp] = zero1[tmp]
+ if not self.perchannel:
+ if weight:
+ tmp = shape[0]
+ else:
+ tmp = shape[1] if len(shape) != 3 else shape[2]
+ self.scale = self.scale.repeat(tmp)
+ self.zero = self.zero.repeat(tmp)
+
+ if weight:
+ shape = [-1] + [1] * (len(shape) - 1)
+ self.scale = self.scale.reshape(shape)
+ self.zero = self.zero.reshape(shape)
+ return
+ if len(shape) == 4:
+ self.scale = self.scale.reshape((1, -1, 1, 1))
+ self.zero = self.zero.reshape((1, -1, 1, 1))
+ if len(shape) == 3:
+ self.scale = self.scale.reshape((1, 1, -1))
+ self.zero = self.zero.reshape((1, 1, -1))
+ if len(shape) == 2:
+ self.scale = self.scale.unsqueeze(0)
+ self.zero = self.zero.unsqueeze(0)
+
+ def quantize(self, x):
+ if self.ready():
+ return self._quantize(x, self.scale, self.zero, self.maxq)
+
+ return x
+
+ def enabled(self):
+ return self.maxq > 0
+
+ def ready(self):
+ return torch.all(self.scale != 0)
+
+
+class GPTQ:
+ def __init__(self, layer, observe=False):
+ self.layer = layer
+ self.dev = self.layer.weight.device
+ W = layer.weight.data.clone()
+ if isinstance(self.layer, nn.Conv2d):
+ W = W.flatten(1)
+ if isinstance(self.layer, transformers.Conv1D):
+ W = W.t()
+ self.rows = W.shape[0]
+ self.columns = W.shape[1]
+ self.H = torch.zeros((self.columns, self.columns), device=self.dev)
+ self.nsamples = 0
+ self.quantizer = Quantizer()
+ self.observe = observe
+
+ def add_batch(self, inp, out):
+ # Hessian H = 2 X XT + λ I
+ if self.observe:
+ self.inp1 = inp
+ self.out1 = out
+ else:
+ self.inp1 = None
+ self.out1 = None
+
+ if len(inp.shape) == 2:
+ inp = inp.unsqueeze(0)
+ tmp = inp.shape[0]
+ if isinstance(self.layer, nn.Linear) or isinstance(
+ self.layer, transformers.Conv1D
+ ):
+ if len(inp.shape) == 3:
+ inp = inp.reshape((-1, inp.shape[-1]))
+ inp = inp.t()
+ if isinstance(self.layer, nn.Conv2d):
+ unfold = nn.Unfold(
+ self.layer.kernel_size,
+ dilation=self.layer.dilation,
+ padding=self.layer.padding,
+ stride=self.layer.stride,
+ )
+ inp = unfold(inp)
+ inp = inp.permute([1, 0, 2])
+ inp = inp.flatten(1)
+ self.H *= self.nsamples / (self.nsamples + tmp)
+ self.nsamples += tmp
+ # inp = inp.float()
+ inp = math.sqrt(2 / self.nsamples) * inp.float()
+ # self.H += 2 / self.nsamples * inp.matmul(inp.t())
+ self.H += inp.matmul(inp.t())
+
+ def print_loss(self, name, q_weight, weight_error, timecost):
+ table = Texttable()
+ length = 28
+ name = (
+ (name + " " * (length - len(name)))
+ if len(name) <= length
+ else name[:length]
+ )
+
+ table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"])
+
+ # assign weight
+ self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(
+ self.layer.weight.data.dtype
+ )
+
+ if self.inp1 is not None:
+ # quantize input to int8
+ quantizer = Quantizer()
+ quantizer.configure(8, perchannel=False, sym=True, mse=False)
+ quantizer.find_params(self.inp1)
+ q_in = quantizer.quantize(self.inp1).type(torch.float16)
+ q_out = self.layer(q_in)
+
+ # get kinds of SNR
+ q_SNR = torch_snr_error(q_out, self.out1).item()
+ fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item()
+ else:
+ q_SNR = "-"
+ fp_SNR = "-"
+
+ table.add_row([name, weight_error, fp_SNR, q_SNR, timecost])
+ print(table.draw().split("\n")[-2])
+
+ def fasterquant(
+ self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, name=""
+ ):
+ self.layer.to(self.dev)
+
+ W = self.layer.weight.data.clone()
+ if isinstance(self.layer, nn.Conv2d):
+ W = W.flatten(1)
+ if isinstance(self.layer, transformers.Conv1D):
+ W = W.t()
+ W = W.float()
+
+ tick = time.time()
+
+ if not self.quantizer.ready():
+ self.quantizer.find_params(W, weight=True)
+
+ H = self.H
+ if not self.observe:
+ del self.H
+ dead = torch.diag(H) == 0
+ H[dead, dead] = 1
+ W[:, dead] = 0
+
+ if act_order:
+ perm = torch.argsort(torch.diag(H), descending=True)
+ W = W[:, perm]
+ H = H[perm][:, perm]
+
+ Losses = torch.zeros_like(W)
+ Q = torch.zeros_like(W)
+
+ damp = percdamp * torch.mean(torch.diag(H))
+ diag = torch.arange(self.columns, device=self.dev)
+ H[diag, diag] += damp
+ H = torch.linalg.cholesky(H)
+ H = torch.cholesky_inverse(H)
+ try:
+ H = torch.linalg.cholesky(H, upper=True)
+ except Exception:
+ # Addition because Falcon fails on h_to_4h
+ H = torch.linalg.cholesky(
+ H + 1e-5 * torch.eye(H.shape[0]).to(H.device), upper=True
+ )
+ Hinv = H
+
+ g_idx = []
+ scale = []
+ zero = []
+ now_idx = 1
+
+ for i1 in range(0, self.columns, blocksize):
+ i2 = min(i1 + blocksize, self.columns)
+ count = i2 - i1
+
+ W1 = W[:, i1:i2].clone()
+ Q1 = torch.zeros_like(W1)
+ Err1 = torch.zeros_like(W1)
+ Losses1 = torch.zeros_like(W1)
+ Hinv1 = Hinv[i1:i2, i1:i2]
+
+ for i in range(count):
+ w = W1[:, i]
+ d = Hinv1[i, i]
+
+ if groupsize != -1:
+ if (i1 + i) % groupsize == 0:
+ self.quantizer.find_params(
+ W[:, (i1 + i) : (i1 + i + groupsize)], weight=True
+ )
+
+ if ((i1 + i) // groupsize) - now_idx == -1:
+ scale.append(self.quantizer.scale)
+ zero.append(self.quantizer.zero)
+ now_idx += 1
+
+ q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
+ Q1[:, i] = q
+ Losses1[:, i] = (w - q) ** 2 / d**2
+
+ err1 = (w - q) / d
+ W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
+ Err1[:, i] = err1
+
+ Q[:, i1:i2] = Q1
+ Losses[:, i1:i2] = Losses1 / 2
+
+ W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
+
+ torch.cuda.synchronize()
+ error = torch.sum(Losses).item()
+
+ groupsize = groupsize if groupsize != -1 else self.columns
+ g_idx = [i // groupsize for i in range(self.columns)]
+ g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
+ if act_order:
+ invperm = torch.argsort(perm)
+ Q = Q[:, invperm]
+ g_idx = g_idx[invperm]
+
+ if isinstance(self.layer, transformers.Conv1D):
+ Q = Q.t()
+
+ self.print_loss(
+ name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)
+ )
+
+ if scale == []:
+ scale.append(self.quantizer.scale)
+ zero.append(self.quantizer.zero)
+ scale = torch.cat(scale, dim=1)
+ zero = torch.cat(zero, dim=1)
+ return scale, zero, g_idx, error
+
+ def free(self):
+ self.inp1 = None
+ self.out1 = None
+ self.H = None
+ self.Losses = None
+ self.Trace = None
+ torch.cuda.empty_cache()
+
+
+def get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code):
+ from datasets import load_dataset
+
+ traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
+ testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
+
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id, use_fast=False, trust_remote_code=trust_remote_code
+ )
+ except Exception:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id, use_fast=True, trust_remote_code=trust_remote_code
+ )
+
+ trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
+ testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
+
+ import random
+
+ random.seed(seed)
+ trainloader = []
+ for _ in range(nsamples):
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
+ j = i + seqlen
+ inp = trainenc.input_ids[:, i:j]
+ tar = inp.clone()
+ tar[:, :-1] = -100
+ trainloader.append((inp, tar))
+ return trainloader, testenc
+
+
+def get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code):
+ from datasets import load_dataset
+
+ traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
+ valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation")
+
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id, use_fast=False, trust_remote_code=trust_remote_code
+ )
+ except Exception:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id, use_fast=True, trust_remote_code=trust_remote_code
+ )
+
+ trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt")
+ testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt")
+
+ import random
+
+ random.seed(seed)
+ trainloader = []
+ for _ in range(nsamples):
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
+ j = i + seqlen
+ inp = trainenc.input_ids[:, i:j]
+ tar = inp.clone()
+ tar[:, :-1] = -100
+ trainloader.append((inp, tar))
+ return trainloader, testenc
+
+
+def get_c4(nsamples, seed, seqlen, model_id, trust_remote_code):
+ from datasets import load_dataset
+
+ traindata = load_dataset(
+ "allenai/c4",
+ "allenai--c4",
+ data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
+ split="train",
+ use_auth_token=False,
+ )
+ valdata = load_dataset(
+ "allenai/c4",
+ "allenai--c4",
+ data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
+ split="validation",
+ use_auth_token=False,
+ )
+
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id, use_fast=False, trust_remote_code=trust_remote_code
+ )
+ except Exception:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id, use_fast=True, trust_remote_code=trust_remote_code
+ )
+
+ import random
+
+ random.seed(seed)
+ trainloader = []
+ for _ in range(nsamples):
+ while True:
+ i = random.randint(0, len(traindata) - 1)
+ trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
+ if trainenc.input_ids.shape[1] >= seqlen:
+ break
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
+ j = i + seqlen
+ inp = trainenc.input_ids[:, i:j]
+ tar = inp.clone()
+ tar[:, :-1] = -100
+ trainloader.append((inp, tar))
+
+ import random
+
+ random.seed(0)
+ valenc = []
+ for _ in range(256):
+ while True:
+ i = random.randint(0, len(valdata) - 1)
+ tmp = tokenizer(valdata[i]["text"], return_tensors="pt")
+ if tmp.input_ids.shape[1] >= seqlen:
+ break
+ i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
+ j = i + seqlen
+ valenc.append(tmp.input_ids[:, i:j])
+ valenc = torch.hstack(valenc)
+
+ class TokenizerWrapper:
+ def __init__(self, input_ids):
+ self.input_ids = input_ids
+
+ valenc = TokenizerWrapper(valenc)
+
+ return trainloader, valenc
+
+
+def get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code):
+ from datasets import load_dataset
+
+ traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
+ testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
+
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id, use_fast=False, trust_remote_code=trust_remote_code
+ )
+ except Exception:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id, use_fast=True, trust_remote_code=trust_remote_code
+ )
+
+ trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt")
+ testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt")
+
+ import random
+
+ random.seed(seed)
+ trainloader = []
+ for _ in range(nsamples):
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
+ j = i + seqlen
+ inp = trainenc.input_ids[:, i:j]
+ tar = inp.clone()
+ tar[:, :-1] = -100
+ trainloader.append((inp, tar))
+ return trainloader, testenc
+
+
+def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code):
+ from datasets import load_dataset
+
+ traindata = load_dataset(
+ "allenai/c4",
+ "allenai--c4",
+ data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
+ split="train",
+ )
+ valdata = load_dataset(
+ "allenai/c4",
+ "allenai--c4",
+ data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
+ split="validation",
+ )
+
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id, use_fast=False, trust_remote_code=trust_remote_code
+ )
+ except Exception:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id, use_fast=True, trust_remote_code=trust_remote_code
+ )
+
+ import random
+
+ random.seed(seed)
+ trainloader = []
+ for _ in range(nsamples):
+ while True:
+ i = random.randint(0, len(traindata) - 1)
+ trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
+ if trainenc.input_ids.shape[1] >= seqlen:
+ break
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
+ j = i + seqlen
+ inp = trainenc.input_ids[:, i:j]
+ tar = inp.clone()
+ tar[:, :-1] = -100
+ trainloader.append((inp, tar))
+
+ valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt")
+ valenc = valenc.input_ids[:, : (256 * seqlen)]
+
+ class TokenizerWrapper:
+ def __init__(self, input_ids):
+ self.input_ids = input_ids
+
+ valenc = TokenizerWrapper(valenc)
+
+ return trainloader, valenc
+
+
+def get_loaders(
+ name, nsamples=128, seed=0, seqlen=2048, model_id="", trust_remote_code=False
+):
+ if "wikitext2" in name:
+ return get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code)
+ if "ptb" in name:
+ if "new" in name:
+ return get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code)
+ return get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code)
+ if "c4" in name:
+ if "new" in name:
+ return get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code)
+ return get_c4(nsamples, seed, seqlen, model_id, trust_remote_code)
+
+
+def find_layers(module, layers=(nn.Conv2d, nn.Linear), name=""):
+ # Skip last lm_head linear
+ # Need isintance Falcon is inheriting Linear.
+ if isinstance(module, layers) and "lm_head" not in name:
+ return {name: module}
+ res = {}
+ for name1, child in module.named_children():
+ res.update(
+ find_layers(
+ child, layers=layers, name=name + "." + name1 if name != "" else name1
+ )
+ )
+ return res
+
+
+@torch.no_grad()
+def sequential(
+ model,
+ dataloader,
+ dev,
+ nsamples,
+ bits,
+ groupsize,
+ *,
+ hooks,
+ percdamp=0.01,
+ sym: bool = False,
+ act_order: bool = False,
+):
+ print("Starting ...")
+
+ use_cache = model.config.use_cache
+ model.config.use_cache = False
+ try:
+ layers = model.model.layers
+ prefix = "model.layers"
+ except Exception:
+ layers = model.transformer.h
+ prefix = "transformer.h"
+
+ dtype = next(iter(model.parameters())).dtype
+ inps = torch.zeros(
+ (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
+ )
+
+ cache = {"i": 0}
+ extra = {}
+
+ class Catcher(nn.Module):
+ def __init__(self, module):
+ super().__init__()
+ self.module = module
+
+ def forward(self, inp, **kwargs):
+ inps[cache["i"]] = inp
+ cache["i"] += 1
+ extra.update(kwargs.copy())
+ raise ValueError
+
+ layers[0] = Catcher(layers[0])
+ for batch in dataloader:
+ try:
+ model(batch[0].cuda())
+ except ValueError:
+ pass
+ layers[0] = layers[0].module
+
+ # layers[0] = layers[0].cpu()
+ # model.model.embed_tokens = model.model.embed_tokens.cpu()
+ # model.model.norm = model.model.norm.cpu()
+ torch.cuda.empty_cache()
+ for hook in hooks:
+ hook.remove()
+
+ outs = torch.zeros_like(inps)
+
+ extra = {
+ k: v.to(dev) if isinstance(v, torch.Tensor) else v for k, v in extra.items()
+ }
+
+ print("Ready.")
+
+ quantizers = {}
+ for i in range(len(layers)):
+ print(f"Quantizing layer {i+1}/{len(layers)}..")
+ print("+------------------+--------------+------------+-----------+-------+")
+ print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |")
+ print("+==================+==============+============+===========+=======+")
+
+ layer = layers[i]
+ layer.load()
+ full = find_layers(layer)
+ sequential = [list(full.keys())]
+
+ for names in sequential:
+ subset = {n: full[n] for n in names}
+ gptq = {}
+ for name in subset:
+ gptq[name] = GPTQ(subset[name])
+ gptq[name].quantizer.configure(
+ bits, perchannel=True, sym=sym, mse=False
+ )
+ pass
+
+ def add_batch(name):
+ nonlocal gptq
+
+ def tmp(_, inp, out):
+ gptq[name].add_batch(inp[0].data, out.data)
+
+ return tmp
+
+ handles = []
+ for name in subset:
+ handles.append(subset[name].register_forward_hook(add_batch(name)))
+ for j in range(nsamples):
+ outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
+ for h in handles:
+ h.remove()
+
+ for name in subset:
+ scale, zero, g_idx, error = gptq[name].fasterquant(
+ percdamp=percdamp,
+ groupsize=groupsize,
+ act_order=act_order,
+ name=name,
+ )
+ quantizers[f"{prefix}.{i}.{name}"] = (
+ gptq[name].quantizer.cpu(),
+ scale.cpu(),
+ zero.cpu(),
+ g_idx.cpu(),
+ bits,
+ groupsize,
+ )
+
+ gptq[name].free()
+
+ for j in range(nsamples):
+ outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
+
+ layer.unload()
+ del layer
+ del gptq
+ torch.cuda.empty_cache()
+
+ inps, outs = outs, inps
+ print("+------------------+--------------+------------+-----------+-------+")
+ print("\n")
+
+ model.config.use_cache = use_cache
+
+ return quantizers
+
+
+def make_quant_linear(module, names, bits, groupsize, name=""):
+ if isinstance(module, QuantLinear):
+ return
+ for attr in dir(module):
+ tmp = getattr(module, attr)
+ name1 = name + "." + attr if name != "" else attr
+ if name1 in names:
+ delattr(module, attr)
+ setattr(
+ module,
+ attr,
+ QuantLinear.new(
+ bits,
+ groupsize,
+ tmp.in_features,
+ tmp.out_features,
+ tmp.bias is not None,
+ ),
+ )
+ for name1, child in module.named_children():
+ make_quant_linear(
+ child, names, bits, groupsize, name + "." + name1 if name != "" else name1
+ )
+
+
+# TODO: perform packing on GPU
+def pack(model, quantizers, bits, groupsize):
+ layers = find_layers(model)
+ layers = {n: layers[n] for n in quantizers}
+ make_quant_linear(model, quantizers, bits, groupsize)
+ qlayers = find_layers(model, (QuantLinear,))
+ print("Packing ...")
+ for name in qlayers:
+ print(name)
+ quantizers[name], scale, zero, g_idx, _, _ = quantizers[name]
+ qlayers[name].pack(layers[name], scale, zero, g_idx)
+ print("Done.")
+ return model
+
+
+def setdeepattr(module, full_name, tensor):
+ current = module
+ tokens = full_name.split(".")
+ for token in tokens[:-1]:
+ current = getattr(current, token)
+ setattr(current, tokens[-1], tensor)
+
+
+def getdeepattr(module, full_name):
+ current = module
+ tokens = full_name.split(".")
+ for token in tokens:
+ current = getattr(current, token)
+ return current
+
+
+def load_weights_pre_hook(module_name, weights, recursive=False):
+ def inner(module, args):
+ print(f"Pre hook {module_name}")
+ local_params = {}
+ for k, v in module.named_parameters():
+ if not recursive and k.count(".") != 1:
+ continue
+ local_params[k] = v
+ for k, v in module.named_buffers():
+ if not recursive and k.count(".") != 1:
+ continue
+ local_params[k] = v
+
+ for local_param in local_params:
+ current_tensor = getdeepattr(module, local_param)
+ if current_tensor.device == torch.device("meta"):
+ # print(f"Loading {local_param}")
+ if module_name:
+ tensor_name = f"{module_name}.{local_param}"
+ else:
+ tensor_name = local_param
+ tensor = weights.get_tensor(tensor_name)
+ setdeepattr(module, local_param, nn.Parameter(tensor))
+ else:
+ tensor = current_tensor.to(device=torch.device("cuda:0"))
+ if current_tensor.requires_grad:
+ tensor = nn.Parameter(tensor)
+ setdeepattr(module, local_param, tensor)
+
+ return inner
+
+
+def load_weights_post_hook(module_name, weights, recursive=False):
+ def inner(module, args, output):
+ print(f"Post hook {module_name}")
+ local_params = {}
+ for k, v in module.named_parameters():
+ if not recursive and k.count(".") != 1:
+ continue
+ local_params[k] = v
+ for k, v in module.named_buffers():
+ if not recursive and k.count(".") != 1:
+ continue
+ local_params[k] = v
+ for local_param in local_params:
+ # print(f"Unloading {local_param}")
+ current_tensor = getdeepattr(module, local_param)
+ setdeepattr(
+ module,
+ local_param,
+ nn.Parameter(current_tensor.to(device=torch.device("cpu"))),
+ )
+ return output
+
+ return inner
+
+
+def quantize(
+ model_id: str,
+ bits: int,
+ groupsize: int,
+ output_dir: str,
+ revision: str,
+ trust_remote_code: bool,
+ upload_to_model_id: Optional[str],
+ percdamp: float,
+ act_order: bool,
+ sym: bool,
+):
+ print("loading model")
+ config = AutoConfig.from_pretrained(
+ model_id,
+ trust_remote_code=trust_remote_code,
+ )
+
+ with init_empty_weights():
+ model = AutoModelForCausalLM.from_config(
+ config, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
+ )
+ model = model.eval()
+
+ print("LOADED model")
+ files = weight_files(model_id, revision, extension=".safetensors")
+ process_group, _, _ = initialize_torch_distributed()
+ weights = Weights(
+ files,
+ device=torch.device("cuda:0"),
+ dtype=torch.float16,
+ process_group=process_group,
+ aliases={"embed_tokens.weight": ["lm_head.weight"]},
+ weights_loader=DefaultWeightsLoader(UnquantizedWeight),
+ )
+ hooks = []
+ for name, module in model.named_modules():
+
+ def load(module, name):
+ def _load():
+ load_weights_pre_hook(name, weights, recursive=True)(module, None)
+
+ return _load
+
+ def unload(module, name):
+ def _unload():
+ load_weights_post_hook(name, weights, recursive=True)(
+ module, None, None
+ )
+
+ return _unload
+
+ module.load = load(module, name)
+ module.unload = unload(module, name)
+ hooks.append(
+ module.register_forward_pre_hook(load_weights_pre_hook(name, weights))
+ )
+ hooks.append(
+ module.register_forward_hook(load_weights_post_hook(name, weights))
+ )
+ model.seqlen = 2048
+
+ dataset = "wikitext2"
+ nsamples = 128
+ seed = None
+
+ dataloader, testloader = get_loaders(
+ dataset,
+ nsamples=nsamples,
+ seed=seed,
+ model_id=model_id,
+ seqlen=model.seqlen,
+ trust_remote_code=trust_remote_code,
+ )
+
+ tick = time.time()
+ quantizers = sequential(
+ model,
+ dataloader,
+ DEV,
+ nsamples,
+ bits,
+ groupsize,
+ percdamp=percdamp,
+ act_order=act_order,
+ hooks=hooks,
+ sym=sym,
+ )
+ print(time.time() - tick)
+
+ pack(model, quantizers, bits, groupsize)
+ from safetensors.torch import save_file
+ from huggingface_hub import split_torch_state_dict_into_shards
+
+ state_dict = model.state_dict()
+ state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
+
+ max_shard_size = "10GB"
+ state_dict_split = split_torch_state_dict_into_shards(
+ state_dict,
+ filename_pattern="model.safetensors",
+ max_shard_size=max_shard_size,
+ )
+ index = None
+ if state_dict_split.is_sharded:
+ index = {
+ "metadata": state_dict_split.metadata,
+ "weight_map": state_dict_split.tensor_to_filename,
+ }
+ shards = state_dict_split.filename_to_tensors
+ os.makedirs(output_dir, exist_ok=True)
+ for shard_file, shard in shards.items():
+ save_file(
+ shard,
+ os.path.join(output_dir, shard_file),
+ metadata={
+ "format": "pt",
+ "quantized": "gptq",
+ "origin": "text-generation-inference",
+ },
+ )
+ if index is None:
+ path_to_weights = os.path.join(output_dir, "model.safetensors")
+ logger.info(f"Model weights saved in {path_to_weights}")
+ else:
+ save_index_file = "model.safetensors.index.json"
+ save_index_file = os.path.join(output_dir, save_index_file)
+ with open(save_index_file, "w", encoding="utf-8") as f:
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
+ f.write(content)
+ logger.info(
+ f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
+ f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
+ config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
+ config.quantization_config = {
+ "bits": bits,
+ "group_size": groupsize,
+ "damp_percent": percdamp,
+ "desc_act": act_order,
+ "static_groups": False,
+ "sym": sym,
+ "quant_method": "gptq",
+ }
+ config.save_pretrained(output_dir)
+ logger.info("Saved config")
+ logger.info("Saving tokenizer")
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id, trust_remote_code=trust_remote_code
+ )
+ tokenizer.save_pretrained(output_dir)
+ logger.info("Saved tokenizer")
+
+ if upload_to_model_id:
+ api = HfApi()
+
+ api.upload_folder(
+ folder_path=output_dir, repo_id=upload_to_model_id, repo_type="model"
+ )
diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/utils.py b/backends/gaudi/server/text_generation_server/layers/gptq/utils.py
new file mode 100644
index 00000000000..cbc0f391faf
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/gptq/utils.py
@@ -0,0 +1,56 @@
+import torch
+
+
+# copied from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py
+def torch_snr_error(
+ y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = "mean"
+) -> torch.Tensor:
+ """
+ Compute SNR between y_pred(tensor) and y_real(tensor)
+
+ SNR can be calcualted as following equation:
+
+ SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2
+
+ if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements.
+
+ SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)
+
+ Args:
+ y_pred (torch.Tensor): _description_
+ y_real (torch.Tensor): _description_
+ reduction (str, optional): _description_. Defaults to 'mean'.
+
+ Raises:
+ ValueError: _description_
+ ValueError: _description_
+
+ Returns:
+ torch.Tensor: _description_
+ """
+ if y_pred.shape != y_real.shape:
+ raise ValueError(
+ f"Can not compute snr loss for tensors with different shape. "
+ f"({y_pred.shape} and {y_real.shape})"
+ )
+ reduction = str(reduction).lower()
+
+ if y_pred.ndim == 1:
+ y_pred = y_pred.unsqueeze(0)
+ y_real = y_real.unsqueeze(0)
+
+ y_pred = y_pred.flatten(start_dim=1)
+ y_real = y_real.flatten(start_dim=1)
+
+ noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)
+ signal_power = torch.pow(y_real, 2).sum(dim=-1)
+ snr = (noise_power) / (signal_power + 1e-7)
+
+ if reduction == "mean":
+ return torch.mean(snr)
+ elif reduction == "sum":
+ return torch.sum(snr)
+ elif reduction == "none":
+ return snr
+ else:
+ raise ValueError("Unsupported reduction method.")
diff --git a/backends/gaudi/server/text_generation_server/layers/layernorm.py b/backends/gaudi/server/text_generation_server/layers/layernorm.py
new file mode 100644
index 00000000000..4bbb6c1fe88
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/layernorm.py
@@ -0,0 +1,62 @@
+import torch
+from torch import nn
+from accelerate import init_empty_weights
+
+
+# Monkey patching
+@classmethod
+def load_layer_norm(cls, prefix, weights, eps):
+ weight = weights.get_tensor(f"{prefix}.weight")
+ bias = weights.get_tensor(f"{prefix}.bias")
+ with init_empty_weights():
+ ln = cls(weight.shape, eps=eps)
+
+ ln.weight = torch.nn.Parameter(weight)
+ ln.bias = torch.nn.Parameter(bias)
+ return ln
+
+
+@classmethod
+def load_layer_norm_no_bias(cls, prefix, weights, eps):
+ weight = weights.get_tensor(f"{prefix}.weight")
+ with init_empty_weights():
+ ln = cls(weight.shape, eps=eps)
+
+ ln.weight = torch.nn.Parameter(weight)
+ ln.bias = None
+ return ln
+
+
+torch.nn.LayerNorm.load = load_layer_norm
+torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
+
+
+class FastLayerNorm(nn.LayerNorm):
+ def forward(self, hidden_states, residual=None):
+ if residual is not None:
+ hidden_states += residual
+ residual = hidden_states
+
+ return super().forward(hidden_states), residual
+
+
+class FastRMSNorm(nn.Module):
+ def __init__(self, weight: torch.Tensor, eps: float):
+ super().__init__()
+
+ self.weight = nn.Parameter(weight)
+ self.variance_epsilon = eps
+
+ @classmethod
+ def load(cls, prefix, weights, eps=1e-6):
+ weight = weights.get_tensor(f"{prefix}.weight")
+ return cls(weight, eps)
+
+ def forward(self, hidden_states, residual=None):
+ if residual is not None:
+ hidden_states += residual
+ residual = hidden_states
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(self.weight.dtype), residual
diff --git a/backends/gaudi/server/text_generation_server/layers/linear.py b/backends/gaudi/server/text_generation_server/layers/linear.py
new file mode 100644
index 00000000000..cca80c44ede
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/linear.py
@@ -0,0 +1,38 @@
+import torch
+from torch.nn import functional as F
+
+
+class FastLinear(torch.nn.Module):
+ def __init__(
+ self,
+ weight,
+ bias,
+ ) -> None:
+ super().__init__()
+ self.weight = torch.nn.Parameter(weight, requires_grad=False)
+ if bias is not None:
+ self.bias = torch.nn.Parameter(bias, requires_grad=False)
+ else:
+ self.bias = None
+
+ @classmethod
+ def load(cls, config, prefix: str, weights, bias: bool):
+ weight = weights.get_tensor(f"{prefix}.weight")
+ if bias:
+ bias = weights.get_tensor(f"{prefix}.bias")
+ else:
+ bias = None
+ return cls(weight, bias)
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return F.linear(input, self.weight, self.bias)
+
+
+def get_linear(weight, bias):
+ # Weights that are loaded through methods that are not
+ # quantization-aware are still bare tensors. We may want
+ # to change this in the future.
+ if isinstance(weight, torch.Tensor):
+ return FastLinear(weight, bias)
+
+ return weight.get_linear(bias)
diff --git a/backends/gaudi/server/text_generation_server/layers/lora.py b/backends/gaudi/server/text_generation_server/layers/lora.py
new file mode 100644
index 00000000000..a4537b55bbf
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/lora.py
@@ -0,0 +1,279 @@
+from typing import TYPE_CHECKING, Optional, List
+
+import torch
+import torch.distributed
+from torch import nn
+from torch.distributed import ProcessGroup
+
+from text_generation_server.utils.sgmv import (
+ add_lora_a_bgmv,
+ add_lora_b_bgmv,
+ has_sgmv,
+ lora_a_sgmv_cutlass,
+ lora_b_sgmv_cutlass,
+ orient_for_rank,
+)
+
+if TYPE_CHECKING:
+ from text_generation_server.adapters import AdapterBatchData
+ from text_generation_server.adapters.lora import BatchLoraWeights
+
+
+class LoraLinear(nn.Module):
+ def __init__(
+ self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup
+ ):
+ super().__init__()
+ self.base_layer = base_layer
+ self.layer_id = layer_id
+ self.process_group = process_group
+
+ def forward_layer_type(
+ self,
+ result: torch.Tensor,
+ input: torch.Tensor,
+ adapter_data: "AdapterBatchData",
+ layer_type: str,
+ start_idx: int,
+ end_idx: int,
+ ) -> torch.Tensor:
+ if adapter_data is None:
+ return result
+ data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
+
+ if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
+ # In tensor-parallel configurations, each GPU processes a specific segment of the output.
+ # The 'result' tensor represents the full output, which can vary in size based on
+ # the layer type (e.g., attention vs. feed-forward layers). We define the current
+ # segment using start_idx and end_idx. If the segment size doesn't match this GPU's
+ # slice of 'result', we create a zero tensor of the correct size for LoRA computation.
+ # This approach ensures accurate LoRA application across various layer sizes and
+ # configurations, adapting to different model architectures and parallelization strategies.
+ #
+ # Example scenarios where this is necessary:
+ # 1. The adapter's size doesn't evenly divide across GPUs.
+ # 2. We're processing the last segment which might be smaller.
+ # 3. Different projection layers (q, k, v) have different sizes.
+ if end_idx - start_idx != result.shape[1]:
+ proj = torch.zeros_like(result[:, start_idx:end_idx])
+ else:
+ proj = result
+
+ for r, rank_segments in data.rank_data.items():
+ lora_a_ptr = rank_segments.lora_a_ptr
+ lora_b_ptr = rank_segments.lora_b_ptr
+
+ if lora_a_ptr is None or lora_b_ptr is None:
+ raise ValueError("LoRA data is missing")
+
+ if data.use_sgmv:
+ # Use SGMV for prefill
+ v = lora_a_sgmv_cutlass(
+ input,
+ rank_segments.tmp_shrink,
+ lora_a_ptr,
+ rank_segments.segment_starts,
+ rank_segments.segment_ends,
+ self.layer_id,
+ r,
+ )
+
+ if self.process_group.size() > 1:
+ v = self.collect_lora_a(v)
+
+ lora_b_sgmv_cutlass(
+ proj,
+ v,
+ rank_segments.tmp_expand,
+ lora_b_ptr,
+ rank_segments.segment_starts,
+ rank_segments.segment_ends,
+ self.layer_id,
+ )
+ else:
+ # Use BGMV for decode
+ v = torch.zeros(
+ (input.size(0), r), dtype=input.dtype, device=input.device
+ )
+ # TODO: error with [-1, 0], but not [0, -1]
+ add_lora_a_bgmv(
+ v,
+ input,
+ lora_a_ptr,
+ rank_segments.indices,
+ self.layer_id,
+ )
+
+ if self.process_group.size() > 1:
+ v = self.collect_lora_a(v)
+
+ add_lora_b_bgmv(
+ proj,
+ v,
+ lora_b_ptr,
+ rank_segments.indices,
+ self.layer_id,
+ )
+
+ if end_idx - start_idx != result.shape[1]:
+ result[:, start_idx:end_idx] += proj
+ else:
+ for adapter_index in adapter_data.meta.adapter_set:
+ if data is not None and data.has_adapter(adapter_index):
+ adapter_mask = (
+ (adapter_data.meta.adapter_indices == adapter_index)
+ .to(input.dtype)
+ .view(-1, 1)
+ )
+ layer_result = self.forward_lora(
+ input, data, adapter_index, adapter_mask
+ )
+ result[:, start_idx:end_idx] += layer_result
+
+ return result
+
+ def forward_lora(
+ self,
+ input: torch.Tensor,
+ data: "BatchLoraWeights",
+ adapter_index: int,
+ adapter_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
+ lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
+
+ lora_a = orient_for_rank(lora_a, lora_b.size(0))
+
+ a_out = input @ lora_a
+ if self.process_group.size() > 1:
+ a_out = self.collect_lora_a(a_out)
+
+ result = (a_out @ lora_b) * adapter_mask
+ return result
+
+ def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
+ raise NotImplementedError("Implemented in subclasses")
+
+
+class TensorParallelMultiAdapterLinear(LoraLinear):
+ def __init__(
+ self,
+ base_layer: nn.Module,
+ layer_id: int,
+ layer_names: List[str],
+ sizes: List[int],
+ process_group: ProcessGroup,
+ ):
+ super().__init__(base_layer, layer_id, process_group)
+ self.layer_names = layer_names
+ self.sizes = sizes
+
+ @classmethod
+ def load(
+ cls,
+ base_layer: nn.Module,
+ layer_id: int,
+ layer_names: List[str],
+ sizes: List[int],
+ process_group: ProcessGroup,
+ ):
+ return TensorParallelMultiAdapterLinear(
+ base_layer, layer_id, layer_names, sizes, process_group
+ )
+
+ def forward(
+ self, input: torch.Tensor, adapter_data: "AdapterBatchData"
+ ) -> torch.Tensor:
+ result = self.base_layer(input)
+
+ # noop if no layer names are provided (e.g. for models without adapters)
+ if self.layer_names is None:
+ return result
+
+ # handle models like Bloom that have inputs of shape
+ # (batch_size, sequence_length, hidden_size)
+ # we need to reshape them to (batch_size * sequence_length, hidden_size)
+ # for the LoRA computation, then reshape back
+ prev_shape = result.shape
+ is_3d = len(input.shape) >= 3
+ if is_3d:
+ input = input.reshape(-1, input.shape[-1])
+ result = result.reshape(-1, result.shape[-1])
+
+ offset = 0
+ for i, layer_name in enumerate(self.layer_names):
+ start_idx = offset // self.process_group.size()
+ # The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
+ # projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
+ # ensures correct slicing of the result tensor, accommodating variations like grouped-query
+ # attention where k_proj and v_proj differ from q_proj. This allows precise application of
+ # LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
+ # different projection sizes across layers and model architectures.
+ if self.sizes is not None:
+ offset += self.sizes[i]
+ end_idx = offset // self.process_group.size()
+ else:
+ end_idx = result.shape[1]
+
+ result = self.forward_layer_type(
+ result, input, adapter_data, layer_name, start_idx, end_idx
+ )
+
+ if is_3d:
+ result = result.reshape(prev_shape)
+
+ return result
+
+ def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
+ # Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
+ # We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
+ #
+ # TODO(travis): this is not very efficient as we do an all-gather for every adapter,
+ # instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
+ # rank, compute `a_out` on each, and then slice them into the buffer as shown here:
+ # https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
+ gathered_tensors = [
+ torch.empty_like(a_out) for _ in range(self.process_group.size())
+ ]
+ torch.distributed.all_gather(gathered_tensors, a_out)
+ return torch.cat(gathered_tensors, dim=1)
+
+
+class TensorParallelAdapterRowLinear(LoraLinear):
+ def __init__(self, base_layer, layer_id, layer_name, process_group):
+ super().__init__(base_layer, layer_id, process_group)
+ self.layer_name = layer_name
+
+ @classmethod
+ def load(cls, base_layer, layer_id, layer_name, process_group):
+ return cls(base_layer, layer_id, layer_name, process_group)
+
+ def forward(
+ self, input: torch.Tensor, adapter_data: "AdapterBatchData"
+ ) -> torch.Tensor:
+ result = self.base_layer(input)
+
+ if self.layer_name is None:
+ return result
+
+ # Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
+ stride = result.shape[-1] // self.process_group.size()
+ start_idx = self.process_group.rank() * stride
+ end_idx = (self.process_group.rank() + 1) * stride
+
+ self.forward_layer_type(
+ result, input, adapter_data, self.layer_name, start_idx, end_idx
+ )
+
+ return result
+
+ def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
+ # Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.
+ # We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.
+ #
+ # TODO(travis): this is not very efficient as we do an all-reduce for every adapter,
+ # instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
+ # rank, compute `a_out` on each, and then slice them into the buffer as shown here:
+ # https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
+ torch.distributed.all_reduce(a_out, group=self.process_group)
+ return a_out
diff --git a/backends/gaudi/server/text_generation_server/layers/medusa.py b/backends/gaudi/server/text_generation_server/layers/medusa.py
new file mode 100644
index 00000000000..139c4dc2500
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/medusa.py
@@ -0,0 +1,191 @@
+import torch
+from torch import nn
+from typing import Tuple, Optional
+from text_generation_server.utils.speculate import get_speculate
+from text_generation_server.layers.linear import FastLinear
+from text_generation_server.layers.tensor_parallel import (
+ TensorParallelHead,
+ TensorParallelColumnLinear,
+)
+
+
+class ResBlock(torch.nn.Module):
+ def __init__(self, config, prefix, weights):
+ super().__init__()
+ self.linear = FastLinear.load(
+ config, prefix=f"{prefix}.linear", weights=weights, bias=True
+ )
+ self.act = torch.nn.SiLU()
+
+ def forward(self, x):
+ return x + self.act(self.linear(x))
+
+
+class MedusaModel(torch.nn.Module):
+ def __init__(self, config, medusa_config, weights):
+ super().__init__()
+ self.heads = torch.nn.ModuleList(
+ [
+ MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
+ for i in range(get_speculate())
+ ]
+ )
+
+ def forward(self, x):
+ if not self.heads:
+ return None
+ speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
+ return speculative_logits
+
+
+class MedusaHead(torch.nn.Module):
+ def __init__(self, config, medusa_config, prefix, weights):
+ super().__init__()
+ self.blocks = torch.nn.ModuleList(
+ [
+ ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
+ for i in range(medusa_config["medusa_num_layers"])
+ ]
+ )
+ n = len(self.blocks)
+ self.out = FastLinear.load(
+ config, prefix=f"{prefix}.{n}", weights=weights, bias=False
+ )
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+ x = self.out(x)
+ return x
+
+
+class MedusaHeadV1(nn.Module):
+ def __init__(self, lm_head, medusa):
+ super().__init__()
+ self.lm_head = lm_head
+ self.medusa = medusa
+
+ @staticmethod
+ def load(config, prefix: str, weights):
+ from pathlib import Path
+ from safetensors import safe_open
+ import json
+
+ speculator = config.speculator
+
+ path = speculator["path"]
+ medusa_config = str(Path(path) / "config.json")
+
+ for fname in speculator["model_paths"]:
+ filename = str(Path(path) / fname)
+
+ with open(medusa_config, "r") as f:
+ medusa_config = json.load(f)
+ routing = weights.routing
+ with safe_open(filename, framework="pytorch") as f:
+ for k in f.keys():
+ if k in routing and routing[k] != filename:
+ raise RuntimeError(
+ f"Key {k} was found in multiple files: {filename} and {routing[k]}"
+ )
+ routing[k] = filename
+
+ medusa = MedusaModel(config, medusa_config, weights)
+ lm_head = TensorParallelHead.load(config, prefix, weights)
+ return MedusaHeadV1(lm_head, medusa)
+
+ def forward(
+ self, input: torch.Tensor
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ logits = self.lm_head(input)
+ # If we have too many tokens, we skip speculative logits
+ if input.shape[0] > 128:
+ return logits, None
+
+ speculative_logits = self.medusa(input)
+ return logits, speculative_logits
+
+
+class MedusaHeadV2(nn.Module):
+ def __init__(self, config, prefix, weights):
+ super().__init__()
+ from pathlib import Path
+ from safetensors import safe_open
+ import json
+
+ speculator_path = config.speculator["path"]
+
+ medusa_config = str(Path(speculator_path) / "config.json")
+ filename = str(Path(speculator_path) / "medusa_lm_head.safetensors")
+
+ with open(medusa_config, "r") as f:
+ medusa_config = json.load(f)
+ routing = weights.routing
+ with safe_open(filename, framework="pytorch") as f:
+ for k in f.keys():
+ if k in routing and routing[k] != filename:
+ raise RuntimeError(
+ f"Key {k} was found in multiple files: {filename} and {routing[k]}"
+ )
+ routing[k] = filename
+
+ self.n_medusa_heads = get_speculate()
+
+ assert medusa_config["medusa_num_layers"] == 1
+ self.linear = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)],
+ dim=0,
+ weights=weights,
+ bias=True,
+ )
+ self.process_group = weights.process_group
+ self.world_size = self.process_group.size()
+ self.rank = self.process_group.rank()
+
+ self.act = torch.nn.SiLU()
+
+ self.lm_head = TensorParallelHead.load(config, prefix, weights)
+
+ def forward(self, x):
+ # If we have too many tokens, we skip speculative logits
+ if x.shape[0] > 128:
+ logits = self.lm_head(x)
+ return logits, None
+
+ size = x.shape[-1]
+ block_size = (size + self.world_size - 1) // self.world_size
+ start = self.rank * block_size
+ stop = (self.rank + 1) * block_size
+
+ x_block = x[:, start:stop]
+
+ # Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
+ medusa_res = self.act(self.linear(x)).reshape(
+ *x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]
+ )
+
+ # Apply all residual medusa heads
+ output = x[:, start:stop].unsqueeze(-2) + medusa_res
+
+ # Gather medusa heads
+ world_output = [
+ torch.empty_like(output) for _ in range(self.process_group.size())
+ ]
+ torch.distributed.all_gather(world_output, output, group=self.process_group)
+ world_output = torch.cat(world_output, dim=-1)
+
+ # Stack x and medusa residual x
+ stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)
+
+ # Compute lm head on x + medusa residual x
+ logits = self.lm_head(stacked_x)
+
+ # Finally, split logits from speculative logits
+ logits, speculative_logits = torch.split(
+ logits, [1, self.n_medusa_heads], dim=-2
+ )
+ # Squeeze added dimension
+ logits = logits.squeeze(-2)
+
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/layers/mlp.py b/backends/gaudi/server/text_generation_server/layers/mlp.py
new file mode 100644
index 00000000000..d33b41f323b
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/mlp.py
@@ -0,0 +1,282 @@
+import torch
+import math
+from torch import nn
+from torch.nn import functional as F
+from typing import Optional, Tuple
+from text_generation_server.layers import TensorParallelEmbedding, FastLinear
+from text_generation_server.layers.tensor_parallel import TensorParallelHead
+from text_generation_server.utils.speculate import get_speculate
+
+
+class MLPSpeculatorLayerNorm(nn.Module):
+ """
+ A L2 normalization implementation
+ ...
+ Args
+ ----
+ normalized_shape : int
+ Dimensionality of input data (size of final tensor axis)
+ elementwise_scale_weight : torch.Tensor
+ learned scaling term after normalization?
+ elementwise_shift_bias : torch.Tensor
+ learned bias term after normalization?
+ eps : float
+ Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8).
+ """
+
+ def __init__(
+ self,
+ prefix,
+ config,
+ weights,
+ eps=1e-06,
+ ):
+ super(MLPSpeculatorLayerNorm, self).__init__()
+ self.weight = weights.get_tensor(f"{prefix}.weight")
+ self.bias = weights.get_tensor(f"{prefix}.bias")
+ self.eps = eps
+
+ def forward(self, x):
+ xf = x
+ xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
+ x = xf.type_as(x)
+ x = self.weight * x
+ x = x + self.bias
+ return x
+
+
+INV_SQRT2 = 2**-0.5
+
+
+def simple_norm(x: torch.Tensor, eps=1e-06):
+ xf = x
+ xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + eps)
+ x = xf.type_as(x)
+ return x * INV_SQRT2
+
+
+class MLPSpeculatorModelTied(torch.nn.Module):
+ def __init__(self, config, prefix, weights):
+ super().__init__()
+ self.config = config
+ self.n_predict = get_speculate()
+ self.hidden_size = config.hidden_size
+
+ self.emb = TensorParallelEmbedding(f"{prefix}.emb.0", weights)
+ self.proj0 = FastLinear.load(
+ config,
+ prefix=f"{prefix}.proj.0",
+ weights=weights,
+ bias=False,
+ )
+ self.proj1 = FastLinear.load(
+ config,
+ prefix=f"{prefix}.proj.1",
+ weights=weights,
+ bias=False,
+ )
+ self.head = FastLinear.load(config, f"{prefix}.head.0", weights, bias=False)
+ self.ln = MLPSpeculatorLayerNorm(
+ prefix=f"{prefix}.ln.0",
+ config=config,
+ weights=weights,
+ )
+
+ # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
+ self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
+ self.activation = nn.GELU()
+ self.vsize = config.vocab_size
+ self.inner_dim = config.speculator_config["inner_dim"]
+ self.top_k_tokens_per_head = [1] * self.n_predict
+ self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(
+ self.inner_dim / 2
+ )
+ self.emb.weight *= self.emb_weight
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ input_ids: torch.Tensor,
+ ):
+ top_k_tokens_per_head = self.top_k_tokens_per_head
+
+ # k indicates # of candidates
+ # h indicates # of generated tokens
+ state = hidden_states
+ b = state.size(0)
+ ind = input_ids.unsqueeze(0)
+ all_probs = torch.empty(
+ b, self.n_predict, self.vsize, device=state.device
+ ) # b k h v
+ assert (
+ len(top_k_tokens_per_head) == self.n_predict
+ ), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
+ for i in range(self.n_predict):
+ # Project and predict
+ z = self.emb(ind)
+ # z = z.mul(self.emb_weight) # b k d
+ if i == 0:
+ state = self.proj0(state) * self.state_weight + z
+ else:
+ state = self.proj1(state) * self.state_weight + z
+ state = self.activation(self.ln(state)) # b k d
+ probs = F.log_softmax(self.head(state), dim=-1) # b k v
+ _probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
+
+ # Update candidate set with new predictions
+
+ # Update distribution set with new logits
+ all_probs[:, i] = probs.exp()
+
+ # Update state, log_probs and ind for new predictions
+ state = state.unsqueeze(2).expand(
+ -1, -1, top_k_tokens_per_head[i], -1
+ ) # b k k' d
+ state = state.reshape(-1, b, state.size(3)) # b kk' d
+ ind = preds.view(-1, b) # b kk'
+
+ speculative_logits = all_probs
+ return speculative_logits
+
+
+class MLPSpeculatorModel(torch.nn.Module):
+ def __init__(self, config, prefix, weights):
+ super().__init__()
+ self.config = config
+ self.n_predict = get_speculate()
+ self.hidden_size = config.hidden_size
+
+ self.emb = nn.ModuleList(
+ [
+ TensorParallelEmbedding(f"{prefix}.emb.{i}", weights)
+ for i in range(self.n_predict)
+ ]
+ )
+ self.proj = [
+ FastLinear.load(
+ config,
+ prefix=f"{prefix}.proj.{i}",
+ weights=weights,
+ bias=False,
+ )
+ for i in range(self.n_predict)
+ ]
+ self.head = nn.ModuleList(
+ [
+ FastLinear.load(config, f"{prefix}.head.{i}", weights, bias=False)
+ for i in range(self.n_predict)
+ ]
+ )
+ self.ln = nn.ModuleList(
+ [
+ MLPSpeculatorLayerNorm(
+ prefix=f"{prefix}.ln.{i}",
+ config=config,
+ weights=weights,
+ )
+ for i in range(self.n_predict)
+ ]
+ )
+
+ # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
+ self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
+ self.activation = nn.GELU()
+ self.vsize = config.vocab_size
+ self.inner_dim = config.speculator_config["inner_dim"]
+ self.top_k_tokens_per_head = [1] * self.n_predict
+ self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(
+ self.inner_dim / 2
+ )
+ self.emb.weight *= self.emb_weight
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ input_ids: torch.Tensor,
+ ):
+ top_k_tokens_per_head = self.top_k_tokens_per_head
+
+ # k indicates # of candidates
+ # h indicates # of generated tokens
+ state = hidden_states
+ b = state.size(0)
+ ind = input_ids.unsqueeze(0)
+ all_probs = torch.empty(
+ b, self.n_predict, self.vsize, device=state.device
+ ) # b k h v
+ assert (
+ len(top_k_tokens_per_head) == self.n_predict
+ ), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
+ for i in range(self.n_predict):
+ # Project and predict
+ z = self.emb[i](ind)
+ # z = z.mul(self.emb_weight) # b k d
+ state = self.proj[i](state) * self.state_weight + z
+ state = self.activation(self.ln[i](state)) # b k d
+ probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
+ _probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
+
+ # Update candidate set with new predictions
+
+ # Update distribution set with new logits
+ all_probs[:, i] = probs.exp()
+
+ # Update state, log_probs and ind for new predictions
+ state = state.unsqueeze(2).expand(
+ -1, -1, top_k_tokens_per_head[i], -1
+ ) # b k k' d
+ state = state.reshape(-1, b, state.size(3)) # b kk' d
+ ind = preds.view(-1, b) # b kk'
+
+ speculative_logits = all_probs
+ return speculative_logits
+
+
+class MLPSpeculatorHead(nn.Module):
+ def __init__(self, lm_head, mlp_speculator, scale_input: bool):
+ super().__init__()
+ self.lm_head = lm_head
+ self.mlp_speculator = mlp_speculator
+ self.scale_input = scale_input
+
+ def forward(
+ self, input: torch.Tensor
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ logits = self.lm_head(input)
+ # If we have too many tokens, we skip speculative logits
+ if input.shape[0] > 128:
+ return logits, None
+
+ input_ids = logits.argmax(dim=-1)
+ if self.scale_input:
+ input = simple_norm(input)
+ speculative_logits = self.mlp_speculator(input, input_ids)
+ return logits, speculative_logits
+
+ @staticmethod
+ def load(config, prefix: str, weights):
+ from pathlib import Path
+ from safetensors import safe_open
+
+ speculator_path = config.speculator["path"]
+
+ for fname in config.speculator["model_paths"]:
+ filename = str(Path(speculator_path) / fname)
+ routing = weights.routing
+ with safe_open(filename, framework="pytorch") as f:
+ for k in f.keys():
+ if k in routing and routing[k] != filename:
+ raise RuntimeError(
+ f"Key {k} was found in multiple files: {filename} and {routing[k]}"
+ )
+ routing[k] = filename
+
+ tie_weights = config.speculator_config.get("tie_weights", False)
+ if tie_weights:
+ mlp_speculator = MLPSpeculatorModelTied(config, "speculator", weights)
+ else:
+ mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
+ # This is used in https://huggingface.co/ibm-fms/llama3-70b-accelerator
+ scale_input = config.speculator_config.get("scale_input", False)
+ lm_head = TensorParallelHead.load(config, prefix, weights)
+ return MLPSpeculatorHead(lm_head, mlp_speculator, scale_input)
diff --git a/backends/gaudi/server/text_generation_server/layers/moe/__init__.py b/backends/gaudi/server/text_generation_server/layers/moe/__init__.py
new file mode 100644
index 00000000000..8b9d6fcb0a1
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/moe/__init__.py
@@ -0,0 +1,250 @@
+from typing import Optional, Protocol, runtime_checkable
+
+import torch
+import torch.nn as nn
+from loguru import logger
+from transformers.activations import ACT2FN
+
+from text_generation_server.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelRowLinear,
+)
+from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
+from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
+from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
+from text_generation_server.utils.log import log_once
+from text_generation_server.utils.weights import (
+ DefaultWeightsLoader,
+ Weights,
+ UnquantizedWeight,
+)
+
+from .fused_moe import fused_topk, grouped_topk
+
+# NOTE: we are using a protocol here, because multiple inherance is not nice.
+# We need `Module`, and `Module` -> some abstract class -> some concrete
+# class inheritance is whacky.
+
+
+@runtime_checkable
+class MoELayer(Protocol):
+ def __init__(
+ self,
+ *,
+ n_expert_group: Optional[int],
+ n_experts: int,
+ prefix: str,
+ renormalize: bool,
+ topk: int,
+ topk_group: Optional[int],
+ weights: Weights,
+ gate_proj_name: str = "gate_proj",
+ up_proj_name: str = "up_proj",
+ down_proj_name: str = "down_proj",
+ hidden_act: str = "silu",
+ scoring_func: Optional[str] = None,
+ e_score_correction_bias: Optional[float] = None,
+ ): ...
+
+ def forward(
+ self, x: torch.Tensor, *, gating_output: torch.Tensor
+ ) -> torch.Tensor: ...
+
+
+class DenseMoELayer(nn.Module):
+ """
+ Layer for MoE that applies *all* experts to each tokens and then weights
+ their outputs based on the calculated routing. This layer is much slower
+ than `SparseMoELayer` and should only be used when no fused kernels are
+ available (e.g. for unsupported quantizers).
+ """
+
+ def __init__(
+ self,
+ *,
+ n_expert_group: Optional[int],
+ n_experts: int,
+ prefix: str,
+ renormalize: bool,
+ topk: int,
+ topk_group: Optional[int],
+ weights: Weights,
+ gate_proj_name: str = "gate_proj",
+ up_proj_name: str = "up_proj",
+ down_proj_name: str = "down_proj",
+ hidden_act: str = "silu",
+ scoring_func: Optional[str] = None,
+ e_score_correction_bias: Optional[float] = None,
+ ):
+ super().__init__()
+
+ assert scoring_func is None, "scoring func is not handled"
+ assert e_score_correction_bias is None, "scoring correction bias is not handled"
+
+ log_once(
+ logger.info,
+ "No fused layers are available for this model type, using (slower) dense MoE layer",
+ )
+
+ assert (n_expert_group is None) == (
+ topk_group is None
+ ), "n_expert_group and topk_group must both be None or have some value"
+
+ self.n_expert_group = n_expert_group
+ self.n_experts = n_experts
+ self.renormalize = renormalize
+ self.topk = topk
+ self.topk_group = topk_group
+
+ if "gelu" in hidden_act:
+ self.act = lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh"
+ if hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
+ else "none"
+ ),
+ )
+ elif "silu" in hidden_act:
+ self.act = torch.nn.functional.silu
+ else:
+ self.act = ACT2FN[hidden_act]
+
+ self.gate_proj = [
+ TensorParallelColumnLinear.load(
+ None,
+ prefix=f"{prefix}.{i}.{gate_proj_name}",
+ weights=weights,
+ bias=False,
+ )
+ for i in range(self.n_experts)
+ ]
+ self.up_proj = [
+ TensorParallelColumnLinear.load(
+ None,
+ prefix=f"{prefix}.{i}.{up_proj_name}",
+ weights=weights,
+ bias=False,
+ )
+ for i in range(self.n_experts)
+ ]
+ self.down_proj = [
+ TensorParallelRowLinear.load(
+ None,
+ prefix=f"{prefix}.{i}.{down_proj_name}",
+ weights=weights,
+ bias=False,
+ )
+ for i in range(self.n_experts)
+ ]
+
+ self.process_group = weights.process_group
+
+ def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
+ """
+ x: (sequence_length, model_dim)
+ gating_output: (sequence_length, n_experts)
+ """
+ # optional reshape
+ input_shape = x.shape
+ x = x.view(-1, input_shape[-1])
+
+ if self.n_expert_group is not None and self.topk_group is not None:
+ topk_weights, topk_ids = grouped_topk(
+ x,
+ gating_output,
+ self.topk,
+ renormalize=self.renormalize,
+ num_expert_group=self.n_expert_group,
+ topk_group=self.topk_group,
+ )
+ else:
+ topk_weights, topk_ids = fused_topk(
+ x, gating_output, self.topk, self.renormalize
+ )
+ topk_weights = topk_weights.to(x.dtype)
+
+ weights = torch.zeros(
+ topk_ids.shape[0], self.n_experts, dtype=x.dtype, device=x.device
+ )
+
+ weights.scatter_(1, topk_ids.long(), topk_weights.to(weights.dtype))
+
+ out = torch.zeros_like(x)
+ for i in range(self.n_experts):
+ h = self.act(self.gate_proj[i](x)) * self.up_proj[i](x)
+ h = self.down_proj[i](h, reduce=False)
+ out += h * weights[:, i].view(-1, 1)
+
+ return out
+
+
+class SparseMoELayer(nn.Module):
+ """
+ Layer for MoE that uses fused kernels to only apply the active experts
+ for each token (rather than applying all experts and selecting the
+ outputs of active experts).
+ """
+
+ def __init__(
+ self,
+ *,
+ n_expert_group: Optional[int],
+ n_experts: int,
+ prefix: str,
+ renormalize: bool,
+ topk: int,
+ topk_group: Optional[int],
+ weights: Weights,
+ scoring_func: Optional[str] = "softmax",
+ e_score_correction_bias: Optional[float] = None,
+ gate_proj_name: str = "gate_proj",
+ up_proj_name: str = "up_proj",
+ down_proj_name: str = "down_proj",
+ ):
+ super().__init__()
+ if (
+ isinstance(weights.loader, DefaultWeightsLoader)
+ and isinstance(weights.loader.weight_class, UnquantizedWeight)
+ ) or isinstance(weights.loader, HybridFP8UnquantLoader):
+ if (
+ isinstance(weights.loader, HybridFP8UnquantLoader)
+ and weights.loader.to_fp8
+ ):
+ cls = FP8SparseMoELayer
+ else:
+ cls = UnquantizedSparseMoELayer
+ else:
+ raise ValueError(
+ f"Unsupported weights loader: {type(weights.loader)}, sparse MoE is only supported for unquantized, AWQ, and GPTQ weights"
+ )
+
+ log_once(
+ logger.info,
+ "Using MoE layer wih fused gemm",
+ )
+
+ self.moe = cls(
+ n_expert_group=n_expert_group,
+ n_experts=n_experts,
+ prefix=prefix,
+ renormalize=renormalize,
+ topk=topk,
+ topk_group=topk_group,
+ weights=weights,
+ scoring_func=scoring_func,
+ e_score_correction_bias=e_score_correction_bias,
+ gate_proj_name=gate_proj_name,
+ up_proj_name=up_proj_name,
+ down_proj_name=down_proj_name,
+ )
+
+ def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
+ return self.moe(x, gating_output=gating_output)
+
+ @staticmethod
+ def is_supported(weights: Weights) -> bool:
+ return (
+ isinstance(weights.loader, DefaultWeightsLoader)
+ and isinstance(weights.loader.weight_class, UnquantizedWeight)
+ ) or isinstance(weights.loader, HybridFP8UnquantLoader)
diff --git a/backends/gaudi/server/text_generation_server/layers/moe/fp8.py b/backends/gaudi/server/text_generation_server/layers/moe/fp8.py
new file mode 100644
index 00000000000..d235180e416
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/moe/fp8.py
@@ -0,0 +1,270 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import os
+
+from text_generation_server.utils.weights import Weights
+from text_generation_server.layers.fp8 import (
+ Fp8Weight,
+ fp8_quantize,
+ quant_dtype,
+ normalize_e4m3fn_to_native_float8,
+ dynamic_quant,
+ dequant_block_fp8_weight_naive,
+)
+from text_generation_server.layers.moe.fused_moe import select_experts
+import habana_frameworks.torch as htorch
+
+
+class FP8SparseMoELayer(nn.Module):
+ def __init__(
+ self,
+ *,
+ n_expert_group: Optional[int],
+ n_experts: int,
+ prefix: str,
+ renormalize: bool,
+ topk: int,
+ topk_group: Optional[int],
+ weights: Weights,
+ scoring_func: Optional[str] = "softmax",
+ e_score_correction_bias: Optional[float] = None,
+ gate_proj_name: str = "gate_proj",
+ up_proj_name: str = "up_proj",
+ down_proj_name: str = "down_proj",
+ ):
+ super().__init__()
+
+ assert (n_expert_group is None) == (
+ topk_group is None
+ ), "n_expert_group and topk_group must both be None or have some value"
+
+ self.n_expert_group = n_expert_group
+ self.topk = topk
+ self.topk_group = topk_group
+ self.renormalize = renormalize
+ self.weight_block_size = weights.weights_loader.weight_block_size
+ self.scoring_func = scoring_func
+ self.e_score_correction_bias = e_score_correction_bias
+ self.world_size = weights.process_group.size()
+ self.rank = weights.process_group.rank()
+ self.ep_rank = self.rank
+ self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true"
+ if (n_experts + self.world_size - 1) // self.world_size < 4:
+ self.use_ep = False
+ if self.use_ep:
+ n_experts_per_rank = (n_experts + self.world_size - 1) // self.world_size
+ self.ep_offset = self.ep_rank * n_experts_per_rank
+ n_experts = min(n_experts_per_rank, n_experts - self.ep_offset)
+ else:
+ self.ep_offset = 0
+
+ (
+ self.gate_up_proj,
+ self.gate_up_proj_weight_scale,
+ self.gate_up_proj_input_scale,
+ ) = _load_expert_multi_weights_col(
+ prefix=prefix,
+ n_experts=n_experts,
+ gate_proj_name=gate_proj_name,
+ up_proj_name=up_proj_name,
+ weights=weights,
+ use_ep=self.use_ep,
+ ep_offset=self.ep_offset,
+ )
+
+ self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
+ _load_expert_weights_row(
+ prefix=prefix,
+ n_experts=n_experts,
+ name=down_proj_name,
+ weights=weights,
+ use_ep=self.use_ep,
+ ep_offset=self.ep_offset,
+ )
+ )
+ if self.weight_block_size is not None:
+ self.gate_up_proj, self.gate_up_proj_weight_scale = dynamic_quant(
+ dequant_block_fp8_weight_naive(
+ self.gate_up_proj,
+ self.gate_up_proj_weight_scale,
+ self.weight_block_size,
+ )
+ )
+ self.down_proj, self.down_proj_weight_scale = dynamic_quant(
+ dequant_block_fp8_weight_naive(
+ self.down_proj, self.down_proj_weight_scale, self.weight_block_size
+ )
+ )
+ self.gate_up_proj_weight_scale, self.down_proj_weight_scale = (
+ self.gate_up_proj_weight_scale.squeeze(-1),
+ self.down_proj_weight_scale.squeeze(-1),
+ )
+
+ def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
+ topk_weights, topk_ids = select_experts(
+ hidden_states=x,
+ router_logits=gating_output,
+ use_grouped_topk=self.n_expert_group is not None,
+ top_k=self.topk,
+ renormalize=self.renormalize,
+ topk_group=self.topk_group,
+ num_expert_group=self.n_expert_group,
+ scoring_func=self.scoring_func,
+ e_score_correction_bias=self.e_score_correction_bias,
+ )
+ total_num_experts = gating_output.size(-1)
+ x_fp8, x_scale = dynamic_quant(x, single_scale=True)
+
+ if self.use_ep:
+ moe_n_slice = 1
+ n_expert_slice = (
+ total_num_experts + self.world_size - 1
+ ) // self.world_size
+ else:
+ moe_n_slice = 1
+ n_expert_slice = (total_num_experts + moe_n_slice - 1) // moe_n_slice
+ for i in range(moe_n_slice):
+ min_expert = i * n_expert_slice
+ max_expert = min((i + 1) * n_expert_slice, total_num_experts)
+ w13_list_slice = [
+ self.gate_up_proj[j, ...] for j in range(min_expert, max_expert)
+ ]
+ w2_list_slice = [
+ self.down_proj[j, ...] for j in range(min_expert, max_expert)
+ ]
+ w13_weight_scale = [
+ self.gate_up_proj_weight_scale[j, ...]
+ for j in range(min_expert, max_expert)
+ ]
+ w2_weight_scale = [
+ self.down_proj_weight_scale[j, ...]
+ for j in range(min_expert, max_expert)
+ ]
+
+ current_hidden_states = torch.ops.hpu.mixture_of_experts(
+ hidden_states=x_fp8,
+ expert_routing_table=topk_ids.to(torch.int64),
+ router_weights=topk_weights.to(x.dtype),
+ w12=w13_list_slice,
+ w3=w2_list_slice,
+ d_scale_hidden_states=x_scale,
+ d_scale_w12=w13_weight_scale,
+ d_scale_w3=w2_weight_scale,
+ permuted_weights=True,
+ activation="silu",
+ experts_min=min_expert + self.ep_offset,
+ experts_max=max_expert + self.ep_offset - 1,
+ )
+ htorch.core.mark_step()
+ if i == 0:
+ final_hidden_states = current_hidden_states
+ else:
+ final_hidden_states.add_(current_hidden_states)
+ return final_hidden_states
+
+
+def _load_expert_weights(
+ get_weight_fn,
+ *,
+ prefix: str,
+ n_experts: int,
+ name: str,
+ weights: Weights,
+ ep_offset: int = 0,
+) -> torch.Tensor:
+ all_weight = None
+ all_weight_scales = None
+ max_input_scale = None
+
+ for i in range(n_experts):
+ weight = get_weight_fn(prefix, i + ep_offset, name, weights)
+
+ assert isinstance(weight, Fp8Weight)
+
+ if all_weight is None:
+ all_weight = torch.empty(
+ (n_experts,) + weight.weight.shape,
+ dtype=quant_dtype,
+ device=weight.weight.device,
+ )
+ if all_weight_scales is None:
+ all_weight_scales = torch.empty(
+ (n_experts,) + weight.weight_scale.shape,
+ dtype=torch.float32,
+ device=weight.weight.device,
+ )
+
+ if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:
+ all_weight[i], all_weight_scales[i], current_input_scale = (
+ normalize_e4m3fn_to_native_float8(
+ weight.weight, weight.weight_scale, weight.input_scale
+ )
+ )
+ if current_input_scale is not None:
+ if max_input_scale is None or current_input_scale > max_input_scale:
+ max_input_scale = current_input_scale
+ else:
+ all_weight[i], all_weight_scales[i] = fp8_quantize(
+ weight.weight, scalar=True
+ )
+
+ assert all_weight is not None
+
+ return all_weight, all_weight_scales, max_input_scale
+
+
+def _load_expert_multi_weights_col(
+ *,
+ prefix: str,
+ n_experts: int,
+ gate_proj_name: str,
+ up_proj_name: str,
+ weights: Weights,
+ use_ep: bool = False,
+ ep_offset: int = 0,
+) -> torch.Tensor:
+ def get_weight_fn_sharded(prefix, i, name, weights):
+ return weights.get_multi_weights_col(
+ [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
+ )
+
+ def get_weight_fn(prefix, i, name, weights):
+ return weights.get_multi_weights(
+ [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
+ )
+
+ return _load_expert_weights(
+ get_weight_fn if use_ep else get_weight_fn_sharded,
+ prefix=prefix,
+ n_experts=n_experts,
+ name=None,
+ weights=weights,
+ ep_offset=ep_offset if use_ep else 0,
+ )
+
+
+def _load_expert_weights_row(
+ *,
+ prefix: str,
+ n_experts: int,
+ name: str,
+ weights: Weights,
+ use_ep: bool = False,
+ ep_offset: int = 0,
+) -> torch.Tensor:
+ def get_weight_fn_sharded(prefix, i, name, weights):
+ return weights.get_weights_row(f"{prefix}.{i}.{name}")
+
+ def get_weight_fn(prefix, i, name, weights):
+ return weights.get_weights(f"{prefix}.{i}.{name}")
+
+ return _load_expert_weights(
+ get_weight_fn if use_ep else get_weight_fn_sharded,
+ prefix=prefix,
+ n_experts=n_experts,
+ name=name,
+ weights=weights,
+ ep_offset=ep_offset if use_ep else 0,
+ )
diff --git a/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py b/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py
new file mode 100644
index 00000000000..1987f0edb77
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py
@@ -0,0 +1,131 @@
+# coding=utf-8
+# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Tuple, Optional
+
+import torch
+
+
+def grouped_topk(
+ hidden_states: torch.Tensor,
+ gating_output: torch.Tensor,
+ topk: int,
+ renormalize: bool,
+ num_expert_group: int = 0,
+ topk_group: int = 0,
+ scoring_func: str = "softmax",
+ e_score_correction_bias: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
+
+ gating_output = gating_output.float()
+ if e_score_correction_bias is not None:
+ e_score_correction_bias = e_score_correction_bias.float()
+
+ if scoring_func == "softmax":
+ scores = torch.softmax(gating_output, dim=-1)
+ elif scoring_func == "sigmoid":
+ scores = gating_output.sigmoid()
+ else:
+ raise ValueError(f"Unsupported scoring function: {scoring_func}")
+
+ num_token = scores.shape[0]
+ if e_score_correction_bias is not None:
+ # Store original scores before applying correction bias. We use biased
+ # scores for expert selection but original scores for routing weights
+ original_scores = scores
+ scores = scores + e_score_correction_bias.unsqueeze(0)
+ group_scores = (
+ scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
+ )
+ else:
+ group_scores = (
+ scores.view(num_token, num_expert_group, -1).max(dim=-1).values
+ ) # [n, n_group]
+
+ group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
+ 1
+ ] # [n, top_k_group]
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
+ score_mask = (
+ group_mask.unsqueeze(-1)
+ .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
+ .reshape(num_token, -1)
+ ) # [n, e]
+ tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
+
+ if e_score_correction_bias is not None:
+ topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
+ # Use original unbiased scores for the routing weights
+ topk_weights = original_scores.gather(1, topk_ids)
+ else:
+ topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
+
+ if renormalize:
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
+
+ return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
+
+
+def fused_topk(
+ hidden_states: torch.Tensor,
+ gating_output: torch.Tensor,
+ topk: int,
+ renormalize: bool,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ topk_weights = torch.nn.functional.softmax(
+ gating_output, dim=1, dtype=torch.float32
+ )
+ topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
+ if renormalize:
+ topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
+ return topk_weights, topk_ids
+
+
+def select_experts(
+ hidden_states: torch.Tensor,
+ router_logits: torch.Tensor,
+ top_k: int,
+ use_grouped_topk: bool,
+ renormalize: bool,
+ topk_group: Optional[int] = None,
+ num_expert_group: Optional[int] = None,
+ scoring_func: str = "softmax",
+ e_score_correction_bias: Optional[torch.Tensor] = None,
+):
+
+ # DeekSeekv2 uses grouped_top_k
+ if use_grouped_topk:
+ assert topk_group is not None
+ assert num_expert_group is not None
+ topk_weights, topk_ids = grouped_topk(
+ hidden_states=hidden_states,
+ gating_output=router_logits,
+ topk=top_k,
+ renormalize=renormalize,
+ num_expert_group=num_expert_group,
+ topk_group=topk_group,
+ scoring_func=scoring_func,
+ e_score_correction_bias=e_score_correction_bias,
+ )
+ else:
+ topk_weights, topk_ids = fused_topk(
+ hidden_states=hidden_states,
+ gating_output=router_logits,
+ topk=top_k,
+ renormalize=renormalize,
+ )
+ return topk_weights, topk_ids
diff --git a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py
new file mode 100644
index 00000000000..617ff6ceadb
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py
@@ -0,0 +1,177 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+from text_generation_server.utils.weights import UnquantizedWeight, Weights
+from vllm_hpu_extension.ops import VllmMixtureOfExpertsOp
+import habana_frameworks.torch as htorch
+import torch.nn.functional as F
+import os
+
+
+class UnquantizedSparseMoELayer(nn.Module):
+ def __init__(
+ self,
+ *,
+ n_expert_group: Optional[int],
+ n_experts: int,
+ prefix: str,
+ renormalize: bool,
+ topk: int,
+ topk_group: Optional[int],
+ weights: Weights,
+ scoring_func: Optional[str] = "softmax",
+ e_score_correction_bias: Optional[float] = None,
+ gate_proj_name: str = "gate_proj",
+ up_proj_name: str = "up_proj",
+ down_proj_name: str = "down_proj",
+ ):
+ super().__init__()
+
+ assert (n_expert_group is None) == (
+ topk_group is None
+ ), "n_expert_group and topk_group must both be None or have some value"
+
+ self.n_expert_group = n_expert_group
+ self.topk = topk
+ self.topk_group = topk_group
+ self.renormalize = renormalize
+ self.weight_block_size = weights.weights_loader.weight_block_size
+ self.scoring_func = scoring_func
+ self.e_score_correction_bias = e_score_correction_bias
+ self.rank = weights.process_group.rank()
+ self.world_size = weights.process_group.size()
+ self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true"
+ if (n_experts + self.world_size - 1) // self.world_size < 4:
+ self.use_ep = False
+ if self.use_ep:
+ n_experts_per_rank = (n_experts + self.world_size - 1) // self.world_size
+ self.ep_offset = self.rank * n_experts_per_rank
+ n_experts = min(n_experts_per_rank, n_experts - self.ep_offset)
+ experts_min = self.ep_offset
+ experts_max = self.ep_offset + n_experts - 1
+ else:
+ self.ep_offset = 0
+ experts_min = 0
+ experts_max = n_experts - 1
+
+ self.gate_up_proj = _load_expert_multi_weights_col(
+ prefix=prefix,
+ n_experts=n_experts,
+ gate_proj_name=gate_proj_name,
+ up_proj_name=up_proj_name,
+ weights=weights,
+ use_ep=self.use_ep,
+ ep_offset=self.ep_offset,
+ )
+
+ self.down_proj = _load_expert_weights_row(
+ prefix=prefix,
+ n_experts=n_experts,
+ name=down_proj_name,
+ weights=weights,
+ use_ep=self.use_ep,
+ ep_offset=self.ep_offset,
+ )
+
+ self.MoeOp = VllmMixtureOfExpertsOp(n_experts, experts_min, experts_max)
+ for i in range(n_experts):
+ self.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
+ self.MoeOp.w2_list[i].set_weight(self.down_proj[i])
+
+ def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
+ htorch.core.mark_step()
+ routing_weights = F.softmax(gating_output, dim=1, dtype=torch.float32)
+ routing_weights, selected_experts = torch.topk(
+ routing_weights, self.topk, dim=-1
+ )
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+ routing_weights = routing_weights.to(x.dtype)
+
+ final_hidden_states = self.MoeOp(
+ hidden_states=x,
+ expert_routing_table=selected_experts,
+ router_weights=routing_weights,
+ permuted_weights=True,
+ activation="silu",
+ )
+
+ return final_hidden_states.view(-1, x.shape[1])
+
+
+def _load_expert_multi_weights_col(
+ *,
+ prefix: str,
+ n_experts: int,
+ gate_proj_name: str,
+ up_proj_name: str,
+ weights: Weights,
+ use_ep: bool = False,
+ ep_offset: int = 0,
+) -> torch.Tensor:
+ all_weight = None
+ for i in range(n_experts):
+ if not use_ep:
+ weight = weights.get_multi_weights_col(
+ [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
+ )
+ else:
+ weight = weights.get_multi_weights(
+ [
+ f"{prefix}.{i+ep_offset}.{gate_proj_name}",
+ f"{prefix}.{i+ep_offset}.{up_proj_name}",
+ ],
+ 0,
+ )
+
+ assert isinstance(weight, UnquantizedWeight)
+
+ if all_weight is None:
+ all_weight = torch.empty(
+ (n_experts,) + weight.weight.shape,
+ dtype=weight.weight.dtype,
+ device=weight.weight.device,
+ )
+
+ all_weight[i] = weight.weight
+
+ assert all_weight is not None
+
+ return all_weight
+
+
+def _load_expert_weights_row(
+ *,
+ prefix: str,
+ n_experts: int,
+ name: str,
+ weights: Weights,
+ use_ep: bool = False,
+ ep_offset: int = 0,
+) -> torch.Tensor:
+ all_weight = None
+ for i in range(n_experts):
+ if not use_ep:
+ weight = weights.get_weights_row(
+ f"{prefix}.{i}.{name}",
+ )
+ else:
+ weight = weights.get_weights(
+ f"{prefix}.{i+ep_offset}.{name}",
+ )
+
+ assert isinstance(weight, UnquantizedWeight)
+
+ if all_weight is None:
+ all_weight = torch.empty(
+ (n_experts,) + weight.weight.shape,
+ dtype=weight.weight.dtype,
+ device=weight.weight.device,
+ )
+
+ all_weight[i] = weight.weight
+
+ assert all_weight is not None
+
+ return all_weight
diff --git a/backends/gaudi/server/text_generation_server/layers/rotary.py b/backends/gaudi/server/text_generation_server/layers/rotary.py
new file mode 100644
index 00000000000..7e740e5f69d
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/rotary.py
@@ -0,0 +1,604 @@
+import os
+import math
+import torch
+from torch import nn
+from habana_frameworks.torch.hpex.kernels import (
+ RotaryPosEmbeddingMode,
+ apply_rotary_pos_emb,
+)
+
+
+def _create_inv_freq(dim, base, device):
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
+ )
+ return inv_freq
+
+
+def _get_rope_config(config):
+ if os.getenv("ROPE_SCALING", None) is not None:
+ rope_scaling = {
+ "type": os.environ["ROPE_SCALING"],
+ "factor": float(os.environ["ROPE_FACTOR"]),
+ }
+ return rope_scaling
+ return getattr(config, "rope_scaling", None)
+
+
+class PositionRotaryEmbedding(nn.Module):
+ def __init__(self, inv_freq, scaling_factor, max_position_embeddings):
+ super().__init__()
+ self.inv_freq = inv_freq
+ self._seq_len_cached = 0
+ self._cos_cached = None
+ self._sin_cached = None
+ self._cos_k_cached = None
+ self._sin_k_cached = None
+ self.scaling_factor = scaling_factor
+ self.dynamic_args = None
+ self._update_cos_sin_cache(
+ torch.float32, inv_freq.device, max_position_embeddings
+ )
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ ):
+ num_tokens = query.shape[0]
+ head_size = query.shape[-1]
+ # HPU RoPE kernel requires hidden dimension for cos and sin to be equal
+ # to query hidden dimension, so the original tensors need to be
+ # expanded
+ # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
+ # and expansion of cos/sin tensors via concatenation
+ rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
+ cos = torch.cat((cos, cos), dim=-1)
+ sin = torch.cat((sin, sin), dim=-1)
+ rotary_dim = cos.shape[-1]
+ query_shape = query.shape
+ query = query.view(num_tokens, -1, head_size)
+ query_rot = query[..., :rotary_dim]
+ query_pass = query[..., rotary_dim:]
+ query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
+ query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
+
+ key_shape = key.shape
+ key = key.view(num_tokens, -1, head_size)
+ key_rot = key[..., :rotary_dim]
+ key_pass = key[..., rotary_dim:]
+ key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
+ key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
+
+ @classmethod
+ def static(cls, config, dim, base, device):
+ inv_freq = _create_inv_freq(dim, base, device)
+ scaling_factor = None
+ rope_scaling = _get_rope_config(config)
+ if not hasattr(config, "max_position_embeddings") and hasattr(
+ config, "max_seq_len"
+ ):
+ # handling for dbrx
+ config.max_position_embeddings = config.max_seq_len
+ if rope_scaling is not None:
+ # `rope_type` is now standard in transformers, but some existing models
+ # have `type` instead.
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))
+
+ if rope_type == "linear":
+ pass
+ elif rope_type == "default":
+ pass
+ elif rope_type == "mrope":
+ mrope_section = rope_scaling["mrope_section"]
+ if mrope_section is not None:
+ return RotaryPositionEmbeddingMultimodalSections(
+ inv_freq,
+ scaling_factor,
+ mrope_section,
+ config.max_position_embeddings,
+ )
+ elif rope_type == "dynamic":
+ scaling_factor = rope_scaling["factor"]
+ return DynamicPositionRotaryEmbedding(
+ dim=dim,
+ max_position_embeddings=config.max_position_embeddings,
+ base=base,
+ device=inv_freq.device,
+ scaling_factor=scaling_factor,
+ )
+ elif rope_type == "llama3":
+ inv_freq = apply_llama3_scaling(
+ inv_freq,
+ scaling_factor=rope_scaling["factor"],
+ low_freq_factor=rope_scaling["low_freq_factor"],
+ high_freq_factor=rope_scaling["high_freq_factor"],
+ original_max_position_embeddings=rope_scaling[
+ "original_max_position_embeddings"
+ ],
+ )
+
+ return cls(inv_freq, scaling_factor, config.max_position_embeddings)
+
+ elif rope_type == "yarn":
+ scaling_factor = rope_scaling["factor"]
+ mscale = rope_scaling.get("mscale", 1.0)
+ mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
+ return YarnPositionRotaryEmbedding(
+ dim=2 * inv_freq.shape[0],
+ max_position_embeddings=rope_scaling[
+ "original_max_position_embeddings"
+ ],
+ base=base,
+ device=inv_freq.device,
+ scaling_factor=scaling_factor,
+ extrapolation_factor=1,
+ attn_factor=1,
+ beta_fast=32,
+ beta_slow=1,
+ mscale=mscale,
+ mscale_all_dim=mscale_all_dim,
+ )
+ elif rope_type in ["su", "longrope"]:
+ short_factor = torch.tensor(
+ rope_scaling["short_factor"], dtype=torch.float32, device=device
+ )
+ short_inv_freq = 1.0 / (
+ short_factor
+ * base
+ ** (
+ torch.arange(0, dim, 2, device=device, dtype=torch.float32)
+ / dim
+ )
+ )
+ long_factor = torch.tensor(
+ rope_scaling["long_factor"], dtype=torch.float32, device=device
+ )
+ long_inv_freq = 1.0 / (
+ long_factor
+ * base
+ ** (
+ torch.arange(0, dim, 2, device=device, dtype=torch.float32)
+ / dim
+ )
+ )
+
+ original_max_position_embeddings = (
+ config.original_max_position_embeddings
+ )
+ max_position_embeddings = config.max_position_embeddings
+ if max_position_embeddings <= original_max_position_embeddings:
+ scaling_factor = 1.0
+ else:
+ scale = max_position_embeddings / original_max_position_embeddings
+ scaling_factor = math.sqrt(
+ 1 + math.log(scale) / math.log(original_max_position_embeddings)
+ )
+
+ # if short_mscale and long_mscale are provided we need to scale the freqs
+ # using the Phi3LongRoPEScaledRotaryEmbedding
+ if ("short_mscale" in rope_scaling) and ("long_mscale" in rope_scaling):
+ short_mscale = rope_scaling["short_mscale"]
+ long_mscale = rope_scaling["long_mscale"]
+ return Phi3LongRoPEScaledRotaryEmbedding(
+ short_inv_freq=short_inv_freq,
+ long_inv_freq=long_inv_freq,
+ max_position_embeddings=config.max_position_embeddings,
+ short_mscale=short_mscale,
+ long_mscale=long_mscale,
+ original_max_position_embeddings=original_max_position_embeddings,
+ )
+
+ return SuRotaryEmbedding(
+ short_inv_freq=short_inv_freq,
+ long_inv_freq=long_inv_freq,
+ scaling_factor=scaling_factor,
+ original_max_position_embeddings=original_max_position_embeddings,
+ max_position_embeddings=config.max_position_embeddings,
+ )
+ else:
+ raise NotImplementedError(
+ f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
+ )
+ return cls(inv_freq, scaling_factor, config.max_position_embeddings)
+
+ @classmethod
+ def load(cls, config, prefix, weights):
+ # XXX: Always load this in float32 !
+ dtype = weights.dtype
+ weights.dtype = torch.float32
+ inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
+ weights.dtype = dtype
+
+ scaling_factor = None
+ rope_scaling = _get_rope_config(config)
+ if rope_scaling is not None:
+ scaling_factor = rope_scaling["factor"]
+ if rope_scaling["type"] == "linear":
+ pass
+ elif rope_scaling["type"] == "dynamic":
+ return DynamicPositionRotaryEmbedding(
+ dim=2 * inv_freq.shape[0],
+ max_position_embeddings=config.max_position_embeddings,
+ base=10000.0,
+ device=inv_freq.device,
+ scaling_factor=scaling_factor,
+ )
+ elif rope_scaling["type"] == "yarn":
+ mscale = rope_scaling.get("mscale", 1.0)
+ mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
+ return YarnPositionRotaryEmbedding(
+ dim=2 * inv_freq.shape[0],
+ max_position_embeddings=rope_scaling[
+ "original_max_position_embeddings"
+ ],
+ base=10000.0,
+ device=inv_freq.device,
+ scaling_factor=scaling_factor,
+ extrapolation_factor=1,
+ attn_factor=1,
+ beta_fast=32,
+ beta_slow=1,
+ mscale=mscale,
+ mscale_all_dim=mscale_all_dim,
+ )
+ else:
+ raise NotImplementedError(
+ f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
+ )
+ return cls(inv_freq, scaling_factor, config.max_position_embeddings)
+
+ def _update_cos_sin_cache(self, dtype, device, seqlen):
+ # Reset the tables if the sequence length has changed,
+ # or if we're on a new device (possibly due to tracing for instance)
+ if (
+ seqlen > self._seq_len_cached
+ or self._cos_cached.device != device
+ or self._cos_cached.dtype != dtype
+ ):
+ self._seq_len_cached = seqlen
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
+ if self.scaling_factor is not None:
+ t /= self.scaling_factor
+ # Don't do einsum, it converts fp32 to fp16
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+
+ freqs = torch.outer(t, self.inv_freq.to(device=t.device))
+ self._cos_cached = torch.cos(freqs).to(dtype)
+ self._sin_cached = torch.sin(freqs).to(dtype)
+
+ def get_cos_sin(self, position_ids: torch.Tensor):
+
+ cos = torch.index_select(self._cos_cached, 0, position_ids)
+ sin = torch.index_select(self._sin_cached, 0, position_ids)
+
+ # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
+ return cos.unsqueeze(1), sin.unsqueeze(1)
+
+
+class SuRotaryEmbedding(PositionRotaryEmbedding):
+ def __init__(
+ self,
+ short_inv_freq,
+ long_inv_freq,
+ scaling_factor,
+ original_max_position_embeddings,
+ max_position_embeddings,
+ ):
+ super(PositionRotaryEmbedding, self).__init__()
+ self.short_inv_freq = short_inv_freq
+ self.long_inv_freq = long_inv_freq
+ self.scaling_factor = scaling_factor
+ self.original_max_position_embeddings = original_max_position_embeddings
+ self._seq_len_cached = 0
+ self._cos_cached = None
+ self._sin_cached = None
+ self._cos_k_cached = None
+ self._sin_k_cached = None
+ self.dynamic_args = None
+ self._update_cos_sin_cache(
+ torch.float32, short_inv_freq.device, max_position_embeddings
+ )
+
+ def _update_cos_sin_cache(self, dtype, device, seqlen):
+ # Reset the tables if the sequence length has changed,
+ # or if we're on a new device (possibly due to tracing for instance)
+ if (
+ seqlen > self._seq_len_cached
+ or self._cos_cached is None
+ or self._cos_cached.device != device
+ or self._cos_cached.dtype != dtype
+ ):
+ self._seq_len_cached = seqlen
+
+ t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
+ short_freqs = torch.outer(
+ t[: self.original_max_position_embeddings],
+ self.short_inv_freq.to(device=t.device),
+ )
+ long_freqs = torch.outer(
+ t[self.original_max_position_embeddings :],
+ self.long_inv_freq.to(device=t.device),
+ )
+
+ freqs = torch.cat([short_freqs, long_freqs])
+
+ self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype)
+ self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
+
+
+class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
+ def __init__(
+ self,
+ short_inv_freq: torch.Tensor,
+ long_inv_freq: torch.Tensor,
+ max_position_embeddings: int,
+ short_mscale: float,
+ long_mscale: float,
+ original_max_position_embeddings: int,
+ ):
+ super(PositionRotaryEmbedding, self).__init__()
+ self.short_inv_freq = short_inv_freq
+ self.long_inv_freq = long_inv_freq
+ self.max_position_embeddings = max_position_embeddings
+ self.short_mscale = short_mscale
+ self.long_mscale = long_mscale
+ self.original_max_position_embeddings = original_max_position_embeddings
+
+ # cache
+ self._seq_len_cached = 0
+ self._cos_cached = None
+ self._sin_cached = None
+ self._cos_k_cached = None
+ self._sin_k_cached = None
+ self.dynamic_args = None
+ self._update_cos_sin_cache(
+ torch.float32, short_inv_freq.device, max_position_embeddings
+ )
+
+ def _update_cos_sin_cache(self, dtype, device, seqlen):
+ if (
+ seqlen > self._seq_len_cached
+ or self._cos_cached is None
+ or self._cos_cached.device != device
+ or self._cos_cached.dtype != dtype
+ ):
+ self._seq_len_cached = seqlen
+ t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
+
+ short_freqs = torch.outer(
+ t[: self.original_max_position_embeddings],
+ self.short_inv_freq.to(device=t.device),
+ )
+
+ long_freqs = torch.outer(
+ t[self.original_max_position_embeddings :],
+ self.long_inv_freq.to(device=t.device),
+ )
+
+ short_freqs = short_freqs * self.short_mscale
+ long_freqs = long_freqs * self.long_mscale
+
+ freqs = torch.empty((seqlen, short_freqs.shape[1]), device=device)
+ freqs[: self.original_max_position_embeddings] = short_freqs
+ freqs[self.original_max_position_embeddings :] = long_freqs
+
+ self._cos_cached = torch.cos(freqs).to(dtype)
+ self._sin_cached = torch.sin(freqs).to(dtype)
+
+
+class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
+ def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
+ inv_freq = _create_inv_freq(dim, base, device)
+ super().__init__(inv_freq, scaling_factor, max_position_embeddings)
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+
+ def _update_cos_sin_cache(self, dtype, device, seqlen):
+ # Reset the tables if the sequence length has changed,
+ # or if we're on a new device (possibly due to tracing for instance)
+ if (
+ seqlen > self._seq_len_cached
+ or self._cos_cached.device != device
+ or self._cos_cached.dtype != dtype
+ ):
+ if seqlen > self.max_position_embeddings:
+ newbase = self.base * (
+ (self.scaling_factor * seqlen / self.max_position_embeddings)
+ - (self.scaling_factor - 1)
+ ) ** (self.dim / (self.dim - 2))
+ self.inv_freq = _create_inv_freq(
+ self.dim, newbase, self.inv_freq.device
+ )
+ self._seq_len_cached = seqlen
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
+ # Don't do einsum, it converts fp32 to fp16
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+
+ freqs = torch.outer(t, self.inv_freq.to(device=t.device))
+ self._cos_cached = torch.cos(freqs).to(dtype)
+ self._sin_cached = torch.sin(freqs).to(dtype)
+
+
+def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
+ 2 * math.log(base)
+ )
+
+
+# Find dim range bounds based on rotations
+def find_correction_range(
+ low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
+):
+ low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
+ high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
+ return max(low, 0), min(high, dim - 1) # Clamp values just in case
+
+
+def linear_ramp_mask(min, max, dim):
+ if min == max:
+ max += 0.001 # Prevent singularity
+
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
+ ramp_func = torch.clamp(linear_func, 0, 1)
+ return ramp_func
+
+
+def get_mscale(scale: float = 1.0, mscale: float = 1.0):
+ if scale <= 1:
+ return 1.0
+ return 0.1 * mscale * math.log(scale) + 1.0
+
+
+class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
+ def __init__(
+ self,
+ dim,
+ max_position_embeddings,
+ base,
+ device,
+ scaling_factor,
+ *,
+ extrapolation_factor,
+ attn_factor,
+ beta_fast,
+ beta_slow,
+ mscale: float,
+ mscale_all_dim: float,
+ ):
+ inv_freq = _create_inv_freq(dim, base, device)
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ self.extrapolation_factor = extrapolation_factor
+ self.attn_factor = attn_factor
+ self.beta_fast = beta_fast
+ self.beta_slow = beta_slow
+ self.mscale_all_dim = mscale_all_dim
+ self.scaling_factor = scaling_factor
+ self.mscale = float(
+ get_mscale(self.scaling_factor, mscale)
+ / get_mscale(self.scaling_factor, mscale_all_dim)
+ * self.attn_factor
+ ) # Get n-d magnitude scaling corrected for interpolation
+ super().__init__(inv_freq, scaling_factor, max_position_embeddings)
+
+ def _update_cos_sin_cache(self, dtype, device, seqlen):
+ # Reset the tables if the sequence length has changed,
+ # or if we're on a new device (possibly due to tracing for instance)
+ if (
+ seqlen > self._seq_len_cached
+ or self._cos_cached.device != device
+ or self._cos_cached.dtype != dtype
+ ):
+ if seqlen > self.max_position_embeddings or True:
+ inv_freq_extrapolation = _create_inv_freq(
+ self.dim, self.base, self.inv_freq.device
+ )
+ freqs = 1.0 / inv_freq_extrapolation
+ inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
+ low, high = find_correction_range(
+ self.beta_fast,
+ self.beta_slow,
+ self.dim,
+ self.base,
+ self.max_position_embeddings,
+ )
+
+ inv_freq_mask = (
+ 1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)
+ ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
+ inv_freq = (
+ inv_freq_interpolation * (1 - inv_freq_mask)
+ + inv_freq_extrapolation * inv_freq_mask
+ )
+
+ self.inv_freq = inv_freq
+
+ self._seq_len_cached = seqlen
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
+ # Don't do einsum, it converts fp32 to fp16
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+
+ freqs = torch.outer(t, self.inv_freq.to(device=t.device))
+ self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
+ self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
+
+
+def apply_llama3_scaling(
+ freqs: torch.Tensor,
+ *,
+ scaling_factor: int,
+ low_freq_factor: int,
+ high_freq_factor: int,
+ original_max_position_embeddings: int,
+):
+ low_freq_wavelen = original_max_position_embeddings / low_freq_factor
+ high_freq_wavelen = original_max_position_embeddings / high_freq_factor
+ new_freqs = []
+
+ for freq in freqs:
+ wavelen = 2 * math.pi / freq
+
+ if wavelen < high_freq_wavelen:
+ new_freqs.append(freq)
+ elif wavelen > low_freq_wavelen:
+ new_freqs.append(freq / scaling_factor)
+ else:
+ assert low_freq_wavelen != high_freq_wavelen
+ smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
+ high_freq_factor - low_freq_factor
+ )
+ new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
+
+ return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
+
+
+class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
+ def __init__(
+ self,
+ inv_freq: torch.Tensor,
+ scaling_factor: float,
+ sections: list,
+ max_position_embeddings,
+ ):
+ self.sections = sections
+ self._cos_cached = None
+ self._sin_cached = None
+ self.section_indices = (
+ torch.arange(len(self.sections))
+ .repeat_interleave(torch.tensor(self.sections))
+ .view(1, 1, -1)
+ .to(inv_freq.device)
+ )
+ super().__init__(inv_freq, scaling_factor, max_position_embeddings)
+
+ def _update_cos_sin_cache(
+ self, dtype: torch.dtype, device: torch.device, seqlen: int
+ ):
+ # always cache the cos/sin for the full sequence length to avoid
+ # recomputing if the sequence length is smaller than the cached one
+ if (
+ seqlen > self._seq_len_cached
+ or self._cos_cached.device != device
+ or self._cos_cached.dtype != dtype
+ ):
+ self._seq_len_cached = seqlen
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
+ freqs = torch.outer(t, self.inv_freq.to(device=t.device))
+ self._cos_cached = torch.cos(freqs).to(dtype)
+ self._sin_cached = torch.sin(freqs).to(dtype)
+ self._sections = self.section_indices.expand(seqlen, -1, -1)
+
+ def get_cos_sin(
+ self,
+ position_ids: torch.Tensor,
+ ):
+ slen = position_ids.shape[0]
+
+ cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
+ sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])
+ return cos, sin
diff --git a/backends/gaudi/server/text_generation_server/layers/speculative.py b/backends/gaudi/server/text_generation_server/layers/speculative.py
new file mode 100644
index 00000000000..cf8469b53d0
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/speculative.py
@@ -0,0 +1,52 @@
+import torch
+import json
+from typing import Tuple, Optional
+from text_generation_server.layers.tensor_parallel import TensorParallelHead
+from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2
+from text_generation_server.layers.mlp import MLPSpeculatorHead
+
+
+class SpeculativeHead(torch.nn.Module):
+ def __init__(self, lm_head, speculator):
+ super().__init__()
+ self.head = lm_head
+ self.speculator = speculator
+
+ @staticmethod
+ def load(config, prefix: str, weights):
+ speculator = config.speculator
+ if speculator:
+ speculator_path = config.speculator["path"]
+ speculator_config = str(speculator_path / "config.json")
+
+ with open(speculator_config, "r") as f:
+ speculator_config = json.load(f)
+
+ config.speculator_config = speculator_config
+ try:
+ architecture = speculator_config["architectures"][0]
+
+ if architecture == "MLPSpeculatorPreTrainedModel":
+ speculator = MLPSpeculatorHead.load(config, prefix, weights)
+ else:
+ speculator = None
+ except KeyError:
+ try:
+ speculator = MedusaHeadV1.load(config, prefix, weights)
+ except Exception:
+ speculator = MedusaHeadV2(config, prefix, weights)
+ lm_head = None
+ else:
+ lm_head = TensorParallelHead.load(config, prefix, weights)
+ speculator = None
+ return SpeculativeHead(lm_head, speculator)
+
+ def forward(
+ self, input: torch.Tensor
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ if self.speculator is not None:
+ return self.speculator(input)
+
+ assert self.head is not None
+ logits = self.head(input)
+ return logits, None
diff --git a/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py b/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py
new file mode 100644
index 00000000000..8f19174f80f
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py
@@ -0,0 +1,244 @@
+import torch
+from torch.nn import functional as F
+from typing import Iterable, List
+from text_generation_server.layers.linear import get_linear, FastLinear
+
+import habana_frameworks.torch as htorch
+
+
+class LayerConcat(torch.nn.Module):
+ """
+ Apply multiple layers to the input and concatenate their
+ outputs.
+ """
+
+ def __init__(self, layers: Iterable[torch.nn.Module], dim: int = -1):
+ """
+ `dim` is the dimension along which layer outputs are concatenated.
+ """
+ super().__init__()
+ self.layers = layers
+ self.dim = dim
+
+ def forward(self, x: torch.Tensor):
+ outputs = [layer(x) for layer in self.layers]
+ return torch.cat(outputs, self.dim)
+
+
+class SuperLayer(torch.nn.Module):
+ def __init__(self, linear):
+ super().__init__()
+ self.linear = linear
+
+ def forward(self, x):
+ return self.linear.forward(x)
+
+
+class TensorParallelHead(SuperLayer):
+ def __init__(self, linear, process_group, should_gather: bool):
+ super().__init__(linear)
+ self.process_group = process_group
+ self.should_gather = should_gather
+
+ @staticmethod
+ def load(config, prefix: str, weights):
+ if config.quantize == "exl2":
+ try:
+ # If the piece and LM head embeddings are shared, we have
+ # non-quantized weights...
+ weight = weights.get_tensor(f"{prefix}.weight")
+ except Exception:
+ # ...otherwise they are quantized.
+ weight = weights.get_weights_col(prefix)
+ should_gather = weights.process_group.size() > 1
+ elif weights.process_group.size() > 1:
+ try:
+ weight = weights.get_sharded(f"{prefix}.weight", dim=0)
+ should_gather = True
+ except AssertionError:
+ # If the vocab size is not divisible by number of shards
+ # just load the entire thing.
+ weight = weights.get_tensor(f"{prefix}.weight")
+ should_gather = False
+ else:
+ weight = weights.get_tensor(f"{prefix}.weight")
+ should_gather = False
+
+ return TensorParallelHead(
+ get_linear(weight, bias=None),
+ process_group=weights.process_group,
+ should_gather=should_gather,
+ )
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ if not self.should_gather:
+ return super().forward(input)
+
+ world_size = self.process_group.size()
+ if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
+ out_dim = self.linear.weight.shape[0]
+
+ if input.shape[0] == 1:
+ world_out = input.new_empty(1, out_dim * world_size)
+ local_out = input.new_empty(1, out_dim)
+ gather_input = local_out
+ else:
+ world_out = input.new_empty(out_dim * world_size, input.shape[0])
+ gather_input = input.new_empty(out_dim, input.shape[0])
+ local_out = gather_input.T
+
+ torch.mm(input, self.linear.weight.T, out=local_out)
+ htorch.core.mark_step()
+ torch.distributed.all_gather_into_tensor(
+ world_out, gather_input, group=self.process_group
+ )
+
+ if input.shape[0] == 1:
+ return world_out
+ return world_out.T
+
+ output = super().forward(input)
+ world_output = [
+ torch.empty_like(output) for _ in range(self.process_group.size())
+ ]
+
+ htorch.core.mark_step()
+ torch.distributed.all_gather(world_output, output, group=self.process_group)
+ world_output = torch.cat(world_output, dim=-1)
+ return world_output
+
+
+class TensorParallelColumnLinear(SuperLayer):
+ @classmethod
+ def load_gate_up(cls, config, prefix: str, weights, bias: bool):
+ """Specific method when the QKV was joined after the fact"""
+ weight = weights.get_weights_col_packed_gate_up(prefix)
+ if bias:
+ raise NotImplementedError("packed_gate_up only implemented without bias")
+ else:
+ bias = None
+ linear = get_linear(weight, bias)
+ return cls(linear)
+
+ @classmethod
+ def load_qkv(
+ cls,
+ config,
+ prefix: str,
+ weights,
+ bias: bool,
+ num_heads: int,
+ num_key_value_heads: int,
+ ):
+ """Specific method when the QKV was joined after the fact"""
+ weight = weights.get_weights_col_packed_qkv(
+ prefix,
+ num_heads=num_heads,
+ num_key_value_heads=num_key_value_heads,
+ )
+ if bias:
+ raise NotImplementedError("packed_qkv only implemented for baichuan")
+ else:
+ bias = None
+ linear = get_linear(weight, bias)
+ return cls(linear)
+
+ @classmethod
+ def load(cls, config, prefix: str, weights, bias: bool):
+ weight = weights.get_weights_col(prefix)
+ if bias:
+ bias = weights.get_sharded(f"{prefix}.bias", dim=0)
+ else:
+ bias = None
+ linear = get_linear(weight, bias)
+ return cls(linear)
+
+ @classmethod
+ def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
+ if config.quantize == "exl2":
+ linears = []
+ for prefix in prefixes:
+ weight = weights.get_weights_col(prefix)
+ b = weights.get_tensor(f"{prefix}.bias") if bias else None
+ linears.append(get_linear(weight, b))
+ linear = LayerConcat(linears)
+ else:
+ weight = weights.get_multi_weights_col(prefixes, dim=dim)
+ if bias:
+ b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
+ bias = torch.cat(b, dim=dim)
+ else:
+ bias = None
+ linear = get_linear(weight, bias)
+ return cls(linear)
+
+
+class TensorParallelRowLinear(SuperLayer):
+ def __init__(self, linear, process_group):
+ super().__init__(linear)
+ self.process_group = process_group
+
+ @classmethod
+ def load(cls, config, prefix: str, weights, bias: bool):
+ weight = weights.get_weights_row(prefix)
+
+ if bias and weights.process_group.rank() == 0:
+ # Rank is only on the first rank process
+ bias = weights.get_tensor(f"{prefix}.bias")
+ else:
+ bias = None
+ return cls(
+ get_linear(weight, bias),
+ process_group=weights.process_group,
+ )
+
+ def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
+ out = super().forward(input)
+ if self.process_group.size() > 1 and reduce:
+ # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
+ # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
+ # (which is required for tensor parallel HPUGraph inference)
+ htorch.core.mark_step()
+ torch.distributed.all_reduce(out, group=self.process_group)
+ return out
+
+
+class TensorParallelEmbedding(torch.nn.Module):
+ def __init__(self, prefix: str, weights, reduce=True):
+ super().__init__()
+ weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
+ num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
+
+ process_group = weights.process_group
+
+ world_size = process_group.size()
+ rank = process_group.rank()
+
+ block_size = (num_embeddings + world_size - 1) // world_size
+ self.min_id = rank * block_size
+ self.max_id = min(num_embeddings, (rank + 1) * block_size)
+ self.null_idx = weight.shape[
+ 0
+ ] # Usually block_size, might be less in non even vocab_size.
+ self.process_group = weights.process_group
+ self.reduce = reduce
+
+ """Additional 0 entry used for masking"""
+ self.weight = torch.nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ # default all out of bounds values to `self.null_idx` that will then be mapped to 0
+ # translate for [0, self.max_id - self.min_id[
+ input = torch.where(
+ (self.min_id > input) | (input >= self.max_id),
+ self.null_idx,
+ input - self.min_id,
+ )
+ out = torch.nn.functional.embedding(input, self.weight)
+ if self.reduce and self.process_group.size() > 1:
+ # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
+ # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
+ # (which is required for tensor parallel HPUGraph inference)
+ htorch.core.mark_step()
+ torch.distributed.all_reduce(out, group=self.process_group)
+ return out
diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py
new file mode 100644
index 00000000000..76e64f3a580
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/__init__.py
@@ -0,0 +1,1060 @@
+# ruff: noqa: F821
+# the above line disables the `undefined-name` rule for the model type variables
+import torch
+import os
+
+from loguru import logger
+from transformers.configuration_utils import PretrainedConfig
+from huggingface_hub import hf_hub_download, HfApi
+from typing import Optional
+from pathlib import Path
+from typing import List, Dict
+import enum
+
+# Needed to properly setup habana_frameworks
+
+from text_generation_server.utils.speculate import get_speculate, set_speculate
+from text_generation_server.models.model import Model
+from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
+ PhiMoEConfig,
+)
+
+from text_generation_server.utils.adapter import (
+ AdapterParameters,
+ build_layer_weight_lookup,
+ load_and_merge_adapters,
+ AdapterInfo,
+)
+from text_generation_server.adapters.lora import LoraWeights
+
+from text_generation_server.utils.log import log_master
+
+__all__ = [
+ "Model",
+ "CausalLM",
+ "Seq2SeqLM",
+ "get_model_with_lora_adapters",
+]
+
+VLM_BATCH_TYPES = set()
+
+FLASH_ATTENTION = True
+
+try:
+ from text_generation_server.models.flash_causal_lm import FlashCausalLM
+ from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM
+ from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLM
+ from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
+ FlashDeepseekV2ForCausalLM,
+ DeepseekV2Config,
+ )
+ from text_generation_server.models.custom_modeling.flash_deepseek_v3_modeling import (
+ FlashDeepseekV3ForCausalLM,
+ DeepseekV3Config,
+ )
+ from text_generation_server.models.custom_modeling.flash_llama_modeling import (
+ FlashLlamaForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_llama4_modeling import (
+ Llama4ForConditionalGeneration,
+ )
+ from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
+ FlashCohereForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
+ FlashGemmaForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
+ FlashGemma2ForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_gemma3_modeling import (
+ Gemma3ForConditionalGeneration,
+ FlashGemma3ForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
+ FlashDbrxForCausalLM,
+ DbrxConfig,
+ )
+ from text_generation_server.models.custom_modeling.flash_rw_modeling import (
+ RWConfig,
+ FlashRWForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_neox_modeling import (
+ FlashGPTNeoXForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
+ PaliGemmaForConditionalGeneration,
+ )
+ from text_generation_server.models.custom_modeling.flash_phi_modeling import (
+ FlashPhiForCausalLM,
+ )
+ from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
+ from text_generation_server.models.custom_modeling.flash_mllama import (
+ FlashMllamaForConditionalGeneration,
+ )
+ from text_generation_server.models.custom_modeling.flash_llava_next import (
+ FlashLlavaNextForConditionalGeneration,
+ )
+
+ from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
+ FlashSantacoderForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
+ FlashStarcoder2ForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
+ Qwen2ForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_qwen3_modeling import (
+ Qwen3ForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_qwen3_moe_modeling import (
+ Qwen3MoeForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
+ FlashMistralForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
+ FlashMixtralForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
+ FlashGPT2ForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_gptj_modeling import (
+ FlashGPTJForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.idefics2 import (
+ Idefics2ForConditionalGeneration,
+ )
+ from text_generation_server.models.custom_modeling.idefics3 import (
+ Idefics3ForConditionalGeneration,
+ )
+ from text_generation_server.models.custom_modeling.qwen2_vl import (
+ Qwen2VLForConditionalGeneration,
+ )
+ from text_generation_server.models.custom_modeling.qwen2_5_vl import (
+ Qwen2_5VLForConditionalGeneration,
+ Qwen2_5_VLConfig,
+ Qwen2_5_VLProcessor,
+ )
+ from text_generation_server.layers.attention import SUPPORTS_WINDOWING
+except ImportError as e:
+ log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
+ SUPPORTS_WINDOWING = False
+ FLASH_ATTENTION = False
+ VLM_BATCH_TYPES = set()
+
+if FLASH_ATTENTION:
+ __all__.append(FlashCausalLM)
+
+ from text_generation_server.models.flash_vlm_causal_lm import (
+ FlashVlmCausalLMBatch,
+ )
+
+ VLM_BATCH_TYPES = {
+ FlashVlmCausalLMBatch,
+ FlashMllamaCausalLMBatch,
+ }
+
+
+__all__.append(VLM_BATCH_TYPES)
+
+
+class ModelType(enum.Enum):
+ DEEPSEEK_V2 = {
+ "type": "deepseek_v2",
+ "name": "Deepseek V2",
+ "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
+ }
+ DEEPSEEK_V3 = {
+ "type": "deepseek_v3",
+ "name": "Deepseek V3",
+ "url": "https://huggingface.co/deepseek-ai/DeepSeek-V3",
+ }
+ IDEFICS2 = {
+ "type": "idefics2",
+ "name": "Idefics 2",
+ "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
+ "multimodal": True,
+ }
+ IDEFICS3 = {
+ "type": "idefics3",
+ "name": "Idefics 3",
+ "url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3",
+ "multimodal": True,
+ }
+ LLAVA_NEXT = {
+ "type": "llava_next",
+ "name": "Llava Next (1.6)",
+ "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
+ "multimodal": True,
+ }
+ LLAMA = {
+ "type": "llama",
+ "name": "Llama",
+ "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
+ }
+ LLAMA4 = {
+ "type": "llama4",
+ "name": "Llama4",
+ "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
+ }
+ PHI3 = {
+ "type": "phi3",
+ "name": "Phi 3",
+ "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
+ }
+ GRANITE = {
+ "type": "granite",
+ "name": "Granite",
+ "url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
+ }
+ GEMMA = {
+ "type": "gemma",
+ "name": "Gemma",
+ "url": "https://huggingface.co/google/gemma-7b",
+ }
+ PALIGEMMA = {
+ "type": "paligemma",
+ "name": "PaliGemma",
+ "url": "https://huggingface.co/google/paligemma-3b-pt-224",
+ }
+ GEMMA2 = {
+ "type": "gemma2",
+ "name": "Gemma2",
+ "url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
+ }
+ GEMMA3 = {
+ "type": "gemma3",
+ "name": "Gemma3",
+ "url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d",
+ }
+ GEMMA3_TEXT = {
+ "type": "gemma3_text",
+ "name": "Gemma3 Text",
+ "url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d",
+ }
+ COHERE = {
+ "type": "cohere",
+ "name": "Cohere",
+ "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
+ }
+ DBRX = {
+ "type": "dbrx",
+ "name": "Dbrx",
+ "url": "https://huggingface.co/databricks/dbrx-instruct",
+ }
+ MAMBA = {
+ "type": "mamba",
+ "name": "Mamba",
+ "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
+ }
+ MISTRAL = {
+ "type": "mistral",
+ "name": "Mistral",
+ "url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407",
+ }
+ MIXTRAL = {
+ "type": "mixtral",
+ "name": "Mixtral",
+ "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
+ }
+ GPT_BIGCODE = {
+ "type": "gpt_bigcode",
+ "name": "Gpt Bigcode",
+ "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
+ }
+ PHI = {
+ "type": "phi",
+ "name": "Phi",
+ "url": "https://huggingface.co/microsoft/phi-1_5",
+ }
+ PHI_MOE = {
+ "type": "phimoe",
+ "name": "PhiMoe",
+ "url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
+ }
+ BAICHUAN = {
+ "type": "baichuan",
+ "name": "Baichuan",
+ "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
+ }
+ FALCON = {
+ "type": "falcon",
+ "name": "Falcon",
+ "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
+ }
+ STARCODER2 = {
+ "type": "starcoder2",
+ "name": "StarCoder 2",
+ "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
+ }
+ QWEN2 = {
+ "type": "qwen2",
+ "name": "Qwen 2",
+ "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
+ }
+ QWEN2_VL = {
+ "type": "qwen2_vl",
+ "name": "Qwen 2 VL",
+ "url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
+ }
+ QWEN2_5_VL = {
+ "type": "qwen2_5_vl",
+ "name": "Qwen 2.5 VL",
+ "url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e",
+ }
+ QWEN3 = {
+ "type": "qwen3",
+ "name": "Qwen 3",
+ "url": "https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f",
+ }
+ QWEN3_MOE = {
+ "type": "qwen3_moe",
+ "name": "Qwen 3 Moe",
+ "url": "https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f",
+ }
+ GALACTICA = {
+ "type": "galactica",
+ "name": "Galactica",
+ "url": "https://huggingface.co/facebook/galactica-120b",
+ }
+ SANTACODER = {
+ "type": "santacoder",
+ "name": "SantaCoder",
+ "url": "https://huggingface.co/bigcode/santacoder",
+ }
+ GPT2 = {
+ "type": "gpt2",
+ "name": "Gpt2",
+ "url": "https://huggingface.co/openai-community/gpt2",
+ }
+ GPT_NEOX = {
+ "type": "gpt_neox",
+ "name": "Gpt Neox",
+ "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
+ }
+ GPTJ = {
+ "type": "gptj",
+ "name": "Gptj",
+ "url": "https://huggingface.co/EleutherAI/gpt-j-6b",
+ }
+ MLLAMA = {
+ "type": "mllama",
+ "name": "Mllama",
+ "url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
+ "multimodal": True,
+ }
+
+
+__GLOBALS = locals()
+for data in ModelType:
+ __GLOBALS[data.name] = data.value["type"]
+
+SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 0))
+# Disable gradients
+torch.set_grad_enabled(False)
+
+
+def get_model(
+ model_id: str,
+ lora_adapter_ids: Optional[List[str]],
+ revision: Optional[str],
+ sharded: bool,
+ quantize: Optional[str],
+ speculate: Optional[int],
+ dtype: Optional[torch.dtype],
+ kv_cache_dtype: Optional[str],
+ trust_remote_code: bool,
+ max_input_tokens: int,
+) -> Model:
+ global FLASH_ATTENTION
+
+ if speculate is not None:
+ set_speculate(speculate)
+ else:
+ set_speculate(0)
+
+ config_dict, _ = PretrainedConfig.get_config_dict(
+ model_id, revision=revision, trust_remote_code=trust_remote_code
+ )
+ model_type = config_dict.get("model_type", None)
+
+ speculator = None
+ if "medusa_num_heads" in config_dict:
+ medusa_model_id = model_id
+ medusa_revision = revision
+ model_id = config_dict["base_model_name_or_path"]
+ revision = "main"
+ speculate_medusa = config_dict["medusa_num_heads"]
+ if speculate is not None:
+ if speculate > speculate_medusa:
+ raise RuntimeError(
+ f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
+ )
+ else:
+ set_speculate(speculate)
+ else:
+ set_speculate(speculate_medusa)
+
+ config_dict, _ = PretrainedConfig.get_config_dict(
+ model_id, revision=revision, trust_remote_code=trust_remote_code
+ )
+ # Reload model type from parent.
+ model_type = config_dict.get("model_type", None)
+ is_local = Path(medusa_model_id).exists()
+ if not is_local:
+ medusa_config = hf_hub_download(
+ medusa_model_id, revision=medusa_revision, filename="config.json"
+ )
+ hf_hub_download(
+ medusa_model_id,
+ revision=medusa_revision,
+ filename="medusa_lm_head.safetensors",
+ )
+ speculator = {
+ "path": Path(medusa_config).parent,
+ "model_paths": ["medusa_lm_head.safetensors"],
+ }
+ else:
+ speculator = {
+ "path": Path(medusa_model_id),
+ "model_paths": ["medusa_lm_head.safetensors"],
+ }
+
+ method = "medusa"
+ elif model_type == "mlp_speculator":
+ mlp_model_id = model_id
+ mlp_revision = revision
+ model_id = config_dict["base_model_name_or_path"]
+ revision = "main"
+ speculate_mlp = config_dict["n_predict"]
+ if speculate is not None:
+ if speculate > speculate_mlp:
+ raise RuntimeError(
+ f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
+ )
+ else:
+ set_speculate(speculate)
+ else:
+ set_speculate(speculate_mlp)
+
+ config_dict, _ = PretrainedConfig.get_config_dict(
+ model_id, revision=revision, trust_remote_code=trust_remote_code
+ )
+ # Reload model type from parent.
+ model_type = config_dict.get("model_type", None)
+ is_local = Path(mlp_model_id).exists()
+ extension = ".safetensors"
+ if not is_local:
+ mlp_speculator_config = hf_hub_download(
+ mlp_model_id, revision=mlp_revision, filename="config.json"
+ )
+ api = HfApi()
+ info = api.model_info(mlp_model_id, revision=mlp_revision)
+ filenames = [
+ s.rfilename
+ for s in info.siblings
+ if s.rfilename.endswith(extension)
+ and len(s.rfilename.split("/")) == 1
+ and "arguments" not in s.rfilename
+ and "args" not in s.rfilename
+ and "training" not in s.rfilename
+ ]
+ for filename in filenames:
+ hf_hub_download(
+ mlp_model_id,
+ revision=mlp_revision,
+ filename=filename,
+ )
+ speculator_dir_path = Path(mlp_speculator_config).parent
+ # if these are downloaded, they get converted to safetensors
+ filenames.extend(
+ [p for p in os.listdir(speculator_dir_path) if p.endswith(extension)]
+ )
+ speculator = {
+ "path": Path(mlp_speculator_config).parent,
+ "model_paths": filenames,
+ }
+ else:
+ speculator = Path(mlp_model_id)
+ filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
+ speculator = {"path": speculator, "model_paths": filenames}
+ method = "mlp_speculator"
+ else:
+ method = "n-gram"
+
+ speculate = get_speculate()
+ if speculate > 0:
+ logger.info(f"Using speculation {method} with {speculate} input ids.")
+
+ model_type = config_dict["model_type"]
+
+ if kv_cache_dtype == "fp8_e4m3fn":
+ kv_cache_dtype = torch.float8_e4m3fn
+ elif kv_cache_dtype == "fp8_e5m2":
+ kv_cache_dtype = torch.float8_e5m2
+ else:
+ kv_cache_dtype = dtype
+
+ if FLASH_ATTENTION:
+ if model_type == DEEPSEEK_V2:
+ head_size = max(
+ config_dict.get("qk_nope_dim", 128)
+ + config_dict.get("qk_rope_dim", 64),
+ config_dict.get("v_head_dim", 128),
+ )
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashDeepseekV2ForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ default_dtype=torch.bfloat16,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ config_class=DeepseekV2Config,
+ head_size=head_size,
+ )
+ elif model_type == DEEPSEEK_V3:
+ head_size = max(
+ config_dict.get("qk_nope_dim", 128)
+ + config_dict.get("qk_rope_dim", 64),
+ config_dict.get("v_head_dim", 128),
+ )
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashDeepseekV3ForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ default_dtype=torch.bfloat16,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ config_class=DeepseekV3Config,
+ head_size=head_size,
+ )
+
+ elif (
+ model_type == GPT_BIGCODE
+ or model_type == GPT2
+ and model_id.startswith("bigcode/")
+ ):
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashSantacoderForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ aliases={"transformer.wte.weight": ["lm_head.weight"]},
+ num_kv_heads=1,
+ )
+ elif model_type == GPT2:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashGPT2ForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == GPTJ:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashGPTJForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == GPT_NEOX:
+ from text_generation_server.models.custom_modeling.flash_neox_modeling import (
+ GPTNeoXConfig,
+ )
+
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashGPTNeoXForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ config_class=GPTNeoXConfig,
+ )
+ elif model_type == PHI:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashPhiForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == PHI_MOE:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashLlamaForCausalLM,
+ config_class=PhiMoEConfig,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashLlamaForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == LLAMA4:
+ print(f"Llama4 model detected: {model_id}")
+ return FlashVlmCausalLM(
+ model_id=model_id,
+ model_class=Llama4ForConditionalGeneration,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ default_dtype=torch.bfloat16,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ support_chunking=False,
+ )
+ elif model_type == BAICHUAN:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashLlamaForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == GEMMA:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashGemmaForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ # Works better for these models
+ default_dtype=torch.bfloat16,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == GEMMA2:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashGemma2ForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ # Works better for these models
+ default_dtype=torch.bfloat16,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == GEMMA3:
+ return FlashVlmCausalLM(
+ model_id=model_id,
+ model_class=Gemma3ForConditionalGeneration,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ default_dtype=torch.bfloat16,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ support_chunking=False,
+ )
+ elif model_type == GEMMA3_TEXT:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashGemma3ForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ # Works better for these models
+ default_dtype=torch.bfloat16,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == COHERE:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashCohereForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == DBRX:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashDbrxForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ # Dbrx works better in bfloat16.
+ default_dtype=torch.bfloat16,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ config_class=DbrxConfig,
+ )
+ elif (
+ model_type in ["RefinedWeb", "RefinedWebModel", FALCON]
+ and not sharded
+ and not config_dict.get("alibi", False)
+ ):
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashRWForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ aliases={
+ "lm_head.weight": ["transformer.word_embeddings.weight"],
+ "transformer.word_embeddings.weight": ["lm_head.weight"],
+ },
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ config_class=RWConfig,
+ )
+ elif model_type == MISTRAL:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashMistralForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == MIXTRAL:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashMixtralForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == STARCODER2:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashStarcoder2ForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == QWEN2:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=Qwen2ForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == QWEN2_VL:
+ return FlashVlmCausalLM(
+ model_id=model_id,
+ model_class=Qwen2VLForConditionalGeneration,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ default_dtype=torch.bfloat16,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ # TODO: Fix bug in rust image_text_replacement implementation
+ support_chunking=False,
+ )
+ elif model_type == QWEN2_5_VL:
+ return FlashVlmCausalLM(
+ model_id=model_id,
+ model_class=Qwen2_5VLForConditionalGeneration,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ default_dtype=torch.bfloat16,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ config_class=Qwen2_5_VLConfig,
+ processor_class=Qwen2_5_VLProcessor,
+ # TODO: Fix bug in rust image_text_replacement implementation
+ support_chunking=False,
+ )
+ elif model_type == QWEN3:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=Qwen3ForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == QWEN3_MOE:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=Qwen3MoeForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == MLLAMA:
+ return FlashMllamaCausalLM(
+ model_id=model_id,
+ model_class=FlashMllamaForConditionalGeneration,
+ batch_class=FlashMllamaCausalLMBatch,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ default_dtype=torch.bfloat16,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ support_chunking=False,
+ )
+ elif model_type == IDEFICS2:
+ return FlashVlmCausalLM(
+ model_id=model_id,
+ model_class=Idefics2ForConditionalGeneration,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ # XXX: Extremely important to cap resolution in order to limit
+ # VRAM usage.
+ processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
+ )
+ elif model_type == IDEFICS3:
+ return FlashVlmCausalLM(
+ model_id=model_id,
+ model_class=Idefics3ForConditionalGeneration,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ default_dtype=torch.bfloat16,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ # XXX: Extremely important to cap resolution in order to limit
+ # VRAM usage.
+ processor_kwargs={"size": {"longest_edge": 1456}},
+ )
+ elif model_type == PALIGEMMA:
+ return FlashVlmCausalLM(
+ model_id=model_id,
+ model_class=PaliGemmaForConditionalGeneration,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ # Works better for these models
+ default_dtype=torch.bfloat16,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == LLAVA_NEXT:
+ return FlashVlmCausalLM(
+ model_class=FlashLlavaNextForConditionalGeneration,
+ model_id=model_id,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ )
+
+ raise ValueError(f"Unsupported model type {model_type}")
+
+
+# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
+# this provides a post model loading hook to load adapters into the model after the model has been loaded
+def get_model_with_lora_adapters(
+ model_id: str,
+ lora_adapters: Optional[List[AdapterInfo]],
+ revision: Optional[str],
+ sharded: bool,
+ quantize: Optional[str],
+ speculate: Optional[int],
+ dtype: Optional[torch.dtype],
+ kv_cache_dtype: Optional[str],
+ trust_remote_code: bool,
+ max_input_tokens: int,
+ adapter_to_index: Dict[str, int],
+):
+ lora_adapter_ids = [adapter.id for adapter in lora_adapters]
+ model = get_model(
+ model_id,
+ lora_adapter_ids,
+ revision,
+ sharded,
+ quantize,
+ speculate,
+ dtype,
+ kv_cache_dtype,
+ trust_remote_code,
+ max_input_tokens,
+ )
+
+ if len(lora_adapters) > 0:
+ target_to_layer = build_layer_weight_lookup(model.model)
+
+ for index, adapter in enumerate(lora_adapters):
+ # The AdapterParameters object allows for merging multiple adapters into a single adapter.
+ # At the moment, we only support loading a single adapter into the model, but we keep the
+ # AdapterParameters object for easier extension in the future.
+ adapter_parameters = AdapterParameters(
+ adapter_info=[adapter],
+ # when merging multiple adapters we can weight them differently
+ # if this is not set, all adapters will be weighted equally
+ # see: text_generation_server.utils.merges.strategies for impl
+ weights=None,
+ merge_strategy=0,
+ density=1.0,
+ majority_sign_method=0,
+ )
+
+ adapter_index = index + 1
+ adapter_to_index[adapter.id] = adapter_index
+
+ logger.info(
+ f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
+ )
+ weight_names = tuple([v[0] for v in target_to_layer.values()])
+ (
+ module_map,
+ adapter_config,
+ adapter_weight_names,
+ adapter_tokenizer,
+ ) = load_and_merge_adapters(
+ model.model_id,
+ adapter_parameters,
+ adapter_index,
+ weight_names,
+ False,
+ )
+
+ unused_weight_names = adapter_weight_names.copy()
+
+ adapter_layers = [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ "o_proj",
+ "gate_proj",
+ "up_proj",
+ "down_proj",
+ "qkv_proj",
+ ]
+
+ for layer_name in adapter_layers:
+ nlayers = (
+ 1 if layer_name == "lm_head" else len(model.model.model.layers)
+ )
+ adapter_weights = LoraWeights.prepare_weights(
+ config=adapter_config,
+ module_map=module_map,
+ layer_type=layer_name,
+ unused_weight_names=unused_weight_names,
+ nlayers=nlayers,
+ dtype=model.dtype,
+ world_size=model.world_size,
+ process_group=model.process_group,
+ target_to_layer=target_to_layer,
+ )
+
+ if adapter_weights is None:
+ continue
+
+ model.layer_to_adapter_weights[layer_name].add_adapter(
+ adapter_index, adapter_weights
+ )
+
+ if len(unused_weight_names) > 0:
+ logger.warning(
+ f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}"
+ )
+
+ if adapter_tokenizer is not None:
+ model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
+
+ model.loaded_adapters.add(adapter_index)
+
+ return model
diff --git a/.devcontainer/devcontainer.json b/backends/gaudi/server/text_generation_server/models/custom_modeling/__init__.py
similarity index 100%
rename from .devcontainer/devcontainer.json
rename to backends/gaudi/server/text_generation_server/models/custom_modeling/__init__.py
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py
new file mode 100644
index 00000000000..84835ab89bb
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py
@@ -0,0 +1,923 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch BLOOM model."""
+
+import math
+import os
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.distributed
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import LayerNorm
+from torch.nn import functional as F
+
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+)
+from transformers import BloomConfig, PreTrainedModel
+
+from text_generation_server.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ TensorParallelRowLinear,
+ SpeculativeHead,
+)
+
+CUSTOM_KERNELS_ENABLED = False
+if (
+ torch.cuda.is_available()
+ and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True"
+):
+ try:
+ from custom_kernels import fused_bloom_attention_cuda
+
+ CUSTOM_KERNELS_ENABLED = True
+ except ImportError:
+ pass
+
+_CHECKPOINT_FOR_DOC = "bigscience/bloom-560m"
+_CONFIG_FOR_DOC = "BloomConfig"
+
+BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "bigscience/bigscience-small-testing",
+ "bigscience/bloom-560m",
+ "bigscience/bloom-1b1",
+ "bigscience/bloom-1b7",
+ "bigscience/bloom-3b",
+ "bigscience/bloom-7b1",
+ "bigscience/bloom",
+]
+
+
+def _make_causal_mask(
+ input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
+) -> torch.BoolTensor:
+ """
+ Make causal mask used for self-attention.
+ """
+ batch_size, target_length = input_ids_shape
+ mask = torch.ones(
+ (target_length, target_length + past_key_values_length),
+ dtype=torch.bool,
+ device=device,
+ )
+ mask = mask.triu(1 + past_key_values_length)
+
+ expanded_mask = mask.unsqueeze(0).expand(
+ batch_size, target_length, target_length + past_key_values_length
+ )
+ return expanded_mask
+
+
+def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
+ """
+ Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
+ """
+ batch_size, src_length = mask.shape
+ tgt_length = tgt_length if tgt_length is not None else src_length
+
+ expanded_mask = ~(mask[:, None, :].to(torch.bool))
+ return expanded_mask.expand(batch_size, tgt_length, src_length)
+
+
+def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int) -> torch.Tensor:
+ """
+ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
+ relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
+ `softmax(l+a) = softmax(l)`. Based on
+ https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
+ TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
+
+ Args:
+ Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
+ attention_mask (`torch.Tensor`):
+ Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
+ num_heads (`int`, *required*):
+ number of heads
+ dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
+ dtype of the output tensor
+ """
+ batch_size, seq_length = attention_mask.shape
+ closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
+ base = torch.tensor(
+ 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
+ device=attention_mask.device,
+ dtype=torch.float32,
+ )
+ powers = torch.arange(
+ 1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32
+ )
+ slopes = torch.pow(base, powers)
+
+ if closest_power_of_2 != num_heads:
+ extra_base = torch.tensor(
+ 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
+ device=attention_mask.device,
+ dtype=torch.float32,
+ )
+ num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
+ extra_powers = torch.arange(
+ 1,
+ 1 + 2 * num_remaining_heads,
+ 2,
+ device=attention_mask.device,
+ dtype=torch.int32,
+ )
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
+
+ # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
+ # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
+ # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
+ # => the query_length dimension will then be broadcasted correctly
+ # This is more or less identical to T5's relative position bias:
+ # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
+ arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
+ alibi = slopes[..., None] * arange_tensor
+ return alibi
+
+
+# @torch.jit.script
+def dropout_add(
+ x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool
+) -> torch.Tensor:
+ """
+ Dropout add function
+
+ Args:
+ x (`torch.tensor`, *required*):
+ input tensor
+ residual (`torch.tensor`, *required*):
+ esidual tensor
+ prob (`float`, *required*):
+ dropout probability
+ training (`bool`, *required*):
+ training mode
+ """
+ out = F.dropout(x, p=prob, training=training)
+ out = residual + out
+ return out
+
+
+# @torch.jit.script # this is shit for unknow reasons.
+def _split_heads(
+ fused_qkv: torch.Tensor, num_heads: int, head_dim: int
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
+ storage as `fused_qkv`
+
+ Args:
+ fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
+
+ Returns:
+ query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
+ value: [batch_size, seq_length, num_heads, head_dim]
+ """
+ batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
+ fused_qkv = fused_qkv.view(batch_size, seq_length, num_heads, 3 * head_dim)
+ query_layer, key_layer, value_layer = fused_qkv.split(head_dim, dim=-1)
+
+ query_layer = query_layer.transpose(1, 2).reshape(
+ batch_size * num_heads, seq_length, head_dim
+ )
+ key_layer = key_layer.permute(0, 2, 3, 1).reshape(
+ batch_size * num_heads, head_dim, seq_length
+ )
+ value_layer = value_layer.transpose(1, 2).reshape(
+ batch_size * num_heads, seq_length, head_dim
+ )
+
+ return query_layer, key_layer, value_layer
+
+
+# @torch.jit.script
+def _merge_heads(x: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor:
+ """
+ Merge heads together over the last dimenstion
+
+ Args:
+ x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
+
+ Returns:
+ torch.tensor: [batch_size, seq_length, num_heads * head_dim]
+ """
+ # What we want to achieve is:
+ # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
+ batch_size_and_num_heads, seq_length, _ = x.shape
+ batch_size = batch_size_and_num_heads // num_heads
+
+ # First view to decompose the batch size
+ # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
+ x = x.view(batch_size, num_heads, seq_length, head_dim)
+
+ # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
+ x = x.permute(0, 2, 1, 3)
+
+ # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
+ return x.reshape(batch_size, seq_length, num_heads * head_dim)
+
+
+class BloomAttention(nn.Module):
+ def __init__(self, prefix, config: BloomConfig, weights):
+ super().__init__()
+
+ self.pretraining_tp = config.pretraining_tp
+ self.slow_but_exact = config.slow_but_exact
+
+ self.process_group = weights.process_group
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.n_head
+ self.head_dim = self.hidden_size // self.num_heads
+ self.split_size = self.hidden_size
+ self.hidden_dropout = config.hidden_dropout
+
+ if self.head_dim * self.num_heads != self.hidden_size:
+ raise ValueError(
+ f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+
+ # Layer-wise attention scaling
+ self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
+ self.beta = 1.0
+
+ process_group = weights.process_group
+ if self.num_heads % process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {process_group.size()}"
+ )
+ self.num_heads = self.num_heads // process_group.size()
+ self.query_key_value = TensorParallelColumnLinear.load(
+ config=config,
+ prefix=f"{prefix}.query_key_value",
+ weights=weights,
+ bias=True,
+ )
+ self.dense = TensorParallelRowLinear.load(
+ config=config, prefix=f"{prefix}.dense", weights=weights, bias=True
+ )
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
+
+ @staticmethod
+ def compute_attention(
+ fused_qkv: torch.Tensor,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]],
+ alibi: torch.Tensor,
+ attention_mask: torch.Tensor,
+ head_mask: Optional[torch.Tensor],
+ beta: float,
+ inv_norm_factor: float,
+ num_heads: int,
+ use_cache: bool,
+ ):
+ batch_size, q_length, three_times_hidden_size = fused_qkv.shape
+ head_dim = three_times_hidden_size // (3 * num_heads)
+ batch_size * num_heads
+
+ ### TODO @thomasw21: this takes quite a bit of time, how do I accelerate that?
+ # 3 x [batch_size, seq_length, num_heads, head_dim]
+ (query_layer, key_layer, value_layer) = _split_heads(
+ fused_qkv, num_heads=num_heads, head_dim=head_dim
+ )
+
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ # concatenate along seq_length dimension:
+ # - key: [batch_size * self.num_heads, head_dim, kv_length]
+ # - value: [batch_size * self.num_heads, kv_length, head_dim]
+ past_key = past_key.view(-1, *past_key.shape[-2:])
+ key_layer = torch.cat((past_key, key_layer), dim=2)
+ past_value = past_value.view(-1, *past_value.shape[-2:])
+ value_layer = torch.cat((past_value, value_layer), dim=1)
+
+ _, _, kv_length = key_layer.shape
+
+ if use_cache is True:
+ present = (key_layer, value_layer)
+ else:
+ present = None
+ ###
+
+ # [batch_size * num_heads, q_length, kv_length]
+ # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
+ attention_scores = alibi.baddbmm(
+ batch1=query_layer,
+ batch2=key_layer,
+ beta=beta,
+ alpha=inv_norm_factor,
+ )
+
+ # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
+ input_dtype = attention_scores.dtype
+ # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
+ if input_dtype == torch.float16:
+ attention_scores = attention_scores.to(torch.float)
+ # torch.finfo not supported by torch.jit, we temporarily remplace with `-1e34`
+ attn_weights = attention_scores.masked_fill_(
+ attention_mask, torch.finfo(attention_scores.dtype).min
+ )
+ attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
+ input_dtype
+ )
+
+ # # [batch_size, num_heads, q_length, kv_length]
+ # attention_probs = self.attention_dropout(attention_probs)
+
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ # matmul: [batch_size * num_heads, q_length, head_dim]
+ context_layer = torch.bmm(attention_probs, value_layer, out=query_layer)
+
+ # change view [batch_size, num_heads, q_length, head_dim]
+ context_layer = _merge_heads(
+ context_layer, num_heads=num_heads, head_dim=head_dim
+ )
+
+ return context_layer, present, attention_probs
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor,
+ alibi: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+ fused_qkv = self.query_key_value(
+ hidden_states
+ ) # [batch_size, seq_length, 3 x hidden_size]
+ batch_size, q_length, _ = fused_qkv.shape
+
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ layer_past = (
+ past_key.view(-1, *past_key.shape[-2:]),
+ past_value.view(-1, *past_value.shape[-2:]),
+ )
+
+ if CUSTOM_KERNELS_ENABLED and attention_mask.shape[-1] < 4096:
+ assert self.training is False, "Only foward pass was implemented"
+ assert (
+ attention_mask.shape[-1] < 4096
+ ), "Custom kernel support only up to 4096 tokens"
+ (
+ context_layer,
+ present,
+ attention_probs,
+ ) = fused_bloom_attention_cuda.forward(
+ fused_qkv,
+ layer_past,
+ alibi,
+ attention_mask,
+ head_mask,
+ self.beta,
+ self.inv_norm_factor,
+ self.num_heads,
+ use_cache,
+ )
+ else:
+ context_layer, present, attention_probs = self.compute_attention(
+ fused_qkv=fused_qkv,
+ layer_past=layer_past,
+ alibi=alibi,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ beta=self.beta,
+ inv_norm_factor=self.inv_norm_factor,
+ num_heads=self.num_heads,
+ use_cache=use_cache,
+ )
+
+ # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
+ if self.pretraining_tp > 1 and self.slow_but_exact:
+ slices = self.hidden_size / self.pretraining_tp
+ output_tensor = torch.zeros_like(context_layer)
+ for i in range(self.pretraining_tp):
+ output_tensor = output_tensor + F.linear(
+ context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
+ self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
+ )
+ else:
+ output_tensor = self.dense(context_layer)
+
+ # output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
+ output_tensor += residual
+
+ outputs = (output_tensor, present)
+ if output_attentions:
+ outputs += (attention_probs,)
+
+ return outputs
+
+
+class BloomMLP(nn.Module):
+ def __init__(self, prefix, config: BloomConfig, weights):
+ super().__init__()
+
+ self.pretraining_tp = config.pretraining_tp
+ self.slow_but_exact = config.slow_but_exact
+ self.dense_h_to_4h = TensorParallelColumnLinear.load(
+ config=config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
+ )
+ self.dense_4h_to_h = TensorParallelRowLinear.load(
+ config=config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
+ )
+ self.gelu_impl = torch.nn.GELU(approximate="tanh")
+ self.hidden_dropout = config.hidden_dropout
+
+ def forward(
+ self, hidden_states: torch.Tensor, residual: torch.Tensor
+ ) -> torch.Tensor:
+ hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
+
+ if self.pretraining_tp > 1 and self.slow_but_exact:
+ intermediate_output = torch.zeros_like(residual)
+ slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
+ for i in range(self.pretraining_tp):
+ intermediate_output = intermediate_output + F.linear(
+ hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
+ self.dense_4h_to_h.weight[
+ :, int(i * slices) : int((i + 1) * slices)
+ ],
+ )
+ else:
+ intermediate_output = self.dense_4h_to_h(hidden_states)
+
+ # output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
+ intermediate_output += residual
+
+ return intermediate_output
+
+
+class BloomBlock(nn.Module):
+ def __init__(self, layer_id: int, config: BloomConfig, weights):
+ super().__init__()
+
+ prefix = f"h.{layer_id}"
+ self.input_layernorm = LayerNorm.load(
+ prefix=f"{prefix}.input_layernorm",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+ self.num_heads = config.n_head
+ self.self_attention = BloomAttention(
+ prefix=f"{prefix}.self_attention", config=config, weights=weights
+ )
+ self.post_attention_layernorm = LayerNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+
+ self.mlp = BloomMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+ self.apply_residual_connection_post_layernorm = (
+ config.apply_residual_connection_post_layernorm
+ )
+ self.hidden_dropout = config.hidden_dropout
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ alibi: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+ # hidden_states: [batch_size, seq_length, hidden_size]
+
+ # Layer norm at the beginning of the transformer layer.
+ layernorm_output = self.input_layernorm(hidden_states)
+
+ # Layer norm post the self attention.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+
+ # Self attention.
+ attn_outputs = self.self_attention(
+ layernorm_output,
+ residual,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ alibi=alibi,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ attention_output = attn_outputs[0]
+
+ outputs = attn_outputs[1:]
+
+ layernorm_output = self.post_attention_layernorm(attention_output)
+
+ # Get residual
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = attention_output
+
+ # MLP.
+ output = self.mlp(layernorm_output, residual)
+
+ if use_cache:
+ outputs = (output,) + outputs
+ else:
+ outputs = (output,) + outputs[1:]
+
+ return outputs # hidden_states, present, attentions
+
+
+class BloomPreTrainedModel(PreTrainedModel):
+ config_class = BloomConfig
+ base_model_prefix = "transformer"
+ _no_split_modules = ["BloomBlock"]
+
+ @staticmethod
+ def _convert_to_standard_cache(
+ past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
+ num_heads, ...]))
+ """
+ batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
+ num_heads = batch_size_times_num_heads // batch_size
+ # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
+ # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
+ return tuple(
+ (
+ layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
+ layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
+ )
+ for layer_past in past_key_value
+ )
+
+ @staticmethod
+ def _convert_to_bloom_cache(
+ past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
+ """
+ batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
+ batch_size_times_num_heads = batch_size * num_heads
+ # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
+ # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
+ return tuple(
+ (
+ layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
+ layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
+ )
+ for layer_past in past_key_value
+ )
+
+
+class BloomModel(BloomPreTrainedModel):
+ def __init__(self, config: BloomConfig, weights):
+ super().__init__(config)
+
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.n_head
+
+ process_group = weights.process_group
+ self.tp_rank = process_group.rank()
+ self.tp_world_size = process_group.size()
+
+ self.word_embeddings = TensorParallelEmbedding(
+ prefix="word_embeddings", weights=weights
+ )
+
+ self.word_embeddings_layernorm = LayerNorm.load(
+ prefix="word_embeddings_layernorm",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+
+ # Transformer blocks
+ self.h = nn.ModuleList(
+ [
+ BloomBlock(layer_id=layer_id, config=config, weights=weights)
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+
+ # Final Layer Norm
+ self.ln_f = LayerNorm.load(
+ prefix="ln_f", weights=weights, eps=config.layer_norm_epsilon
+ )
+
+ def _prepare_attn_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_shape: Tuple[int, int],
+ past_key_values_length: int,
+ ) -> torch.BoolTensor:
+ # create causal mask
+ # [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
+ combined_attention_mask = None
+ device = attention_mask.device
+ _, src_length = input_shape
+
+ if src_length > 1:
+ combined_attention_mask = _make_causal_mask(
+ input_shape,
+ device=device,
+ past_key_values_length=past_key_values_length,
+ )
+
+ # [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
+ expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
+ combined_attention_mask = (
+ expanded_attn_mask
+ if combined_attention_mask is None
+ else expanded_attn_mask | combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
+ self.word_embeddings = new_embeddings
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **deprecated_arguments,
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
+ " passing `position_ids`.",
+ FutureWarning,
+ )
+ if len(deprecated_arguments) > 0:
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if past_key_values is None:
+ past_key_values = tuple([None] * len(self.h))
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape batch_size x num_heads x N x N
+ # head_mask has shape n_layer x batch x num_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ # Compute alibi tensor: check build_alibi_tensor documentation
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+ if past_key_values[0] is not None:
+ past_key_values_length = past_key_values[0][0].shape[-1]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), device=hidden_states.device
+ )
+ else:
+ attention_mask = attention_mask.to(hidden_states.device)
+
+ alibi = build_alibi_tensor(attention_mask, self.num_heads)
+
+ causal_mask = self._prepare_attn_mask(
+ attention_mask,
+ input_shape=(batch_size, seq_length),
+ past_key_values_length=past_key_values_length,
+ )
+
+ if hasattr(self, "tp_rank"):
+ assert self.num_heads % self.tp_world_size == 0
+ block_size = self.num_heads // self.tp_world_size
+ alibi = alibi[
+ :, self.tp_rank * block_size : (self.tp_rank + 1) * block_size
+ ]
+ alibi = alibi.reshape(batch_size * block_size, 1, seq_length_with_past)
+ causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)
+ else:
+ alibi = alibi.reshape(batch_size * self.num_heads, 1, seq_length_with_past)
+ causal_mask = torch.repeat_interleave(causal_mask, self.num_heads, dim=0)
+
+ alibi = alibi.to(hidden_states.dtype)
+
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ outputs = block(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=causal_mask,
+ head_mask=head_mask[i],
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ alibi=alibi,
+ )
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (
+ outputs[2 if use_cache else 1],
+ )
+
+ # Add last hidden state
+ hidden_states = self.ln_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ presents,
+ all_hidden_states,
+ all_self_attentions,
+ ]
+ if v is not None
+ )
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class BloomForCausalLM(BloomPreTrainedModel):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__(config)
+ self.transformer = BloomModel(config, weights)
+
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix="word_embeddings",
+ weights=weights,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids: torch.LongTensor,
+ past_key_values: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> dict:
+ # only last token for input_ids if past is not None
+ if past_key_values:
+ input_ids = input_ids[:, -1].unsqueeze(-1)
+
+ # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
+ if past_key_values[0][0].shape[0] == input_ids.shape[0]:
+ past_key_values = self._convert_to_bloom_cache(past_key_values)
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **deprecated_arguments,
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
+ " passing `position_ids`.",
+ FutureWarning,
+ )
+ if len(deprecated_arguments) > 0:
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+
+ logits, speculative_logits = self.lm_head(hidden_states)
+ loss = None
+
+ if not return_dict:
+ output = (logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return (
+ CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ ),
+ speculative_logits,
+ )
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/clip.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/clip.py
new file mode 100644
index 00000000000..ab824da5b26
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/clip.py
@@ -0,0 +1,817 @@
+from typing import Optional, Tuple
+
+import torch
+from torch import nn
+
+from transformers.activations import ACT2FN
+from transformers.modeling_attn_mask_utils import (
+ _create_4d_causal_attention_mask,
+ _prepare_4d_attention_mask,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPooling,
+)
+from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
+
+from text_generation_server.layers import (
+ TensorParallelEmbedding,
+ TensorParallelColumnLinear,
+ TensorParallelRowLinear,
+)
+
+
+class CLIPVisionEmbeddings(nn.Module):
+ def __init__(self, prefix, config: CLIPVisionConfig, weights):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ # TODO Should we TP this ?
+ self.class_embedding = weights.get_tensor(f"{prefix}.class_embedding")
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ bias=False,
+ )
+ self.patch_embedding.weight = nn.Parameter(
+ weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
+ )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches + 1
+ self.position_embedding = TensorParallelEmbedding(
+ prefix=f"{prefix}.position_embedding", weights=weights
+ )
+ self.register_buffer(
+ "position_ids",
+ torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
+ persistent=False,
+ )
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ batch_size = pixel_values.shape[0]
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(
+ pixel_values.to(dtype=target_dtype)
+ ) # shape = [*, width, grid, grid]
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ embeddings = embeddings + self.position_embedding(self.position_ids)
+ return embeddings
+
+
+class CLIPTextEmbeddings(nn.Module):
+ def __init__(self, config: CLIPTextConfig):
+ super().__init__()
+ embed_dim = config.hidden_size
+
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
+ self.position_embedding = nn.Embedding(
+ config.max_position_embeddings, embed_dim
+ )
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer(
+ "position_ids",
+ torch.arange(config.max_position_embeddings).expand((1, -1)),
+ persistent=False,
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ seq_length = (
+ input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
+ )
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.token_embedding(input_ids)
+
+ position_embeddings = self.position_embedding(position_ids)
+ embeddings = inputs_embeds + position_embeddings
+
+ return embeddings
+
+
+class CLIPAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_size = self.embed_dim // self.num_heads
+ if self.head_size * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.embed_dim = self.embed_dim // weights.process_group.size()
+ self.scale = self.head_size**-0.5
+ self.dropout = config.attention_dropout
+
+ self.qkv = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=True,
+ )
+ self.out_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.out_proj",
+ weights=weights,
+ bias=True,
+ )
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return (
+ tensor.view(bsz, seq_len, self.num_heads, self.head_size)
+ .transpose(1, 2)
+ .contiguous()
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+
+ qkv = self.qkv(hidden_states)
+ query_states, key_states, value_states = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ ]
+ * 3,
+ dim=2,
+ )
+ query_states = query_states * self.scale
+ key_states = self._shape(key_states, -1, bsz)
+ value_states = self._shape(value_states, -1, bsz)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_size)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ # apply the causal_attention_mask first
+ if causal_attention_mask is not None:
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {causal_attention_mask.size()}"
+ )
+ attn_weights = (
+ attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ + causal_attention_mask
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = (
+ attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ + attention_mask
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ attn_probs = nn.functional.dropout(
+ attn_weights, p=self.dropout, training=self.training
+ )
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, None
+
+
+class CLIPMLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = TensorParallelColumnLinear.load(
+ prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
+ )
+ self.fc2 = TensorParallelRowLinear.load(
+ prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class CLIPEncoderLayer(nn.Module):
+ def __init__(self, prefix, config: CLIPConfig, weights):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = CLIPAttention(
+ prefix=f"{prefix}.self_attn", config=config, weights=weights
+ )
+ self.layer_norm1 = nn.LayerNorm.load(
+ prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
+ )
+ self.mlp = CLIPMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+ self.layer_norm2 = nn.LayerNorm.load(
+ prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ causal_attention_mask: torch.Tensor,
+ ):
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ `(config.encoder_attention_heads,)`.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+class CLIPPreTrainedModel(nn.Module):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = CLIPConfig
+ base_model_prefix = "clip"
+ supports_gradient_checkpointing = True
+
+
+CLIP_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+CLIP_TEXT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+"""
+
+CLIP_VISION_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
+"""
+
+CLIP_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
+ return_loss (`bool`, *optional*):
+ Whether or not to return the contrastive loss.
+"""
+
+
+class CLIPEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`CLIPEncoderLayer`].
+
+ Args:
+ config: CLIPConfig
+ """
+
+ def __init__(self, prefix, config: CLIPConfig, weights):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList(
+ [
+ CLIPEncoderLayer(
+ prefix=f"{prefix}.layers.{i}", config=config, weights=weights
+ )
+ for i in range(config.num_hidden_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ """
+
+ hidden_states = inputs_embeds
+ for idx, encoder_layer in enumerate(self.layers):
+ hidden_states = encoder_layer(
+ hidden_states,
+ attention_mask,
+ causal_attention_mask,
+ )
+
+ return hidden_states
+
+
+class CLIPTextTransformer(nn.Module):
+ def __init__(self, prefix: str, config: CLIPTextConfig, weights=None):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+ self.embeddings = CLIPTextEmbeddings(config)
+ # Initialize weights and apply final processing with `self.post_init()`
+ self.encoder = CLIPEncoder(
+ prefix=f"{prefix}.encoder", config=config, weights=weights
+ )
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+ # For `pooled_output` computation
+ self.eos_token_id = config.eos_token_id
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ ):
+ r"""
+ Returns:
+
+ """
+ if input_ids is None:
+ raise ValueError("You have to specify input_ids")
+
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
+
+ # CLIP's text model uses causal mask, prepare it here.
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
+ causal_attention_mask = _create_4d_causal_attention_mask(
+ input_shape, hidden_states.dtype, device=hidden_states.device
+ )
+ # expand attention_mask
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(
+ attention_mask, hidden_states.dtype
+ )
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
+
+ if self.eos_token_id == 2:
+ # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
+ # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
+ # ------------------------------------------------------------
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
+ last_hidden_state[
+ torch.arange(
+ last_hidden_state.shape[0], device=last_hidden_state.device
+ ),
+ input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(
+ dim=-1
+ ),
+ ]
+ else:
+ # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
+ last_hidden_state[
+ torch.arange(
+ last_hidden_state.shape[0], device=last_hidden_state.device
+ ),
+ # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
+ (
+ input_ids.to(dtype=torch.int, device=last_hidden_state.device)
+ == self.eos_token_id
+ )
+ .int()
+ .argmax(dim=-1),
+ ]
+
+ return last_hidden_state
+
+
+class CLIPTextModel(CLIPPreTrainedModel):
+ config_class = CLIPTextConfig
+
+ _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
+
+ def __init__(self, prefix, config: CLIPTextConfig):
+ super().__init__(config)
+ self.text_model = CLIPTextTransformer(prefix, config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ ):
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, CLIPTextModel
+
+ >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
+
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
+ ```"""
+
+ return self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ )
+
+
+class CLIPVisionTransformer(nn.Module):
+ def __init__(self, prefix, config: CLIPVisionConfig, weights):
+ super().__init__()
+ self.config = config
+
+ self.embeddings = CLIPVisionEmbeddings(
+ prefix=f"{prefix}.embeddings", config=config, weights=weights
+ )
+ self.pre_layrnorm = nn.LayerNorm.load(
+ prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps
+ )
+ self.encoder = CLIPEncoder(
+ prefix=f"{prefix}.encoder", config=config, weights=weights
+ )
+ # self.post_layernorm = nn.LayerNorm.load(prefix=f"{prefix}.post_layernorm", weights=weights, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Returns:
+
+ """
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ hidden_states = self.embeddings(pixel_values)
+ hidden_states = self.pre_layrnorm(hidden_states)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ )
+ last_hidden_state = encoder_outputs
+ # pooled_output = last_hidden_state[:, 0, :]
+ # pooled_output = self.post_layernorm(pooled_output)
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ # pooler_output=pooled_output,
+ # hidden_states=encoder_outputs,
+ )
+
+
+class CLIPVisionModel(CLIPPreTrainedModel):
+ config_class = CLIPVisionConfig
+ main_input_name = "pixel_values"
+ _no_split_modules = ["CLIPEncoderLayer"]
+
+ def __init__(self, config: CLIPVisionConfig):
+ super().__init__(config)
+ self.vision_model = CLIPVisionTransformer(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.vision_model.embeddings.patch_embedding
+
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, CLIPVisionModel
+
+ >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
+ ```"""
+
+ return self.vision_model(
+ pixel_values=pixel_values,
+ )
+
+
+class CLIPModel(nn.Module):
+ def __init__(self, prefix, config: CLIPConfig, weights):
+ super().__init__()
+ text_config = config.text_config
+ vision_config = config.vision_config
+
+ self.projection_dim = config.projection_dim
+ self.text_embed_dim = text_config.hidden_size
+ self.vision_embed_dim = vision_config.hidden_size
+
+ self.text_model = CLIPTextTransformer(text_config)
+ self.vision_model = CLIPVisionTransformer(vision_config)
+
+ self.visual_projection = nn.Linear(
+ self.vision_embed_dim, self.projection_dim, bias=False
+ )
+ self.text_projection = nn.Linear(
+ self.text_embed_dim, self.projection_dim, bias=False
+ )
+ self.logit_scale = nn.Parameter(
+ torch.tensor(self.config.logit_scale_init_value)
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_text_features(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
+ applying the projection layer to the pooled output of [`CLIPTextModel`].
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, CLIPModel
+
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
+
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
+ >>> text_features = model.get_text_features(**inputs)
+ ```"""
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ )
+
+ pooled_output = text_outputs[1]
+ text_features = self.text_projection(pooled_output)
+
+ return text_features
+
+ def get_image_features(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
+ applying the projection layer to the pooled output of [`CLIPVisionModel`].
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, CLIPModel
+
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> image_features = model.get_image_features(**inputs)
+ ```"""
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ )
+
+ pooled_output = vision_outputs[1] # pooled_output
+ image_features = self.visual_projection(pooled_output)
+
+ return image_features
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ ):
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, CLIPModel
+
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(
+ ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
+ ... )
+
+ >>> outputs = model(**inputs)
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
+ ```"""
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ )
+
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ )
+
+ image_embeds = vision_outputs[1]
+ image_embeds = self.visual_projection(image_embeds)
+
+ text_embeds = text_outputs[1]
+ text_embeds = self.text_projection(text_embeds)
+
+ # normalized features
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logit_scale = self.logit_scale.exp()
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
+ logits_per_image = logits_per_text.t()
+
+ return logits_per_image, logits_per_text
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py
new file mode 100644
index 00000000000..367c26c99d5
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py
@@ -0,0 +1,510 @@
+# coding=utf-8
+# Copyright 2024 Cohere team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from typing import Optional, List, Tuple
+
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ set_block_mapping,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+ get_linear,
+)
+from text_generation_server.layers.layernorm import (
+ FastLayerNorm,
+)
+from text_generation_server.layers.rotary import (
+ PositionRotaryEmbedding,
+)
+from text_generation_server.utils.weights import UnquantizedWeight
+from habana_frameworks.torch.hpex.kernels import (
+ RotaryPosEmbeddingMode,
+ apply_rotary_pos_emb,
+)
+
+import habana_frameworks.torch as htorch
+
+
+class CohereRotary(PositionRotaryEmbedding):
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ ):
+ # Such controlflows may add some overhead.
+ num_tokens = query.shape[0]
+ head_size = query.shape[-1]
+ rope_mode = RotaryPosEmbeddingMode.PAIRWISE
+ sin = torch.repeat_interleave(sin, 2, dim=-1)
+ cos = torch.repeat_interleave(cos, 2, dim=-1)
+ rotary_dim = cos.shape[-1]
+ query_shape = query.shape
+ query = query.view(num_tokens, -1, head_size)
+ query_rot = query[..., :rotary_dim]
+ query_pass = query[..., rotary_dim:]
+ query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
+ query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
+
+ key_shape = key.shape
+ key = key.view(num_tokens, -1, head_size)
+ key_rot = key[..., :rotary_dim]
+ key_pass = key[..., rotary_dim:]
+ key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
+ key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
+
+
+class CohereLayerNorm(nn.Module):
+ def __init__(self, prefix, weights, eps):
+ super().__init__()
+ weight = weights.get_sharded(f"{prefix}.weight", dim=0)
+ self.weight = nn.Parameter(weight)
+ # Fake weights
+ self.ones = weight.new_ones(weight.shape[1])
+ self.eps = eps
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.reshape(
+ -1, self.weight.shape[0], self.weight.shape[1]
+ )
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ mean = hidden_states.mean(-1, keepdim=True)
+ hidden_states_minus_mean = hidden_states - mean
+ variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps)
+ hidden_states = self.weight.to(torch.float32) * hidden_states
+ hidden_states = hidden_states.view(-1, self.weight.shape[1])
+ return hidden_states.to(input_dtype)
+
+
+def load_attention(config, prefix, weights):
+ if config.num_attention_heads != config.num_key_value_heads:
+ return _load_gqa(config, prefix, weights)
+ else:
+ return TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=config.attention_bias,
+ )
+
+
+def _load_gqa(config, prefix: str, weights):
+ assert config.hidden_size % config.num_attention_heads == 0
+ assert config.num_attention_heads % weights.process_group.size() == 0
+
+ weight = weights.get_multi_weights_col(
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ )
+
+ if isinstance(weight, UnquantizedWeight):
+ weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
+
+ head_size = config.hidden_size // config.num_attention_heads
+ num_heads = config.num_attention_heads // weights.process_group.size()
+ num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
+ assert list(weight.weight.shape) == [
+ (num_heads + 2 * num_key_value_heads) * head_size,
+ config.hidden_size,
+ ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
+
+ if config.attention_bias:
+ w = [
+ weights.get_sharded(f"{p}.bias", dim=0)
+ for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
+ ]
+ bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device)
+ else:
+ bias = None
+
+ return TensorParallelColumnLinear(get_linear(weight, bias=bias))
+
+
+class FlashCohereAttention(torch.nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights,
+ rotary_emb,
+ ):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.head_size = self.hidden_size // self.num_heads
+
+ self.rotary_emb = rotary_emb
+
+ self.softmax_scale = self.head_size**-0.5
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+
+ self.query_key_value = load_attention(config, prefix, weights)
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ self.use_qk_norm = config.use_qk_norm
+ if self.use_qk_norm:
+ self.q_norm = CohereLayerNorm(
+ prefix=f"{prefix}.q_norm",
+ weights=weights,
+ eps=config.layer_norm_eps,
+ )
+ self.k_norm = CohereLayerNorm(
+ prefix=f"{prefix}.k_norm",
+ weights=weights,
+ eps=config.layer_norm_eps,
+ )
+ else:
+ self.q_norm = None
+ self.k_norm = None
+
+ self.o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=config.attention_bias,
+ )
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ qkv = self.query_key_value(hidden_states)
+ query, key, value = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ self.head_size * self.num_key_value_heads,
+ self.head_size * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+
+ if self.use_qk_norm:
+ query = query.reshape(-1, self.head_size)
+ key = key.reshape(-1, self.head_size)
+ query = self.q_norm(query.contiguous())
+ key = self.k_norm(key.contiguous())
+
+ query = query.view(-1, self.num_heads, self.head_size)
+ key = key.view(-1, self.num_key_value_heads, self.head_size)
+ value = value.view(-1, self.num_key_value_heads, self.head_size)
+
+ self.rotary_emb(query, key, cos, sin)
+
+ kv_cache.store(
+ key=key,
+ value=value,
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=key,
+ value=value,
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.o_proj(
+ attn_output.view(-1, self.num_heads * self.head_size), reduce=False
+ )
+
+
+class CohereMLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ act = config.hidden_act
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+ # Fuse gate and up proj
+ self.gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+ self.down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+
+ def forward(self, hidden_states):
+ gate_up_states = self.gate_up_proj(hidden_states)
+ gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
+ return self.down_proj(
+ self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=False
+ )
+
+
+class FlashCohereLayer(nn.Module):
+ def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):
+ super().__init__()
+ prefix = f"{prefix}.layers.{layer_id}"
+ self.self_attn = FlashCohereAttention(
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ rotary_emb=rotary_emb,
+ )
+ self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+
+ self.input_layernorm = FastLayerNorm.load_no_bias(
+ prefix=f"{prefix}.input_layernorm",
+ weights=weights,
+ eps=config.layer_norm_eps,
+ )
+ self.process_group = weights.process_group
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ mlp_output = self.mlp(normed_hidden_states)
+ output = attn_output + mlp_output
+
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(output, group=self.process_group)
+
+ return output, res
+
+
+class FlashCohereModel(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ process_group = weights.process_group
+ self.tp_rank = process_group.rank()
+ self.tp_world_size = process_group.size()
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=f"{prefix}.embed_tokens", weights=weights
+ )
+ rotary_emb = CohereRotary.static(
+ config=config,
+ dim=config.hidden_size // config.num_attention_heads,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+ self.layers = nn.ModuleList(
+ [
+ FlashCohereLayer(
+ prefix,
+ layer_id,
+ config,
+ weights,
+ rotary_emb,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = FastLayerNorm.load_no_bias(
+ prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_eps
+ )
+
+ self.gradient_checkpointing = False
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: torch.Tensor,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ if hpu_attention_meta is not None:
+ hpu_attention_meta = set_block_mapping(
+ hpu_attention_meta, input_ids.shape[0]
+ )
+ hidden_states = self.embed_tokens(input_ids)
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashCohereForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ if not prefix:
+ prefix = "model"
+ else:
+ prefix = f"{prefix}.model"
+
+ self.model = FlashCohereModel(prefix, config, weights)
+ try:
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix="lm_head",
+ weights=weights,
+ )
+ except RuntimeError:
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix=f"{prefix}.embed_tokens",
+ weights=weights,
+ )
+ self.logit_scale = config.logit_scale
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ hidden_states = self.model(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+ logits *= self.logit_scale
+ if speculative_logits is not None:
+ speculative_logits *= self.logit_scale
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py
new file mode 100644
index 00000000000..c097f71ee52
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py
@@ -0,0 +1,765 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+from typing import Optional, List, Tuple, Any
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+
+
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ set_block_mapping,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ FastLinear,
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+ get_linear,
+)
+from text_generation_server.layers.rotary import (
+ PositionRotaryEmbedding,
+)
+from text_generation_server.layers.layernorm import (
+ FastLayerNorm,
+)
+from vllm_hpu_extension.ops import DynamicFusedMOE
+import habana_frameworks.torch as htorch
+
+
+class DbrxAttentionConfig(PretrainedConfig):
+ def __init__(
+ self,
+ attn_pdrop: float = 0,
+ clip_qkv: Optional[float] = None,
+ kv_n_heads: int = 1,
+ rope_theta: float = 10000.0,
+ **kwargs: Any,
+ ):
+ super().__init__(**kwargs)
+ self.attn_pdrop = attn_pdrop
+ self.clip_qkv = clip_qkv
+ self.kv_n_heads = kv_n_heads
+ self.rope_theta = rope_theta
+
+ for k in ["model_type"]:
+ if k in kwargs:
+ kwargs.pop(k)
+ if len(kwargs) != 0:
+ raise ValueError(f"Found unknown {kwargs=}")
+
+
+class DbrxFFNConfig(PretrainedConfig):
+ def __init__(
+ self,
+ ffn_act_fn: Optional[dict] = None,
+ ffn_hidden_size: int = 3584,
+ moe_num_experts: int = 4,
+ moe_top_k: int = 1,
+ moe_jitter_eps: Optional[float] = None,
+ moe_loss_weight: float = 0.01,
+ moe_normalize_expert_weights: Optional[float] = 1,
+ uniform_expert_assignment: bool = False,
+ **kwargs: Any,
+ ):
+ super().__init__()
+ if ffn_act_fn is None:
+ ffn_act_fn = {"name": "silu"}
+ self.ffn_act_fn = ffn_act_fn
+ self.ffn_hidden_size = ffn_hidden_size
+ self.moe_num_experts = moe_num_experts
+ self.moe_top_k = moe_top_k
+ self.moe_jitter_eps = moe_jitter_eps
+ self.moe_loss_weight = moe_loss_weight
+ self.moe_normalize_expert_weights = moe_normalize_expert_weights
+ self.uniform_expert_assignment = uniform_expert_assignment
+
+ if uniform_expert_assignment:
+ raise ValueError("`uniform_expert_assignment = True` is not supported")
+
+ for k in ["model_type"]:
+ if k in kwargs:
+ kwargs.pop(k)
+ if len(kwargs) != 0:
+ raise ValueError(f"Found unknown {kwargs=}")
+
+
+class DbrxConfig(PretrainedConfig):
+ attribute_map = {
+ "hidden_size": "d_model",
+ "num_attention_heads": "n_heads",
+ "num_hidden_layers": "n_layers",
+ }
+
+ def __init__(
+ self,
+ d_model: int = 2048,
+ n_heads: int = 16,
+ n_layers: int = 24,
+ max_seq_len: int = 2048,
+ vocab_size: int = 32000,
+ resid_pdrop: float = 0.0,
+ emb_pdrop: float = 0.0,
+ attn_config: Optional[DbrxAttentionConfig] = None,
+ ffn_config: Optional[DbrxFFNConfig] = None,
+ use_cache: bool = True,
+ initializer_range: float = 0.02,
+ output_router_logits: bool = False,
+ router_aux_loss_coef: float = 0.05,
+ **kwargs: Any,
+ ):
+ if attn_config is None:
+ self.attn_config = DbrxAttentionConfig()
+ elif isinstance(attn_config, dict):
+ self.attn_config = DbrxAttentionConfig(**attn_config)
+ else:
+ self.attn_config = attn_config
+
+ if ffn_config is None:
+ self.ffn_config = DbrxFFNConfig()
+ elif isinstance(ffn_config, dict):
+ self.ffn_config = DbrxFFNConfig(**ffn_config)
+ else:
+ self.ffn_config = ffn_config
+
+ self.d_model = d_model
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.max_seq_len = max_seq_len
+ self.vocab_size = vocab_size
+ self.resid_pdrop = resid_pdrop
+ self.emb_pdrop = emb_pdrop
+ self.use_cache = use_cache
+ self.initializer_range = initializer_range
+ self.output_router_logits = output_router_logits
+ self.router_aux_loss_coef = router_aux_loss_coef
+
+ tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
+ if tie_word_embeddings:
+ raise ValueError("tie_word_embeddings is not supported for Dbrx models.")
+
+ super().__init__(
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ @property
+ def num_key_value_heads(self):
+ # We can't use the attribute map, since this the number of KV
+ # heads is not top-level.
+ return self.attn_config.kv_n_heads
+
+
+def promote_scalar(x: torch.Tensor) -> torch.Tensor:
+ return x.view(1) if len(x.size()) == 0 else x
+
+
+def load_attention(config, prefix, weights):
+ return TensorParallelColumnLinear.load_qkv(
+ config,
+ prefix=f"{prefix}.Wqkv",
+ weights=weights,
+ bias=False,
+ num_heads=config.n_heads,
+ num_key_value_heads=config.attn_config.kv_n_heads,
+ )
+
+
+def _load_experts(config, prefix, weights):
+ world_size = weights.process_group.size()
+ rank = weights.process_group.rank()
+
+ assert (
+ config.ffn_config.ffn_hidden_size % world_size == 0
+ ), f"The chosen size {config.ffn_config.ffn_hidden_size} is not compatible with sharding on {world_size} shards"
+
+ expert_size = config.ffn_config.ffn_hidden_size
+ block_size = expert_size // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+
+ tensor = torch.empty(
+ (config.ffn_config.moe_num_experts * block_size, config.d_model),
+ dtype=weights.dtype,
+ device=weights.device,
+ )
+
+ slice_ = weights._get_slice(f"{prefix}")
+
+ for i in range(config.ffn_config.moe_num_experts):
+ offset = i * expert_size
+ expert_slice = slice_[start + offset : stop + offset]
+
+ tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
+ dtype=weights.dtype
+ ).to(device=weights.device)
+ return tensor
+
+
+def _load_experts_quantized(config, prefix, weights, cls):
+ world_size = weights.process_group.size()
+ rank = weights.process_group.rank()
+
+ assert (
+ config.ffn_config.ffn_hidden_size % world_size == 0
+ ), f"The chosen size {config.ffn_config.ffn_hidden_size} is not compatible with sharding on {world_size} shards"
+
+ expert_size = config.ffn_config.ffn_hidden_size
+ block_size = expert_size // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+
+ slice_ = weights._get_slice(f"{prefix}")
+
+ experts = []
+ for i in range(config.ffn_config.moe_num_experts):
+ if config.quantize in ["gptq", "awq"]:
+ raise NotImplementedError(
+ "Dbrx does not support gptq/awq quantization yet."
+ )
+ else:
+ offset = i * expert_size
+ expert_slice = (
+ slice_[start + offset : stop + offset]
+ .to(dtype=weights.dtype)
+ .to(device=weights.device)
+ )
+
+ if cls == TensorParallelRowLinear:
+ expert_slice = expert_slice.t().contiguous()
+ linear = get_linear(expert_slice, None)
+ experts.append(cls(linear, weights.process_group))
+ else:
+ linear = get_linear(expert_slice, None)
+ experts.append(cls(linear))
+
+ return experts
+
+
+class DbrxAttention(torch.nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights,
+ rotary_emb,
+ ):
+ super().__init__()
+ self.clip_qkv = config.attn_config.clip_qkv
+ self.num_heads = config.n_heads
+ self.hidden_size = config.d_model
+ self.head_size = self.hidden_size // self.num_heads
+
+ self.rotary_emb = rotary_emb
+
+ self.softmax_scale = self.head_size**-0.5
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.attn_config.kv_n_heads // weights.process_group.size()
+ )
+
+ self.query_key_value = load_attention(config, prefix, weights)
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ self.o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.out_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ qkv = self.query_key_value(hidden_states)
+ if self.clip_qkv is not None:
+ qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
+
+ query, kv = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ 2 * self.head_size * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+ query = query.view(-1, self.num_heads, self.head_size)
+ kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
+
+ self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
+
+ kv_cache.store(
+ key=kv[:, 0],
+ value=kv[:, 1],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=kv[:, 0],
+ value=kv[:, 1],
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
+
+
+class DbrxNormAttentionNorm(nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights,
+ rotary_emb,
+ ):
+ super().__init__()
+ self.norm_1 = FastLayerNorm.load_no_bias(
+ prefix=f"{prefix}.norm_1", weights=weights, eps=1e-5
+ )
+ self.self_attn = DbrxAttention(
+ prefix=f"{prefix}.attn",
+ config=config,
+ weights=weights,
+ rotary_emb=rotary_emb,
+ )
+ self.norm_2 = FastLayerNorm.load_no_bias(
+ prefix=f"{prefix}.norm_2",
+ weights=weights,
+ eps=1e-5,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ normed_hidden_states, res = self.norm_1(hidden_states, residual)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ # faster post attention rms norm
+ normed_attn_res_output, attn_res = self.norm_2(attn_output, res)
+
+ return normed_attn_res_output, attn_res
+
+
+@torch.jit.script
+def select_experts(
+ gate_logits: torch.Tensor, top_k: int, moe_normalize_expert_weights: int
+):
+ # all_probs: (sequence_length, n_experts) and upcast for softmax
+ all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
+ # weights, selected_experts: (sequence_length, top-k)
+ weights, selected_experts = torch.topk(all_probs, top_k, dim=-1)
+ if moe_normalize_expert_weights:
+ weights = weights / torch.norm(
+ weights, p=moe_normalize_expert_weights, dim=-1, keepdim=True
+ )
+ weights = weights.view(-1)
+ selected_experts = selected_experts.view(-1)
+
+ return selected_experts, weights
+
+
+@torch.jit.script
+def round_up(x: torch.Tensor, value: int):
+ return torch.div(x + (value - 1), value, rounding_mode="trunc") * value
+
+
+class BlockSparseMoE(nn.Module):
+ def __init__(self, prefix, config: DbrxConfig, weights):
+ super().__init__()
+ self.moe_normalize_expert_weights = (
+ config.ffn_config.moe_normalize_expert_weights
+ )
+ self.hidden_dim = config.d_model
+ self.ffn_dim = config.ffn_config.ffn_hidden_size // weights.process_group.size()
+ self.num_experts = config.ffn_config.moe_num_experts
+ self.top_k = config.ffn_config.moe_top_k
+
+ act = config.ffn_config.ffn_act_fn["name"]
+ if "gelu" in act:
+ self.act = lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ elif "silu" in act:
+ self.act = torch.nn.functional.silu
+ else:
+ self.act = ACT2FN[act]
+
+ # gating
+ self.gate = FastLinear.load(
+ config, f"{prefix}.router.layer", weights, bias=False
+ )
+
+ # merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
+ w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights).view(
+ self.num_experts, self.ffn_dim, self.hidden_dim
+ )
+ v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights).view(
+ self.num_experts, self.ffn_dim, self.hidden_dim
+ )
+ self.wv1 = torch.cat([w1, v1], dim=1)
+ self.w2 = (
+ _load_experts(config, f"{prefix}.experts.mlp.w2", weights)
+ .view(self.num_experts, self.ffn_dim, self.hidden_dim)
+ .transpose(1, 2)
+ .contiguous()
+ )
+
+ self.process_group = weights.process_group
+
+ self.hpu_fused_moe = DynamicFusedMOE(self.num_experts)
+ for i in range(self.num_experts):
+ self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.wv1[i])
+ self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.w2[i])
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # router_logits: (num_tokens, n_experts)
+ router_logits = self.gate(x)
+ out = self.hpu_fused_moe(x, router_logits, self.top_k)
+
+ # Reduce sum
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(out, group=self.process_group)
+
+ return out.view(*x.shape)
+
+
+class DenseMoE(nn.Module):
+ def __init__(self, prefix, config: DbrxConfig, weights):
+ super().__init__()
+
+ self.moe_normalize_expert_weights = (
+ config.ffn_config.moe_normalize_expert_weights
+ )
+ self.hidden_dim = config.d_model
+ self.ffn_dim = config.ffn_config.ffn_hidden_size // weights.process_group.size()
+ self.num_experts = config.ffn_config.moe_num_experts
+ self.top_k = config.ffn_config.moe_top_k
+
+ act = config.ffn_config.ffn_act_fn["name"]
+ if "gelu" in act:
+ self.act = lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ elif "silu" in act:
+ self.act = torch.nn.functional.silu
+ else:
+ self.act = ACT2FN[act]
+
+ # gating
+ self.gate = FastLinear.load(
+ config, f"{prefix}.router.layer", weights, bias=False
+ )
+
+ self.w1 = _load_experts_quantized(
+ config,
+ prefix=f"{prefix}.experts.mlp.w1",
+ weights=weights,
+ cls=TensorParallelColumnLinear,
+ )
+ self.w2 = _load_experts_quantized(
+ config,
+ prefix=f"{prefix}.experts.mlp.w2",
+ weights=weights,
+ cls=TensorParallelRowLinear,
+ )
+ self.v1 = _load_experts_quantized(
+ config,
+ prefix=f"{prefix}.experts.mlp.v1",
+ weights=weights,
+ cls=TensorParallelColumnLinear,
+ )
+
+ self.process_group = weights.process_group
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ x: (sequence_length, model_dim)
+ gate_logits: (sequence_length, n_experts)
+ """
+ # optional reshape
+ input_shape = x.shape
+ x = x.view(-1, input_shape[-1])
+
+ # gate_logits: (sequence_length, n_experts)
+ gate_logits = self.gate(x)
+ # all_probs: (sequence_length, n_experts) and upcast for softmax
+ weights = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
+
+ if self.top_k < self.num_experts:
+ _, not_selected_experts = torch.topk(
+ weights,
+ self.num_experts - self.top_k,
+ largest=False,
+ sorted=False,
+ dim=1,
+ )
+ # Mask not selected experts
+ weights.scatter_(1, not_selected_experts, 0)
+
+ # Re-normalize
+ if self.moe_normalize_expert_weights:
+ weights = weights / torch.norm(
+ weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
+ )
+ weights = weights.to(x.dtype)
+
+ # Final output tensor
+ out = x.new_zeros(x.shape[0], self.hidden_dim)
+ for i in range(self.num_experts):
+ h = self.act(self.w1[i](x)) * self.v1[i](x)
+ h = self.w2[i](h, reduce=False)
+ # Add expert output to out with masking
+ out += h * weights[:, i].view(-1, 1)
+
+ # Reduce sum
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(out, group=self.process_group)
+
+ return out
+
+
+class DbrxLayer(nn.Module):
+ def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):
+ super().__init__()
+ prefix = f"{prefix}.blocks.{layer_id}"
+
+ self.attn = DbrxNormAttentionNorm(
+ prefix=f"{prefix}.norm_attn_norm",
+ config=config,
+ weights=weights,
+ rotary_emb=rotary_emb,
+ )
+
+ moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE
+ self.moe = moe_cls(f"{prefix}.ffn", config, weights)
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ # Self Attention
+ attn_output, attn_res = self.attn(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ moe_output = self.moe(attn_output)
+
+ return moe_output, attn_res
+
+
+class DbrxModel(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=f"{prefix}.wte", weights=weights
+ )
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.d_model // config.n_heads,
+ base=config.attn_config.rope_theta,
+ device=weights.device,
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ DbrxLayer(
+ prefix,
+ layer_id,
+ config,
+ weights,
+ rotary_emb,
+ )
+ for layer_id in range(config.n_layers)
+ ]
+ )
+ self.norm = FastLayerNorm.load_no_bias(
+ prefix=f"{prefix}.norm_f", weights=weights, eps=1e-5
+ )
+
+ self.head_size = self.layers[0].attn.self_attn.head_size
+ self.num_heads = self.layers[0].attn.self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].attn.self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ if hpu_attention_meta is not None:
+ hpu_attention_meta = set_block_mapping(
+ hpu_attention_meta, input_ids.shape[0]
+ )
+ hidden_states = self.embed_tokens(input_ids)
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids)
+ residual = None
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashDbrxForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ if not prefix:
+ prefix = "transformer"
+ else:
+ prefix = f"{prefix}.transformer"
+
+ self.model = DbrxModel(prefix, config, weights)
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix="lm_head",
+ weights=weights,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ hidden_states = self.model(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py
new file mode 100644
index 00000000000..2189fb627a8
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py
@@ -0,0 +1,715 @@
+# coding=utf-8
+# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+import torch.distributed
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+
+from text_generation_server.layers import (
+ FastLinear,
+ SpeculativeHead,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ TensorParallelRowLinear,
+ get_linear,
+ Fp8Linear,
+)
+from text_generation_server.layers.attention import (
+ Seqlen,
+ attention,
+ paged_attention_mla,
+ set_block_mapping,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
+from text_generation_server.layers.layernorm import FastRMSNorm
+from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
+from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
+from text_generation_server.utils.weights import Weights
+import habana_frameworks.torch as htorch
+
+
+def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor:
+ if isinstance(layer, Fp8Linear):
+ eye = torch.eye(
+ layer.qweight.shape[-1], dtype=torch.bfloat16, device=layer.qweight.device
+ )
+ dequant_weights = layer(eye)
+ del eye
+ # standardize to (output, input)
+ return dequant_weights.T
+ return layer.weight
+
+
+class DeepseekV2Config(PretrainedConfig):
+ def __init__(
+ self,
+ vocab_size=102400,
+ hidden_size=4096,
+ intermediate_size=11008,
+ moe_intermediate_size=1407,
+ num_hidden_layers=30,
+ num_attention_heads=32,
+ num_key_value_heads=32,
+ n_shared_experts=2,
+ n_routed_experts=160,
+ ep_size=1,
+ routed_scaling_factor=1.0,
+ kv_lora_rank=512,
+ q_lora_rank=1536,
+ qk_rope_head_dim=64,
+ v_head_dim=128,
+ qk_nope_head_dim=128,
+ topk_method="gready",
+ n_group=8,
+ topk_group=3,
+ num_experts_per_tok=6,
+ moe_layer_freq=1,
+ first_k_dense_replace=0,
+ norm_topk_prob=False,
+ scoring_func="softmax",
+ aux_loss_alpha=0.001,
+ seq_aux=True,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=100000,
+ eos_token_id=100001,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.moe_intermediate_size = moe_intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.n_shared_experts = n_shared_experts
+ self.n_routed_experts = n_routed_experts
+ self.ep_size = ep_size
+ self.routed_scaling_factor = routed_scaling_factor
+ self.kv_lora_rank = kv_lora_rank
+ self.q_lora_rank = q_lora_rank
+ self.qk_rope_head_dim = qk_rope_head_dim
+ self.v_head_dim = v_head_dim
+ self.qk_nope_head_dim = qk_nope_head_dim
+ self.topk_method = topk_method
+ self.n_group = n_group
+ self.topk_group = topk_group
+ self.num_experts_per_tok = num_experts_per_tok
+ self.moe_layer_freq = moe_layer_freq
+ self.first_k_dense_replace = first_k_dense_replace
+ self.norm_topk_prob = norm_topk_prob
+ self.scoring_func = scoring_func
+ self.aux_loss_alpha = aux_loss_alpha
+ self.seq_aux = seq_aux
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
+ if tie_word_embeddings:
+ raise ValueError(
+ "tie_word_embeddings is not supported for Deepseek V2 models."
+ )
+
+ if ep_size != 1:
+ raise ValueError(
+ f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}"
+ )
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class DeepseekV2Attention(torch.nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights: Weights,
+ rotary_emb,
+ ):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.kv_lora_rank = config.kv_lora_rank
+ self.q_lora_rank = config.q_lora_rank
+ self.qk_nope_head_dim = config.qk_nope_head_dim
+ self.qk_rope_head_dim = config.qk_rope_head_dim
+ self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
+ self.value_head_size = config.v_head_dim
+ self.head_pad_size = max(self.head_size, self.value_head_size)
+ self.rotary_emb = rotary_emb
+
+ mscale = get_mscale(
+ self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
+ )
+ self.softmax_scale = self.head_size**-0.5 * mscale * mscale
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+
+ if self.q_lora_rank is None:
+ self.q_proj = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.q_proj",
+ weights=weights,
+ bias=config.attention_bias,
+ )
+ else:
+ self.q_a_proj = get_linear(
+ weight=weights.get_weights(f"{prefix}.q_a_proj"),
+ bias=(
+ weights.get_tensor(f"{prefix}.q_a_proj.bias")
+ if config.attention_bias
+ else None
+ ),
+ )
+ self.q_a_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.q_a_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.q_b_proj = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.q_b_proj",
+ weights=weights,
+ bias=config.attention_bias,
+ )
+
+ self.kv_a_proj_with_mqa = get_linear(
+ weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"),
+ bias=(
+ weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias")
+ if config.attention_bias
+ else None
+ ),
+ )
+
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ self.kv_a_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+
+ self.kv_b_proj = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.kv_b_proj",
+ weights=weights,
+ bias=config.attention_bias,
+ )
+
+ self.o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj.linear).T
+ kv_b_proj_weight = kv_b_proj_weight.view(
+ self.kv_lora_rank,
+ self.num_heads,
+ self.qk_nope_head_dim + self.value_head_size,
+ )
+
+ W_UK, W_UV = kv_b_proj_weight.split(
+ [self.qk_nope_head_dim, self.value_head_size], dim=-1
+ )
+ # Convert from (L, N, V) to (N, L, V)
+ self.W_UV = W_UV.transpose(0, 1)
+ # Convert from (L, N, P) to (N, P, L)
+ self.W_UK_T = W_UK.permute(1, 2, 0)
+
+ def _q_proj_and_k_up_proj(self, x):
+ q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
+ q_nope, q_pe = (
+ q_proj(x)
+ .view(-1, self.num_heads, self.head_size)
+ .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
+ )
+
+ # Convert from (B, N, P) to (N, B, P)
+ q_nope = q_nope.transpose(0, 1)
+ # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
+ ql_nope = torch.bmm(q_nope, self.W_UK_T)
+ # Convert from (N, B, L) to (B, N, L)
+ return ql_nope.transpose(0, 1), q_pe
+
+ def _v_up_proj_and_o_proj(self, x):
+ # Convert from (B, N, L) to (N, B, L)
+ x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
+ # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
+ x = torch.bmm(x, self.W_UV)
+ # Convert from (N, B, V) to (B, N * V)
+ x = x.transpose(0, 1).reshape(-1, self.num_heads * self.value_head_size)
+ return self.o_proj(x)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ cu_seqlen_prefill: torch.Tensor,
+ kv_cache: KVCache,
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ):
+ if self.q_lora_rank is None:
+ hidden_states_or_q_c = hidden_states
+ else:
+ hidden_states_or_q_c = self.q_a_layernorm(self.q_a_proj(hidden_states))[0]
+
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
+ compressed_kv, key_pe = torch.split(
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
+ )
+
+ key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
+ kv_c_normed = self.kv_a_layernorm(compressed_kv.contiguous())[0]
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
+ query = q_proj(hidden_states_or_q_c)
+ query = query.view(-1, self.num_heads, self.head_size)
+ query_nope, query_pe = torch.split(
+ query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
+ )
+ else:
+ query_nope, query_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
+
+ batch_size, heads, head_dim = query_pe.shape
+ query_pe = (
+ query_pe.view(batch_size, heads, head_dim // 2, 2)
+ .transpose(2, 3)
+ .reshape(batch_size, heads, head_dim)
+ )
+ batch_size, heads, head_dim = key_pe.shape
+ key_pe = (
+ key_pe.view(batch_size, heads, head_dim // 2, 2)
+ .transpose(2, 3)
+ .reshape(batch_size, heads, head_dim)
+ )
+ self.rotary_emb(query_pe, key_pe, cos, sin)
+ latent_vec_k = torch.concat(
+ (kv_c_normed, key_pe.view(-1, self.qk_rope_head_dim)), dim=-1
+ )
+ latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)
+
+ latent_vec_k = latent_vec_k.unflatten(0, (slots.size(0), -1))
+
+ kv_cache.store(
+ key=latent_vec_k,
+ value=None,
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ if cu_seqlen_prefill is not None:
+ kv = self.kv_b_proj(kv_c_normed).view(
+ -1,
+ self.num_key_value_heads,
+ self.qk_nope_head_dim + self.value_head_size,
+ )
+
+ key_nope, value = torch.split(
+ kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
+ )
+ query[..., self.qk_nope_head_dim :] = query_pe
+ key = torch.empty_like(query)
+ key[..., : self.qk_nope_head_dim] = key_nope
+ key[..., self.qk_nope_head_dim :] = key_pe
+
+ # We need to pad the heads because Flash Attention does not support
+ # qk and v with different head sizes.
+ query = torch.nn.functional.pad(
+ query, (0, self.head_pad_size - self.head_size), value=0
+ )
+ key = torch.nn.functional.pad(
+ key, (0, self.head_pad_size - self.head_size), value=0
+ )
+ value = torch.nn.functional.pad(
+ value, (0, self.head_pad_size - self.value_head_size), value=0
+ )
+
+ # flash attention
+ attn_output = attention(
+ query=query,
+ key=key,
+ value=value,
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ attn_output = attn_output[..., : self.value_head_size]
+
+ return self.o_proj(
+ attn_output.reshape(-1, self.num_heads * self.value_head_size)
+ )
+ else:
+ # Decode
+ query = torch.cat([query_nope, query_pe], dim=-1)
+ attn_output = paged_attention_mla(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ kv_lora_rank=self.kv_lora_rank,
+ )
+ attn_output = self._v_up_proj_and_o_proj(attn_output)
+ return attn_output
+
+
+class DeepseekV2MLP(nn.Module):
+ def __init__(self, prefix: str, config, weights, intermediate_size: int):
+ super().__init__()
+ self.hidden_act = config.hidden_act
+ if self.hidden_act != "silu":
+ # Bail out because MoE only supports silu.
+ raise NotImplementedError(
+ "Currently only `silu` is supported as an activation for Deepseek V2."
+ )
+ self.act = ACT2FN[self.hidden_act]
+
+ self.gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+
+ self.down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+
+ self.intermediate_size = intermediate_size // weights.process_group.size()
+
+ # TODO: This is a hotfix to be removed & properly refactored.
+ self.quantize = config.quantize
+
+ def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
+ gate_up_states = self.gate_up_proj(hidden_states)
+ gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
+ return self.down_proj(
+ self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
+ )
+
+
+class DeepseekV2MoE(nn.Module):
+ def __init__(
+ self,
+ prefix,
+ config: DeepseekV2Config,
+ moe_layer_cls: Type[MoELayer],
+ weights,
+ ):
+ super().__init__()
+
+ self.hidden_dim = config.hidden_size
+ self.moe_intermediate_size = (
+ config.moe_intermediate_size // weights.process_group.size()
+ )
+ self.routed_scaling_factor = config.routed_scaling_factor
+
+ # Gating
+ self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
+
+ self.moe_layer = moe_layer_cls(
+ prefix=f"{prefix}.experts",
+ n_experts=config.n_routed_experts,
+ n_expert_group=config.n_group,
+ renormalize=config.norm_topk_prob,
+ topk=config.num_experts_per_tok,
+ topk_group=config.topk_group,
+ weights=weights,
+ )
+ assert isinstance(self.moe_layer, MoELayer)
+
+ if config.n_shared_experts is not None:
+ self.shared_experts = DeepseekV2MLP(
+ prefix=f"{prefix}.shared_experts",
+ config=config,
+ weights=weights,
+ intermediate_size=config.moe_intermediate_size
+ * config.n_shared_experts,
+ )
+ else:
+ self.shared_experts = None
+
+ self.process_group = weights.process_group
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.shared_experts is not None:
+ shared_output = self.shared_experts(x, reduce=False)
+ else:
+ shared_output = None
+
+ router_logits = self.gate(x)
+
+ out = self.moe_layer(x, gating_output=router_logits)
+
+ if shared_output is not None:
+ out = out + shared_output
+
+ # Reduce sum
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(out, group=self.process_group)
+
+ return out.view(*x.shape)
+
+
+class DeepseekV2Layer(nn.Module):
+ def __init__(self, prefix, layer_id, config, weights, rotary_emb):
+ super().__init__()
+ prefix = f"{prefix}.layers.{layer_id}"
+
+ self.self_attn = DeepseekV2Attention(
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ rotary_emb=rotary_emb,
+ )
+
+ if (
+ config.n_routed_experts is not None
+ and layer_id >= config.first_k_dense_replace
+ and layer_id % config.moe_layer_freq == 0
+ ):
+ moe_layer_cls = (
+ SparseMoELayer
+ if SparseMoELayer.is_supported(weights)
+ else DenseMoELayer
+ )
+ self.mlp = DeepseekV2MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
+ else:
+ self.mlp = DeepseekV2MLP(
+ prefix=f"{prefix}.mlp",
+ config=config,
+ weights=weights,
+ intermediate_size=config.intermediate_size,
+ )
+
+ self.input_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ cu_seqlen_prefill: torch.Tensor,
+ kv_cache,
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ):
+ normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ # faster post attention rms norm
+ normed_attn_res_output, residual = self.post_attention_layernorm(
+ attn_output, residual
+ )
+
+ output = self.mlp(normed_attn_res_output)
+
+ return output, residual
+
+
+class DeepseekV2Model(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights: Weights):
+ super().__init__()
+
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=f"{prefix}.embed_tokens", weights=weights
+ )
+
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.qk_rope_head_dim,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+ self.layers = nn.ModuleList(
+ [
+ DeepseekV2Layer(
+ prefix,
+ layer_id,
+ config,
+ weights,
+ rotary_emb,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = FastRMSNorm.load(
+ prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
+ )
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ if hpu_attention_meta is not None:
+ hpu_attention_meta = set_block_mapping(
+ hpu_attention_meta, input_ids.shape[0]
+ )
+ hidden_states = self.embed_tokens(input_ids)
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashDeepseekV2ForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights: Weights):
+ super().__init__()
+
+ self.model = DeepseekV2Model(
+ "model" if not prefix else f"{prefix}.model", config, weights
+ )
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix="lm_head" if not prefix else f"{prefix}.lm_head",
+ weights=weights,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ hidden_states = self.model(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py
new file mode 100644
index 00000000000..3a6a974aeab
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py
@@ -0,0 +1,723 @@
+# coding=utf-8
+# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+import torch.distributed
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+
+from text_generation_server.layers import (
+ FastLinear,
+ SpeculativeHead,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ TensorParallelRowLinear,
+ get_linear,
+ Fp8Linear,
+)
+from text_generation_server.layers.attention import (
+ Seqlen,
+ attention,
+ paged_attention_mla,
+ set_block_mapping,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
+from text_generation_server.layers.layernorm import FastRMSNorm
+from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
+from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
+from text_generation_server.utils.weights import Weights
+import habana_frameworks.torch as htorch
+
+
+def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor:
+ if isinstance(layer, Fp8Linear):
+ eye = torch.eye(
+ layer.qweight.shape[-1], dtype=torch.bfloat16, device=layer.qweight.device
+ )
+ dequant_weights = layer(eye)
+ del eye
+ # standardize to (output, input)
+ return dequant_weights.T
+ return layer.weight
+
+
+class DeepseekV3Config(PretrainedConfig):
+ def __init__(
+ self,
+ vocab_size=102400,
+ hidden_size=4096,
+ intermediate_size=11008,
+ moe_intermediate_size=1407,
+ num_hidden_layers=30,
+ num_attention_heads=32,
+ num_key_value_heads=32,
+ n_shared_experts=2,
+ n_routed_experts=160,
+ ep_size=1,
+ routed_scaling_factor=1.0,
+ kv_lora_rank=512,
+ q_lora_rank=1536,
+ qk_rope_head_dim=64,
+ v_head_dim=128,
+ qk_nope_head_dim=128,
+ topk_method="gready",
+ n_group=8,
+ topk_group=3,
+ num_experts_per_tok=6,
+ moe_layer_freq=1,
+ first_k_dense_replace=0,
+ norm_topk_prob=False,
+ scoring_func="softmax",
+ aux_loss_alpha=0.001,
+ seq_aux=True,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=100000,
+ eos_token_id=100001,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.moe_intermediate_size = moe_intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.n_shared_experts = n_shared_experts
+ self.n_routed_experts = n_routed_experts
+ self.ep_size = ep_size
+ self.routed_scaling_factor = routed_scaling_factor
+ self.kv_lora_rank = kv_lora_rank
+ self.q_lora_rank = q_lora_rank
+ self.qk_rope_head_dim = qk_rope_head_dim
+ self.v_head_dim = v_head_dim
+ self.qk_nope_head_dim = qk_nope_head_dim
+ self.topk_method = topk_method
+ self.n_group = n_group
+ self.topk_group = topk_group
+ self.num_experts_per_tok = num_experts_per_tok
+ self.moe_layer_freq = moe_layer_freq
+ self.first_k_dense_replace = first_k_dense_replace
+ self.norm_topk_prob = norm_topk_prob
+ self.scoring_func = scoring_func
+ self.aux_loss_alpha = aux_loss_alpha
+ self.seq_aux = seq_aux
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
+ if tie_word_embeddings:
+ raise ValueError(
+ "tie_word_embeddings is not supported for Deepseek V2 models."
+ )
+
+ if ep_size != 1:
+ raise ValueError(
+ f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}"
+ )
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class DeepseekV3Attention(torch.nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights: Weights,
+ rotary_emb,
+ ):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.kv_lora_rank = config.kv_lora_rank
+ self.q_lora_rank = config.q_lora_rank
+ self.qk_nope_head_dim = config.qk_nope_head_dim
+ self.qk_rope_head_dim = config.qk_rope_head_dim
+ self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
+ self.value_head_size = config.v_head_dim
+ self.head_pad_size = max(self.head_size, self.value_head_size)
+ self.rotary_emb = rotary_emb
+
+ mscale = get_mscale(
+ self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
+ )
+ self.softmax_scale = self.head_size**-0.5 * mscale * mscale
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+
+ if self.q_lora_rank is None:
+ self.q_proj = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.q_proj",
+ weights=weights,
+ bias=config.attention_bias,
+ )
+ else:
+ self.q_a_proj = get_linear(
+ weight=weights.get_weights(f"{prefix}.q_a_proj"),
+ bias=(
+ weights.get_tensor(f"{prefix}.q_a_proj.bias")
+ if config.attention_bias
+ else None
+ ),
+ )
+ self.q_a_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.q_a_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.q_b_proj = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.q_b_proj",
+ weights=weights,
+ bias=config.attention_bias,
+ )
+
+ self.kv_a_proj_with_mqa = get_linear(
+ weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"),
+ bias=(
+ weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias")
+ if config.attention_bias
+ else None
+ ),
+ )
+
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ self.kv_a_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+
+ self.kv_b_proj = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.kv_b_proj",
+ weights=weights,
+ bias=config.attention_bias,
+ )
+
+ self.o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj.linear).T
+ kv_b_proj_weight = kv_b_proj_weight.view(
+ self.kv_lora_rank,
+ self.num_heads,
+ self.qk_nope_head_dim + self.value_head_size,
+ )
+ W_UK, W_UV = kv_b_proj_weight.split(
+ [self.qk_nope_head_dim, self.value_head_size], dim=-1
+ )
+ # Convert from (L, N, V) to (N, L, V)
+ self.W_UV = W_UV.transpose(0, 1)
+ # Convert from (L, N, P) to (N, P, L)
+ self.W_UK_T = W_UK.permute(1, 2, 0)
+
+ def _q_proj_and_k_up_proj(self, x):
+ q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
+ q_nope, q_pe = (
+ q_proj(x)
+ .view(-1, self.num_heads, self.head_size)
+ .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
+ )
+
+ # Convert from (B, N, P) to (N, B, P)
+ q_nope = q_nope.transpose(0, 1)
+ # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
+ ql_nope = torch.bmm(q_nope, self.W_UK_T)
+ # Convert from (N, B, L) to (B, N, L)
+ return ql_nope.transpose(0, 1), q_pe
+
+ def _v_up_proj_and_o_proj(self, x):
+ # Convert from (B, N, L) to (N, B, L)
+ x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
+ # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
+ x = torch.bmm(x, self.W_UV)
+ # Convert from (N, B, V) to (B, N * V)
+ x = x.transpose(0, 1).reshape(-1, self.num_heads * self.value_head_size)
+ return self.o_proj(x)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ cu_seqlen_prefill: torch.Tensor,
+ kv_cache: KVCache,
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ):
+ if self.q_lora_rank is None:
+ hidden_states_or_q_c = hidden_states
+ else:
+ hidden_states_or_q_c = self.q_a_layernorm(self.q_a_proj(hidden_states))[0]
+
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
+ compressed_kv, key_pe = torch.split(
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
+ )
+
+ key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
+ kv_c_normed = self.kv_a_layernorm(compressed_kv.contiguous())[0]
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
+ query = q_proj(hidden_states_or_q_c)
+ query = query.view(-1, self.num_heads, self.head_size)
+ query_nope, query_pe = torch.split(
+ query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
+ )
+ else:
+ query_nope, query_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
+
+ batch_size, heads, head_dim = query_pe.shape
+ query_pe = (
+ query_pe.view(batch_size, heads, head_dim // 2, 2)
+ .transpose(2, 3)
+ .reshape(batch_size, heads, head_dim)
+ )
+ batch_size, heads, head_dim = key_pe.shape
+ key_pe = (
+ key_pe.view(batch_size, heads, head_dim // 2, 2)
+ .transpose(2, 3)
+ .reshape(batch_size, heads, head_dim)
+ )
+ self.rotary_emb(query_pe, key_pe, cos, sin)
+ latent_vec_k = torch.concat(
+ (kv_c_normed, key_pe.view(-1, self.qk_rope_head_dim)), dim=-1
+ )
+ latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)
+
+ latent_vec_k = latent_vec_k.unflatten(0, (slots.size(0), -1))
+
+ kv_cache.store(
+ key=latent_vec_k,
+ value=None,
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ if cu_seqlen_prefill is not None:
+ kv = self.kv_b_proj(kv_c_normed).view(
+ -1,
+ self.num_key_value_heads,
+ self.qk_nope_head_dim + self.value_head_size,
+ )
+
+ key_nope, value = torch.split(
+ kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
+ )
+ query[..., self.qk_nope_head_dim :] = query_pe
+ key = torch.empty_like(query)
+ key[..., : self.qk_nope_head_dim] = key_nope
+ key[..., self.qk_nope_head_dim :] = key_pe
+
+ # We need to pad the heads because Flash Attention does not support
+ # qk and v with different head sizes.
+ query = torch.nn.functional.pad(
+ query, (0, self.head_pad_size - self.head_size), value=0
+ )
+ key = torch.nn.functional.pad(
+ key, (0, self.head_pad_size - self.head_size), value=0
+ )
+ value = torch.nn.functional.pad(
+ value, (0, self.head_pad_size - self.value_head_size), value=0
+ )
+
+ # flash attention
+ attn_output = attention(
+ query=query,
+ key=key,
+ value=value,
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ attn_output = attn_output[..., : self.value_head_size]
+
+ return self.o_proj(
+ attn_output.reshape(-1, self.num_heads * self.value_head_size)
+ )
+ else:
+ # Decode
+ query = torch.cat([query_nope, query_pe], dim=-1)
+ attn_output = paged_attention_mla(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ kv_lora_rank=self.kv_lora_rank,
+ )
+ attn_output = self._v_up_proj_and_o_proj(attn_output)
+ return attn_output
+
+
+class DeepseekV3MLP(nn.Module):
+ def __init__(self, prefix: str, config, weights, intermediate_size: int):
+ super().__init__()
+ self.hidden_act = config.hidden_act
+ if self.hidden_act != "silu":
+ # Bail out because MoE only supports silu.
+ raise NotImplementedError(
+ "Currently only `silu` is supported as an activation for Deepseek V2."
+ )
+ self.act = ACT2FN[self.hidden_act]
+
+ self.gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+
+ self.down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+
+ self.intermediate_size = intermediate_size // weights.process_group.size()
+
+ # TODO: This is a hotfix to be removed & properly refactored.
+ self.quantize = config.quantize
+
+ def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
+ gate_up_states = self.gate_up_proj(hidden_states)
+ gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
+ return self.down_proj(
+ self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
+ )
+
+
+class DeepseekV3MoE(nn.Module):
+ def __init__(
+ self,
+ prefix,
+ config: DeepseekV3Config,
+ moe_layer_cls: Type[MoELayer],
+ weights,
+ ):
+ super().__init__()
+
+ self.hidden_dim = config.hidden_size
+ self.moe_intermediate_size = (
+ config.moe_intermediate_size // weights.process_group.size()
+ )
+ self.routed_scaling_factor = config.routed_scaling_factor
+
+ # Gating
+ self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
+
+ if config.topk_method == "noaux_tc":
+ self.gate.e_score_correction_bias = torch.zeros(
+ config.n_routed_experts, device=weights.device
+ )
+ else:
+ self.gate.e_score_correction_bias = None
+
+ self.moe_layer = moe_layer_cls(
+ prefix=f"{prefix}.experts",
+ n_experts=config.n_routed_experts,
+ n_expert_group=config.n_group,
+ renormalize=config.norm_topk_prob,
+ topk=config.num_experts_per_tok,
+ topk_group=config.topk_group,
+ weights=weights,
+ scoring_func=config.scoring_func,
+ e_score_correction_bias=self.gate.e_score_correction_bias,
+ )
+ assert isinstance(self.moe_layer, MoELayer)
+
+ if config.n_shared_experts is not None:
+ self.shared_experts = DeepseekV3MLP(
+ prefix=f"{prefix}.shared_experts",
+ config=config,
+ weights=weights,
+ intermediate_size=config.moe_intermediate_size
+ * config.n_shared_experts,
+ )
+ else:
+ self.shared_experts = None
+
+ self.process_group = weights.process_group
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.shared_experts is not None:
+ shared_output = self.shared_experts(x, reduce=False)
+ else:
+ shared_output = None
+
+ router_logits = self.gate(x)
+
+ out = self.moe_layer(x, gating_output=router_logits)
+
+ if shared_output is not None:
+ out = out + shared_output
+
+ # Reduce sum
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(out, group=self.process_group)
+
+ return out.view(*x.shape)
+
+
+class DeepseekV3Layer(nn.Module):
+ def __init__(self, prefix, layer_id, config, weights, rotary_emb):
+ super().__init__()
+ prefix = f"{prefix}.layers.{layer_id}"
+
+ self.self_attn = DeepseekV3Attention(
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ rotary_emb=rotary_emb,
+ )
+
+ if (
+ config.n_routed_experts is not None
+ and layer_id >= config.first_k_dense_replace
+ and layer_id % config.moe_layer_freq == 0
+ ):
+ moe_layer_cls = (
+ SparseMoELayer
+ if SparseMoELayer.is_supported(weights)
+ else DenseMoELayer
+ )
+ self.mlp = DeepseekV3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
+ else:
+ self.mlp = DeepseekV3MLP(
+ prefix=f"{prefix}.mlp",
+ config=config,
+ weights=weights,
+ intermediate_size=config.intermediate_size,
+ )
+
+ self.input_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ cu_seqlen_prefill: torch.Tensor,
+ kv_cache,
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ):
+ normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ # faster post attention rms norm
+ normed_attn_res_output, residual = self.post_attention_layernorm(
+ attn_output, residual
+ )
+
+ output = self.mlp(normed_attn_res_output)
+
+ return output, residual
+
+
+class DeepseekV3Model(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights: Weights):
+ super().__init__()
+
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=f"{prefix}.embed_tokens", weights=weights
+ )
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.qk_rope_head_dim,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ DeepseekV3Layer(
+ prefix,
+ layer_id,
+ config,
+ weights,
+ rotary_emb,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = FastRMSNorm.load(
+ prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
+ )
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ if hpu_attention_meta is not None:
+ hpu_attention_meta = set_block_mapping(
+ hpu_attention_meta, input_ids.shape[0]
+ )
+ hidden_states = self.embed_tokens(input_ids)
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashDeepseekV3ForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights: Weights):
+ super().__init__()
+
+ self.model = DeepseekV3Model(
+ "model" if not prefix else f"{prefix}.model", config, weights
+ )
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix="lm_head" if not prefix else f"{prefix}.lm_head",
+ weights=weights,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ hidden_states = self.model(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py
new file mode 100644
index 00000000000..6ab1c4a9a64
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py
@@ -0,0 +1,585 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+from typing import Optional, List, Tuple
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ set_block_mapping,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+ get_linear,
+ TensorParallelMultiAdapterLinear,
+ TensorParallelAdapterRowLinear,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.rotary import PositionRotaryEmbedding
+from text_generation_server.layers.layernorm import (
+ FastRMSNorm,
+)
+from text_generation_server.utils.weights import UnquantizedWeight
+import habana_frameworks.torch as htorch
+
+
+class Gemma2Config(PretrainedConfig):
+ def __init__(
+ self,
+ vocab_size=256128,
+ hidden_size=3072,
+ intermediate_size=24576,
+ num_hidden_layers=28,
+ num_attention_heads=16,
+ num_key_value_heads=16,
+ head_dim=256,
+ hidden_act="gelu_pytorch_tanh",
+ max_position_embeddings=8192,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=True,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.head_dim = head_dim
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class Gemma2FastRMSNorm(FastRMSNorm):
+ @classmethod
+ def load(cls, prefix: str, weights, eps=1e-6):
+ dtype = weights.dtype
+ weights.dtype = torch.float32
+ weight = weights.get_tensor(f"{prefix}.weight") + 1
+ weights.dtype = dtype
+ new = cls(weight, eps)
+ new.dtype = dtype
+ return new
+
+ # perform the multiplication in full precision and downcast after
+ def forward(self, hidden_states, residual=None):
+ if residual is not None:
+ hidden_states += residual
+ residual = hidden_states
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ hidden_states = hidden_states * self.weight
+ return hidden_states.to(self.dtype), residual
+
+
+def load_attention(config, prefix: str, weights):
+ if config.num_attention_heads != config.num_key_value_heads:
+ return _load_gqa(config, prefix, weights)
+ else:
+ return TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=False,
+ )
+
+
+def _load_gqa(config, prefix: str, weights):
+ assert config.num_attention_heads % weights.process_group.size() == 0
+
+ weight = weights.get_multi_weights_col(
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ )
+
+ if isinstance(weight, UnquantizedWeight):
+ weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
+
+ head_size = config.head_dim
+ num_heads = config.num_attention_heads // weights.process_group.size()
+ num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
+ assert list(weight.weight.shape) == [
+ (num_heads + 2 * num_key_value_heads) * head_size,
+ config.hidden_size,
+ ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
+
+ return TensorParallelColumnLinear(get_linear(weight, bias=None))
+
+
+class FlashGemma2Attention(torch.nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights,
+ layer_id,
+ causal: bool,
+ is_sliding: bool,
+ rotary_emb,
+ ):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.head_size = config.head_dim
+ self.causal = causal
+ if is_sliding:
+ self.window_size = config.sliding_window
+ else:
+ self.window_size = -1
+ self.rotary_emb = rotary_emb
+
+ # self.softmax_scale = self.head_size**-0.5
+ self.softmax_scale = config.query_pre_attn_scalar**-0.5
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+ self.softcap = config.attn_logit_softcapping
+
+ query_key_value = load_attention(config, prefix, weights)
+ self.query_key_value = TensorParallelMultiAdapterLinear.load(
+ query_key_value,
+ layer_id,
+ ["q_proj", "k_proj", "v_proj"],
+ sizes=[
+ self.head_size * config.num_attention_heads,
+ self.head_size * config.num_key_value_heads,
+ self.head_size * config.num_key_value_heads,
+ ],
+ process_group=weights.process_group,
+ )
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.o_proj = TensorParallelAdapterRowLinear.load(
+ o_proj,
+ layer_id,
+ "o_proj",
+ process_group=weights.process_group,
+ )
+
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ ):
+ qkv = self.query_key_value(hidden_states, adapter_data)
+ query, kv = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ 2 * self.head_size * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+ query = query.view(-1, self.num_heads, self.head_size)
+ kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
+
+ self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
+
+ kv_cache.store(
+ key=kv[:, 0],
+ value=kv[:, 1],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=kv[:, 0],
+ value=kv[:, 1],
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ window_size_left=self.window_size,
+ softcap=self.softcap,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ softcap=self.softcap,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ window_size_left=self.window_size,
+ )
+
+ return self.o_proj(
+ attn_output.view(-1, self.num_heads * self.head_size), adapter_data
+ )
+
+
+class Gemma2MLP(nn.Module):
+ def __init__(self, prefix, config, weights, layer_id):
+ super().__init__()
+ act = config.hidden_activation
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+ # Fuse gate and up proj
+ gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+ self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
+ gate_up_proj,
+ layer_id,
+ ["gate_proj", "up_proj"],
+ sizes=[
+ config.intermediate_size,
+ config.intermediate_size,
+ ],
+ process_group=weights.process_group,
+ )
+
+ down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.down_proj = TensorParallelAdapterRowLinear.load(
+ down_proj,
+ layer_id,
+ "down_proj",
+ process_group=weights.process_group,
+ )
+
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+
+ def forward(self, hidden_states, adapter_data):
+ gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
+ gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
+ return self.down_proj(
+ self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
+ )
+
+
+class FlashGemma2Layer(nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights,
+ layer_id,
+ causal: bool,
+ is_sliding: bool,
+ rotary_emb,
+ ):
+ super().__init__()
+ self.self_attn = FlashGemma2Attention(
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ layer_id=layer_id,
+ causal=causal,
+ is_sliding=is_sliding,
+ rotary_emb=rotary_emb,
+ )
+ self.mlp = Gemma2MLP(
+ prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
+ )
+
+ self.input_layernorm = Gemma2FastRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = Gemma2FastRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.pre_feedforward_layernorm = Gemma2FastRMSNorm.load(
+ prefix=f"{prefix}.pre_feedforward_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.post_feedforward_layernorm = Gemma2FastRMSNorm.load(
+ prefix=f"{prefix}.post_feedforward_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ ):
+ normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ )
+
+ # faster post attention rms norm
+ normed_attn_res_output, _ = self.post_attention_layernorm(attn_output)
+ normed_attn_res_output = normed_attn_res_output + res
+ res = normed_attn_res_output
+
+ pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)
+ mlp_output = self.mlp(pre_normed, adapter_data)
+ post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)
+
+ return post_hidden_states, normed_attn_res_output
+
+
+class FlashGemma2Model(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights, causal: bool):
+ super().__init__()
+
+ process_group = weights.process_group
+ self.tp_rank = process_group.rank()
+ self.tp_world_size = process_group.size()
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.head_dim,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ FlashGemma2Layer(
+ prefix=f"{prefix}.layers.{layer_id}",
+ config=config,
+ weights=weights,
+ layer_id=layer_id,
+ causal=causal,
+ is_sliding=layer_id % 2 == 0,
+ rotary_emb=rotary_emb,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = Gemma2FastRMSNorm.load(
+ prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
+ )
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ adapter_data: Optional[torch.Tensor],
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ if hpu_attention_meta is not None:
+ hpu_attention_meta = set_block_mapping(
+ hpu_attention_meta, inputs_embeds.shape[0]
+ )
+ hidden_states = inputs_embeds
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ )
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashGemma2ForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights, *, causal: bool = True):
+ super().__init__()
+
+ embed_norm = config.hidden_size**0.5
+ if not prefix:
+ prefix = "model"
+ else:
+ prefix = f"{prefix}.model"
+
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=f"{prefix}.embed_tokens", weights=weights
+ )
+ self.embed_tokens.weight *= embed_norm
+
+ self.model = FlashGemma2Model(
+ prefix=prefix, config=config, weights=weights, causal=causal
+ )
+ self.lm_head = SpeculativeHead.load(
+ prefix=(
+ f"{prefix}.embed_tokens"
+ if config.tie_word_embeddings
+ else f"{prefix}.lm_head"
+ ),
+ config=config,
+ weights=weights,
+ )
+ self.softcap = config.final_logit_softcapping
+ assert isinstance(self.softcap, float)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_embeds = self.embed_tokens(input_ids)
+ hidden_states = self.model(
+ input_embeds,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+
+ logits /= self.softcap
+ logits = torch.tanh(logits)
+ logits *= self.softcap
+
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py
new file mode 100644
index 00000000000..c7091f901ba
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py
@@ -0,0 +1,755 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.distributed
+from torch import nn
+from typing import Optional, List, Tuple
+import copy
+
+from text_generation_server.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ TensorParallelRowLinear,
+ get_linear,
+ #
+ SpeculativeHead,
+ TensorParallelMultiAdapterLinear,
+ TensorParallelAdapterRowLinear,
+)
+
+import torch
+
+
+from text_generation_server.models.custom_modeling.vlm import (
+ load_text_model,
+ load_vision_model,
+)
+
+
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.rotary import PositionRotaryEmbedding
+from text_generation_server.layers.layernorm import (
+ FastRMSNorm,
+)
+from text_generation_server.utils.weights import UnquantizedWeight
+from transformers.activations import ACT2FN
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ Seqlen,
+ set_block_mapping,
+ HPUPagedAttentionMetadata,
+)
+import habana_frameworks.torch as htorch
+
+ATTENTION_TYPE_GLOBAL = "global"
+ATTENTION_TYPE_LOCAL = "local_sliding"
+
+
+class Gemma3FastRMSNorm(FastRMSNorm):
+ @classmethod
+ def load(cls, prefix: str, weights, eps=1e-6):
+ dtype = weights.dtype
+ weights.dtype = torch.float32
+ weight = weights.get_tensor(f"{prefix}.weight") + 1
+ weights.dtype = dtype
+ new = cls(weight, eps)
+ new.dtype = dtype
+ return new
+
+ # perform the multiplication in full precision and downcast after
+ def forward(self, hidden_states, residual=None):
+ if residual is not None:
+ hidden_states += residual
+ residual = hidden_states
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ hidden_states = hidden_states * self.weight
+ return hidden_states.to(self.dtype), residual
+
+
+def load_attention(config, prefix: str, weights):
+ if config.num_attention_heads != config.num_key_value_heads:
+ return _load_gqa(config, prefix, weights)
+ else:
+ return TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=False,
+ )
+
+
+def _load_gqa(config, prefix: str, weights):
+ assert config.num_attention_heads % weights.process_group.size() == 0
+
+ weight = weights.get_multi_weights_col(
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ )
+
+ if isinstance(weight, UnquantizedWeight):
+ weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
+
+ head_size = config.head_dim
+ num_heads = config.num_attention_heads // weights.process_group.size()
+ num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
+ assert list(weight.weight.shape) == [
+ (num_heads + 2 * num_key_value_heads) * head_size,
+ config.hidden_size,
+ ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
+
+ return TensorParallelColumnLinear(get_linear(weight, bias=None))
+
+
+class FlashGemma3Attention(torch.nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights,
+ layer_id,
+ causal: bool,
+ is_sliding: bool,
+ local_rotary_emb,
+ global_rotary_emb,
+ ):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.head_size = config.head_dim
+ self.causal = causal
+ if is_sliding:
+ self.window_size = config.sliding_window
+ self.rotary_emb = local_rotary_emb
+ else:
+ self.window_size = -1
+ self.rotary_emb = global_rotary_emb
+
+ self.softmax_scale = (
+ config.query_pre_attn_scalar**-0.5
+ if config.query_pre_attn_scalar is not None
+ else None
+ )
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+ self.softcap = None # config.attn_logit_softcapping
+
+ query_key_value = load_attention(config, prefix, weights)
+ self.query_key_value = TensorParallelMultiAdapterLinear.load(
+ query_key_value,
+ layer_id,
+ ["q_proj", "k_proj", "v_proj"],
+ sizes=[
+ self.head_size * config.num_attention_heads,
+ self.head_size * config.num_key_value_heads,
+ self.head_size * config.num_key_value_heads,
+ ],
+ process_group=weights.process_group,
+ )
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.o_proj = TensorParallelAdapterRowLinear.load(
+ o_proj,
+ layer_id,
+ "o_proj",
+ process_group=weights.process_group,
+ )
+
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+ self.q_norm = Gemma3FastRMSNorm.load(
+ prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.k_norm = Gemma3FastRMSNorm.load(
+ prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.enable_gqa = self.num_heads != self.num_key_value_heads
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ ):
+
+ qkv = self.query_key_value(hidden_states, adapter_data)
+ query, kv = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ 2 * self.head_size * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+
+ kv = kv.view(-1, 2, self.num_key_value_heads * self.head_size)
+ key = kv[:, 0]
+ value = kv[:, 1]
+
+ query = query.reshape(-1, self.head_size)
+ key = key.reshape(-1, self.head_size)
+
+ query, _ = self.q_norm(query.contiguous())
+ key, _ = self.k_norm(key.contiguous())
+
+ query = query.view(-1, self.num_heads, self.head_size)
+ key = key.view(-1, self.num_key_value_heads, self.head_size)
+ value = value.view(-1, self.num_key_value_heads, self.head_size)
+
+ self.rotary_emb(query, key, cos, sin)
+
+ kv_cache.store(
+ key=key,
+ value=value,
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=key,
+ value=value,
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ window_size_left=self.window_size,
+ softcap=self.softcap,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ softcap=self.softcap,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ window_size_left=self.window_size,
+ )
+
+ return self.o_proj(
+ attn_output.view(-1, self.num_heads * self.head_size), adapter_data
+ )
+
+
+class Gemma3MLP(nn.Module):
+ def __init__(self, prefix, config, weights, layer_id):
+ super().__init__()
+ act = config.hidden_activation
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+ # Fuse gate and up proj
+ gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+ self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
+ gate_up_proj,
+ layer_id,
+ ["gate_proj", "up_proj"],
+ sizes=[
+ config.intermediate_size,
+ config.intermediate_size,
+ ],
+ process_group=weights.process_group,
+ )
+
+ down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.down_proj = TensorParallelAdapterRowLinear.load(
+ down_proj,
+ layer_id,
+ "down_proj",
+ process_group=weights.process_group,
+ )
+
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+
+ def forward(self, hidden_states, adapter_data):
+ gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
+ gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
+ return self.down_proj(
+ self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
+ )
+
+
+class FlashGemma3Layer(nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights,
+ layer_id,
+ causal: bool,
+ is_sliding: bool,
+ local_rotary_emb,
+ global_rotary_emb,
+ ):
+ super().__init__()
+ self.self_attn = FlashGemma3Attention(
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ layer_id=layer_id,
+ causal=causal,
+ is_sliding=is_sliding,
+ local_rotary_emb=local_rotary_emb,
+ global_rotary_emb=global_rotary_emb,
+ )
+ self.mlp = Gemma3MLP(
+ prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
+ )
+
+ self.input_layernorm = Gemma3FastRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = Gemma3FastRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.pre_feedforward_layernorm = Gemma3FastRMSNorm.load(
+ prefix=f"{prefix}.pre_feedforward_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.post_feedforward_layernorm = Gemma3FastRMSNorm.load(
+ prefix=f"{prefix}.post_feedforward_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ ):
+ normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ )
+
+ # faster post attention rms norm
+ normed_attn_res_output, _ = self.post_attention_layernorm(attn_output)
+ normed_attn_res_output = normed_attn_res_output + res
+ res = normed_attn_res_output
+
+ pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)
+ mlp_output = self.mlp(pre_normed, adapter_data)
+ post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)
+
+ return post_hidden_states, normed_attn_res_output
+
+
+class FlashGemma3Model(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights, causal: bool):
+ super().__init__()
+
+ process_group = weights.process_group
+ self.tp_rank = process_group.rank()
+ self.tp_world_size = process_group.size()
+ local_config = copy.deepcopy(config)
+ local_config.rope_scaling = dict(rope_type="default")
+ local_rotary_emb = PositionRotaryEmbedding.static(
+ config=local_config,
+ dim=config.head_dim,
+ base=config.rope_local_base_freq,
+ device=weights.device,
+ )
+ global_rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.head_dim,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ FlashGemma3Layer(
+ prefix=f"{prefix}.layers.{layer_id}",
+ config=config,
+ weights=weights,
+ layer_id=layer_id,
+ causal=causal,
+ is_sliding=bool((layer_id + 1) % config.sliding_window_pattern),
+ local_rotary_emb=local_rotary_emb,
+ global_rotary_emb=global_rotary_emb,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = Gemma3FastRMSNorm.load(
+ prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
+ )
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ adapter_data: Optional[torch.Tensor],
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ if hpu_attention_meta is not None:
+ hpu_attention_meta = set_block_mapping(
+ hpu_attention_meta, inputs_embeds.shape[0]
+ )
+ hidden_states = inputs_embeds
+
+ residual = None
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+
+ residual = None
+ for i, layer in enumerate(self.layers):
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = layer.self_attn.rotary_emb.get_cos_sin(position_ids)
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ )
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashGemma3ForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights, *, causal: bool = True):
+ super().__init__()
+
+ embed_norm = config.hidden_size**0.5
+ if not prefix:
+ prefix = "model"
+ else:
+ prefix = f"{prefix}.model"
+
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=f"{prefix}.embed_tokens", weights=weights
+ )
+ self.embed_tokens.weight *= embed_norm
+
+ self.model = FlashGemma3Model(
+ prefix=prefix, config=config, weights=weights, causal=causal
+ )
+ self.lm_head = SpeculativeHead.load(
+ prefix=(
+ f"{prefix}.embed_tokens"
+ if config.tie_word_embeddings
+ else f"{prefix}.lm_head"
+ ),
+ config=config,
+ weights=weights,
+ )
+ # self.softcap = config.attn_logit_softcapping
+ # assert isinstance(self.softcap, float)
+ self.softcap = None
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_embeds = self.embed_tokens(input_ids)
+
+ hidden_states = self.model(
+ input_embeds,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+
+ return logits, speculative_logits
+
+
+class Gemma3MultimodalInputProjection(torch.nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+
+ self.mm_input_projection_weight = weights.get_tensor(
+ "multi_modal_projector.mm_input_projection_weight"
+ )
+
+ self.mm_soft_emb_norm = Gemma3FastRMSNorm.load(
+ prefix=f"{prefix}.mm_soft_emb_norm",
+ weights=weights,
+ eps=config.vision_config.layer_norm_eps,
+ )
+
+ self.patches_per_image = int(
+ config.vision_config.image_size // config.vision_config.patch_size
+ )
+ self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
+ self.kernel_size = self.patches_per_image // self.tokens_per_side
+ self.avg_pool = nn.AvgPool2d(
+ kernel_size=self.kernel_size, stride=self.kernel_size
+ )
+
+ def forward(self, vision_outputs: torch.Tensor):
+ batch_size, _, seq_length = vision_outputs.shape
+
+ reshaped_vision_outputs = vision_outputs.transpose(1, 2)
+ reshaped_vision_outputs = reshaped_vision_outputs.reshape(
+ batch_size, seq_length, self.patches_per_image, self.patches_per_image
+ )
+ reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
+
+ pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
+ pooled_vision_outputs = pooled_vision_outputs.flatten(2)
+ pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
+
+ normed_vision_outputs, _ = self.mm_soft_emb_norm(pooled_vision_outputs)
+
+ projected_vision_outputs = torch.matmul(
+ normed_vision_outputs, self.mm_input_projection_weight
+ )
+ return projected_vision_outputs.type_as(vision_outputs)
+
+
+class Gemma3ForConditionalGeneration(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+
+ self.config = config
+
+ if config.vision_config is not None:
+
+ config.vision_config.quantize = config.quantize
+
+ self.post_vision_model_layernorm = nn.LayerNorm.load(
+ prefix="vision_tower.vision_model.post_layernorm",
+ weights=weights,
+ eps=config.vision_config.layer_norm_eps,
+ )
+
+ self.multimodal_projector = Gemma3MultimodalInputProjection(
+ prefix="multi_modal_projector",
+ config=config,
+ weights=weights,
+ )
+
+ text_config = config.text_config
+ text_config.speculator = config.speculator
+ text_config.quantize = config.quantize
+
+ self.vision_model = load_vision_model(
+ prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
+ config=config.vision_config,
+ weights=weights,
+ )
+
+ self.text_model = load_text_model(
+ prefix="language_model" if not prefix else f"{prefix}.language_model",
+ config=config.text_config,
+ weights=weights,
+ )
+ else:
+ config.text_config.quantize = config.quantize
+ config.text_config.speculator = config.speculator
+ self.text_model = load_text_model(
+ prefix=prefix,
+ config=config.text_config,
+ weights=weights,
+ )
+
+ self.pad_token_id = (
+ config.pad_token_id if config.pad_token_id is not None else -1
+ )
+ self.dtype = weights.dtype
+
+ def get_vision_embeds(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_attention_mask: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ ):
+ pixel_values = pixel_values.to(dtype=self.dtype)
+ image_outputs = self.vision_model(pixel_values)
+ vision_outputs = self.post_vision_model_layernorm(
+ image_outputs.last_hidden_state
+ )
+ image_features = self.multimodal_projector(vision_outputs)
+ image_features = image_features.view(-1, image_features.shape[-1])
+ return image_features
+
+ def get_inputs_embeds(
+ self,
+ input_ids: torch.Tensor,
+ vision_embeds: torch.Tensor = None,
+ ):
+ inputs_embeds = self.text_model.embed_tokens(input_ids)
+
+ if vision_embeds is not None:
+ # Replace the image token embeddings with the vision features
+ image_token_mask = (input_ids == self.config.image_token_index).to(
+ input_ids.device
+ )
+ inputs_embeds[image_token_mask] = vision_embeds.view(
+ -1, vision_embeds.shape[-1]
+ )
+ return inputs_embeds
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.BoolTensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ if cu_seqlen_prefill is not None:
+ position_ids += 1
+
+ if attention_mask is not None:
+ min_dtype = torch.finfo(inputs_embeds.dtype).min
+ # prefill may be larger than sliding window
+ effective_seq_len = max(
+ position_ids.shape[0], self.config.text_config.sliding_window
+ )
+ sliding_window_mask = torch.tril(
+ torch.ones_like(attention_mask, dtype=torch.bool),
+ diagonal=-self.config.text_config.sliding_window,
+ )
+ attention_mask_local = torch.where(
+ sliding_window_mask, min_dtype, attention_mask
+ )
+ offset = max(0, position_ids.shape[0] - effective_seq_len)
+ attention_mask_local = attention_mask_local[
+ :, :, :, offset : offset + effective_seq_len
+ ]
+ else:
+ attention_mask_local = None
+
+ hidden_states = self.text_model.model(
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ cu_seqlen_prefill=cu_seqlen_prefill,
+ kv_cache=kv_cache,
+ slots=slots,
+ seqlen=seqlen,
+ hpu_attention_meta=hpu_attention_meta,
+ adapter_data=adapter_data,
+ )
+
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.text_model.lm_head(hidden_states)
+
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
new file mode 100644
index 00000000000..5d6dc67c0fd
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
@@ -0,0 +1,485 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+from typing import Optional, List, Tuple
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ set_block_mapping,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+ get_linear,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.rotary import PositionRotaryEmbedding
+from text_generation_server.layers.layernorm import (
+ FastRMSNorm,
+)
+from text_generation_server.utils.weights import UnquantizedWeight
+import habana_frameworks.torch as htorch
+
+
+class GemmaConfig(PretrainedConfig):
+ def __init__(
+ self,
+ vocab_size=256128,
+ hidden_size=3072,
+ intermediate_size=24576,
+ num_hidden_layers=28,
+ num_attention_heads=16,
+ num_key_value_heads=16,
+ head_dim=256,
+ hidden_act="gelu_pytorch_tanh",
+ max_position_embeddings=8192,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=True,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.head_dim = head_dim
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class GemmaFastRMSNorm(FastRMSNorm):
+ @classmethod
+ def load(cls, prefix: str, weights, eps=1e-6):
+ dtype = weights.dtype
+ weights.dtype = torch.float32
+ weight = weights.get_tensor(f"{prefix}.weight") + 1
+ weights.dtype = dtype
+ new = cls(weight, eps)
+ new.dtype = dtype
+ return new
+
+ # perform the multiplication in full precision and downcast after
+ def forward(self, hidden_states, residual=None):
+ if residual is not None:
+ hidden_states += residual
+ residual = hidden_states
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ hidden_states = hidden_states * self.weight
+ return hidden_states.to(self.dtype), residual
+
+
+def load_attention(config, prefix: str, weights):
+ if config.num_attention_heads != config.num_key_value_heads:
+ return _load_gqa(config, prefix, weights)
+ else:
+ return TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=False,
+ )
+
+
+def _load_gqa(config, prefix: str, weights):
+ assert config.num_attention_heads % weights.process_group.size() == 0
+
+ weight = weights.get_multi_weights_col(
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ )
+
+ if isinstance(weight, UnquantizedWeight):
+ weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
+
+ head_size = config.head_dim
+ num_heads = config.num_attention_heads // weights.process_group.size()
+ num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
+ assert list(weight.weight.shape) == [
+ (num_heads + 2 * num_key_value_heads) * head_size,
+ config.hidden_size,
+ ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
+
+ return TensorParallelColumnLinear(get_linear(weight, bias=None))
+
+
+class FlashGemmaAttention(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights, causal: bool, rotary_emb):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.head_size = config.head_dim
+ self.causal = causal
+ self.rotary_emb = rotary_emb
+ self.softmax_scale = self.head_size**-0.5
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+
+ self.query_key_value = load_attention(config, prefix, weights)
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ self.o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ qkv = self.query_key_value(hidden_states)
+ query, kv = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ 2 * self.head_size * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+ query = query.view(-1, self.num_heads, self.head_size)
+ kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
+
+ self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
+
+ kv_cache.store(
+ key=kv[:, 0],
+ value=kv[:, 1],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=kv[:, 0],
+ value=kv[:, 1],
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ causal=self.causal,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
+
+
+class GemmaMLP(nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+ act = config.hidden_act
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+ # Fuse gate and up proj
+ self.gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+ self.down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+
+ def forward(self, hidden_states):
+ gate_up_states = self.gate_up_proj(hidden_states)
+ gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
+ return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
+
+
+class FlashGemmaLayer(nn.Module):
+ def __init__(self, prefix: str, config, weights, causal: bool, rotary_emb):
+ super().__init__()
+ self.self_attn = FlashGemmaAttention(
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ causal=causal,
+ rotary_emb=rotary_emb,
+ )
+ self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+
+ self.input_layernorm = GemmaFastRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = GemmaFastRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ # faster post attention rms norm
+ normed_attn_res_output, attn_res = self.post_attention_layernorm(
+ attn_output, res
+ )
+
+ mlp_output = self.mlp(normed_attn_res_output)
+
+ return mlp_output, attn_res
+
+
+class FlashGemmaModel(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights, causal: bool):
+ super().__init__()
+
+ process_group = weights.process_group
+ self.tp_rank = process_group.rank()
+ self.tp_world_size = process_group.size()
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.head_dim,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ FlashGemmaLayer(
+ prefix=f"{prefix}.layers.{layer_id}",
+ config=config,
+ weights=weights,
+ causal=causal,
+ rotary_emb=rotary_emb,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = GemmaFastRMSNorm.load(
+ prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
+ )
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ adapter_data: Optional[torch.Tensor],
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ if hpu_attention_meta is not None:
+ hpu_attention_meta = set_block_mapping(
+ hpu_attention_meta, inputs_embeds.shape[0]
+ )
+ hidden_states = inputs_embeds
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashGemmaForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights, *, causal: bool = True):
+ super().__init__()
+
+ embed_norm = config.hidden_size**0.5
+ if not prefix:
+ prefix = "model"
+ else:
+ prefix = f"{prefix}.model"
+
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=f"{prefix}.embed_tokens", weights=weights
+ )
+ self.embed_tokens.weight *= embed_norm
+
+ self.model = FlashGemmaModel(
+ prefix=prefix, config=config, weights=weights, causal=causal
+ )
+ self.lm_head = SpeculativeHead.load(
+ prefix=(
+ f"{prefix}.embed_tokens"
+ if config.tie_word_embeddings
+ else f"{prefix}.lm_head"
+ ),
+ config=config,
+ weights=weights,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_embeds = self.embed_tokens(input_ids)
+ hidden_states = self.model(
+ input_embeds,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py
new file mode 100644
index 00000000000..a6b53656ea5
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py
@@ -0,0 +1,463 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from typing import Optional, List, Tuple
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ set_block_mapping,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+ get_linear,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+import habana_frameworks.torch as htorch
+
+
+def load_qkv(config, prefix: str, weights, head_size, num_heads):
+ if config.quantize == "gptq":
+ return _load_qkv_gptq(
+ config,
+ prefix,
+ weights,
+ )
+ else:
+ return _load_qkv(config, prefix, weights, head_size, num_heads)
+
+
+def _load_qkv_gptq(config, prefix: str, weights):
+ world_size = weights.process_group.size()
+ rank = weights.process_group.rank()
+
+ # Weights
+ weight = weights.get_weights_col_packed_qkv(
+ f"{prefix}.c_attn",
+ config.num_attention_heads,
+ config.num_attention_heads,
+ )
+
+ # Bias
+ slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
+ shape = slice_.get_shape()
+ total_size = shape[0]
+ assert total_size % 3 == 0, f"Prepacked is not divisible by {3}"
+ single_size = total_size // 3
+ assert single_size % world_size == 0
+ block_size = single_size // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+ tensors = []
+ for i in range(3):
+ tensor = slice_[start + i * single_size : stop + i * single_size]
+ tensors.append(tensor)
+ bias = torch.cat(tensors, dim=0)
+ bias = bias.to(device=weights.device)
+
+ return TensorParallelColumnLinear(get_linear(weight, bias))
+
+
+def _load_qkv(config, prefix: str, weights, head_size, num_heads):
+ """Load QKV from a single, transposed matrix."""
+
+ slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
+ shape = slice_.get_shape()
+ total_size = shape[1]
+ assert total_size % 3 == 0, f"Prepacked is not divisible by {3}"
+ world_size = weights.process_group.size()
+ single_size = total_size // 3
+ assert single_size % world_size == 0
+ rank = weights.process_group.rank()
+
+ # Weights
+ block_size = single_size // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+ tensors = []
+ for i in range(3):
+ tensor = slice_[:, start + i * single_size : stop + i * single_size]
+ tensors.append(tensor)
+ weight = torch.cat(tensors, dim=1).T
+ weight = weight.to(dtype=weights.dtype)
+ weight = weight.to(device=weights.device)
+
+ # Bias
+ slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
+ shape = slice_.get_shape()
+ total_size = shape[0]
+ single_size = total_size // 3
+ block_size = single_size // world_size
+ assert single_size % world_size == 0
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+ b = []
+ for i in range(3):
+ tensor = slice_[start + i * single_size : stop + i * single_size]
+ b.append(tensor)
+ bias = torch.cat(b, dim=0)
+ bias = bias.to(dtype=weights.dtype)
+ bias = bias.to(device=weights.device)
+ assert list(bias.shape) == [
+ 3 * num_heads * head_size
+ ], f"{weight.shape} != {[3 * num_heads * head_size]}"
+
+ return TensorParallelColumnLinear(get_linear(weight, bias))
+
+
+def load_row(config, prefix: str, weights, bias: bool):
+ """load_row, but with transposed weight matrices."""
+
+ if config.quantize == "gptq":
+ weight = weights.get_weights_row(prefix)
+ else:
+ weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
+
+ if bias and weights.process_group.rank() == 0:
+ # Rank is only on the first rank process
+ bias = weights.get_tensor(f"{prefix}.bias")
+ else:
+ bias = None
+
+ return TensorParallelRowLinear(
+ get_linear(weight, bias), process_group=weights.process_group
+ )
+
+
+def load_col(config, prefix: str, weights, bias: bool):
+ """load_col, but with transposed weight matrices."""
+ if config.quantize == "gptq":
+ weight = weights.get_multi_weights_col([prefix], dim=1)
+ else:
+ weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
+
+ if bias:
+ bias = weights.get_sharded(f"{prefix}.bias", dim=0)
+ else:
+ bias = None
+
+ return TensorParallelColumnLinear(get_linear(weight, bias))
+
+
+class FlashGPT2Attention(torch.nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights,
+ ):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+
+ self.head_size = self.hidden_size // self.num_heads
+ self.softmax_scale = self.head_size**-0.5
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+
+ self.query_key_value = load_qkv(
+ config,
+ prefix=prefix,
+ weights=weights,
+ head_size=self.head_size,
+ num_heads=self.num_heads,
+ )
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ self.o_proj = load_row(
+ config,
+ prefix=f"{prefix}.c_proj",
+ weights=weights,
+ bias=True,
+ )
+
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_heads, dtype=torch.int32, device=weights.device
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ query, key, value = self.query_key_value(hidden_states).split(
+ self.head_size * self.num_heads, dim=1
+ )
+ query = query.view(-1, self.num_heads, self.head_size)
+ key = key.view(-1, self.num_heads, self.head_size)
+ value = value.view(-1, self.num_heads, self.head_size)
+
+ kv_cache.store(
+ key=key,
+ value=value,
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=key,
+ value=value,
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
+
+
+class GPT2MLP(nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+ act = config.activation_function
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+
+ self.c_fc = load_col(
+ config, prefix=f"{prefix}.c_fc", weights=weights, bias=True
+ )
+ self.c_proj = load_row(
+ config,
+ prefix=f"{prefix}.c_proj",
+ weights=weights,
+ bias=True,
+ )
+
+ intermediate_size = (
+ config.n_inner if config.n_inner is not None else 4 * config.hidden_size
+ )
+
+ self.intermediate_size = intermediate_size // weights.process_group.size()
+
+ def forward(self, hidden_states):
+ hidden_states = self.c_fc(hidden_states)
+ hidden_states = self.act(hidden_states)
+ return self.c_proj(hidden_states)
+
+
+class FlashGPT2Layer(nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+ self.self_attn = FlashGPT2Attention(
+ prefix=f"{prefix}.attn", config=config, weights=weights
+ )
+ self.mlp = GPT2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+
+ self.input_layernorm = nn.LayerNorm.load(
+ prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
+ )
+ self.post_attention_layernorm = nn.LayerNorm.load(
+ prefix=f"{prefix}.ln_2",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ hidden_states,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states = attn_output + residual
+ residual = hidden_states
+
+ hidden_states = self.post_attention_layernorm(hidden_states)
+
+ mlp_output = self.mlp(hidden_states)
+
+ return residual + mlp_output, residual
+
+
+class FlashGPT2Model(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ process_group = weights.process_group
+ self.tp_rank = process_group.rank()
+ self.tp_world_size = process_group.size()
+ self.layers = nn.ModuleList(
+ [
+ FlashGPT2Layer(
+ prefix=(
+ f"h.{layer_id}" if not prefix else f"{prefix}.h.{layer_id}"
+ ),
+ config=config,
+ weights=weights,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+
+ self.norm = nn.LayerNorm.load(
+ prefix="ln_f" if not prefix else f"{prefix}.ln_f",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+
+ self.gradient_checkpointing = False
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ if hpu_attention_meta is not None:
+ hpu_attention_meta = set_block_mapping(
+ hpu_attention_meta, inputs_embeds.shape[0]
+ )
+ hidden_states = inputs_embeds
+
+ residual = None
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ hidden_states = self.norm(hidden_states)
+
+ return hidden_states
+
+
+class FlashGPT2ForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=("wte" if not prefix else f"{prefix}.wte"),
+ weights=weights,
+ )
+ self.embed_positions = TensorParallelEmbedding(
+ prefix=("wpe" if not prefix else f"{prefix}.wpe"),
+ weights=weights,
+ )
+
+ self.model = FlashGPT2Model(prefix, config, weights)
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix="wte" if not prefix else f"{prefix}.wte",
+ weights=weights,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ token_embeds = self.embed_tokens(input_ids)
+ position_embeds = self.embed_positions(position_ids)
+ inputs_embeds = token_embeds + position_embeds
+ hidden_states = self.model(
+ inputs_embeds,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py
new file mode 100644
index 00000000000..1e7a867c520
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py
@@ -0,0 +1,405 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from typing import Optional, List, Tuple
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ set_block_mapping,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+ get_linear,
+)
+from text_generation_server.layers.rotary import (
+ PositionRotaryEmbedding,
+)
+from text_generation_server.layers.layernorm import (
+ FastLayerNorm,
+)
+from habana_frameworks.torch.hpex.kernels import (
+ RotaryPosEmbeddingMode,
+ apply_rotary_pos_emb,
+)
+import habana_frameworks.torch as htorch
+
+
+def load_attention(config, prefix: str, weights):
+ return TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=False,
+ )
+
+
+def load_row(config, prefix: str, weights, bias: bool):
+ weight = weights.get_weights_row(prefix)
+
+ if bias and weights.process_group.rank() == 0:
+ # Rank is only on the first rank process
+ bias = weights.get_tensor(f"{prefix}.bias")
+ else:
+ bias = None
+
+ linear = get_linear(weight, bias)
+ return TensorParallelRowLinear(linear, process_group=weights.process_group)
+
+
+class GPTJRotary(PositionRotaryEmbedding):
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ ):
+ num_tokens = query.shape[0]
+ head_size = query.shape[-1]
+ rope_mode = RotaryPosEmbeddingMode.PAIRWISE
+ sin = torch.repeat_interleave(sin, 2, dim=-1)
+ cos = torch.repeat_interleave(cos, 2, dim=-1)
+ rotary_dim = cos.shape[-1]
+ query_shape = query.shape
+ query = query.view(num_tokens, -1, head_size)
+ query_rot = query[..., :rotary_dim]
+ query_pass = query[..., rotary_dim:]
+ query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
+ query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
+
+ key_shape = key.shape
+ key = key.view(num_tokens, -1, head_size)
+ key_rot = key[..., :rotary_dim]
+ key_pass = key[..., rotary_dim:]
+ key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
+ key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
+
+
+class FlashGPTJAttention(torch.nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights,
+ rotary_emb,
+ ):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+
+ self.head_size = self.hidden_size // self.num_heads
+ self.softmax_scale = self.head_size**-0.5
+ self.rotary_dim = config.rotary_dim
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+
+ self.query_key_value = load_attention(
+ config,
+ prefix=prefix,
+ weights=weights,
+ )
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ self.o_proj = load_row(
+ config,
+ prefix=f"{prefix}.out_proj",
+ weights=weights,
+ bias=False,
+ )
+
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_heads, dtype=torch.int32, device=weights.device
+ )
+ self.rotary_emb = rotary_emb
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ query, key, value = self.query_key_value(hidden_states).split(
+ self.head_size * self.num_heads, dim=1
+ )
+ query = query.view(-1, self.num_heads, self.head_size)
+ key = key.view(-1, self.num_heads, self.head_size)
+ value = value.view(-1, self.num_heads, self.head_size)
+
+ # Compute rotary embeddings on rotary_ndims
+ if self.rotary_dim is not None:
+ self.rotary_emb(
+ query[..., : self.rotary_dim], key[..., : self.rotary_dim], cos, sin
+ )
+ else:
+ self.rotary_emb(query, key, cos, sin)
+
+ kv_cache.store(
+ key=key,
+ value=value,
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=key,
+ value=value,
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
+
+
+class GPTJMLP(nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+ act = config.activation_function
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+
+ self.fc_in = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.fc_in", weights=weights, bias=True
+ )
+
+ self.fc_out = load_row(
+ config,
+ prefix=f"{prefix}.fc_out",
+ weights=weights,
+ bias=True,
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = self.fc_in(hidden_states)
+ hidden_states = self.act(hidden_states)
+ return self.fc_out(hidden_states)
+
+
+class FlashGPTJLayer(nn.Module):
+ def __init__(self, prefix: str, config, weights, rotary_emb):
+ super().__init__()
+ self.self_attn = FlashGPTJAttention(
+ prefix=f"{prefix}.attn",
+ config=config,
+ weights=weights,
+ rotary_emb=rotary_emb,
+ )
+ self.mlp = GPTJMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+
+ self.input_layernorm = FastLayerNorm.load(
+ prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+ # Self Attention
+ attn_output = self.self_attn(
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ feed_forward_hidden_states = self.mlp(hidden_states)
+
+ return attn_output + feed_forward_hidden_states, residual
+
+
+class FlashGPTJModel(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+ self.config = config
+
+ self.wte = TensorParallelEmbedding(prefix=f"{prefix}.wte", weights=weights)
+ rotary_emb = GPTJRotary.static(
+ config=config,
+ dim=config.rotary_dim,
+ base=10000,
+ device=weights.device,
+ )
+ self.layers = nn.ModuleList(
+ [
+ FlashGPTJLayer(
+ prefix=(
+ f"h.{layer_id}" if not prefix else f"{prefix}.h.{layer_id}"
+ ),
+ config=config,
+ weights=weights,
+ rotary_emb=rotary_emb,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+
+ self.ln_f = FastLayerNorm.load(
+ prefix="ln_f" if not prefix else f"{prefix}.ln_f",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+
+ self.gradient_checkpointing = False
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor],
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ if hpu_attention_meta is not None:
+ hpu_attention_meta = set_block_mapping(
+ hpu_attention_meta, input_ids.shape[0]
+ )
+ hidden_states = self.wte(input_ids)
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ hidden_states, _ = self.ln_f(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashGPTJForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+ if not prefix:
+ prefix = "transformer"
+ else:
+ prefix = f"{prefix}.transformer"
+ self.model = FlashGPTJModel(prefix, config, weights)
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix="lm_head",
+ weights=weights,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ hidden_states = self.model(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py
new file mode 100644
index 00000000000..1db8ad10e50
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py
@@ -0,0 +1,1439 @@
+# coding=utf-8
+# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import math
+import torch.utils.checkpoint
+from torch import nn
+import torch.nn.functional as F
+
+import habana_frameworks.torch as htorch
+from transformers.cache_utils import Cache
+from transformers.activations import ACT2FN
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
+from transformers.modeling_outputs import (
+ BaseModelOutput,
+)
+
+from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+
+from text_generation_server.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ TensorParallelRowLinear,
+ SpeculativeHead,
+ FastLinear,
+)
+from text_generation_server.layers.layernorm import FastRMSNorm
+from text_generation_server.layers.attention import (
+ KVCache,
+ paged_attention,
+ set_block_mapping,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.models.custom_modeling.flash_llama_modeling import (
+ FlashLlamaAttention,
+)
+
+
+def reshape_for_broadcast(freqs: torch.Tensor, target):
+ ndim = len(target)
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(target)]
+ return freqs.view(*shape)
+
+
+def apply_rotary_emb(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ freqs_ci: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ query_shape = query.shape
+ key_shape = key.shape
+ cos_emb, sin_emb = freqs_ci.split(1, dim=-1)
+
+ if len(query.shape) == 3:
+ query = query.unsqueeze(0)
+ key = key.unsqueeze(0)
+
+ query_reshaped = query.float().reshape(*query.shape[:-1], -1, 2)
+ key_reshaped = key.float().reshape(*key.shape[:-1], -1, 2)
+ q_shape = query_reshaped.shape[:-1]
+ cos_emb = reshape_for_broadcast(cos_emb, q_shape)
+ sin_emb = reshape_for_broadcast(sin_emb, q_shape)
+ x_q, y_q = query_reshaped.unbind(-1)
+ x_k, y_k = key_reshaped.unbind(-1)
+
+ x_q_rot = x_q * cos_emb - y_q * sin_emb
+ y_q_rot = x_q * sin_emb + y_q * cos_emb
+ x_k_rot = x_k * cos_emb - y_k * sin_emb
+ y_k_rot = x_k * sin_emb + y_k * cos_emb
+
+ query_out = torch.stack([x_q_rot, y_q_rot], dim=-1).flatten(-2)
+ key_out = torch.stack([x_k_rot, y_k_rot], dim=-1).flatten(-2)
+ query_out = query_out.view(*query_shape)
+ key_out = key_out.view(*key_shape)
+ return query_out.type_as(query), key_out.type_as(key)
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(
+ batch, num_key_value_heads, n_rep, slen, head_dim
+ )
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class Llama4TextExperts(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.process_group = weights.process_group
+ self.num_experts = config.num_local_experts
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+ self.hidden_size = config.hidden_size
+ self.expert_dim = self.intermediate_size
+ self.gate_up_proj = nn.Parameter(
+ weights.get_packed_sharded(f"{prefix}.gate_up_proj", dim=-1, block_sizes=2),
+ requires_grad=False,
+ )
+ self.down_proj = nn.Parameter(
+ weights.get_sharded(f"{prefix}.down_proj", dim=1), requires_grad=False
+ )
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """
+ This should really not be run on a single machine, as we are reaching compute bound:
+ - the inputs are expected to be "sorted" per expert already.
+ - the weights are viewed with another dim, to match num_expert, 1, shape * num_tokens, shape
+
+ Args:
+ hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
+ selected_experts (torch.Tensor): (batch_size * token_num, top_k)
+ routing_weights (torch.Tensor): (batch_size * token_num, top_k)
+ Returns:
+ torch.Tensor
+ """
+ gate_up_proj = self.gate_up_proj.view(self.num_experts, -1, 2 * self.expert_dim)
+ down_proj = self.down_proj.view(self.num_experts, self.expert_dim, -1)
+ hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
+ gate_up = torch.bmm(hidden_states, gate_up_proj)
+ gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
+ next_states = torch.bmm((up * self.act_fn(gate)), down_proj)
+ next_states = next_states.view(-1, self.hidden_size)
+
+ # Reduce sum
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(next_states, group=self.process_group)
+
+ return next_states
+
+
+# Phi3MLP
+class Llama4TextMLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+ self.gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config=config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+ self.down_proj = TensorParallelRowLinear.load(
+ config=config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ gate_up_states = self.gate_up_proj(x)
+ gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
+ return self.down_proj(self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1])
+
+
+class Llama4TextL2Norm(torch.nn.Module):
+ def __init__(self, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ return self._norm(x.float()).type_as(x)
+
+ def extra_repr(self):
+ return f"eps={self.eps}"
+
+
+class Llama4TextMoe(nn.Module):
+ def __init__(
+ self,
+ prefix,
+ config,
+ weights,
+ ):
+ super().__init__()
+ self.top_k = config.num_experts_per_tok
+ self.hidden_dim = config.hidden_size
+ self.num_experts = config.num_local_experts
+ self.experts = Llama4TextExperts(
+ config=config, prefix=f"{prefix}.experts", weights=weights
+ )
+ self.router = FastLinear.load(
+ config=config, prefix=f"{prefix}.router", weights=weights, bias=False
+ )
+ self.shared_expert = Llama4TextMLP(
+ config=config, prefix=f"{prefix}.shared_expert", weights=weights
+ )
+ self.process_group = weights.process_group
+
+ def forward(self, hidden_states, adapter_data):
+ seq_len, hidden_dim = hidden_states.shape
+ hidden_states = hidden_states.view(-1, self.hidden_dim)
+ tokens_per_expert = hidden_states.shape[0]
+ router_logits = self.router(hidden_states)
+
+ router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
+ router_scores = (
+ torch.full_like(router_logits, float("-inf"))
+ .scatter_(1, router_indices, router_top_value)
+ .transpose(0, 1)
+ )
+ # We do this to make sure we have -inf for non topK tokens before going through the !
+ # Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this!
+ router_indices = (
+ torch.arange(tokens_per_expert, device=hidden_states.device)
+ .view(1, -1)
+ .expand(router_scores.size(0), -1)
+ )
+ router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
+
+ router_indices = router_indices.reshape(-1, 1).expand(-1, self.hidden_dim)
+ routed_in = torch.gather(
+ input=hidden_states,
+ dim=0,
+ index=router_indices,
+ ).to(hidden_states.device)
+
+ # we gather inputs corresponding to each expert based on the router indices
+ routed_in = routed_in * router_scores.reshape(-1, 1)
+ routed_out = self.experts(routed_in)
+ out = self.shared_expert(hidden_states)
+
+ # now that we finished expert computation -> we scatter add because we gathered previously
+ # we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound
+ # this scales a lot better if you do EP!
+ out.scatter_add_(
+ dim=0, index=router_indices, src=routed_out.view(-1, self.hidden_dim)
+ )
+ return out
+
+
+class Llama4TextRotaryEmbedding(nn.Module):
+ def __init__(self, config, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ self.rope_type = "llama3" if config.rope_scaling is not None else "default"
+
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ def forward(self, x, position_ids):
+ inv_freq_expanded = (
+ self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ )
+ position_ids_expanded = position_ids[:, None, :].float()
+ device_type = (
+ x.device.type
+ if isinstance(x.device.type, str) and x.device.type != "mps"
+ else "cpu"
+ )
+ inv_freq_expanded = inv_freq_expanded.to(device_type)
+ position_ids_expanded = position_ids_expanded.to(device_type)
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
+ freqs_cis = (
+ torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
+ * self.attention_scaling
+ )
+ return freqs_cis.to(dtype=x.dtype, device=x.device)
+
+
+class Llama4TextAttention(FlashLlamaAttention):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, prefix, config, weights, layer_idx):
+ super().__init__(layer_idx, prefix, config, weights, None)
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(
+ config, "head_dim", config.hidden_size // config.num_attention_heads
+ )
+ self.num_key_value_groups = (
+ config.num_attention_heads // config.num_key_value_heads
+ )
+ self.scaling = self.head_dim**-0.5
+ self.attn_scale = config.attn_scale
+ self.floor_scale = config.floor_scale
+ self.attn_temperature_tuning = config.attn_temperature_tuning
+ self.attention_dropout = config.attention_dropout
+ self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers
+
+ if self.config.use_qk_norm and self.use_rope:
+ self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ freqs_ci,
+ cu_seqlen_prefill,
+ kv_cache: KVCache,
+ slots,
+ seqlen,
+ adapter_data,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bs = seqlen.input_lengths.shape[0]
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+ qkv = self.query_key_value(hidden_states, adapter_data)
+ query_states, key_states, value_states = qkv.split(
+ [
+ self.head_dim * self.num_heads,
+ self.head_dim * self.num_key_value_heads,
+ self.head_dim * self.num_key_value_heads,
+ ],
+ dim=-1,
+ )
+
+ query_states = query_states.view(hidden_shape)
+ key_states = key_states.view(hidden_shape)
+ value_states = value_states.view(hidden_shape)
+
+ if self.use_rope: # the 16E model skips rope for long context on certain layers
+ query_states, key_states = apply_rotary_emb(
+ query_states, key_states, freqs_ci
+ )
+
+ if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm
+ query_states = self.qk_norm(query_states)
+ key_states = self.qk_norm(key_states)
+
+ kv_cache.store(
+ key=key_states,
+ value=value_states,
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers
+ if self.attn_temperature_tuning and not self.use_rope:
+ attn_scales = (
+ torch.log(
+ torch.floor((position_ids.float() + 1.0) / self.floor_scale) + 1.0
+ )
+ * self.attn_scale
+ + 1.0
+ )
+ attn_scales = attn_scales.view(*input_shape, 1, 1)
+ query_states = (query_states * attn_scales).to(query_states.dtype)
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ query = query_states.view(bs, -1, self.num_heads, self.head_dim).transpose(
+ 1, 2
+ )
+ key = key_states.view(
+ bs, -1, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+ value = value_states.view(
+ bs, -1, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+ key = repeat_kv(key, self.num_key_value_groups)
+ value = repeat_kv(value, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None and causal_mask.ndim == 4:
+ causal_mask = causal_mask[:, :, :, : key.shape[-2]]
+ is_causal = query.shape[2] > 1 and causal_mask is None
+ # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query,
+ key,
+ value,
+ attn_mask=causal_mask,
+ dropout_p=0,
+ scale=self.scaling,
+ is_causal=is_causal,
+ )
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query_states,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output, adapter_data)
+ return attn_output
+
+
+class Llama4TextDecoderLayer(nn.Module):
+ def __init__(self, prefix, config, weights, layer_idx):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = Llama4TextAttention(
+ f"{prefix}.self_attn", config, weights, layer_idx
+ )
+ self.use_chunked_attention = int((layer_idx + 1) % 4 != 0) # <=> use rope
+ self.is_moe_layer = layer_idx in config.moe_layers
+ if self.is_moe_layer: # the 128E model interleaves dense / sparse
+ self.feed_forward = Llama4TextMoe(f"{prefix}.feed_forward", config, weights)
+ else:
+ self.feed_forward = Llama4TextMLP(f"{prefix}.feed_forward", config, weights)
+
+ self.input_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.post_attention_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ freqs_ci,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ attention_mask: Optional[torch.Tensor] = None,
+ chunk_causal_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
+ ) -> Tuple[
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
+ ]:
+ residual = hidden_states
+ hidden_states, _ = self.input_layernorm(hidden_states)
+
+ # use local attention mask for ROPE layers
+ if self.use_chunked_attention and chunk_causal_mask is not None:
+ attention_mask = chunk_causal_mask
+
+ attention_states = self.self_attn(
+ hidden_states,
+ freqs_ci,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ hidden_states = residual + attention_states
+
+ # Fully Connected
+ residual = hidden_states
+
+ hidden_states, _ = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.feed_forward(hidden_states, adapter_data)
+ hidden_states = residual + hidden_states.view(residual.shape)
+ return hidden_states
+
+
+class Llama4TextModel(nn.Module):
+
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=f"{prefix}.embed_tokens", weights=weights
+ )
+ self.layers = nn.ModuleList(
+ [
+ Llama4TextDecoderLayer(
+ prefix=f"{prefix}.layers.{layer_idx}",
+ config=config,
+ weights=weights,
+ layer_idx=layer_idx,
+ )
+ for layer_idx in range(config.num_hidden_layers)
+ ]
+ )
+
+ # self.norm = Llama4TextRMSNorm(prefix=f"{prefix}.norm", config=config, weights=weights)
+ self.norm = FastRMSNorm.load(
+ prefix=f"{prefix}.norm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ self.rotary_emb = Llama4TextRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ adapter_data,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if hpu_attention_meta is not None:
+ hpu_attention_meta = set_block_mapping(
+ hpu_attention_meta, inputs_embeds.shape[0]
+ )
+
+ hidden_states = inputs_embeds
+ bs = seqlen.input_lengths.shape[0]
+ seq_len = inputs_embeds.shape[0] / bs
+ cache_position = torch.arange(0, seq_len, device=inputs_embeds.device)
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask, chunk_causal_mask = self._update_causal_mask(
+ attention_mask,
+ inputs_embeds.view(bs, int(seq_len), -1),
+ cache_position,
+ None,
+ output_attentions=False,
+ use_cache=False,
+ )
+
+ freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1))
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ for i, layer in enumerate(self.layers):
+ hidden_states = layer(
+ hidden_states,
+ freqs_ci,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ adapter_data,
+ attention_mask=causal_mask,
+ chunk_causal_mask=chunk_causal_mask,
+ position_ids=position_ids,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ hidden_states, _ = self.norm(hidden_states)
+
+ return hidden_states
+
+ def _update_causal_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool = False,
+ chunked_attention_mask=None,
+ use_cache=True,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return (
+ attention_mask,
+ attention_mask,
+ ) # flash does not support chunked attn TODO support flash
+ return None, None
+
+ if self.config._attn_implementation not in ["sdpa", "flex_attention", "eager"]:
+ return None, None
+
+ sequence_length = input_tensor.shape[1]
+ attention_chunk_size = self.config.attention_chunk_size
+
+ first_cache_position = cache_position[0]
+
+ if past_key_values is not None:
+ full_cache_length = past_key_values.get_max_cache_shape() or sequence_length
+ else:
+ full_cache_length = (
+ attention_mask.shape[-1]
+ if attention_mask is not None
+ else sequence_length
+ )
+
+ cond1 = first_cache_position >= attention_chunk_size
+ cond2 = (first_cache_position < attention_chunk_size) & (
+ first_cache_position + sequence_length > attention_chunk_size
+ )
+ key_length = (
+ torch.where(
+ cond1,
+ attention_chunk_size + sequence_length - 1,
+ torch.where(
+ cond2, first_cache_position + sequence_length, attention_chunk_size
+ ),
+ )
+ if use_cache
+ else full_cache_length
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ dtype, device = input_tensor.dtype, input_tensor.device
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=max(full_cache_length, attention_chunk_size),
+ dtype=dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ device=device,
+ )
+ if full_cache_length > self.config.attention_chunk_size:
+ start_idx = max(first_cache_position - attention_chunk_size + 1, 0)
+ end_idx = start_idx + key_length
+ chunked_attention_mask = self.create_chunked_attention_mask(
+ self.config.attention_chunk_size,
+ start=start_idx, # same offset as with flex
+ end=end_idx,
+ device=device,
+ )
+
+ local_attention_mask = attention_mask[
+ :, start_idx:end_idx
+ ] # offset here as well
+ # It may be smaller than attention_chunk_size -> pad it
+ requires_padding = local_attention_mask.shape[-1] < attention_chunk_size
+ if requires_padding:
+ local_attention_mask = nn.functional.pad(
+ local_attention_mask,
+ (0, attention_chunk_size - local_attention_mask.shape[-1]),
+ )
+ # Depending on the padding, take the query tokens from the end or the cache_position
+ if not requires_padding:
+ chunked_attention_mask = chunked_attention_mask[
+ None, None, -sequence_length:, :
+ ]
+ else:
+ chunked_attention_mask = chunked_attention_mask[
+ None, None, cache_position, :
+ ]
+
+ chunked_attention_mask = chunked_attention_mask.expand(
+ input_tensor.shape[0], -1, -1, -1
+ )
+ chunked_attention_mask = (
+ chunked_attention_mask * local_attention_mask[:, None, None, :]
+ )
+ if self.config._attn_implementation == "eager":
+ min_dtype = torch.finfo(dtype).min
+ chunked_attention_mask = torch.where(
+ chunked_attention_mask == 0, min_dtype, 0.0
+ ).to(dtype)
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ and attention_mask.ndim == 4
+ and not output_attentions # Only unmask for 4d masks
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(
+ causal_mask, min_dtype
+ )
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if (
+ self.config._attn_implementation == "sdpa"
+ and chunked_attention_mask is not None
+ ):
+ chunked_attention_mask = chunked_attention_mask.bool()
+ causal_mask = causal_mask.bool()
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=first_cache_position,
+ is_training=self.training,
+ ):
+ causal_mask = None
+ return causal_mask, chunked_attention_mask
+
+ def create_chunked_attention_mask(
+ self, attention_chunk_size: int, start: int, end: int, device: torch.device
+ ) -> torch.Tensor:
+ """
+ Generate the following:
+
+ 'What' : 0 ■ ⬚ ⬚ ⬚ ⬚ ⬚ |
+ '▁is' : 1 ■ ■ ⬚ ⬚ ⬚ ⬚ |
+ '▁ch' : 2 ■ ■ ■ ⬚ ⬚ ⬚ |
+ 'unked' : 3 ⬚ ⬚ ⬚ ■ ⬚ ⬚ |
+ '▁attention': 4 ⬚ ⬚ ⬚ ■ ■ ⬚ |
+ '?' : 5 ⬚ ⬚ ⬚ ■ ■ ■ |
+
+ If the chunk size is 3.
+ This can just be applied over the already created attention mask
+ """
+ arange_vector = torch.arange(start, end, device=device)
+ block_pos = torch.abs(
+ arange_vector.unsqueeze(0) // attention_chunk_size
+ - arange_vector.unsqueeze(1) // attention_chunk_size
+ )
+ token_pos = arange_vector.unsqueeze(0) - arange_vector.unsqueeze(1)
+ mask = (block_pos == 0) & (token_pos <= 0)
+ return mask.to(device)
+
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ device (`torch.device`):
+ The device to place the 4D attention mask on.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length),
+ fill_value=min_dtype,
+ dtype=dtype,
+ device=device,
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(
+ target_length, device=device
+ ) > cache_position.to(device).reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = (
+ causal_mask.clone()
+ ) # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
+ :, None, None, :
+ ].to(device)
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[
+ :, :, :, :mask_length
+ ].masked_fill(padding_mask, min_dtype)
+
+ return causal_mask
+
+
+class Llama4ForCausalLM(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.model = Llama4TextModel(
+ prefix=f"{prefix}.model", config=config, weights=weights
+ )
+ self.vocab_size = config.vocab_size
+ self.lm_head = SpeculativeHead.load(
+ config,
+ f"{prefix}.lm_head",
+ weights,
+ )
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ adapter_data: Optional[torch.Tensor] = None,
+ lm_head_indices: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ hidden_states = self.model(
+ inputs_embeds,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data=adapter_data,
+ hpu_attention_meta=hpu_attention_meta,
+ attention_mask=attention_mask,
+ )
+
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+
+ logits, speculative_logits = self.lm_head(hidden_states)
+ return logits, speculative_logits
+
+
+class Llama4VisionMLP2(torch.nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.fc1 = TensorParallelColumnLinear.load(
+ config=config, prefix=f"{prefix}.fc1", weights=weights, bias=False
+ )
+ self.fc2 = TensorParallelRowLinear.load(
+ config=config, prefix=f"{prefix}.fc2", weights=weights, bias=False
+ )
+ self.activation_fn = nn.GELU() # ACT2FN[config.hidden_act]
+ self.dropout = config.projector_dropout
+
+ def forward(self, hidden_states):
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ return self.activation_fn(
+ hidden_states
+ ) # TODO: check if we need to apply activation again
+
+
+class Llama4MultiModalProjector(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.linear_1 = FastLinear.load(
+ config=config, prefix=f"{prefix}.linear_1", weights=weights, bias=False
+ )
+
+ def forward(self, image_features):
+ hidden_states = self.linear_1(image_features)
+ return hidden_states
+
+
+def pixel_shuffle(input_tensor, shuffle_ratio):
+ # input_tensor: [batch_size, num_patches, channels]
+ batch_size, num_patches, channels = input_tensor.shape
+ patch_size = int(math.sqrt(num_patches))
+
+ input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
+ batch_size, height, width, channels = input_tensor.size()
+ reshaped_tensor = input_tensor.view(
+ batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)
+ )
+ reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
+ reshaped_tensor = reshaped_tensor.view(
+ batch_size,
+ int(height * shuffle_ratio),
+ int(width * shuffle_ratio),
+ int(channels / (shuffle_ratio**2)),
+ )
+ reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
+
+ output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1])
+ return output_tensor
+
+
+class Llama4VisionPixelShuffleMLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
+ self.inner_dim = int(
+ config.projector_input_dim // (self.pixel_shuffle_ratio**2)
+ )
+ self.output_dim = config.projector_output_dim
+ self.mlp = Llama4VisionMLP2(
+ prefix=f"{prefix}.mlp", config=config, weights=weights
+ )
+
+ def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
+ encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
+ return self.mlp(encoded_patches)
+
+
+# TODO there is a different RoPE for vision encoder, defined as below
+def vision_reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor):
+ ndim = query.ndim
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(query.shape)]
+ return freqs_ci.view(*shape)
+
+
+class Llama4VisionAttention(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads // weights.process_group.size()
+ self.progress_group = weights.process_group
+
+ self.head_dim = config.hidden_size // config.num_attention_heads
+ self.num_key_value_groups = 1
+ self.attention_dropout = config.attention_dropout
+ self.qkv_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=True,
+ )
+ self.o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=True,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ freqs_ci: torch.Tensor, # Now takes (cos_theta, sin_theta) instead of complex
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ qkv = self.qkv_proj(hidden_states)
+
+ query_states, key_states, value_states = qkv.split(
+ [
+ self.head_dim * self.num_heads,
+ self.head_dim * self.num_heads,
+ self.head_dim * self.num_heads,
+ ],
+ dim=2,
+ )
+ query_states = query_states.view(hidden_shape)
+ key_states = key_states.view(hidden_shape)
+ value_states = value_states.view(hidden_shape)
+
+ query_states, key_states = apply_rotary_emb(
+ query_states, key_states, freqs_ci=freqs_ci
+ )
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attn_output = F.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=attention_mask,
+ is_causal=False,
+ dropout_p=0,
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output
+
+
+class Llama4VisionMLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.activation_fn = nn.GELU() # ACT2FN[config.hidden_act]
+ self.fc1 = TensorParallelColumnLinear.load(
+ prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
+ )
+ self.fc2 = TensorParallelRowLinear.load(
+ prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Llama4VisionEncoderLayer(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = Llama4VisionAttention(
+ prefix=f"{prefix}.self_attn", config=config, weights=weights
+ )
+ self.mlp = Llama4VisionMLP(
+ prefix=f"{prefix}.mlp", config=config, weights=weights
+ )
+
+ self.input_layernorm = nn.LayerNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05
+ )
+ self.post_attention_layernorm = nn.LayerNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=1e-05
+ )
+
+ def forward(
+ self,
+ hidden_state: torch.Tensor,
+ freqs_ci: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ # Self Attention
+ residual = hidden_state
+
+ hidden_state = self.input_layernorm(hidden_state)
+
+ hidden_state = self.self_attn(
+ hidden_state,
+ freqs_ci=freqs_ci,
+ attention_mask=attention_mask,
+ )
+
+ hidden_state = residual + hidden_state
+
+ # Feed forward
+ residual = hidden_state
+ hidden_state = self.post_attention_layernorm(hidden_state)
+ hidden_state = self.mlp(hidden_state)
+ hidden_state = residual + hidden_state
+ outputs = (hidden_state,)
+ return outputs
+
+
+class Llama4VisionEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`Llama4VisionEncoderLayer`].
+
+ Args:
+ config: Llama4VisionConfig
+ """
+
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList(
+ [
+ Llama4VisionEncoderLayer(
+ prefix=f"{prefix}.layers.{layer_id}", config=config, weights=weights
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.gradient_checkpointing = False
+ self.config = config
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ freqs_ci: torch.Tensor, # TODO move this to an attribute instead of keeping it around
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+
+ for encoder_layer in self.layers:
+ layer_outputs = encoder_layer(
+ hidden_state=hidden_states,
+ attention_mask=attention_mask,
+ freqs_ci=freqs_ci,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ return hidden_states
+
+
+class Llama4UnfoldConvolution(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ kernel_size = config.patch_size
+ if isinstance(kernel_size, int):
+ kernel_size = (kernel_size, kernel_size)
+ self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size)
+ self.linear = FastLinear.load(
+ config=config, prefix=f"{prefix}.linear", weights=weights, bias=False
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.unfold(hidden_states)
+ hidden_states = hidden_states.permute(0, 2, 1)
+ hidden_states = self.linear(hidden_states)
+ return hidden_states
+
+
+class Llama4VisionRotaryEmbedding(nn.Module):
+ def __init__(self, config, weights):
+ super().__init__()
+ # Calculate image grid indices
+ idx = config.image_size // config.patch_size
+ img_idx = torch.arange(
+ idx**2, dtype=torch.int32, device=weights.device
+ ).reshape(idx**2, 1)
+ img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
+
+ img_idx[-1, -1] = -2 # ID_CLS_TOKEN
+ # Calculate x and y coordinates
+ frequencies_x = img_idx % idx # x coordinates
+ frequencies_y = torch.div(img_idx, idx, rounding_mode="floor") # y coordinates
+ # Calculate frequency components
+ freq_dim = config.hidden_size // config.num_attention_heads // 2
+ rope_freq = 1.0 / (
+ config.rope_theta
+ ** (
+ torch.arange(0, freq_dim, 2, device=weights.device)[
+ : (freq_dim // 2)
+ ].float()
+ / freq_dim
+ )
+ )
+
+ # Compute frequencies for x and y directions
+ freqs_x = (frequencies_x + 1)[..., None] * rope_freq[None, None, :]
+ freqs_x = freqs_x.repeat_interleave(2, dim=-1)
+ freqs_y = (frequencies_y + 1)[..., None] * rope_freq[None, None, :]
+ freqs_y = freqs_y.repeat_interleave(2, dim=-1)
+
+ # Combine frequencies and mask special tokens
+ freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
+ freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
+
+ freq_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
+ self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2
+
+ def forward(self, hidden_states):
+ """
+ Returns the rotary embedding components (cosθ, sinθ) for the given hidden states
+ """
+ return self.freqs_ci.to(dtype=hidden_states.dtype, device=hidden_states.device)
+
+
+class Llama4VisionModel(nn.Module):
+
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+ self.hidden_size = config.hidden_size
+ self.num_channels = config.num_channels
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
+ self.scale = config.hidden_size**-0.5
+
+ self.patch_embedding = Llama4UnfoldConvolution(
+ prefix=f"{prefix}.patch_embedding", config=config, weights=weights
+ )
+
+ self.class_embedding = nn.Parameter(
+ weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False
+ )
+
+ self.positional_embedding_vlm = nn.Parameter(
+ weights.get_tensor(f"{prefix}.positional_embedding_vlm"),
+ requires_grad=False,
+ )
+
+ self.rotary_embedding = Llama4VisionRotaryEmbedding(config, weights)
+
+ # layer norms
+ self.layernorm_pre = nn.LayerNorm.load(
+ prefix=f"{prefix}.layernorm_pre", weights=weights, eps=config.norm_eps
+ )
+ self.layernorm_post = nn.LayerNorm.load(
+ prefix=f"{prefix}.layernorm_post", weights=weights, eps=config.norm_eps
+ )
+
+ # encoders
+ self.model = Llama4VisionEncoder(
+ prefix=f"{prefix}.model", config=config, weights=weights
+ )
+ self.vision_adapter = Llama4VisionPixelShuffleMLP(
+ prefix=f"{prefix}.vision_adapter", config=config, weights=weights
+ )
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ # num_concurrent_media and num_chunks are both currently 1
+ batch_size_times_num_tiles, num_channels, height, width = pixel_values.shape
+ num_concurrent_media = 1
+ num_chunks = 1
+ hidden_state = self.patch_embedding(pixel_values)
+ _, num_patches, hidden_dim = hidden_state.shape
+
+ # Add cls token
+ hidden_state = hidden_state.reshape(
+ batch_size_times_num_tiles * num_concurrent_media * num_chunks,
+ num_patches,
+ hidden_dim,
+ )
+ class_embedding = self.class_embedding.expand(
+ hidden_state.shape[0], 1, hidden_state.shape[-1]
+ )
+ hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
+ num_patches += 1
+
+ # Position embeddings
+ hidden_state = hidden_state.reshape(
+ batch_size_times_num_tiles * num_concurrent_media,
+ num_chunks,
+ num_patches,
+ hidden_dim,
+ )
+ positional_embedding = self.positional_embedding_vlm.to(
+ dtype=hidden_state.dtype, device=hidden_state.device
+ )
+ hidden_state = hidden_state + positional_embedding
+ hidden_state = self.layernorm_pre(hidden_state)
+ hidden_state = hidden_state.view(batch_size_times_num_tiles, -1, hidden_dim)
+ freqs_ci = self.rotary_embedding(pixel_values)
+
+ hidden_state = self.model(
+ hidden_state,
+ attention_mask=None,
+ freqs_ci=freqs_ci,
+ )
+
+ hidden_state = self.layernorm_post(hidden_state)
+
+ hidden_state = hidden_state[:, :-1, :]
+
+ # now, we use Llama4VisionPixelShuffle + mlp to project embeddings
+ hidden_state = self.vision_adapter(hidden_state)
+ return hidden_state
+
+
+class Llama4ForConditionalGeneration(nn.Module):
+
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+ self.config = config
+ config.vision_config.quantize = None
+ config.vision_config.speculator = config.speculator
+ config.text_config.quantize = config.quantize
+ config.text_config.speculator = config.speculator
+ config.text_config._attn_implementation = None
+
+ self.vision_model = Llama4VisionModel(
+ prefix="vision_model", config=config.vision_config, weights=weights
+ )
+
+ self.multi_modal_projector = Llama4MultiModalProjector(
+ prefix="multi_modal_projector", config=config, weights=weights
+ )
+
+ self.text_model = Llama4ForCausalLM(
+ prefix="language_model", config=config.text_config, weights=weights
+ )
+ self.vocab_size = config.text_config.vocab_size
+ self.pad_token_id = (
+ self.config.pad_token_id if self.config.pad_token_id is not None else -1
+ )
+ self.config = config
+ self.dtype = weights.dtype
+ self.device = weights.device
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ vision_feature_layer: Union[int, List[int]],
+ vision_feature_select_strategy: str,
+ **kwargs,
+ ):
+ """
+ Obtains image last hidden states from the vision tower and apply al projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
+ The tensors corresponding to the input images.
+ vision_feature_layer (`Union[int, List[int]]`):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ vision_feature_select_strategy (`str`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
+ """
+ if vision_feature_select_strategy not in ["default", "full"]:
+ raise ValueError(
+ f"Unexpected select feature strategy: {self.vision_feature_select_strategy}"
+ )
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
+ hidden_state = self.vision_model(pixel_values)
+ return hidden_state
+
+ def get_vision_embeds(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_attention_mask: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ ):
+ image_features = self.get_image_features(
+ pixel_values=pixel_values,
+ vision_feature_layer=self.config.vision_config.vision_feature_layer,
+ vision_feature_select_strategy=self.config.vision_config.vision_feature_select_strategy,
+ image_sizes=image_sizes,
+ )
+ vision_flat = image_features.view(-1, image_features.size(-1))
+ image_features = self.multi_modal_projector(vision_flat)
+ return image_features
+
+ def get_inputs_embeds(
+ self,
+ input_ids: torch.Tensor,
+ vision_embeds: torch.Tensor = None,
+ pixel_values: torch.FloatTensor = None,
+ image_sizes: Optional[torch.LongTensor] = None,
+ ):
+ inputs_embeds = self.text_model.model.embed_tokens(input_ids)
+
+ if vision_embeds is not None:
+ # When we generate, we don't want to replace the potential image_token_id that we generated by images
+ # that simply don't exist
+ original_inputs_embeds_shape = inputs_embeds.shape
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
+ -1
+ )
+ final_mask = special_image_mask.to(inputs_embeds.device)
+ inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1))
+
+ final_mask_1d = final_mask[..., 0].reshape(-1)
+ num_tokens_to_fill = final_mask_1d.sum()
+
+ if num_tokens_to_fill != vision_embeds.size(0):
+ raise ValueError(
+ f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
+ f"but multi_modal_projector returned {vision_embeds.size(0)}"
+ )
+
+ expanded_mask = final_mask_1d.unsqueeze(-1).expand(
+ -1, inputs_embeds.size(-1)
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, vision_embeds)
+ inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
+ return inputs_embeds
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: Optional[torch.LongTensor] = None,
+ cu_seqlen_prefill: Optional[torch.Tensor] = None,
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = None,
+ slots: torch.Tensor = None,
+ seqlen: Seqlen = None,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
+ lm_head_indices: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ **lm_kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+
+ logits, speculative_logits = self.text_model(
+ inputs_embeds,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ adapter_data,
+ lm_head_indices,
+ attention_mask,
+ )
+
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
new file mode 100644
index 00000000000..5ec10f86eb0
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
@@ -0,0 +1,675 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from contextlib import contextmanager
+from typing import List, Optional, Tuple, Type
+
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+import habana_frameworks.torch as htorch
+from text_generation_server.layers.attention import (
+ KVCache,
+ get_kv_scales,
+)
+from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ set_block_mapping,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+ TensorParallelMultiAdapterLinear,
+ TensorParallelAdapterRowLinear,
+)
+from text_generation_server.layers.rotary import PositionRotaryEmbedding
+from text_generation_server.layers.layernorm import (
+ FastRMSNorm,
+ FastLayerNorm,
+)
+from text_generation_server.layers import (
+ FastLinear,
+)
+from text_generation_server.utils.weights import (
+ Weights,
+)
+from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
+
+
+def load_attention(config, prefix: str, weights, layer_id):
+ # Only defined in granite.
+ bias = getattr(config, "attention_bias", False)
+ head_size = config.hidden_size // config.num_attention_heads
+ sizes = None
+ prefixes = None
+
+ if config.model_type == "phi3":
+ base_layer = TensorParallelColumnLinear.load_qkv(
+ config,
+ prefix=f"{prefix}.qkv_proj",
+ weights=weights,
+ bias=bias,
+ num_heads=config.num_attention_heads,
+ num_key_value_heads=config.num_key_value_heads,
+ )
+ prefixes = ["qkv_proj"]
+ elif config.model_type == "baichuan":
+ prefix = f"{prefix}.W_pack"
+ base_layer = TensorParallelColumnLinear.load_qkv(
+ config,
+ prefix=prefix,
+ weights=weights,
+ bias=bias,
+ num_heads=config.num_attention_heads,
+ num_key_value_heads=config.num_key_value_heads,
+ )
+ prefixes = [prefix]
+ else:
+ prefixes = ["q_proj", "k_proj", "v_proj"]
+ sizes = [
+ head_size * config.num_attention_heads,
+ head_size * config.num_key_value_heads,
+ head_size * config.num_key_value_heads,
+ ]
+ base_layer = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=bias,
+ )
+
+ return TensorParallelMultiAdapterLinear.load(
+ base_layer=base_layer,
+ layer_id=layer_id,
+ layer_names=prefixes,
+ sizes=sizes,
+ process_group=weights.process_group,
+ )
+
+
+@contextmanager
+def no_fp8(weights: Weights):
+ """De-activate fp8 auto conversion for the duration of this context manager"""
+ weights_loader = weights.weights_loader
+ if isinstance(weights_loader, HybridFP8UnquantLoader) and weights_loader.to_fp8:
+ weights_loader = HybridFP8UnquantLoader(
+ weights_loader.activation_scale_ub, to_fp8=False
+ )
+
+ with weights.use_loader(weights_loader):
+ yield
+
+
+class FlashLlamaAttention(torch.nn.Module):
+ def __init__(
+ self,
+ index: int,
+ prefix: str,
+ config,
+ weights,
+ rotary_emb,
+ ):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.head_size = self.hidden_size // self.num_heads
+
+ self.rotary_emb = rotary_emb
+
+ # `config.attention_multiplier` is used in Granite
+ self.softmax_scale = getattr(
+ config, "attention_multiplier", self.head_size**-0.5
+ )
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ if config.num_key_value_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_key_value_heads` must be divisible by `num_shards` (got `num_key_value_heads`: {config.num_key_value_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+
+ self.query_key_value = load_attention(config, prefix, weights, index)
+ self.index = index
+
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=getattr(config, "attention_bias", False),
+ )
+
+ self.o_proj = TensorParallelAdapterRowLinear.load(
+ o_proj,
+ index,
+ "o_proj",
+ process_group=weights.process_group,
+ )
+
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache: KVCache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ):
+ qkv = self.query_key_value(hidden_states, adapter_data)
+ query, kv = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ 2 * self.head_size * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+ query = query.view(-1, self.num_heads, self.head_size)
+ kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
+
+ self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
+
+ kv_cache.store(
+ key=kv[:, 0],
+ value=kv[:, 1],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=kv[:, 0],
+ value=kv[:, 1],
+ kv_scales=self.kv_scales,
+ kv_cache=kv_cache,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.o_proj(
+ attn_output.view(-1, self.num_heads * self.head_size), adapter_data
+ )
+
+
+class Phi3MoE(nn.Module):
+ def __init__(
+ self, prefix: str, config, moe_layer_cls: Type[MoELayer], weights: Weights
+ ):
+ super().__init__()
+
+ # gating
+ self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
+
+ self.moe = moe_layer_cls(
+ prefix=f"{prefix}.experts",
+ n_experts=config.num_local_experts,
+ n_expert_group=None,
+ renormalize=True,
+ topk=config.num_experts_per_tok,
+ topk_group=None,
+ weights=weights,
+ gate_proj_name="w1",
+ up_proj_name="w3",
+ down_proj_name="w2",
+ )
+
+ self.process_group = weights.process_group
+
+ def forward(self, x, adapter_data) -> torch.Tensor:
+ # router_logits: (num_tokens, n_experts)
+ router_logits = self.gate(x)
+ out = self.moe(x, gating_output=router_logits)
+
+ # Reduce sum
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(out, group=self.process_group)
+
+ return out.view(*x.shape)
+
+
+class LlamaMLP(nn.Module):
+ def __init__(self, prefix, config, weights, index):
+ super().__init__()
+ self.hidden_act = config.hidden_act
+ self.act = (
+ ACT2FN[self.hidden_act]
+ if "gelu" not in self.hidden_act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh"
+ if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
+ else "none"
+ ),
+ )
+ )
+ prefixes = None
+ sizes = None
+
+ # Fuse gate and up proj
+ bias = getattr(config, "mlp_bias", False)
+ if config.model_type == "phi3":
+ gate_up_proj = TensorParallelColumnLinear.load_gate_up(
+ config,
+ prefix=f"{prefix}.gate_up_proj",
+ weights=weights,
+ bias=bias,
+ )
+ else:
+ prefixes = ["gate_proj", "up_proj"]
+ sizes = [
+ config.intermediate_size,
+ config.intermediate_size,
+ ]
+ gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=bias,
+ )
+
+ self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
+ gate_up_proj,
+ index,
+ layer_names=prefixes,
+ sizes=sizes,
+ process_group=weights.process_group,
+ )
+
+ down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=bias,
+ )
+
+ self.down_proj = TensorParallelAdapterRowLinear.load(
+ down_proj,
+ index,
+ "down_proj",
+ process_group=weights.process_group,
+ )
+
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+
+ # TODO: This is a hotfix to be removed & properly refactored.
+ self.quantize = config.quantize
+
+ self.hidden_size = config.hidden_size
+
+ def forward(self, hidden_states, adapter_data):
+ gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
+ gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
+ return self.down_proj(
+ self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
+ )
+
+
+class FlashLlamaLayer(nn.Module):
+ def __init__(self, index, prefix, config, weights, rotary_emb):
+ super().__init__()
+
+ with no_fp8(weights):
+ self.self_attn = FlashLlamaAttention(
+ index=index,
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ rotary_emb=rotary_emb,
+ )
+
+ if config.model_type == "phimoe":
+ moe_layer_cls = (
+ SparseMoELayer
+ if SparseMoELayer.is_supported(weights)
+ else DenseMoELayer
+ )
+ self.mlp = Phi3MoE(
+ f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
+ )
+ # with moe the layernorms are are not rmsnorms and they have bias
+ self.input_layernorm = FastLayerNorm.load(
+ prefix=f"{prefix}.input_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.post_attention_layernorm = FastLayerNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ else:
+ self.mlp = LlamaMLP(
+ prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
+ )
+ self.input_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.post_attention_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ # Used in Granite
+ # This could eventually be baked into the weights like we do for the embeddings/lm_head
+ # but this would mean modifying the lora code
+ self.residual_multiplier = getattr(config, "residual_multiplier", None)
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ cross_attention_states,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ):
+ normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+ if self.residual_multiplier is not None:
+ attn_output *= self.residual_multiplier
+
+ normed_attn_res_output, attn_res = self.post_attention_layernorm(
+ attn_output, res
+ )
+
+ mlp_output = self.mlp(normed_attn_res_output, adapter_data)
+ if self.residual_multiplier is not None:
+ mlp_output *= self.residual_multiplier
+
+ return mlp_output, attn_res
+
+
+class FlashLlamaModel(torch.nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+
+ process_group = weights.process_group
+ self.tp_rank = process_group.rank()
+ self.tp_world_size = process_group.size()
+
+ # Skip fp8 quant for first and last layers
+ self.layers = nn.ModuleList()
+ self.cross_attention_layers = getattr(config, "cross_attention_layers", [])
+ # Setting defaults for baichuan custom config which doesn't apply them.
+ config.rope_theta = getattr(config, "rope_theta", 10000)
+ config.num_key_value_heads = getattr(
+ config, "num_key_value_heads", config.num_attention_heads
+ )
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.hidden_size // config.num_attention_heads,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+ with no_fp8(weights):
+ self.layers.append(
+ FlashLlamaLayer(
+ index=0,
+ prefix=f"{prefix}.layers.0",
+ config=config,
+ weights=weights,
+ rotary_emb=rotary_emb,
+ )
+ )
+
+ # Skip first and last layers
+ for layer_id in range(1, config.num_hidden_layers - 1):
+ if layer_id in self.cross_attention_layers:
+ from text_generation_server.models.custom_modeling.flash_mllama import (
+ FlashLlamaCrossLayer,
+ )
+
+ self.layers.append(
+ FlashLlamaCrossLayer(
+ index=layer_id,
+ prefix=(f"{prefix}.layers.{layer_id}"),
+ config=config,
+ weights=weights,
+ )
+ )
+ else:
+ self.layers.append(
+ FlashLlamaLayer(
+ index=layer_id,
+ prefix=(f"{prefix}.layers.{layer_id}"),
+ config=config,
+ weights=weights,
+ rotary_emb=rotary_emb,
+ )
+ )
+
+ with no_fp8(weights):
+ last_layer_id = config.num_hidden_layers - 1
+ self.layers.append(
+ FlashLlamaLayer(
+ index=last_layer_id,
+ prefix=(f"{prefix}.layers.{last_layer_id}"),
+ config=config,
+ weights=weights,
+ rotary_emb=rotary_emb,
+ )
+ )
+
+ self.norm = FastRMSNorm.load(
+ prefix=f"{prefix}.norm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ self.gradient_checkpointing = False
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ adapter_data,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ cross_attention_states=None,
+ ) -> torch.Tensor:
+ if hpu_attention_meta is not None:
+ hpu_attention_meta = set_block_mapping(
+ hpu_attention_meta, inputs_embeds.shape[0]
+ )
+
+ hidden_states = inputs_embeds
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ adapter_data,
+ cross_attention_states,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashLlamaForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights, name=None):
+ if name is None:
+ name = "model"
+ super().__init__()
+ with no_fp8(weights):
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=(
+ f"{name}.embed_tokens"
+ if not prefix
+ else f"{prefix}.{name}.embed_tokens"
+ ),
+ weights=weights,
+ )
+ self.model = FlashLlamaModel(
+ prefix=name if not prefix else f"{prefix}.{name}",
+ config=config,
+ weights=weights,
+ )
+ if config.tie_word_embeddings:
+ suffix = "model.embed_tokens"
+ else:
+ suffix = "lm_head"
+
+ # Used in Granite
+ embedding_multiplier = getattr(config, "embedding_multiplier", None)
+ if embedding_multiplier is not None:
+ self.embed_tokens.weight.data *= embedding_multiplier
+ prefix = suffix if not prefix or name != "model" else f"{prefix}.{suffix}"
+ with no_fp8(weights):
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix,
+ weights,
+ )
+
+ # Used in Granite
+ self.logits_scaling = getattr(config, "logits_scaling", None)
+ if self.logits_scaling is not None and self.lm_head.head is not None:
+ try:
+ # Scale the weights directly
+ self.lm_head.head.linear.weight.data /= self.logits_scaling
+ self.logits_scaled = True
+ except Exception:
+ self.logits_scaled = False
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ cross_attention_states=None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ inputs_embeds = self.embed_tokens(input_ids)
+ hidden_states = self.model(
+ inputs_embeds,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data=adapter_data,
+ cross_attention_states=cross_attention_states,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+
+ # Used in Granite
+ if self.logits_scaling is not None and not self.logits_scaled:
+ logits /= self.logits_scaling
+ if speculative_logits is not None:
+ speculative_logits /= self.logits_scaling
+
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py
new file mode 100644
index 00000000000..d884f413cdd
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py
@@ -0,0 +1,298 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Llava-NeXT model."""
+
+from typing import List, Optional, Tuple
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from transformers.activations import ACT2FN
+from transformers.image_processing_utils import select_best_resolution
+
+from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
+from text_generation_server.models.custom_modeling.vlm import (
+ load_text_model,
+ load_vision_model,
+)
+from text_generation_server.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelRowLinear,
+)
+
+
+def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
+ """
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
+
+ Args:
+ image_size (`tuple`):
+ The size of the input image in the format (height, width).
+ grid_pinpoints (`List`):
+ A list containing possible resolutions. Each item in the list should be a tuple or list
+ of the form `(height, width)`.
+ patch_size (`int`):
+ The size of each image patch.
+
+ Returns:
+ tuple: The shape of the image patch grid in the format (height, width).
+ """
+ if not isinstance(grid_pinpoints, list):
+ raise ValueError("grid_pinpoints should be a list of tuples or lists")
+
+ height, width = select_best_resolution(image_size, grid_pinpoints)
+ return height // patch_size, width // patch_size
+
+
+def unpad_image(tensor, original_size):
+ """
+ Unpads a PyTorch tensor of a padded and resized image.
+
+ Args:
+ tensor (`torch.Tensor`):
+ The image tensor, assumed to be of shape (num_channels, height, width).
+ original_size (`tuple`):
+ The original size of the image (height, width).
+
+ Returns:
+ `torch.Tensor`: The unpadded image tensor.
+ """
+ original_height, original_width = original_size
+ current_height, current_width = tensor.shape[1:]
+
+ original_aspect_ratio = original_width / original_height
+ current_aspect_ratio = current_width / current_height
+
+ if original_aspect_ratio > current_aspect_ratio:
+ scale_factor = current_width / original_width
+ new_height = int(original_height * scale_factor)
+ padding = (current_height - new_height) // 2
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
+ else:
+ scale_factor = current_height / original_height
+ new_width = int(original_width * scale_factor)
+ padding = (current_width - new_width) // 2
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
+
+ return unpadded_tensor
+
+
+# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
+class LlavaNextMultiModalProjector(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+
+ self.linear_1 = TensorParallelColumnLinear.load(
+ prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
+ )
+ self.act = ACT2FN[config.projector_hidden_act]
+ self.linear_2 = TensorParallelRowLinear.load(
+ prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
+ )
+
+ def forward(self, image_features):
+ hidden_states = self.linear_1(image_features)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+class FlashLlavaNextForConditionalGeneration(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ config.vision_config.quantize = config.quantize
+ vision_config = config.vision_config
+ # Instead of selecting in hidden_states[-2].
+ # Instead compute only the n -2 + 1 layers and don't pool
+ if config.vision_feature_layer < 0:
+ vision_config.num_hidden_layers += config.vision_feature_layer + 1
+ else:
+ vision_config.num_hidden_layers = config.vision_feature_layer + 1
+ self.vision_tower = load_vision_model(
+ prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
+ config=config.vision_config,
+ weights=weights,
+ )
+
+ self.multi_modal_projector = LlavaNextMultiModalProjector(
+ prefix="multi_modal_projector", config=config, weights=weights
+ )
+
+ self.image_newline = weights.get_tensor("image_newline")
+
+ self.vocab_size = config.text_config.vocab_size
+ self.config = config
+ config.text_config.quantize = config.quantize
+ config.text_config.speculator = config.speculator
+ self.text_model = load_text_model(
+ prefix="language_model" if not prefix else f"{prefix}.language_model",
+ config=config.text_config,
+ weights=weights,
+ )
+ self.pad_token_id = (
+ config.pad_token_id if config.pad_token_id is not None else -1
+ )
+
+ def _merge_input_ids_with_image_features(
+ self,
+ input_ids: torch.Tensor,
+ inputs_embeds: torch.Tensor,
+ image_features: torch.Tensor,
+ ):
+ """In place merges in vision_embeddings with inputs_embeds."""
+ mask = torch.where(input_ids == self.config.image_token_index)
+ # Let's pray we have enabled enough slots !
+ try:
+ inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
+ except Exception as e:
+ raise RuntimeError(
+ f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
+ )
+ return inputs_embeds
+
+ def get_vision_embeds(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_attention_mask: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ ):
+ # num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
+ # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
+ # 1. Extract the input embeddings
+
+ # 2. Merge text and images
+ num_images, num_patches, channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.view(
+ num_images * num_patches, channels, height, width
+ )
+ image_features = self.vision_tower(pixel_values)
+
+ # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
+ # Already done within the clip model
+ selected_image_feature = image_features.last_hidden_state
+
+ if self.config.vision_feature_select_strategy == "default":
+ selected_image_feature = selected_image_feature[:, 1:]
+ elif self.config.vision_feature_select_strategy == "full":
+ selected_image_feature = selected_image_feature
+ else:
+ raise RuntimeError(
+ f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
+ )
+
+ image_features = self.multi_modal_projector(selected_image_feature)
+
+ # split up image_features for each of the individual images
+ # hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
+ # if we assume each image has 5 image features (base image + 4 patches)
+ split_sizes = [num_patches] * num_images
+ image_features = torch.split(image_features, split_sizes, dim=0)
+
+ # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
+ height = width = (
+ self.config.vision_config.image_size // self.config.vision_config.patch_size
+ )
+
+ new_image_features = []
+ for image_idx, image_feature in enumerate(image_features):
+ if image_feature.shape[0] > 1:
+ base_image_feature = image_feature[0]
+ image_feature = image_feature[1:]
+
+ if height * width != base_image_feature.shape[0]:
+ raise ValueError(
+ "The number of patches is not consistent with the image size."
+ )
+
+ # Dimensions are intentionally swapped to be bug-compatible with
+ # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(
+ image_sizes[image_idx],
+ self.config.image_grid_pinpoints,
+ self.config.vision_config.image_size,
+ )
+ image_feature = image_feature.view(
+ num_patch_height, num_patch_width, height, width, -1
+ )
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
+ image_feature = torch.cat(
+ (
+ image_feature,
+ self.image_newline[:, None, None].expand(
+ *image_feature.shape[:-1], 1
+ ),
+ ),
+ dim=-1,
+ )
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
+ else:
+ image_feature = image_feature[0]
+ image_feature = torch.cat(
+ (image_feature, self.image_newline[None]), dim=0
+ )
+ new_image_features.append(image_feature)
+ image_features = torch.stack(new_image_features, dim=0)
+ return image_features.view(-1, image_features.shape[-1])
+
+ def get_inputs_embeds(
+ self,
+ input_ids: torch.Tensor,
+ vision_embeds: torch.Tensor = None,
+ pixel_values: torch.FloatTensor = None,
+ image_sizes: Optional[torch.LongTensor] = None,
+ ):
+ inputs_embeds = self.text_model.embed_tokens(input_ids)
+
+ if vision_embeds is not None:
+ # When we generate, we don't want to replace the potential image_token_id that we generated by images
+ # that simply don't exist
+ inputs_embeds = self._merge_input_ids_with_image_features(
+ input_ids, inputs_embeds, vision_embeds
+ )
+ return inputs_embeds
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.BoolTensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ):
+
+ hidden_states = self.text_model.model(
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ cu_seqlen_prefill=cu_seqlen_prefill,
+ kv_cache=kv_cache,
+ slots=slots,
+ seqlen=seqlen,
+ hpu_attention_meta=hpu_attention_meta,
+ adapter_data=adapter_data,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.text_model.lm_head(hidden_states)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py
new file mode 100644
index 00000000000..43584d9115e
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py
@@ -0,0 +1,504 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+from typing import Optional, List, Tuple
+
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ set_block_mapping,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+ TensorParallelMultiAdapterLinear,
+ TensorParallelAdapterRowLinear,
+)
+from text_generation_server.layers.rotary import PositionRotaryEmbedding
+from text_generation_server.layers.layernorm import (
+ FastRMSNorm,
+)
+import habana_frameworks.torch as htorch
+
+
+class MistralConfig(PretrainedConfig):
+ model_type = "mistral"
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ hidden_size=4096,
+ intermediate_size=14336,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=8,
+ hidden_act="silu",
+ max_position_embeddings=4096 * 32,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ sliding_window=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.sliding_window = sliding_window
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class MistralAttention(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights, layer_id, rotary_emb):
+ super().__init__()
+ self.max_past = (
+ config.sliding_window if config.sliding_window is not None else -1
+ )
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+
+ if getattr(config, "head_dim", None) is not None:
+ self.head_size = config.head_dim
+ else:
+ self.head_size = self.hidden_size // self.num_heads
+
+ self.rotary_emb = rotary_emb
+
+ self.softmax_scale = self.head_size**-0.5
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+
+ query_key_value = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=False,
+ )
+
+ self.query_key_value = TensorParallelMultiAdapterLinear.load(
+ query_key_value,
+ layer_id,
+ ["q_proj", "k_proj", "v_proj"],
+ sizes=[
+ self.head_size * config.num_attention_heads,
+ self.head_size * config.num_key_value_heads,
+ self.head_size * config.num_key_value_heads,
+ ],
+ process_group=weights.process_group,
+ )
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.o_proj = TensorParallelAdapterRowLinear.load(
+ o_proj,
+ layer_id,
+ "o_proj",
+ process_group=weights.process_group,
+ )
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ ):
+ qkv = self.query_key_value(hidden_states, adapter_data)
+ query, kv = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ 2 * self.head_size * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+ query = query.view(-1, self.num_heads, self.head_size)
+ kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
+
+ self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
+
+ kv_cache.store(
+ key=kv[:, 0],
+ value=kv[:, 1],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=kv[:, 0],
+ value=kv[:, 1],
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ window_size_left=self.max_past,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ window_size_left=self.max_past,
+ )
+
+ return self.o_proj(
+ attn_output.view(-1, self.num_heads * self.head_size), adapter_data
+ )
+
+
+class MistralMLP(nn.Module):
+ def __init__(self, prefix: str, config, weights, layer_id):
+ super().__init__()
+ self.hidden_act = config.hidden_act
+ self.act = (
+ ACT2FN[self.hidden_act]
+ if "gelu" not in self.hidden_act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh"
+ if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
+ else "none"
+ ),
+ )
+ )
+ # Fuse gate and up proj
+ gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+ self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
+ gate_up_proj,
+ layer_id,
+ ["gate_proj", "up_proj"],
+ sizes=[
+ config.intermediate_size,
+ config.intermediate_size,
+ ],
+ process_group=weights.process_group,
+ )
+
+ down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+
+ self.down_proj = TensorParallelAdapterRowLinear.load(
+ down_proj,
+ layer_id,
+ "down_proj",
+ process_group=weights.process_group,
+ )
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+
+ # TODO: This is a hotfix to be removed & properly refactored.
+ self.quantize = config.quantize
+
+ def forward(self, hidden_states, adapter_data):
+ gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
+ gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
+ return self.down_proj(
+ self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
+ )
+
+
+class MistralLayer(nn.Module):
+ def __init__(self, prefix: str, config, weights, layer_id, rotary_emb):
+ super().__init__()
+ self.self_attn = MistralAttention(
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ layer_id=layer_id,
+ rotary_emb=rotary_emb,
+ )
+ self.mlp = MistralMLP(
+ prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
+ )
+
+ self.input_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ ):
+ normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ )
+
+ # faster post attention rms norm
+ normed_attn_res_output, attn_res = self.post_attention_layernorm(
+ attn_output, res
+ )
+
+ mlp_output = self.mlp(normed_attn_res_output, adapter_data)
+
+ return mlp_output, attn_res
+
+
+class MistralModel(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ process_group = weights.process_group
+ self.tp_rank = process_group.rank()
+ self.tp_world_size = process_group.size()
+
+ if getattr(config, "head_dim", None) is not None:
+ head_dim = config.head_dim
+ else:
+ head_dim = config.hidden_size // config.num_attention_heads
+
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=head_dim,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ MistralLayer(
+ prefix=f"{prefix}.layers.{layer_id}",
+ config=config,
+ weights=weights,
+ layer_id=layer_id,
+ rotary_emb=rotary_emb,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = FastRMSNorm.load(
+ prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
+ )
+
+ self.gradient_checkpointing = False
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ adapter_data: Optional[torch.Tensor] = None,
+ ):
+ if hpu_attention_meta is not None:
+ hpu_attention_meta = set_block_mapping(
+ hpu_attention_meta, inputs_embeds.shape[0]
+ )
+ hidden_states = inputs_embeds
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ )
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+
+class FlashMistralForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights, name=None):
+ if name is None:
+ name = "model"
+ super().__init__()
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=(
+ f"{name}.embed_tokens"
+ if not prefix
+ else f"{prefix}.{name}.embed_tokens"
+ ),
+ weights=weights,
+ )
+ self.model = MistralModel(
+ prefix=name if not prefix else f"{prefix}.{name}",
+ config=config,
+ weights=weights,
+ )
+ self.lm_head = SpeculativeHead.load(
+ config,
+ # TODO dirty hack for idefics2.
+ prefix=(
+ "lm_head" if not prefix or name != "model" else f"{prefix}.lm_head"
+ ),
+ weights=weights,
+ )
+ self.max_past = config.sliding_window
+ self.max_past_tensor = (
+ torch.tensor(config.sliding_window, device=weights.device)
+ if self.max_past is not None
+ else None
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ inputs_embeds = self.embed_tokens(input_ids)
+ hidden_states = self.model(
+ inputs_embeds,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ adapter_data,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits = self.lm_head(hidden_states)
+ return logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py
new file mode 100644
index 00000000000..8c682c7f1ba
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py
@@ -0,0 +1,530 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+import torch.distributed
+from torch import nn
+from transformers.configuration_utils import PretrainedConfig
+
+from text_generation_server.layers import (
+ FastLinear,
+ SpeculativeHead,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ TensorParallelRowLinear,
+ get_linear,
+)
+from text_generation_server.layers.attention import (
+ Seqlen,
+ attention,
+ paged_attention,
+ set_block_mapping,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.layernorm import FastRMSNorm
+from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
+from text_generation_server.layers.rotary import PositionRotaryEmbedding
+from text_generation_server.utils.weights import UnquantizedWeight
+import habana_frameworks.torch as htorch
+
+
+class MixtralConfig(PretrainedConfig):
+ model_type = "mixtral"
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ hidden_size=4096,
+ intermediate_size=14336,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=8,
+ hidden_act="silu",
+ max_position_embeddings=4096 * 32,
+ initializer_range=0.02,
+ rms_norm_eps=1e-05,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ sliding_window=None,
+ num_experts_per_tok=2,
+ num_local_experts=8,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.sliding_window = sliding_window
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.num_experts_per_tok = num_experts_per_tok
+ self.num_local_experts = num_local_experts
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+def promote_scalar(x: torch.Tensor) -> torch.Tensor:
+ return x.view(1) if len(x.size()) == 0 else x
+
+
+def load_attention(config, prefix: str, weights):
+ if config.num_attention_heads != config.num_key_value_heads:
+ return _load_gqa(config, prefix, weights)
+ else:
+ return TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=False,
+ )
+
+
+def _load_gqa(config, prefix: str, weights):
+ assert config.hidden_size % config.num_attention_heads == 0
+ assert config.num_attention_heads % weights.process_group.size() == 0
+
+ weight = weights.get_multi_weights_col(
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ )
+
+ if isinstance(weight, UnquantizedWeight):
+ weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
+
+ head_size = config.hidden_size // config.num_attention_heads
+ num_heads = config.num_attention_heads // weights.process_group.size()
+ num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
+ assert list(weight.weight.shape) == [
+ (num_heads + 2 * num_key_value_heads) * head_size,
+ config.hidden_size,
+ ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
+
+ return TensorParallelColumnLinear(get_linear(weight, bias=None))
+
+
+def _load_experts(config, prefix: str, mat, weights):
+ if config.quantize is not None:
+ raise NotImplementedError("Mixtral does not support weight quantization yet.")
+
+ assert mat in ["w1", "w2", "w3"]
+
+ world_size = weights.process_group.size()
+ rank = weights.process_group.rank()
+
+ assert (
+ config.intermediate_size % world_size == 0
+ ), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
+
+ block_size = config.intermediate_size // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+
+ tensor = torch.empty(
+ (config.num_local_experts * block_size, config.hidden_size),
+ dtype=weights.dtype,
+ device=weights.device,
+ )
+
+ for i in range(config.num_local_experts):
+ slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
+
+ if mat == "w2":
+ expert_slice = slice_[:, start:stop].t().contiguous()
+ else:
+ expert_slice = slice_[start:stop]
+ tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
+ dtype=weights.dtype
+ ).to(device=weights.device)
+ return tensor
+
+
+class MixtralAttention(torch.nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights,
+ rotary_emb,
+ ):
+ super().__init__()
+ self.max_past = (
+ config.sliding_window if config.sliding_window is not None else -1
+ )
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.head_size = self.hidden_size // self.num_heads
+ self.rotary_emb = rotary_emb
+
+ self.softmax_scale = self.head_size**-0.5
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+
+ self.query_key_value = load_attention(config, prefix, weights)
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ self.o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ qkv = self.query_key_value(hidden_states)
+ query, kv = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ 2 * self.head_size * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+ query = query.view(-1, self.num_heads, self.head_size)
+ kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
+
+ self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
+
+ kv_cache.store(
+ key=kv[:, 0],
+ value=kv[:, 1],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=kv[:, 0],
+ value=kv[:, 1],
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ window_size_left=self.max_past,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
+
+
+@torch.jit.script
+def select_experts(gate_logits: torch.Tensor, top_k: int):
+ # all_probs: (sequence_length, n_experts) and upcast for softmax
+ all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
+ # weights, selected_experts: (sequence_length, top-k)
+ weights, selected_experts = torch.topk(all_probs, top_k, dim=-1)
+ weights /= weights.sum(dim=-1, keepdim=True)
+ weights = weights.view(-1)
+ selected_experts = selected_experts.view(-1)
+
+ return selected_experts, weights
+
+
+@torch.jit.script
+def round_up(x: torch.Tensor, value: int):
+ return torch.div(x + (value - 1), value, rounding_mode="trunc") * value
+
+
+class MixtralMoE(nn.Module):
+ def __init__(
+ self, prefix, config: MixtralConfig, moe_layer_cls: Type[MoELayer], weights
+ ):
+ super().__init__()
+
+ # gating
+ self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
+
+ self.moe = moe_layer_cls(
+ n_expert_group=None,
+ n_experts=config.num_local_experts,
+ prefix=f"{prefix}.experts",
+ renormalize=True,
+ topk=config.num_experts_per_tok,
+ topk_group=None,
+ weights=weights,
+ gate_proj_name="w1",
+ up_proj_name="w3",
+ down_proj_name="w2",
+ )
+ assert isinstance(self.moe, MoELayer)
+
+ self.process_group = weights.process_group
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # router_logits: (num_tokens, n_experts)
+ router_logits = self.gate(x)
+ out = self.moe(x, gating_output=router_logits)
+
+ # Reduce sum
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(out, group=self.process_group)
+
+ return out.view(*x.shape)
+
+
+class MixtralLayer(nn.Module):
+ def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):
+ super().__init__()
+ prefix = f"{prefix}.layers.{layer_id}"
+
+ self.self_attn = MixtralAttention(
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ rotary_emb=rotary_emb,
+ )
+
+ moe_layer_cls = (
+ SparseMoELayer if SparseMoELayer.is_supported(weights) else DenseMoELayer
+ )
+ self.moe = MixtralMoE(
+ f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
+ )
+
+ self.input_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ # faster post attention rms norm
+ normed_attn_res_output, attn_res = self.post_attention_layernorm(
+ attn_output, res
+ )
+
+ moe_output = self.moe(normed_attn_res_output)
+
+ return moe_output, attn_res
+
+
+class MixtralModel(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=(
+ "model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
+ ),
+ weights=weights,
+ )
+
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.hidden_size // config.num_attention_heads,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+ self.layers = nn.ModuleList(
+ [
+ MixtralLayer(
+ "model" if not prefix else f"{prefix}.model",
+ layer_id,
+ config,
+ weights,
+ rotary_emb,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = FastRMSNorm.load(
+ prefix="model.norm" if not prefix else f"{prefix}.model.norm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ if hpu_attention_meta is not None:
+ hpu_attention_meta = set_block_mapping(
+ hpu_attention_meta, input_ids.shape[0]
+ )
+ hidden_states = self.embed_tokens(input_ids)
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashMixtralForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ self.model = MixtralModel(prefix, config, weights)
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix="lm_head" if not prefix else f"{prefix}.lm_head",
+ weights=weights,
+ )
+ self.max_past = config.sliding_window
+ self.max_past_tensor = (
+ torch.tensor(config.sliding_window, device=weights.device)
+ if self.max_past is not None
+ else None
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ hidden_states = self.model(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits = self.lm_head(hidden_states)
+ return logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py
new file mode 100644
index 00000000000..fe6d137b61d
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py
@@ -0,0 +1,951 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Mllama model."""
+
+from typing import Optional, Tuple, List
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from transformers.activations import ACT2FN
+import torch.nn.functional as F
+
+from text_generation_server.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ TensorParallelRowLinear,
+ FastLinear,
+)
+from text_generation_server.layers.attention import (
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.models.custom_modeling.flash_llama_modeling import (
+ FlashLlamaForCausalLM,
+)
+from habana_frameworks.torch.hpex.kernels import FusedSDPA
+from vllm_hpu_extension.utils import ModuleFusedSDPA
+import habana_frameworks.torch as htorch
+
+
+def _prepare_aspect_ratio_attention_mask(
+ aspect_ratio_mask: torch.Tensor,
+ num_patches: int,
+ target_length: int,
+ dtype: torch.dtype,
+) -> torch.Tensor:
+ # Expand aspect ratio mask to target_length
+ batch_size, max_num_tiles = aspect_ratio_mask.shape
+ attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype)
+ attention_mask = attention_mask.repeat(1, 1, target_length, 1)
+
+ # Mask padding patches
+ pad_patches = target_length - num_patches
+ attention_mask[:, :, -pad_patches:] = 0
+
+ # Invert the mask (0 -> 1, 1 -> 0)
+ attention_mask = 1 - attention_mask
+
+ # Reshape to 2D and create 4D attention mask
+ # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
+ attention_mask = attention_mask.reshape(
+ batch_size, max_num_tiles * target_length, 1
+ )
+ attention_mask = (
+ attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min
+ )
+ attention_mask = attention_mask.unsqueeze(1)
+
+ return attention_mask
+
+
+# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
+def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ min_dtype: float,
+ cache_position: torch.Tensor,
+ batch_size: int,
+):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ device (`torch.device`):
+ The device to plcae the 4D attention mask on.
+ min_dtype (`float`):
+ The minimum value representable with the dtype `dtype`.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ causal_mask = torch.full(
+ (sequence_length, target_length),
+ fill_value=min_dtype,
+ dtype=dtype,
+ device=device,
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(
+ target_length, device=device
+ ) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = (
+ causal_mask.clone()
+ ) # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = (
+ causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[
+ :, :, :, :mask_length
+ ].masked_fill(padding_mask, min_dtype)
+
+ return causal_mask
+
+
+def _prepare_cross_attention_mask(
+ cross_attention_mask: torch.Tensor,
+ num_vision_tokens: int,
+ dtype: str,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ # reshape so it can be used by attn module
+ batch_size, text_total_length, *_ = cross_attention_mask.shape
+ cross_attention_mask = cross_attention_mask.repeat_interleave(
+ num_vision_tokens, dim=3
+ )
+ cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1)
+ cross_attention_mask = cross_attention_mask.unsqueeze(1)
+
+ # invert the mask
+ inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)
+ cross_attention_mask = inverted_cross_attn_mask.masked_fill(
+ inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
+ )
+
+ # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's
+ # last dimension contains negative infinity values, otherwise it's 1
+ negative_inf_value = torch.finfo(dtype).min
+ full_text_row_masked_out_mask = (
+ (cross_attention_mask != negative_inf_value)
+ .any(dim=-1)
+ .type_as(cross_attention_mask)[..., None]
+ )
+ cross_attention_mask *= full_text_row_masked_out_mask
+
+ return cross_attention_mask, full_text_row_masked_out_mask
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision
+class MllamaVisionMLP(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = TensorParallelColumnLinear.load(
+ prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
+ )
+ self.fc2 = TensorParallelRowLinear.load(
+ prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class MllamaVisionSdpaAttention(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+
+ self.embed_dim = config.hidden_size
+ self.head_dim = config.hidden_size // config.attention_heads
+ self.num_heads = config.attention_heads // weights.process_group.size()
+
+ self.qkv_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=False,
+ )
+ self.o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=False,
+ )
+
+ def forward(
+ self,
+ hidden_state: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ qkv = self.qkv_proj(hidden_state)
+ query, key, value = qkv.split(
+ [
+ self.head_dim * self.num_heads,
+ self.head_dim * self.num_heads,
+ self.head_dim * self.num_heads,
+ ],
+ dim=2,
+ )
+
+ batch_size, q_seq_len, _ = query.shape
+ _, kv_seq_len, _ = key.shape
+
+ query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim)
+ key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
+ value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
+
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ fsdpa_op = ModuleFusedSDPA(FusedSDPA)
+ attn_output = fsdpa_op(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ scale=None,
+ softmax_mode="None",
+ recompute_mode=None,
+ valid_sequence_lengths=None,
+ )
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
+
+ output = self.o_proj(attn_output)
+ return output
+
+
+class MllamaVisionEncoderLayer(nn.Module):
+ def __init__(self, *, prefix, config, weights, is_gated: bool):
+ super().__init__()
+
+ self.hidden_size = config.hidden_size
+ self.num_attention_heads = config.attention_heads
+ self.is_gated = is_gated
+ self.intermediate_size = config.intermediate_size
+
+ self.self_attn = MllamaVisionSdpaAttention(
+ prefix=f"{prefix}.self_attn", config=config, weights=weights
+ )
+ self.mlp = MllamaVisionMLP(
+ prefix=f"{prefix}.mlp", config=config, weights=weights
+ )
+
+ self.input_layernorm = nn.LayerNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05
+ )
+ self.post_attention_layernorm = nn.LayerNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=1e-05
+ )
+
+ # there used to be an if else here, no code path
+ if is_gated:
+ self.gate_attn = nn.Parameter(
+ weights.get_tensor(f"{prefix}.gate_attn"), requires_grad=False
+ )
+ self.gate_ffn = nn.Parameter(
+ weights.get_tensor(f"{prefix}.gate_ffn"), requires_grad=False
+ )
+
+ def forward(
+ self,
+ hidden_state: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ # Self Attention
+ residual = hidden_state
+ hidden_state = self.input_layernorm(hidden_state)
+ hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask)
+ gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
+ hidden_state = residual + gate_attn * hidden_state
+
+ # Feed forward
+ residual = hidden_state
+ hidden_state = self.post_attention_layernorm(hidden_state)
+ hidden_state = self.mlp(hidden_state)
+ gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
+ hidden_state = residual + gate_ffn * hidden_state
+ return hidden_state
+
+
+class MllamaVisionEncoder(nn.Module):
+ def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int):
+ super().__init__()
+ self.config = config
+ self.layers = [
+ MllamaVisionEncoderLayer(
+ prefix=f"{prefix}.layers.{i}",
+ config=config,
+ weights=weights,
+ is_gated=is_gated,
+ )
+ for i in range(num_layers)
+ ]
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ encoder_states = [hidden_states]
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+ for encoder_layer in self.layers:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ )
+
+ hidden_states = layer_outputs
+ encoder_states.append(hidden_states)
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ return hidden_states, encoder_states
+
+
+class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.max_num_tiles = config.max_num_tiles
+ self.hidden_size = config.hidden_size
+ self.max_aspect_ratio_id = config.max_aspect_ratio_id
+
+ self.embedding = TensorParallelEmbedding(
+ prefix=f"{prefix}.embedding", weights=weights
+ )
+ self.gate = nn.Parameter(
+ weights.get_tensor(f"{prefix}.gate"), requires_grad=False
+ )
+
+ def forward(
+ self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
+ ) -> torch.Tensor:
+ embeddings = self.embedding(aspect_ratio_ids)
+ embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)
+
+ # Always gated.
+ embeddings = embeddings * self.gate.tanh()
+
+ hidden_state = hidden_state + embeddings
+ return hidden_state
+
+
+class MllamaPrecomputedPositionEmbedding(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.max_num_tiles = config.max_num_tiles
+ self.max_aspect_ratio_id = config.max_aspect_ratio_id
+ self.num_patches = (config.image_size // config.patch_size) ** 2 + 1
+ self.hidden_size = config.hidden_size
+ self.scale = config.hidden_size**-0.5
+
+ self.gate = nn.Parameter(
+ weights.get_tensor(f"{prefix}.gate"), requires_grad=False
+ )
+
+ # position embedding
+ embedding = nn.Parameter(
+ weights.get_tensor(f"{prefix}.embedding"), requires_grad=False
+ )
+ self.gated_position_embedding = (1 - self.gate.tanh()) * embedding
+ self.tile_embedding = TensorParallelEmbedding(
+ prefix=f"{prefix}.tile_embedding", weights=weights
+ )
+
+ def forward(
+ self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
+ ) -> torch.Tensor:
+ # position embeddings
+ hidden_state = hidden_state + self.gated_position_embedding.view(
+ 1, 1, self.num_patches, self.hidden_size
+ )
+
+ # precomputed tile position embeddings
+ tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
+ batch_size = hidden_state.shape[0]
+ tile_position_embedding = tile_position_embedding.reshape(
+ batch_size, self.max_num_tiles, self.num_patches, self.hidden_size
+ )
+ gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding
+ hidden_state = hidden_state + gated_tile_position_embedding
+
+ return hidden_state
+
+
+class MllamaVisionModel(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+ self.max_num_tiles = config.max_num_tiles
+ self.hidden_size = config.hidden_size
+ self.num_channels = config.num_channels
+ self.intermediate_layers_indices = config.intermediate_layers_indices
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
+ self.scale = config.hidden_size**-0.5
+ self.dtype = weights.dtype
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.hidden_size,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ bias=False,
+ )
+ self.patch_embedding.weight = nn.Parameter(
+ weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
+ )
+
+ self.class_embedding = nn.Parameter(
+ weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False
+ )
+
+ self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(
+ prefix=f"{prefix}.gated_positional_embedding",
+ config=config,
+ weights=weights,
+ )
+
+ self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
+ prefix=f"{prefix}.pre_tile_positional_embedding",
+ config=config,
+ weights=weights,
+ )
+ self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
+ prefix=f"{prefix}.post_tile_positional_embedding",
+ config=config,
+ weights=weights,
+ )
+
+ ## layer norms
+ self.layernorm_pre = nn.LayerNorm.load(
+ prefix=f"{prefix}.layernorm_pre",
+ weights=weights,
+ # torch default
+ eps=1e-05,
+ )
+ self.layernorm_post = nn.LayerNorm.load(
+ prefix=f"{prefix}.layernorm_post",
+ weights=weights,
+ # torch default
+ eps=1e-05,
+ )
+
+ ## encoders
+ self.transformer = MllamaVisionEncoder(
+ prefix=f"{prefix}.transformer",
+ config=config,
+ weights=weights,
+ is_gated=False,
+ num_layers=config.num_hidden_layers,
+ )
+ self.global_transformer = MllamaVisionEncoder(
+ prefix=f"{prefix}.global_transformer",
+ config=config,
+ weights=weights,
+ is_gated=True,
+ num_layers=config.num_global_layers,
+ )
+
+ def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ batch_size, _, hidden_size = hidden_state.shape
+ class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
+ hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
+ return hidden_state
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ aspect_ratio_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ (
+ batch_size,
+ num_concurrent_media,
+ num_tiles,
+ num_channels,
+ height,
+ width,
+ ) = pixel_values.shape
+
+ pixel_values = pixel_values.reshape(
+ batch_size * num_concurrent_media * num_tiles, num_channels, height, width
+ )
+ aspect_ratio_ids = aspect_ratio_ids.reshape(
+ batch_size * num_concurrent_media, -1
+ )
+
+ # patch embedding
+ patch_embeds = self.patch_embedding(pixel_values)
+ hidden_state = patch_embeds.flatten(2).transpose(1, 2)
+
+ # tile embeddings
+ _, num_patches, dim = hidden_state.shape
+ hidden_state = hidden_state.reshape(
+ batch_size * num_concurrent_media, num_tiles, -1, dim
+ )
+ hidden_state = self.pre_tile_positional_embedding(
+ hidden_state, aspect_ratio_ids
+ )
+
+ # apply cls token
+ hidden_state = hidden_state.reshape(
+ batch_size * num_concurrent_media * num_tiles, num_patches, dim
+ )
+ hidden_state = self.apply_class_embedding(hidden_state)
+ num_patches += 1
+
+ # apply position embeddings
+ hidden_state = hidden_state.reshape(
+ batch_size * num_concurrent_media, num_tiles, num_patches, dim
+ )
+ hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)
+
+ # apply encoder
+ hidden_state = self.layernorm_pre(hidden_state)
+
+ # Compute the number of tokens to pad
+ num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
+ # Compute padding tuple for pad function
+ padding = (
+ 0,
+ 0,
+ 0,
+ num_padding_patches,
+ ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
+ # Pad the tensor
+ hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
+ slice_index = -num_padding_patches if num_padding_patches > 0 else None
+
+ if attention_mask is not None:
+ attention_mask = attention_mask.reshape(
+ batch_size * num_concurrent_media, -1
+ )
+ attention_mask = _prepare_aspect_ratio_attention_mask(
+ aspect_ratio_mask=attention_mask,
+ num_patches=self.num_patches,
+ target_length=hidden_state.shape[2],
+ dtype=self.dtype,
+ )
+
+ hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)
+ hidden_state, all_intermediate_hidden_states = self.transformer(
+ hidden_state,
+ attention_mask=attention_mask,
+ )
+ intermediate_hidden_states = [
+ hidden_state
+ for idx, hidden_state in enumerate(all_intermediate_hidden_states)
+ if idx in self.intermediate_layers_indices
+ ]
+ intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1)
+
+ # apply global encoder
+ hidden_state = self.layernorm_post(hidden_state)
+ hidden_state = hidden_state.reshape(
+ batch_size * num_concurrent_media,
+ num_tiles,
+ num_patches + num_padding_patches,
+ dim,
+ )
+ hidden_state = self.post_tile_positional_embedding(
+ hidden_state, aspect_ratio_ids
+ )
+ hidden_state = hidden_state.reshape(
+ batch_size * num_concurrent_media,
+ num_tiles * (num_patches + num_padding_patches),
+ dim,
+ )
+ hidden_state, _ = self.global_transformer(
+ hidden_state, attention_mask=attention_mask
+ )
+ hidden_state = hidden_state.reshape(
+ batch_size * num_concurrent_media,
+ num_tiles,
+ num_patches + num_padding_patches,
+ dim,
+ )
+ hidden_state = hidden_state[:, :, :slice_index]
+
+ # adding intermediate layer outputs
+ hidden_state = hidden_state.reshape(
+ batch_size, num_concurrent_media, num_tiles, num_patches, dim
+ )
+ intermediate_hidden_states = intermediate_hidden_states.reshape(
+ batch_size * num_concurrent_media,
+ num_tiles,
+ num_patches + num_padding_patches,
+ -1,
+ )
+ intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
+ intermediate_hidden_states = intermediate_hidden_states.reshape(
+ batch_size, num_concurrent_media, num_tiles, num_patches, -1
+ )
+ hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
+ return hidden_state
+
+
+class MllamaTextCrossAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, *, prefix, config, weights, layer_idx):
+ super().__init__()
+ self.config = config
+ self.num_heads = self.config.num_attention_heads
+ self.num_key_value_heads = self.config.num_key_value_heads
+ self.dropout = config.dropout
+ self.hidden_size = config.hidden_size
+ self.head_size = config.hidden_size // self.num_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.layer_idx = layer_idx
+
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ self.num_key_value_heads // weights.process_group.size()
+ )
+
+ self.q_proj = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.q_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.k_proj = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.k_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.v_proj = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.v_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=False,
+ )
+
+ self.q_norm = MllamaTextRMSNorm.load(
+ prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.k_norm = MllamaTextRMSNorm.load(
+ prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.softmax_scale = self.head_size**-0.5
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cross_attention_states: Optional[torch.Tensor] = None,
+ # past_key_value=None,
+ # attention_mask: Optional[torch.Tensor] = None,
+ # cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+ # hidden_states = hidden_states.unsqueeze(0)
+ # bsz, q_len, _ = hidden_states.size()
+ (
+ cross_attention_states,
+ cross_attention_len,
+ indices,
+ ) = cross_attention_states
+ bs = cross_attention_len.size(0)
+ query_states = self.q_proj(hidden_states)
+ query_states = query_states.view(bs, -1, self.num_heads, self.head_size)
+ query_states = self.q_norm(query_states)
+
+ key_states = self.k_proj(cross_attention_states)
+ value_states = self.v_proj(cross_attention_states)
+ key_states = key_states.view(bs, -1, self.num_key_value_heads, self.head_size)
+ value_states = value_states.view(
+ bs, -1, self.num_key_value_heads, self.head_size
+ )
+ key_states = self.k_norm(key_states)
+
+ # key_states = key_states.repeat(1, self.num_key_value_groups, 1)
+ # value_states = value_states.repeat(1, self.num_key_value_groups, 1)
+ # logger.info(
+ # f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
+ # )
+ # execute sdpa
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+ fsdpa_op = ModuleFusedSDPA(FusedSDPA)
+ attn_output = fsdpa_op(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=False,
+ scale=None,
+ softmax_mode="None",
+ recompute_mode=None,
+ valid_sequence_lengths=None,
+ )
+ attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
+ attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
+
+ return attn_output
+
+
+# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText
+class MllamaTextMLP(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+ self.gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+ self.down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ shape = x.shape
+ gate_up_states = self.gate_up_proj(x)
+ gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size)
+ result = self.down_proj(
+ self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1]
+ )
+ return result
+
+
+class FlashLlamaCrossLayer(torch.nn.Module):
+ """Cross-attention transformer block with tanh-gated attention and feedforward."""
+
+ def __init__(self, *, prefix, config, weights, index) -> None:
+ layer_idx = index
+ super().__init__()
+ self.cross_attn = MllamaTextCrossAttention(
+ prefix=f"{prefix}.cross_attn",
+ config=config,
+ weights=weights,
+ layer_idx=layer_idx,
+ )
+
+ self.input_layernorm = MllamaTextRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.cross_attn_attn_gate = torch.nn.Parameter(
+ weights.get_tensor(f"{prefix}.cross_attn_attn_gate"), requires_grad=False
+ )
+
+ self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+ self.post_attention_layernorm = MllamaTextRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.cross_attn_mlp_gate = torch.nn.Parameter(
+ weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False
+ )
+ self.layer_idx = layer_idx
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ cross_attention_states, # [ IB, ...]
+ hpu_attention_meta,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if cross_attention_states is None:
+ return hidden_states, residual
+ if residual is not None:
+ hidden_states += residual
+
+ indices = cross_attention_states[-1]
+ out_hidden_states = hidden_states[:]
+ hidden_states = hidden_states[indices]
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+
+ hidden_states = self.cross_attn(
+ hidden_states=hidden_states,
+ # attention_mask=cross_attention_mask,
+ cross_attention_states=cross_attention_states,
+ )
+ hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
+
+ out_hidden_states[indices] = hidden_states
+ hidden_states = out_hidden_states
+
+ return hidden_states, None
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText
+class MllamaTextRMSNorm(nn.Module):
+ def __init__(self, weight, eps):
+ super().__init__()
+ self.weight = weight
+ self.variance_epsilon = eps
+
+ @classmethod
+ def load(cls, *, prefix, weights, eps):
+ weight = nn.Parameter(
+ weights.get_tensor(f"{prefix}.weight"), requires_grad=False
+ )
+ return cls(weight=weight, eps=eps)
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class FlashMllamaForConditionalGeneration(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ config.vision_config.quantize = None
+ config.vision_config.speculator = config.speculator
+ config.text_config.quantize = config.quantize
+ config.text_config.speculator = config.speculator
+ config.text_config._attn_implementation = "sdpa"
+ self.hidden_size = config.text_config.hidden_size
+ self.vision_model = MllamaVisionModel(
+ prefix="vision_model", config=config.vision_config, weights=weights
+ )
+ self.multi_modal_projector = FastLinear.load(
+ prefix="multi_modal_projector", config=config, weights=weights, bias=True
+ )
+ self.text_model = FlashLlamaForCausalLM(
+ prefix="language_model", config=config.text_config, weights=weights
+ )
+ self.config = config
+ self.dtype = weights.dtype
+ self.device = weights.device
+
+ def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask):
+ if aspect_ratio_ids is None:
+ raise ValueError(
+ "`aspect_ratio_ids` must be provided if `pixel_values` is provided"
+ )
+ # logger.info(f"PIxel values {pixel_values.shape}")
+ batch_size = pixel_values.shape[0]
+ vision_states = self.vision_model(
+ pixel_values, aspect_ratio_ids, aspect_ratio_mask
+ )
+ cross_attention_states = self.multi_modal_projector(vision_states).reshape(
+ -1, vision_states.shape[-2], self.hidden_size
+ )
+ _, _, h = cross_attention_states.shape
+ cross_attention_states = cross_attention_states.view(batch_size, -1, h)
+ # logger.info(f"cross {cross_attention_states.shape}")
+ return cross_attention_states
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor],
+ adapter_data: Optional[torch.Tensor] = None,
+ cross_attention_states: Optional[torch.Tensor] = None,
+ indices=None,
+ cross_attention_len: Optional[torch.Tensor] = None,
+ ):
+ if cross_attention_states is not None:
+ cross_attention_states = (
+ cross_attention_states,
+ cross_attention_len,
+ indices,
+ )
+
+ outputs = self.text_model(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ cu_seqlen_prefill=cu_seqlen_prefill,
+ kv_cache=kv_cache,
+ slots=slots,
+ seqlen=seqlen,
+ hpu_attention_meta=hpu_attention_meta,
+ lm_head_indices=lm_head_indices,
+ adapter_data=adapter_data,
+ cross_attention_states=cross_attention_states,
+ )
+
+ return outputs
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
new file mode 100644
index 00000000000..8ee1dfa2411
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
@@ -0,0 +1,436 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
+from typing import Optional, List, Tuple
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ set_block_mapping,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+ get_linear,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.layernorm import (
+ FastLayerNorm,
+)
+from text_generation_server.layers.rotary import (
+ PositionRotaryEmbedding,
+)
+from text_generation_server.utils.weights import UnquantizedWeight
+import habana_frameworks.torch as htorch
+
+
+class GPTNeoXConfig(TransformersGPTNeoXConfig):
+ attribute_map = {
+ "num_key_value_heads": "num_attention_heads",
+ }
+
+
+def load_row(config, prefix: str, weights, bias: bool):
+ weight = weights.get_weights_row(prefix)
+
+ if bias and weights.process_group.rank() == 0:
+ # Rank is only on the first rank process
+ bias = weights.get_tensor(f"{prefix}.bias")
+ else:
+ bias = None
+
+ linear = get_linear(weight, bias)
+ if config.use_parallel_residual:
+ return linear
+ else:
+ return TensorParallelRowLinear(linear, process_group=weights.process_group)
+
+
+def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
+ weight = weights.get_multi_weights_col([prefix], dim=0)
+ if isinstance(weight, UnquantizedWeight):
+ # Only on non quantized versions
+ weight.weight = (
+ weight.weight.view(
+ num_heads,
+ 3,
+ head_size,
+ hidden_size,
+ )
+ .permute(1, 0, 2, 3)
+ .reshape(-1, hidden_size)
+ )
+
+ bias = weights.get_sharded(f"{prefix}.bias", dim=0)
+ bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
+
+ linear = get_linear(weight, bias)
+ if config.use_parallel_residual:
+ return linear
+ else:
+ return TensorParallelColumnLinear(linear)
+
+
+class FlashNeoxAttention(torch.nn.Module):
+ def __init__(self, config, prefix, weights, rotary_emb):
+ super().__init__()
+ num_heads = config.num_attention_heads
+ hidden_size = config.hidden_size
+
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.head_size = hidden_size // num_heads
+
+ self.rotary_dim = int(config.rotary_pct * self.head_size)
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.rotary_emb = rotary_emb
+ self.softmax_scale = self.head_size ** (-0.5)
+
+ self.query_key_value = load_qkv(
+ config,
+ prefix=f"{prefix}.query_key_value",
+ weights=weights,
+ num_heads=self.num_heads,
+ head_size=self.head_size,
+ hidden_size=self.hidden_size,
+ )
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+ self.dense = load_row(
+ config, prefix=f"{prefix}.dense", weights=weights, bias=True
+ )
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_heads, dtype=torch.int32, device=weights.device
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ qkv = self.query_key_value(hidden_states)
+ qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
+
+ # Compute rotary embeddings on rotary_ndims
+ query_rot = qkv[:, 0][..., : self.rotary_dim]
+ query_pass = qkv[:, 0][..., self.rotary_dim :]
+ key_rot = qkv[:, 1][..., : self.rotary_dim]
+ key_pass = qkv[:, 1][..., self.rotary_dim :]
+
+ # Inplace rotary
+ self.rotary_emb(query_rot, key_rot, cos, sin)
+ qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
+ qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
+
+ kv_cache.store(
+ key=qkv[:, 1],
+ value=qkv[:, 2],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=qkv[:, 0],
+ key=qkv[:, 1],
+ value=qkv[:, 2],
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ qkv[:, 0],
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
+
+
+class FlashMLP(nn.Module):
+ def __init__(self, config, prefix, weights):
+ super().__init__()
+ act = config.hidden_act
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+
+ self.dense_h_to_4h = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
+ )
+ self.dense_4h_to_h = load_row(
+ config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense_h_to_4h(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.dense_4h_to_h(hidden_states)
+ return hidden_states
+
+
+class FlashNeoXLayer(nn.Module):
+ def __init__(self, layer_id, config, weights, rotary_emb):
+ super().__init__()
+
+ layer_norm_eps = config.layer_norm_eps
+
+ prefix = f"gpt_neox.layers.{layer_id}"
+
+ self.use_parallel_residual = config.use_parallel_residual
+ self.input_layernorm = FastLayerNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=layer_norm_eps
+ )
+ self.post_attention_layernorm = FastLayerNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=layer_norm_eps,
+ )
+ self.attention = FlashNeoxAttention(
+ config,
+ prefix=f"{prefix}.attention",
+ weights=weights,
+ rotary_emb=rotary_emb,
+ )
+
+ self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
+ self.process_group = weights.process_group
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ if self.use_parallel_residual:
+ ln1_hidden_states, _ = self.input_layernorm(hidden_states)
+
+ attn_output = self.attention(
+ ln1_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
+
+ mlp_output = self.mlp(ln2_hidden_states)
+ intermediate = mlp_output + attn_output
+
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(intermediate, group=self.process_group)
+
+ return intermediate + hidden_states, None
+ else:
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+
+ hidden_states = self.attention(
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states, residual = self.post_attention_layernorm(
+ hidden_states, residual
+ )
+
+ mlp_output = self.mlp(hidden_states)
+
+ return mlp_output, residual
+
+
+class FlashGPTNeoXPreTrainedModel(PreTrainedModel):
+ config_class = GPTNeoXConfig
+ base_model_prefix = "gpt_neox"
+ supports_gradient_checkpointing = False
+ _no_split_modules = None
+
+
+class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__(config)
+ self.config = config
+
+ self.embed_in = TensorParallelEmbedding(
+ prefix=f"{prefix}.embed_in", weights=weights
+ )
+
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=int(
+ config.rotary_pct * (config.hidden_size // config.num_attention_heads)
+ ),
+ base=config.rotary_emb_base,
+ device=weights.device,
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ FlashNeoXLayer(layer_id, config, weights, rotary_emb)
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.final_layer_norm = FastLayerNorm.load(
+ prefix=f"{prefix}.final_layer_norm",
+ weights=weights,
+ eps=config.layer_norm_eps,
+ )
+
+ self.gradient_checkpointing = False
+
+ self.head_size = self.layers[0].attention.head_size
+ self.num_heads = self.layers[0].attention.num_heads
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ if hpu_attention_meta is not None:
+ hpu_attention_meta = set_block_mapping(
+ hpu_attention_meta, input_ids.shape[0]
+ )
+ hidden_states = self.embed_in(input_ids)
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ hidden_states, _ = self.final_layer_norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
+ def __init__(self, prefix, config, weights):
+ super().__init__(config)
+
+ if not prefix:
+ prefix = "gpt_neox"
+ else:
+ prefix = f"{prefix}.gpt_neox"
+
+ self.gpt_neox = FlashGPTNeoXModel(prefix, config, weights)
+
+ self.embed_out = SpeculativeHead.load(
+ config, prefix="embed_out", weights=weights
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ hidden_states = self.gpt_neox(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits = self.embed_out(hidden_states)
+ return logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py
new file mode 100644
index 00000000000..a13b9f095b9
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py
@@ -0,0 +1,128 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.distributed
+from torch import nn
+from typing import Optional, List, Tuple
+
+from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
+from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
+from text_generation_server.models.custom_modeling.vlm import (
+ load_text_model,
+ load_vision_model,
+)
+
+
+class PaliGemmaForConditionalGeneration(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ config.vision_config.quantize = config.quantize
+ self.vision_tower = load_vision_model(
+ prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
+ config=config.vision_config,
+ weights=weights,
+ )
+ self.post_vision_tower_layernorm = nn.LayerNorm.load(
+ prefix="vision_tower.vision_model.post_layernorm",
+ weights=weights,
+ eps=config.vision_config.layer_norm_eps,
+ )
+
+ self.multi_modal_projector = TensorParallelColumnLinear.load(
+ config,
+ prefix="multi_modal_projector.linear",
+ weights=weights,
+ bias=True,
+ )
+
+ self.vocab_size = config.text_config.vocab_size
+ self.config = config
+
+ text_config = config.text_config
+ text_config.speculator = config.speculator
+ text_config.quantize = config.quantize
+ self.text_model = load_text_model(
+ prefix="language_model" if not prefix else f"{prefix}.language_model",
+ config=config.text_config,
+ weights=weights,
+ )
+ self.pad_token_id = (
+ config.pad_token_id if config.pad_token_id is not None else -1
+ )
+ self.dtype = weights.dtype
+
+ def get_vision_embeds(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_attention_mask: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ ):
+ pixel_values = pixel_values.to(dtype=self.dtype)
+ image_outputs = self.vision_tower(pixel_values)
+ last_hidden_state = self.post_vision_tower_layernorm(
+ image_outputs.last_hidden_state
+ )
+ image_features = self.multi_modal_projector(last_hidden_state)
+ image_features = image_features.view(-1, image_features.shape[-1])
+ return image_features
+
+ def get_inputs_embeds(
+ self,
+ input_ids: torch.Tensor,
+ vision_embeds: torch.Tensor = None,
+ ):
+ inputs_embeds = self.text_model.embed_tokens(input_ids)
+
+ if vision_embeds is not None:
+ mask = input_ids == self.config.image_token_index
+ inputs_embeds[mask] = vision_embeds
+
+ return inputs_embeds
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.BoolTensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ # TODO This is odd but apparently pali gemma position ids start at 1.
+ if cu_seqlen_prefill is not None:
+ position_ids += 1
+
+ hidden_states = self.text_model.model(
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ cu_seqlen_prefill=cu_seqlen_prefill,
+ kv_cache=kv_cache,
+ slots=slots,
+ seqlen=seqlen,
+ hpu_attention_meta=hpu_attention_meta,
+ adapter_data=adapter_data,
+ )
+
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.text_model.lm_head(hidden_states)
+
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py
new file mode 100644
index 00000000000..d7fc844b32e
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py
@@ -0,0 +1,434 @@
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+from typing import Optional, List, Tuple
+
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ set_block_mapping,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+ get_linear,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.layernorm import (
+ FastLayerNorm,
+)
+from text_generation_server.layers.rotary import (
+ PositionRotaryEmbedding,
+)
+import habana_frameworks.torch as htorch
+
+
+class PhiConfig(PretrainedConfig):
+ def __init__(
+ self,
+ vocab_size=51200,
+ hidden_size=2560,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=32,
+ hidden_act="gelu_fast", # llama uses silu
+ layer_norm_eps=1e-05, # rms in llama,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ resid_pdrop=0.1, # llama doesn't have this
+ partial_rotary_factor=0.5, # important difference between llama and phi
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.rope_theta = rope_theta
+ self.resid_pdrop = resid_pdrop
+ self.partial_rotary_factor = partial_rotary_factor
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+# this is the same as llama except for Phi uses bias=True
+def load_attention(config, prefix, weights):
+ if config.num_attention_heads != config.num_key_value_heads:
+ return _load_gqa(config, prefix, weights)
+ else:
+ return TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=True,
+ )
+
+
+def _load_gqa(config, prefix: str, weights):
+ assert config.hidden_size % config.num_attention_heads == 0
+ assert config.num_attention_heads % weights.process_group.size() == 0
+
+ weight = weights.get_multi_weights_col(
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ )
+
+ if config.quantize not in ["gptq", "awq"]:
+ weight = weight.to(dtype=weights.dtype).to(device=weights.device)
+
+ head_size = config.hidden_size // config.num_attention_heads
+ num_heads = config.num_attention_heads // weights.process_group.size()
+ num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
+ assert list(weight.shape) == [
+ (num_heads + 2 * num_key_value_heads) * head_size,
+ config.hidden_size,
+ ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
+
+ # this is the same as llama except for Phi uses bias=True
+ return TensorParallelColumnLinear(get_linear(weight, bias=True))
+
+
+class FlashPhiAttention(torch.nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights,
+ rotary_emb,
+ ):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.head_size = self.hidden_size // self.num_heads
+
+ self.softmax_scale = self.head_size**-0.5
+ self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
+ self.rotary_emb = rotary_emb
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+
+ self.query_key_value = load_attention(config, prefix, weights)
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ # in llama the dense layer is called "o_proj" and has bias=False
+ self.dense = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.dense",
+ weights=weights,
+ bias=True,
+ )
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ # Compute query, key, value and split
+ qkv = self.query_key_value(hidden_states)
+ query, kv = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ 2 * self.head_size * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+
+ # Reshape query and key for rotary embeddings
+ query = query.view(-1, self.num_heads, self.head_size)
+ kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
+
+ # NOTE: this is the main difference between Llama and Phi
+ # in llama the rotary embeddings are applied to the whole query and key.
+ # Phi uses PARTIAL rotary embeddings, which are applied to the first 32 dimensions
+ #
+ # Apply partial positional embeddings in place
+ self.rotary_emb(
+ query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin
+ )
+
+ # Reshape key and value and cache
+ kv_cache.store(
+ key=kv[:, 0],
+ value=kv[:, 1],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ attn_output = attention(
+ query=query,
+ key=kv[:, 0],
+ value=kv[:, 1],
+ kv_scales=self.kv_scales,
+ kv_cache=kv_cache,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
+
+
+class PhiMLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ act = config.hidden_act
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+
+ # llama weights are up_proj and down_proj and bias=False
+ self.up_proj = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.fc1",
+ weights=weights,
+ bias=True,
+ )
+ self.down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.fc2",
+ weights=weights,
+ bias=True,
+ )
+
+ def forward(self, hidden_states):
+ # NOTE: Llama requires the gate up states to an intermediate size
+ # Phi does not and we can avoid the `view` operation
+ return self.down_proj(self.act(self.up_proj(hidden_states)))
+
+
+class FlashPhiLayer(nn.Module):
+ def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):
+ super().__init__()
+ prefix = f"{prefix}.layers.{layer_id}"
+ self.self_attn = FlashPhiAttention(
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ rotary_emb=rotary_emb,
+ )
+ self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+ self.input_layernorm = FastLayerNorm.load(
+ prefix=f"{prefix}.input_layernorm",
+ weights=weights,
+ eps=config.layer_norm_eps,
+ )
+ self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ hidden_states, res = self.input_layernorm(hidden_states, residual)
+ # Self Attention
+ attn_output = self.self_attn(
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states = self.resid_dropout(attn_output).add(
+ self.resid_dropout(self.mlp(hidden_states))
+ )
+
+ return hidden_states, res
+
+
+class FlashPhiModel(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ process_group = weights.process_group
+ self.tp_rank = process_group.rank()
+ self.tp_world_size = process_group.size()
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=f"{prefix}.embed_tokens", weights=weights
+ )
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=int(
+ config.partial_rotary_factor
+ * (config.hidden_size // config.num_attention_heads)
+ ),
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ FlashPhiLayer(
+ prefix,
+ layer_id,
+ config,
+ weights,
+ rotary_emb,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.gradient_checkpointing = False
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ self.norm = FastLayerNorm.load(
+ prefix="model.final_layernorm",
+ weights=weights,
+ eps=config.layer_norm_eps,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ if hpu_attention_meta is not None:
+ hpu_attention_meta = set_block_mapping(
+ hpu_attention_meta, input_ids.shape[0]
+ )
+ hidden_states = self.embed_tokens(input_ids)
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashPhiForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ if not prefix:
+ prefix = "model"
+ else:
+ prefix = f"{prefix}.model"
+
+ self.model = FlashPhiModel(prefix, config, weights)
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix="lm_head",
+ weights=weights,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ hidden_states = self.model(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+
+ return self.lm_head(hidden_states)
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py
new file mode 100644
index 00000000000..c28f3aeeb55
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py
@@ -0,0 +1,253 @@
+# coding=utf-8
+# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""PyTorch Phi-MoE model."""
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+PHIMOE_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "microsoft/Phi-3.5-MoE-instruct": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct/resolve/main/config.json",
+}
+
+
+class PhiMoEConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`PhiMoEModel`]. It is used to instantiate a Phi-MoE
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the
+ [microsoft/Phi-3.5-MoE-instruct](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32064):
+ Vocabulary size of the PhiMoE model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`PhiMoEModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 6400):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 8):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
+ The maximum sequence length that this model might ever be used with. Mixtral's sliding window attention
+ allows sequence of up to 4096*32 tokens.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*):
+ The id of the padding token.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ The id of the "beginning-of-sequence" token.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ The id of the "end-of-sequence" token.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`dict`, *optional*):
+ The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
+ contain the following keys: `type`, `short_factor`, `long_factor`, `short_mscale`, `long_mscale` and
+ `original_max_position_embeddings`. The `type` must be `longrope`, the `short_mscale` and `long_scale` must
+ be numbers, the `short_factor` and `long_factor` must be lists of numbers with the same length as half of
+ the attention head size and the `original_max_position_embeddings` must be an integer.
+ sliding_window (`int`, *optional*):
+ Sliding window attention window size. If not specified, will default to `262144`.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ num_experts_per_tok (`int`, *optional*, defaults to 2):
+ The number of experts to root per-token, can be also interpreted as the `top-p` routing
+ parameter
+ num_local_experts (`int`, *optional*, defaults to 16):
+ Number of experts per Sparse MLP layer.
+ output_router_logits (`bool`, *optional*, defaults to `False`):
+ Whether or not the router logits should be returned by the model. Enabeling this will also
+ allow the model to output the auxiliary loss. See [here]() for more details
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.0):
+ The aux loss factor for the total loss.
+ router_jitter_noise (`float`, *optional*, defaults to 0.01):
+ Amount of noise to add to the router.
+
+ ```python
+ >>> from transformers import PhiMoEModel, PhiMoEConfig
+
+ >>> # Initializing a Phi-3 style configuration
+ >>> configuration = PhiMoEConfig.from_pretrained("microsoft/Phi-3.5-MoE-instruct")
+
+ >>> # Initializing a model from the configuration
+ >>> model = PhiMoEModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "phimoe"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=32064,
+ hidden_size=4096,
+ intermediate_size=6400,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=8,
+ hidden_act="silu",
+ max_position_embeddings=4096 * 32,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ rope_theta=1e6,
+ rope_scaling=None,
+ sliding_window=None,
+ attention_dropout=0.0,
+ num_experts_per_tok=2,
+ num_local_experts=16,
+ output_router_logits=False,
+ router_aux_loss_coef=0.001,
+ router_jitter_noise=0.01,
+ input_jitter_noise=0.0,
+ attention_bias=False,
+ lm_head_bias=False,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.sliding_window = sliding_window
+ self.attention_bias = attention_bias
+ self.lm_head_bias = lm_head_bias
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_dropout = attention_dropout
+
+ self.num_experts_per_tok = num_experts_per_tok
+ self.num_local_experts = num_local_experts
+ self.output_router_logits = output_router_logits
+ self.router_aux_loss_coef = router_aux_loss_coef
+ self.router_jitter_noise = router_jitter_noise
+ self.input_jitter_noise = input_jitter_noise
+
+ self.rope_scaling = rope_scaling
+ self._rope_scaling_validation()
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ def _rope_scaling_validation(self):
+ """
+ Validate the `rope_scaling` configuration.
+ """
+ if self.rope_scaling is None:
+ return
+
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 6:
+ raise ValueError(
+ "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor`, `long_factor`, "
+ f"`short_mscale`, `long_mscale` and `original_max_position_embeddings`, got {self.rope_scaling}"
+ )
+ rope_scaling_type = self.rope_scaling.get("type", None)
+ rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
+ rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
+ rope_scaling_short_mscale = self.rope_scaling.get("short_mscale", None)
+ rope_scaling_long_mscale = self.rope_scaling.get("long_mscale", None)
+ original_max_position_embeddings = self.rope_scaling.get(
+ "original_max_position_embeddings", None
+ )
+ if rope_scaling_type is None or rope_scaling_type not in ["longrope"]:
+ raise ValueError(
+ f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}"
+ )
+ if not (
+ isinstance(rope_scaling_short_factor, list)
+ and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
+ ):
+ raise ValueError(
+ f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
+ )
+ if (
+ not len(rope_scaling_short_factor)
+ == self.hidden_size // self.num_attention_heads // 2
+ ):
+ raise ValueError(
+ f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
+ )
+ if not (
+ isinstance(rope_scaling_long_factor, list)
+ and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
+ ):
+ raise ValueError(
+ f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
+ )
+ if (
+ not len(rope_scaling_long_factor)
+ == self.hidden_size // self.num_attention_heads // 2
+ ):
+ raise ValueError(
+ f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
+ )
+ if not isinstance(rope_scaling_short_mscale, (int, float)):
+ raise ValueError(
+ f"`rope_scaling`'s short_mscale field must be a number, got {rope_scaling_short_mscale}"
+ )
+ if not isinstance(rope_scaling_long_mscale, (int, float)):
+ raise ValueError(
+ f"`rope_scaling`'s long_mscale field must be a number, got {rope_scaling_long_mscale}"
+ )
+ if not isinstance(original_max_position_embeddings, int):
+ raise ValueError(
+ f"`rope_scaling`'s original_max_position_embeddings field must be an integer, got {original_max_position_embeddings}"
+ )
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py
new file mode 100644
index 00000000000..de7641e357c
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py
@@ -0,0 +1,391 @@
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from typing import Optional, List, Tuple
+
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ set_block_mapping,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.rotary import PositionRotaryEmbedding
+from text_generation_server.layers.layernorm import (
+ FastRMSNorm,
+)
+import habana_frameworks.torch as htorch
+
+
+def load_attention(config, prefix, weights):
+ if config.num_attention_heads != config.num_key_value_heads:
+ return _load_gqa(config, prefix, weights)
+ else:
+ return TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=True,
+ )
+
+
+def _load_gqa(config, prefix: str, weights):
+ assert config.hidden_size % config.num_attention_heads == 0
+ assert config.num_attention_heads % weights.process_group.size() == 0
+
+ return TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=True,
+ )
+
+
+class Qwen2Attention(torch.nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights,
+ rotary_emb,
+ ):
+ super().__init__()
+ self.max_past = (
+ config.sliding_window
+ if config.use_sliding_window and config.sliding_window is not None
+ else -1
+ )
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.head_size = self.hidden_size // self.num_heads
+ self.rotary_emb = rotary_emb
+
+ self.softmax_scale = self.head_size**-0.5
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+
+ self.query_key_value = load_attention(config, prefix, weights)
+
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ self.o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ qkv = self.query_key_value(hidden_states)
+ query, kv = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ 2 * self.head_size * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+ query = query.view(-1, self.num_heads, self.head_size)
+ kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
+
+ self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
+
+ kv_cache.store(
+ key=kv[:, 0],
+ value=kv[:, 1],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=kv[:, 0],
+ value=kv[:, 1],
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ window_size_left=self.max_past,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ window_size_left=self.max_past,
+ )
+
+ return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
+
+
+class Qwen2MLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ act = config.hidden_act
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+ # Fuse gate and up proj
+ self.gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+ self.down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+
+ def forward(self, hidden_states):
+ gate_up_states = self.gate_up_proj(hidden_states)
+ gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
+ return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
+
+
+class Qwen2Layer(nn.Module):
+ def __init__(self, prefix, layer_id, config, weights, rotary_emb):
+ super().__init__()
+ prefix = f"{prefix}.layers.{layer_id}"
+ self.self_attn = Qwen2Attention(
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ rotary_emb=rotary_emb,
+ )
+ self.mlp = Qwen2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+ self.input_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ normed_hidden_states, residual = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ hidden_states = attn_output + residual
+
+ # faster post attention rms norm
+ hidden_states, residual = self.post_attention_layernorm(hidden_states)
+ mlp_output = self.mlp(hidden_states)
+ hidden_states = mlp_output + residual
+ return hidden_states
+
+
+class Qwen2Model(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ prefix = f"{prefix}.model" if prefix else "model"
+
+ process_group = weights.process_group
+ self.tp_rank = process_group.rank()
+ self.tp_world_size = process_group.size()
+
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.hidden_size // config.num_attention_heads,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ Qwen2Layer(
+ prefix,
+ layer_id,
+ config,
+ weights,
+ rotary_emb,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = FastRMSNorm.load(
+ prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
+ )
+
+ self.gradient_checkpointing = False
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ if hpu_attention_meta is not None:
+ hpu_attention_meta = set_block_mapping(
+ hpu_attention_meta, inputs_embeds.shape[0]
+ )
+ hidden_states = inputs_embeds
+
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
+ position_ids,
+ )
+
+ residual = None
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+ for i, layer in enumerate(self.layers):
+ hidden_states = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ hidden_states, _ = self.norm(hidden_states)
+
+ return hidden_states
+
+
+class Qwen2ForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ self.model = Qwen2Model(prefix, config, weights)
+
+ if config.tie_word_embeddings:
+ suffix = "model.embed_tokens"
+ else:
+ suffix = "lm_head"
+
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix=f"{prefix}.{suffix}" if prefix else suffix,
+ weights=weights,
+ )
+
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens",
+ weights=weights,
+ )
+
+ self.max_past = config.sliding_window
+ self.max_past_tensor = (
+ torch.tensor(config.sliding_window, device=weights.device)
+ if self.max_past is not None
+ else None
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ hidden_states = self.model(
+ inputs_embeds,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits = self.lm_head(hidden_states)
+ return logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py
new file mode 100644
index 00000000000..8ffbde98aaa
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py
@@ -0,0 +1,366 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, List
+
+import torch
+from torch import nn
+import habana_frameworks.torch as htorch
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ set_block_mapping,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers import (
+ TensorParallelEmbedding,
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ SpeculativeHead,
+)
+
+
+from text_generation_server.layers.layernorm import (
+ FastRMSNorm,
+)
+from .flash_qwen2_modeling import Qwen2MLP
+from text_generation_server.layers.rotary import PositionRotaryEmbedding
+
+
+class Qwen3Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config, prefix, weights, layer_idx, rotary_emb):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(
+ config, "head_dim", config.hidden_size // config.num_attention_heads
+ )
+ self.num_key_value_groups = (
+ config.num_attention_heads // config.num_key_value_heads
+ )
+ self.num_heads = config.num_attention_heads
+ self.attention_dropout = config.attention_dropout
+ self.softmax_scale = self.head_dim**-0.5
+ self.rotary_emb = rotary_emb
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+ self.query_key_value = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=False,
+ )
+
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ self.o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=False,
+ )
+
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ self.max_past = (
+ config.sliding_window if config.sliding_window is not None else -1
+ )
+
+ self.q_norm = FastRMSNorm.load(
+ prefix=f"{prefix}.q_norm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.k_norm = FastRMSNorm.load(
+ prefix=f"{prefix}.k_norm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.sliding_window = config.sliding_window
+ if not (
+ self.config.use_sliding_window
+ and getattr(self.config, "sliding_window", None) is not None
+ and self.layer_idx >= self.config.max_window_layers
+ ):
+ self.sliding_window = None
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+ qkv = self.query_key_value(hidden_states)
+ query_states, key_states, value_states = qkv.split(
+ [
+ self.head_dim * self.num_heads,
+ self.head_dim * self.num_key_value_heads,
+ self.head_dim * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+
+ query_states, _ = self.q_norm(query_states.view(hidden_shape))
+ key_states, _ = self.k_norm(key_states.view(hidden_shape))
+ value_states = value_states.view(hidden_shape)
+ self.rotary_emb(query_states, key_states, cos, sin)
+
+ kv_cache.store(
+ key=key_states,
+ value=value_states,
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query_states,
+ key=key_states,
+ value=value_states,
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ window_size_left=self.max_past,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query_states,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ window_size_left=self.max_past,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ return self.o_proj(attn_output)
+
+
+class Qwen3DecoderLayer(nn.Module):
+ def __init__(self, config, prefix, weights, layer_idx: int, rotary_emb):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = Qwen3Attention(
+ config=config,
+ prefix=f"{prefix}.self_attn",
+ weights=weights,
+ layer_idx=layer_idx,
+ rotary_emb=rotary_emb,
+ )
+ self.mlp = Qwen2MLP(config=config, prefix=f"{prefix}.mlp", weights=weights)
+ self.input_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states, _ = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states = self.self_attn(
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states, _ = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class Qwen3Model(nn.Module):
+ def __init__(self, config, prefix: str, weights):
+ super().__init__()
+ self.config = config
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ head_dim = getattr(
+ config, "head_dim", config.hidden_size // config.num_attention_heads
+ )
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=head_dim,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ Qwen3DecoderLayer(
+ config=config,
+ prefix=f"{prefix}.layers.{layer_idx}",
+ weights=weights,
+ layer_idx=layer_idx,
+ rotary_emb=rotary_emb,
+ )
+ for layer_idx in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = FastRMSNorm.load(
+ prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
+ )
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ if hpu_attention_meta is not None:
+ hpu_attention_meta = set_block_mapping(
+ hpu_attention_meta, inputs_embeds.shape[0]
+ )
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
+ position_ids,
+ )
+
+ residual = None
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ for i, decoder_layer in enumerate(self.layers):
+ hidden_states = decoder_layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ hidden_states, _ = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ return hidden_states
+
+
+class Qwen3ForCausalLM(nn.Module):
+
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+ self.model = Qwen3Model(config=config, prefix="model", weights=weights)
+ self.vocab_size = config.vocab_size
+ if config.tie_word_embeddings:
+ suffix = "model.embed_tokens"
+ else:
+ suffix = "lm_head"
+
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix=f"{prefix}.{suffix}" if prefix else suffix,
+ weights=weights,
+ )
+
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens",
+ weights=weights,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ inputs_embeds = self.embed_tokens(input_ids)
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ hidden_states = self.model(
+ inputs_embeds,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits = self.lm_head(hidden_states)
+
+ return logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py
new file mode 100644
index 00000000000..bd48519c23d
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py
@@ -0,0 +1,554 @@
+# coding=utf-8
+# Copyright 5 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from text_generation_server.layers.attention import (
+ attention,
+ paged_attention,
+ set_block_mapping,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
+from text_generation_server.layers import (
+ TensorParallelEmbedding,
+ TensorParallelColumnLinear,
+ TensorParallelRowLinear,
+ SpeculativeHead,
+ FastLinear,
+)
+
+from text_generation_server.layers.layernorm import (
+ FastRMSNorm,
+)
+from .flash_qwen2_modeling import Qwen2MLP
+from .flash_qwen3_modeling import Qwen3Attention
+from transformers.activations import ACT2FN
+from text_generation_server.layers.rotary import PositionRotaryEmbedding
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class Qwen3MoeAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config, prefix, weights, layer_idx, rotary_emb):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(
+ config, "head_dim", config.hidden_size // config.num_attention_heads
+ )
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = (
+ config.num_attention_heads // config.num_key_value_heads
+ )
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = FastLinear.load(
+ config, f"{prefix}.q_proj", weights, bias=config.attention_bias
+ )
+
+ self.k_proj = FastLinear.load(
+ config, f"{prefix}.k_proj", weights, bias=config.attention_bias
+ )
+ self.v_proj = FastLinear.load(
+ config, f"{prefix}.v_proj", weights, bias=config.attention_bias
+ )
+ self.o_proj = FastLinear.load(
+ config, f"{prefix}.o_proj", weights, bias=config.attention_bias
+ )
+ self.rotary_emb = rotary_emb
+
+ self.q_norm = FastRMSNorm.load(
+ prefix=f"{prefix}.q_norm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ self.k_norm = FastRMSNorm.load(
+ prefix=f"{prefix}.k_norm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ self.max_past = (
+ config.sliding_window if config.sliding_window is not None else -1
+ )
+
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_key_value_groups)
+
+ self.sliding_window = config.sliding_window
+ if not (
+ self.config.use_sliding_window
+ and getattr(self.config, "sliding_window", None) is not None
+ and self.layer_idx >= self.config.max_window_layers
+ ):
+ self.sliding_window = None
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states, _ = self.q_norm(self.q_proj(hidden_states).view(hidden_shape))
+ key_states, _ = self.k_norm(self.k_proj(hidden_states).view(hidden_shape))
+ value_states = self.v_proj(hidden_states).view(hidden_shape)
+
+ self.rotary_emb(query_states, key_states, cos, sin)
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ kv_cache.store(
+ key=key_states,
+ value=value_states,
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query_states,
+ key=key_states,
+ value=value_states,
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.scaling,
+ window_size_left=self.max_past,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query_states,
+ kv_cache,
+ self.kv_head_mapping,
+ self.scaling,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ window_size_left=self.max_past,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output
+
+
+class Qwen3MoE(nn.Module):
+ def __init__(self, prefix, config, moe_layer_cls: Type[MoELayer], weights):
+ super().__init__()
+
+ # gating
+ self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
+
+ self.moe = moe_layer_cls(
+ n_expert_group=None,
+ n_experts=config.num_experts,
+ prefix=f"{prefix}.experts",
+ renormalize=True,
+ topk=config.num_experts_per_tok,
+ topk_group=None,
+ weights=weights,
+ )
+ # gate_proj_name="w1",
+ # up_proj_name="w3",
+ # down_proj_name="w2",
+
+ assert isinstance(self.moe, MoELayer)
+
+ self.process_group = weights.process_group
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ router_logits = self.gate(x)
+ out = self.moe(x, gating_output=router_logits)
+
+ # Reduce sum
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(out, group=self.process_group)
+
+ return out.view(*x.shape)
+
+
+class Qwen3MoeMLP(nn.Module):
+ def __init__(self, prefix, config, weights, intermediate_size=None):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = (
+ intermediate_size
+ if intermediate_size is not None
+ else config.intermediate_size
+ )
+ # Fuse gate and up proj
+ self.gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+ self.down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ gate_up_states = self.gate_up_proj(x)
+ gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
+ return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
+
+
+class Qwen3MoeSparseMoeBlock(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.num_experts = config.num_experts
+ self.top_k = config.num_experts_per_tok
+ self.norm_topk_prob = config.norm_topk_prob
+
+ # gating
+ # self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
+ self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
+ self.experts = nn.ModuleList(
+ [
+ Qwen3MoeMLP(
+ prefix=f"{prefix}.experts.{i}",
+ config=config,
+ weights=weights,
+ intermediate_size=config.moe_intermediate_size,
+ )
+ for i in range(self.num_experts)
+ ]
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """ """
+ input_shape = hidden_states.shape
+ _, hidden_dim = hidden_states.shape
+ # hidden_states = hidden_states.view(-1, hidden_dim)
+ # router_logits: (batch * sequence_length, n_experts)
+ router_logits = self.gate(hidden_states)
+
+ routing_weights = F.softmax(router_logits, dim=1, dtype=hidden_states.dtype)
+ routing_weights, selected_experts = torch.topk(
+ routing_weights, self.top_k, dim=-1
+ )
+ if self.norm_topk_prob: # only diff with mixtral sparse moe block!
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+ # we cast back to the input dtype
+ routing_weights = routing_weights.to(hidden_states.dtype)
+
+ final_hidden_states = torch.zeros(
+ (input_shape), dtype=hidden_states.dtype, device=hidden_states.device
+ )
+
+ # One hot encode the selected experts to create an expert mask
+ # this will be used to easily index which expert is going to be sollicitated
+ expert_mask = torch.nn.functional.one_hot(
+ selected_experts, num_classes=self.num_experts
+ ).permute(2, 1, 0)
+ # Loop over all available experts in the model and perform the computation on each expert
+ for expert_idx in range(self.num_experts):
+ expert_layer = self.experts[expert_idx]
+ idx, top_x = torch.where(expert_mask[expert_idx])
+
+ # Index the correct hidden states and compute the expert hidden state for
+ # the current expert. We need to make sure to multiply the output hidden
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
+ current_hidden_states = (
+ expert_layer(current_state) * routing_weights[top_x, idx, None]
+ )
+
+ # However `index_add_` only support torch tensors for indexing so we'll use
+ # the `top_x` tensor here.
+ final_hidden_states.index_add_(
+ 0, top_x, current_hidden_states.to(hidden_states.dtype)
+ )
+ final_hidden_states = final_hidden_states.reshape(input_shape)
+ return final_hidden_states
+
+
+class Qwen3MoeDecoderLayer(nn.Module):
+ def __init__(self, config, prefix, weights, layer_idx: int, rotary_emb):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ if config.num_key_value_heads // weights.process_group.size() > 0:
+ self.self_attn = Qwen3Attention(
+ config,
+ prefix=f"{prefix}.self_attn",
+ weights=weights,
+ layer_idx=layer_idx,
+ rotary_emb=rotary_emb,
+ )
+ else:
+ self.self_attn = Qwen3MoeAttention(
+ config,
+ prefix=f"{prefix}.self_attn",
+ weights=weights,
+ layer_idx=layer_idx,
+ rotary_emb=rotary_emb,
+ )
+
+ moe_layer_cls = (
+ SparseMoELayer if SparseMoELayer.is_supported(weights) else DenseMoELayer
+ )
+
+ if (layer_idx not in config.mlp_only_layers) and (
+ config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
+ ):
+ self.mlp = Qwen3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
+ # self.mlp = Qwen3MoeSparseMoeBlock(f"{prefix}.mlp", config, weights)
+
+ else:
+ self.mlp = Qwen2MLP(config=config, prefix=f"{prefix}.mlp", weights=weights)
+
+ self.input_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ) -> torch.Tensor:
+ if residual is None:
+ residual = hidden_states
+
+ hidden_states, _ = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states = self.self_attn(
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states, _ = self.post_attention_layernorm(hidden_states)
+
+ hidden_states = self.mlp(hidden_states)
+
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class Qwen3MoeModel(nn.Module):
+ def __init__(self, config, prefix: str, weights):
+ super().__init__()
+ self.config = config
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ head_dim = getattr(
+ config, "head_dim", config.hidden_size // config.num_attention_heads
+ )
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=head_dim,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ Qwen3MoeDecoderLayer(
+ config=config,
+ prefix=f"{prefix}.layers.{layer_idx}",
+ weights=weights,
+ layer_idx=layer_idx,
+ rotary_emb=rotary_emb,
+ )
+ for layer_idx in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = FastRMSNorm.load(
+ prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
+ )
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ if hpu_attention_meta is not None:
+ hpu_attention_meta = set_block_mapping(
+ hpu_attention_meta, inputs_embeds.shape[0]
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
+ position_ids,
+ )
+
+ residual = None
+ for i, decoder_layer in enumerate(self.layers):
+ hidden_states = decoder_layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states, _ = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ return hidden_states
+
+
+class Qwen3MoeForCausalLM(nn.Module):
+
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+ self.model = Qwen3MoeModel(config=config, prefix="model", weights=weights)
+ self.vocab_size = config.vocab_size
+ if config.tie_word_embeddings:
+ suffix = "model.embed_tokens"
+ else:
+ suffix = "lm_head"
+
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix=f"{prefix}.{suffix}" if prefix else suffix,
+ weights=weights,
+ )
+
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens",
+ weights=weights,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+
+ inputs_embeds = self.embed_tokens(input_ids)
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ hidden_states = self.model(
+ inputs_embeds,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits = self.lm_head(hidden_states)
+
+ return logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
new file mode 100644
index 00000000000..bd8397f18d8
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
@@ -0,0 +1,704 @@
+from typing import List, Optional, Tuple
+
+import torch
+import torch.distributed
+from torch import nn
+from transformers.configuration_utils import PretrainedConfig
+from transformers.modeling_utils import PreTrainedModel
+from text_generation_server.layers import (
+ SpeculativeHead,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ TensorParallelRowLinear,
+ get_linear,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.layernorm import FastLayerNorm
+from text_generation_server.layers.rotary import PositionRotaryEmbedding
+from text_generation_server.layers.attention import (
+ attention,
+ paged_attention,
+ set_block_mapping,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+import habana_frameworks.torch as htorch
+
+
+def load_row(config, prefix: str, weights, bias: bool):
+ weight = weights.get_weights_row(prefix)
+
+ if bias and weights.process_group.rank() == 0:
+ # Rank is only on the first rank process
+ bias = weights.get_tensor(f"{prefix}.bias")
+ else:
+ bias = None
+
+ linear = get_linear(weight, bias)
+ if config.parallel_attn:
+ return linear
+ else:
+ return TensorParallelRowLinear(linear, process_group=weights.process_group)
+
+
+class RWConfig(PretrainedConfig):
+ attribute_map = {
+ "num_hidden_layers": "n_layer",
+ "num_attention_heads": "n_head",
+ "num_key_value_heads": "n_head_kv",
+ }
+
+ def __init__(
+ self,
+ model_type="RefinedWeb",
+ vocab_size=250880,
+ hidden_size=64,
+ num_hidden_layers=None,
+ num_attention_heads=None,
+ num_ln_in_prallel_attention=None,
+ layer_norm_epsilon=1e-5,
+ initializer_range=0.02,
+ use_cache=True,
+ bos_token_id=1,
+ eos_token_id=2,
+ hidden_dropout=0.0,
+ attention_dropout=0.0,
+ num_kv_heads=None,
+ multi_query=False,
+ alibi=False,
+ new_decoder_architecture=None,
+ bias=False,
+ parallel_attn=False,
+ rope_theta=10_000.0,
+ **kwargs,
+ ):
+ if alibi:
+ raise NotImplementedError(
+ "alibi is not supported by this version of the model"
+ )
+
+ self.model_type = model_type
+ self.alibi = False
+ self.rotary = True
+ self.rope_theta = rope_theta
+ self.max_position_embeddings = 2048
+
+ self.vocab_size = vocab_size
+ # Backward compatibility with n_embed kwarg
+ n_embed = kwargs.pop("n_embed", None)
+ self.hidden_size = hidden_size if n_embed is None else n_embed
+ self.n_layer = (
+ num_hidden_layers
+ if num_hidden_layers is not None
+ else kwargs.pop("n_layer", 2)
+ )
+ self.n_head = (
+ num_attention_heads
+ if num_attention_heads is not None
+ else kwargs.pop("n_head", 8)
+ )
+ self.layer_norm_epsilon = layer_norm_epsilon
+ self.num_ln_in_parallel_attn = num_ln_in_prallel_attention
+ self.initializer_range = initializer_range
+ self.use_cache = use_cache
+ self.hidden_dropout = hidden_dropout
+ self.attention_dropout = attention_dropout
+ self.bias = bias
+ self.parallel_attn = parallel_attn
+
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+
+ if num_kv_heads is not None:
+ self.n_head_kv = num_kv_heads
+ else:
+ old_n_head_kv = kwargs.pop("n_head_kv", None)
+ if old_n_head_kv is not None:
+ self.n_head_kv = old_n_head_kv
+ else:
+ self.n_head_kv = 1 if multi_query else self.n_head
+
+ if new_decoder_architecture is not None:
+ self.new_decoder_architecture = new_decoder_architecture
+ elif model_type == "RefinedWeb":
+ self.new_decoder_architecture = True
+ else:
+ self.new_decoder_architecture = False
+
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+
+class FlashRWAttention(torch.nn.Module):
+ def __init__(
+ self,
+ config,
+ prefix: str,
+ weights,
+ rotary_emb,
+ ):
+ super().__init__()
+ self.num_heads = config.n_head
+ self.num_heads_kv = config.n_head_kv
+ self.hidden_size = config.hidden_size
+ self.head_size = self.hidden_size // self.num_heads
+ self.rope_theta = config.rope_theta
+ self.rotary_emb = rotary_emb
+
+ self.softmax_scale = self.head_size ** (-0.5)
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+
+ self.query_key_value = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.query_key_value",
+ weights=weights,
+ bias=config.bias,
+ )
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+ self.dense = load_row(
+ config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
+ )
+
+ if self.num_heads_kv == 1:
+ self.kv_head_mapping = torch.zeros(
+ self.num_heads, dtype=torch.int32, device=weights.device
+ )
+ else:
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_heads, dtype=torch.int32, device=weights.device
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ qkv = self.query_key_value(hidden_states)
+
+ # Split query from key_value
+ query, kv = qkv.split(
+ [self.head_size * self.num_heads, 2 * self.head_size * self.num_heads_kv],
+ dim=1,
+ )
+
+ # Prepare query and key_value for indexing
+ query = query.view(-1, self.num_heads, self.head_size)
+ kv = kv.view(-1, 2, self.num_heads_kv, self.head_size)
+
+ # Inplace rotary
+ self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
+
+ kv_cache.store(
+ key=kv[:, 0],
+ value=kv[:, 1],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=kv[:, 0],
+ value=kv[:, 1],
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
+
+
+class FlashRWLargeAttention(torch.nn.Module):
+ def __init__(
+ self,
+ config,
+ prefix: str,
+ weights,
+ rotary_emb,
+ ):
+ super().__init__()
+
+ hidden_size = config.hidden_size
+ num_heads = config.n_head
+ # num_heads_kv = config.n_head_kv
+ num_groups = config.n_head_kv
+
+ self.hidden_size = hidden_size
+ self.head_size = hidden_size // num_heads
+ self.num_groups = num_groups
+ self.rope_theta = config.rope_theta
+ self.rotary_emb = rotary_emb
+
+ self.softmax_scale = self.head_size ** (-0.5)
+
+ # self.num_groups = num_heads // (num_heads_kv * 2)
+ self.num_heads = num_heads // self.num_groups
+ # self.num_heads_kv = num_heads_kv // self.num_groups
+ process_group = weights.process_group
+
+ if process_group.size() > self.num_groups:
+ raise NotImplementedError(
+ "Tensor Parallelism is not implemented for world_size > n groups"
+ )
+ if self.num_groups % process_group.size() != 0:
+ raise NotImplementedError(
+ f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}"
+ )
+
+ self.num_groups = self.num_groups // process_group.size()
+
+ self.query_key_value = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.query_key_value",
+ weights=weights,
+ bias=config.bias,
+ )
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+ self.dense = load_row(
+ config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
+ )
+
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_groups, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_heads)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ qkv = self.query_key_value(hidden_states)
+ qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
+
+ # Split on group dimension
+ query, kv = qkv.split(
+ [self.num_heads, 2],
+ dim=2,
+ )
+ # Merge groups and heads
+ query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
+
+ # Inplace rotary
+ self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
+
+ kv_cache.store(
+ key=kv[:, :, 0].contiguous(),
+ value=kv[:, :, 1].contiguous(),
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # flash attention
+ attn_output = attention(
+ query=query,
+ key=kv[:, :, 0],
+ value=kv[:, :, 1],
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.dense(
+ attn_output.view(-1, self.num_groups * self.num_heads * self.head_size)
+ )
+
+
+class FlashMLP(nn.Module):
+ def __init__(self, config, prefix: str, weights):
+ super().__init__()
+ self.act = torch.nn.functional.gelu
+
+ self.dense_h_to_4h = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias
+ )
+ self.dense_4h_to_h = load_row(
+ config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense_h_to_4h(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.dense_4h_to_h(hidden_states)
+ return hidden_states
+
+
+class FlashRWLayer(nn.Module):
+ def __init__(
+ self,
+ layer_id,
+ prefix: str,
+ config,
+ weights,
+ rotary_emb,
+ ):
+ super().__init__()
+
+ parallel_attn = config.parallel_attn
+ self.parallel_attn = parallel_attn
+
+ prefix = f"{prefix}.h.{layer_id}"
+
+ # NOTE: Falcon 180B uses the ln_attn prefix
+ ln_prefix = "input_layernorm"
+ if config.num_hidden_layers == 80:
+ ln_prefix = "ln_attn"
+
+ self.input_layernorm = FastLayerNorm.load(
+ prefix=f"{prefix}.{ln_prefix}",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+ self.self_attention = FlashRWAttention(
+ config,
+ prefix=f"{prefix}.self_attention",
+ weights=weights,
+ rotary_emb=rotary_emb,
+ )
+ self.post_attention_layernorm = (
+ FastLayerNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+ if not parallel_attn
+ else None
+ )
+
+ self.mlp = FlashMLP(
+ config,
+ prefix=f"{prefix}.mlp",
+ weights=weights,
+ )
+
+ self.process_group = weights.process_group
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ if self.parallel_attn:
+ ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
+
+ attn_output = self.self_attention(
+ ln_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ mlp_output = self.mlp(ln_hidden_states)
+ intermediate = mlp_output + attn_output
+
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(intermediate, group=self.process_group)
+
+ return intermediate, residual
+ else:
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+
+ hidden_states = self.self_attention(
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ if self.post_attention_layernorm is not None:
+ hidden_states, residual = self.post_attention_layernorm(
+ hidden_states, residual
+ )
+
+ mlp_output = self.mlp(hidden_states)
+
+ return mlp_output, residual
+
+
+class FlashRWLayerNorm(nn.Module):
+ def __init__(self, config, prefix: str, weights):
+ super().__init__()
+ # Falcon2 includes the number of layer norms in the config
+ # in the case no number of layer norms is provided, we default to 1
+ self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1)
+
+ # Falcon 180B uses the ln_attn prefix and has 2 layer norms
+ if config.num_hidden_layers == 80:
+ self.num_ln = 2
+
+ if self.num_ln == 1:
+ self.input_ln = FastLayerNorm.load(
+ prefix=f"{prefix}.input_layernorm",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+ elif self.num_ln == 2:
+ self.ln_attn = FastLayerNorm.load(
+ prefix=f"{prefix}.ln_attn",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+ self.ln_mlp = FastLayerNorm.load(
+ prefix=f"{prefix}.ln_mlp",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+ else:
+ raise ValueError("Number of layer norms can either be 1 or 2.")
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ ):
+ if self.num_ln == 1:
+ ln_hidden_states, residual = self.input_ln(hidden_states, residual)
+ return ln_hidden_states, ln_hidden_states, residual
+ elif self.num_ln == 2:
+ ln_attn, residual = self.ln_attn(hidden_states, residual)
+ ln_mlp, _ = self.ln_mlp(residual)
+ return ln_attn, ln_mlp, residual
+
+
+class FlashRWLargeLayer(nn.Module):
+ def __init__(self, layer_id, prefix: str, config, weights, rotary_emb):
+ super().__init__()
+ prefix = f"{prefix}.h.{layer_id}"
+
+ self.ln_layer = FlashRWLayerNorm(config, prefix, weights)
+
+ self.self_attention = FlashRWLargeAttention(
+ config,
+ prefix=f"{prefix}.self_attention",
+ weights=weights,
+ rotary_emb=rotary_emb,
+ )
+ assert config.parallel_attn, "This version doesn't support non parallel_attn"
+
+ self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
+
+ self.process_group = weights.process_group
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ # Layer norm.
+ ln_attn, ln_mlp, residual = self.ln_layer(hidden_states, residual)
+
+ # Self attention.
+ attn_output = self.self_attention(
+ ln_attn,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ # MLP.
+ mlp_output = self.mlp(ln_mlp)
+
+ intermediate = attn_output + mlp_output
+
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(intermediate, group=self.process_group)
+
+ return intermediate, residual
+
+
+class FlashRWPreTrainedModel(PreTrainedModel):
+ config_class = RWConfig
+
+
+class FlashRWModel(FlashRWPreTrainedModel):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__(config)
+ self.config = config
+
+ self.word_embeddings = TensorParallelEmbedding(
+ prefix=f"{prefix}.word_embeddings", weights=weights
+ )
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.hidden_size // config.n_head,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
+ if config.new_decoder_architecture:
+ self.h = nn.ModuleList(
+ [
+ FlashRWLargeLayer(layer_id, prefix, config, weights, rotary_emb)
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.cache_size = self.h[0].self_attention.num_groups
+ else:
+ self.h = nn.ModuleList(
+ [
+ FlashRWLayer(layer_id, prefix, config, weights, rotary_emb)
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.cache_size = self.h[0].self_attention.num_heads_kv
+
+ self.ln_f = FastLayerNorm.load(
+ prefix=f"{prefix}.ln_f",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+
+ self.head_size = self.h[0].self_attention.head_size
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ if hpu_attention_meta is not None:
+ hpu_attention_meta = set_block_mapping(
+ hpu_attention_meta, input_ids.shape[0]
+ )
+ hidden_states = self.word_embeddings(input_ids)
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+ for i, layer in enumerate(self.h):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ hidden_states, _ = self.ln_f(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashRWForCausalLM(FlashRWPreTrainedModel):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__(config)
+
+ if not prefix:
+ prefix = "transformer"
+ else:
+ prefix = f"{prefix}.transformer"
+
+ self.transformer = FlashRWModel(prefix, config, weights)
+
+ self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ hidden_states = self.transformer(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits = self.lm_head(hidden_states)
+ return logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
new file mode 100644
index 00000000000..b6a0d32aab9
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
@@ -0,0 +1,511 @@
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from typing import Optional, List, Tuple
+
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ set_block_mapping,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ SpeculativeHead,
+ TensorParallelEmbedding,
+ get_linear,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.gptq import GPTQWeightsLoader
+from text_generation_server.layers.layernorm import (
+ FastLayerNorm,
+)
+import habana_frameworks.torch as htorch
+
+
+def load_multi_mqa(
+ config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
+):
+ if config.quantize == "gptq":
+ return _load_multi_mqa_gptq(
+ config, prefix, weights, bias, head_size, num_heads, hidden_size
+ )
+ else:
+ return _load_multi_mqa(
+ config, prefix, weights, bias, head_size, num_heads, hidden_size
+ )
+
+
+def _load_multi_mqa_gptq(
+ config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
+):
+ from text_generation_server.layers.gptq import GPTQWeight
+
+ if any("c_attn" in k for k in weights.routing.keys()) and not config.transpose:
+ world_size = weights.process_group.size()
+ rank = weights.process_group.rank()
+
+ slice_ = weights._get_slice(f"{prefix}.c_attn.qweight")
+ shape = slice_.get_shape()
+ block_size = (shape[1] - 2 * head_size) // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+ assert (shape[1] - 2 * head_size) % world_size == 0
+ q_tensor = slice_[:, start:stop]
+ kv_tensor = slice_[:, -2 * head_size :]
+ qweight = torch.cat([q_tensor, kv_tensor], dim=1)
+ qweight = qweight.to(device=weights.device)
+
+ slice_ = weights._get_slice(f"{prefix}.c_attn.scales")
+ shape = slice_.get_shape()
+ block_size = (shape[1] - 2 * head_size) // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+ assert (shape[1] - 2 * head_size) % world_size == 0
+ q_tensor = slice_[:, start:stop]
+ kv_tensor = slice_[:, -2 * head_size :]
+ scales = torch.cat([q_tensor, kv_tensor], dim=1)
+ scales = scales.to(device=weights.device)
+
+ slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros")
+ shape = slice_.get_shape()
+ block_size = (shape[1] - (2 * head_size) * 4 // 32) // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+ assert 2 * head_size % (32 // 4) == 0
+ q_tensor = slice_[:, start:stop]
+ kv_tensor = slice_[:, -2 * head_size * 4 // 32 :]
+ qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
+ qzeros = qzeros.to(device=weights.device)
+
+ loader = weights.weights_loader
+ assert isinstance(loader, GPTQWeightsLoader)
+ loader._get_gptq_params(weights)
+ if loader.quant_method == "gptq":
+ g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
+ g_idx = g_idx.to(device=weights.device)
+ elif loader.quant_method == "awq":
+ g_idx = None
+ from text_generation_server.layers.awq.conversion_utils import (
+ fast_awq_to_gptq,
+ )
+
+ qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
+
+ from text_generation_server.layers.gptq import HAS_EXLLAMA
+
+ weight = GPTQWeight(
+ qweight=qweight,
+ qzeros=qzeros,
+ scales=scales,
+ g_idx=g_idx,
+ bits=loader.bits,
+ groupsize=loader.groupsize,
+ use_awq_kernel=loader.quantize == "awq",
+ use_exllama=HAS_EXLLAMA,
+ )
+
+ if bias:
+ slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
+ shape = slice_.get_shape()
+ block_size = (shape[0] - 2 * head_size) // world_size
+ assert (shape[0] - 2 * head_size) % world_size == 0
+ q_tensor = slice_[start:stop]
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+ q_tensor = slice_[start:stop]
+ kv_tensor = slice_[-2 * head_size :]
+ bias = torch.cat([q_tensor, kv_tensor], dim=0)
+ bias = bias.to(device=weights.device)
+
+ return TensorParallelColumnLinear(get_linear(weight, bias))
+ else:
+ raise NotImplementedError("Gptq loading with santacoder is not implemented")
+
+
+def _load_multi_mqa(
+ config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
+):
+ if any("c_attn" in k for k in weights.routing.keys()):
+ slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
+ shape = slice_.get_shape()
+ world_size = weights.process_group.size()
+ rank = weights.process_group.rank()
+ if config.transpose:
+ block_size = (shape[1] - 2 * head_size) // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+ assert (shape[1] - 2 * head_size) % world_size == 0
+ q_tensor = slice_[:, start:stop]
+ kv_tensor = slice_[:, -2 * head_size :]
+ weight = torch.cat([q_tensor, kv_tensor], dim=1).T
+ else:
+ block_size = (shape[0] - 2 * head_size) // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+ assert (shape[0] - 2 * head_size) % world_size == 0
+ q_tensor = slice_[start:stop]
+ kv_tensor = slice_[-2 * head_size :]
+ weight = torch.cat([q_tensor, kv_tensor], dim=0)
+ if bias:
+ slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
+ shape = slice_.get_shape()
+ block_size = (shape[0] - 2 * head_size) // world_size
+ assert (shape[0] - 2 * head_size) % world_size == 0
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+ q_tensor = slice_[start:stop]
+ kv_tensor = slice_[-2 * head_size :]
+ bias = torch.cat([q_tensor, kv_tensor], dim=0)
+ else:
+ if config.transpose:
+ w = [
+ weights.get_sharded(f"{prefix}.q_attn.weight", dim=1).T,
+ weights.get_tensor(f"{prefix}.kv_attn.weight").T,
+ ]
+ weight = torch.cat(w, dim=0)
+ else:
+ w = [
+ weights.get_sharded(f"{prefix}.q_attn.weight", dim=0),
+ weights.get_tensor(f"{prefix}.kv_attn.weight"),
+ ]
+ weight = torch.cat(w, dim=1)
+
+ if bias:
+ b = [
+ weights.get_sharded(f"{prefix}.q_attn.bias", dim=0),
+ weights.get_tensor(f"{prefix}.kv_attn.bias"),
+ ]
+ bias = torch.cat(b, dim=0)
+ else:
+ bias = None
+
+ weight = weight.to(dtype=weights.dtype).to(device=weights.device)
+ assert list(weight.shape) == [
+ (num_heads + 2) * head_size,
+ hidden_size,
+ ], f"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}"
+ if bias is not None:
+ bias = bias.to(dtype=weights.dtype).to(device=weights.device)
+ assert list(bias.shape) == [
+ (num_heads + 2) * head_size
+ ], f"{weight.shape} != {[(num_heads + 2) * head_size]}"
+ return TensorParallelColumnLinear(get_linear(weight, bias))
+
+
+def load_col(config, prefix: str, weights, bias: bool):
+ if config.transpose:
+ weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
+ else:
+ weight = weights.get_multi_weights_col([prefix], dim=0)
+
+ if bias:
+ bias = weights.get_sharded(f"{prefix}.bias", dim=0)
+ else:
+ bias = None
+ return TensorParallelColumnLinear(get_linear(weight, bias))
+
+
+def load_row(config, prefix: str, weights, bias: bool):
+ if config.transpose:
+ weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
+ else:
+ weight = weights.get_weights_row(prefix)
+
+ if bias and weights.process_group.rank() == 0:
+ # Rank is only on the first rank process
+ bias = weights.get_tensor(f"{prefix}.bias")
+ else:
+ bias = None
+ return TensorParallelRowLinear(
+ get_linear(weight, bias), process_group=weights.process_group
+ )
+
+
+class FlashMQAttention(torch.nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ num_heads = config.num_attention_heads
+ hidden_size = config.hidden_size
+
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.head_size = hidden_size // num_heads
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+
+ self.softmax_scale = self.head_size ** (-0.5)
+
+ self.c_attn = load_multi_mqa(
+ config,
+ prefix=prefix,
+ weights=weights,
+ bias=True,
+ head_size=self.head_size,
+ hidden_size=hidden_size,
+ num_heads=self.num_heads,
+ )
+ self.c_proj = load_row(
+ config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
+ )
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+ self.kv_head_mapping = torch.zeros(
+ self.num_heads, dtype=torch.int32, device=weights.device
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ qkv = self.c_attn(hidden_states)
+
+ # Split query from key_value
+ query, key_value = qkv.split(
+ [self.head_size * self.num_heads, 2 * self.head_size], dim=1
+ )
+
+ # Prepare query and key_value for indexing
+ query = query.view(-1, self.num_heads, self.head_size)
+ key_value = key_value.view(-1, 2, 1, self.head_size)
+
+ kv_cache.store(
+ key=key_value[:, 0],
+ value=key_value[:, 1],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=key_value[:, 0],
+ value=key_value[:, 1],
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
+
+
+class MLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ act = config.activation_function
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+
+ self.c_fc = load_col(
+ config, prefix=f"{prefix}.c_fc", weights=weights, bias=True
+ )
+ self.c_proj = load_row(
+ config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = self.c_fc(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.c_proj(hidden_states)
+ return hidden_states
+
+
+class Block(nn.Module):
+ def __init__(self, prefix: str, layer_id, config, weights):
+ super().__init__()
+ prefix = f"{prefix}.h.{layer_id}"
+ self.ln_1 = FastLayerNorm.load(
+ prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
+ )
+ self.ln_2 = FastLayerNorm.load(
+ prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon
+ )
+ self.self_attn = FlashMQAttention(
+ prefix=f"{prefix}.attn",
+ config=config,
+ weights=weights,
+ )
+ self.mlp = MLP(
+ prefix=f"{prefix}.mlp",
+ config=config,
+ weights=weights,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ hidden_states, residual = self.ln_1(hidden_states, residual)
+ hidden_states = self.self_attn(
+ hidden_states,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states, residual = self.ln_2(hidden_states, residual)
+
+ mlp_output = self.mlp(hidden_states)
+
+ return mlp_output, residual
+
+
+class FlashSantacoderModel(nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+ self.config = config
+
+ self.process_group = weights.process_group
+ self.wte = TensorParallelEmbedding(
+ prefix=f"{prefix}.wte",
+ weights=weights,
+ reduce=False,
+ )
+ self.wpe = TensorParallelEmbedding(
+ prefix=f"{prefix}.wpe",
+ weights=weights,
+ reduce=False,
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ Block(
+ prefix,
+ layer_id,
+ config,
+ weights,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.ln_f = FastLayerNorm.load(
+ prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon
+ )
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ if hpu_attention_meta is not None:
+ hpu_attention_meta = set_block_mapping(
+ hpu_attention_meta, input_ids.shape[0]
+ )
+ hidden_states = self.wte(input_ids) + self.wpe(position_ids)
+
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(hidden_states, group=self.process_group)
+
+ residual = None
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ hidden_states, _ = self.ln_f(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashSantacoderForCausalLM(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+
+ if not prefix:
+ prefix = "transformer"
+ else:
+ prefix = f"{prefix}.transformer"
+
+ config.transpose = config.architectures[0].startswith("GPT2")
+ self.model = FlashSantacoderModel(prefix, config, weights)
+ self.lm_head = SpeculativeHead.load(
+ config, prefix=f"{prefix}.wte", weights=weights
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ hidden_states = self.model(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits = self.lm_head(hidden_states)
+ return logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py
new file mode 100644
index 00000000000..b36ead7d7ab
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py
@@ -0,0 +1,612 @@
+# coding=utf-8
+# Copyright 2024 Starcoder2 AI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+from typing import Optional, List, Tuple
+
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ set_block_mapping,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ TensorParallelMultiAdapterLinear,
+ TensorParallelAdapterRowLinear,
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+ get_linear,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.layernorm import (
+ FastLayerNorm,
+ FastRMSNorm,
+)
+from text_generation_server.layers.rotary import (
+ PositionRotaryEmbedding,
+)
+from text_generation_server.utils.weights import UnquantizedWeight
+import habana_frameworks.torch as htorch
+
+
+class Starcoder2Config(PretrainedConfig):
+ model_type = "starcoder2"
+
+ def __init__(
+ self,
+ vocab_size=49152,
+ hidden_size=3072,
+ intermediate_size=12288,
+ num_hidden_layers=30,
+ num_attention_heads=24,
+ num_key_value_heads=2,
+ mlp_type="default",
+ hidden_act="gelu_pytorch_tanh",
+ max_position_embeddings=4096,
+ initializer_range=0.018042,
+ norm_type="layer_norm",
+ norm_epsilon=1e-5,
+ use_cache=True,
+ bos_token_id=50256,
+ eos_token_id=50256,
+ rope_theta=10000.0,
+ sliding_window=None,
+ attention_dropout=0.0,
+ residual_dropout=0.0,
+ embedding_dropout=0.0,
+ use_bias: bool = True,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.sliding_window = sliding_window
+ self.use_bias = use_bias
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.mlp_type = mlp_type
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.norm_type = norm_type
+ self.norm_epsilon = norm_epsilon
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_dropout = attention_dropout
+ self.residual_dropout = residual_dropout
+ self.embedding_dropout = embedding_dropout
+
+ super().__init__(
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ **kwargs,
+ )
+
+
+def load_attention(config, prefix, weights, layer_id):
+ prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
+ head_size = config.hidden_size // config.num_attention_heads
+ sizes = [
+ head_size * config.num_attention_heads,
+ head_size * config.num_key_value_heads,
+ head_size * config.num_key_value_heads,
+ ]
+ if config.num_attention_heads != config.num_key_value_heads:
+ base_layer = _load_gqa(config, prefix, weights)
+ else:
+ base_layer = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=prefixes,
+ dim=0,
+ weights=weights,
+ bias=config.use_bias,
+ )
+ return TensorParallelMultiAdapterLinear.load(
+ base_layer=base_layer,
+ layer_id=layer_id,
+ layer_names=prefixes,
+ sizes=sizes,
+ process_group=weights.process_group,
+ )
+
+
+def _load_gqa(config, prefix: str, weights):
+ assert config.hidden_size % config.num_attention_heads == 0
+ assert config.num_attention_heads % weights.process_group.size() == 0
+
+ weight = weights.get_multi_weights_col(
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ )
+
+ if isinstance(weight, UnquantizedWeight):
+ weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
+
+ head_size = config.hidden_size // config.num_attention_heads
+ num_heads = config.num_attention_heads // weights.process_group.size()
+ num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
+ assert list(weight.weight.shape) == [
+ (num_heads + 2 * num_key_value_heads) * head_size,
+ config.hidden_size,
+ ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
+
+ if config.use_bias:
+ w = [
+ weights.get_sharded(f"{p}.bias", dim=0)
+ for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
+ ]
+ bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device)
+ else:
+ bias = None
+
+ return TensorParallelColumnLinear(get_linear(weight, bias=bias))
+
+
+class Starcoder2Attention(torch.nn.Module):
+ def __init__(
+ self,
+ index: int,
+ prefix: str,
+ config,
+ weights,
+ rotary_emb,
+ ):
+ super().__init__()
+ self.max_past = (
+ config.sliding_window if config.sliding_window is not None else -1
+ )
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.head_size = self.hidden_size // self.num_heads
+ self.rotary_emb = rotary_emb
+
+ self.softmax_scale = self.head_size**-0.5
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+
+ self.query_key_value = load_attention(config, prefix, weights, index)
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=getattr(config, "use_bias", False),
+ )
+
+ self.o_proj = TensorParallelAdapterRowLinear.load(
+ o_proj,
+ index,
+ "o_proj",
+ process_group=weights.process_group,
+ )
+
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ ):
+ qkv = self.query_key_value(hidden_states, adapter_data)
+ query, kv = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ 2 * self.head_size * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+ query = query.view(-1, self.num_heads, self.head_size)
+ kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
+
+ self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
+
+ kv_cache.store(
+ key=kv[:, 0],
+ value=kv[:, 1],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=kv[:, 0],
+ value=kv[:, 1],
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ window_size_left=self.max_past,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ window_size_left=self.max_past,
+ )
+
+ return self.o_proj(
+ attn_output.view(-1, self.num_heads * self.head_size), adapter_data
+ )
+
+
+class Starcoder2MLP(nn.Module):
+ def __init__(self, prefix, config, weights, index):
+ super().__init__()
+ act = config.hidden_act
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+ # Fuse gate and up proj
+ c_fc = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.c_fc",
+ weights=weights,
+ bias=config.use_bias,
+ )
+ c_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.c_proj",
+ weights=weights,
+ bias=config.use_bias,
+ )
+
+ self.c_fc = TensorParallelMultiAdapterLinear.load(
+ c_fc,
+ layer_id=index,
+ layer_names=[f"{prefix}.c_fc"],
+ sizes=[config.intermediate_size, config.intermediate_size],
+ process_group=weights.process_group,
+ )
+
+ self.c_proj = TensorParallelAdapterRowLinear.load(
+ c_proj,
+ index,
+ "c_proj",
+ process_group=weights.process_group,
+ )
+
+ def forward(self, hidden_states, adapter_data):
+ hidden_states = self.c_fc(hidden_states, adapter_data)
+ hidden_states = self.act(hidden_states)
+ return self.c_proj(hidden_states, adapter_data)
+
+
+class Starcoder2GatedMLP(nn.Module):
+ def __init__(self, index, prefix, config, weights):
+ super().__init__()
+ act = config.hidden_act
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+ # Fuse gate and up proj
+ prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
+ sizes = [
+ config.intermediate_size,
+ config.intermediate_size,
+ ]
+ gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=prefixes,
+ weights=weights,
+ dim=0,
+ bias=config.use_bias,
+ )
+ self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
+ gate_up_proj,
+ index,
+ layer_names=prefixes,
+ sizes=sizes,
+ process_group=weights.process_group,
+ )
+ down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=config.use_bias,
+ )
+ self.down_proj = TensorParallelAdapterRowLinear.load(
+ down_proj,
+ index,
+ "down_proj",
+ process_group=weights.process_group,
+ )
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+
+ def forward(self, hidden_states, adapter_data):
+ gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
+ gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
+ return self.down_proj(
+ self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
+ )
+
+
+STARCODER2_NORMALIZATION_CLASSES = {
+ "layer_norm": FastLayerNorm,
+ "rms_norm": FastRMSNorm,
+}
+
+STARCODER2_MLP_CLASSES = {
+ "default": Starcoder2MLP,
+ "gated": Starcoder2GatedMLP,
+}
+
+
+class Starcoder2Layer(nn.Module):
+ def __init__(self, layer_id, config, weights, rotary_emb):
+ super().__init__()
+ prefix = f"model.layers.{layer_id}"
+ self.self_attn = Starcoder2Attention(
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ index=layer_id,
+ rotary_emb=rotary_emb,
+ )
+
+ self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
+ prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id
+ )
+
+ self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.norm_epsilon
+ )
+ self.post_attention_layernorm = STARCODER2_NORMALIZATION_CLASSES[
+ config.norm_type
+ ].load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.norm_epsilon,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ ):
+ normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ )
+
+ # faster post attention rms norm
+ normed_attn_res_output, attn_res = self.post_attention_layernorm(
+ attn_output, res
+ )
+
+ mlp_output = self.mlp(normed_attn_res_output, adapter_data)
+
+ return mlp_output, attn_res
+
+
+class Starcoder2Model(torch.nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+
+ process_group = weights.process_group
+ self.tp_rank = process_group.rank()
+ self.tp_world_size = process_group.size()
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=f"{prefix}.embed_tokens", weights=weights
+ )
+ rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=config.hidden_size // config.num_attention_heads,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+ self.layers = nn.ModuleList(
+ [
+ Starcoder2Layer(
+ layer_id,
+ config,
+ weights,
+ rotary_emb,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
+ prefix=f"{prefix}.norm", weights=weights, eps=config.norm_epsilon
+ )
+
+ self.gradient_checkpointing = False
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ adapter_data,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ if hpu_attention_meta is not None:
+ hpu_attention_meta = set_block_mapping(
+ hpu_attention_meta, input_ids.shape[0]
+ )
+ hidden_states = self.embed_tokens(input_ids)
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ )
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashStarcoder2ForCausalLM(torch.nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+
+ if not prefix:
+ prefix = "model"
+ else:
+ prefix = f"{prefix}.model"
+
+ self.model = Starcoder2Model(prefix, config, weights)
+ try:
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix="lm_head",
+ weights=weights,
+ )
+ except RuntimeError:
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix=f"{prefix}.embed_tokens",
+ weights=weights,
+ )
+
+ self.max_past = config.sliding_window
+ self.max_past_tensor = (
+ torch.tensor(config.sliding_window, device=weights.device)
+ if self.max_past is not None
+ else None
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ hidden_states = self.model(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits = self.lm_head(hidden_states)
+ return logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py
new file mode 100644
index 00000000000..0579ca5db90
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py
@@ -0,0 +1,861 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Idefics2 model."""
+
+from typing import List, Optional, Tuple
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+import math
+
+from transformers.activations import ACT2FN
+from text_generation_server.models.custom_modeling.vlm import (
+ load_text_model,
+)
+from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
+from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
+
+from text_generation_server.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ TensorParallelRowLinear,
+)
+from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(
+ batch, num_key_value_heads, n_rep, slen, head_dim
+ )
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class Idefics2VisionEmbeddings(nn.Module):
+ """
+ This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
+ resolution.
+
+ The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
+ which allows treating images in their native aspect ratio and without the need to resize them to the same
+ fixed size. In particular, we start from the original pre-trained SigLIP model
+ (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
+ """
+
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ )
+ self.patch_embedding.weight = nn.Parameter(
+ weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
+ )
+ self.patch_embedding.bias = nn.Parameter(
+ weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
+ )
+
+ self.num_patches_per_side = self.image_size // self.patch_size
+ self.num_patches = self.num_patches_per_side**2
+ self.num_positions = self.num_patches
+ self.position_embedding = TensorParallelEmbedding(
+ prefix=f"{prefix}.position_embedding", weights=weights
+ )
+
+ def forward(
+ self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
+ ) -> torch.Tensor:
+ batch_size, _, max_im_h, max_im_w = pixel_values.shape
+
+ patch_embeds = self.patch_embedding(pixel_values)
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+
+ max_nb_patches_h, max_nb_patches_w = (
+ max_im_h // self.patch_size,
+ max_im_w // self.patch_size,
+ )
+ boundaries = torch.arange(
+ 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
+ )
+ position_ids = torch.full(
+ size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
+ )
+
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
+ nb_patches_h = p_attn_mask[:, 0].sum()
+ nb_patches_w = p_attn_mask[0].sum()
+
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
+
+ bucket_coords_h = torch.bucketize(
+ fractional_coords_h, boundaries, right=True
+ )
+ bucket_coords_w = torch.bucketize(
+ fractional_coords_w, boundaries, right=True
+ )
+
+ pos_ids = (
+ bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
+ ).flatten()
+ position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
+
+ position_ids = position_ids.to(self.position_embedding.weight.device)
+ embeddings = embeddings + self.position_embedding(position_ids)
+ return embeddings
+
+
+class Idefics2VisionAttention(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_size = self.embed_dim // self.num_heads
+ if self.head_size * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_size**-0.5
+ self.dropout = config.attention_dropout
+
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.embed_dim = self.embed_dim // weights.process_group.size()
+
+ self.qkv = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=True,
+ )
+ self.out_proj = TensorParallelRowLinear.load(
+ config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
+ )
+ self.is_causal = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, q_len, _ = hidden_states.size()
+
+ qkv = self.qkv(hidden_states)
+ query_states, key_states, value_states = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ self.head_size * self.num_heads,
+ self.head_size * self.num_heads,
+ ],
+ dim=2,
+ )
+
+ query_states = query_states.view(
+ batch_size, q_len, self.num_heads, self.head_size
+ ).transpose(1, 2)
+ key_states = key_states.view(
+ batch_size, q_len, self.num_heads, self.head_size
+ ).transpose(1, 2)
+ value_states = value_states.view(
+ batch_size, q_len, self.num_heads, self.head_size
+ ).transpose(1, 2)
+
+ k_v_seq_len = key_states.shape[-2]
+ attn_weights = (
+ torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
+ )
+
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(
+ attn_weights, dim=-1, dtype=torch.float32
+ ).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(
+ attn_weights, p=self.dropout, training=self.training
+ )
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):
+ raise ValueError(
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output
+
+
+class Idefics2VisionMLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = TensorParallelColumnLinear.load(
+ prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
+ )
+ self.fc2 = TensorParallelRowLinear.load(
+ prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Idefics2EncoderLayer(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = Idefics2VisionAttention(
+ prefix=f"{prefix}.self_attn", config=config, weights=weights
+ )
+ self.layer_norm1 = nn.LayerNorm.load(
+ prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights
+ )
+ self.layer_norm2 = nn.LayerNorm.load(
+ prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
+ )
+ self.mlp = Idefics2VisionMLP(
+ prefix=f"{prefix}.mlp", config=config, weights=weights
+ )
+
+ # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+class Idefics2Encoder(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList(
+ [
+ Idefics2EncoderLayer(
+ prefix=f"{prefix}.layers.{i}", config=config, weights=weights
+ )
+ for i in range(config.num_hidden_layers)
+ ]
+ )
+
+ # Ignore copy
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ hidden_states = encoder_layer(
+ hidden_states,
+ attention_mask,
+ )
+ return hidden_states
+
+
+class Idefics2VisionTransformer(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.embeddings = Idefics2VisionEmbeddings(
+ prefix=f"{prefix}.embeddings", config=config, weights=weights
+ )
+ self.encoder = Idefics2Encoder(
+ prefix=f"{prefix}.encoder", config=config, weights=weights
+ )
+ self.post_layernorm = nn.LayerNorm.load(
+ prefix=f"{prefix}.post_layernorm",
+ weights=weights,
+ eps=config.layer_norm_eps,
+ )
+
+ def forward(
+ self,
+ pixel_values,
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
+ ):
+ batch_size = pixel_values.size(0)
+ if patch_attention_mask is None:
+ patch_size = self.config.patch_size
+ patch_attention_mask = torch.ones(
+ (
+ batch_size,
+ pixel_values.size(2) // patch_size,
+ pixel_values.size(3) // patch_size,
+ )
+ )
+ patch_attention_mask = patch_attention_mask.to(
+ dtype=torch.bool, device=pixel_values.device
+ )
+
+ hidden_states = self.embeddings(
+ pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
+ )
+
+ patch_attention_mask = patch_attention_mask.view(batch_size, -1)
+ # The call to `_upad_input` in `_flash_attention_forward` is expensive
+ # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
+ # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
+ if not torch.any(~patch_attention_mask):
+ patch_attention_mask = None
+ else:
+ patch_attention_mask = _prepare_4d_attention_mask(
+ patch_attention_mask, hidden_states.dtype
+ )
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=patch_attention_mask,
+ )
+
+ last_hidden_state = encoder_outputs
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ return last_hidden_state
+
+
+class Idefics2MLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ act = config.text_config.hidden_act
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+ self.gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+ self.down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+
+ def forward(self, hidden_states):
+ start_shape = hidden_states.shape[:-1]
+ gate_up_states = self.gate_up_proj(hidden_states)
+ intermediate_size = gate_up_states.shape[-1] // 2
+ gate_up_states = gate_up_states.view(-1, 2, intermediate_size)
+ return self.down_proj(
+ self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]
+ ).view(*start_shape, -1)
+
+
+class Idefics2RMSNorm(nn.Module):
+ def __init__(self, prefix, weights, eps):
+ """
+ Idefics2RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(
+ weights.get_tensor(f"{prefix}.weight"), requires_grad=False
+ )
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+class Idefics2PerceiverAttention(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+
+ self.layer_idx = None
+ self.hidden_size = config.text_config.hidden_size
+ self.num_heads = config.perceiver_config.resampler_n_heads
+ self.head_size = config.perceiver_config.resampler_head_dim
+ self.num_key_value_heads = config.perceiver_config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.attention_dropout = config.perceiver_config.attention_dropout
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ self.num_key_value_heads // weights.process_group.size()
+ )
+
+ self.q_proj = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.q_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.kv = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=False,
+ )
+ self.o_proj = TensorParallelRowLinear.load(
+ config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
+ )
+
+ self.is_causal = False
+
+ def forward(
+ self,
+ latents: torch.Tensor,
+ context: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = latents.size()
+ kv_seq_len = q_len + context.size()[1]
+
+ hidden_states = torch.concat([context, latents], dim=-2)
+ query_states = self.q_proj(latents)
+ kv = self.kv(hidden_states)
+ key_states, value_states = kv.split(
+ [
+ self.head_size * self.num_key_value_heads,
+ self.head_size * self.num_key_value_heads,
+ ],
+ dim=2,
+ )
+
+ query_states = query_states.view(
+ bsz, q_len, self.num_heads, self.head_size
+ ).transpose(1, 2)
+ key_states = key_states.view(
+ bsz, kv_seq_len, self.num_key_value_heads, self.head_size
+ ).transpose(1, 2)
+ value_states = value_states.view(
+ bsz, kv_seq_len, self.num_key_value_heads, self.head_size
+ ).transpose(1, 2)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(
+ query_states, key_states.transpose(2, 3)
+ ) / math.sqrt(self.head_size)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(
+ attn_weights, dim=-1, dtype=torch.float32
+ ).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_size):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_size)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output
+
+
+class Idefics2PerceiverLayer(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.hidden_size = config.text_config.hidden_size
+ self.n_latents = config.perceiver_config.resampler_n_latents
+ self.depth = config.perceiver_config.resampler_depth
+ self.rms_norm_eps = config.text_config.rms_norm_eps
+
+ self.input_latents_norm = Idefics2RMSNorm(
+ prefix=f"{prefix}.input_latents_norm",
+ weights=weights,
+ eps=self.rms_norm_eps,
+ )
+ self.input_context_norm = Idefics2RMSNorm(
+ prefix=f"{prefix}.input_context_norm",
+ weights=weights,
+ eps=self.rms_norm_eps,
+ )
+ self.self_attn = Idefics2PerceiverAttention(
+ prefix=f"{prefix}.self_attn", config=config, weights=weights
+ )
+ self.post_attention_layernorm = Idefics2RMSNorm(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=self.rms_norm_eps,
+ )
+ self.mlp = Idefics2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+
+ def forward(
+ self,
+ latents: torch.Tensor,
+ context: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ """
+ Args:
+ latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ """
+ residual = latents
+
+ latents = self.input_latents_norm(latents)
+ context = self.input_context_norm(context)
+
+ latents = self.self_attn(
+ latents=latents,
+ context=context,
+ attention_mask=attention_mask,
+ )
+ latents = residual + latents
+ residual = latents
+
+ latents = self.post_attention_layernorm(latents)
+ latents = self.mlp(latents)
+ latents = residual + latents
+
+ return latents
+
+
+class Idefics2PerceiverResampler(nn.Module):
+ def __init__(self, prefix, config, weights) -> None:
+ super().__init__()
+ self.hidden_size = config.text_config.hidden_size
+ self.hidden_act = config.perceiver_config.hidden_act
+ self.n_latents = config.perceiver_config.resampler_n_latents
+ self.depth = config.perceiver_config.resampler_depth
+ self.rms_norm_eps = config.text_config.rms_norm_eps
+
+ # Create Latents for Perceiver
+ self.latents = weights.get_tensor(f"{prefix}.latents")
+
+ # Create Transformer Blocks
+ self.layers = nn.ModuleList(
+ [
+ Idefics2PerceiverLayer(
+ prefix=f"{prefix}.layers.{idx}", config=config, weights=weights
+ )
+ for idx in range(self.depth)
+ ]
+ )
+ self.norm = Idefics2RMSNorm(
+ prefix=f"{prefix}.norm",
+ weights=weights,
+ eps=config.text_config.rms_norm_eps,
+ )
+
+ def forward(
+ self,
+ context: torch.Tensor,
+ attention_mask,
+ ) -> torch.Tensor:
+ # seq embed -> bsz seq embed
+ latents = self.latents.unsqueeze(0).expand(
+ (context.shape[0], *self.latents.size())
+ )
+
+ latent_attention_mask = torch.ones(
+ (attention_mask.size(0), latents.size(1)),
+ dtype=attention_mask.dtype,
+ device=attention_mask.device,
+ )
+ attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)
+ attention_mask = _prepare_4d_attention_mask(
+ attention_mask, latents.dtype, tgt_len=self.n_latents
+ )
+
+ compressed_context = latents
+ for perceiver_layer in self.layers:
+ compressed_context = perceiver_layer(
+ compressed_context,
+ context,
+ attention_mask=attention_mask,
+ )
+ compressed_context = self.norm(compressed_context)
+
+ return compressed_context
+
+
+class Idefics2Connector(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.modality_projection = Idefics2MLP(
+ prefix=f"{prefix}.modality_projection", config=config, weights=weights
+ )
+ self.perceiver_resampler = Idefics2PerceiverResampler(
+ prefix=f"{prefix}.perceiver_resampler", config=config, weights=weights
+ )
+
+ def forward(self, image_hidden_states, attention_mask):
+ image_hidden_states = self.modality_projection(image_hidden_states)
+ image_hidden_states = self.perceiver_resampler(
+ context=image_hidden_states, attention_mask=attention_mask
+ )
+ return image_hidden_states
+
+
+class Idefics2ForConditionalGeneration(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ config.vision_config.quantize = None
+ config.vision_config.speculator = config.speculator
+ config.text_config.quantize = config.quantize
+ config.text_config.speculator = config.speculator
+
+ vision_config = config.vision_config
+ self.text_model = load_text_model(
+ prefix="model" if not prefix else f"{prefix}.model",
+ config=config.text_config,
+ weights=weights,
+ name="text_model",
+ )
+ self.dtype = weights.dtype
+
+ # The vision and connector models are not quantized.
+ with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
+ self.vision_model = Idefics2VisionTransformer(
+ prefix=(
+ f"{prefix}.model.vision_model" if prefix else "model.vision_model"
+ ),
+ config=vision_config,
+ weights=weights,
+ )
+
+ config.quantize = None
+ self.connector = Idefics2Connector(
+ prefix=f"{prefix}.model.connector" if prefix else "model.connector",
+ config=config,
+ weights=weights,
+ )
+
+ self.config = config
+ self.image_seq_len = config.perceiver_config.resampler_n_latents
+ self.image_token_id = config.image_token_id
+ self.pad_token_id = (
+ config.pad_token_id if config.pad_token_id is not None else -1
+ )
+
+ def _merge_input_ids_with_image_features(
+ self,
+ input_ids: torch.Tensor,
+ inputs_embeds: torch.Tensor,
+ image_features: torch.Tensor,
+ ):
+ """In place merges in vision_embeddings with inputs_embeds."""
+ # mask = input_ids == self.config.image_token_index
+ # - replace `==` with torch.where to fix the issue in hpu graph
+ mask = torch.where(input_ids == self.config.image_token_id)
+ # Let's pray we have enabled enough slots !
+ inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
+ return inputs_embeds
+
+ def get_vision_embeds(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_attention_mask: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ ):
+ assert pixel_values is not None
+ batch_size, num_images, num_channels, height, width = pixel_values.shape
+ all_states = []
+ all_pixel_values = pixel_values
+ all_pixel_mask = pixel_attention_mask
+ for i in range(batch_size):
+ pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility
+ pixel_values = pixel_values[i : i + 1]
+ pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
+
+ # Remove padding images - padding images are full 0.
+ nb_values_per_image = pixel_values.shape[1:].numel()
+ real_images_inds = (pixel_values == 0.0).sum(
+ dim=(-1, -2, -3)
+ ) != nb_values_per_image
+ pixel_values = pixel_values[real_images_inds].contiguous()
+
+ # Handle the vision attention mask
+ if pixel_attention_mask is None:
+ pixel_attention_mask = torch.ones(
+ size=(
+ pixel_values.size(0),
+ pixel_values.size(2),
+ pixel_values.size(3),
+ ),
+ dtype=torch.bool,
+ device=pixel_values.device,
+ )
+ else:
+ # Remove padding images from the mask/pP p
+ pixel_attention_mask = all_pixel_mask[i : i + 1]
+ pixel_attention_mask = pixel_attention_mask.view(
+ 1 * num_images, *pixel_attention_mask.shape[2:]
+ )
+ pixel_attention_mask = pixel_attention_mask[
+ real_images_inds
+ ].contiguous()
+
+ patch_size = self.config.vision_config.patch_size
+ """
+ patches_subgrid = pixel_attention_mask.unfold(
+ dimension=1, size=patch_size, step=patch_size
+ )
+ patches_subgrid = patches_subgrid.unfold(
+ dimension=2, size=patch_size, step=patch_size
+ )
+ patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
+ """
+ # hpu does none support unfold
+ conv_kernel = torch.ones(
+ [1, 1, patch_size, patch_size],
+ dtype=pixel_values.dtype,
+ device=pixel_values.device,
+ )
+ patches_subgrid = torch.nn.functional.conv2d(
+ pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),
+ conv_kernel,
+ stride=patch_size,
+ ).squeeze(1)
+ patch_attention_mask = torch.gt(patches_subgrid, 0)
+
+ # Get sequence from the vision encoder
+ image_hidden_states = self.vision_model(
+ pixel_values=pixel_values,
+ patch_attention_mask=patch_attention_mask,
+ )
+
+ # Modality projection & resampling
+ image_hidden_states = self.connector(
+ image_hidden_states,
+ attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
+ )
+ all_states.append(image_hidden_states)
+ image_hidden_states = torch.stack(all_states, dim=0)
+ return image_hidden_states.view(-1, image_hidden_states.shape[-1])
+
+ def get_inputs_embeds(
+ self,
+ input_ids: torch.Tensor,
+ vision_embeds: torch.Tensor = None,
+ ):
+ inputs_embeds = self.text_model.embed_tokens(input_ids)
+
+ if vision_embeds is not None:
+ # When we generate, we don't want to replace the potential image_token_id that we generated by images
+ # that simply don't exist
+ inputs_embeds = self._merge_input_ids_with_image_features(
+ input_ids, inputs_embeds, vision_embeds
+ )
+ return inputs_embeds
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.BoolTensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ):
+ hidden_states = self.text_model.model(
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ cu_seqlen_prefill=cu_seqlen_prefill,
+ kv_cache=kv_cache,
+ slots=slots,
+ seqlen=seqlen,
+ hpu_attention_meta=hpu_attention_meta,
+ adapter_data=adapter_data,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.text_model.lm_head(hidden_states)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py
new file mode 100644
index 00000000000..e12f220922e
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py
@@ -0,0 +1,605 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Idefics3 model."""
+
+from typing import List, Optional, Tuple
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from transformers.activations import ACT2FN
+from text_generation_server.models.custom_modeling.vlm import (
+ load_text_model,
+)
+from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
+from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
+
+from text_generation_server.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ TensorParallelRowLinear,
+)
+from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(
+ batch, num_key_value_heads, n_rep, slen, head_dim
+ )
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class Idefics3VisionEmbeddings(nn.Module):
+ """
+ This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
+ resolution.
+
+ The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
+ which allows treating images in their native aspect ratio and without the need to resize them to the same
+ fixed size. In particular, we start from the original pre-trained SigLIP model
+ (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
+ """
+
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ )
+ self.patch_embedding.weight = nn.Parameter(
+ weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
+ )
+ self.patch_embedding.bias = nn.Parameter(
+ weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
+ )
+
+ self.num_patches_per_side = self.image_size // self.patch_size
+ self.num_patches = self.num_patches_per_side**2
+ self.num_positions = self.num_patches
+ self.position_embedding = TensorParallelEmbedding(
+ prefix=f"{prefix}.position_embedding", weights=weights
+ )
+
+ def forward(
+ self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
+ ) -> torch.Tensor:
+ batch_size, _, max_im_h, max_im_w = pixel_values.shape
+
+ patch_embeds = self.patch_embedding(pixel_values)
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+
+ max_nb_patches_h, max_nb_patches_w = (
+ max_im_h // self.patch_size,
+ max_im_w // self.patch_size,
+ )
+ boundaries = torch.arange(
+ 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
+ )
+ position_ids = torch.full(
+ size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
+ )
+
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
+ nb_patches_h = p_attn_mask[:, 0].sum()
+ nb_patches_w = p_attn_mask[0].sum()
+
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
+
+ bucket_coords_h = torch.bucketize(
+ fractional_coords_h, boundaries, right=True
+ )
+ bucket_coords_w = torch.bucketize(
+ fractional_coords_w, boundaries, right=True
+ )
+
+ pos_ids = (
+ bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
+ ).flatten()
+ position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
+
+ position_ids = position_ids.to(self.position_embedding.weight.device)
+ embeddings = embeddings + self.position_embedding(position_ids)
+ return embeddings
+
+
+class Idefics3VisionAttention(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_size = self.embed_dim // self.num_heads
+ if self.head_size * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_size**-0.5
+ self.dropout = config.attention_dropout
+
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.embed_dim = self.embed_dim // weights.process_group.size()
+
+ self.qkv = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=True,
+ )
+ self.out_proj = TensorParallelRowLinear.load(
+ config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
+ )
+ self.is_causal = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, q_len, _ = hidden_states.size()
+
+ qkv = self.qkv(hidden_states)
+ query_states, key_states, value_states = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ self.head_size * self.num_heads,
+ self.head_size * self.num_heads,
+ ],
+ dim=2,
+ )
+
+ query_states = query_states.view(
+ batch_size, q_len, self.num_heads, self.head_size
+ ).transpose(1, 2)
+ key_states = key_states.view(
+ batch_size, q_len, self.num_heads, self.head_size
+ ).transpose(1, 2)
+ value_states = value_states.view(
+ batch_size, q_len, self.num_heads, self.head_size
+ ).transpose(1, 2)
+
+ k_v_seq_len = key_states.shape[-2]
+ attn_weights = (
+ torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
+ )
+
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(
+ attn_weights, dim=-1, dtype=torch.float32
+ ).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(
+ attn_weights, p=self.dropout, training=self.training
+ )
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):
+ raise ValueError(
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output
+
+
+class Idefics3VisionMLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = TensorParallelColumnLinear.load(
+ prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
+ )
+ self.fc2 = TensorParallelRowLinear.load(
+ prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Idefics3EncoderLayer(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = Idefics3VisionAttention(
+ prefix=f"{prefix}.self_attn", config=config, weights=weights
+ )
+ self.layer_norm1 = nn.LayerNorm.load(
+ prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights
+ )
+ self.layer_norm2 = nn.LayerNorm.load(
+ prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
+ )
+ self.mlp = Idefics3VisionMLP(
+ prefix=f"{prefix}.mlp", config=config, weights=weights
+ )
+
+ # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+class Idefics3Encoder(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList(
+ [
+ Idefics3EncoderLayer(
+ prefix=f"{prefix}.layers.{i}", config=config, weights=weights
+ )
+ for i in range(config.num_hidden_layers)
+ ]
+ )
+
+ # Ignore copy
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ hidden_states = encoder_layer(
+ hidden_states,
+ attention_mask,
+ )
+ return hidden_states
+
+
+class Idefics3VisionTransformer(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.embeddings = Idefics3VisionEmbeddings(
+ prefix=f"{prefix}.embeddings", config=config, weights=weights
+ )
+ self.encoder = Idefics3Encoder(
+ prefix=f"{prefix}.encoder", config=config, weights=weights
+ )
+ self.post_layernorm = nn.LayerNorm.load(
+ prefix=f"{prefix}.post_layernorm",
+ weights=weights,
+ eps=config.layer_norm_eps,
+ )
+
+ def forward(
+ self,
+ pixel_values,
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
+ ):
+ batch_size = pixel_values.size(0)
+ if patch_attention_mask is None:
+ patch_size = self.config.patch_size
+ patch_attention_mask = torch.ones(
+ (
+ batch_size,
+ pixel_values.size(2) // patch_size,
+ pixel_values.size(3) // patch_size,
+ )
+ )
+ patch_attention_mask = patch_attention_mask.to(
+ dtype=torch.bool, device=pixel_values.device
+ )
+
+ hidden_states = self.embeddings(
+ pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
+ )
+
+ patch_attention_mask = patch_attention_mask.view(batch_size, -1)
+ # The call to `_upad_input` in `_flash_attention_forward` is expensive
+ # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
+ # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
+ if not torch.any(~patch_attention_mask):
+ patch_attention_mask = None
+ else:
+ patch_attention_mask = _prepare_4d_attention_mask(
+ patch_attention_mask, hidden_states.dtype
+ )
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=patch_attention_mask,
+ )
+
+ last_hidden_state = encoder_outputs
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ return last_hidden_state
+
+
+class Idefics3SimpleMLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ input_size = config.vision_config.hidden_size * (config.scale_factor**2)
+ output_size = config.text_config.hidden_size
+ proj = nn.Parameter(
+ weights.get_tensor(f"{prefix}.modality_projection.proj.weight"),
+ requires_grad=False,
+ ).to(weights.dtype)
+ self.proj = nn.Linear(input_size, output_size, bias=False)
+ self.proj.weight = proj
+
+ def forward(self, x):
+ return self.proj(x)
+
+
+class Idefics3Connector(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.modality_projection = Idefics3SimpleMLP(prefix, config, weights)
+ self.scale_factor = config.scale_factor
+
+ def pixel_shuffle(self, x, scale_factor=2):
+ bsz, seq, embed_dim = x.size()
+ height = width = int(seq**0.5)
+ x = x.view(bsz, height, width, embed_dim)
+ x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
+ x = x.permute(0, 2, 1, 3)
+ x = x.reshape(
+ bsz,
+ int(width / scale_factor),
+ int(height / scale_factor),
+ embed_dim * (scale_factor**2),
+ )
+ x = x.permute(0, 2, 1, 3)
+ x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
+ return x
+
+ def forward(self, image_hidden_states):
+ image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
+ image_hidden_states = self.modality_projection(image_hidden_states)
+ return image_hidden_states
+
+
+class Idefics3ForConditionalGeneration(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ config.vision_config.quantize = None
+ config.vision_config.speculator = config.speculator
+ config.text_config.quantize = config.quantize
+ config.text_config.speculator = config.speculator
+ # set tie_word_embeddings to True to load `.embed_tokens.weight` instead of `.lm_head.weight`
+ # since Idefics3 uses the `embed_tokens` for the final prediction
+ # config.text_config.tie_word_embeddings = True
+
+ vision_config = config.vision_config
+ self.text_model = load_text_model(
+ prefix="model" if not prefix else f"{prefix}.model",
+ config=config.text_config,
+ weights=weights,
+ name="text_model",
+ )
+ self.dtype = weights.dtype
+
+ # The vision and connector models are not quantized.
+ with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
+ self.vision_model = Idefics3VisionTransformer(
+ prefix=(
+ f"{prefix}.model.vision_model" if prefix else "model.vision_model"
+ ),
+ config=vision_config,
+ weights=weights,
+ )
+
+ config.quantize = None
+ self.connector = Idefics3Connector(
+ prefix=f"{prefix}.model.connector" if prefix else "model.connector",
+ config=config,
+ weights=weights,
+ )
+
+ self.config = config
+ self.image_token_id = config.image_token_id
+ self.pad_token_id = (
+ config.pad_token_id if config.pad_token_id is not None else -1
+ )
+
+ def _merge_input_ids_with_image_features(
+ self,
+ input_ids: torch.Tensor,
+ inputs_embeds: torch.Tensor,
+ image_features: torch.Tensor,
+ ):
+ """In place merges in vision_embeddings with inputs_embeds."""
+ # mask = input_ids == self.config.image_token_index
+ # - replace `==` with torch.where to fix the issue in hpu graph
+ mask = torch.where(input_ids == self.config.image_token_id)
+ # Let's pray we have enabled enough slots !
+ inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
+ return inputs_embeds
+
+ def get_vision_embeds(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_attention_mask: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ ):
+ batch_size, num_images, num_channels, height, width = pixel_values.shape
+ all_states = []
+ all_pixel_values = pixel_values
+ all_pixel_mask = pixel_attention_mask
+ for i in range(batch_size):
+ pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility
+ pixel_values = pixel_values[i : i + 1]
+ pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
+
+ # Remove padding images - padding images are full 0.
+ nb_values_per_image = pixel_values.shape[1:].numel()
+ real_images_inds = (pixel_values == 0.0).sum(
+ dim=(-1, -2, -3)
+ ) != nb_values_per_image
+ pixel_values = pixel_values[real_images_inds].contiguous()
+ # Handle the vision attention mask
+ if pixel_attention_mask is None:
+ pixel_attention_mask = torch.ones(
+ size=(
+ pixel_values.size(0),
+ pixel_values.size(2),
+ pixel_values.size(3),
+ ),
+ dtype=torch.bool,
+ device=pixel_values.device,
+ )
+ else:
+ # Remove padding images from the mask/pP p
+ pixel_attention_mask = all_pixel_mask[i : i + 1]
+ pixel_attention_mask = pixel_attention_mask.view(
+ 1 * num_images, *pixel_attention_mask.shape[2:]
+ )
+ pixel_attention_mask = pixel_attention_mask[
+ real_images_inds
+ ].contiguous()
+
+ patch_size = self.config.vision_config.patch_size
+
+ """
+ patches_subgrid = pixel_attention_mask.unfold(
+ dimension=1, size=patch_size, step=patch_size
+ )
+ patches_subgrid = patches_subgrid.unfold(
+ dimension=2, size=patch_size, step=patch_size
+ )
+ patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
+ """
+ # hpu does none support unfold
+ conv_kernel = torch.ones(
+ [1, 1, patch_size, patch_size],
+ dtype=pixel_values.dtype,
+ device=pixel_values.device,
+ )
+ patches_subgrid = torch.nn.functional.conv2d(
+ pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),
+ conv_kernel,
+ stride=patch_size,
+ ).squeeze(1)
+ patch_attention_mask = torch.gt(patches_subgrid, 0)
+
+ # Get sequence from the vision encoder
+ image_hidden_states = self.vision_model(
+ pixel_values=pixel_values,
+ patch_attention_mask=patch_attention_mask,
+ )
+
+ # Modality projection & resampling
+ image_hidden_states = self.connector(
+ image_hidden_states,
+ )
+
+ all_states.append(image_hidden_states)
+ image_hidden_states = torch.stack(all_states, dim=0)
+
+ return image_hidden_states.view(-1, image_hidden_states.shape[-1])
+
+ def get_inputs_embeds(
+ self,
+ input_ids: torch.Tensor,
+ vision_embeds: torch.Tensor = None,
+ ):
+ inputs_embeds = self.text_model.embed_tokens(input_ids)
+
+ if vision_embeds is not None:
+ # When we generate, we don't want to replace the potential image_token_id that we generated by images
+ # that simply don't exist
+ inputs_embeds = self._merge_input_ids_with_image_features(
+ input_ids, inputs_embeds, vision_embeds
+ )
+ return inputs_embeds
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.BoolTensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ image_indices=None,
+ ):
+ hidden_states = self.text_model.model(
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ cu_seqlen_prefill=cu_seqlen_prefill,
+ kv_cache=kv_cache,
+ slots=slots,
+ seqlen=seqlen,
+ hpu_attention_meta=hpu_attention_meta,
+ adapter_data=adapter_data,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.text_model.lm_head(hidden_states)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py
new file mode 100644
index 00000000000..5a9c058871c
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py
@@ -0,0 +1,238 @@
+import torch
+import torch.distributed
+
+from mamba_ssm.ops.triton.selective_state_update import selective_state_update
+from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
+from torch import nn
+from typing import Optional, Tuple, Any
+from transformers.configuration_utils import PretrainedConfig
+import torch.nn.functional as F
+
+from text_generation_server.layers import (
+ SpeculativeHead,
+ TensorParallelEmbedding,
+ FastLinear,
+)
+from text_generation_server.layers.layernorm import FastRMSNorm
+
+from einops import rearrange
+from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
+import math
+from dataclasses import dataclass
+
+
+@dataclass
+class InferenceParams:
+ """Inference parameters that are passed to the main model in order
+ to efficienly calculate and store the context during inference."""
+
+ max_seqlen: int
+ max_batch_size: int
+ conv_states: torch.Tensor
+ ssm_states: torch.Tensor
+ seqlen_offset: int
+
+
+class MambaConfig(PretrainedConfig):
+ def __init__(
+ self,
+ vocab_size=50280,
+ d_model=768,
+ d_state=16,
+ n_layer=32,
+ layer_norm_epsilon=1e-5,
+ tie_word_embeddings=False,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ expand=2,
+ dt_rank="auto",
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.n_layer = n_layer
+ self.layer_norm_epsilon = layer_norm_epsilon
+ self.d_model = d_model
+ self.d_inner = d_model * 2
+ self.d_conv = 4
+ self.d_state = d_state
+ self.expand = expand
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class MambaBlock(nn.Module):
+ def __init__(self, prefix, config, weights, layer_id):
+ super().__init__()
+ self.layer_id = layer_id
+ self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False)
+ self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False)
+ self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True)
+ self.dt_proj_no_bias = FastLinear.load(
+ config, f"{prefix}.dt_proj", weights, bias=False
+ )
+ self.out_proj = FastLinear.load(
+ config, f"{prefix}.out_proj", weights, bias=False
+ )
+ self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True)
+ self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float())
+ self.D = weights.get_tensor(f"{prefix}.D")
+ self.activation = "silu"
+ self.dt_rank = config.dt_rank
+ self.d_state = config.d_state
+ self.d_conv = config.d_conv
+ self.act = nn.SiLU()
+
+ # inference_params
+ def forward(self, hidden_states: torch.Tensor, inference_params=None):
+ if inference_params.seqlen_offset > 0:
+ conv_state = inference_params.conv_states[self.layer_id]
+ ssm_state = inference_params.ssm_states[self.layer_id]
+ out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state)
+ return out, conv_state, ssm_state
+
+ _, seqlen, _ = hidden_states.shape
+ projected_states = self.in_proj(hidden_states).transpose(1, 2)
+ # assert projected_states.shape == [batch_size, 2 * dstate, seqlen], f"{projected_states.shape} [{batch_size}, {dstate}, {seqlen}]"
+ x, z = projected_states.chunk(2, dim=1)
+ conv_state = F.pad(x, (self.d_conv - seqlen, 0))
+ x = causal_conv1d_fn(
+ x=x,
+ weight=self.conv1d.weight.squeeze(1),
+ bias=self.conv1d.bias,
+ activation=self.activation,
+ )
+
+ # We're careful here about the layout, to avoid extra transposes.
+ # We want dt to have d as the slowest moving dimension
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
+ dt, B, C = torch.split(
+ x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
+ )
+ dt = self.dt_proj.weight @ dt.t()
+ dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
+ y, last_state = selective_scan_fn(
+ x,
+ dt,
+ self.negA,
+ B,
+ C,
+ self.D.float(),
+ z=z,
+ delta_bias=self.dt_proj.bias.float(),
+ delta_softplus=True,
+ return_last_state=True,
+ )
+ y = rearrange(y, "b d l -> b l d")
+ attn_outputs = self.out_proj(y)
+ return attn_outputs, conv_state, last_state
+
+ def step(self, hidden_states, conv_state, ssm_state):
+ xz = self.in_proj(hidden_states.squeeze(1))
+ x, z = xz.chunk(2, dim=-1) # (B D)
+ x = causal_conv1d_update(
+ x,
+ conv_state,
+ self.conv1d.weight.squeeze(1),
+ self.conv1d.bias,
+ self.activation,
+ )
+ x_db = self.x_proj(x) # (B dt_rank+2*d_state)
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
+ dt = F.linear(dt, self.dt_proj.weight)
+ A = self.negA
+ y = selective_state_update(
+ ssm_state,
+ x,
+ dt,
+ A,
+ B,
+ C,
+ self.D,
+ z=z,
+ dt_bias=self.dt_proj.bias,
+ dt_softplus=True,
+ )
+ out = self.out_proj(y)
+ return out.unsqueeze(1), conv_state.clone(), ssm_state.clone()
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, prefix, config, weights, layer_id):
+ super().__init__()
+ self.mamba_block = MambaBlock(
+ prefix=f"{prefix}.mixer", config=config, weights=weights, layer_id=layer_id
+ )
+ self.layer_norm = FastRMSNorm.load(
+ prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_epsilon
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ residual: Optional[torch.Tensor] = None,
+ inference_params: Optional[Any] = None,
+ ):
+ residual = (hidden_states + residual) if residual is not None else hidden_states
+ shape = residual.shape
+ hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1]))
+ hidden_states, conv_state, last_ssm_state = self.mamba_block(
+ hidden_states.view(*shape), inference_params
+ )
+ return hidden_states, residual, conv_state, last_ssm_state
+
+
+class MambaModel(nn.Module):
+ def __init__(self, config, weights):
+ super().__init__()
+ prefix = "backbone"
+ try:
+ self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embeddings", weights)
+ except RuntimeError:
+ self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
+ self.blocks = nn.ModuleList(
+ [
+ ResidualBlock(f"{prefix}.layers.{i}", config, weights, layer_id=i)
+ for i in range(config.n_layer)
+ ]
+ )
+ self.norm_f = FastRMSNorm.load(
+ f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
+ )
+ try:
+ self.lm_head = SpeculativeHead.load(config, f"{prefix}.embeddings", weights)
+ except RuntimeError:
+ self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights)
+ self.config = config
+
+ def forward(
+ self, input_ids: torch.Tensor, inference_params=None, residual=None
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ hidden_states = self.embed_tokens(input_ids)
+ for i, block in enumerate(self.blocks):
+ hidden_states, residual, conv_state, ssm_state = block(
+ hidden_states, residual, inference_params
+ )
+ inference_params.conv_states[i].copy_(conv_state)
+ inference_params.ssm_states[i].copy_(ssm_state)
+
+ hidden_states = (
+ hidden_states + residual if residual is not None else hidden_states
+ )
+ hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
+ hidden_states = hidden_states.view(residual.shape)
+ logits, speculative_logits = self.lm_head(hidden_states)
+
+ # update the offset for the next inference using these params
+ inference_params.seqlen_offset += input_ids.size(1)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py
new file mode 100644
index 00000000000..ac1578e994f
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py
@@ -0,0 +1,949 @@
+# coding=utf-8
+# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Qwen2.5 VL model."""
+
+from typing import Optional, Tuple, List
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from habana_frameworks.torch.hpex.kernels import FusedSDPA
+from vllm_hpu_extension.utils import ModuleFusedSDPA
+
+
+import numpy as np
+
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+
+import torch.nn.functional as F
+
+from text_generation_server.layers.layernorm import FastRMSNorm
+from text_generation_server.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelRowLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+)
+from text_generation_server.layers.attention import (
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
+ Qwen2Model,
+)
+from habana_frameworks.torch.hpex.kernels import (
+ RotaryPosEmbeddingMode,
+ apply_rotary_pos_emb,
+)
+import habana_frameworks.torch as htorch
+
+# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py
+from typing import Union
+from transformers.feature_extraction_utils import BatchFeature
+from transformers.image_utils import ImageInput
+from transformers.video_utils import VideoInput
+from transformers.processing_utils import (
+ ProcessingKwargs,
+ ProcessorMixin,
+ Unpack,
+ VideosKwargs,
+)
+from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
+
+
+class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False):
+ fps: Union[List[float], float]
+
+
+class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
+ videos_kwargs: Qwen2_5_VLVideosProcessorKwargs
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ },
+ "videos_kwargs": {"fps": 2.0},
+ }
+
+
+class Qwen2_5_VLProcessor(ProcessorMixin):
+ r"""
+ Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor.
+ [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
+ [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information.
+ Args:
+ image_processor ([`Qwen2VLImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`Qwen2TokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ valid_kwargs = ["chat_template"]
+
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
+
+ def __init__(
+ self, image_processor=None, tokenizer=None, chat_template=None, **kwargs
+ ):
+ self.image_token = (
+ "<|image_pad|>"
+ if not hasattr(tokenizer, "image_token")
+ else tokenizer.image_token
+ )
+ self.video_token = (
+ "<|video_pad|>"
+ if not hasattr(tokenizer, "video_token")
+ else tokenizer.video_token
+ )
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ def __call__(
+ self,
+ images: ImageInput = None,
+ text: Union[
+ TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
+ ] = None,
+ videos: VideoInput = None,
+ **kwargs: Unpack[Qwen2_5_VLProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
+ Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
+ tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
+ - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
+ - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
+ - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
+ """
+ output_kwargs = self._merge_kwargs(
+ Qwen2_5_VLProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+ if images is not None:
+ image_inputs = self.image_processor(
+ images=images, videos=None, **output_kwargs["images_kwargs"]
+ )
+ image_grid_thw = image_inputs["image_grid_thw"]
+ else:
+ image_inputs = {}
+ image_grid_thw = None
+
+ if videos is not None:
+ videos_inputs = self.image_processor(
+ images=None, videos=videos, **output_kwargs["images_kwargs"]
+ )
+ video_grid_thw = videos_inputs["video_grid_thw"]
+
+ fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
+ if isinstance(fps, (int, float)):
+ second_per_grid_ts = [
+ self.image_processor.temporal_patch_size / fps
+ ] * len(video_grid_thw)
+ elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
+ second_per_grid_ts = [
+ self.image_processor.temporal_patch_size / tmp for tmp in fps
+ ]
+ else:
+ raise ValueError(
+ f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
+ )
+ videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})
+
+ else:
+ videos_inputs = {}
+ video_grid_thw = None
+
+ if not isinstance(text, list):
+ text = [text]
+
+ if image_grid_thw is not None:
+ merge_length = self.image_processor.merge_size**2
+ index = 0
+ for i in range(len(text)):
+ while self.image_token in text[i]:
+ text[i] = text[i].replace(
+ self.image_token,
+ "<|placeholder|>"
+ * (image_grid_thw[index].prod() // merge_length),
+ 1,
+ )
+ index += 1
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
+
+ if video_grid_thw is not None:
+ merge_length = self.image_processor.merge_size**2
+ index = 0
+ for i in range(len(text)):
+ while self.video_token in text[i]:
+ text[i] = text[i].replace(
+ self.video_token,
+ "<|placeholder|>"
+ * (video_grid_thw[index].prod() // merge_length),
+ 1,
+ )
+ index += 1
+ text[i] = text[i].replace("<|placeholder|>", self.video_token)
+
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
+
+ return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ def post_process_image_text_to_text(self, generated_outputs):
+ """
+ Post-process the output of the model to decode the text.
+
+ Args:
+ generated_outputs (`torch.Tensor` or `np.ndarray`):
+ The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
+ or `(sequence_length,)`.
+
+ Returns:
+ `List[str]`: The decoded text.
+ """
+ return self.tokenizer.batch_decode(
+ generated_outputs,
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=False,
+ )
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ names_from_processor = list(
+ dict.fromkeys(tokenizer_input_names + image_processor_input_names)
+ )
+ return names_from_processor + ["second_per_grid_ts"]
+
+
+# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py
+class Qwen2_5_VLVisionConfig(PretrainedConfig):
+ model_type = "qwen2_5_vl"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ depth=32,
+ hidden_size=3584,
+ hidden_act="silu",
+ intermediate_size=3420,
+ num_heads=16,
+ in_channels=3,
+ patch_size=14,
+ spatial_merge_size=2,
+ spatial_patch_size=14,
+ temporal_patch_size=2,
+ tokens_per_second=4,
+ window_size=112,
+ out_hidden_size=3584,
+ fullatt_block_indexes=[7, 15, 23, 31],
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.depth = depth
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.patch_size = patch_size
+ self.spatial_patch_size = spatial_patch_size
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+ self.tokens_per_second = tokens_per_second
+ self.window_size = window_size
+ self.fullatt_block_indexes = fullatt_block_indexes
+ self.out_hidden_size = out_hidden_size
+
+
+class Qwen2_5_VLConfig(PretrainedConfig):
+
+ def __init__(
+ self,
+ vocab_size=152064,
+ hidden_size=8192,
+ intermediate_size=29568,
+ num_hidden_layers=80,
+ num_attention_heads=64,
+ num_key_value_heads=8,
+ hidden_act="silu",
+ max_position_embeddings=32768,
+ initializer_range=0.02,
+ rms_norm_eps=1e-05,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=1000000.0,
+ use_sliding_window=False,
+ sliding_window=4096,
+ max_window_layers=80,
+ attention_dropout=0.0,
+ vision_config=None,
+ rope_scaling=None,
+ **kwargs,
+ ):
+ if vision_config is not None:
+ self.vision_config = Qwen2_5_VLVisionConfig(**vision_config)
+
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.use_sliding_window = use_sliding_window
+ self.sliding_window = sliding_window
+ self.max_window_layers = max_window_layers
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_dropout = attention_dropout
+ self.rope_scaling = rope_scaling
+
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, move it to 'rope_type'.
+ # and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations
+ # one can set it to "linear"/"dynamic" etc. to have scaled RoPE
+ # TODO: @raushan update config in the hub
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ if self.rope_scaling["type"] == "mrope":
+ self.rope_scaling["type"] = "default"
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+
+class Qwen2_5VLAttention(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.embed_dim = config.hidden_size // weights.process_group.size()
+ self.head_dim = config.hidden_size // config.num_heads
+ self.num_heads = config.num_heads // weights.process_group.size()
+
+ self.qkv = TensorParallelColumnLinear.load_qkv(
+ config,
+ prefix=f"{prefix}.qkv",
+ weights=weights,
+ bias=False,
+ num_heads=self.num_heads,
+ num_key_value_heads=self.num_heads,
+ )
+ self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
+
+ self.proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.proj",
+ weights=weights,
+ bias=True,
+ )
+ self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)
+
+ def forward(
+ self,
+ hidden_state: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ max_seqlen: int,
+ ) -> torch.Tensor:
+ # apply the qkv linear layer to the hidden state
+ qkv = self.qkv(hidden_state)
+ query, key, value = qkv.split(
+ [self.embed_dim, self.embed_dim, self.embed_dim], dim=1
+ )
+
+ # reshape the query, key, and value tensors
+ _shape = (
+ hidden_state.shape[0],
+ self.num_heads,
+ self.embed_dim // self.num_heads,
+ )
+ query = query.view(*_shape)
+ key = key.view(*_shape)
+ value = value.view(*_shape)
+ # apply rotary positional embeddings
+ rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
+ rotary_dim = cos.shape[-1]
+ query_rot = query[..., :rotary_dim]
+ query_pass = query[..., rotary_dim:]
+ query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
+ query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape))
+
+ key_rot = key[..., :rotary_dim]
+ key_pass = key[..., rotary_dim:]
+ key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
+ key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
+
+ # execute sdpa
+ causal = False
+ query = query.transpose(0, 1)
+ key = key.transpose(0, 1)
+ value = value.transpose(0, 1)
+ fsdpa_op = ModuleFusedSDPA(FusedSDPA)
+ attention_mask = torch.zeros(
+ [1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool
+ )
+ for i in range(1, len(cu_seqlens)):
+ attention_mask[
+ :, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]
+ ] = True
+ attn_output = fsdpa_op(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=causal,
+ scale=None,
+ softmax_mode="None",
+ recompute_mode=None,
+ valid_sequence_lengths=None,
+ )
+ attn_output = attn_output.transpose(0, 1)
+
+ # reshape output to original dimensions
+ attn_output = attn_output.reshape(hidden_state.shape[0], -1)
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class Qwen2_5VLVisionMLP(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.activation_fn = ACT2FN[config.hidden_act]
+
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+
+ self.up = TensorParallelColumnLinear.load(
+ prefix=f"{prefix}.up_proj", weights=weights, config=config, bias=True
+ )
+ self.gate = TensorParallelColumnLinear.load(
+ prefix=f"{prefix}.gate_proj", weights=weights, config=config, bias=True
+ )
+ self.down = TensorParallelRowLinear.load(
+ prefix=f"{prefix}.down_proj", weights=weights, config=config, bias=True
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ gate_states = self.gate(hidden_states)
+ up_states = self.up(hidden_states)
+ activated_states = self.activation_fn(gate_states) * up_states
+ down_states = self.down(activated_states)
+ return down_states
+
+
+class Qwen2_5VLVisionBlock(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.attn = Qwen2_5VLAttention(
+ prefix=f"{prefix}.attn",
+ config=config,
+ weights=weights,
+ )
+ self.norm1 = FastRMSNorm.load(
+ prefix=f"{prefix}.norm1",
+ weights=weights,
+ eps=1e-6,
+ )
+ self.norm2 = FastRMSNorm.load(
+ prefix=f"{prefix}.norm2",
+ weights=weights,
+ eps=1e-6,
+ )
+ self.mlp = Qwen2_5VLVisionMLP(
+ prefix=f"{prefix}.mlp",
+ config=config,
+ weights=weights,
+ )
+
+ def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor:
+ norm1_out, _ = self.norm1(hidden_states)
+ attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen)
+ hidden_states = hidden_states + attn_out
+ norm2_out, _ = self.norm2(hidden_states)
+ mlp_out = self.mlp(norm2_out)
+ hidden_states = hidden_states + mlp_out
+ return hidden_states
+
+
+class Qwen2_5VLPatchMerger(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
+ self.patch_merger_ln_q = FastRMSNorm.load(
+ prefix=f"{prefix}.ln_q",
+ weights=weights,
+ eps=1e-6,
+ )
+ self.fc1 = TensorParallelColumnLinear.load(
+ prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True
+ )
+ self.fc2 = TensorParallelRowLinear.load(
+ prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
+ )
+
+ def forward(self, hidden_states) -> torch.Tensor:
+ hidden_states, _ = self.patch_merger_ln_q(hidden_states)
+ hidden_states = hidden_states.view(-1, self.hidden_size)
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = F.gelu(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Qwen2_5VisionModel(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+
+ self.spatial_merge_size = config.spatial_merge_size
+ kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
+ self.patch_embedding = nn.Conv3d(
+ in_channels=config.in_channels,
+ out_channels=config.hidden_size,
+ kernel_size=kernel_size,
+ stride=kernel_size,
+ bias=False,
+ )
+ self.patch_embedding.weight = nn.Parameter(
+ weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False
+ )
+ head_dim = config.hidden_size // config.num_heads
+
+ theta = 10000.0
+ dim = head_dim // 2
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ self.blocks = nn.ModuleList(
+ [
+ Qwen2_5VLVisionBlock(
+ prefix=f"{prefix}.blocks.{i}",
+ config=config,
+ weights=weights,
+ )
+ for i in range(config.depth)
+ ]
+ )
+ self.merger = Qwen2_5VLPatchMerger(
+ prefix=f"{prefix}.merger",
+ config=config,
+ weights=weights,
+ )
+
+ self.temporal_patch_size = config.temporal_patch_size
+ self.spatial_patch_size = config.spatial_patch_size
+ self.in_channels = config.in_channels
+ self.embed_dim = config.hidden_size
+ self.window_size = config.window_size
+ self.patch_size = config.patch_size
+ self.spatial_merge_unit = config.spatial_merge_size * config.spatial_merge_size
+ self.fullatt_block_indexes = config.fullatt_block_indexes
+
+ def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ batch_size, _, hidden_size = hidden_state.shape
+ class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
+ hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
+ return hidden_state
+
+ def get_window_index(self, grid_thw):
+ window_index: list = []
+ cu_window_seqlens: list = [0]
+ window_index_id = 0
+ vit_merger_window_size = (
+ self.window_size // self.spatial_merge_size // self.patch_size
+ )
+
+ for grid_t, grid_h, grid_w in grid_thw:
+ llm_grid_h, llm_grid_w = (
+ grid_h // self.spatial_merge_size,
+ grid_w // self.spatial_merge_size,
+ )
+ index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
+ grid_t, llm_grid_h, llm_grid_w
+ )
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
+ index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
+ index_padded = index_padded.reshape(
+ grid_t,
+ num_windows_h,
+ vit_merger_window_size,
+ num_windows_w,
+ vit_merger_window_size,
+ )
+ index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
+ grid_t,
+ num_windows_h * num_windows_w,
+ vit_merger_window_size,
+ vit_merger_window_size,
+ )
+ seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
+ index_padded = index_padded.reshape(-1)
+ index_new = index_padded[index_padded != -100]
+ window_index.append(index_new + window_index_id)
+ cu_seqlens_tmp = (
+ seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
+ )
+ cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
+ window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
+ window_index = torch.cat(window_index, dim=0)
+
+ return window_index, cu_window_seqlens
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ grid_thw: Optional[torch.LongTensor] = None,
+ ) -> torch.Tensor:
+
+ # reshape the input tensor for processing
+ shape = (
+ -1,
+ self.in_channels,
+ self.temporal_patch_size,
+ self.spatial_patch_size,
+ self.spatial_patch_size,
+ )
+ pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)
+ hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)
+ # TODO: revisit to see if we can avoid some of these reshapes
+
+ # find the position ids for the input tensor based on the grid_thw
+ pos_ids = []
+ for t, h, w in grid_thw:
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+ hpos_ids = hpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
+ hpos_ids = hpos_ids.flatten()
+
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+ wpos_ids = wpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
+ wpos_ids = wpos_ids.flatten()
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
+
+ pos_ids = torch.cat(pos_ids, dim=0)
+
+ max_grid_size = grid_thw[:, 1:].max()
+
+ # apply the positional embeddings to the position ids
+ seq = torch.arange(
+ max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype
+ )
+ rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+ window_index, cu_window_seqlens = self.get_window_index(grid_thw)
+ seq_len = hidden_states.shape[0]
+ patch_shape = (seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
+ og_shape = (seq_len, -1)
+
+ hidden_states = hidden_states.view(patch_shape)[window_index, :, :].view(
+ og_shape
+ )
+ rotary_pos_emb = rotary_pos_emb.view(patch_shape)[window_index, :, :].view(
+ og_shape
+ )
+
+ rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device)
+ cos = rotary_pos_emb.cos()
+ sin = rotary_pos_emb.sin()
+ cos = torch.cat((cos, cos), dim=-1).unsqueeze(1)
+ sin = torch.cat((sin, sin), dim=-1).unsqueeze(1)
+
+ cu_window_seqlens = torch.tensor(
+ cu_window_seqlens,
+ device="cpu",
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens).to(
+ hidden_states.device
+ )
+
+ # create a cu_seqlens tensor to be used in the attention mask
+ cu_seqlens = torch.repeat_interleave(
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
+ ).cumsum(dim=0, dtype=torch.int32)
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+ max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
+
+ # iterately apply the blocks to the hidden states
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+ for layer_num, block in enumerate(self.blocks):
+ # NOTE: qwen2_5_vl.py has a concept of full attention blocks
+ # that are applied at specific layers.
+ if layer_num in self.fullatt_block_indexes:
+ cu_seqlens_now = cu_seqlens
+ else:
+ cu_seqlens_now = cu_window_seqlens
+
+ hidden_states = block(hidden_states, cu_seqlens_now, cos, sin, max_seqlen)
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ # apply the final patch merger to the hidden states
+ hidden_states = self.merger(hidden_states)
+ reverse_indices = torch.argsort(window_index)
+ hidden_states = hidden_states[reverse_indices, :]
+ return hidden_states
+
+
+class Qwen2_5VLForConditionalGeneration(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ config.vision_config.quantize = None
+ config.vision_config.speculator = config.speculator
+ # set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly
+ # returns rope_scaling.type == "default" for Qwen2_5-VL model at the moment
+ if (
+ hasattr(config, "rope_scaling")
+ and config.rope_scaling is not None
+ and config.rope_scaling.get("type", None) == "default"
+ ):
+ config.rope_scaling.update({"rope_type": "mrope"})
+ self.hidden_size = config.hidden_size
+ self.vision_start_token_id = config.vision_start_token_id
+ self.vision_end_token_id = config.vision_end_token_id
+ self.image_token_id = config.image_token_id
+ self.video_token_id = config.video_token_id
+ self.spatial_merge_size = config.vision_config.spatial_merge_size
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix="model.embed_tokens", weights=weights
+ )
+ self.visual = Qwen2_5VisionModel(
+ prefix="visual", config=config.vision_config, weights=weights
+ )
+ self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
+ if config.tie_word_embeddings:
+ suffix = "model.embed_tokens"
+ else:
+ suffix = "lm_head"
+
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix=suffix if not prefix else f"{prefix}.{suffix}",
+ weights=weights,
+ )
+ self.device = weights.device
+
+ # based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391
+ # modified to first find segments then initialize position ids for each segment
+ # Steps:
+ # locate all vision and text segments
+ # calculate `vision_segment_lengths` for each vision segment to be use as offset
+ # calculate `text_segment_lengths` for each text segment to be used as offset
+ # create position ids for each vision segment based on the image grid
+ # create position ids for each text segment
+ # combine all the position ids
+ # the final segment is the difference between the last vision segment and the end of the input
+ # combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
+ def get_position_ids(
+ self,
+ input_ids: torch.Tensor,
+ image_grid_thw: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if image_grid_thw is None:
+ return (
+ torch.arange(input_ids.shape[0], device=input_ids.device)
+ .unsqueeze(1)
+ .repeat(1, 3)
+ )
+
+ spatial_merge_size = self.spatial_merge_size
+ vision_start_token_id = self.vision_start_token_id
+ vision_end_token_id = self.vision_end_token_id
+ device = input_ids.device
+ dtype = input_ids.dtype
+ input_ids_len = input_ids.shape[0]
+
+ vision_starts = torch.where(input_ids == vision_start_token_id)[0]
+ vision_ends = torch.where(input_ids == vision_end_token_id)[0]
+ vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
+ prev_vision_end = torch.cat(
+ [torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
+ )
+ text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
+ vision_widths_max = torch.cat(
+ [
+ torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
+ image_grid_thw[:-1, 2] // spatial_merge_size,
+ ]
+ )
+ vision_segment_lengths = vision_widths_max + text_lengths_between_vision
+ vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
+ text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
+
+ # create position ids for each vision segment based on the image grid
+ llm_pos_ids_list = []
+ for i, _ in enumerate(vision_segments):
+ t, h, w = (
+ image_grid_thw[i][0],
+ image_grid_thw[i][1] // spatial_merge_size,
+ image_grid_thw[i][2] // spatial_merge_size,
+ )
+ t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
+ h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
+ w_indices = torch.arange(w, device=device).repeat(t * h)
+ image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
+
+ # offset by the position of the last vision segment
+ im = image_position_ids + vision_segment_lengths[i]
+ llm_pos_ids_list.append(im)
+
+ # create position ids for each text segment
+ text_ranges = [
+ torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
+ + text_segment_lengths[i]
+ for i, seq_len in enumerate(text_lengths_between_vision)
+ ]
+
+ full_llm_pos_ids_list = [
+ item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
+ ]
+ max_s = full_llm_pos_ids_list[-1].max() + 1
+ final_text_len = input_ids_len - vision_ends[-1]
+ if final_text_len > 0:
+ m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
+ full_llm_pos_ids_list.append(m + max_s)
+
+ position_ids = (
+ torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
+ )
+ return position_ids
+
+ def get_vision_embeds(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_attention_mask: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ ):
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
+ return image_embeds
+
+ def get_inputs_embeds(
+ self,
+ input_ids: torch.Tensor,
+ vision_embeds: torch.Tensor = None,
+ ):
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ # apply the visual model to the pixel values if they are provided
+ if vision_embeds is not None:
+ mask = torch.where(input_ids == self.image_token_id)
+ inputs_embeds[mask] = vision_embeds
+
+ return inputs_embeds
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor],
+ attention_mask: Optional[torch.BoolTensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ image_indices=None,
+ ):
+
+ hidden_states = self.text_model(
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ cu_seqlen_prefill=cu_seqlen_prefill,
+ kv_cache=kv_cache,
+ slots=slots,
+ seqlen=seqlen,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py
new file mode 100644
index 00000000000..96acef31569
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py
@@ -0,0 +1,528 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Qwen2 VL model."""
+
+from typing import Optional, Tuple, List
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+
+from habana_frameworks.torch.hpex.kernels import FusedSDPA
+from vllm_hpu_extension.utils import ModuleFusedSDPA
+
+
+import numpy as np
+
+from transformers.activations import ACT2FN
+import torch.nn.functional as F
+
+from text_generation_server.layers.layernorm import FastLayerNorm, FastRMSNorm
+from text_generation_server.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelRowLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+)
+from text_generation_server.layers.attention import (
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
+ Qwen2Model,
+)
+from habana_frameworks.torch.hpex.kernels import (
+ RotaryPosEmbeddingMode,
+ apply_rotary_pos_emb,
+)
+import habana_frameworks.torch as htorch
+
+
+class Qwen2VLAttention(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.embed_dim = config.embed_dim // weights.process_group.size()
+ self.head_dim = config.hidden_size // config.num_heads
+ self.num_heads = config.num_heads // weights.process_group.size()
+
+ self.qkv = TensorParallelColumnLinear.load_qkv(
+ config,
+ prefix=f"{prefix}.qkv",
+ weights=weights,
+ bias=False,
+ num_heads=self.num_heads,
+ num_key_value_heads=self.num_heads,
+ )
+ self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
+ self.proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.proj",
+ weights=weights,
+ bias=True,
+ )
+ self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)
+
+ def forward(
+ self,
+ hidden_state: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ max_seqlen: int,
+ ) -> torch.Tensor:
+ # apply the qkv linear layer to the hidden state
+ qkv = self.qkv(hidden_state)
+ query, key, value = qkv.split(
+ [self.embed_dim, self.embed_dim, self.embed_dim], dim=1
+ )
+
+ # reshape the query, key, and value tensors
+ _shape = (
+ hidden_state.shape[0],
+ self.num_heads,
+ self.embed_dim // self.num_heads,
+ )
+ query = query.view(*_shape)
+ key = key.view(*_shape)
+ value = value.view(*_shape)
+
+ # apply rotary positional embeddings
+ rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
+ rotary_dim = cos.shape[-1]
+ query_rot = query[..., :rotary_dim]
+ query_pass = query[..., rotary_dim:]
+ query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
+ query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape))
+
+ key_rot = key[..., :rotary_dim]
+ key_pass = key[..., rotary_dim:]
+ key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
+ key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
+
+ # execute sdpa
+ causal = False
+ query = query.transpose(0, 1)
+ key = key.transpose(0, 1)
+ value = value.transpose(0, 1)
+ fsdpa_op = ModuleFusedSDPA(FusedSDPA)
+ attention_mask = torch.zeros(
+ [1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool
+ )
+ for i in range(1, len(cu_seqlens)):
+ attention_mask[
+ :, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]
+ ] = True
+ attn_output = fsdpa_op(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=causal,
+ scale=None,
+ softmax_mode="None",
+ recompute_mode=None,
+ valid_sequence_lengths=None,
+ )
+ attn_output = attn_output.transpose(0, 1)
+ # reshape output to original dimensions
+ attn_output = attn_output.reshape(hidden_state.shape[0], -1)
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class Qwen2VLVisionMLP(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = TensorParallelColumnLinear.load(
+ prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
+ )
+ self.fc2 = TensorParallelRowLinear.load(
+ prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Qwen2VLVisionBlock(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.attn = Qwen2VLAttention(
+ prefix=f"{prefix}.attn",
+ config=config,
+ weights=weights,
+ )
+ self.norm1 = FastLayerNorm.load(
+ prefix=f"{prefix}.norm1",
+ weights=weights,
+ eps=1e-6,
+ )
+ self.norm2 = FastLayerNorm.load(
+ prefix=f"{prefix}.norm2",
+ weights=weights,
+ eps=1e-6,
+ )
+ self.mlp = Qwen2VLVisionMLP(
+ prefix=f"{prefix}.mlp",
+ config=config,
+ weights=weights,
+ )
+
+ def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor:
+ norm1_out, residual = self.norm1(hidden_states)
+ attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen)
+ hidden_states = attn_out + residual
+ norm2_out, residual = self.norm2(hidden_states)
+ hidden_states = hidden_states + self.mlp(norm2_out)
+ return hidden_states
+
+
+class Qwen2VLPatchMerger(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.hidden_size = config.embed_dim * (config.spatial_merge_size**2)
+ self.patch_merger_ln_q = FastLayerNorm.load(
+ prefix=f"{prefix}.ln_q",
+ weights=weights,
+ eps=1e-6,
+ )
+ self.fc1 = TensorParallelColumnLinear.load(
+ prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True
+ )
+ self.fc2 = TensorParallelRowLinear.load(
+ prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
+ )
+
+ def forward(self, hidden_states) -> torch.Tensor:
+ hidden_states, _ = self.patch_merger_ln_q(hidden_states)
+ hidden_states = hidden_states.view(-1, self.hidden_size)
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = F.gelu(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Qwen2VisionModel(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.spatial_merge_size = config.spatial_merge_size
+ kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
+ self.patch_embedding = nn.Conv3d(
+ in_channels=config.in_chans,
+ out_channels=config.embed_dim,
+ kernel_size=kernel_size,
+ stride=kernel_size,
+ bias=False,
+ )
+ self.patch_embedding.weight = nn.Parameter(
+ weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False
+ )
+ head_dim = config.embed_dim // config.num_heads
+ # TODO: replace with static positional embeddings once implemented
+ theta = 10000.0
+ dim = head_dim // 2
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ self.blocks = nn.ModuleList(
+ [
+ Qwen2VLVisionBlock(
+ prefix=f"{prefix}.blocks.{i}",
+ config=config,
+ weights=weights,
+ )
+ for i in range(config.depth)
+ ]
+ )
+ self.merger = Qwen2VLPatchMerger(
+ prefix=f"{prefix}.merger",
+ config=config,
+ weights=weights,
+ )
+
+ self.temporal_patch_size = config.temporal_patch_size
+ self.spatial_patch_size = config.spatial_patch_size
+ self.in_channels = config.in_channels
+ self.embed_dim = config.embed_dim
+
+ def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ batch_size, _, hidden_size = hidden_state.shape
+ class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
+ hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
+ return hidden_state
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ grid_thw: Optional[torch.LongTensor] = None,
+ ) -> torch.Tensor:
+ # reshape the input tensor for processing
+ shape = (
+ -1,
+ self.in_channels,
+ self.temporal_patch_size,
+ self.spatial_patch_size,
+ self.spatial_patch_size,
+ )
+ pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)
+ hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)
+ # TODO: revisit to see if we can avoid some of these reshapes
+
+ # find the position ids for the input tensor based on the grid_thw
+ pos_ids = []
+ for t, h, w in grid_thw:
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+ hpos_ids = hpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
+ hpos_ids = hpos_ids.flatten()
+
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+ wpos_ids = wpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
+ wpos_ids = wpos_ids.flatten()
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
+
+ pos_ids = torch.cat(pos_ids, dim=0)
+ max_grid_size = grid_thw[:, 1:].max()
+
+ # apply the positional embeddings to the position ids
+ seq = torch.arange(
+ max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype
+ )
+ rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+ rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype)
+
+ cos = rotary_pos_emb.cos()
+ sin = rotary_pos_emb.sin()
+ cos = torch.cat((cos, cos), dim=-1).unsqueeze(1)
+ sin = torch.cat((sin, sin), dim=-1).unsqueeze(1)
+
+ # create a cu_seqlens tensor to be used in the attention mask
+ cu_seqlens = torch.repeat_interleave(
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
+ ).cumsum(dim=0, dtype=torch.int32)
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+ max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
+ # iterately apply the blocks to the hidden states
+ lazy_mode = htorch.utils.internal.is_lazy()
+ if lazy_mode:
+ htorch.core.mark_step()
+ for block in self.blocks:
+ hidden_states = block(hidden_states, cu_seqlens, cos, sin, max_seqlen)
+ if lazy_mode:
+ htorch.core.mark_step()
+
+ # apply the final patch merger to the hidden states
+ hidden_states = self.merger(hidden_states)
+ return hidden_states
+
+
+class Qwen2VLForConditionalGeneration(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ config.vision_config.quantize = None
+ config.vision_config.speculator = config.speculator
+ # set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly
+ # returns rope_scaling.type == "default" for Qwen2-VL model at the moment
+ if (
+ hasattr(config, "rope_scaling")
+ and config.rope_scaling is not None
+ and config.rope_scaling.get("type", None) == "default"
+ ):
+ config.rope_scaling.update({"rope_type": "mrope"})
+ self.hidden_size = config.hidden_size
+ self.vision_start_token_id = config.vision_start_token_id
+ self.vision_end_token_id = config.vision_end_token_id
+ self.image_token_id = config.image_token_id
+ self.video_token_id = config.video_token_id
+ self.spatial_merge_size = config.vision_config.spatial_merge_size
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix="model.embed_tokens", weights=weights
+ )
+ self.visual = Qwen2VisionModel(
+ prefix="visual", config=config.vision_config, weights=weights
+ )
+ self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
+ if config.tie_word_embeddings:
+ suffix = "model.embed_tokens"
+ else:
+ suffix = "lm_head"
+
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix=suffix if not prefix else f"{prefix}.{suffix}",
+ weights=weights,
+ )
+ self.norm = FastRMSNorm.load(
+ prefix="model.norm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.device = weights.device
+
+ # based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391
+ # modified to first find segments then initialize position ids for each segment
+ # Steps:
+ # locate all vision and text segments
+ # calculate `vision_segment_lengths` for each vision segment to be use as offset
+ # calculate `text_segment_lengths` for each text segment to be used as offset
+ # create position ids for each vision segment based on the image grid
+ # create position ids for each text segment
+ # combine all the position ids
+ # the final segment is the difference between the last vision segment and the end of the input
+ # combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
+ def get_position_ids(
+ self,
+ input_ids: torch.Tensor,
+ image_grid_thw: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if image_grid_thw is None:
+ return (
+ torch.arange(input_ids.shape[0], device=input_ids.device)
+ .unsqueeze(1)
+ .repeat(1, 3)
+ )
+
+ spatial_merge_size = self.spatial_merge_size
+ vision_start_token_id = self.vision_start_token_id
+ vision_end_token_id = self.vision_end_token_id
+ device = input_ids.device
+ dtype = input_ids.dtype
+ input_ids_len = input_ids.shape[0]
+
+ vision_starts = torch.where(input_ids == vision_start_token_id)[0]
+ vision_ends = torch.where(input_ids == vision_end_token_id)[0]
+ vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
+ prev_vision_end = torch.cat(
+ [torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
+ )
+ text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
+ vision_widths_max = torch.cat(
+ [
+ torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
+ image_grid_thw[:-1, 2] // spatial_merge_size,
+ ]
+ )
+ vision_segment_lengths = vision_widths_max + text_lengths_between_vision
+ vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
+ text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
+
+ # create position ids for each vision segment based on the image grid
+ llm_pos_ids_list = []
+ for i, _ in enumerate(vision_segments):
+ t, h, w = (
+ image_grid_thw[i][0],
+ image_grid_thw[i][1] // spatial_merge_size,
+ image_grid_thw[i][2] // spatial_merge_size,
+ )
+ t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
+ h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
+ w_indices = torch.arange(w, device=device).repeat(t * h)
+ image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
+
+ # offset by the position of the last vision segment
+ im = image_position_ids + vision_segment_lengths[i]
+ llm_pos_ids_list.append(im)
+
+ # create position ids for each text segment
+ text_ranges = [
+ torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
+ + text_segment_lengths[i]
+ for i, seq_len in enumerate(text_lengths_between_vision)
+ ]
+
+ full_llm_pos_ids_list = [
+ item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
+ ]
+ max_s = full_llm_pos_ids_list[-1].max() + 1
+ final_text_len = input_ids_len - vision_ends[-1]
+ if final_text_len > 0:
+ m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
+ full_llm_pos_ids_list.append(m + max_s)
+
+ position_ids = (
+ torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
+ )
+ return position_ids
+
+ def get_vision_embeds(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_attention_mask: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ ):
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
+ return image_embeds
+
+ def get_inputs_embeds(
+ self,
+ input_ids: torch.Tensor,
+ vision_embeds: torch.Tensor = None,
+ ):
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ # apply the visual model to the pixel values if they are provided
+ if vision_embeds is not None:
+ mask = torch.where(input_ids == self.image_token_id)
+ inputs_embeds[mask] = vision_embeds
+
+ return inputs_embeds
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor],
+ attention_mask: Optional[torch.BoolTensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ image_indices=None,
+ ):
+ hidden_states = self.text_model(
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ cu_seqlen_prefill=cu_seqlen_prefill,
+ kv_cache=kv_cache,
+ slots=slots,
+ seqlen=seqlen,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/siglip.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/siglip.py
new file mode 100644
index 00000000000..95ac9edee20
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/siglip.py
@@ -0,0 +1,410 @@
+from typing import Optional, Tuple
+import warnings
+import math
+import torch
+from torch import nn
+
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPooling,
+)
+from transformers import SiglipConfig, SiglipVisionConfig
+from torch.nn.init import _calculate_fan_in_and_fan_out
+
+from text_generation_server.layers.tensor_parallel import (
+ TensorParallelEmbedding,
+ TensorParallelColumnLinear,
+ TensorParallelRowLinear,
+)
+
+
+class SiglipVisionEmbeddings(nn.Module):
+ def __init__(self, prefix, config: SiglipVisionConfig, weights):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ )
+ self.patch_embedding.weight = nn.Parameter(
+ weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
+ )
+ self.patch_embedding.bias = nn.Parameter(
+ weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
+ )
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches
+ self.position_embedding = TensorParallelEmbedding(
+ prefix=f"{prefix}.position_embedding", weights=weights
+ )
+ self.register_buffer(
+ "position_ids",
+ torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
+ persistent=False,
+ )
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ patch_embeds = self.patch_embedding(
+ pixel_values
+ ) # shape = [*, width, grid, grid]
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+
+ embeddings = embeddings + self.position_embedding(self.position_ids)
+ return embeddings
+
+
+class SiglipAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ self.head_size = self.head_dim
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.embed_dim = self.embed_dim // weights.process_group.size()
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+
+ self.k_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.k_proj", weights=weights, bias=True
+ )
+ self.v_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.v_proj", weights=weights, bias=True
+ )
+ self.q_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.q_proj", weights=weights, bias=True
+ )
+ self.out_proj = TensorParallelRowLinear.load(
+ config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
+ )
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return (
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ .contiguous()
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ bsz, tgt_len, _ = hidden_states.size()
+ query_states = self.q_proj(hidden_states)
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ # scale post matmul
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) * self.scale
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = (
+ attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ + attention_mask
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(
+ attn_weights, dim=-1, dtype=torch.float32
+ ).to(attn_weights.dtype)
+ attn_weights = nn.functional.dropout(
+ attn_weights, p=self.dropout, training=self.training
+ )
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class SiglipMLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = TensorParallelColumnLinear.load( # config.hidden_size, config.intermediate_size
+ prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
+ )
+ self.fc2 = TensorParallelRowLinear.load( # config.intermediate_size, config.hidden_size
+ prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class SiglipEncoderLayer(nn.Module):
+ def __init__(self, prefix, config: SiglipConfig, weights):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = SiglipAttention(
+ prefix=f"{prefix}.self_attn", config=config, weights=weights
+ )
+ self.layer_norm1 = nn.LayerNorm.load(
+ prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
+ )
+ self.mlp = SiglipMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+ self.layer_norm2 = nn.LayerNorm.load(
+ prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ ) -> Tuple[torch.FloatTensor]:
+ residual = hidden_states
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ )
+ hidden_states = residual + hidden_states
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states, None
+
+
+class SiglipMultiheadAttentionPoolingHead(nn.Module):
+ """Multihead Attention Pooling."""
+
+ def __init__(self, prefix, config: SiglipVisionConfig, weights):
+ super().__init__()
+
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ self.attention = torch.nn.MultiheadAttention(
+ config.hidden_size, config.num_attention_heads, batch_first=True
+ )
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.mlp = SiglipMLP(prefix, config, weights)
+
+ def forward(self, hidden_state):
+ batch_size = hidden_state.shape[0]
+ probe = self.probe.repeat(batch_size, 1, 1)
+
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
+
+ residual = hidden_state
+ hidden_state = self.layernorm(hidden_state)
+ hidden_state = residual + self.mlp(hidden_state)
+
+ return hidden_state[:, 0]
+
+
+def _trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2,
+ )
+
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ lower = norm_cdf((a - mean) / std)
+ upper = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * lower - 1, 2 * upper - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.0))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+
+
+def trunc_normal_tf_(
+ tensor: torch.Tensor,
+ mean: float = 0.0,
+ std: float = 1.0,
+ a: float = -2.0,
+ b: float = 2.0,
+) -> torch.Tensor:
+ """Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \\leq \text{mean} \\leq b`.
+
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
+ and the result is subsquently scaled and shifted by the mean and std args.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ """
+ with torch.no_grad():
+ _trunc_normal_(tensor, 0, 1.0, a, b)
+ tensor.mul_(std).add_(mean)
+
+
+def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
+ if mode == "fan_in":
+ denom = fan_in
+ elif mode == "fan_out":
+ denom = fan_out
+ elif mode == "fan_avg":
+ denom = (fan_in + fan_out) / 2
+
+ variance = scale / denom
+
+ if distribution == "truncated_normal":
+ # constant is stddev of standard normal truncated to (-2, 2)
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
+ elif distribution == "normal":
+ with torch.no_grad():
+ tensor.normal_(std=math.sqrt(variance))
+ elif distribution == "uniform":
+ bound = math.sqrt(3 * variance)
+ with torch.no_grad():
+ tensor.uniform_(-bound, bound)
+ else:
+ raise ValueError(f"invalid distribution {distribution}")
+
+
+def lecun_normal_(tensor):
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
+
+
+def default_flax_embed_init(tensor):
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
+
+
+class SiglipEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`SiglipEncoderLayer`].
+
+ Args:
+ config: SiglipConfig
+ """
+
+ def __init__(self, prefix, config: SiglipConfig, weights):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList(
+ [
+ SiglipEncoderLayer(
+ prefix=f"{prefix}.layers.{i}", config=config, weights=weights
+ )
+ for i in range(config.num_hidden_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ hidden_states = inputs_embeds
+ for idx, encoder_layer in enumerate(self.layers):
+ hidden_states, _ = encoder_layer(
+ hidden_states,
+ attention_mask,
+ )
+
+ return hidden_states
+
+
+class SiglipVisionTransformer(nn.Module):
+ def __init__(self, prefix, config: SiglipVisionConfig, weights):
+ super().__init__()
+ self.config = config
+
+ self.embeddings = SiglipVisionEmbeddings(
+ prefix=f"{prefix}.embeddings", config=config, weights=weights
+ )
+ self.encoder = SiglipEncoder(
+ prefix=f"{prefix}.encoder", config=config, weights=weights
+ )
+
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ ):
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ hidden_states = self.embeddings(pixel_values)
+
+ # NOTE: up until this point, the code logits are exactly
+ # the same as the transformers code. The values evaulate
+ # slightly differently in our encoder layer.
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ )
+ last_hidden_state = encoder_outputs
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ # pooler_output=pooled_output,
+ # hidden_states=encoder_outputs,
+ )
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py
new file mode 100644
index 00000000000..23cba7b0535
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py
@@ -0,0 +1,67 @@
+def load_text_model(prefix, config, weights, name=None):
+ if config.model_type == "llama":
+ from text_generation_server.models.custom_modeling.flash_llama_modeling import (
+ FlashLlamaForCausalLM,
+ )
+
+ return FlashLlamaForCausalLM(prefix, config, weights, name=name)
+ elif config.model_type == "mistral":
+ from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
+ FlashMistralForCausalLM,
+ )
+
+ return FlashMistralForCausalLM(prefix, config, weights, name=name)
+ elif config.model_type == "gemma":
+ from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
+ FlashGemmaForCausalLM,
+ )
+
+ return FlashGemmaForCausalLM(prefix, config, weights)
+ elif config.model_type == "gemma2":
+ from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
+ FlashGemma2ForCausalLM,
+ )
+
+ return FlashGemma2ForCausalLM(prefix, config, weights)
+ elif config.model_type == "gemma3" or config.model_type == "gemma3_text":
+ from text_generation_server.models.custom_modeling.flash_gemma3_modeling import (
+ FlashGemma3ForCausalLM,
+ )
+
+ return FlashGemma3ForCausalLM(prefix, config, weights)
+ elif config.model_type == "paligemma":
+ from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
+ FlashGemmaForCausalLM,
+ )
+
+ return FlashGemmaForCausalLM(prefix, config, weights)
+ else:
+ raise RuntimeError(f"Unsupported model type {config.model_type}")
+
+
+def load_vision_model(prefix, config, weights):
+ if config.model_type == "clip_vision_model":
+ from text_generation_server.models.custom_modeling.clip import (
+ CLIPVisionTransformer,
+ )
+
+ return CLIPVisionTransformer(
+ prefix=f"{prefix}.vision_model", config=config, weights=weights
+ )
+ if (
+ config.model_type == "siglip_vision_model"
+ or config.model_type == "gemma3_vision"
+ ):
+ from text_generation_server.models.custom_modeling.siglip import (
+ SiglipVisionTransformer,
+ )
+
+ # TODO: ensure that using the prefix doesn't break any existing models
+ # that rely on the old prefix (update the old models if necessary)
+ return SiglipVisionTransformer(
+ prefix=f"{prefix}.vision_model",
+ config=config,
+ weights=weights,
+ )
+ else:
+ raise RuntimeError(f"Unsupported model type {config.model_type}")
diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py
new file mode 100644
index 00000000000..d1e1595b4d3
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py
@@ -0,0 +1,2653 @@
+import math
+import os
+import time
+import torch
+import torch.distributed
+
+import numpy as np
+
+from loguru import logger
+from dataclasses import dataclass
+from opentelemetry import trace
+from transformers import (
+ PreTrainedTokenizerBase,
+ AutoConfig,
+ AutoTokenizer,
+ GenerationConfig,
+)
+from typing import (
+ Any,
+ Iterable,
+ Optional,
+ Tuple,
+ List,
+ Type,
+ Dict,
+ Union,
+)
+import torch.nn.functional as F
+from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
+from text_generation_server.utils.chunks import concat_text_chunks
+from text_generation_server.models import Model
+from text_generation_server.utils.log import log_master
+from text_generation_server.utils.tokens import batch_top_tokens
+from text_generation_server.utils.speculate import get_speculate
+from text_generation_server.utils import (
+ initialize_torch_distributed,
+ weight_files,
+ Weights,
+ pad_next_token_chooser_parameters,
+)
+from text_generation_server.models.types import (
+ Batch,
+ Tokens,
+ Generation,
+ GeneratedText,
+)
+from text_generation_server.pb import generate_pb2
+from text_generation_server.models.globals import (
+ BLOCK_SIZE,
+ REQUEST_LOGPROBS,
+ TGI_WIGGLE_ROOM,
+ get_adapter_to_index,
+)
+from text_generation_server.layers.attention import (
+ KVCache,
+ KVCompressCache,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+ trim_attn_metadata,
+ trim_seqlen_metadata,
+ _async_h2d_tensor_copy,
+)
+from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
+from text_generation_server.utils.dist import MEMORY_FRACTION
+from text_generation_server.utils.quantization import get_loader
+from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
+from text_generation_server.utils.import_utils import (
+ empty_cache,
+ synchronize,
+ get_free_memory,
+)
+from text_generation_server.utils.prefill_chunking import (
+ get_max_prefill_tokens,
+)
+import vllm_hpu_extension.environment as environment
+import habana_frameworks.torch as htorch
+import itertools
+from vllm_hpu_extension.bucketing.common import get_bucketing_context
+from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
+
+tracer = trace.get_tracer(__name__)
+
+
+def generate_block_metadata(
+ dtype,
+ use_contiguous_pa,
+ slots,
+ block_tables,
+ bucketing_ctx,
+ slots_in_window=None,
+ block_bucket_size=None,
+):
+ # Prepare values if we need to continue decoding
+ # need for HPUPagedAttentionMetadata preparation
+ def flatten(in_list):
+ return list(itertools.chain(*in_list))
+
+ def gather_list(input, indices, v):
+ return [input[i] if i is not None else v for i in indices]
+
+ def pad_list(input, k, v):
+ input_len = len(input)
+ target_len = (input_len + k - 1) // k * k
+ padding = target_len - input_len
+ return input + [v] * padding
+
+ last_block_usage = [slot % BLOCK_SIZE + 1 for slot in slots]
+ block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)]
+ block_usage = [
+ [BLOCK_SIZE] * (len(bt) - 1) + [lbu]
+ for bt, lbu in zip(block_tables, last_block_usage)
+ if bt
+ ]
+
+ block_list = flatten(block_tables)
+ block_groups = flatten(block_groups)
+ block_usage = flatten(block_usage)
+ assert len(block_list) == len(block_groups)
+ assert len(block_list) == len(block_usage)
+ if use_contiguous_pa:
+ if block_bucket_size is None:
+ block_bucket_size = max(max(block_list) + 1, len(block_list))
+ if bucketing_ctx is not None:
+ block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks(
+ block_bucket_size
+ )
+ indices: List[Any]
+ indices = [None] * block_bucket_size
+ for i, bid in enumerate(block_list):
+ indices[bid] = i
+ block_list = gather_list(block_list, indices, 0)
+ block_groups = gather_list(block_groups, indices, -1)
+ block_usage = gather_list(block_usage, indices, 1)
+ else:
+ if block_bucket_size is None:
+ block_bucket_size = len(block_list)
+ if bucketing_ctx is not None:
+ block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks(
+ block_bucket_size
+ )
+ block_list = pad_list(block_list, block_bucket_size, 0)
+ block_groups = pad_list(block_groups, block_bucket_size, -1)
+ block_usage = pad_list(block_usage, block_bucket_size, 1)
+ slots_in_window_mask = None
+ if slots_in_window is not None:
+ slot_list = [
+ block_id * BLOCK_SIZE + slot_idx
+ for block_id in block_list
+ for slot_idx in range(BLOCK_SIZE)
+ ]
+ slot_list = torch.tensor(slot_list, dtype=torch.int64)
+ slot_list = slot_list.view(-1, BLOCK_SIZE)
+ slots_in_window_mask = torch.isin(slot_list, slots_in_window)
+ for i in range(slots_in_window_mask.shape[0]):
+ if not slots_in_window_mask[i].any():
+ slots_in_window_mask[i, 0] = True
+
+ block_list = torch.tensor(block_list, dtype=torch.int, device="cpu")
+ block_groups = torch.tensor(block_groups, dtype=torch.int, device="cpu")
+ block_usage = torch.tensor(block_usage, dtype=dtype, device="cpu")
+ return (
+ block_list,
+ block_groups,
+ block_usage,
+ slots_in_window_mask,
+ block_bucket_size,
+ )
+
+
+@dataclass
+class FlashCausalLMBatch(Batch):
+ batch_id: int
+ requests: List[generate_pb2.Request]
+ # request id -> idx in list mapping
+ requests_idx_mapping: Dict[int, int]
+
+ # Decoder values
+ # Can be a list for easy filtering
+ # If `input_ids` is a list, it needs to be materialized to a tensor first
+ input_ids: Union[torch.Tensor, List[List[int]]]
+ # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
+ position_ids: Optional[torch.Tensor]
+ speculative_ids: Optional[torch.Tensor]
+
+ # Set when creating the batch
+ # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
+ # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
+ slot_indices: Optional[torch.Tensor]
+
+ # list of length b of list of length s_i // block_size
+ block_tables: List[List[int]]
+ # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
+ block_tables_tensor: torch.Tensor
+ # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
+ slots: torch.Tensor
+ # list of length b + 1 containing the cumulative sequence slot lengths of the sequences in the batch
+ # used for filtering
+ cu_slots: torch.Tensor
+
+ max_input_length: int
+ max_current_length: int
+
+ # Whether this batch contains at least one request that is prefilling
+ prefilling: bool
+ # Whether each request is prefilling
+ prefilling_mask: List[bool]
+
+ # Prefill metadata tensors to efficiently compute logprobs
+ # tensor of length b + 1 containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
+ cu_seqlen_prefill: Optional[torch.Tensor]
+ # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
+ # as we only keep SLIDING_WINDOW values instead of the whole tensor
+ prefill_cache_indices: Optional[torch.Tensor]
+ # Will be set by `generate_token` and reset after each prefill forward
+ prefill_head_indices: Optional[torch.Tensor]
+ # Will be set by `generate_token` and reset after each prefill forward
+ prefill_next_token_indices: Optional[torch.tensor]
+ # Will be set by `generate_token` and reset after each prefill forward
+ prefill_cu_outlens: Optional[List[int]]
+ # Will be set by `generate_token` and reset after each prefill forward
+ prefill_logprob_tokens: List[Optional[Tokens]]
+
+ # All tokens
+ all_input_ids: List[List[int]]
+ all_input_ids_tensor: torch.Tensor
+
+ # Lengths of all generations present in the batch
+ input_lengths: List[int]
+ # size [b], containing the number of blocks that can be retrieved from the cache
+ cache_lengths: List[int]
+ prompt_lengths: List[int]
+ # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
+ input_lengths_tensor: Optional[torch.Tensor]
+ cache_lengths_tensor: Optional[torch.Tensor]
+ prompt_lengths_tensor: torch.Tensor
+
+ prefix_offsets: List[Optional[int]]
+ read_offsets: List[Optional[int]]
+
+ # Generation helpers
+ next_token_chooser: HeterogeneousNextTokenChooser
+ stopping_criterias: List[StoppingCriteria]
+ top_n_tokens: List[int]
+ top_n_tokens_tensor: torch.Tensor
+
+ # Adapter metadata for each request
+ # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
+ adapter_meta: Optional[AdapterBatchMetadata]
+
+ # Number of blocks in this batch
+ num_blocks: int
+ # Maximum number of blocks
+ max_blocks: int
+
+ hpu_attn_meta: Optional[HPUPagedAttentionMetadata]
+
+ next_token_logits: Optional[torch.Tensor]
+ speculative_logits: Optional[torch.Tensor]
+ valid_indices: Optional[List[int]]
+
+ def to_pb(self) -> generate_pb2.CachedBatch:
+ return generate_pb2.CachedBatch(
+ id=self.batch_id,
+ request_ids=[r.id for r in self.requests],
+ size=len(self),
+ max_tokens=self.num_blocks * BLOCK_SIZE,
+ current_tokens=(
+ sum([len(i) for i in self.input_ids])
+ if isinstance(self.input_ids, list)
+ else len(self.input_ids)
+ ),
+ )
+
+ @classmethod
+ def batch_tokenized_inputs(
+ cls, requests: Iterable[generate_pb2.Request], tokenizer
+ ):
+ max_length = 0
+ all_input_ids = []
+ batch_size = 0
+ for r in requests:
+ batch_size += 1
+ inputs = concat_text_chunks(r.input_chunks.chunks)
+ input_ids = tokenizer(
+ inputs,
+ truncation=True,
+ max_length=r.truncate,
+ add_special_tokens=r.add_special_tokens,
+ )["input_ids"]
+ max_length = max(max_length, len(input_ids))
+ all_input_ids.append(input_ids)
+ return all_input_ids
+
+ @classmethod
+ def from_tokenized(
+ cls,
+ pb: generate_pb2.Batch,
+ tokenizer: PreTrainedTokenizerBase,
+ batch_tokenized_inputs,
+ dtype: torch.dtype,
+ device: torch.device,
+ ) -> "FlashCausalLMBatch":
+ cache_lengths = []
+ input_lengths = []
+ prompt_lengths = []
+ prefix_offsets = []
+ read_offsets = []
+ all_input_ids = []
+ all_postfix_ids = []
+ requests_idx_mapping = {}
+ slots = []
+ cu_slots = [0]
+
+ next_token_chooser_parameters = []
+ stopping_criterias = []
+ top_n_tokens = []
+
+ num_blocks = 0
+ max_input_length = 0
+ max_current_length = 0
+ max_length = 0
+ max_blocks = 0
+
+ cu_blocks = [0]
+ block_tables = []
+ block_tables_ragged = []
+
+ # Parse batch
+ for i, (r, tokenized_input) in enumerate(
+ zip(pb.requests, batch_tokenized_inputs)
+ ):
+ ### XXX: This consumes so much memory on long requests
+ ### Deactivating it by default seems like the best course.
+ if not REQUEST_LOGPROBS:
+ r.prefill_logprobs = False
+ else:
+ assert False, "prefill_logprobs not supported yet"
+ # request id -> idx in list mapping
+ requests_idx_mapping[r.id] = i
+
+ prompt_length = len(tokenized_input)
+ prompt_lengths.append(prompt_length)
+
+ cache_length = r.cache_len
+
+ assert (
+ cache_length <= prompt_length
+ ), f"Prefix {cache_length} vs input {prompt_length}"
+ if cache_length == prompt_length:
+ assert False, "unreachable"
+
+ # `chunk_len` is an optional field in the protobuf
+ # It is only set if the model support chunking
+ # Use all the remaining ids
+ postfix_ids = tokenized_input[cache_length:]
+ input_length = len(postfix_ids)
+
+ input_lengths.append(input_length)
+
+ prefix_offsets.append(prompt_length - 5)
+ read_offsets.append(prompt_length)
+
+ all_postfix_ids.append(postfix_ids)
+ all_input_ids.append(tokenized_input)
+
+ next_token_chooser_parameters.append(r.parameters)
+
+ stopping_criteria = StoppingCriteria.from_pb(
+ r.stopping_parameters, tokenizer
+ )
+ max_new_tokens = stopping_criteria.max_new_tokens
+ stopping_criterias.append(stopping_criteria)
+ top_n_tokens.append(r.top_n_tokens)
+
+ # Paged attention
+ # Remove one as the first token des not have a past
+ speculative_length = get_speculate()
+ speculative_length = 0 if speculative_length is None else speculative_length
+
+ # Tokens that need to be mapped to blocks.
+ block_tokens = prompt_length + max_new_tokens - 1 + speculative_length
+
+ # blocks and slots can be empty (for example in warmup)
+ if not r.blocks:
+ needed_blocks = math.ceil(block_tokens / BLOCK_SIZE)
+ request_blocks = [
+ b for b in range(num_blocks, num_blocks + needed_blocks)
+ ]
+ request_slots = [
+ s
+ for b in request_blocks
+ for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
+ ]
+ else:
+ request_blocks = r.blocks
+ request_slots = r.slots
+
+ block_tables.append(request_blocks)
+ block_tables_ragged.extend(request_blocks)
+ cu_blocks.append(len(block_tables_ragged))
+
+ slots.extend(request_slots)
+ cu_slots.append(len(slots))
+
+ cache_lengths.append(cache_length)
+ num_blocks += len(request_blocks)
+
+ # Update
+ max_blocks = max(max_blocks, len(request_blocks))
+ max_input_length = max(max_input_length, input_length)
+ max_current_length = max(max_current_length, cache_length + input_length)
+ max_length = max(
+ max_length,
+ prompt_length + max_new_tokens + speculative_length,
+ )
+
+ next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
+ next_token_chooser_parameters, dtype, device, tokenizer
+ )
+
+ # Padded all_input_ids_tensor
+ all_input_ids_tensor = np.zeros(
+ (len(all_input_ids), max_length), dtype=np.int64
+ )
+ for i, input_ids in enumerate(all_input_ids):
+ all_input_ids_tensor[i, : len(input_ids)] = input_ids
+
+ # put on cpu temporarily, move to hpu in prepare_for_prefill
+ all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64)
+
+ top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64)
+
+ block_tables_ragged = torch.tensor(block_tables_ragged, dtype=torch.int32)
+ cu_blocks = torch.tensor(cu_blocks, dtype=torch.int64)
+ block_tables_tensor = torch.empty(
+ (len(block_tables), max_blocks),
+ dtype=torch.int32,
+ )
+
+ for i, request_blocks in enumerate(block_tables):
+ block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
+
+ prompt_lengths_tensor = torch.tensor(prompt_lengths, dtype=torch.int32)
+
+ slots = torch.tensor(slots, dtype=torch.int64)
+ cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
+
+ return cls(
+ batch_id=pb.id,
+ requests=pb.requests,
+ requests_idx_mapping=requests_idx_mapping,
+ input_ids=all_postfix_ids,
+ block_tables=block_tables,
+ block_tables_tensor=block_tables_tensor,
+ cache_lengths=cache_lengths,
+ max_input_length=max_input_length,
+ max_current_length=max_current_length,
+ prefilling=True,
+ prefilling_mask=[True] * len(pb.requests),
+ prefill_logprob_tokens=[None] * len(pb.requests),
+ input_lengths=input_lengths,
+ prompt_lengths=prompt_lengths,
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
+ all_input_ids=all_input_ids,
+ all_input_ids_tensor=all_input_ids_tensor,
+ next_token_chooser=next_token_chooser,
+ stopping_criterias=stopping_criterias,
+ top_n_tokens=top_n_tokens,
+ top_n_tokens_tensor=top_n_tokens_tensor,
+ num_blocks=num_blocks,
+ max_blocks=max_blocks,
+ speculative_ids=None,
+ prompt_lengths_tensor=prompt_lengths_tensor,
+ # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
+ position_ids=None,
+ cu_seqlen_prefill=None,
+ prefill_cache_indices=None,
+ slot_indices=None,
+ slots=slots,
+ cu_slots=cu_slots,
+ prefill_head_indices=None,
+ prefill_next_token_indices=None,
+ prefill_cu_outlens=None,
+ cache_lengths_tensor=None,
+ input_lengths_tensor=None,
+ adapter_meta=None,
+ hpu_attn_meta=None,
+ next_token_logits=None,
+ speculative_logits=None,
+ valid_indices=None,
+ )
+
+ @classmethod
+ def from_pb(
+ cls,
+ pb: generate_pb2.Batch,
+ tokenizer: PreTrainedTokenizerBase,
+ dtype: torch.dtype,
+ device: torch.device,
+ ) -> "FlashCausalLMBatch":
+ assert len(pb.requests) > 0
+ batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
+ return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
+
+ @tracer.start_as_current_span("filter")
+ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
+ if len(request_ids) == 0:
+ raise ValueError("Batch must have at least one request")
+ # We assume that if len(requests) == len(self) then the requests are the same
+ if len(request_ids) == len(self):
+ return self
+
+ device = self.block_tables_tensor.device
+
+ # New values after filtering
+ requests_idx_mapping = {}
+
+ # Used to index into tensors
+ indices = []
+
+ # slots to keep after filtering
+ slot_filtering_indices = torch.zeros(self.slots.shape[0], dtype=torch.bool)
+
+ # Create on CPU to only move to GPU once instead of at every copy
+ slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
+ max_input_length = 0
+ max_current_length = 0
+
+ requests = []
+ block_tables = []
+ all_input_ids = []
+ input_ids = []
+
+ prompt_lengths = []
+ input_lengths = []
+ cache_lengths = []
+ prefix_offsets = []
+ read_offsets = []
+ cu_slots = [0]
+
+ prefilling_mask = []
+ prefill_logprob_tokens = []
+
+ stopping_criterias = []
+ adapter_set = set()
+
+ num_blocks = 0
+ max_blocks = 0
+ max_slots = 0
+ cumulative_slot_tokens = 0
+
+ for i, request_id in enumerate(request_ids):
+ idx = self.requests_idx_mapping[request_id]
+ indices.append(idx)
+ requests_idx_mapping[request_id] = i
+
+ requests.append(self.requests[idx])
+
+ # Prefilling
+ request_prefilling = self.prefilling_mask[idx]
+ prefilling_mask.append(request_prefilling)
+
+ # Get length
+ request_input_length = self.input_lengths[idx]
+ request_cache_length = self.cache_lengths[idx]
+ max_input_length = max(max_input_length, request_input_length)
+ max_current_length = max(
+ max_current_length, request_cache_length + request_input_length
+ )
+
+ all_input_ids.append(self.all_input_ids[idx])
+
+ prompt_lengths.append(self.prompt_lengths[idx])
+ input_lengths.append(request_input_length)
+ cache_lengths.append(request_cache_length)
+ prefix_offsets.append(self.prefix_offsets[idx])
+ read_offsets.append(self.read_offsets[idx])
+
+ stopping_criteria = self.stopping_criterias[idx]
+ stopping_criterias.append(stopping_criteria)
+
+ prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx])
+
+ ADAPTER_TO_INDEX = get_adapter_to_index()
+ adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
+ adapter_set.add(adapter_index)
+
+ request_block_table = self.block_tables[idx]
+ num_blocks += len(request_block_table)
+ block_tables.append(request_block_table)
+
+ start_slot = self.cu_slots[idx]
+ end_slot = self.cu_slots[idx + 1]
+ slot_length = end_slot - start_slot
+
+ # Set slice
+ slot_filtering_indices[start_slot:end_slot] = True
+
+ cu_slots.append(cumulative_slot_tokens + slot_length)
+
+ # Input ids if the request was part of a prefilling batch
+ # If the batch was decoding we can index into the tensor directly later
+ if self.prefilling:
+ input_ids.append(self.input_ids[idx])
+ else:
+ # Copy to tensor (CPU)
+ slot_indices[i] = cumulative_slot_tokens + request_cache_length
+
+ cumulative_slot_tokens += slot_length
+ max_blocks = max(max_blocks, len(request_block_table))
+ max_slots = max(max_slots, slot_length)
+
+ block_tables_tensor = self.block_tables_tensor[indices]
+ prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
+
+ cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
+
+ slots = self.slots[slot_filtering_indices]
+
+ if self.prefilling:
+ # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
+ position_ids = None
+ slot_indices = None
+ cache_lengths_tensor = None
+ input_lengths_tensor = None
+ adapter_meta = None
+ else:
+ # Index into tensors
+ input_ids = self.input_ids[indices]
+ position_ids = self.position_ids[indices]
+ input_lengths_tensor = self.input_lengths_tensor[indices]
+ cache_lengths_tensor = self.cache_lengths_tensor[indices]
+
+ # Move to GPU now that we have the whole tensor
+ slot_indices = slot_indices.to(device)
+ if self.adapter_meta is not None:
+ adapter_indices = self.adapter_meta.adapter_indices[indices]
+ adapter_segments, adapter_segment_indices = find_segments(
+ adapter_indices
+ )
+ adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
+ adapter_meta = AdapterBatchMetadata(
+ adapter_indices=adapter_indices,
+ adapter_set=adapter_set,
+ adapter_segments=adapter_segments,
+ segment_indices=adapter_segment_indices,
+ )
+ else:
+ adapter_meta = None
+ htorch.core.mark_step()
+ return type(self)(
+ batch_id=self.batch_id,
+ requests=requests,
+ requests_idx_mapping=requests_idx_mapping,
+ input_ids=input_ids,
+ position_ids=position_ids,
+ cu_seqlen_prefill=None,
+ prefill_cache_indices=None,
+ slot_indices=slot_indices,
+ block_tables=block_tables,
+ block_tables_tensor=block_tables_tensor,
+ slots=slots,
+ cu_slots=cu_slots,
+ max_input_length=max_input_length,
+ max_current_length=max_current_length,
+ prefilling=self.prefilling,
+ prefilling_mask=prefilling_mask,
+ prefill_head_indices=None,
+ prefill_next_token_indices=None,
+ prefill_cu_outlens=None,
+ prefill_logprob_tokens=prefill_logprob_tokens,
+ prompt_lengths=prompt_lengths,
+ prompt_lengths_tensor=prompt_lengths_tensor,
+ input_lengths=input_lengths,
+ input_lengths_tensor=input_lengths_tensor,
+ cache_lengths=cache_lengths,
+ cache_lengths_tensor=cache_lengths_tensor,
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
+ all_input_ids=all_input_ids,
+ all_input_ids_tensor=self.all_input_ids_tensor,
+ next_token_chooser=self.next_token_chooser,
+ stopping_criterias=stopping_criterias,
+ top_n_tokens=self.top_n_tokens,
+ top_n_tokens_tensor=self.top_n_tokens_tensor,
+ num_blocks=num_blocks,
+ max_blocks=max_blocks,
+ speculative_ids=self.speculative_ids,
+ adapter_meta=adapter_meta,
+ hpu_attn_meta=None,
+ valid_indices=indices,
+ next_token_logits=self.next_token_logits,
+ speculative_logits=self.speculative_logits,
+ )
+
+ @classmethod
+ @tracer.start_as_current_span("concatenate")
+ def concatenate(
+ cls, batches: List["FlashCausalLMBatch"], padded_total_bs: int = 0
+ ) -> "FlashCausalLMBatch":
+ # Batch attributes
+ requests = []
+ requests_idx_mapping = {}
+
+ prefilling = False
+ num_blocks = 0
+ total_batch_size = 0
+ total_slots = 0
+ max_blocks = 0
+ max_length = 0
+ max_input_length = 0
+ max_current_length = 0
+ ADAPTER_TO_INDEX = get_adapter_to_index()
+ for b in batches:
+ total_batch_size += len(b)
+ max_blocks = max(max_blocks, b.max_blocks)
+ total_slots += len(b.slots)
+ num_blocks += b.num_blocks
+ speculative_length = (
+ b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
+ )
+ max_input_length = max(max_input_length, b.max_input_length)
+ max_current_length = max(max_current_length, b.max_current_length)
+ max_length = max(
+ max_length,
+ max(
+ prompt_length
+ + stopping_criteria.max_new_tokens
+ + speculative_length
+ for prompt_length, stopping_criteria in zip(
+ b.prompt_lengths, b.stopping_criterias
+ )
+ ),
+ )
+ prefilling = prefilling or b.prefilling
+
+ slots = batches[0].slots.new_empty(total_slots)
+ cu_slots = torch.zeros(total_batch_size + 1, dtype=torch.int64)
+ if prefilling:
+ input_ids = []
+ # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
+ position_ids = None
+ slot_indices = None
+ cache_lengths_tensor = None
+ input_lengths_tensor = None
+ adapter_meta = None
+ adapter_segment_builder = None
+ else:
+ if padded_total_bs == batches[0].input_ids.shape[0]:
+ input_ids = batches[0].input_ids
+ else:
+ input_ids = batches[0].input_ids.new_empty(total_batch_size)
+ if (
+ batches[0].position_ids is not None
+ and batches[0].position_ids.dim() == 2
+ ):
+ # Qwen2_vl case:
+ position_ids = batches[0].position_ids.new_empty(
+ (total_batch_size, batches[0].position_ids.shape[-1])
+ )
+ else:
+ position_ids = batches[0].position_ids.new_empty(total_batch_size)
+ slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
+ input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
+ total_batch_size
+ )
+ cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(
+ total_batch_size
+ )
+ if ADAPTER_TO_INDEX:
+ total_indices_size = sum(
+ b.adapter_meta.adapter_indices.shape[0] for b in batches
+ )
+ adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
+ total_indices_size
+ )
+ adapter_segment_builder = SegmentConcatBuilder()
+ adapter_set = set()
+
+ prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
+ total_batch_size
+ )
+ block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
+ (total_batch_size, max_blocks)
+ )
+ all_input_ids_tensor = batches[0].all_input_ids_tensor
+ top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
+ total_batch_size,
+ )
+
+ block_tables = []
+ cache_lengths = []
+ all_input_ids = []
+
+ prompt_lengths = []
+ input_lengths = []
+ prefix_offsets = []
+ read_offsets = []
+
+ prefill_logprob_tokens = []
+
+ next_token_chooser_parameters = []
+ fsm_grammar_states = []
+ stopping_criterias = []
+ top_n_tokens = []
+ prefilling_mask = []
+
+ # Cumulative length
+ cumulative_batch_size = 0
+ cumulative_slots = 0
+ cumulative_adapter_indices_size = 0
+
+ for i, batch in enumerate(batches):
+ requests.extend(batch.requests)
+ valid_bsize = len(batch)
+
+ if i == 0:
+ requests_idx_mapping = batch.requests_idx_mapping
+ else:
+ # We need to offset the mapping for each batch by the cumulative batch size
+ for k, v in batch.requests_idx_mapping.items():
+ requests_idx_mapping[k] = v + cumulative_batch_size
+
+ start_index = cumulative_batch_size
+ end_index = cumulative_batch_size + valid_bsize
+
+ index = torch.tensor(list(range(start_index, end_index)), device="cpu")
+ top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
+ if i > 0:
+ all_input_ids_tensor.index_copy_(
+ 0,
+ index.to(batch.all_input_ids_tensor.device),
+ batch.all_input_ids_tensor[:valid_bsize, :],
+ )
+
+ block_tables_tensor[
+ start_index:end_index, : batch.block_tables_tensor.shape[1]
+ ] = batch.block_tables_tensor[:, :max_blocks]
+ prompt_lengths_tensor.index_copy_(0, index, batch.prompt_lengths_tensor)
+
+ slots_start_index = cumulative_slots
+ slots_end_index = cumulative_slots + len(batch.slots)
+ slot_index = torch.tensor(
+ list(range(slots_start_index, slots_end_index)),
+ device=batch.slots.device,
+ )
+
+ slots.index_copy_(0, slot_index, batch.slots)
+ cu_slots[start_index + 1 : end_index + 1] = (
+ batch.cu_slots[1:] + cumulative_slots
+ )
+
+ if not prefilling:
+ if padded_total_bs != batches[0].input_ids.shape[0] or i > 0:
+ input_ids.index_copy_(
+ 0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
+ )
+ position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize])
+ slot_indices.index_copy_(
+ 0, index, batch.slot_indices + cumulative_slots
+ )
+ input_lengths_tensor.index_copy_(
+ 0, index, batch.input_lengths_tensor[:valid_bsize]
+ )
+ cache_lengths_tensor.index_copy_(
+ 0, index, batch.cache_lengths_tensor[:valid_bsize]
+ )
+ if ADAPTER_TO_INDEX:
+ adapter_start_index = cumulative_adapter_indices_size
+ adapter_end_index = (
+ cumulative_adapter_indices_size
+ + batch.adapter_meta.adapter_indices.shape[0]
+ )
+ adapter_indices[adapter_start_index:adapter_end_index] = (
+ batch.adapter_meta.adapter_indices
+ )
+ cumulative_adapter_indices_size = adapter_end_index
+ adapter_set.update(batch.adapter_meta.adapter_set)
+ adapter_segment_builder.concat(
+ batch.adapter_meta.adapter_segments,
+ batch.adapter_meta.segment_indices,
+ )
+ else:
+ if isinstance(batch.input_ids, torch.Tensor):
+ batch.input_ids = batch.input_ids.view(-1, 1).tolist()
+ input_ids.extend(batch.input_ids)
+
+ prefilling_mask.extend(batch.prefilling_mask)
+ block_tables.extend(batch.block_tables)
+ cache_lengths.extend(batch.cache_lengths)
+ all_input_ids.extend(batch.all_input_ids)
+
+ prompt_lengths.extend(batch.prompt_lengths)
+ input_lengths.extend(batch.input_lengths)
+ prefix_offsets.extend(batch.prefix_offsets)
+ read_offsets.extend(batch.read_offsets)
+
+ prefill_logprob_tokens.extend(batch.prefill_logprob_tokens)
+
+ next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
+ fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
+ stopping_criterias.extend(batch.stopping_criterias)
+
+ top_n_tokens.extend(batch.top_n_tokens)
+
+ # Update
+ cumulative_slots += len(batch.slots)
+ cumulative_batch_size += len(batch)
+
+ next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
+ next_token_chooser_parameters,
+ dtype=batches[0].next_token_chooser.dtype,
+ device=batches[0].next_token_chooser.device,
+ tokenizer=batches[0].next_token_chooser.tokenizer,
+ fsm_grammar_states=fsm_grammar_states,
+ )
+
+ # We skip computing the speculative_ids when the batch size is too large, so
+ # we must check that all batches have them, otherwise they must be discarded
+ if get_speculate() > 0 and all(b.speculative_ids is not None for b in batches):
+ speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
+ else:
+ speculative_ids = None
+
+ if ADAPTER_TO_INDEX and adapter_segment_builder is not None:
+ adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
+ adapter_meta = AdapterBatchMetadata(
+ adapter_indices=adapter_indices,
+ adapter_set=adapter_set,
+ adapter_segments=adapter_segments,
+ segment_indices=adapter_segment_indices,
+ )
+
+ return cls(
+ batch_id=batches[0].batch_id,
+ requests=requests,
+ requests_idx_mapping=requests_idx_mapping,
+ input_ids=input_ids,
+ position_ids=position_ids,
+ cu_seqlen_prefill=None,
+ prefill_cache_indices=None,
+ slot_indices=slot_indices,
+ block_tables=block_tables,
+ block_tables_tensor=block_tables_tensor,
+ cache_lengths=cache_lengths,
+ cache_lengths_tensor=cache_lengths_tensor,
+ slots=slots,
+ cu_slots=cu_slots,
+ max_input_length=max_input_length,
+ max_current_length=max_current_length,
+ prefilling=prefilling,
+ prefilling_mask=prefilling_mask,
+ prefill_head_indices=None,
+ prefill_next_token_indices=None,
+ prefill_cu_outlens=None,
+ prefill_logprob_tokens=prefill_logprob_tokens,
+ prompt_lengths=prompt_lengths,
+ prompt_lengths_tensor=prompt_lengths_tensor,
+ input_lengths=input_lengths,
+ input_lengths_tensor=input_lengths_tensor,
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
+ all_input_ids=all_input_ids,
+ all_input_ids_tensor=all_input_ids_tensor,
+ next_token_chooser=next_token_chooser,
+ stopping_criterias=stopping_criterias,
+ top_n_tokens=top_n_tokens,
+ top_n_tokens_tensor=top_n_tokens_tensor,
+ num_blocks=num_blocks,
+ max_blocks=max_blocks,
+ speculative_ids=speculative_ids,
+ adapter_meta=adapter_meta if ADAPTER_TO_INDEX else None,
+ hpu_attn_meta=None,
+ next_token_logits=None,
+ speculative_logits=None,
+ valid_indices=None,
+ )
+
+ def prepare_for_decode(
+ self, dtype, use_contiguous_pa, bucketing_ctx, pad_token_id, sliding_window
+ ):
+ block_num = [length // BLOCK_SIZE + 1 for length in self.cache_lengths]
+ block_tables = []
+ for i, bt in enumerate(self.block_tables):
+ block_tables.append(bt[0 : block_num[i]])
+ if bucketing_ctx is not None:
+ padded_bs = bucketing_ctx.get_padded_decode_batch_size(
+ self.input_ids.shape[0]
+ )
+ else:
+ padded_bs = self.input_ids.shape[0]
+ slots = self.slots[self.slot_indices]
+
+ block_list, block_groups, block_usage, _, block_bucket_size = (
+ generate_block_metadata(
+ dtype,
+ use_contiguous_pa,
+ slots,
+ block_tables,
+ bucketing_ctx,
+ )
+ )
+ meta = HPUPagedAttentionMetadata(
+ block_list=_async_h2d_tensor_copy(block_list),
+ block_groups=_async_h2d_tensor_copy(block_groups),
+ block_usage=_async_h2d_tensor_copy(block_usage),
+ block_mapping=None,
+ attn_bias=None,
+ )
+ if sliding_window is not None:
+ block_tables_in_window = []
+ for i, bt in enumerate(self.block_tables):
+ block_num_in_window = (
+ sliding_window + 2 * BLOCK_SIZE - 2 - slots[i] % BLOCK_SIZE
+ ) // BLOCK_SIZE
+ block_tables_in_window.append(
+ bt[max(0, block_num[i] - block_num_in_window) : block_num[i]]
+ )
+ slots_in_window = []
+ for i, indice in enumerate(self.slot_indices):
+ start_idx = indice - self.cache_lengths[i]
+ mask = (
+ indice
+ - torch.arange(
+ start_idx,
+ indice + 1,
+ device=self.slots.device,
+ )
+ ) < sliding_window
+ slots_in_window.append(self.slots[start_idx : indice + 1][mask])
+ slots_in_window = torch.cat(slots_in_window, dim=0)
+ (
+ block_list_in_window,
+ block_groups_in_window,
+ block_usage_in_window,
+ slots_in_window_mask,
+ _,
+ ) = generate_block_metadata(
+ dtype,
+ use_contiguous_pa,
+ slots,
+ block_tables_in_window,
+ bucketing_ctx,
+ slots_in_window,
+ block_bucket_size,
+ )
+ meta.block_list_in_window = _async_h2d_tensor_copy(block_list_in_window)
+ meta.block_groups_in_window = _async_h2d_tensor_copy(block_groups_in_window)
+ meta.block_usage_in_window = _async_h2d_tensor_copy(block_usage_in_window)
+ meta.slots_in_window_mask = _async_h2d_tensor_copy(slots_in_window_mask)
+
+ self.hpu_attn_meta = trim_attn_metadata(meta)
+ self.input_ids = F.pad(
+ self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=pad_token_id
+ )
+
+ if self.position_ids.dim() == 2:
+ # Qwen VL case
+ self.position_ids = F.pad(
+ self.position_ids,
+ (0, 0, 0, padded_bs - self.position_ids.shape[0]),
+ value=1,
+ )
+ else:
+ self.position_ids = F.pad(
+ self.position_ids, (0, padded_bs - self.position_ids.shape[0]), value=1
+ )
+ self.input_lengths_tensor = F.pad(
+ self.input_lengths_tensor,
+ (0, padded_bs - self.input_lengths_tensor.shape[0]),
+ value=0,
+ )
+ self.cache_lengths_tensor = F.pad(
+ self.cache_lengths_tensor,
+ (0, padded_bs - self.cache_lengths_tensor.shape[0]),
+ value=0,
+ )
+ if len(self.next_token_chooser.do_sample) != padded_bs:
+ next_token_chooser_parameters = []
+ next_token_chooser_parameters.extend([r.parameters for r in self.requests])
+ pad_next_token_chooser_parameters(next_token_chooser_parameters, padded_bs)
+ # update past grammar states
+ fsm_grammar_states = [0] * padded_bs
+
+ for i, req in enumerate(self.requests):
+ fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i]
+
+ self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
+ next_token_chooser_parameters,
+ self.next_token_chooser.dtype,
+ self.next_token_chooser.device,
+ self.next_token_chooser.tokenizer,
+ fsm_grammar_states,
+ )
+
+ def prepare_for_prefill(
+ self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
+ ):
+ # Prepare values if we need to continue prefilling
+ # Speculation must be ignored while we prefill even with chunking
+ # it simplifies everything
+ assert self.speculative_ids is None
+
+ # device = self.block_tables_tensor.device
+
+ # hpu does not support varlen for prefill, use sdpa instead. so need to pad input_tensor, position
+ # padding to left to work with sliding window
+ # use prefill_cache_indices to indicate the valid kv slot, update prefill_next_token_indices to indicate
+ # the right logit position
+ input_ids_padded_length = []
+ # need extra pad to match warmup seq
+ extra_pad = max_padded_input_len - self.max_input_length
+ extra_pad_bs = max_padded_bs - len(self)
+ device = "hpu"
+ if isinstance(self.input_ids, list) and len(self) > 1:
+ input_ids_padded_length = []
+ input_ids = []
+ for input_id in self.input_ids:
+ padded = self.max_input_length - len(input_id) + extra_pad
+ if padded > 0:
+ input_id = [pad_token_id] * padded + input_id
+ input_ids.append(input_id)
+ input_ids_padded_length.append(padded)
+ input_ids = np.concatenate(input_ids, dtype=np.int64)
+ self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
+ elif isinstance(self.input_ids, list):
+ input_ids = self.input_ids[0]
+ input_ids_padded_length.append(extra_pad)
+ input_ids = [pad_token_id] * extra_pad + input_ids
+ self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
+ else:
+ input_ids = torch.full(
+ (max_padded_input_len * len(self),),
+ pad_token_id,
+ dtype=torch.int64,
+ device=self.input_ids.device,
+ )
+ src_pos = 0
+ for i in range(len(self)):
+ end_pos = (i + 1) * max_padded_input_len
+ start_pos = end_pos - self.input_lengths[i]
+ input_ids[start_pos:end_pos] = self.input_ids[
+ src_pos : src_pos + self.input_lengths[i]
+ ]
+ input_ids_padded_length.append(
+ max_padded_input_len - self.input_lengths[i]
+ )
+ src_pos += self.input_lengths[i]
+ self.input_ids = input_ids
+
+ self.input_ids = F.pad(
+ self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=pad_token_id
+ )
+
+ self.input_lengths_tensor = torch.tensor(self.input_lengths, dtype=torch.int32)
+
+ self.input_lengths_tensor = F.pad(
+ self.input_lengths_tensor, (0, extra_pad_bs), value=0
+ )
+
+ cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(max_padded_bs + 1)
+ torch.cumsum(self.input_lengths_tensor, out=cu_seqlen_prefill[1:], dim=0)
+ self.cu_seqlen_prefill = cu_seqlen_prefill.to(torch.int32)
+ self.cache_lengths_tensor = torch.tensor(self.cache_lengths, dtype=torch.int32)
+ self.cache_lengths_tensor = F.pad(
+ self.cache_lengths_tensor, (0, extra_pad_bs), value=0
+ )
+
+ position_ids = []
+ slot_indices = []
+ prefill_cache_indices = []
+ all_prefill_logprobs = True
+ no_prefill_logprobs = True
+ prefill_cu_outlens = [0]
+
+ # Cumulative length
+ cumulative_length = 0
+ cumulative_slot_tokens = 0
+ prefill_out_cumulative_length = 0
+
+ adapter_indices_list = []
+ adapter_set = set()
+
+ for i, (
+ r,
+ cache_length,
+ input_length,
+ prompt_length,
+ request_prefilling,
+ blocks,
+ ) in enumerate(
+ zip(
+ self.requests,
+ self.cache_lengths,
+ self.input_lengths,
+ self.prompt_lengths,
+ self.prefilling_mask,
+ self.block_tables,
+ )
+ ):
+ next_chunk_length = input_length
+
+ # Position ids
+ request_position_ids = torch.arange(
+ cache_length, cache_length + input_length, dtype=torch.int32
+ )
+ request_position_ids = F.pad(
+ request_position_ids, (input_ids_padded_length[i], 0), value=1
+ )
+ position_ids.append(request_position_ids)
+
+ if not r.slots:
+ request_slots = [
+ s
+ for b in blocks
+ for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
+ ]
+ else:
+ request_slots = r.slots
+
+ request_slot_indices = torch.arange(
+ cache_length + cumulative_slot_tokens,
+ cache_length + cumulative_slot_tokens + input_length,
+ dtype=torch.int64,
+ )
+
+ slot_indices.append(request_slot_indices)
+
+ # Update
+ cumulative_slot_tokens += len(request_slots)
+
+ # Create tensor to slice into the kv tensor in prefill
+ # hpu need request_prefill_cache_indices to skip padding in kv cache
+ sliding_window = input_length
+ cumulative_length += input_ids_padded_length[i]
+ if sliding_window is not None:
+ request_prefill_cache_indices = torch.arange(
+ cumulative_length + max(0, input_length - sliding_window),
+ cumulative_length + input_length,
+ dtype=torch.int64,
+ )
+
+ # Prefill logprobs is ignored if the request is done prefilling
+ prefill_logprobs = r.prefill_logprobs and request_prefilling
+
+ all_prefill_logprobs = all_prefill_logprobs and prefill_logprobs
+ no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
+
+ if prefill_logprobs:
+ prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
+ prefill_out_cumulative_length += input_length
+ else:
+ prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
+ prefill_out_cumulative_length += 1
+
+ prefill_cache_indices.append(request_prefill_cache_indices)
+
+ ADAPTER_TO_INDEX = get_adapter_to_index()
+ if ADAPTER_TO_INDEX:
+ adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
+ adapter_indices_list.append(
+ torch.full((next_chunk_length,), adapter_index)
+ )
+ adapter_set.add(adapter_index)
+
+ # Update
+ cumulative_length += next_chunk_length
+
+ if not all_prefill_logprobs and not no_prefill_logprobs:
+ prefill_head_indices = []
+ prefill_next_token_indices = []
+
+ # Cumulative length
+ cumulative_length = 0
+ prefill_out_cumulative_length = 0
+
+ for i, (
+ r,
+ input_length,
+ request_prefilling,
+ ) in enumerate(
+ zip(
+ self.requests,
+ self.input_lengths,
+ self.prefilling_mask,
+ )
+ ):
+ # Prefill logprobs is ignored if the request is done prefilling
+ prefill_logprobs = r.prefill_logprobs and request_prefilling
+
+ if prefill_logprobs:
+ prefill_head_indices.append(
+ torch.arange(
+ cumulative_length,
+ cumulative_length + input_length,
+ dtype=torch.int32,
+ )
+ )
+ prefill_next_token_indices.append(
+ prefill_out_cumulative_length + input_length - 1
+ )
+ prefill_out_cumulative_length += input_length
+ else:
+ prefill_head_indices.append(
+ torch.tensor(
+ [cumulative_length + input_length - 1],
+ dtype=torch.int32,
+ )
+ )
+ prefill_next_token_indices.append(prefill_out_cumulative_length)
+ prefill_out_cumulative_length += 1
+
+ # Update
+ cumulative_length += input_length
+
+ if len(self) > 1:
+ if position_ids:
+ position_ids = torch.cat(position_ids)
+ if slot_indices:
+ slot_indices = torch.cat(slot_indices)
+ prefill_cache_indices = torch.cat(prefill_cache_indices)
+ else:
+ if position_ids:
+ position_ids = position_ids[0]
+ if slot_indices:
+ slot_indices = slot_indices[0]
+ prefill_cache_indices = prefill_cache_indices[0]
+
+ self.position_ids = position_ids
+ self.position_ids = F.pad(
+ self.position_ids, (0, extra_pad_bs * max_padded_input_len), value=1
+ )
+ self.slot_indices = slot_indices
+
+ self.prefill_cu_outlens = prefill_cu_outlens
+ self.prefill_cache_indices = torch.zeros_like(
+ self.input_ids, dtype=torch.bool, device="cpu"
+ )
+ self.prefill_cache_indices[prefill_cache_indices] = True
+
+ if all_prefill_logprobs:
+ prefill_head_indices = None
+ prefill_next_token_indices = self.cu_seqlen_prefill[1:] - 1
+ elif no_prefill_logprobs:
+ prefill_head_indices = self.cu_seqlen_prefill[1:] - 1
+ prefill_next_token_indices = None
+ else:
+ prefill_head_indices = torch.cat(prefill_head_indices)
+ prefill_next_token_indices = torch.tensor(
+ prefill_next_token_indices, dtype=torch.int64
+ )
+
+ self.prefill_head_indices = prefill_head_indices
+ self.prefill_next_token_indices = prefill_next_token_indices
+ input_ids_padded_length_tensor = torch.cumsum(
+ torch.tensor(input_ids_padded_length, dtype=torch.int32),
+ dim=-1,
+ ).to(torch.int32)
+ input_ids_padded_length_tensor = F.pad(
+ input_ids_padded_length_tensor, (0, extra_pad_bs), value=0
+ )
+ if self.prefill_head_indices is not None:
+ self.prefill_head_indices = (
+ self.prefill_head_indices + input_ids_padded_length_tensor
+ )
+
+ if self.prefill_next_token_indices is not None:
+ self.prefill_next_token_indices = (
+ self.prefill_next_token_indices + input_ids_padded_length_tensor
+ )
+ all_input_ids_tensor = torch.full(
+ (max_padded_bs, max(max_total_tokens, self.all_input_ids_tensor.shape[-1])),
+ pad_token_id,
+ dtype=torch.int64,
+ device="hpu",
+ )
+ for i in range(len(self)):
+ all_input_ids_tensor[i, : self.all_input_ids_tensor.shape[-1]] = (
+ self.all_input_ids_tensor[i]
+ )
+ self.all_input_ids_tensor = all_input_ids_tensor
+ if len(self.next_token_chooser.do_sample) != max_padded_bs:
+ next_token_chooser_parameters = []
+ next_token_chooser_parameters.extend([r.parameters for r in self.requests])
+ pad_next_token_chooser_parameters(
+ next_token_chooser_parameters, max_padded_bs
+ )
+ # update past grammar states
+ fsm_grammar_states = [0] * max_padded_bs
+
+ for i, req in enumerate(self.requests):
+ fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i]
+
+ self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
+ next_token_chooser_parameters,
+ self.next_token_chooser.dtype,
+ self.next_token_chooser.device,
+ self.next_token_chooser.tokenizer,
+ fsm_grammar_states,
+ )
+
+ if ADAPTER_TO_INDEX:
+ if adapter_set:
+ adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64)
+ adapter_segments, adapter_segment_indices = find_segments(
+ adapter_indices
+ )
+ else:
+ adapter_indices = torch.zeros_like(self.input_ids)
+ adapter_segments = [0, len(adapter_indices)]
+ adapter_segment_indices = [len(adapter_indices) - 1]
+
+ adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
+ self.adapter_meta = AdapterBatchMetadata(
+ adapter_indices=adapter_indices,
+ adapter_set=adapter_set,
+ adapter_segments=adapter_segments,
+ segment_indices=adapter_segment_indices,
+ )
+
+ def __len__(self):
+ return len(self.requests)
+
+
+ADAPTER_LAYERS = [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ "o_proj",
+ "gate_proj",
+ "up_proj",
+ "down_proj",
+]
+ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
+
+
+class FlashCausalLM(Model):
+ def __init__(
+ self,
+ model_id: str,
+ model_class,
+ revision: Optional[str] = None,
+ quantize: Optional[str] = None,
+ speculator: Optional[str] = None,
+ dtype: Optional[torch.dtype] = None,
+ trust_remote_code: bool = False,
+ lora_adapter_ids: Optional[list] = [],
+ tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
+ config_class: PreTrainedTokenizerBase = AutoConfig,
+ default_dtype=torch.float16,
+ aliases=None,
+ # Used for Santacoder override of config
+ num_kv_heads: Optional[int] = None,
+ # Deepseek V2 uses different QK and V dims.
+ head_size: Optional[int] = None,
+ skip_special_tokens: bool = True,
+ kv_cache_dtype: Optional[torch.dtype] = None,
+ support_chunking: bool = True,
+ ):
+ self.quantize = quantize
+ self.process_group, rank, world_size = initialize_torch_distributed()
+ if world_size > 1:
+ self.process_group_cpu = torch.distributed.new_group(backend="gloo")
+
+ device = torch.device("hpu")
+ dtype = torch.bfloat16 if dtype is None else dtype
+
+ tokenizer = tokenizer_class.from_pretrained(
+ model_id,
+ revision=revision,
+ padding_side="left",
+ truncation_side="left",
+ trust_remote_code=trust_remote_code,
+ )
+ try:
+ generation_config = GenerationConfig.from_pretrained(
+ model_id, revision=revision, trust_remote_code=trust_remote_code
+ )
+ if isinstance(generation_config.eos_token_id, (list, set)):
+ # TODO Huge hack
+ tokenizer._eos_token_ids = set(generation_config.eos_token_id)
+ except Exception:
+ pass
+
+ config = config_class.from_pretrained(
+ model_id, revision=revision, trust_remote_code=trust_remote_code
+ )
+ config.quantize = quantize
+ config.speculator = speculator
+
+ torch.distributed.barrier(group=self.process_group)
+
+ weights_loader = get_loader(quantize, model_id, revision)
+ filenames = weight_files(model_id, revision=revision, extension=".safetensors")
+ weights = Weights(
+ filenames,
+ device,
+ dtype,
+ process_group=self.process_group,
+ aliases=aliases,
+ weights_loader=weights_loader,
+ )
+
+ prefix = None
+ model = model_class(prefix, config, weights)
+ torch.distributed.barrier(group=self.process_group)
+
+ # VLM models define the config we care about in their text_config
+ text_config = getattr(config, "text_config", None)
+ if text_config is not None:
+ config = text_config
+
+ if getattr(config, "sliding_window", None) is None:
+ config.sliding_window = None
+ if getattr(config, "use_sliding_window", True) is False:
+ config.sliding_window = None
+
+ self.num_layers = config.num_hidden_layers
+ self.num_heads = config.num_attention_heads // self.process_group.size()
+ self.config = config
+ # Validation is done in the model itself
+ if num_kv_heads is None:
+ num_kv_heads = getattr(config, "num_key_value_heads", None)
+ # GPT-2 workaround
+ if num_kv_heads is None:
+ num_kv_heads = getattr(config, "n_head", None)
+ if num_kv_heads is None:
+ raise ValueError("Cannot get the number of key/value heads")
+ self.num_kv_heads = (
+ num_kv_heads // self.process_group.size()
+ if num_kv_heads // self.process_group.size() > 0
+ else num_kv_heads
+ )
+ assert self.num_kv_heads > 0
+
+ if head_size is None:
+ # Some models use GQA and different sizes for o_proj
+ # and q_proj, that allows for that.
+ if getattr(config, "head_dim", None) is not None:
+ self.head_size = config.head_dim
+ else:
+ self.head_size = config.hidden_size // config.num_attention_heads
+ else:
+ self.head_size = head_size
+
+ self.cuda_graphs = {}
+ self.kv_cache = []
+ self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
+ self.bucketing_ctx = None
+ self.max_total_tokens = None
+ self.max_input_tokens = None
+ htorch.core.hpu_set_env()
+ if htorch.utils.internal.is_lazy():
+ htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True)
+ environment.set_model_config(self.config)
+ self.use_contiguous_pa = (
+ os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true"
+ )
+ self.limit_hpu_graph = (
+ os.environ.get("LIMIT_HPU_GRAPH", "false").lower() == "true"
+ )
+ self.skip_warmup = os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true"
+ self.max_seq_len_to_capture = 8192
+ if tokenizer.pad_token_id is None:
+ if config.pad_token_id is not None:
+ tokenizer.pad_token_id = config.pad_token_id
+ elif config.eos_token_id is not None:
+ tokenizer.pad_token_id = (
+ config.eos_token_id[0]
+ if isinstance(config.eos_token_id, list)
+ else config.eos_token_id
+ )
+ elif tokenizer.eos_token_id is not None:
+ tokenizer.pad_token_id = tokenizer.eos_token_id
+ else:
+ tokenizer.pad_token_id = 0
+ super().__init__(
+ model_id=model_id,
+ model=model,
+ tokenizer=tokenizer,
+ requires_padding=False,
+ dtype=dtype,
+ device=device,
+ rank=rank,
+ world_size=world_size,
+ sliding_window=config.sliding_window,
+ support_chunking=support_chunking,
+ )
+
+ @property
+ def batch_type(self) -> Type[FlashCausalLMBatch]:
+ return FlashCausalLMBatch
+
+ def max_past(self) -> int:
+ return getattr(self.model, "max_past", None)
+
+ def init_kv_cache(
+ self,
+ num_blocks: int,
+ num_layers: int,
+ num_heads: int,
+ head_size: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ ):
+ self.kv_cache = []
+ empty_cache()
+ if self.config.model_type in ["deepseek_v3", "deepseek_v2"]:
+ self.kv_cache = [
+ KVCompressCache(
+ num_blocks=num_blocks,
+ head_size=self.config.kv_lora_rank + self.config.qk_rope_head_dim,
+ dtype=dtype,
+ device=device,
+ )
+ for _ in range(num_layers)
+ ]
+ else:
+ self.kv_cache = [
+ KVCache(
+ num_blocks=num_blocks,
+ num_heads=num_heads,
+ head_size=head_size,
+ dtype=dtype,
+ device=device,
+ )
+ for _ in range(num_layers)
+ ]
+
+ def warmup(
+ self,
+ batch: FlashCausalLMBatch,
+ max_input_tokens: Optional[int],
+ max_total_tokens: Optional[int],
+ ):
+ if os.environ.get("MAX_BATCH_SIZE") is None:
+ raise RuntimeError(
+ "MAX_BATCH_SIZE is not set, it should be set in the launcher "
+ "using `--max-batch-size xxx`"
+ )
+ # The warmup batch is the biggest batch we could ever receive
+ self.kv_cache = []
+ empty_cache()
+ self.graphed_buckets = set()
+ # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
+ # Calculate the number of blocks that can be allocated with the free memory
+ dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
+ if self.config.model_type in ["deepseek_v3", "deepseek_v2"]:
+ cache_block_size = BLOCK_SIZE * (
+ self.config.kv_lora_rank + self.config.qk_rope_head_dim
+ )
+ else:
+ cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
+ cache_block_size = cache_block_size * 2
+ total_cache_size = self.num_layers * cache_block_size * dtype_size
+ free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM)
+ self.mem_reserved = int(free_memory * (1 - MEMORY_FRACTION))
+ graph_reserved_mem = (
+ float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1"))
+ if htorch.utils.internal.is_lazy()
+ else 0
+ )
+ mem_used_from_graph = int(
+ (free_memory - self.mem_reserved) * graph_reserved_mem
+ )
+ log_master(
+ logger.info,
+ f"Free memory on device {self.device}: {format_bytes(free_memory)} used_for_graph: {format_bytes(mem_used_from_graph)} ratio {graph_reserved_mem} reserved_for_runtime: {format_bytes(self.mem_reserved)}",
+ )
+ if max_total_tokens is None:
+ max_total_tokens = sum(batch.input_lengths)
+
+ if max_input_tokens is None:
+ max_input_tokens = max_total_tokens - 1
+
+ self.max_total_tokens = max_total_tokens
+ self.max_input_tokens = max_input_tokens
+ try:
+ self.init_kv_cache(
+ batch.num_blocks,
+ self.num_layers,
+ self.num_kv_heads,
+ self.head_size,
+ self.kv_cache_dtype,
+ self.device,
+ )
+
+ batch_num_blocks = batch.num_blocks
+
+ num_tokens = batch.to_pb().current_tokens
+ synchronize(self.device)
+ _, _batch, _ = self.generate_token([batch])
+ except Exception:
+ raise RuntimeError(
+ f"Not enough memory to handle {num_tokens} prefill tokens. "
+ f"You need to decrease `--max-batch-prefill-tokens`"
+ )
+
+ synchronize(self.device)
+ free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM)
+
+ kv_memory = free_memory - self.mem_reserved - mem_used_from_graph
+ num_blocks = (
+ # Leave 5% for some wiggle room
+ int(kv_memory // total_cache_size)
+ # Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
+ + batch_num_blocks
+ )
+
+ log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
+
+ self.kv_cache = []
+ empty_cache()
+ self.init_kv_cache(
+ num_blocks,
+ self.num_layers,
+ self.num_kv_heads,
+ self.head_size,
+ self.kv_cache_dtype,
+ self.device,
+ )
+ self.max_batch_prefill_tokens = get_max_prefill_tokens()
+ max_num_seqs = int(os.getenv("MAX_BATCH_SIZE"))
+ HPUBucketingContext = get_bucketing_context()
+ # need to warmup one more step since block is allocated from 1
+ block_step = os.getenv("VLLM_DECODE_BLOCK_BUCKET_STEP", BLOCK_SIZE)
+ max_total_tokens_aligned = math.ceil(
+ max_total_tokens / BLOCK_SIZE
+ ) * BLOCK_SIZE + math.ceil(block_step * BLOCK_SIZE / max_num_seqs)
+ model_max_length = self.tokenizer.model_max_length
+ max_position_embeddings = getattr(
+ self.config, "max_position_embeddings", model_max_length
+ )
+
+ self.bucketing_ctx = HPUBucketingContext(
+ max_num_seqs,
+ max_num_seqs, # self.max_num_prefill_seqs, #TODO
+ BLOCK_SIZE,
+ max_num_seqs * max_total_tokens_aligned,
+ False,
+ min(model_max_length, max_position_embeddings),
+ max_input_tokens,
+ max_total_tokens_aligned,
+ )
+ max_blocks = max(
+ BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE
+ )
+ self.bucketing_ctx.num_hpu_blocks = min(max_blocks, num_blocks)
+ synchronize(self.device)
+ if self.skip_warmup:
+ self.bucketing_ctx.generate_prompt_buckets()
+ self.bucketing_ctx.generate_decode_buckets(
+ self.bucketing_ctx.num_hpu_blocks
+ )
+ log_master(
+ logger.info, "skip warmup hpu graph, not recommmended, may cause OOM"
+ )
+ del _batch, batch
+ return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
+ self.warmup_hpu_graph(batch)
+ del _batch, batch
+
+ return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
+
+ def log_warmup(self, prefilling, i, max_i, batch_size, seq_len):
+ free_mem = format_bytes(HabanaMemoryProfiler.current_free_device_memory())
+ phase = "Prompt" if prefilling else "Decode"
+ dim = "seq_len" if prefilling else "num_blocks"
+ graphed_bucket = (batch_size, seq_len, prefilling)
+ bypass = graphed_bucket not in self.graphed_buckets
+ msg = (
+ f"[Warmup][{phase}][{i+1}/{max_i}] "
+ f"batch_size:{batch_size} "
+ f"{dim}:{seq_len} "
+ f"bypass:{bypass} "
+ f"free_mem:{free_mem}"
+ ", this may take a while..."
+ )
+ log_master(logger.info, msg)
+
+ def use_graphs(self, prefill, seq_len, batch_size):
+ if self.limit_hpu_graph and prefill:
+ return False
+
+ if self.skip_warmup:
+ return True
+
+ return (batch_size, seq_len, prefill) in self.graphed_buckets
+
+ def align_workers(self, value, op):
+ if self.world_size <= 1:
+ return value
+ value_t = torch.tensor(value, device="cpu")
+ torch.distributed.all_reduce(value_t, op=op, group=self.process_group_cpu)
+ return value_t.item()
+
+ def warmup_hpu_graph(self, batch):
+ prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3"))
+ free_mem = HabanaMemoryProfiler.current_free_device_memory()
+ graph_free_mem = free_mem - self.mem_reserved
+ graph_free_mem = self.align_workers(
+ graph_free_mem, torch.distributed.ReduceOp.MIN
+ )
+ prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
+ decode_available_memory = graph_free_mem - prompt_available_memory
+ msg = (
+ f"Using {format_bytes(graph_free_mem)}"
+ f"/{format_bytes(free_mem)} "
+ "of free device memory for HPUGraphs, "
+ f"{format_bytes(prompt_available_memory)} for prompt and "
+ f"{format_bytes(decode_available_memory)} for decode "
+ f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})"
+ )
+ log_master(logger.info, msg)
+ start_time = time.time()
+ warmup_shape_count = 0
+ warmup_times = 3
+ self.bucketing_ctx.generate_prompt_buckets()
+
+ def ordering_function_min_tokens(b):
+ return (b[0] * b[1], b[1], b[0])
+
+ buckets = list(
+ sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
+ )
+ total_batch_seq = 0.001
+ total_mem = 0
+ available_mem = prompt_available_memory
+ msg = (
+ f"Prefill batch size list:{[bsz[0] for bsz in buckets]}\n"
+ f"Prefill sequence length list:{[seq[1] for seq in buckets]}\n"
+ )
+ log_master(logger.info, msg)
+ for i, (batch_size, seq_len) in enumerate(buckets):
+ if batch_size * seq_len > self.max_batch_prefill_tokens:
+ continue
+ # Graph memory usage is proportional to seq dimension in a batch
+ batch_seq = batch_size * seq_len
+ mem_estimate = batch_seq / total_batch_seq * total_mem
+ graphed_bucket = (batch_size, seq_len, True)
+ if not (
+ mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture
+ ):
+ if graphed_bucket not in self.graphed_buckets:
+ self.graphed_buckets.add(graphed_bucket)
+ warmup_shape_count += 1
+ self.log_warmup(True, i, len(buckets), batch_size, seq_len)
+ with HabanaMemoryProfiler() as mem_prof:
+ for index in range(warmup_times):
+ self.warmup_prefill(seq_len, batch_size, batch)
+ synchronize(self.device)
+ used_mem = self.align_workers(
+ mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
+ )
+ if graphed_bucket in self.graphed_buckets:
+ available_mem -= used_mem
+ total_mem += used_mem
+ total_batch_seq += batch_seq
+
+ log_master(logger.info, "Prefill warmup successful.\n")
+
+ def ordering_function_max_bs(b):
+ return (-b[0], b[1])
+
+ self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
+ buckets = list(
+ sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
+ )
+ free_mem = HabanaMemoryProfiler.current_free_device_memory()
+ total_batch_seq = 0.001
+ total_mem = 0
+ available_mem = free_mem - self.mem_reserved
+ log_master(
+ logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n"
+ )
+ for i, (batch_size, block_num) in enumerate(buckets):
+ if batch_size > block_num:
+ continue
+ # Graph memory usage is proportional to seq dimension in a batch
+ batch_seq = batch_size
+ mem_estimate = batch_seq / total_batch_seq * total_mem
+ graphed_bucket = (batch_size, block_num, False)
+ if not mem_estimate >= available_mem:
+ if graphed_bucket not in self.graphed_buckets:
+ self.graphed_buckets.add(graphed_bucket)
+ warmup_shape_count += 1
+ self.log_warmup(False, i, len(buckets), batch_size, block_num)
+ with HabanaMemoryProfiler() as mem_prof:
+ for index in range(warmup_times):
+ self.warmup_decode(batch_size, block_num, batch)
+ synchronize(self.device)
+ used_mem = self.align_workers(
+ mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
+ )
+ if graphed_bucket in self.graphed_buckets:
+ available_mem -= used_mem
+ total_mem += used_mem
+ total_batch_seq += batch_seq
+
+ log_master(logger.info, "Decode warmup successful.\n")
+
+ log_master(
+ logger.info,
+ f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
+ )
+
+ def warmup_prefill(
+ self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch
+ ):
+ input_ids = torch.zeros(prompt_len, dtype=batch.input_ids.dtype).repeat(
+ batch_size
+ )
+ position_ids = torch.arange(prompt_len, dtype=batch.position_ids.dtype).repeat(
+ batch_size
+ )
+ max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size
+ block_tables = torch.arange(max_bt, dtype=torch.int32).reshape(batch_size, -1)
+ slot_acc = []
+ for i in range(batch_size):
+ slots = []
+ for b in block_tables[i]:
+ slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
+ slot_acc.extend(slots[:prompt_len])
+ slots = torch.tensor(slot_acc, dtype=batch.slots.dtype)
+
+ input_lengths = torch.ones(batch_size, dtype=torch.int32) * prompt_len
+ cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32)
+ torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
+
+ seqlen = Seqlen(
+ input_lengths=_async_h2d_tensor_copy(input_lengths),
+ )
+ lm_head_indices = input_lengths - 1
+ kwargs = {}
+ if htorch.utils.internal.is_lazy():
+ kwargs["bypass_hpu_graphs"] = not self.use_graphs(
+ True, prompt_len, batch_size
+ )
+ if self.sliding_window is not None:
+ attn_mask = seqlen.make_sliding_window_bias(
+ input_lengths.tolist(),
+ self.sliding_window,
+ self.dtype,
+ prompt_len,
+ batch_size,
+ )
+ seqlen.attn_mask = _async_h2d_tensor_copy(attn_mask)
+
+ # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
+ self.model.forward(
+ input_ids=_async_h2d_tensor_copy(input_ids),
+ position_ids=_async_h2d_tensor_copy(position_ids),
+ cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
+ kv_cache=self.kv_cache,
+ slots=_async_h2d_tensor_copy(slots),
+ seqlen=trim_seqlen_metadata(seqlen),
+ lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
+ adapter_data=None,
+ hpu_attention_meta=None,
+ **kwargs,
+ )
+
+ def warmup_decode(self, batch_size: int, block_num: int, batch: FlashCausalLMBatch):
+ input_ids = torch.zeros(batch_size, dtype=batch.input_ids.dtype)
+ position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype)
+ blocks = [block_num // batch_size for _ in range(batch_size)]
+ blocks[0] += block_num % batch_size
+ block_tables = []
+ slots = []
+ start_idx = 0
+ slot_indices = []
+
+ # fetch the last blocked to warmup block num
+ for i in range(batch_size):
+ block_array = list(range(start_idx, start_idx + blocks[i]))
+ slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
+ slot_indices.append((start_idx + blocks[i]) * BLOCK_SIZE - 1)
+ block_tables.append(block_array)
+ start_idx += blocks[i]
+ input_lengths = torch.ones(batch_size, dtype=torch.int32)
+ cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32)
+ torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
+
+ seqlen = Seqlen(
+ input_lengths=_async_h2d_tensor_copy(input_lengths),
+ )
+ block_list, block_groups, block_usage, _, block_bucket_size = (
+ generate_block_metadata(
+ self.dtype,
+ self.use_contiguous_pa,
+ slots,
+ block_tables,
+ self.bucketing_ctx,
+ )
+ )
+ meta = HPUPagedAttentionMetadata(
+ block_list=_async_h2d_tensor_copy(block_list),
+ block_groups=_async_h2d_tensor_copy(block_groups),
+ block_usage=_async_h2d_tensor_copy(block_usage),
+ block_mapping=None,
+ attn_bias=None,
+ )
+ if self.sliding_window is not None:
+ block_tables_in_window = []
+ for i, bt in enumerate(block_tables):
+ block_num_in_window = (
+ self.sliding_window + BLOCK_SIZE - 1
+ ) // BLOCK_SIZE
+ block_tables_in_window.append(
+ bt[max(0, blocks[i] - block_num_in_window) : blocks[i]]
+ )
+ slots_in_window = []
+ start_idx = 0
+ for i, indice in enumerate(slot_indices):
+ mask = (
+ indice - torch.arange(start_idx, indice + 1)
+ ) < self.sliding_window
+ slots_in_window.append(torch.arange(start_idx, indice + 1)[mask])
+ start_idx += blocks[i] * BLOCK_SIZE
+ slots_in_window = torch.cat(slots_in_window, dim=0)
+ (
+ block_list_in_window,
+ block_groups_in_window,
+ block_usage_in_window,
+ slots_in_window_mask,
+ _,
+ ) = generate_block_metadata(
+ self.dtype,
+ self.use_contiguous_pa,
+ slots,
+ block_tables_in_window,
+ self.bucketing_ctx,
+ slots_in_window,
+ block_bucket_size,
+ )
+ meta.block_list_in_window = _async_h2d_tensor_copy(block_list_in_window)
+ meta.block_groups_in_window = _async_h2d_tensor_copy(block_groups_in_window)
+ meta.block_usage_in_window = _async_h2d_tensor_copy(block_usage_in_window)
+ meta.slots_in_window_mask = _async_h2d_tensor_copy(slots_in_window_mask)
+
+ hpu_attention_meta = trim_attn_metadata(meta)
+ slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
+ kwargs = {}
+ if htorch.utils.internal.is_lazy():
+ kwargs["bypass_hpu_graphs"] = not self.use_graphs(
+ False, hpu_attention_meta.block_list.shape[0], batch_size
+ )
+ # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
+ self.model.forward(
+ input_ids=_async_h2d_tensor_copy(input_ids),
+ position_ids=_async_h2d_tensor_copy(position_ids),
+ cu_seqlen_prefill=None,
+ kv_cache=self.kv_cache,
+ slots=_async_h2d_tensor_copy(slots_tensor),
+ seqlen=trim_seqlen_metadata(seqlen),
+ lm_head_indices=None,
+ adapter_data=None,
+ hpu_attention_meta=hpu_attention_meta,
+ **kwargs,
+ )
+
+ def forward(
+ self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ # Model Forward
+ if batch.speculative_ids is not None:
+ input_ids = batch.input_ids
+ position_ids = batch.position_ids
+ cu_seqlen_prefill = batch.cu_seqlen_prefill
+ kv_cache = self.kv_cache
+ block_tables = batch.block_tables_tensor
+ slots = batch.slots[batch.slot_indices]
+ input_lengths = batch.input_lengths_tensor
+ max_s = batch.max_current_length
+ lm_head_indices = batch.prefill_head_indices
+
+ speculative_ids = batch.speculative_ids
+
+ B, speculative_length = speculative_ids.shape
+ new_length = speculative_length + 1
+ new_input_ids = torch.cat(
+ [input_ids.unsqueeze(-1), speculative_ids], dim=1
+ ).reshape(-1)
+ arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
+ arange_int = arange.to(dtype=torch.int32)
+ new_position_ids = (
+ position_ids.unsqueeze(-1).expand(B, new_length) + arange
+ ).view(-1)
+
+ # Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices,
+ # then update the slots with the additional indices to ensure we're grabbing the ones that have been
+ # allocated
+ slot_indices = (
+ batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int
+ ).view(-1)
+ slots = batch.slots[slot_indices]
+
+ input_lengths = (
+ input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
+ ).view(-1)
+
+ # Add Copy the block tables for all members
+ block_tables = (
+ block_tables.unsqueeze(1)
+ .expand(B, new_length, -1)
+ .reshape(B * new_length, -1)
+ .contiguous()
+ )
+ max_s = max_s + speculative_length
+
+ input_ids = new_input_ids
+ position_ids = new_position_ids
+ else:
+ input_ids = batch.input_ids
+ position_ids = batch.position_ids
+ cu_seqlen_prefill = batch.cu_seqlen_prefill
+ kv_cache = self.kv_cache
+ block_tables = batch.block_tables_tensor
+ slots = batch.slots[batch.slot_indices]
+ input_lengths = batch.input_lengths_tensor
+ max_s = batch.max_current_length
+ lm_head_indices = batch.prefill_head_indices
+
+ if cu_seqlen_prefill is None and self.max_past() is not None:
+ # In decode, not prefill, we're actually overwriting the KV-cache
+ # in a circular buffer mode.
+ # This makes sure the max_s for the decode pass is correct.
+ max_s = min(self.max_past(), max_s)
+ if batch.prefill_cache_indices is not None:
+ slots_pad = torch.zeros_like(input_ids, device=slots.device)
+ slots_pad[batch.prefill_cache_indices] = slots
+ slots = slots_pad
+ else:
+ slots_pad = torch.zeros_like(input_ids, device=slots.device)
+ slots_pad[: slots.shape[0]] = slots
+ slots = slots_pad
+ seqlen = Seqlen(
+ input_lengths=_async_h2d_tensor_copy(input_lengths),
+ )
+
+ kwargs = {}
+ batch_size = input_lengths.shape[0]
+ prompt_len = (
+ input_ids.shape[0] // batch_size
+ if batch.prefilling
+ else batch.hpu_attn_meta.block_list.shape[0]
+ )
+ if htorch.utils.internal.is_lazy():
+ kwargs["bypass_hpu_graphs"] = not self.use_graphs(
+ batch.prefilling, prompt_len, batch_size
+ )
+ if self.sliding_window is not None and batch.prefilling:
+ attn_mask = seqlen.make_sliding_window_bias(
+ input_lengths.tolist(),
+ self.sliding_window,
+ self.dtype,
+ prompt_len,
+ batch_size,
+ )
+ seqlen.attn_mask = _async_h2d_tensor_copy(attn_mask)
+
+ logits, speculative_logits = self.model.forward(
+ input_ids=input_ids,
+ position_ids=_async_h2d_tensor_copy(position_ids),
+ cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
+ kv_cache=kv_cache,
+ slots=_async_h2d_tensor_copy(slots),
+ seqlen=trim_seqlen_metadata(seqlen),
+ lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
+ # TODO not support adapter now, need the add in the future
+ adapter_data=None,
+ hpu_attention_meta=batch.hpu_attn_meta,
+ **kwargs,
+ )
+ return logits, speculative_logits
+
+ @tracer.start_as_current_span("generate_token")
+ def generate_token(
+ self, batches: List[FlashCausalLMBatch]
+ ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
+
+ # In order to pipeline any actions on CPU we perform the operation in 3 main stages:
+ # Stage 1. Collect next token ids of any previously started generations
+ start = time.time_ns()
+ prev_batches = []
+ requests_to_generate = []
+ for batch_id, batch in enumerate(batches):
+ if batch.next_token_logits is not None:
+ prefill = batch.prefilling
+ if batch.prefilling:
+ batch.prefilling = False
+ batch.prefilling_mask = [False] * len(batch)
+
+ speculate = get_speculate()
+ (
+ next_input_ids,
+ next_token_logprobs,
+ logprobs,
+ accepted_ids,
+ speculative_ids,
+ ) = batch.next_token_chooser(
+ batch.all_input_ids_tensor[
+ : batch.next_token_logits.shape[0], : batch.max_current_length
+ ],
+ batch.next_token_logits,
+ speculate,
+ batch.speculative_ids,
+ batch.speculative_logits,
+ )
+
+ batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
+ batch.top_n_tokens,
+ _async_h2d_tensor_copy(batch.top_n_tokens_tensor),
+ logprobs,
+ accepted_ids,
+ )
+ if batch.valid_indices is not None:
+ # TODO speculative decoding handling missing
+ index = torch.arange(
+ 0,
+ len(batch.valid_indices),
+ device=batch.all_input_ids_tensor.device,
+ )
+ batch.all_input_ids_tensor.index_copy_(
+ 0, index, batch.all_input_ids_tensor[batch.valid_indices]
+ )
+ padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size(
+ len(batch.valid_indices)
+ )
+ next_input_ids.index_copy_(
+ 0, index, next_input_ids[batch.valid_indices]
+ )
+ next_input_ids = next_input_ids[:padded_total_bs]
+
+ next_token_logprobs.index_copy_(
+ 0, index, next_token_logprobs[batch.valid_indices]
+ )
+ accepted_ids.index_copy_(
+ 0, index, accepted_ids[batch.valid_indices]
+ )
+ if speculative_ids is not None:
+ speculative_ids = speculative_ids[batch.valid_indices]
+ batch.top_n_tokens_tensor = batch.top_n_tokens_tensor[
+ batch.valid_indices
+ ]
+ top_n_tokens = []
+ batch_top_token_ids_v = []
+ batch_top_token_logprobs_v = []
+ for i in batch.valid_indices:
+ top_n_tokens.append(batch.top_n_tokens[i])
+ batch_top_token_ids_v.append(batch_top_token_ids[i])
+ batch_top_token_logprobs_v.append(batch_top_token_logprobs[i])
+ batch_top_token_ids = batch_top_token_ids_v
+ batch_top_token_logprobs = batch_top_token_logprobs_v
+ batch.top_n_tokens = top_n_tokens
+ batch.next_token_chooser = batch.next_token_chooser.filter(
+ batch.valid_indices
+ )
+ batch.valid_indices = None
+
+ # Since we are done prefilling, all the tensors that were concatenating values for all the requests
+ # instantly become of shape [BATCH_SIZE]
+ if prefill:
+ indices = batch.cu_seqlen_prefill[1:] - 1
+ # pad in left
+ if batch.prefill_cache_indices is not None:
+ batch.position_ids = batch.position_ids[
+ batch.prefill_cache_indices
+ ][indices]
+ else:
+ batch.position_ids = batch.position_ids[indices]
+
+ batch.slot_indices = batch.slot_indices[indices[: len(batch)]]
+ if batch.adapter_meta is not None:
+ batch.adapter_meta.adapter_indices = (
+ batch.adapter_meta.adapter_indices[indices]
+ )
+ # For each member of the batch
+ # Cumulative length
+
+ if batch.speculative_logits is not None:
+ cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
+ torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
+ for i in range(len(batch)):
+ batch.all_input_ids_tensor[
+ i,
+ batch.cache_lengths[i]
+ + batch.input_lengths[i] : batch.cache_lengths[i]
+ + batch.input_lengths[i]
+ + accepted_ids[i],
+ ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
+ batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
+ accepted_ids = accepted_ids.cpu()
+ if batch.position_ids.dim() == 2:
+ # Qwen2_vl case:
+ batch.position_ids += accepted_ids.unsqueeze(-1)
+ else:
+ batch.position_ids += accepted_ids
+ batch.cache_lengths_tensor += (
+ batch.input_lengths_tensor + accepted_ids - 1
+ )
+ batch.input_lengths_tensor = torch.ones_like(
+ batch.input_lengths_tensor
+ )
+ batch.slot_indices += accepted_ids[: len(batch)]
+ else:
+ index = batch.cache_lengths_tensor + batch.input_lengths_tensor
+ index = F.pad(
+ index, (0, next_input_ids.shape[0] - index.shape[0]), value=0
+ )
+ index = index.to(batch.all_input_ids_tensor.device)
+ batch_idx = torch.arange(
+ 0,
+ index.shape[0],
+ dtype=torch.long,
+ device=batch.all_input_ids_tensor.device,
+ )
+ batch.all_input_ids_tensor.index_put_(
+ (batch_idx, index.long()), next_input_ids
+ )
+ batch.input_ids = next_input_ids
+ batch.position_ids += 1
+ batch.cache_lengths_tensor += batch.input_lengths_tensor
+ batch.input_lengths_tensor = torch.ones_like(
+ batch.input_lengths_tensor
+ )
+ batch.slot_indices += 1
+
+ batch.speculative_ids = speculative_ids
+
+ # Does a HPU <-> CPU sync internally
+ if prefill and batch.adapter_meta is not None:
+ # adjust segment lengths to account for all request lengths being 1 during decoding
+ adapter_segments, _ = find_segments(
+ batch.adapter_meta.adapter_indices
+ )
+ batch.adapter_meta.adapter_segments = torch.tensor(
+ adapter_segments,
+ dtype=torch.int32,
+ device=batch.adapter_meta.adapter_segments.device,
+ )
+ prev_batches.append(
+ {
+ "next_token_ids": next_input_ids,
+ "next_token_logprobs": next_token_logprobs,
+ "accepted_ids": accepted_ids,
+ }
+ )
+ idx = len(prev_batches) - 1
+
+ for req_idx, req in enumerate(batch.requests):
+ new_input_length = 1
+ if batch.speculative_logits is not None:
+ new_cache_length = (
+ batch.cache_lengths[req_idx]
+ + batch.input_lengths[req_idx]
+ + accepted_ids[req_idx]
+ - 1
+ )
+ else:
+ new_cache_length = (
+ batch.cache_lengths[req_idx] + batch.input_lengths[req_idx]
+ )
+ batch.cache_lengths[req_idx] = new_cache_length
+ batch.max_input_length = max(
+ batch.max_input_length, new_input_length
+ )
+ batch.input_lengths[req_idx] = new_input_length
+ current_length = new_cache_length + new_input_length
+ batch.max_current_length = max(
+ batch.max_current_length, current_length
+ )
+
+ requests_to_generate.append(
+ {
+ "idx": idx,
+ "request_id": req.id,
+ "prefix_offset": batch.prefix_offsets[req_idx],
+ "read_offset": batch.read_offsets[req_idx],
+ "stopping_criteria": batch.stopping_criterias[req_idx],
+ "all_input_ids": batch.all_input_ids[req_idx],
+ "do_sample": batch.next_token_chooser.do_sample[req_idx],
+ "seed": batch.next_token_chooser.seeds[req_idx],
+ "top_n_tokens": batch.top_n_tokens[req_idx],
+ "top_token_ids": batch_top_token_ids[req_idx],
+ "top_token_logprobs": batch_top_token_logprobs[req_idx],
+ }
+ )
+ if prefill:
+ # We do not need prefill tensors anymore
+ batch.cu_seqlen_prefill = None
+ batch.prefill_cache_indices = None
+ batch.prefill_cu_outlens = None
+ batch.prefill_head_indices = None
+ batch.prefill_next_token_indices = None
+ batch.next_token_logits = None
+ batch.speculative_ids = None
+
+ htorch.core.mark_step()
+ # Stage 2. Prepare new batch for speculative scheduling
+ if len(batches) > 1:
+ if self.bucketing_ctx is not None:
+ total_batch_size = 0
+ for b in batches:
+ total_batch_size += len(b)
+ padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size(
+ total_batch_size
+ )
+ batch = self.batch_type.concatenate(
+ batches, padded_total_bs=padded_total_bs
+ )
+ else:
+ batch = self.batch_type.concatenate(batches)
+ else:
+ batch = batches[0]
+ prefill = batch.prefilling
+ if prefill:
+ if self.bucketing_ctx is not None:
+ batch.prepare_for_prefill(
+ self.bucketing_ctx.get_padded_prompt_seq_len(
+ batch.max_input_length
+ ),
+ self.bucketing_ctx.get_padded_prompt_batch_size(len(batch)),
+ self.max_total_tokens,
+ self.tokenizer.pad_token_id,
+ )
+ else:
+ batch.prepare_for_prefill(
+ batch.max_input_length,
+ len(batch),
+ self.max_total_tokens,
+ self.tokenizer.pad_token_id,
+ )
+ else:
+ batch.prepare_for_decode(
+ self.dtype,
+ self.use_contiguous_pa,
+ self.bucketing_ctx,
+ self.tokenizer.pad_token_id,
+ self.sliding_window,
+ )
+ if hasattr(self, "set_inputs_embeds") and callable(self.set_inputs_embeds):
+ self.set_inputs_embeds(batch)
+ prefill_logprobs = batch.prefill_next_token_indices is not None
+ # Update adapter indices for speculative tokens (if present)
+ adapter_meta = batch.adapter_meta
+ if adapter_meta is not None:
+ if batch.speculative_ids is not None:
+ B, speculative_length = batch.speculative_ids.shape
+ new_length = speculative_length + 1
+ adapter_indices = (
+ adapter_meta.adapter_indices.unsqueeze(-1)
+ .expand(B, new_length)
+ .reshape(-1)
+ )
+ adapter_segments = adapter_meta.adapter_segments * new_length
+ adapter_meta = AdapterBatchMetadata(
+ adapter_indices=adapter_indices,
+ adapter_set=adapter_meta.adapter_set,
+ adapter_segments=adapter_segments,
+ segment_indices=adapter_meta.segment_indices,
+ )
+
+ # Assign pointers to adapter weights
+ # TODO(travis): don't update this if indices haven't changed
+ adapter_data = AdapterBatchData.from_meta(
+ adapter_meta,
+ self.layer_to_adapter_weights,
+ prefill,
+ batch.prefill_head_indices,
+ )
+ else:
+ adapter_data = None
+
+ out, speculative_logits = self.forward(batch, adapter_data)
+
+ if prefill:
+ batch.next_token_logits = (
+ out[batch.prefill_next_token_indices] if prefill_logprobs else out
+ )
+ if speculative_logits is not None:
+ speculative_logits = (
+ speculative_logits[batch.prefill_next_token_indices]
+ if prefill_logprobs
+ else speculative_logits
+ )
+ else:
+ prefill_logprobs = None
+ batch.next_token_logits = out
+ batch.speculative_logits = speculative_logits
+
+ # HPU->CPU sync
+ htorch.core.mark_step()
+ start_decode = time.time_ns()
+ for prev_batch in prev_batches:
+ prev_batch["next_token_logprobs"] = prev_batch[
+ "next_token_logprobs"
+ ].tolist()
+ prev_batch["next_token_ids"] = prev_batch["next_token_ids"].tolist()
+ prev_batch["accepted_ids"] = prev_batch["accepted_ids"].tolist()
+ htorch.core.mark_step()
+ # Stage 3. Finish and return previous generations
+ # Results
+ generations: List[Generation] = []
+ stopped = len(requests_to_generate) > 0
+ # Reset max_input_length
+ batch.max_input_length = 0
+ # For each member of the batch
+ indexs = [0] * len(prev_batches)
+ idx_accept_ids = [0] * len(prev_batches)
+ for i, req_data in enumerate(requests_to_generate):
+ idx = req_data["idx"]
+ request_id = req_data["request_id"]
+ prefix_offset = req_data["prefix_offset"]
+ read_offset = req_data["read_offset"]
+ stopping_criteria = req_data["stopping_criteria"]
+ all_input_ids = req_data["all_input_ids"]
+ do_sample = req_data["do_sample"]
+ seed = req_data["seed"]
+ top_n_tokens = req_data["top_n_tokens"]
+ n_accepted_ids = prev_batches[idx]["accepted_ids"][idx_accept_ids[idx]]
+ top_token_ids = req_data["top_token_ids"]
+ top_token_logprobs = req_data["top_token_logprobs"]
+ # Append next token to all tokens
+ next_token_texts = []
+ left = 0
+
+ if n_accepted_ids > 1:
+ log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
+
+ current_stopped = False
+ index = indexs[idx]
+ for j in range(index, index + n_accepted_ids):
+ # Generated token
+ next_token_id = prev_batches[idx]["next_token_ids"][j]
+ all_input_ids.append(next_token_id)
+ next_token_text, prefix_offset, read_offset = self.decode_token(
+ all_input_ids,
+ prefix_offset,
+ read_offset,
+ )
+ next_token_texts.append(next_token_text)
+
+ stop, reason = stopping_criteria(
+ next_token_id,
+ next_token_text,
+ )
+
+ if stop:
+ left = index + n_accepted_ids - j - 1
+ current_stopped = True
+ break
+ else:
+ current_stopped = False
+ stopped = stopped and current_stopped
+
+ _next_token_ids = prev_batches[idx]["next_token_ids"][
+ index : index + n_accepted_ids - left
+ ]
+ _next_token_logprobs = prev_batches[idx]["next_token_logprobs"][
+ index : index + n_accepted_ids - left
+ ]
+
+ # Shard generations
+ # All generations will be appended in the rust sharded client
+ if request_id % self.world_size == self.rank:
+ if stop:
+ # Decode generated tokens
+ output_text, _, _ = self.decode_token(
+ all_input_ids,
+ prefix_offset=len(all_input_ids)
+ - stopping_criteria.current_tokens
+ - 1,
+ read_offset=len(all_input_ids)
+ - stopping_criteria.current_tokens,
+ skip_special_tokens=True,
+ )
+ generated_text = GeneratedText(
+ output_text,
+ stopping_criteria.current_tokens,
+ reason,
+ seed if do_sample else None,
+ )
+ else:
+ generated_text = None
+
+ if top_n_tokens > 0:
+ all_top_tokens = []
+ for top_token_ids, top_token_logprobs in zip(
+ top_token_ids, top_token_logprobs
+ ):
+ toptoken_texts = self.tokenizer.batch_decode(
+ top_token_ids,
+ clean_up_tokenization_spaces=False,
+ skip_special_tokens=False,
+ )
+ special_toptokens = [
+ token_id in self.all_special_ids
+ for token_id in top_token_ids
+ ]
+ top_tokens = Tokens(
+ top_token_ids,
+ top_token_logprobs,
+ toptoken_texts,
+ special_toptokens,
+ )
+ all_top_tokens.append(top_tokens)
+ top_tokens = all_top_tokens
+ else:
+ top_tokens = None
+
+ generation = Generation(
+ request_id,
+ None,
+ Tokens(
+ _next_token_ids,
+ _next_token_logprobs,
+ next_token_texts,
+ [nid in self.all_special_ids for nid in _next_token_ids],
+ ),
+ generated_text,
+ top_tokens,
+ )
+
+ generations.append(generation)
+
+ # accept each new token for this specific request since we may
+ # have more than one new token per request with speculative decoding
+ for next_token_id in _next_token_ids:
+ batch.next_token_chooser = (
+ batch.next_token_chooser.advance_grammar_single(
+ i, next_token_id
+ )
+ )
+
+ # Update values
+ indexs[idx] += n_accepted_ids
+ idx_accept_ids[idx] += 1
+
+ batch.prefix_offsets[i] = prefix_offset
+ batch.read_offsets[i] = read_offset
+ batch.all_input_ids[i] = all_input_ids
+ htorch.core.mark_step()
+ if stopped:
+ # No need to return a batch if we know that all requests stopped
+ forward_ns = start_decode - start
+ decode_ns = time.time_ns() - start_decode
+ return generations, None, (forward_ns, decode_ns)
+
+ forward_ns = start_decode - start
+ decode_ns = time.time_ns() - start_decode
+ return generations, batch, (forward_ns, decode_ns)
diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py
new file mode 100644
index 00000000000..0cd49d45f19
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py
@@ -0,0 +1,1108 @@
+import torch
+from PIL import Image
+from io import BytesIO
+from dataclasses import dataclass
+from opentelemetry import trace
+from typing import Iterable, Optional, Tuple, List, Type, Dict
+
+from transformers import PreTrainedTokenizerBase
+from transformers.image_processing_utils import select_best_resolution
+from text_generation_server.pb import generate_pb2
+from text_generation_server.models.flash_causal_lm import (
+ FlashCausalLMBatch,
+ FlashCausalLM,
+ generate_block_metadata,
+)
+from text_generation_server.models.globals import PREFIX_CACHING, BLOCK_SIZE
+from loguru import logger
+from text_generation_server.utils.log import log_master
+from transformers import AutoProcessor
+from text_generation_server.layers.attention import (
+ Seqlen,
+ trim_seqlen_metadata,
+ _async_h2d_tensor_copy,
+ HPUPagedAttentionMetadata,
+ trim_attn_metadata,
+)
+import habana_frameworks.torch as htorch
+import time
+from text_generation_server.utils.import_utils import (
+ synchronize,
+)
+from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
+
+tracer = trace.get_tracer(__name__)
+
+IDEFICS2_FAKE_TOKEN = ""
+IDEFICS2_IMAGE_TOKEN = ""
+
+IDEFICS3_IMAGE_TOKEN = ""
+IDEFICS3_FAKE_IMAGE_TOKEN = ""
+IDEFICS3_GLOBAL_IMG_TOKEN = ""
+
+
+def prompt_split_image_llama4(aspect_ratio, num_patches_per_chunk):
+ """
+ Create a structured string representation of image tokens
+
+ Args:
+ num_patches: Number of patches in the image
+
+ Returns:
+ String with appropriate image tokens
+ """
+ img_string = "<|image_start|>"
+ ratio_h, ratio_w = aspect_ratio
+ if ratio_h * ratio_w > 1:
+ for yy in range(ratio_h):
+ for xx in range(ratio_w):
+ img_string += "<|patch|>" * num_patches_per_chunk
+ if xx < ratio_w - 1:
+ img_string += "<|tile_x_separator|>"
+
+ img_string += "<|tile_y_separator|>"
+ img_string += "<|image|>"
+ img_string += "<|patch|>" * num_patches_per_chunk
+ img_string += "<|image_end|>"
+
+ return img_string
+
+
+# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60
+def _prompt_split_image(
+ *,
+ image_seq_len: int,
+ image_rows: int,
+ image_cols: int,
+ fake_token_around_image: str,
+ image_token: str,
+ global_img_token: str,
+):
+ """Prompt with expanded image tokens for when the image is split into patches."""
+ text_split_images = ""
+ for n_h in range(image_rows):
+ for n_w in range(image_cols):
+ text_split_images += (
+ f"{fake_token_around_image}"
+ + f""
+ + f"{image_token}" * image_seq_len
+ )
+ text_split_images += "\n"
+
+ text_split_images += (
+ f"\n{fake_token_around_image}"
+ + f"{global_img_token}"
+ + f"{image_token}" * image_seq_len
+ + f"{fake_token_around_image}"
+ )
+ return text_split_images
+
+
+def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
+ """
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
+
+ Args:
+ image_size (`tuple`):
+ The size of the input image in the format (height, width).
+ grid_pinpoints (`List`):
+ A list containing possible resolutions. Each item in the list should be a tuple or list
+ of the form `(height, width)`.
+ patch_size (`int`):
+ The size of each image patch.
+
+ Returns:
+ tuple: The shape of the image patch grid in the format (width, height).
+ """
+ if not isinstance(grid_pinpoints, list):
+ raise ValueError("grid_pinpoints should be a list of tuples or lists")
+
+ height, width = select_best_resolution(image_size, grid_pinpoints)
+ return height // patch_size, width // patch_size
+
+
+def image_text_replacement(processor, image_input, config) -> str:
+ if config.model_type == "idefics2":
+ image_seq_len = 64
+ image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
+ if processor.image_processor.do_image_splitting:
+ image_str *= 5
+ return image_str, IDEFICS2_FAKE_TOKEN
+ if config.model_type == "idefics3":
+ # TODO: implement this in a more general way
+ n_rows = image_input["rows"][0][0]
+ n_cols = image_input["cols"][0][0]
+ image_seq_len = int(
+ ((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
+ / (config.scale_factor**2)
+ )
+ image_str = _prompt_split_image(
+ image_seq_len=image_seq_len,
+ image_rows=n_rows,
+ image_cols=n_cols,
+ fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN,
+ image_token=IDEFICS3_IMAGE_TOKEN,
+ global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
+ )
+ return image_str, IDEFICS3_FAKE_IMAGE_TOKEN
+ elif config.model_type == "llava_next":
+ height, width = image_input["image_sizes"][0]
+ num_features = get_number_of_features(height, width, config)
+
+ log_master(
+ logger.info,
+ f"Found {num_features} features in image of resolution {height}x{width}",
+ )
+ return "" * num_features, ""
+
+ elif config.model_type == "paligemma":
+ return "" * config.text_config.num_image_tokens, ""
+ elif config.model_type == "qwen2_vl":
+ grid_t, grid_h, grid_w = image_input["image_grid_thw"][0]
+ num_pads = grid_t * grid_h * grid_w // 4
+ padding = "<|image_pad|>" * num_pads
+ return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
+ elif config.model_type == "qwen2_5_vl":
+ grid_t, grid_h, grid_w = image_input["image_grid_thw"][0]
+ num_pads = grid_t * grid_h * grid_w // 4
+ padding = "<|image_pad|>" * num_pads
+ return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
+ elif config.model_type == "gemma3":
+ # TODO: get correct number of features via reviewing the Gemma3 architecture
+ # and calculating the number of image tokens
+ num_pads = 256
+ padding = "" * num_pads
+ return f"\n\n{padding}\n\n", ""
+ elif config.model_type == "llama4":
+ patch_size = config.vision_config.patch_size
+ pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio
+ downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
+ aspect_ratios = image_input["aspect_ratios"][0]
+ image_height, image_width = image_input["pixel_values"][0].shape[-2:]
+
+ num_patches_per_chunk = int(
+ (image_height // patch_size)
+ * (image_width // patch_size)
+ // downsample_ratio
+ )
+ tokens_for_this_image = prompt_split_image_llama4(
+ aspect_ratios, num_patches_per_chunk
+ )
+
+ return tokens_for_this_image, "<|image_start|>"
+ else:
+ raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
+
+
+def image_text_replacement_fixup(config, text: str) -> str:
+ if config.model_type == "idefics2":
+ return text.replace(
+ f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
+ )
+ return text
+
+
+def preprocess_text(config, text: str) -> str:
+ if config.model_type == "paligemma":
+ return "" + text + "\n"
+ return text
+
+
+def preprocess_image(config, img):
+ model_type = config.model_type
+
+ if model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20:
+ img = img.resize((img.width * 2, img.height * 2))
+ if model_type == "paligemma":
+ img = img.convert("RGB")
+
+ if model_type not in {"llava_next", "gemma3", "llama4"}:
+ # TODO: check if this is needed
+ img = [img]
+
+ return img
+
+
+def get_unpadded_features(
+ original_height: int,
+ original_width: int,
+ npatches: int,
+ num_patch_height: int,
+ num_patch_width: int,
+) -> Tuple[int, int]:
+ current_height = npatches * num_patch_height
+ current_width = npatches * num_patch_width
+
+ aspect_ratio: float = original_width / original_height
+ current_aspect_ratio: float = current_width / current_height
+
+ if aspect_ratio > current_aspect_ratio:
+ new_height = (original_height * current_width) // original_width
+ padding = (current_height - new_height) // 2
+ current_height = current_height - (2 * padding)
+ else:
+ new_width = (original_width * current_height) // original_height
+ padding = (current_width - new_width) // 2
+ current_width = current_width - (2 * padding)
+
+ unpadded_features = current_height * current_width
+ newline_features = current_height
+ return (unpadded_features, newline_features)
+
+
+def get_number_of_features(height: int, width: int, config) -> int:
+ # From config
+ # Hardcoded for CLIP for now
+ # image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
+ image_grid_pinpoints = config.image_grid_pinpoints
+ image_size = config.vision_config.image_size
+ patch_size = config.vision_config.patch_size
+
+ assert image_size % patch_size == 0
+
+ npatches = image_size // patch_size
+
+ # Dimensions are intentionally swapped to be bug-compatible with
+ # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(
+ [height, width],
+ image_grid_pinpoints,
+ image_size,
+ )
+ unpadded_features, newline_features = get_unpadded_features(
+ height, width, npatches, num_patch_height, num_patch_width
+ )
+ # The base patch covers the entire image
+ base_features = npatches**2
+ return unpadded_features + newline_features + base_features
+
+
+def scatter_image_embeds(
+ embeds: torch.Tensor, is_embed: Optional[torch.Tensor]
+) -> torch.Tensor:
+ if is_embed is None:
+ return embeds
+
+ placeholders = embeds.new_full(
+ (is_embed.shape[0], embeds.shape[-1]),
+ fill_value=torch.nan,
+ )
+ placeholders[is_embed.to(embeds.device)] = embeds
+ return placeholders
+
+
+def gather_image_embeds(
+ embeds: torch.Tensor, is_embed: Optional[torch.Tensor]
+) -> Optional[torch.Tensor]:
+ if is_embed is None:
+ return embeds
+ sel = embeds[is_embed.to(embeds.device)]
+ return sel if sel.numel() else None
+
+
+@dataclass
+class ImagePositions:
+ offset: int
+ length: int
+ id: int
+ num_placeholder_tokens: int
+ is_embed: Optional[torch.Tensor] = None
+
+
+class FlashVlmCausalLMBatch(FlashCausalLMBatch):
+ image_inputs: Optional[List[List[Dict[str, torch.Tensor]]]]
+ image_positions: Optional[List[List[ImagePositions]]]
+ encoder_cache: Optional[List[Dict[int, torch.Tensor]]]
+ pixel_values: Optional[List[torch.Tensor]]
+ pixel_attention_mask: Optional[List[torch.Tensor]]
+ image_sizes: Optional[List[Tuple[int, int]]]
+ image_grid_thw: Optional[torch.Tensor]
+ cache_entries_to_free: List[Tuple[int, int]]
+ has_image_inputs: bool = False
+ inputs_embeds: Optional[torch.Tensor] = None
+
+ @classmethod
+ @tracer.start_as_current_span("concatenate")
+ def concatenate(cls, batches, padded_total_bs: int = 0):
+ batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs)
+ batch.image_inputs = []
+ batch.image_positions = []
+ batch.encoder_cache = []
+ for b in batches:
+ if b.image_inputs is not None:
+ batch.image_inputs.extend(b.image_inputs)
+ else:
+ batch.image_inputs.append(None)
+ if b.image_positions is not None:
+ batch.image_positions.extend(b.image_positions)
+ else:
+ batch.image_positions.append(None)
+ if b.encoder_cache is not None:
+ batch.encoder_cache.extend(b.encoder_cache)
+ else:
+ batch.encoder_cache.append(None)
+
+ batch.pixel_values = None
+ batch.pixel_attention_mask = None
+ batch.image_sizes = None
+ batch.image_grid_thw = None
+ batch.inputs_embeds = None
+ # To be filled in prepare_for_prefill
+ batch.has_image_inputs = False
+ batch.cache_entries_to_free = []
+ return batch
+
+ @tracer.start_as_current_span("filter")
+ def filter(self, request_ids: List[int]):
+ if len(request_ids) == 0:
+ raise ValueError("Batch must have at least one request")
+
+ image_inputs = []
+ image_positions = []
+ encoder_cache = []
+
+ for request_id in request_ids:
+ idx = self.requests_idx_mapping[request_id]
+ image_inputs.append(self.image_inputs[idx])
+ image_positions.append(self.image_positions[idx])
+ encoder_cache.append(self.encoder_cache[idx])
+
+ batch = super().filter(request_ids)
+ batch.pixel_values = None
+ batch.pixel_attention_mask = None
+ batch.image_sizes = None
+ batch.image_grid_thw = None
+ batch.inputs_embeds = None
+ batch.image_inputs = image_inputs
+ batch.image_positions = image_positions
+ batch.encoder_cache = encoder_cache
+
+ # To be filled in prepare_for_prefill
+ batch.has_image_inputs = False
+ batch.cache_entries_to_free = []
+ return batch
+
+ @classmethod
+ def batch_tokenized_inputs(
+ cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
+ ):
+ kwargs = {}
+ if (
+ hasattr(processor, "image_processor_class")
+ and processor.image_processor_class == "Idefics3ImageProcessor"
+ ):
+ kwargs["return_row_col_info"] = True
+
+ max_length = 0
+ vocab = tokenizer.get_vocab()
+
+ if not hasattr(config, "image_token_index"):
+ config.image_token_index = config.image_token_id
+
+ batch_tokenized_inputs: List[List[int]] = []
+ batch_image_inputs: List[Optional[List[dict]]] = []
+ batch_image_positions: List[Optional[List[ImagePositions]]] = []
+
+ for r in requests:
+ text_parts = []
+ image_inputs = []
+ image_texts = []
+
+ image_id = 0
+
+ for chunk in r.input_chunks.chunks:
+ chunk_type = chunk.WhichOneof("chunk")
+ if chunk_type == "text":
+ text = preprocess_text(config, chunk.text)
+ text_parts.append(text)
+ elif chunk_type == "image":
+ img = Image.open(BytesIO(chunk.image.data))
+ img = preprocess_image(config, img)
+
+ image_input = processor.image_processor(
+ [img], return_tensors="pt", **kwargs
+ )
+ image_inputs.append(image_input)
+
+ img_text, img_start_token_str = image_text_replacement(
+ processor, image_input, config
+ )
+ text_parts.append(img_text)
+
+ image_texts.append([image_id, img_start_token_str, img_text])
+ image_id += 1
+ else:
+ raise RuntimeError(f"Invalid chunk type {chunk_type}")
+
+ full_text = image_text_replacement_fixup(config, "".join(text_parts))
+ input_ids = tokenizer(
+ full_text,
+ truncation=True,
+ max_length=r.truncate,
+ add_special_tokens=(
+ r.add_special_tokens if config.model_type != "paligemma" else False
+ ),
+ )["input_ids"]
+ max_length = max(max_length, len(input_ids))
+
+ if len(image_inputs) > 0:
+ img_start_token = vocab[image_texts[0][1]]
+ image_positions = cls.get_image_positions(
+ input_ids, image_texts, img_start_token, config, tokenizer
+ )
+ else:
+ image_inputs = None
+ image_positions = None
+
+ batch_tokenized_inputs.append(input_ids)
+ batch_image_inputs.append(image_inputs)
+ batch_image_positions.append(image_positions)
+
+ return batch_tokenized_inputs, batch_image_inputs, batch_image_positions
+
+ @classmethod
+ def get_image_positions(
+ cls,
+ input_ids: List[int],
+ image_texts: List[Tuple[int, str, str]],
+ img_start_token: int,
+ config,
+ tokenizer: PreTrainedTokenizerBase,
+ ) -> List[ImagePositions]:
+ image_positions = []
+ num_images = len(image_texts)
+
+ input_ids_t = torch.as_tensor(input_ids)
+ img_start_token_pos = torch.where(input_ids_t.eq(img_start_token))[0]
+ num_tokens = input_ids_t.numel()
+
+ last_pos = 0
+ for i in range(num_images):
+ image_id, img_start_token_str, img_text = image_texts[i]
+ img_text = image_text_replacement_fixup(config, img_text)
+
+ if config.model_type == "gemma3":
+ img_text = img_text.replace("\n\n", "")
+
+ tokens = tokenizer(img_text, add_special_tokens=False, return_tensors="pt")[
+ "input_ids"
+ ][0]
+ length = tokens.numel()
+
+ assert (
+ length <= num_tokens
+ ), f"{length} > {num_tokens} Image is truncated, try increasing --max-batch-prefill-tokens"
+
+ pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)
+ index = img_start_token_pos[pos]
+ assert torch.equal(
+ input_ids_t[index : index + length], tokens
+ ), "Image tokens not found in input_ids"
+
+ is_embed = tokens == config.image_token_index
+ num_placeholder_tokens = int(is_embed.sum())
+ if num_placeholder_tokens == length:
+ is_embed = None
+
+ pos = ImagePositions(
+ offset=index,
+ length=length,
+ id=image_id,
+ num_placeholder_tokens=num_placeholder_tokens,
+ is_embed=is_embed,
+ )
+
+ image_positions.append(pos)
+ last_pos = index + length
+
+ if (
+ config.model_type == "idefics2"
+ and i + 1 != num_images
+ and input_ids[last_pos] == config.image_token_index
+ ):
+ fake_token = last_pos - 1
+ fake_token_index = torch.searchsorted(
+ img_start_token_pos, fake_token, right=False
+ )
+ img_start_token_pos[fake_token_index] = last_pos
+ image_texts[i + 1][2] = image_texts[i + 1][2][
+ len(img_start_token_str) :
+ ]
+
+ return image_positions
+
+ @classmethod
+ def from_pb_processor(
+ cls,
+ pb: generate_pb2.Batch,
+ tokenizer: PreTrainedTokenizerBase,
+ processor,
+ config,
+ dtype: torch.dtype,
+ device: torch.device,
+ ) -> "FlashVlmCausalLMBatch":
+ batch_tokenized_inputs, image_inputs, image_positions = (
+ cls.batch_tokenized_inputs(pb.requests, tokenizer, processor, config)
+ )
+ batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
+ batch.image_inputs = image_inputs
+ batch.image_positions = image_positions
+ batch.encoder_cache = [{} for _ in range(len(pb.requests))]
+ if len(image_inputs):
+ batch.pixel_values = None
+ batch.pixel_attention_mask = None
+ batch.image_sizes = None
+ batch.image_grid_thw = None
+ return batch
+
+ def prepare_for_prefill(
+ self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
+ ):
+ super().prepare_for_prefill(
+ max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
+ )
+
+ self.has_image_inputs = False
+ self.cache_entries_to_free = []
+
+ self.pixel_values = []
+
+ assert (
+ len(self.cache_lengths)
+ == len(self.input_lengths)
+ == len(self.prefilling_mask)
+ ), "Mismatch in lengths of cache_lengths, input_lengths, and prefilling_mask"
+
+ for i, (
+ cache_length,
+ input_length,
+ request_prefilling,
+ ) in enumerate(
+ zip(
+ self.cache_lengths,
+ self.input_lengths,
+ self.prefilling_mask,
+ )
+ ):
+ if not request_prefilling or self.image_positions[i] is None:
+ continue
+
+ for image_position in self.image_positions[i]:
+ if image_position is None:
+ continue
+ start_pos = image_position.offset
+ length = image_position.length
+
+ if start_pos >= cache_length + input_length:
+ # No encoder input required at this step
+ break
+ if start_pos + length <= cache_length:
+ # The encode input is already processed
+ continue
+
+ self.has_image_inputs = True
+
+ if image_position.id not in self.encoder_cache[i]:
+ image_inputs = self.image_inputs[i][image_position.id]
+ self.pixel_values.append((i, image_position.id, image_inputs))
+
+ # Remove the image from the image_inputs
+ self.image_inputs[i][image_position.id] = None
+
+ if not self.has_image_inputs:
+ self.pixel_values = None
+ self.pixel_attention_mask = None
+ self.image_sizes = None
+ self.image_grid_thw = None
+ else:
+ image_grid_thw_list = [
+ x[2]["image_grid_thw"]
+ for x in self.pixel_values
+ if "image_grid_thw" in x[2]
+ ]
+ if image_grid_thw_list:
+ self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0)
+ else:
+ self.image_grid_thw = None
+
+ def update_encoder_cache(self, encoder_outputs, request_id, img_pos):
+ self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds(
+ encoder_outputs, img_pos.is_embed
+ )
+
+ def gather_vision_embeds(self):
+ device = self.input_ids.device
+ chunks = []
+ for (
+ i,
+ cache_length,
+ input_length,
+ request_prefilling,
+ ) in zip(
+ range(len(self.requests)),
+ self.cache_lengths,
+ self.input_lengths,
+ self.prefilling_mask,
+ ):
+ if not request_prefilling or self.image_positions[i] is None:
+ continue
+
+ for image_position in self.image_positions[i]:
+ if image_position is None:
+ continue
+ start_pos = image_position.offset
+ length = image_position.length
+
+ if start_pos >= cache_length + input_length:
+ # No encoder input required at this step
+ break
+ if start_pos + length <= cache_length:
+ # The encode input is already processed
+ continue
+
+ start_idx = max(cache_length - start_pos, 0)
+ end_idx = min(cache_length - start_pos + input_length, length)
+
+ assert (
+ image_position.id in self.encoder_cache[i]
+ ), f"image_id {image_position.id} not in encoder_cache {self.encoder_cache[i]}"
+ encoder_output = self.encoder_cache[i][image_position.id]
+
+ is_embed = image_position.is_embed
+ if is_embed is not None:
+ is_embed = is_embed[start_idx:end_idx]
+
+ from loguru import logger
+
+ logger.info(
+ f"image_id {image_position.id} start_idx {start_idx} end_idx {end_idx}, length {length}"
+ )
+
+ embeds = gather_image_embeds(
+ encoder_output[start_idx:end_idx],
+ is_embed=is_embed,
+ )
+ if embeds is not None:
+ chunks.append(embeds)
+
+ if end_idx == length:
+ self.cache_entries_to_free.append((i, image_position.id))
+ self.image_positions[i][image_position.id] = None
+
+ if len(chunks) == 0:
+ return None
+ return torch.cat(chunks, dim=0).to(device)
+
+ def free_encoder_cache(self):
+ for i, image_id in self.cache_entries_to_free:
+ self.encoder_cache[i].pop(image_id, None)
+
+ self.cache_entries_to_free = []
+
+
+class FlashVlmCausalLM(FlashCausalLM):
+ def __init__(
+ self,
+ model_id: str,
+ *,
+ processor_class=AutoProcessor,
+ processor_kwargs=None,
+ batch_class=FlashVlmCausalLMBatch,
+ revision,
+ trust_remote_code: bool,
+ support_chunking: bool = False,
+ **kwargs,
+ ):
+ if PREFIX_CACHING:
+ raise NotImplementedError("Vlm do not work with prefix caching yet")
+ if processor_kwargs is None:
+ processor_kwargs = {}
+ self.processor = processor_class.from_pretrained(
+ model_id,
+ revision=revision,
+ trust_remote_code=trust_remote_code,
+ **processor_kwargs,
+ )
+ self.batch_class = batch_class
+ super().__init__(
+ model_id=model_id,
+ revision=revision,
+ trust_remote_code=trust_remote_code,
+ support_chunking=support_chunking,
+ **kwargs,
+ )
+
+ @property
+ def batch_type(self) -> Type[FlashVlmCausalLMBatch]:
+ return self.batch_class
+
+ def max_past(self) -> Optional[int]:
+ return getattr(self.model.text_model, "max_past", None)
+
+ def warmup_decode(
+ self, batch_size: int, block_num: int, batch: FlashVlmCausalLMBatch
+ ):
+ input_ids = torch.zeros(batch_size, dtype=batch.input_ids.dtype)
+ position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype)
+ if batch.position_ids is not None and batch.position_ids.dim() == 2:
+ # qwen2_vl and qwen2_5_vl case
+ position_ids = position_ids.unsqueeze(-1).repeat(
+ (1, batch.position_ids.shape[-1])
+ )
+ blocks = [block_num // batch_size for _ in range(batch_size)]
+ blocks[0] += block_num % batch_size
+ block_tables = []
+ slots = []
+ start_idx = 0
+ slot_indices = []
+
+ # fetch the last blocked to warmup block num
+
+ for i in range(batch_size):
+ block_array = list(range(start_idx, start_idx + blocks[i]))
+ slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
+ block_tables.append(block_array)
+ slot_indices.append((start_idx + blocks[i]) * BLOCK_SIZE - 1)
+ start_idx += blocks[i]
+ input_lengths = torch.ones(batch_size, dtype=torch.int32)
+
+ seqlen = Seqlen(
+ input_lengths=_async_h2d_tensor_copy(input_lengths),
+ )
+ block_list, block_groups, block_usage, _, block_bucket_size = (
+ generate_block_metadata(
+ self.dtype,
+ self.use_contiguous_pa,
+ slots,
+ block_tables,
+ self.bucketing_ctx,
+ )
+ )
+ meta = HPUPagedAttentionMetadata(
+ block_list=_async_h2d_tensor_copy(block_list),
+ block_groups=_async_h2d_tensor_copy(block_groups),
+ block_usage=_async_h2d_tensor_copy(block_usage),
+ block_mapping=None,
+ attn_bias=None,
+ )
+ if self.sliding_window is not None:
+ block_tables_in_window = []
+ for i, bt in enumerate(block_tables):
+ block_num_in_window = (
+ self.sliding_window + BLOCK_SIZE - 1
+ ) // BLOCK_SIZE
+ block_tables_in_window.append(
+ bt[max(0, blocks[i] - block_num_in_window) : blocks[i]]
+ )
+ slots_in_window = []
+ start_idx = 0
+ for i, indice in enumerate(slot_indices):
+ mask = (
+ indice - torch.arange(start_idx, indice + 1)
+ ) < self.sliding_window
+ slots_in_window.append(torch.arange(start_idx, indice + 1)[mask])
+ start_idx += blocks[i] * BLOCK_SIZE
+ slots_in_window = torch.cat(slots_in_window, dim=0)
+ (
+ block_list_in_window,
+ block_groups_in_window,
+ block_usage_in_window,
+ slots_in_window_mask,
+ _,
+ ) = generate_block_metadata(
+ self.dtype,
+ self.use_contiguous_pa,
+ slots,
+ block_tables_in_window,
+ self.bucketing_ctx,
+ slots_in_window,
+ block_bucket_size,
+ )
+ meta.block_list_in_window = _async_h2d_tensor_copy(block_list_in_window)
+ meta.block_groups_in_window = _async_h2d_tensor_copy(block_groups_in_window)
+ meta.block_usage_in_window = _async_h2d_tensor_copy(block_usage_in_window)
+ meta.slots_in_window_mask = _async_h2d_tensor_copy(slots_in_window_mask)
+
+ hpu_attention_meta = trim_attn_metadata(meta)
+ slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
+ inputs_embeds = self.get_inputs_embeds(
+ input_ids=input_ids.to(self.device),
+ )
+ # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
+ self.model.forward(
+ inputs_embeds=inputs_embeds,
+ position_ids=_async_h2d_tensor_copy(position_ids),
+ cu_seqlen_prefill=None,
+ kv_cache=self.kv_cache,
+ slots=_async_h2d_tensor_copy(slots_tensor),
+ seqlen=trim_seqlen_metadata(seqlen),
+ hpu_attention_meta=hpu_attention_meta,
+ lm_head_indices=None,
+ attention_mask=None,
+ )
+
+ def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
+ free_mem = HabanaMemoryProfiler.current_free_device_memory()
+ graph_free_mem = free_mem - self.mem_reserved
+ graph_free_mem = self.align_workers(
+ graph_free_mem, torch.distributed.ReduceOp.MIN
+ )
+ decode_available_memory = graph_free_mem
+ msg = (
+ f"Using {format_bytes(graph_free_mem)}"
+ f"/{format_bytes(free_mem)} "
+ "of free device memory for HPUGraphs, "
+ f"{format_bytes(decode_available_memory)} for decode "
+ )
+ log_master(logger.info, msg)
+ start_time = time.time()
+ warmup_shape_count = 0
+ warmup_times = 3
+
+ # only warmup decode, for prefill, image pixal size may change, make the warmup useless
+ def ordering_function_max_bs(b):
+ return (-b[0], b[1])
+
+ self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
+ buckets = list(
+ sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
+ )
+ total_batch_seq = 0.001
+ total_mem = 0
+ available_mem = decode_available_memory
+ log_master(
+ logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n"
+ )
+ for i, (batch_size, block_num) in enumerate(buckets):
+ if batch_size > block_num:
+ continue
+ # Graph memory usage is proportional to seq dimension in a batch
+ batch_seq = batch_size
+ mem_estimate = batch_seq / total_batch_seq * total_mem
+ graphed_bucket = (batch_size, block_num, False)
+ if not mem_estimate >= available_mem:
+ if graphed_bucket not in self.graphed_buckets:
+ self.graphed_buckets.add(graphed_bucket)
+ warmup_shape_count += 1
+ self.log_warmup(False, i, len(buckets), batch_size, block_num)
+ with HabanaMemoryProfiler() as mem_prof:
+ for index in range(warmup_times):
+ self.warmup_decode(batch_size, block_num, batch)
+ synchronize(self.device)
+ used_mem = self.align_workers(
+ mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
+ )
+ if graphed_bucket in self.graphed_buckets:
+
+ available_mem -= used_mem
+ total_mem += used_mem
+ total_batch_seq += batch_seq
+
+ log_master(logger.info, "Decode warmup successful.\n")
+
+ log_master(
+ logger.info,
+ f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
+ )
+
+ def get_vision_embeds(
+ self,
+ pixel_values: torch.Tensor,
+ pixel_attention_mask: torch.Tensor,
+ image_sizes: torch.Tensor,
+ image_grid_thw: torch.Tensor,
+ ):
+ embeds = self.model.get_vision_embeds(
+ pixel_values=pixel_values,
+ pixel_attention_mask=pixel_attention_mask,
+ image_sizes=image_sizes,
+ image_grid_thw=image_grid_thw,
+ )
+ return embeds
+
+ def get_inputs_embeds(
+ self,
+ input_ids: torch.Tensor,
+ vision_embeds: Optional[torch.Tensor] = None,
+ ):
+ return self.model.get_inputs_embeds(
+ input_ids=input_ids,
+ vision_embeds=vision_embeds,
+ )
+
+ def encode_images(self, batch):
+ if batch.pixel_values is not None:
+ device = batch.input_ids.device
+ for request_id, image_id, image_input in batch.pixel_values:
+ pixel_values = image_input["pixel_values"].to(device)
+
+ if "pixel_attention_mask" in image_input:
+ pixel_attention_mask = image_input["pixel_attention_mask"].to(
+ device
+ )
+ else:
+ pixel_attention_mask = None
+
+ if "image_sizes" in image_input:
+ image_sizes = image_input["image_sizes"].to(device)
+ else:
+ image_sizes = None
+
+ if "image_grid_thw" in image_input:
+ image_grid_thw = image_input["image_grid_thw"]
+ else:
+ image_grid_thw = None
+
+ encoder_outputs = self.get_vision_embeds(
+ pixel_values=pixel_values,
+ pixel_attention_mask=pixel_attention_mask,
+ image_sizes=image_sizes,
+ image_grid_thw=image_grid_thw,
+ )
+ batch.update_encoder_cache(
+ encoder_outputs,
+ request_id,
+ batch.image_positions[request_id][image_id],
+ )
+
+ batch.pixel_values = None
+ batch.pixel_attention_mask = None
+ batch.image_sizes = None
+
+ def set_inputs_embeds(self, batch):
+ if batch.has_image_inputs:
+ self.encode_images(batch)
+ vision_embeds = batch.gather_vision_embeds()
+ batch.has_image_inputs = False
+ else:
+ vision_embeds = None
+
+ inputs_embeds = self.get_inputs_embeds(
+ batch.input_ids, vision_embeds=vision_embeds
+ )
+
+ batch.inputs_embeds = inputs_embeds
+
+ def forward(
+ self,
+ batch: FlashVlmCausalLMBatch,
+ adapter_data: Optional[Dict[str, torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ # Model Forward
+ if batch.speculative_ids is not None:
+ input_ids = batch.input_ids
+ position_ids = batch.position_ids
+ cu_seqlen_prefill = batch.cu_seqlen_prefill
+ kv_cache = self.kv_cache
+ block_tables = batch.block_tables_tensor
+ slots = batch.slots[batch.slot_indices]
+ input_lengths = batch.input_lengths_tensor
+ max_s = batch.max_current_length
+ lm_head_indices = batch.prefill_head_indices
+
+ speculative_ids = batch.speculative_ids
+
+ B, speculative_length = speculative_ids.shape
+ new_length = speculative_length + 1
+ new_input_ids = torch.cat(
+ [input_ids.unsqueeze(-1), speculative_ids], dim=1
+ ).reshape(-1)
+ arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
+ arange_int = arange.to(dtype=torch.int32)
+ new_position_ids = (
+ position_ids.unsqueeze(-1).expand(B, new_length) + arange
+ ).view(-1)
+ slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
+ input_lengths = (
+ input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
+ ).view(-1)
+
+ # Add Copy the block tables for all members
+ block_tables = (
+ block_tables.unsqueeze(1)
+ .expand(B, new_length, -1)
+ .reshape(B * new_length, -1)
+ .contiguous()
+ )
+ max_s = max_s + speculative_length
+
+ input_ids = new_input_ids
+ position_ids = new_position_ids
+ else:
+ input_ids = batch.input_ids
+ inputs_embeds = batch.inputs_embeds
+ position_ids = batch.position_ids
+ cu_seqlen_prefill = batch.cu_seqlen_prefill
+ kv_cache = self.kv_cache
+ block_tables = batch.block_tables_tensor
+ slots = batch.slots[batch.slot_indices]
+ input_lengths = batch.input_lengths_tensor
+ max_s = batch.max_current_length
+ lm_head_indices = batch.prefill_head_indices
+
+ if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
+ if position_ids.dim() == 1 and batch.prefilling:
+ position_ids = self.model.get_position_ids(
+ input_ids.cpu(), batch.image_grid_thw
+ )
+ batch.position_ids = position_ids
+
+ attention_mask = None
+ attention_mask_forward = None
+ if self.model.config.model_type == "llama4":
+ attention_mask = (input_ids != self.tokenizer.pad_token_id).long()
+ attention_mask_forward = attention_mask.view(input_lengths.shape[0], -1)
+
+ if cu_seqlen_prefill is None and self.max_past() is not None:
+ # In decode, not prefill, we're actually overwriting the KV-cache
+ # in a circular buffer mode.
+ # This makes sure the max_s for the decode pass is correct.
+ max_s = min(self.max_past(), max_s)
+
+ if batch.prefill_cache_indices is not None:
+ slots_pad = torch.zeros_like(input_ids, device=slots.device)
+ slots_pad[batch.prefill_cache_indices] = slots
+ slots = slots_pad
+ else:
+ slots_pad = torch.zeros_like(input_ids, device=slots.device)
+ slots_pad[: slots.shape[0]] = slots
+ slots = slots_pad
+
+ seqlen = Seqlen(
+ input_lengths=_async_h2d_tensor_copy(input_lengths),
+ )
+ kwargs = {}
+ batch_size = input_lengths.shape[0]
+ prompt_len = (
+ input_ids.shape[0] // batch_size
+ if batch.prefilling
+ else batch.hpu_attn_meta.block_list.shape[0]
+ )
+ if htorch.utils.internal.is_lazy():
+ kwargs["bypass_hpu_graphs"] = not self.use_graphs(
+ batch.prefilling, prompt_len, batch_size
+ )
+ if self.sliding_window is not None:
+ attn_mask = seqlen.make_sliding_window_bias(
+ input_lengths.tolist(),
+ self.sliding_window,
+ self.dtype,
+ prompt_len,
+ batch_size,
+ )
+ seqlen.attn_mask = _async_h2d_tensor_copy(attn_mask)
+ logits, speculative_logits = self.model.forward(
+ inputs_embeds=inputs_embeds,
+ position_ids=_async_h2d_tensor_copy(position_ids),
+ cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
+ kv_cache=kv_cache,
+ slots=_async_h2d_tensor_copy(slots),
+ seqlen=trim_seqlen_metadata(seqlen),
+ hpu_attention_meta=batch.hpu_attn_meta,
+ lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
+ attention_mask=attention_mask_forward,
+ **kwargs,
+ )
+ batch.image_grid_thw = None
+ batch.free_encoder_cache()
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/globals.py b/backends/gaudi/server/text_generation_server/models/globals.py
new file mode 100644
index 00000000000..cdde67ca486
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/globals.py
@@ -0,0 +1,52 @@
+import os
+from typing import Dict, Optional
+from loguru import logger
+from text_generation_server.utils.log import log_master
+
+REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
+ATTENTION = os.getenv("ATTENTION", "paged")
+# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
+PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in {
+ "1",
+ "true",
+}
+log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
+_expected = {"paged"}
+assert (
+ ATTENTION in _expected
+), f"Attention is not valid {ATTENTION}, expected {_expected}"
+log_master(logger.info, f"Using Attention = {ATTENTION}")
+
+TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
+assert TGI_WIGGLE_ROOM > 0
+assert TGI_WIGGLE_ROOM < 1
+
+# This is overridden by the cli
+BLOCK_SIZE: int
+
+BLOCK_SIZE = 128
+
+
+# This is overridden at model loading.
+global MODEL_ID
+MODEL_ID = None
+
+
+def set_model_id(model_id: str):
+ global MODEL_ID
+ MODEL_ID = model_id
+
+
+# NOTE: eventually we should move this into the router and pass back the
+# index in all cases.
+ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
+
+
+def set_adapter_to_index(adapter_to_index: Dict[str, int]):
+ global ADAPTER_TO_INDEX
+ ADAPTER_TO_INDEX = adapter_to_index
+
+
+def get_adapter_to_index():
+ global ADAPTER_TO_INDEX
+ return ADAPTER_TO_INDEX
diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py
new file mode 100644
index 00000000000..d266aad905a
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py
@@ -0,0 +1,624 @@
+import torch
+
+import numpy as np
+
+from typing import Iterable, Optional, Tuple, List, Dict
+from text_generation_server.pb.generate_pb2 import Request
+from io import BytesIO
+from PIL import Image
+from dataclasses import dataclass
+from opentelemetry import trace
+from transformers import (
+ PreTrainedTokenizerBase,
+)
+from text_generation_server.models.flash_causal_lm import (
+ generate_block_metadata,
+)
+from text_generation_server.models.flash_vlm_causal_lm import (
+ FlashVlmCausalLMBatch,
+ FlashVlmCausalLM,
+)
+from text_generation_server.pb import generate_pb2
+from text_generation_server.layers.attention import (
+ Seqlen,
+ trim_seqlen_metadata,
+ _async_h2d_tensor_copy,
+ HPUPagedAttentionMetadata,
+ trim_attn_metadata,
+)
+import habana_frameworks.torch as htorch
+from loguru import logger
+from text_generation_server.models.globals import BLOCK_SIZE
+from text_generation_server.utils.import_utils import (
+ synchronize,
+)
+import torch.nn.functional as F
+from text_generation_server.utils.log import log_master
+import time
+import os
+from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
+
+tracer = trace.get_tracer(__name__)
+
+
+@dataclass
+class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
+ image_indices: List[int] = 42
+ aspect_ratio_ids: Optional[torch.Tensor] = None
+ aspect_ratio_mask: Optional[torch.Tensor] = None
+ cross_attention_states: Optional[torch.Tensor] = None
+
+ def prepare_for_prefill(
+ self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
+ ):
+ super(FlashVlmCausalLMBatch, self).prepare_for_prefill(
+ max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
+ )
+
+ @classmethod
+ @tracer.start_as_current_span("concatenate")
+ def concatenate(cls, batches, padded_total_bs: int = 0):
+ batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs)
+ batch.pixel_values = None
+ batch.pixel_attention_mask = None
+
+ offset = 0
+ image_indices = []
+ attention_states = []
+ for b in batches:
+ if b.cross_attention_states is not None:
+ attention_states.append(b.cross_attention_states)
+ image_indices.extend([i + offset for i in b.image_indices])
+ offset += len(b.image_indices)
+ if len(attention_states) > 0:
+ assert len(image_indices) > 0
+ batch.cross_attention_states = torch.cat(attention_states, dim=0)
+ batch.image_indices = image_indices
+ else:
+ batch.cross_attention_states = None
+ batch.image_indices = []
+ return batch
+
+ @tracer.start_as_current_span("filter")
+ def filter(self, request_ids: List[int]):
+ assert self.image_indices is not None
+ batch = super(FlashVlmCausalLMBatch, self).filter(request_ids)
+ assert self.image_indices is not None
+ indices = []
+ for i, request_id in enumerate(request_ids):
+ idx = self.requests_idx_mapping[request_id]
+ indices.append(idx)
+
+ offset = 0
+ new_image_indices = []
+ prev_i = None
+ for i in self.image_indices:
+ if i in indices:
+ new_image_indices.append(offset)
+ if i != prev_i:
+ offset += 1
+ prev_i = i
+
+ batch.image_indices = new_image_indices
+ if len(new_image_indices) > 0:
+ assert max(new_image_indices) < self.cross_attention_states.shape[0]
+ assert offset <= self.cross_attention_states.shape[0]
+ batch.cross_attention_states = self.cross_attention_states[
+ new_image_indices
+ ]
+ else:
+ batch.cross_attention_states = None
+ batch.pixel_values = None
+ return batch
+
+ @classmethod
+ def batch_tokenized_inputs(
+ cls, requests: Iterable[Request], tokenizer, processor, config
+ ):
+ image_inputs = []
+ texts = []
+ image_indices = []
+ batch_tokenized_inputs = []
+
+ for i, r in enumerate(requests):
+ # Each input is encoded into a list, where each element of this input list is either a string or a URL
+ curr_text = ""
+ curr_image = None
+ curr_i = None
+ for chunk in r.input_chunks.chunks:
+ chunk_type = chunk.WhichOneof("chunk")
+ if chunk_type == "text":
+ curr_text += chunk.text
+ elif chunk_type == "image":
+ image = Image.open(BytesIO(chunk.image.data))
+ # TODO unsure about BOS
+ curr_text += "<|image|>"
+ image_input = processor.image_processor(image, return_tensors="pt")
+ curr_image = image_input
+ curr_i = i
+ # image_inputs.append(image_input)
+ # image_indices.append(i)
+ else:
+ raise RuntimeError(f"Invalid chunk type {chunk_type}")
+ texts.append(curr_text)
+ if curr_image is not None:
+ image_inputs.append(curr_image)
+ image_indices.append(curr_i)
+
+ input_ids = tokenizer(
+ curr_text,
+ truncation=True,
+ max_length=r.truncate,
+ add_special_tokens=r.add_special_tokens,
+ )["input_ids"]
+ batch_tokenized_inputs.append(input_ids)
+ if image_inputs:
+ image_input = image_inputs[0]
+ new_image_inputs = {
+ "pixel_values": torch.cat(
+ [img["pixel_values"] for img in image_inputs], dim=0
+ ),
+ }
+ if "aspect_ratio_ids" in image_input:
+ new_image_inputs["aspect_ratio_ids"] = torch.cat(
+ [img["aspect_ratio_ids"] for img in image_inputs], dim=0
+ )
+ if "aspect_ratio_mask" in image_input:
+ new_image_inputs["aspect_ratio_mask"] = torch.cat(
+ [img["aspect_ratio_mask"] for img in image_inputs], dim=0
+ )
+ image_inputs = new_image_inputs
+ image_inputs["image_indices"] = image_indices
+ else:
+ image_inputs = None
+
+ if image_inputs is not None:
+ assert len(image_indices) == image_inputs["pixel_values"].shape[0]
+
+ return batch_tokenized_inputs, image_inputs
+
+ @classmethod
+ def from_pb_processor(
+ cls,
+ pb: generate_pb2.Batch,
+ tokenizer: PreTrainedTokenizerBase,
+ processor,
+ config,
+ dtype: torch.dtype,
+ device: torch.device,
+ ) -> "FlashVlmCausalLMBatch":
+ batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
+ pb.requests, tokenizer, processor, config
+ )
+ batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
+ # XXX: <|image|> token is actually out of bounds and bugs out the logit processors.
+ batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
+ max=config.text_config.vocab_size - 1
+ )
+ if isinstance(batch.input_ids, list):
+ if len(batch) > 1:
+ input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
+ else:
+ input_ids = batch.input_ids[0]
+ batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
+
+ batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
+
+ if image_inputs is not None:
+ batch.pixel_values = image_inputs["pixel_values"].to(
+ device=device, dtype=dtype
+ )
+ batch.aspect_ratio_ids = image_inputs["aspect_ratio_ids"].to(device=device)
+ batch.aspect_ratio_mask = image_inputs["aspect_ratio_mask"].to(
+ device=device
+ )
+ batch.image_indices = image_inputs["image_indices"]
+ else:
+ batch.pixel_values = None
+ batch.aspect_ratio_ids = None
+ batch.aspect_ratio_mask = None
+ batch.image_indices = []
+ assert batch.image_indices is not None
+ return batch
+
+
+def generate_cross_attention_states(
+ cross_attention_states, image_indices, input_lengths, pad_seq_len, prefilling
+):
+ if cross_attention_states is None:
+ return None, None
+ indices_list = []
+ if prefilling:
+ for i in image_indices:
+ indices_list.append(torch.arange(pad_seq_len * i, pad_seq_len * (i + 1)))
+ indices = torch.cat(indices_list, dim=0)
+ else:
+ indices = image_indices[:]
+ return indices, input_lengths.index_select(0, image_indices)
+
+
+class FlashMllamaCausalLM(FlashVlmCausalLM):
+ def set_inputs_embeds(self, batch):
+ # Set the input embeddings to None, as we are using the input_ids for the model
+ batch.inputs_embeds = None
+
+ def warmup_decode(
+ self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
+ ):
+ input_ids = torch.zeros(batch_size, dtype=batch.input_ids.dtype)
+ position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype)
+ blocks = [block_num // batch_size for _ in range(batch_size)]
+ blocks[0] += block_num % batch_size
+ block_tables = []
+ slots = []
+ start_idx = 0
+ slot_indices = []
+
+ # fetch the last blocked to warmup block num
+ for i in range(batch_size):
+ block_array = list(range(start_idx, start_idx + blocks[i]))
+ slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
+ block_tables.append(block_array)
+ slot_indices.append((start_idx + blocks[i]) * BLOCK_SIZE - 1)
+ start_idx += blocks[i]
+ input_lengths = torch.ones(batch_size, dtype=torch.int32)
+
+ seqlen = Seqlen(
+ input_lengths=_async_h2d_tensor_copy(input_lengths),
+ )
+ block_list, block_groups, block_usage, _, block_bucket_size = (
+ generate_block_metadata(
+ self.dtype,
+ self.use_contiguous_pa,
+ slots,
+ block_tables,
+ self.bucketing_ctx,
+ )
+ )
+ meta = HPUPagedAttentionMetadata(
+ block_list=_async_h2d_tensor_copy(block_list),
+ block_groups=_async_h2d_tensor_copy(block_groups),
+ block_usage=_async_h2d_tensor_copy(block_usage),
+ block_mapping=None,
+ attn_bias=None,
+ )
+
+ hpu_attention_meta = trim_attn_metadata(meta)
+ # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
+ image_indices = torch.tensor(batch.image_indices)
+ image_indices = image_indices.repeat(batch_size)
+ cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)
+ indices, cross_attention_len = generate_cross_attention_states(
+ cross_attention_states, image_indices, input_lengths, 1, False
+ )
+ slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
+ kwargs = {}
+ if htorch.utils.internal.is_lazy():
+ kwargs["bypass_hpu_graphs"] = not self.use_graphs(
+ False, hpu_attention_meta.block_list.shape[0], batch_size
+ )
+ self.model.forward(
+ input_ids=_async_h2d_tensor_copy(input_ids),
+ position_ids=_async_h2d_tensor_copy(position_ids),
+ cu_seqlen_prefill=None,
+ kv_cache=self.kv_cache,
+ slots=_async_h2d_tensor_copy(slots_tensor),
+ seqlen=trim_seqlen_metadata(seqlen),
+ hpu_attention_meta=hpu_attention_meta,
+ lm_head_indices=None,
+ adapter_data=None,
+ cross_attention_states=cross_attention_states,
+ indices=_async_h2d_tensor_copy(indices),
+ cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),
+ **kwargs,
+ )
+
+ def warmup_prefill(
+ self, prompt_len: int, batch_size: int, batch: FlashMllamaCausalLMBatch
+ ):
+ input_ids = torch.zeros(prompt_len, dtype=batch.input_ids.dtype).repeat(
+ batch_size
+ )
+ position_ids = torch.arange(prompt_len, dtype=batch.position_ids.dtype).repeat(
+ batch_size
+ )
+ max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size
+ block_tables = torch.arange(max_bt, dtype=torch.int32).reshape(batch_size, -1)
+ slot_acc = []
+ for i in range(batch_size):
+ slots = []
+ for b in block_tables[i]:
+ slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
+ slot_acc.extend(slots[:prompt_len])
+ slots = torch.tensor(slot_acc, dtype=batch.slots.dtype)
+
+ input_lengths = (
+ torch.ones(
+ batch_size,
+ dtype=torch.int32,
+ )
+ * prompt_len
+ )
+ cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32)
+ torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
+
+ lm_head_indices = input_lengths - 1
+
+ # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
+ image_indices = torch.tensor(batch.image_indices)
+ image_indices = image_indices.repeat(batch_size)
+ cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)
+ indices, cross_attention_len = generate_cross_attention_states(
+ cross_attention_states, image_indices, input_lengths, prompt_len, True
+ )
+ seqlen = Seqlen(
+ input_lengths=_async_h2d_tensor_copy(input_lengths),
+ )
+ kwargs = {}
+ if htorch.utils.internal.is_lazy():
+ kwargs["bypass_hpu_graphs"] = not self.use_graphs(
+ True, prompt_len, batch_size
+ )
+ self.model.forward(
+ input_ids=_async_h2d_tensor_copy(input_ids),
+ position_ids=_async_h2d_tensor_copy(position_ids),
+ cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
+ kv_cache=self.kv_cache,
+ slots=_async_h2d_tensor_copy(slots),
+ seqlen=trim_seqlen_metadata(seqlen),
+ hpu_attention_meta=None,
+ lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
+ adapter_data=None,
+ cross_attention_states=cross_attention_states,
+ indices=_async_h2d_tensor_copy(indices),
+ cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),
+ **kwargs,
+ )
+
+ def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
+ prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3"))
+ free_mem = HabanaMemoryProfiler.current_free_device_memory()
+ graph_free_mem = free_mem - self.mem_reserved
+ graph_free_mem = self.align_workers(
+ graph_free_mem, torch.distributed.ReduceOp.MIN
+ )
+ prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
+ decode_available_memory = graph_free_mem - prompt_available_memory
+ msg = (
+ f"Using {format_bytes(graph_free_mem)}"
+ f"/{format_bytes(free_mem)} "
+ "of free device memory for HPUGraphs, "
+ f"{format_bytes(prompt_available_memory)} for prompt and "
+ f"{format_bytes(decode_available_memory)} for decode "
+ f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})"
+ )
+ log_master(logger.info, msg)
+ start_time = time.time()
+ warmup_shape_count = 0
+ warmup_times = 3
+ self.bucketing_ctx.generate_prompt_buckets()
+
+ def ordering_function_min_tokens(b):
+ return (b[0] * b[1], b[1], b[0])
+
+ buckets = list(
+ sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
+ )
+ graph_free_mem
+ total_batch_seq = 0.001
+ total_mem = 0
+ available_mem = prompt_available_memory
+ msg = (
+ f"Prefill batch size list:{[bsz[0] for bsz in buckets]}\n"
+ f"Prefill sequence length list:{[seq[1] for seq in buckets]}\n"
+ )
+ log_master(logger.info, msg)
+ for i, (batch_size, seq_len) in enumerate(buckets):
+ if batch_size * seq_len > self.max_batch_prefill_tokens:
+ continue
+ # Graph memory usage is proportional to seq dimension in a batch
+ batch_seq = batch_size * seq_len
+ mem_estimate = batch_seq / total_batch_seq * total_mem
+ graphed_bucket = (batch_size, seq_len, True)
+ if not (
+ mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture
+ ):
+ if graphed_bucket not in self.graphed_buckets:
+ self.graphed_buckets.add(graphed_bucket)
+ warmup_shape_count += 1
+ self.log_warmup(True, i, len(buckets), batch_size, seq_len)
+ with HabanaMemoryProfiler() as mem_prof:
+ for index in range(warmup_times):
+ self.warmup_prefill(seq_len, batch_size, batch)
+ synchronize(self.device)
+ used_mem = self.align_workers(
+ mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
+ )
+ if graphed_bucket in self.graphed_buckets:
+ available_mem -= used_mem
+ total_mem += used_mem
+ total_batch_seq += batch_seq
+
+ log_master(logger.info, "Prefill warmup successful.\n")
+
+ def ordering_function_max_bs(b):
+ return (-b[0], b[1])
+
+ self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
+ buckets = list(
+ sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
+ )
+ free_mem = HabanaMemoryProfiler.current_free_device_memory()
+ total_batch_seq = 0.001
+ total_mem = 0
+ available_mem = free_mem - self.mem_reserved
+ log_master(
+ logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n"
+ )
+ for i, (batch_size, block_num) in enumerate(buckets):
+ if batch_size > block_num:
+ continue
+ # Graph memory usage is proportional to seq dimension in a batch
+ batch_seq = batch_size
+ mem_estimate = batch_seq / total_batch_seq * total_mem
+ graphed_bucket = (batch_size, block_num, False)
+ if not mem_estimate >= available_mem:
+ if graphed_bucket not in self.graphed_buckets:
+ self.graphed_buckets.add(graphed_bucket)
+ warmup_shape_count += 1
+ self.log_warmup(False, i, len(buckets), batch_size, block_num)
+ with HabanaMemoryProfiler() as mem_prof:
+ for index in range(warmup_times):
+ self.warmup_decode(batch_size, block_num, batch)
+ synchronize(self.device)
+ used_mem = self.align_workers(
+ mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
+ )
+ if graphed_bucket in self.graphed_buckets:
+ available_mem -= used_mem
+ total_mem += used_mem
+ total_batch_seq += batch_seq
+
+ log_master(logger.info, "Decode warmup successful.\n")
+
+ log_master(
+ logger.info,
+ f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
+ )
+
+ def forward(
+ self,
+ batch: FlashMllamaCausalLMBatch,
+ adapter_data: Optional[Dict[str, torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ # Model Forward
+ if batch.speculative_ids is not None:
+ input_ids = batch.input_ids
+ position_ids = batch.position_ids
+ cu_seqlen_prefill = batch.cu_seqlen_prefill
+ kv_cache = self.kv_cache
+ block_tables = batch.block_tables_tensor
+ slots = batch.slots[batch.slot_indices]
+ input_lengths = batch.input_lengths_tensor
+ max_s = batch.max_current_length
+ lm_head_indices = batch.prefill_head_indices
+
+ speculative_ids = batch.speculative_ids
+
+ B, speculative_length = speculative_ids.shape
+ new_length = speculative_length + 1
+ new_input_ids = torch.cat(
+ [input_ids.unsqueeze(-1), speculative_ids], dim=1
+ ).reshape(-1)
+ arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
+ arange_int = arange.to(dtype=torch.int32)
+ new_position_ids = (
+ position_ids.unsqueeze(-1).expand(B, new_length) + arange
+ ).view(-1)
+ slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
+ input_lengths = (
+ input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
+ ).view(-1)
+
+ # Add Copy the block tables for all members
+ block_tables = (
+ block_tables.unsqueeze(1)
+ .expand(B, new_length, -1)
+ .reshape(B * new_length, -1)
+ .contiguous()
+ )
+ max_s = max_s + speculative_length
+
+ input_ids = new_input_ids
+ position_ids = new_position_ids
+ else:
+ input_ids = batch.input_ids
+ position_ids = batch.position_ids
+ cu_seqlen_prefill = batch.cu_seqlen_prefill
+ kv_cache = self.kv_cache
+ block_tables = batch.block_tables_tensor
+ slots = batch.slots[batch.slot_indices]
+ input_lengths = batch.input_lengths_tensor
+ max_s = batch.max_current_length
+ lm_head_indices = batch.prefill_head_indices
+
+ if cu_seqlen_prefill is None and self.max_past() is not None:
+ # In decode, not prefill, we're actually overwriting the KV-cache
+ # in a circular buffer mode.
+ # This makes sure the max_s for the decode pass is correct.
+ max_s = min(self.max_past(), max_s)
+
+ if batch.pixel_values is not None:
+ cross_attention_states = self.model.vision_forward(
+ pixel_values=batch.pixel_values,
+ aspect_ratio_ids=batch.aspect_ratio_ids,
+ aspect_ratio_mask=batch.aspect_ratio_mask,
+ )
+ batch.cross_attention_states = cross_attention_states
+
+ cross_attention_states = batch.cross_attention_states
+
+ kwargs = {}
+ if htorch.utils.internal.is_lazy():
+ batch_size = input_lengths.shape[0]
+ seqlen = (
+ input_ids.shape[0] // batch_size
+ if batch.prefilling
+ else batch.hpu_attn_meta.block_list.shape[0]
+ )
+ kwargs["bypass_hpu_graphs"] = not self.use_graphs(
+ batch.prefilling, seqlen, batch_size
+ )
+
+ if batch.prefill_cache_indices is not None:
+ slots_pad = torch.zeros_like(input_ids, device=slots.device)
+ slots_pad[batch.prefill_cache_indices] = slots
+ slots = slots_pad
+ else:
+ slots_pad = torch.zeros_like(input_ids, device=slots.device)
+ slots_pad[: slots.shape[0]] = slots
+ slots = slots_pad
+ orig_bs = len(batch)
+ padded_bs = batch.input_lengths_tensor.shape[0]
+ padded_input_len = input_ids.view(padded_bs, -1).shape[-1]
+ image_indices = torch.tensor(batch.image_indices)
+
+ if cross_attention_states is not None:
+ cross_attention_states = F.pad(
+ cross_attention_states,
+ (0, 0, 0, 0, 0, (padded_bs - orig_bs)),
+ value=0,
+ )
+ if len(image_indices) != 0:
+ pad_indices = torch.arange(orig_bs, padded_bs)
+ image_indices = torch.cat((image_indices, pad_indices), dim=0)
+
+ indices, cross_attention_len = generate_cross_attention_states(
+ cross_attention_states,
+ image_indices,
+ input_lengths,
+ padded_input_len,
+ batch.prefilling,
+ )
+ seqlen = Seqlen(
+ input_lengths=_async_h2d_tensor_copy(input_lengths),
+ )
+ logits, speculative_logits = self.model.forward(
+ input_ids=input_ids,
+ position_ids=_async_h2d_tensor_copy(position_ids),
+ cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
+ kv_cache=kv_cache,
+ slots=_async_h2d_tensor_copy(slots),
+ seqlen=trim_seqlen_metadata(seqlen),
+ hpu_attention_meta=batch.hpu_attn_meta,
+ lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
+ # TODO list
+ adapter_data=None,
+ cross_attention_states=cross_attention_states,
+ indices=_async_h2d_tensor_copy(indices),
+ cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),
+ **kwargs,
+ )
+ if batch.pixel_values is not None:
+ batch.pixel_values = None
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/model.py b/backends/gaudi/server/text_generation_server/models/model.py
new file mode 100644
index 00000000000..b936bb69039
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/model.py
@@ -0,0 +1,142 @@
+import inspect
+import torch
+
+from abc import ABC, abstractmethod
+from typing import List, Tuple, Optional, TypeVar, Type, Dict
+from collections import defaultdict
+from transformers import PreTrainedTokenizerBase
+
+from text_generation_server.models.types import Batch, Generation
+from text_generation_server.models.globals import BLOCK_SIZE
+from text_generation_server.utils.speculate import get_speculate
+from text_generation_server.pb.generate_pb2 import InfoResponse
+from text_generation_server.adapters.weights import LayerAdapterWeights
+from text_generation_server.pb import generate_pb2
+
+BASE_MODEL_ADAPTER_ID = "__base_model__"
+
+
+B = TypeVar("B", bound=Batch)
+
+
+class Model(ABC):
+ def __init__(
+ self,
+ model_id: str,
+ model: torch.nn.Module,
+ tokenizer: PreTrainedTokenizerBase,
+ requires_padding: bool,
+ dtype: torch.dtype,
+ device: torch.device,
+ rank: int = 0,
+ world_size: int = 1,
+ sliding_window: Optional[int] = None,
+ speculate: Optional[int] = None,
+ adapter_id: str = BASE_MODEL_ADAPTER_ID,
+ support_chunking: bool = False,
+ ):
+ self.model_id = model_id
+ self.model = model.eval()
+ self.tokenizer = tokenizer
+
+ # all_special_ids is not set correctly if the rust tokenizer is unpacked
+ # TODO report this to transformers.
+ other_special_ids = {
+ id for id, token in tokenizer.added_tokens_decoder.items() if token.special
+ }
+ self.all_special_ids = set(tokenizer.all_special_ids)
+ self.all_special_ids.update(other_special_ids)
+ self.requires_padding = requires_padding
+ self.dtype = dtype
+ self.device = device
+ self.rank = rank
+ self.world_size = world_size
+ self.sliding_window = sliding_window if sliding_window != -1 else None
+
+ self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
+ LayerAdapterWeights
+ )
+ self.loaded_adapters = set()
+ self.static_adapter_id = adapter_id
+
+ if speculate is None:
+ speculate = get_speculate()
+ self.speculate = speculate
+
+ self.has_position_ids = (
+ inspect.signature(model.forward).parameters.get("position_ids", None)
+ is not None
+ )
+
+ self.check_initialized()
+
+ @property
+ def info(self) -> InfoResponse:
+ if self.requires_padding and self.sliding_window is not None:
+ raise NotImplementedError("sliding_window is not implemented with padding")
+
+ return InfoResponse(
+ requires_padding=self.requires_padding,
+ dtype=str(self.dtype),
+ device_type=self.device.type,
+ window_size=None,
+ speculate=self.speculate,
+ block_size=BLOCK_SIZE,
+ )
+
+ @property
+ @abstractmethod
+ def batch_type(self) -> Type[B]:
+ raise NotImplementedError
+
+ @abstractmethod
+ def generate_token(
+ self, batch: B
+ ) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
+ raise NotImplementedError
+
+ def warmup(
+ self, batch: generate_pb2.WarmupRequest
+ ) -> Tuple[Optional[int], Optional[int], Optional[int]]:
+ self.generate_token(batch)
+ return None, None, None
+
+ def decode_token(
+ self,
+ all_input_ids: List[int],
+ prefix_offset: int = 0,
+ read_offset: int = 0,
+ skip_special_tokens: bool = False,
+ ) -> Tuple[str, int, int]:
+ """Hack to hopefully support generate_stream for the maximum number of tokenizers"""
+
+ # The prefix text is necessary only to defeat cleanup algorithms in the decode
+ # which decide to add a space or not depending on the surrounding ids.
+ prefix_text = self.tokenizer.decode(
+ all_input_ids[prefix_offset:read_offset],
+ skip_special_tokens=skip_special_tokens,
+ )
+
+ new_text = self.tokenizer.decode(
+ all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens
+ )
+
+ if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
+ # utf-8 char at the end means it's a potential unfinished byte sequence
+ # from byte fallback tokenization.
+ # If it's in the middle, it's probably a real invalid id generated
+ # by the model
+ new_text = new_text[len(prefix_text) :]
+ return new_text, read_offset, len(all_input_ids)
+ else:
+ return "", prefix_offset, read_offset
+
+ def check_initialized(self):
+ uninitialized_parameters = []
+ for n, p in self.model.named_parameters():
+ if p.data.device == torch.device("meta"):
+ uninitialized_parameters.append(n)
+ if uninitialized_parameters:
+ raise RuntimeError(
+ f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
+ )
diff --git a/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py b/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py
new file mode 100644
index 00000000000..0ee6ed16744
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py
@@ -0,0 +1,920 @@
+import torch
+import torch.distributed
+import time
+from dataclasses import dataclass
+from opentelemetry import trace
+from transformers import (
+ AutoTokenizer,
+ AutoModelForSeq2SeqLM,
+ PreTrainedTokenizerBase,
+ AutoConfig,
+)
+from typing import Optional, Tuple, List, Type, Dict
+from text_generation_server.utils import (
+ initialize_torch_distributed,
+ weight_files,
+ Weights,
+)
+from text_generation_server.utils.chunks import concat_text_chunks
+from text_generation_server.utils.quantization import get_loader
+from text_generation_server.utils.tokens import batch_top_tokens
+from text_generation_server.models import Model
+from text_generation_server.models.types import (
+ GeneratedText,
+ Batch,
+ Generation,
+ Tokens,
+)
+from text_generation_server.pb import generate_pb2
+from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
+
+tracer = trace.get_tracer(__name__)
+
+
+@dataclass
+class Seq2SeqLMBatch(Batch):
+ batch_id: int
+ requests: List[generate_pb2.Request]
+ requests_idx_mapping: Dict[int, int]
+
+ # Encoder values
+ input_ids: Optional[torch.Tensor]
+ attention_mask: torch.Tensor
+
+ # Decoder values
+ decoder_input_ids: torch.Tensor
+ decoder_attention_mask: Optional[torch.Tensor]
+ encoder_last_hidden_state: Optional[torch.Tensor]
+
+ # All tokens
+ all_decoder_input_ids: List[torch.Tensor]
+
+ # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
+ past_key_values: Optional[List[Tuple]]
+
+ # Lengths of all generations present in the batch
+ input_lengths: List[int]
+ decoder_input_lengths: List[int]
+ prefix_offsets: List[int]
+ read_offsets: List[int]
+
+ # Generation helpers
+ next_token_choosers: List[NextTokenChooser]
+ stopping_criterias: List[StoppingCriteria]
+ top_n_tokens: List[int]
+ top_n_tokens_tensor: torch.Tensor
+
+ # Metadata used for padding
+ max_input_length: int
+ max_decoder_input_length: int
+ padding_right_offset: int
+
+ # Maximum number of tokens this batch will grow to
+ max_tokens: int
+
+ def to_pb(self) -> generate_pb2.CachedBatch:
+ """Convert a Seq2SeqLMBatch to a text_generation_server.v1.CachedBatch protobuf"""
+ return generate_pb2.CachedBatch(
+ id=self.batch_id,
+ request_ids=[r.id for r in self.requests],
+ size=len(self),
+ max_tokens=self.max_tokens,
+ )
+
+ @classmethod
+ def from_pb(
+ cls,
+ pb: generate_pb2.Batch,
+ tokenizer: PreTrainedTokenizerBase,
+ dtype: torch.dtype,
+ device: torch.device,
+ ) -> "Seq2SeqLMBatch":
+ """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
+ inputs = []
+ next_token_choosers = []
+ stopping_criterias = []
+ top_n_tokens = []
+ decoder_input_lengths = []
+ prefix_offsets = []
+ read_offsets = []
+ requests_idx_mapping = {}
+
+ # Parse batch
+ max_truncation = 0
+ padding_right_offset = 0
+ max_decode_tokens = 0
+ for i, r in enumerate(pb.requests):
+ inputs.append(concat_text_chunks(r.input_chunks.chunks))
+ requests_idx_mapping[r.id] = i
+ decoder_input_lengths.append(1)
+ next_token_choosers.append(
+ NextTokenChooser.from_pb(r.parameters, device, tokenizer)
+ )
+ stopping_criteria = StoppingCriteria.from_pb(
+ r.stopping_parameters, tokenizer
+ )
+ stopping_criterias.append(stopping_criteria)
+ top_n_tokens.append(r.top_n_tokens)
+ max_truncation = max(max_truncation, r.truncate)
+ max_decode_tokens += stopping_criteria.max_new_tokens
+ padding_right_offset = max(
+ padding_right_offset, stopping_criteria.max_new_tokens
+ )
+
+ # Tokenize batch
+ tokenized_inputs = tokenizer(
+ inputs,
+ return_tensors="pt",
+ padding=True,
+ return_token_type_ids=False,
+ truncation=True,
+ max_length=max_truncation,
+ ).to(device)
+
+ input_lengths = tokenized_inputs["attention_mask"].sum(1)
+ max_input_length = input_lengths.max()
+
+ # Decoder sequence only contains the bos_token
+ decoder_input_ids = (
+ torch.tensor(tokenizer.bos_token_id, device=device)
+ .repeat(len(pb.requests))
+ .view(-1, 1)
+ )
+ for _ in pb.requests:
+ prefix_offsets.append(0)
+ read_offsets.append(1)
+ all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
+ top_n_tokens_tensor = torch.tensor(
+ top_n_tokens, device=device, dtype=torch.int64
+ )
+
+ max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
+
+ return cls(
+ batch_id=pb.id,
+ requests=pb.requests,
+ requests_idx_mapping=requests_idx_mapping,
+ input_ids=tokenized_inputs["input_ids"],
+ attention_mask=tokenized_inputs["attention_mask"],
+ decoder_input_ids=decoder_input_ids,
+ all_decoder_input_ids=list(all_decoder_input_ids),
+ decoder_attention_mask=None,
+ encoder_last_hidden_state=None,
+ past_key_values=None,
+ input_lengths=input_lengths.tolist(),
+ decoder_input_lengths=decoder_input_lengths,
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
+ next_token_choosers=next_token_choosers,
+ stopping_criterias=stopping_criterias,
+ top_n_tokens=top_n_tokens,
+ top_n_tokens_tensor=top_n_tokens_tensor,
+ max_input_length=max_input_length.item(),
+ max_decoder_input_length=1,
+ padding_right_offset=padding_right_offset,
+ max_tokens=max_tokens,
+ )
+
+ @tracer.start_as_current_span("filter")
+ def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]:
+ if len(request_ids) == 0:
+ raise ValueError("Batch must have at least one request")
+ if len(request_ids) == len(self):
+ return self
+
+ keep_indices = []
+
+ # New values after filtering
+ requests_idx_mapping = {}
+ requests = []
+ input_lengths = []
+ decoder_input_lengths = []
+ prefix_offsets = []
+ read_offsets = []
+
+ all_decoder_input_ids = []
+
+ next_token_choosers = []
+ stopping_criterias = []
+ top_n_tokens = []
+
+ max_input_length = 0
+ max_decoder_input_length = 0
+ padding_right_offset = 0
+
+ total_remaining_decode_tokens = 0
+
+ for i, request_id in enumerate(request_ids):
+ idx = self.requests_idx_mapping[request_id]
+ requests_idx_mapping[request_id] = i
+ keep_indices.append(idx)
+
+ requests.append(self.requests[idx])
+ prefix_offsets.append(self.prefix_offsets[idx])
+ read_offsets.append(self.read_offsets[idx])
+
+ all_decoder_input_ids.append(self.all_decoder_input_ids[idx])
+
+ request_input_length = self.input_lengths[idx]
+ input_lengths.append(request_input_length)
+ max_input_length = max(max_input_length, request_input_length)
+
+ request_decoder_input_length = self.decoder_input_lengths[idx]
+ decoder_input_lengths.append(request_decoder_input_length)
+ max_decoder_input_length = max(
+ max_decoder_input_length, request_decoder_input_length
+ )
+
+ next_token_choosers.append(self.next_token_choosers[idx])
+ stopping_criteria = self.stopping_criterias[idx]
+ stopping_criterias.append(stopping_criteria)
+ top_n_tokens.append(self.top_n_tokens[idx])
+ remaining_decode_tokens = (
+ stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
+ )
+ total_remaining_decode_tokens += remaining_decode_tokens
+ padding_right_offset = max(padding_right_offset, remaining_decode_tokens)
+
+ # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
+ self.decoder_input_ids = self.decoder_input_ids[keep_indices]
+ self.attention_mask = self.attention_mask[keep_indices, -max_input_length:]
+ if self.decoder_attention_mask is not None:
+ self.decoder_attention_mask = self.decoder_attention_mask[
+ keep_indices,
+ -(self.padding_right_offset + max_decoder_input_length) : (
+ self.decoder_attention_mask.shape[1] - self.padding_right_offset
+ )
+ + padding_right_offset,
+ ]
+
+ self.encoder_last_hidden_state = self.encoder_last_hidden_state[
+ keep_indices, -max_input_length:
+ ]
+
+ # Ensure that past_key_values tensors can be updated in-place
+ if type(self.past_key_values[0]) is tuple:
+ self.past_key_values = [
+ [t for t in layer] for layer in self.past_key_values
+ ]
+
+ decoder_past_seq_len = max_decoder_input_length - 1
+ for layer in self.past_key_values:
+ layer[0] = layer[0][keep_indices, :, -decoder_past_seq_len:]
+ layer[1] = layer[1][keep_indices, :, -decoder_past_seq_len:]
+ layer[2] = layer[2][keep_indices, :, -max_input_length:]
+ layer[3] = layer[3][keep_indices, :, -max_input_length:]
+
+ top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
+ max_tokens = (
+ len(request_ids) * (max_input_length + max_decoder_input_length)
+ + remaining_decode_tokens
+ )
+
+ self.requests = requests
+ self.requests_idx_mapping = requests_idx_mapping
+ self.input_ids = None
+ self.all_decoder_input_ids = all_decoder_input_ids
+ self.input_lengths = input_lengths
+ self.decoder_input_lengths = decoder_input_lengths
+ self.prefix_offsets = prefix_offsets
+ self.read_offsets = read_offsets
+ self.next_token_choosers = next_token_choosers
+ self.stopping_criterias = stopping_criterias
+ self.top_n_tokens = top_n_tokens
+ self.top_n_tokens_tensor = top_n_tokens_tensor
+ self.max_input_length = max_input_length
+ self.max_decoder_input_length = max_decoder_input_length
+ self.padding_right_offset = padding_right_offset
+ self.max_tokens = max_tokens
+
+ return self
+
+ @classmethod
+ @tracer.start_as_current_span("concatenate")
+ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
+ """Concatenate multiple batches together by padding internal torch tensors"""
+
+ # Used for padding
+ total_batch_size = 0
+ max_input_length = 0
+ max_decoder_input_length = 0
+ padding_right_offset = 0
+ for batch in batches:
+ total_batch_size += len(batch)
+ max_input_length = max(max_input_length, batch.max_input_length)
+ max_decoder_input_length = max(
+ max_decoder_input_length, batch.max_decoder_input_length
+ )
+ padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
+
+ # Batch attributes
+ requests = []
+ requests_idx_mapping = {}
+ all_decoder_input_ids = []
+ input_lengths = []
+ decoder_input_lengths = []
+ prefix_offsets = []
+ read_offsets = []
+ next_token_choosers = []
+ stopping_criterias = []
+ top_n_tokens = []
+ max_tokens = 0
+
+ # Batch tensors
+ attention_mask = None
+ decoder_input_ids = None
+ decoder_attention_mask = None
+ encoder_last_hidden_state = None
+ top_n_tokens_tensor = None
+ past_key_values = []
+
+ # Used for slicing correctly inside the tensors
+ # Equivalent to a cumsum on batch sizes
+ start_index = 0
+
+ for i, batch in enumerate(batches):
+ # Extend all list attributes
+ requests.extend(batch.requests)
+ all_decoder_input_ids.extend(batch.all_decoder_input_ids)
+ input_lengths.extend(batch.input_lengths)
+ decoder_input_lengths.extend(batch.decoder_input_lengths)
+ prefix_offsets.extend(batch.prefix_offsets)
+ read_offsets.extend(batch.read_offsets)
+ next_token_choosers.extend(batch.next_token_choosers)
+ stopping_criterias.extend(batch.stopping_criterias)
+ top_n_tokens.extend(batch.top_n_tokens)
+
+ if i == 0:
+ requests_idx_mapping = batch.requests_idx_mapping
+ else:
+ # We need to offset the mapping for each batch by the cumulative batch size
+ for k, v in batch.requests_idx_mapping.items():
+ requests_idx_mapping[k] = v + start_index
+
+ # Slicing end index for this batch
+ end_index = start_index + len(batch)
+
+ # We only concatenate batches that did at least one step
+ if batch.encoder_last_hidden_state is None:
+ raise ValueError("Batch encoder_last_hidden_state cannot be None")
+
+ # Create padded tensor
+ if attention_mask is None:
+ attention_mask = batch.attention_mask.new_zeros(
+ (total_batch_size, max_input_length),
+ )
+ # Copy to correct indices
+ attention_mask[start_index:end_index, -batch.max_input_length :] = (
+ batch.attention_mask[:, -batch.max_input_length :]
+ )
+
+ # Create padded tensor
+ if decoder_input_ids is None:
+ decoder_input_ids = batch.decoder_input_ids.new_zeros(
+ (total_batch_size, 1),
+ )
+ # Copy to correct indices
+ decoder_input_ids[start_index:end_index] = batch.decoder_input_ids
+
+ # Create padded tensor
+ if decoder_attention_mask is None:
+ # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
+ decoder_attention_mask = batch.attention_mask.new_zeros(
+ (total_batch_size, max_decoder_input_length + padding_right_offset),
+ )
+ # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
+ # this batch. All generations are of length `batch.max_decoder_input_length`.
+ left_offset = max_decoder_input_length - batch.max_decoder_input_length
+ if batch.decoder_attention_mask is None:
+ decoder_attention_mask[
+ start_index:end_index,
+ left_offset:-padding_right_offset,
+ ] = 1
+ # If it exists, we need to index
+ else:
+ batch_left_offset = (
+ batch.decoder_attention_mask.shape[1]
+ - batch.max_decoder_input_length
+ - batch.padding_right_offset
+ )
+ decoder_attention_mask[
+ start_index:end_index,
+ left_offset:-padding_right_offset,
+ ] = batch.decoder_attention_mask[
+ :,
+ batch_left_offset : -batch.padding_right_offset,
+ ]
+
+ # Create padded tensor
+ if encoder_last_hidden_state is None:
+ encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
+ (
+ total_batch_size,
+ max_input_length,
+ batch.encoder_last_hidden_state.shape[-1],
+ ),
+ )
+
+ if top_n_tokens_tensor is None:
+ top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
+ total_batch_size,
+ )
+ top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
+
+ # Copy to correct indices
+ encoder_last_hidden_state[
+ start_index:end_index, -batch.max_input_length :, :
+ ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
+ batch.encoder_last_hidden_state = None
+
+ # Ensure that we can update tensors in-place
+ if isinstance(batch.past_key_values[0], tuple):
+ batch.past_key_values = [
+ [t for t in layer] for layer in batch.past_key_values
+ ]
+
+ # Add eventual padding tokens that were added while concatenating
+ max_tokens += batch.max_tokens + (
+ max_input_length
+ - batch.max_input_length
+ + max_decoder_input_length
+ - batch.max_decoder_input_length
+ ) * len(batch)
+
+ start_index = end_index
+
+ # Determine shapes for new past kv tensors
+ first_past_kvs = batches[0].past_key_values
+ _, num_heads, _, head_dim = first_past_kvs[0][0].shape
+
+ padded_dec_t_shape = (
+ total_batch_size,
+ num_heads,
+ (max_decoder_input_length - 1),
+ head_dim,
+ )
+
+ padded_enc_t_shape = (
+ total_batch_size,
+ num_heads,
+ max_input_length,
+ head_dim,
+ )
+
+ # Iterate over attention layers
+ for j in range(len(first_past_kvs)):
+ past_key_values.append([])
+
+ # Decoder past
+ for k in range(0, 2):
+ # Initialize tensors
+ padded_past_values = first_past_kvs[j][k].new_zeros(padded_dec_t_shape)
+ past_key_values[j].append(padded_past_values)
+
+ start_index = 0
+ for batch in batches:
+ t = batch.past_key_values[j][k]
+ # Clear reference to the original tensor
+ batch.past_key_values[j][k] = None
+ # Slicing end index for this batch
+ end_index = start_index + len(batch)
+ # We slice the past keys and values to remove the padding from previous batches
+ past_seq_len = batch.max_decoder_input_length - 1
+ padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[
+ :, :, -past_seq_len:, :
+ ]
+ del t
+
+ start_index = end_index
+
+ # Encoder past
+ for k in range(2, 4):
+ # Initialize tensors
+ padded_past_values = first_past_kvs[j][k].new_zeros(padded_enc_t_shape)
+ past_key_values[j].append(padded_past_values)
+
+ start_index = 0
+ for batch in batches:
+ t = batch.past_key_values[j][k]
+ # Clear reference to the original tensor
+ batch.past_key_values[j][k] = None
+ # Slicing end index for this batch
+ end_index = start_index + len(batch)
+ # We slice the past keys and values to remove the padding from previous batches
+ padded_past_values[
+ start_index:end_index, :, -batch.max_input_length :, :
+ ] = t[:, :, -batch.max_input_length :, :]
+ del t
+
+ start_index = end_index
+
+ return cls(
+ batch_id=batches[0].batch_id,
+ requests=requests,
+ requests_idx_mapping=requests_idx_mapping,
+ input_ids=None,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ all_decoder_input_ids=all_decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ encoder_last_hidden_state=encoder_last_hidden_state,
+ past_key_values=past_key_values,
+ input_lengths=input_lengths,
+ decoder_input_lengths=decoder_input_lengths,
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
+ next_token_choosers=next_token_choosers,
+ stopping_criterias=stopping_criterias,
+ top_n_tokens=top_n_tokens,
+ top_n_tokens_tensor=top_n_tokens_tensor,
+ max_input_length=max_input_length,
+ max_decoder_input_length=max_decoder_input_length,
+ padding_right_offset=padding_right_offset,
+ max_tokens=max_tokens,
+ )
+
+ def __len__(self):
+ return len(self.requests)
+
+
+class Seq2SeqLM(Model):
+ def __init__(
+ self,
+ model_id: str,
+ model_class,
+ revision: Optional[str] = None,
+ quantize: Optional[str] = None,
+ speculator: Optional[str] = None,
+ dtype: Optional[torch.dtype] = None,
+ default_dtype=torch.float16,
+ trust_remote_code: bool = False,
+ config_class=AutoConfig,
+ tokenizer_class=AutoTokenizer,
+ aliases=None,
+ ):
+ self.quantize = quantize
+ self.process_group, rank, world_size = initialize_torch_distributed()
+
+ device = torch.device("hpu")
+ dtype = torch.bfloat16 if dtype is None else dtype
+
+ config = config_class.from_pretrained(
+ model_id,
+ revision=revision,
+ trust_remote_code=trust_remote_code,
+ )
+ config.quantize = quantize
+ config.speculator = speculator
+
+ tokenizer = tokenizer_class.from_pretrained(
+ model_id,
+ revision=revision,
+ padding_side="left",
+ truncation_side="left",
+ trust_remote_code=trust_remote_code,
+ )
+ tokenizer.bos_token_id = config.decoder_start_token_id
+
+ weights_loader = get_loader(
+ quantize=quantize, model_id=model_id, revision=revision
+ )
+ torch.distributed.barrier(group=self.process_group)
+ filenames = weight_files(model_id, revision=revision, extension=".safetensors")
+ weights = Weights(
+ filenames,
+ device=device,
+ dtype=dtype,
+ process_group=self.process_group,
+ aliases=aliases,
+ weights_loader=weights_loader,
+ )
+ if config.quantize in ["awq", "gptq"]:
+ weights._set_gptq_params(model_id, revision)
+
+ model = model_class(config, weights)
+
+ torch.distributed.barrier(group=self.process_group)
+ super().__init__(
+ model_id=model_id,
+ model=model,
+ tokenizer=tokenizer,
+ requires_padding=True,
+ dtype=dtype,
+ device=device,
+ rank=rank,
+ world_size=world_size,
+ )
+
+ @classmethod
+ def fallback(
+ cls,
+ model_id: str,
+ revision: Optional[str] = None,
+ quantize: Optional[str] = None,
+ speculator: Optional[str] = None,
+ dtype: Optional[torch.dtype] = None,
+ trust_remote_code: bool = False,
+ ):
+ if speculator:
+ raise RuntimeError("Speculator decoding is not enabled for AutoModel")
+
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
+ dtype = torch.float16 if dtype is None else dtype
+ else:
+ if quantize:
+ raise ValueError("quantization is not available on CPU")
+
+ device = torch.device("cpu")
+ dtype = torch.float32 if dtype is None else dtype
+
+ model = AutoModelForSeq2SeqLM.from_pretrained(
+ model_id,
+ revision=revision,
+ torch_dtype=dtype,
+ device_map=(
+ "auto"
+ if torch.cuda.is_available() and torch.cuda.device_count() > 1
+ else None
+ ),
+ load_in_8bit=quantize == "bitsandbytes",
+ trust_remote_code=trust_remote_code,
+ )
+ if torch.cuda.is_available() and torch.cuda.device_count() == 1:
+ model = model.cuda()
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id,
+ revision=revision,
+ padding_side="left",
+ truncation_side="left",
+ trust_remote_code=trust_remote_code,
+ )
+ tokenizer.bos_token_id = model.config.decoder_start_token_id
+
+ self = cls.__new__(
+ cls,
+ )
+ super().__init__(
+ self,
+ model_id=model_id,
+ model=model,
+ tokenizer=tokenizer,
+ requires_padding=True,
+ dtype=dtype,
+ device=device,
+ )
+ self.quantize = quantize
+ return self
+
+ @property
+ def batch_type(self) -> Type[Seq2SeqLMBatch]:
+ return Seq2SeqLMBatch
+
+ def forward(
+ self,
+ input_ids,
+ attention_mask,
+ decoder_input_ids,
+ decoder_attention_mask: Optional,
+ encoder_last_hidden_state: Optional,
+ past_key_values: Optional = None,
+ ) -> Tuple[
+ torch.Tensor,
+ Optional[torch.Tensor],
+ torch.Tensor,
+ List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
+ ]:
+ # Model Forward
+ outputs = self.model.forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ encoder_outputs=encoder_last_hidden_state,
+ past_key_values=past_key_values,
+ use_cache=True,
+ )
+ if isinstance(outputs, tuple):
+ # Our custom models
+ outputs, speculative_logits = outputs
+ else:
+ # Generic transformers models
+ speculative_logits = None
+ return (
+ outputs.logits,
+ speculative_logits,
+ outputs.encoder_last_hidden_state,
+ outputs.past_key_values,
+ )
+
+ @tracer.start_as_current_span("generate_token")
+ def generate_token(
+ self, batch: Seq2SeqLMBatch
+ ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch], Tuple[int, int]]:
+ start = time.time_ns()
+ if batch.decoder_attention_mask is not None:
+ # slice to the correct shape
+ decoder_attention_mask = batch.decoder_attention_mask[
+ :, : -batch.padding_right_offset
+ ]
+ else:
+ decoder_attention_mask = None
+
+ # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
+ # internally...
+ if batch.encoder_last_hidden_state is not None:
+ encoder_last_hidden_state = [batch.encoder_last_hidden_state]
+ else:
+ encoder_last_hidden_state = None
+
+ logits, speculative_logits, encoder_last_hidden_state, past = self.forward(
+ batch.input_ids,
+ batch.attention_mask,
+ batch.decoder_input_ids,
+ decoder_attention_mask,
+ encoder_last_hidden_state,
+ batch.past_key_values,
+ )
+
+ # Speculation is not active for seq2seq
+ accepted_ids = torch.ones_like(batch.decoder_input_ids)[:, 0]
+ batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
+ batch.top_n_tokens,
+ batch.top_n_tokens_tensor,
+ torch.log_softmax(logits[:, -1], -1),
+ accepted_ids,
+ )
+
+ start_decode = time.time_ns()
+
+ # Finished requests
+ generations: List[Generation] = []
+ stopped = True
+
+ # Zipped iterator
+ iterator = zip(
+ batch.requests,
+ batch.input_lengths,
+ batch.prefix_offsets,
+ batch.read_offsets,
+ batch.decoder_input_lengths,
+ logits,
+ batch.next_token_choosers,
+ batch.stopping_criterias,
+ batch.all_decoder_input_ids,
+ batch.top_n_tokens,
+ batch_top_token_ids,
+ batch_top_token_logprobs,
+ )
+
+ # For each member of the batch
+ for i, (
+ request,
+ input_length,
+ prefix_offset,
+ read_offset,
+ decoder_input_length,
+ logits,
+ next_token_chooser,
+ stopping_criteria,
+ all_decoder_input_ids,
+ top_n_tokens,
+ top_token_ids,
+ top_token_logprobs,
+ ) in enumerate(iterator):
+ # Select next token
+ next_token_id, logprobs = next_token_chooser(
+ all_decoder_input_ids.view(1, -1), logits[-1:, :]
+ )
+
+ # Append next token to decoder tokens
+ all_decoder_input_ids = torch.cat(
+ [all_decoder_input_ids, next_token_id.squeeze(1)]
+ )
+ new_decoder_input_length = decoder_input_length + 1
+
+ # Generated token
+ next_token_logprob = logprobs[-1, next_token_id]
+ next_token_id_squeezed = next_token_id.squeeze()
+ next_token_text, prefix_offset, read_offset = self.decode_token(
+ all_decoder_input_ids, prefix_offset, read_offset
+ )
+
+ # Evaluate stopping criteria
+ stop, reason = stopping_criteria(next_token_id, next_token_text)
+
+ if not stop:
+ stopped = False
+
+ # Shard generations
+ # All generations will be appended in the rust sharded client
+ if i % self.world_size == self.rank:
+ if stop:
+ # Slice with decoder_input_length to remove padding
+ # Decode all tokens
+ output_text, _, _ = self.decode_token(
+ all_decoder_input_ids,
+ prefix_offset=len(all_decoder_input_ids)
+ - decoder_input_length
+ - 1,
+ read_offset=len(all_decoder_input_ids) - decoder_input_length,
+ skip_special_tokens=True,
+ )
+
+ # Get seed
+ if isinstance(next_token_chooser.choice, Sampling):
+ seed = next_token_chooser.choice.seed
+ else:
+ seed = None
+
+ generated_text = GeneratedText(
+ output_text, stopping_criteria.current_tokens, reason, seed
+ )
+ else:
+ generated_text = None
+
+ # Prefill
+ if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
+ prefill_tokens = Tokens(
+ [self.tokenizer.bos_token_id],
+ [float("nan")],
+ [self.tokenizer.bos_token],
+ [False],
+ )
+ else:
+ prefill_tokens = None
+
+ if top_n_tokens > 0:
+ all_top_tokens = []
+ for top_token_ids, top_token_logprobs in zip(
+ top_token_ids, top_token_logprobs
+ ):
+ toptoken_texts = self.tokenizer.batch_decode(
+ top_token_ids,
+ clean_up_tokenization_spaces=False,
+ skip_special_tokens=False,
+ )
+ special_toptokens = [
+ token_id in self.all_special_ids
+ for token_id in top_token_ids
+ ]
+ top_tokens = Tokens(
+ top_token_ids,
+ top_token_logprobs,
+ toptoken_texts,
+ special_toptokens,
+ )
+ all_top_tokens.append(top_tokens)
+ top_tokens = all_top_tokens
+ else:
+ top_tokens = None
+
+ generation = Generation(
+ request.id,
+ prefill_tokens,
+ Tokens(
+ [next_token_id_squeezed],
+ [next_token_logprob],
+ [next_token_text],
+ [next_token_id_squeezed.item() in self.all_special_ids],
+ ),
+ generated_text,
+ top_tokens,
+ )
+
+ generations.append(generation)
+
+ # Update values
+ batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
+ next_token_id_squeezed.item()
+ )
+ batch.decoder_input_ids[i] = next_token_id
+ batch.all_decoder_input_ids[i] = all_decoder_input_ids
+ batch.input_lengths[i] = input_length
+ batch.decoder_input_lengths[i] = new_decoder_input_length
+ batch.prefix_offsets[i] = prefix_offset
+ batch.read_offsets[i] = read_offset
+ batch.max_input_length = max(batch.max_input_length, input_length)
+ batch.max_decoder_input_length = max(
+ batch.max_decoder_input_length, new_decoder_input_length
+ )
+
+ # We finished all generations in the batch; there is no next batch
+ if stopped:
+ forward_ns = start_decode - start
+ decode_ns = time.time_ns() - start_decode
+ return generations, None, (forward_ns, decode_ns)
+
+ # We don't need input_ids after the prefill forward
+ batch.input_ids = None
+ batch.encoder_last_hidden_state = encoder_last_hidden_state
+ batch.past_key_values = past
+ # Update decoder_attention_mask as we added a new token to input_ids
+ if batch.decoder_attention_mask is not None:
+ batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
+ batch.padding_right_offset -= 1
+
+ forward_ns = start_decode - start
+ decode_ns = time.time_ns() - start_decode
+ return generations, batch, (forward_ns, decode_ns)
diff --git a/backends/gaudi/server/text_generation_server/models/types.py b/backends/gaudi/server/text_generation_server/models/types.py
new file mode 100644
index 00000000000..d4e7cca7504
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/types.py
@@ -0,0 +1,102 @@
+import torch
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import List, Optional
+
+from transformers import PreTrainedTokenizerBase
+
+from text_generation_server.pb import generate_pb2
+from text_generation_server.pb.generate_pb2 import FinishReason
+
+
+class Batch(ABC):
+ @abstractmethod
+ def to_pb(self) -> generate_pb2.CachedBatch:
+ raise NotImplementedError
+
+ @classmethod
+ @abstractmethod
+ def from_pb(
+ cls,
+ pb: generate_pb2.Batch,
+ tokenizer: PreTrainedTokenizerBase,
+ dtype: torch.dtype,
+ device: torch.device,
+ ) -> "Batch":
+ raise NotImplementedError
+
+ @abstractmethod
+ def filter(self, request_ids: List[int]) -> "Batch":
+ raise NotImplementedError
+
+ @classmethod
+ @abstractmethod
+ def concatenate(cls, batches: List["Batch"]) -> "Batch":
+ raise NotImplementedError
+
+ @abstractmethod
+ def __len__(self):
+ raise NotImplementedError
+
+
+@dataclass
+class GeneratedText:
+ text: str
+ generated_tokens: int
+ finish_reason: FinishReason
+ seed: Optional[int]
+
+ def to_pb(self) -> generate_pb2.GeneratedText:
+ return generate_pb2.GeneratedText(
+ text=self.text,
+ generated_tokens=self.generated_tokens,
+ finish_reason=self.finish_reason,
+ seed=self.seed,
+ )
+
+
+@dataclass
+class Tokens:
+ token_ids: List[int]
+ logprobs: List[float]
+ texts: List[str]
+ is_special: List[bool]
+
+ def to_pb(self) -> generate_pb2.Tokens:
+ return generate_pb2.Tokens(
+ ids=self.token_ids,
+ logprobs=self.logprobs,
+ texts=self.texts,
+ is_special=self.is_special,
+ )
+
+ def __len__(self):
+ return len(self.token_ids)
+
+
+@dataclass
+class Generation:
+ request_id: int
+ prefill_tokens: Optional[Tokens]
+ tokens: Tokens
+ generated_text: Optional[GeneratedText]
+ # Optional for now, since it's not yet supported for every model.
+ top_tokens: Optional[List[Tokens]]
+
+ def to_pb(self) -> generate_pb2.Generation:
+ return generate_pb2.Generation(
+ request_id=self.request_id,
+ prefill_tokens=(
+ self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None
+ ),
+ tokens=self.tokens.to_pb(),
+ generated_text=(
+ self.generated_text.to_pb() if self.generated_text is not None else None
+ ),
+ top_tokens=(
+ [top_tokens.to_pb() for top_tokens in self.top_tokens]
+ if self.top_tokens is not None
+ else None
+ ),
+ )
diff --git a/backends/gaudi/server/text_generation_server/pb/.gitignore b/backends/gaudi/server/text_generation_server/pb/.gitignore
new file mode 100644
index 00000000000..5a68d631354
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/pb/.gitignore
@@ -0,0 +1,3 @@
+*.py
+*.pyi
+*.py-e
diff --git a/backends/gaudi/server/text_generation_server/server.py b/backends/gaudi/server/text_generation_server/server.py
new file mode 100644
index 00000000000..f5080ec3ab0
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/server.py
@@ -0,0 +1,317 @@
+# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
+
+import asyncio
+import os
+import torch
+import time
+import signal
+
+from grpc import aio
+from loguru import logger
+
+from grpc_reflection.v1alpha import reflection
+from pathlib import Path
+from typing import List, Optional
+
+from text_generation_server.cache import Cache
+from text_generation_server.interceptor import ExceptionInterceptor
+from text_generation_server.models import Model, get_model_with_lora_adapters
+from text_generation_server.pb import generate_pb2_grpc, generate_pb2
+from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
+from text_generation_server.models.globals import set_model_id, ATTENTION
+from text_generation_server.models.globals import set_adapter_to_index
+from text_generation_server.utils.adapter import AdapterInfo
+from text_generation_server.utils.tokens import make_tokenizer_optional
+from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
+from text_generation_server.models import VLM_BATCH_TYPES
+
+from text_generation_server.utils.version import (
+ is_driver_compatible,
+ MIN_TGI_GAUDI_SYNAPSE_VERSION,
+)
+
+
+class SignalHandler:
+ KEEP_PROCESSING = True
+
+ def __init__(self):
+ signal.signal(signal.SIGINT, self.exit_gracefully)
+ signal.signal(signal.SIGTERM, self.exit_gracefully)
+
+ def exit_gracefully(self, signum, frame):
+ print(f"Exiting gracefully: Signal {signum}")
+ self.KEEP_PROCESSING = False
+
+
+class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
+ def __init__(
+ self,
+ model: Model,
+ cache: Cache,
+ server_urls: List[str],
+ ):
+ self.cache = cache
+ self.model = model
+ # Quantize is resolved during model loading
+ self.quantize = model.quantize
+ self.server_urls = server_urls
+ # For some reason, inference_mode does not work well with GLOO which we use on CPU
+ # TODO: The inferecemode set messes up the autograd op dispatch. And results in aten::matmul
+ # op not optimized issue. Will investigate further.
+ # if model.device.type == "hpu":
+ # Force inference mode for the lifetime of TextGenerationService
+ # self._inference_mode_raii_guard = torch._C._InferenceMode(True)
+
+ async def Info(self, request, context):
+ return self.model.info
+
+ async def Health(self, request, context):
+ if self.model.device.type == "hpu":
+ torch.zeros((2, 2)).to("hpu")
+ return generate_pb2.HealthResponse()
+
+ async def ServiceDiscovery(self, request, context):
+ return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
+
+ async def ClearCache(self, request, context):
+ if request.HasField("id"):
+ self.cache.delete(request.id)
+ else:
+ self.cache.clear()
+ return generate_pb2.ClearCacheResponse()
+
+ async def FilterBatch(self, request, context):
+ batch = self.cache.pop(request.batch_id)
+ if batch is None:
+ raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
+ filtered_batch = batch.filter(request.request_ids)
+ self.cache.set(filtered_batch)
+
+ return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
+
+ async def Warmup(self, request, context):
+ if ATTENTION == "paged":
+ set_max_prefill_tokens(request.max_prefill_tokens)
+ if (
+ self.model.batch_type in VLM_BATCH_TYPES
+ ): # Hack, i would rather use kwargs in the `from_pb` call
+ batch = self.model.batch_type.from_pb_processor(
+ request.batch,
+ self.model.tokenizer,
+ self.model.processor,
+ self.model.model.config,
+ self.model.dtype,
+ self.model.device,
+ )
+ else:
+ batch = self.model.batch_type.from_pb(
+ request.batch,
+ self.model.tokenizer,
+ self.model.dtype,
+ self.model.device,
+ )
+
+ # Override default values with None for clearer semantics.
+ max_input_tokens = (
+ request.max_input_tokens
+ if request.HasField("max_input_tokens")
+ else None
+ )
+ max_total_tokens = (
+ request.max_total_tokens
+ if request.HasField("max_total_tokens")
+ else None
+ )
+ max_supported_total_tokens, max_input_tokens, max_total_tokens = (
+ self.model.warmup(batch, max_input_tokens, max_total_tokens)
+ )
+ else:
+ max_supported_total_tokens, max_input_tokens, max_total_tokens = (
+ self.model.warmup(request)
+ )
+
+ # W/A for the skip tokenizer path
+ # We need to call make_tokenizer_optional after the warmup,
+ # because router is not aware of that feature
+ make_tokenizer_optional(self.model.tokenizer)
+
+ return generate_pb2.WarmupResponse(
+ max_supported_total_tokens=max_supported_total_tokens,
+ max_input_tokens=max_input_tokens,
+ max_total_tokens=max_total_tokens,
+ )
+
+ async def Prefill(self, request, context):
+ start = time.time_ns()
+ if (
+ self.model.batch_type in VLM_BATCH_TYPES
+ ): # Hack, i would rather use kwargs in the `from_pb` call
+ batch = self.model.batch_type.from_pb_processor(
+ request.batch,
+ self.model.tokenizer,
+ self.model.processor,
+ self.model.model.config,
+ self.model.dtype,
+ self.model.device,
+ )
+ else:
+ batch = self.model.batch_type.from_pb(
+ request.batch, self.model.tokenizer, self.model.dtype, self.model.device
+ )
+
+ generations, next_batch, timings = self.model.generate_token([batch])
+ self.cache.set(next_batch)
+
+ return generate_pb2.PrefillResponse(
+ generations=[generation.to_pb() for generation in generations],
+ batch=next_batch.to_pb() if next_batch else None,
+ forward_ns=timings[0],
+ decode_ns=timings[1],
+ total_ns=time.time_ns() - start,
+ )
+
+ async def Decode(self, request, context):
+ start = time.time_ns()
+ if len(request.batches) == 0:
+ raise ValueError("Must provide at least one batch")
+
+ batches = []
+ for batch_pb in request.batches:
+ batch = self.cache.pop(batch_pb.id)
+ if batch is None:
+ raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
+ batches.append(batch)
+
+ if len(batches) == 0:
+ raise ValueError("All batches are empty")
+
+ generations, next_batch, timings = self.model.generate_token(batches)
+ self.cache.set(next_batch)
+
+ return generate_pb2.DecodeResponse(
+ generations=[generation.to_pb() for generation in generations],
+ batch=next_batch.to_pb() if next_batch else None,
+ concat_ns=None,
+ forward_ns=timings[0],
+ decode_ns=timings[1],
+ total_ns=time.time_ns() - start,
+ )
+
+
+def serve(
+ model_id: str,
+ lora_adapters: Optional[List[AdapterInfo]],
+ revision: Optional[str],
+ sharded: bool,
+ quantize: Optional[str],
+ speculate: Optional[int],
+ dtype: Optional[str],
+ kv_cache_dtype: Optional[str],
+ trust_remote_code: bool,
+ uds_path: Path,
+ max_input_tokens: int,
+):
+ async def serve_inner(
+ model_id: str,
+ lora_adapters: Optional[List[AdapterInfo]],
+ revision: Optional[str],
+ sharded: bool = False,
+ quantize: Optional[str] = None,
+ speculate: Optional[int] = None,
+ dtype: Optional[str] = None,
+ kv_cache_dtype: Optional[str] = None,
+ trust_remote_code: bool = False,
+ ):
+ if not is_driver_compatible():
+ logger.warning(
+ f"Current Synapse version is lower than the minimum version supported: {MIN_TGI_GAUDI_SYNAPSE_VERSION}, this could result in failures"
+ )
+
+ unix_socket_template = "unix://{}-{}"
+ adapter_to_index = {}
+ logger.info("Server:server_inner: sharded ={}".format(sharded))
+
+ if sharded:
+ rank = int(os.environ["RANK"])
+ logger.info("Server:server_inner: rank ={}".format(rank))
+ server_urls = [
+ unix_socket_template.format(uds_path, rank)
+ for rank in range(int(os.environ["WORLD_SIZE"]))
+ ]
+ local_url = server_urls[int(os.environ["RANK"])]
+ else:
+ local_url = unix_socket_template.format(uds_path, 0)
+ server_urls = [local_url]
+
+ logger.info(
+ "Server:server_inner: data type = {}, local_url = {}".format(
+ dtype, local_url
+ )
+ )
+ if dtype == "bfloat16" or None:
+ data_type = torch.bfloat16
+ else:
+ data_type = torch.float
+ if revision == "None":
+ revision = None
+ try:
+ model = get_model_with_lora_adapters(
+ model_id,
+ lora_adapters,
+ revision,
+ sharded,
+ quantize,
+ speculate,
+ data_type,
+ kv_cache_dtype,
+ trust_remote_code,
+ max_input_tokens,
+ adapter_to_index,
+ )
+
+ except Exception:
+ logger.exception("Error when initializing model")
+ raise
+
+ set_adapter_to_index(adapter_to_index)
+ server = aio.server(
+ interceptors=[
+ ExceptionInterceptor(),
+ UDSOpenTelemetryAioServerInterceptor(),
+ ],
+ options=[
+ # Set the maximum possible message length: i32::MAX
+ ("grpc.max_receive_message_length", (1 << 31) - 1)
+ ],
+ )
+ generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
+ TextGenerationService(model, Cache(), server_urls), server
+ )
+ SERVICE_NAMES = (
+ generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
+ reflection.SERVICE_NAME,
+ )
+ reflection.enable_server_reflection(SERVICE_NAMES, server)
+ server.add_insecure_port(local_url)
+
+ await server.start()
+
+ logger.info("Server started at {}".format(local_url))
+ signal_handler = SignalHandler()
+ while signal_handler.KEEP_PROCESSING:
+ await asyncio.sleep(0.5)
+
+ set_model_id(model_id)
+ asyncio.run(
+ serve_inner(
+ model_id,
+ lora_adapters,
+ revision,
+ sharded,
+ quantize,
+ speculate,
+ dtype,
+ kv_cache_dtype,
+ trust_remote_code,
+ )
+ )
diff --git a/backends/gaudi/server/text_generation_server/tracing.py b/backends/gaudi/server/text_generation_server/tracing.py
new file mode 100644
index 00000000000..bc7a04ee758
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/tracing.py
@@ -0,0 +1,63 @@
+import grpc
+
+from opentelemetry import trace
+from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
+from opentelemetry.instrumentation.grpc._aio_server import (
+ OpenTelemetryAioServerInterceptor,
+)
+from opentelemetry.semconv.trace import SpanAttributes
+from opentelemetry.sdk.resources import Resource
+from opentelemetry.sdk.trace import TracerProvider
+from opentelemetry.sdk.trace.export import (
+ BatchSpanProcessor,
+)
+
+
+class UDSOpenTelemetryAioServerInterceptor(OpenTelemetryAioServerInterceptor):
+ def __init__(self):
+ super().__init__(trace.get_tracer(__name__))
+
+ def _start_span(self, handler_call_details, context, set_status_on_exception=False):
+ """
+ Rewrite _start_span method to support Unix Domain Socket gRPC contexts
+ """
+
+ # standard attributes
+ attributes = {
+ SpanAttributes.RPC_SYSTEM: "grpc",
+ SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[0],
+ }
+
+ # if we have details about the call, split into service and method
+ if handler_call_details.method:
+ service, method = handler_call_details.method.lstrip("/").split("/", 1)
+ attributes.update(
+ {
+ SpanAttributes.RPC_METHOD: method,
+ SpanAttributes.RPC_SERVICE: service,
+ }
+ )
+
+ # add some attributes from the metadata
+ metadata = dict(context.invocation_metadata())
+ if "user-agent" in metadata:
+ attributes["rpc.user_agent"] = metadata["user-agent"]
+
+ # We use gRPC over a UNIX socket
+ attributes.update({SpanAttributes.NET_TRANSPORT: "unix"})
+
+ return self._tracer.start_as_current_span(
+ name=handler_call_details.method,
+ kind=trace.SpanKind.SERVER,
+ attributes=attributes,
+ set_status_on_exception=set_status_on_exception,
+ )
+
+
+def setup_tracing(otlp_service_name: str, otlp_endpoint: str):
+ resource = Resource.create(attributes={"service.name": otlp_service_name})
+ span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
+ span_processor = BatchSpanProcessor(span_exporter)
+
+ trace.set_tracer_provider(TracerProvider(resource=resource))
+ trace.get_tracer_provider().add_span_processor(span_processor)
diff --git a/backends/gaudi/server/text_generation_server/utils/__init__.py b/backends/gaudi/server/text_generation_server/utils/__init__.py
new file mode 100644
index 00000000000..cda3a4da1cb
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/__init__.py
@@ -0,0 +1,50 @@
+# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
+
+from text_generation_server.utils.convert import convert_file, convert_files
+from text_generation_server.utils.dist import initialize_torch_distributed
+from text_generation_server.utils.weights import Weights
+from text_generation_server.utils.peft import download_and_unload_peft
+from text_generation_server.utils.hub import (
+ weight_files,
+ weight_hub_files,
+ download_weights,
+ EntryNotFoundError,
+ LocalEntryNotFoundError,
+ RevisionNotFoundError,
+)
+from text_generation_server.utils.tokens import (
+ NextTokenChooser,
+ HeterogeneousNextTokenChooser,
+ StoppingCriteria,
+ StopSequenceCriteria,
+ FinishReason,
+ Sampling,
+ Greedy,
+ make_tokenizer_optional,
+ is_tokenizer_transparent,
+ pad_next_token_chooser_parameters,
+)
+
+__all__ = [
+ "convert_file",
+ "convert_files",
+ "initialize_torch_distributed",
+ "weight_files",
+ "weight_hub_files",
+ "download_weights",
+ "download_and_unload_peft",
+ "EntryNotFoundError",
+ "HeterogeneousNextTokenChooser",
+ "LocalEntryNotFoundError",
+ "RevisionNotFoundError",
+ "Greedy",
+ "NextTokenChooser",
+ "Sampling",
+ "StoppingCriteria",
+ "StopSequenceCriteria",
+ "FinishReason",
+ "Weights",
+ "make_tokenizer_optional",
+ "is_tokenizer_transparent",
+ "pad_next_token_chooser_parameters",
+]
diff --git a/backends/gaudi/server/text_generation_server/utils/adapter.py b/backends/gaudi/server/text_generation_server/utils/adapter.py
new file mode 100644
index 00000000000..2b61f9bb448
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/adapter.py
@@ -0,0 +1,320 @@
+# Origin: https://github.com/predibase/lorax
+# Path: lorax/server/lorax_server/utils/adapter.py
+# License: Apache License Version 2.0, January 2004
+
+import warnings
+import re
+from dataclasses import dataclass
+from functools import lru_cache
+from typing import TYPE_CHECKING, Set, Tuple, Optional, List
+
+from safetensors.torch import load_file
+from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
+
+from text_generation_server.utils.merges.strategies import merge_adapters
+
+from text_generation_server.utils import hub
+from text_generation_server.adapters.lora import LoraConfig
+
+
+if TYPE_CHECKING:
+ from text_generation_server.adapters.config import AdapterConfig, ModuleMap
+
+
+BASE_MODEL_ADAPTER_ID = "__base_model__"
+
+
+@dataclass
+class AdapterInfo:
+ id: str
+ path: Optional[str]
+ revision: Optional[str] = None
+
+
+@dataclass
+class AdapterParameters:
+ adapter_info: Tuple[AdapterInfo]
+ weights: Tuple[float]
+ merge_strategy: NotImplemented
+ density: float
+ majority_sign_method: NotImplemented
+
+
+@dataclass
+class AdapterSource:
+ adapter_id: str
+ model_id: str
+ revision: str
+
+
+def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]:
+ if not lora_adapters:
+ return []
+
+ adapter_list = []
+ for adapter in lora_adapters.split(","):
+ adapter = adapter.strip()
+ if adapter.count("=") > 1 or adapter.count("@") > 1:
+ raise ValueError(f"Invalid LoRA adapter format: {adapter}")
+ match = re.match(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$", adapter)
+
+ if match:
+ adapter_id, path, revision = match.groups()
+ adapter_list.append(
+ AdapterInfo(id=adapter_id, path=path, revision=revision)
+ )
+ else:
+ raise ValueError(f"Invalid LoRA adapter format: {adapter}")
+ return adapter_list
+
+
+def load_and_merge_adapters(
+ model_id: str,
+ adapter_parameters: AdapterParameters,
+ adapter_index: int,
+ weight_names: Tuple[str],
+ trust_remote_code: bool = False,
+) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
+ if len(adapter_parameters.adapter_info) == 1:
+ adapter = next(iter(adapter_parameters.adapter_info))
+ return load_module_map(
+ model_id,
+ adapter.revision,
+ adapter.id,
+ adapter.path,
+ weight_names,
+ trust_remote_code,
+ )
+
+ adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
+ return _load_and_merge(
+ model_id,
+ adapter_params,
+ weight_names,
+ trust_remote_code,
+ )
+
+
+@dataclass
+class AdapterParametersContainer:
+ adapter_parameters: AdapterParameters
+ adapter_index: int
+
+ def __hash__(self) -> int:
+ return self.adapter_index
+
+
+@lru_cache(maxsize=32)
+def _load_and_merge(
+ model_id: str,
+ adapter_params: AdapterParametersContainer,
+ weight_names: Tuple[str],
+ trust_remote_code: bool = False,
+) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
+ params = adapter_params.adapter_parameters
+
+ adapters_to_merge = []
+ merged_weight_names = set()
+ tokenizer = None
+ for adapter in params.adapter_info:
+ if adapter.id == BASE_MODEL_ADAPTER_ID:
+ raise ValueError("Base model adapter cannot be merged.")
+
+ module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
+ load_module_map(
+ model_id,
+ adapter.revision,
+ adapter.id,
+ adapter.path,
+ weight_names,
+ trust_remote_code,
+ )
+ )
+
+ adapters_to_merge.append((module_map, adapter_config))
+ merged_weight_names = merged_weight_names.union(adapter_weight_names)
+ if tokenizer is None:
+ tokenizer = adapter_tokenizer
+
+ if len(adapters_to_merge) == 0:
+ raise ValueError("No adapters to merge.")
+
+ module_map, adapter_config = merge_adapters(adapters_to_merge, params)
+ return module_map, adapter_config, merged_weight_names, tokenizer
+
+
+def check_architectures(
+ model_id: str,
+ adapter_id: str,
+ adapter_config: "AdapterConfig",
+ trust_remote_code: bool = False,
+):
+ try:
+ if not adapter_config.base_model_name_or_path:
+ # Avoid execution latency caused by the network connection retrying for AutoConfig.from_pretrained(None)
+ return
+
+ expected_config = AutoConfig.from_pretrained(
+ model_id, trust_remote_code=trust_remote_code
+ )
+ model_config = AutoConfig.from_pretrained(
+ adapter_config.base_model_name_or_path, trust_remote_code=trust_remote_code
+ )
+ except Exception as e:
+ warnings.warn(
+ f"Unable to check architecture compatibility for adapter '{adapter_id}' "
+ f"against model '{model_id}'. Assuming they are compatible. Error: {e}"
+ )
+ return
+
+ if model_config.architectures == expected_config.architectures:
+ warnings.warn(
+ f"Adapter '{adapter_id}' was not trained on base model '{model_id}'. "
+ f"If you encounter issues, use --model-id '{adapter_config.base_model_name_or_path}' instead."
+ )
+ else:
+ # TODO(travis): revisit this when we support clasification heads which will not use CausalLM
+ raise ValueError(
+ f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. "
+ f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. "
+ f"Use --model-id '{adapter_config.base_model_name_or_path}' instead."
+ )
+
+
+@lru_cache(maxsize=128)
+def load_module_map(
+ model_id: str,
+ revision: str,
+ adapter_id: str,
+ adapter_path: Optional[str],
+ weight_names: Tuple[str],
+ trust_remote_code: bool = False,
+) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
+ adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
+
+ if not adapter_path and adapter_config.base_model_name_or_path != model_id:
+ check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
+
+ adapter_filenames = (
+ hub._weight_files_from_dir(adapter_path, extension=".safetensors")
+ if adapter_path
+ else hub._cached_weight_files(
+ adapter_id, revision=revision, extension=".safetensors"
+ )
+ )
+
+ # throw an error if no adapter weights are found
+ if not adapter_filenames:
+ raise FileNotFoundError(
+ f"No adapter weights found for adapter '{adapter_id}' and revision '{revision}'."
+ )
+
+ try:
+ adapter_tokenizer = AutoTokenizer.from_pretrained(
+ adapter_config.config_path,
+ trust_remote_code=trust_remote_code,
+ )
+ except Exception:
+ # Adapter does not have a tokenizer, so fallback to base model tokenizer
+ adapter_tokenizer = None
+
+ # load adapter weights from all shards (should have relatively small memory footprint)
+ adapter_weights = {}
+ for filename in adapter_filenames:
+ adapter_weights.update(load_file(filename))
+
+ # map the model weights to the relevant adapter weights (LoRA A and B matrices)
+ module_map, adapter_weight_names = adapter_config.map_weights_for_model(
+ adapter_weights, weight_names
+ )
+ return module_map, adapter_config, adapter_weight_names, adapter_tokenizer
+
+
+def get_attn_weights(i, layer):
+ qkv = layer.self_attn.query_key_value
+ weights = {}
+
+ for k in ["q", "k", "v"]:
+ key = (i, f"{k}_proj")
+ value = (f"model.layers.{i}.self_attn.{k}_proj", qkv)
+ weights[key] = value
+
+ # also add the qkv_proj weight for the adapter
+ weights[(i, "qkv_proj")] = (
+ f"model.layers.{i}.self_attn.qkv_proj",
+ qkv,
+ )
+
+ weights[(i, "o_proj")] = (
+ f"model.layers.{i}.self_attn.o_proj",
+ layer.self_attn.o_proj,
+ )
+
+ return weights
+
+
+def get_mlp_weights(i, layer):
+ weights = {}
+ if hasattr(layer, "mlp"):
+ mlp = layer.mlp
+ if hasattr(mlp, "gate_up_proj"):
+ # handle combined gate_up_proj (e.g., for some LLaMA variants)
+ weights.update(
+ {
+ (i, "gate_proj"): (
+ f"model.layers.{i}.mlp.gate_proj",
+ mlp.gate_up_proj,
+ ),
+ (i, "up_proj"): (f"model.layers.{i}.mlp.up_proj", mlp.gate_up_proj),
+ }
+ )
+ else:
+ # handle separate gate_proj, up_proj, and down_proj (e.g., for Gemma)
+ if hasattr(mlp, "gate_proj"):
+ weights[(i, "gate_proj")] = (
+ f"model.layers.{i}.mlp.gate_proj",
+ mlp.gate_proj,
+ )
+ if hasattr(mlp, "up_proj"):
+ weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj)
+
+ if hasattr(mlp, "down_proj"):
+ weights[(i, "down_proj")] = (
+ f"model.layers.{i}.mlp.down_proj",
+ mlp.down_proj,
+ )
+
+ return weights
+
+
+# build_layer_weight_lookup creates a mapping of model layers to their corresponding
+# weight tensors and paths. It builds a dictionary that maps layer identifiers to tuples
+# containing the weight tensor path and the actual layer object. This mapping is needed
+# for the lora adapter to know which weights to update when applying the adapter.
+def build_layer_weight_lookup(model):
+ if hasattr(model, "language_model"):
+ m = model.language_model.model
+ elif hasattr(model, "text_model"):
+ m = model.text_model.model
+ else:
+ m = model.model
+
+ layer_weights = {}
+
+ for i, layer in enumerate(m.layers):
+ attn_weights = get_attn_weights(i, layer)
+ mlp_weights = get_mlp_weights(i, layer)
+
+ layer_weights.update(attn_weights)
+ layer_weights.update(mlp_weights)
+
+ lm_head = None
+ if hasattr(m, "lm_head"):
+ lm_head = m.lm_head
+ elif hasattr(model, "lm_head"):
+ lm_head = model.lm_head
+
+ if lm_head:
+ layer_weights[(0, "lm_head")] = ("lm_head", lm_head)
+
+ return layer_weights
diff --git a/backends/gaudi/server/text_generation_server/utils/chunks.py b/backends/gaudi/server/text_generation_server/utils/chunks.py
new file mode 100644
index 00000000000..73962ea39e1
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/chunks.py
@@ -0,0 +1,27 @@
+from typing import Iterable
+
+from loguru import logger
+
+from text_generation_server.pb import generate_pb2
+
+
+def concat_text_chunks(chunks: Iterable[generate_pb2.InputChunk]) -> str:
+ """
+ Concatenate text in text chunks. Non-text chunks are dropped.
+ """
+ text = None
+ for chunk in chunks:
+ chunk_type = chunk.WhichOneof("chunk")
+ if chunk_type == "text":
+ if text is None:
+ text = chunk.text
+ else:
+ raise NotImplementedError("Request contained more than one text chunk")
+ else:
+ # We cannot reject this, e.g. warmup sends an image chunk.
+ logger.debug(f"Encountered non-text chunk type {chunk_type}")
+
+ if text is None:
+ raise NotImplementedError("Request without a text chunk")
+
+ return text
diff --git a/backends/gaudi/server/text_generation_server/utils/convert.py b/backends/gaudi/server/text_generation_server/utils/convert.py
new file mode 100644
index 00000000000..d9c3276bc03
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/convert.py
@@ -0,0 +1,114 @@
+import datetime
+import torch
+import os
+
+from loguru import logger
+from pathlib import Path
+from safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete
+from typing import List, Dict
+from collections import defaultdict
+
+
+def _remove_duplicate_names(
+ state_dict: Dict[str, torch.Tensor],
+ *,
+ preferred_names: List[str] = None,
+ discard_names: List[str] = None,
+) -> Dict[str, List[str]]:
+ if preferred_names is None:
+ preferred_names = []
+ preferred_names = set(preferred_names)
+ if discard_names is None:
+ discard_names = []
+ discard_names = set(discard_names)
+
+ shareds = _find_shared_tensors(state_dict)
+ to_remove = defaultdict(list)
+ for shared in shareds:
+ complete_names = set(
+ [name for name in shared if _is_complete(state_dict[name])]
+ )
+ if not complete_names:
+ if len(shared) == 1:
+ # Force contiguous
+ name = list(shared)[0]
+ state_dict[name] = state_dict[name].clone()
+ complete_names = {name}
+ else:
+ raise RuntimeError(
+ f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
+ )
+
+ keep_name = sorted(list(complete_names))[0]
+
+ # Mecanism to preferentially select keys to keep
+ # coming from the on-disk file to allow
+ # loading models saved with a different choice
+ # of keep_name
+ preferred = complete_names.difference(discard_names)
+ if preferred:
+ keep_name = sorted(list(preferred))[0]
+
+ if preferred_names:
+ preferred = preferred_names.intersection(complete_names)
+ if preferred:
+ keep_name = sorted(list(preferred))[0]
+ for name in sorted(shared):
+ if name != keep_name:
+ to_remove[keep_name].append(name)
+ return to_remove
+
+
+def convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]):
+ """
+ Convert a pytorch file to a safetensors file
+ This will remove duplicate tensors from the file.
+
+ Unfortunately, this might not respect *transformers* convention.
+ Forcing us to check for potentially different keys during load when looking
+ for specific tensors (making tensor sharing explicit).
+ """
+ loaded = torch.load(pt_file, map_location="cpu", weights_only=True)
+ if "state_dict" in loaded:
+ loaded = loaded["state_dict"]
+ to_removes = _remove_duplicate_names(loaded, discard_names=discard_names)
+
+ metadata = {"format": "pt"}
+ for kept_name, to_remove_group in to_removes.items():
+ for to_remove in to_remove_group:
+ if to_remove not in metadata:
+ metadata[to_remove] = kept_name
+ del loaded[to_remove]
+ # Force tensors to be contiguous
+ loaded = {k: v.contiguous() for k, v in loaded.items()}
+
+ dirname = os.path.dirname(sf_file)
+ os.makedirs(dirname, exist_ok=True)
+ save_file(loaded, sf_file, metadata=metadata)
+ reloaded = load_file(sf_file)
+ for k in loaded:
+ pt_tensor = loaded[k]
+ sf_tensor = reloaded[k]
+ if not torch.equal(pt_tensor, sf_tensor):
+ raise RuntimeError(f"The output tensors do not match for key {k}")
+
+
+def convert_files(pt_files: List[Path], sf_files: List[Path], discard_names: List[str]):
+ assert len(pt_files) == len(sf_files)
+
+ N = len(pt_files)
+ # We do this instead of using tqdm because we want to parse the logs with the launcher
+
+ for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)):
+ # Skip blacklisted files
+ if (
+ "arguments" in pt_file.name
+ or "args" in pt_file.name
+ or "training" in pt_file.name
+ ):
+ continue
+
+ start = datetime.datetime.now()
+ convert_file(pt_file, sf_file, discard_names)
+ elapsed = datetime.datetime.now() - start
+ logger.info(f"Convert: [{i + 1}/{N}] -- Took: {elapsed}")
diff --git a/backends/gaudi/server/text_generation_server/utils/debug.py b/backends/gaudi/server/text_generation_server/utils/debug.py
new file mode 100644
index 00000000000..690da54ec71
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/debug.py
@@ -0,0 +1,48 @@
+# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
+
+import os
+import glob
+import time
+
+import habana_frameworks.torch as htorch
+import numpy as np
+
+START_TS = None
+DBG_TRACE_FILENAME = os.environ.get("DBG_TRACE_FILENAME")
+if "GRAPH_VISUALIZATION" in os.environ:
+ for f in glob.glob(".graph_dumps/*"):
+ os.remove(f)
+
+
+def to_gb_rounded(mem: float) -> float:
+ """
+ Rounds and converts to GB.
+
+ Args:
+ mem (float): memory in bytes
+
+ Returns:
+ float: memory in GB rounded to the second decimal
+ """
+ return np.round(mem / 1024**3, 2)
+
+
+def count_hpu_graphs():
+ return len(glob.glob(".graph_dumps/*PreGraph*"))
+
+
+def dbg_trace(tag, txt):
+ global START_TS
+ if DBG_TRACE_FILENAME is not None and int(os.getenv("RANK", 0)) == 0:
+ if START_TS is None:
+ START_TS = time.perf_counter()
+ time_offset = time.perf_counter() - START_TS
+ mem_stats = htorch.hpu.memory.memory_stats()
+ mem_used = to_gb_rounded(mem_stats["InUse"])
+ max_mem_used = to_gb_rounded(mem_stats["MaxInUse"])
+ print(
+ f"ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB "
+ f"mmu:{max_mem_used:.1f}GB | {tag} | {txt}",
+ flush=True,
+ file=open(DBG_TRACE_FILENAME, "a"),
+ )
diff --git a/backends/gaudi/server/text_generation_server/utils/dist.py b/backends/gaudi/server/text_generation_server/utils/dist.py
new file mode 100644
index 00000000000..1c45713e8d8
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/dist.py
@@ -0,0 +1,66 @@
+import os
+import torch
+from torch.distributed import ProcessGroup
+from datetime import timedelta
+from loguru import logger
+
+# Tensor Parallelism settings
+RANK = int(os.getenv("RANK", "0"))
+WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
+MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.8"))
+
+
+class FakeBarrier:
+ def wait(self):
+ pass
+
+
+class FakeGroup(ProcessGroup):
+ def __init__(self, rank, size):
+ self._rank = rank
+ self._size = size
+ super().__init__(rank, size)
+
+ def allreduce(self, *args, **kwargs):
+ return FakeBarrier()
+
+ def allgather(self, inputs, local_tensor, **kwargs):
+ assert (
+ len(inputs[0]) == len(local_tensor) == 1
+ ), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors"
+ for input_ in inputs:
+ input_[0].data = local_tensor[0].data
+ return FakeBarrier()
+
+ def barrier(self, *args, **kwargs):
+ return FakeBarrier()
+
+ def size(self):
+ return self._size
+
+ def rank(self):
+ return self._rank
+
+ def _get_backend_name(self):
+ return "fake"
+
+
+def initialize_torch_distributed():
+ if WORLD_SIZE == 1:
+ return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
+ else:
+ if os.getenv("DEBUG", None) == "1":
+ return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
+
+ if not torch.distributed.is_initialized():
+ # Call the init process.
+ torch.distributed.init_process_group(
+ backend="hccl",
+ world_size=WORLD_SIZE,
+ rank=RANK,
+ timeout=timedelta(seconds=120),
+ )
+ else:
+ logger.warning("torch.distributed is already initialized.")
+
+ return torch.distributed.group.WORLD, RANK, WORLD_SIZE
diff --git a/backends/gaudi/server/text_generation_server/utils/hub.py b/backends/gaudi/server/text_generation_server/utils/hub.py
new file mode 100644
index 00000000000..f9c476ac3cc
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/hub.py
@@ -0,0 +1,234 @@
+import time
+import os
+
+from datetime import timedelta
+from loguru import logger
+from pathlib import Path
+from typing import Optional, List
+
+from huggingface_hub import file_download, hf_api, HfApi, hf_hub_download
+from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
+from huggingface_hub.utils import (
+ LocalEntryNotFoundError,
+ EntryNotFoundError,
+ RevisionNotFoundError, # noqa # Import here to ease try/except in other part of the lib
+)
+
+WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
+HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
+
+
+def _cached_weight_files(
+ model_id: str, revision: Optional[str], extension: str
+) -> List[str]:
+ """Guess weight files from the cached revision snapshot directory"""
+ d = _get_cached_revision_directory(model_id, revision)
+ if not d:
+ return []
+ filenames = _weight_files_from_dir(d, extension)
+ return filenames
+
+
+def _weight_hub_files_from_model_info(
+ info: hf_api.ModelInfo, extension: str
+) -> List[str]:
+ return [
+ s.rfilename
+ for s in info.siblings
+ if s.rfilename.endswith(extension)
+ and len(s.rfilename.split("/")) == 1
+ and "arguments" not in s.rfilename
+ and "args" not in s.rfilename
+ and "training" not in s.rfilename
+ ]
+
+
+def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
+ # os.walk: do not iterate, just scan for depth 1, not recursively
+ # see _weight_hub_files_from_model_info, that's also what is
+ # done there with the len(s.rfilename.split("/")) == 1 condition
+ root, _, files = next(os.walk(str(d)))
+ filenames = [
+ os.path.join(root, f)
+ for f in files
+ if f.endswith(extension)
+ and "arguments" not in f
+ and "args" not in f
+ and "training" not in f
+ ]
+ return filenames
+
+
+def _get_cached_revision_directory(
+ model_id: str, revision: Optional[str]
+) -> Optional[Path]:
+ if revision is None:
+ revision = "main"
+
+ repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path(
+ file_download.repo_folder_name(repo_id=model_id, repo_type="model")
+ )
+
+ if not repo_cache.is_dir():
+ # No cache for this model
+ return None
+
+ refs_dir = repo_cache / "refs"
+ snapshots_dir = repo_cache / "snapshots"
+
+ # Resolve refs (for instance to convert main to the associated commit sha)
+ if refs_dir.is_dir():
+ revision_file = refs_dir / revision
+ if revision_file.exists():
+ with revision_file.open() as f:
+ revision = f.read()
+
+ # Check if revision folder exists
+ if not snapshots_dir.exists():
+ return None
+ cached_shas = os.listdir(snapshots_dir)
+ if revision not in cached_shas:
+ # No cache for this revision and we won't try to return a random revision
+ return None
+
+ return snapshots_dir / revision
+
+
+def weight_hub_files(
+ model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
+) -> List[str]:
+ """Get the weights filenames on the hub"""
+ api = HfApi()
+
+ if HF_HUB_OFFLINE:
+ filenames = _cached_weight_files(model_id, revision, extension)
+ else:
+ # Online case, fetch model info from the Hub
+ info = api.model_info(model_id, revision=revision)
+ filenames = _weight_hub_files_from_model_info(info, extension)
+
+ if not filenames:
+ raise EntryNotFoundError(
+ f"No {extension} weights found for model {model_id} and revision {revision}.",
+ None,
+ )
+
+ return filenames
+
+
+def try_to_load_from_cache(
+ model_id: str, revision: Optional[str], filename: str
+) -> Optional[Path]:
+ """Try to load a file from the Hugging Face cache"""
+
+ d = _get_cached_revision_directory(model_id, revision)
+ if not d:
+ return None
+
+ # Check if file exists in cache
+ cached_file = d / filename
+ return cached_file if cached_file.is_file() else None
+
+
+def weight_files(
+ model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
+) -> List[Path]:
+ """Get the local files"""
+ # Local model
+ d = Path(model_id)
+ if d.exists() and d.is_dir():
+ local_files = _weight_files_from_dir(d, extension)
+ if not local_files:
+ raise FileNotFoundError(
+ f"No local weights found in {model_id} with extension {extension}"
+ )
+ return [Path(f) for f in local_files]
+
+ try:
+ filenames = weight_hub_files(model_id, revision, extension)
+ except EntryNotFoundError as e:
+ if extension != ".safetensors":
+ raise e
+ # Try to see if there are pytorch weights
+ pt_filenames = weight_hub_files(model_id, revision, extension=".bin")
+ # Change pytorch extension to safetensors extension
+ # It is possible that we have safetensors weights locally even though they are not on the
+ # hub if we converted weights locally without pushing them
+ filenames = [
+ f"{Path(f).stem.lstrip('pytorch_')}.safetensors" for f in pt_filenames
+ ]
+
+ if WEIGHTS_CACHE_OVERRIDE is not None:
+ files = []
+ for filename in filenames:
+ p = Path(WEIGHTS_CACHE_OVERRIDE) / filename
+ if not p.exists():
+ raise FileNotFoundError(
+ f"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}."
+ )
+ files.append(p)
+ return files
+
+ files = []
+ for filename in filenames:
+ cache_file = try_to_load_from_cache(
+ model_id, revision=revision, filename=filename
+ )
+ if cache_file is None:
+ raise LocalEntryNotFoundError(
+ f"File {filename} of model {model_id} not found in "
+ f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
+ f"Please run `text-generation-server download-weights {model_id}` first."
+ )
+ files.append(cache_file)
+
+ return files
+
+
+def download_weights(
+ filenames: List[str], model_id: str, revision: Optional[str] = None
+) -> List[Path]:
+ """Download the safetensors files from the hub"""
+
+ def download_file(fname, tries=5, backoff: int = 5):
+ local_file = try_to_load_from_cache(model_id, revision, fname)
+ if local_file is not None:
+ logger.info(f"File {fname} already present in cache.")
+ return Path(local_file)
+
+ for idx in range(tries):
+ try:
+ logger.info(f"Download file: {fname}")
+ stime = time.time()
+ local_file = hf_hub_download(
+ filename=fname,
+ repo_id=model_id,
+ revision=revision,
+ local_files_only=HF_HUB_OFFLINE,
+ )
+ logger.info(
+ f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - stime))}."
+ )
+ return Path(local_file)
+ except Exception as e:
+ if idx + 1 == tries:
+ raise e
+ logger.error(e)
+ logger.info(f"Retrying in {backoff} seconds")
+ time.sleep(backoff)
+ logger.info(f"Retry {idx + 1}/{tries - 1}")
+
+ # We do this instead of using tqdm because we want to parse the logs with the launcher
+ start_time = time.time()
+ files = []
+ for i, filename in enumerate(filenames):
+ file = download_file(filename)
+
+ elapsed = timedelta(seconds=int(time.time() - start_time))
+ remaining = len(filenames) - (i + 1)
+ eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0
+
+ logger.info(f"Download: [{i + 1}/{len(filenames)}] -- ETA: {eta}")
+ files.append(file)
+
+ return files
diff --git a/backends/gaudi/server/text_generation_server/utils/import_utils.py b/backends/gaudi/server/text_generation_server/utils/import_utils.py
new file mode 100644
index 00000000000..bdcfc9fa6e4
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/import_utils.py
@@ -0,0 +1,19 @@
+import torch
+
+
+def get_hpu_free_memory(device, memory_fraction):
+ free_hpu_memory, _ = torch.hpu.mem_get_info()
+ return free_hpu_memory
+
+
+def synchronize_hpu(device):
+ torch.hpu.synchronize()
+
+
+def noop(*args, **kwargs):
+ pass
+
+
+empty_cache = noop
+synchronize = synchronize_hpu
+get_free_memory = get_hpu_free_memory
diff --git a/backends/gaudi/server/text_generation_server/utils/kernels.py b/backends/gaudi/server/text_generation_server/utils/kernels.py
new file mode 100644
index 00000000000..42745c7165f
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/kernels.py
@@ -0,0 +1,22 @@
+import importlib
+
+from loguru import logger
+from hf_kernels import load_kernel as hf_load_kernel
+
+from text_generation_server.utils.log import log_once
+
+
+def load_kernel(*, module: str, repo_id: str):
+ """
+ Load a kernel. First try to load it as the given module (e.g. for
+ local development), falling back to a locked Hub kernel.
+ """
+ try:
+ m = importlib.import_module(module)
+ log_once(logger.info, f"Using local module for `{module}`")
+ return m
+ except ModuleNotFoundError:
+ return hf_load_kernel(repo_id=repo_id)
+
+
+__all__ = ["load_kernel"]
diff --git a/backends/gaudi/server/text_generation_server/utils/log.py b/backends/gaudi/server/text_generation_server/utils/log.py
new file mode 100644
index 00000000000..4385c71ee96
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/log.py
@@ -0,0 +1,15 @@
+from functools import lru_cache
+from text_generation_server.utils.dist import RANK
+
+
+@lru_cache(10)
+def log_once(log, msg: str, master=True):
+ if master:
+ log_master(log, msg)
+ else:
+ log(msg)
+
+
+def log_master(log, msg: str):
+ if RANK == 0:
+ log(msg)
diff --git a/backends/gaudi/server/text_generation_server/utils/logits_process.py b/backends/gaudi/server/text_generation_server/utils/logits_process.py
new file mode 100644
index 00000000000..c0fd6cbaeb3
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/logits_process.py
@@ -0,0 +1,610 @@
+import math
+import torch
+import habana_frameworks.torch.core as htcore
+
+from loguru import logger
+from typing import Dict
+from text_generation_server.pb.generate_pb2 import GrammarType
+
+from outlines.fsm.fsm import RegexFSM
+from outlines.fsm.json_schema import build_regex_from_schema
+from functools import lru_cache
+from typing import List, Optional, DefaultDict
+import time
+
+from transformers import (
+ LogitsProcessor,
+ TemperatureLogitsWarper,
+ TopKLogitsWarper,
+ TopPLogitsWarper,
+ TypicalLogitsWarper,
+)
+
+mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
+
+
+class StaticWarper:
+ def __init__(
+ self,
+ temperature=1.0,
+ top_k=None,
+ top_p=None,
+ typical_p=None,
+ ):
+ self.warpers = []
+
+ if temperature is not None and temperature != 1.0:
+ temperature = float(temperature)
+ self.warpers.append(TemperatureLogitsWarper(temperature))
+ if top_k is not None and top_k != 0:
+ self.warpers.append(TopKLogitsWarper(top_k=top_k))
+ if top_p is not None and top_p < 1.0:
+ self.warpers.append(TopPLogitsWarper(top_p=top_p))
+ if typical_p is not None and typical_p < 1.0:
+ self.warpers.append(TypicalLogitsWarper(mass=typical_p))
+
+ self.hpu_graph = None
+ self.static_scores = None
+ self.static_warped_scores = None
+ self.static_next_logprob = None
+
+ def __call__(self, scores):
+ if self.hpu_graph is None:
+ self.static_scores = scores.clone().contiguous()
+ self.static_warped_scores = scores.clone().contiguous()
+ self.static_next_logprob = scores.clone().contiguous()
+ self.hpu_graph = htcore.hpu.HPUGraph()
+
+ with htcore.hpu.graph(self.hpu_graph):
+ local_scores = self.static_scores
+ for warper in self.warpers:
+ local_scores = warper(None, local_scores)
+
+ self.static_warped_scores.copy_(local_scores)
+ # Compute logprobs
+ self.static_next_logprob.copy_(
+ torch.log_softmax(self.static_warped_scores, -1)
+ )
+
+ self.static_scores.copy_(scores)
+ self.hpu_graph.replay()
+
+ return self.static_warped_scores, self.static_next_logprob
+
+
+@lru_cache(10)
+def static_warper(
+ temperature: Optional[float],
+ top_k: Optional[int],
+ top_p: Optional[float],
+ typical_p: Optional[float],
+) -> StaticWarper:
+ return StaticWarper(
+ temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
+ )
+
+
+class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
+ This version allows for a separate value for each sample and runs inplace when possible.
+ It doesn't validate inputs.
+
+ Args:
+ repetition_penalty (`List[float]`):
+ The parameter for repetition penalty. 1.0 means no penalty. See [this
+ paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
+ """
+
+ def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
+ self.penalty = penalty
+ self.penalty_tensor = torch.tensor(
+ penalty, dtype=dtype, device=device
+ ).unsqueeze(1)
+
+ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
+ score = torch.gather(scores, 1, input_ids)
+
+ # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
+ score = torch.where(
+ score < 0, score * self.penalty_tensor, score / self.penalty_tensor
+ )
+
+ scores.scatter_(1, input_ids, score)
+ return scores
+
+ def filter(self, indices):
+ self.penalty = [self.penalty[i] for i in indices]
+ if any([x != 1.0 for x in self.penalty]):
+ self.penalty_tensor = self.penalty_tensor[indices]
+ return self
+ return None
+
+
+class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
+ r"""
+ Frequency penalty as defined by OpenAI
+
+ Args:
+ penalty (`float`):
+ The parameter for frequency penalty. 0.0 means no penalty.
+ """
+
+ def __init__(self, penalty: float):
+ self.penalty = penalty
+
+ def __call__(
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
+ ) -> torch.FloatTensor:
+ score = torch.gather(scores, 1, input_ids)
+ # if score < 0 then penalty has to be multiplied to reduce the previous token probability
+ score = -torch.where(score < 0, score * self.penalty, score / self.penalty)
+ # set score to 0 where input_ids is a padding token
+ score *= input_ids.ne(0)
+
+ return scores.scatter_add_(1, input_ids, score)
+
+
+class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
+ r"""
+ Frequency penalty as defined by OpenAI in
+ https://platform.openai.com/docs/guides/text-generation/parameter-details
+
+ Args:
+ frequency_penalty (`List[float]`):
+ The parameter for frequency penalty. 0.0 means no penalty.
+ """
+
+ def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
+ self.penalty = penalty
+ self.penalty_tensor = torch.tensor(
+ penalty, dtype=dtype, device=device
+ ).unsqueeze(1)
+
+ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
+ batch_size, input_size = input_ids.size()
+ vocab_size = scores.size(1)
+
+ # Calculate the frequency for each token so far
+ token_freq = torch.zeros(
+ batch_size, vocab_size, dtype=scores.dtype, device=scores.device
+ )
+ token_freq.scatter_add_(
+ 1,
+ input_ids,
+ torch.ones_like(input_ids, dtype=scores.dtype, device=scores.device),
+ )
+ token_freq /= input_size
+
+ # Apply the frequency penalty to logits
+ scores -= token_freq * self.penalty_tensor
+ return scores
+
+ def filter(self, indices):
+ self.penalty = [self.penalty[i] for i in indices]
+ if any([x != 0.0 for x in self.penalty]):
+ self.penalty_tensor = self.penalty_tensor[indices]
+ return self
+ return None
+
+
+class HeterogeneousTemperatureLogitsWarper:
+ r"""
+ [`LogitsProcessor`] for temperature (exponential scaling output probability distribution).
+ This version allows for a separate value for each sample and runs inplace when possible.
+ It doesn't validate inputs.
+
+ Args:
+ temperature (`float`):
+ The value used to module the logits distribution.
+ """
+
+ def __init__(
+ self, temperature: List[float], dtype: torch.dtype, device: torch.device
+ ):
+ self.temperature = temperature
+ self.temperature_tensor = torch.tensor(
+ temperature, dtype=dtype, device=device
+ ).unsqueeze(1)
+
+ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
+ scores.div_(self.temperature_tensor)
+ return scores
+
+ def filter(self, indices):
+ self.temperature = [self.temperature[i] for i in indices]
+ if any([x != 1.0 for x in self.temperature]):
+ self.temperature_tensor = self.temperature_tensor[indices]
+ return self
+ return None
+
+
+class HeterogeneousTopPLogitsWarper(LogitsProcessor):
+ """
+ [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
+ This version allows for a separate value for each sample and runs inplace when possible.
+ It doesn't validate inputs.
+
+ Args:
+ top_p (`float`):
+ If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
+ higher are kept for generation.
+ filter_value (`float`, *optional*, defaults to `-float("Inf")`):
+ All filtered values will be set to this float value.
+ min_tokens_to_keep (`int`, *optional*, defaults to 1):
+ Minimum number of tokens that cannot be filtered.
+ """
+
+ def __init__(
+ self,
+ top_p: List[float],
+ dtype: torch.dtype,
+ device: torch.device,
+ filter_value: float = -math.inf,
+ min_tokens_to_keep: int = 1,
+ ):
+ self.top_p = top_p
+ self.top_p_opposite = 1 - torch.tensor(
+ top_p, dtype=dtype, device=device
+ ).unsqueeze(1)
+ self.filter_value = filter_value
+ self.min_tokens_to_keep = min_tokens_to_keep
+
+ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
+ sorted_logits, sorted_indices = torch.sort(scores, descending=False)
+ probs = sorted_logits.softmax(dim=-1)
+ # This is way faster for some reason
+ for i in range(probs.shape[0]):
+ probs[i] = probs[i].cumsum(dim=-1)
+
+ # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
+ sorted_indices_to_remove = probs <= self.top_p_opposite
+ # Keep at least min_tokens_to_keep
+ sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
+
+ # scatter sorted tensors to original indexing
+ indices_to_remove = sorted_indices_to_remove.scatter(
+ 1, sorted_indices, sorted_indices_to_remove
+ )
+ warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
+
+ return warped_scores
+
+ def filter(self, indices):
+ self.top_p = [self.top_p[i] for i in indices]
+ if any([x < 1.0 for x in self.top_p]):
+ self.top_p_opposite = self.top_p_opposite[indices]
+ return self
+ return None
+
+
+class HeterogeneousTopKLogitsWarper(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] that performs top-k, i.e. restricting to the k highest probability elements.
+ This version allows for a separate value for each sample and runs inplace when possible.
+ It doesn't validate inputs.
+
+ Args:
+ top_k (`int`):
+ The number of highest probability vocabulary tokens to keep for top-k-filtering.
+ filter_value (`float`, *optional*, defaults to `-float("Inf")`):
+ All filtered values will be set to this float value.
+ min_tokens_to_keep (`int`, *optional*, defaults to 1):
+ Minimum number of tokens that cannot be filtered.
+ """
+
+ def __init__(
+ self,
+ top_k: List[int],
+ device: torch.device,
+ filter_value: float = -math.inf,
+ min_tokens_to_keep: int = 1,
+ ):
+ self.top_k = top_k
+ self.max_top_k = max(top_k)
+ # value - 1 as we will use top_k to index and python uses 0 based numbering
+ self.top_k_tensor = torch.tensor(
+ [max(x - 1, min_tokens_to_keep - 1) for x in top_k],
+ dtype=torch.int64,
+ device=device,
+ ).unsqueeze(1)
+
+ # 0 is a special value that disables top_k warping for this member of the batch
+ disabled = [x == 0 for x in top_k]
+
+ if any(disabled):
+ self.top_k_disabled_mask = torch.tensor(
+ disabled, dtype=torch.bool, device=device
+ ).view(-1, 1)
+ else:
+ self.top_k_disabled_mask = None
+
+ self.filter_value = filter_value
+
+ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
+ # If max_top_k is superior to the vocab, we need to clamp or the warper will fail
+ if scores.size(-1) < self.max_top_k:
+ max_top_k = scores.size(-1)
+ top_k = torch.clamp_max(self.top_k_tensor, max_top_k)
+ else:
+ max_top_k = self.max_top_k
+ top_k = self.top_k_tensor
+
+ # Get the kth score for each member of the batch
+ kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)
+
+ # Mask member of kth_scores that do not want to use top_k warping
+ if self.top_k_disabled_mask is not None:
+ kth_scores.masked_fill_(self.top_k_disabled_mask, self.filter_value)
+
+ # Remove all tokens with a probability less than the last token of the top-k
+ indices_to_remove = scores < kth_scores
+ scores.masked_fill_(indices_to_remove, self.filter_value)
+ return scores
+
+ def filter(self, indices):
+ self.top_k = [self.top_k[i] for i in indices]
+ disabled = [x == 0 for x in self.top_k]
+
+ if not all(disabled):
+ self.top_k_tensor = self.top_k_tensor[indices]
+ self.max_top_k = max(self.top_k)
+
+ if self.top_k_disabled_mask is not None:
+ self.top_k_disabled_mask = (
+ self.top_k_disabled_mask[indices] if any(disabled) else None
+ )
+
+ return self
+ return None
+
+
+class HeterogeneousTypicalLogitsWarper(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] that performs typical decoding. See [Typical Decoding for Natural Language
+ Generation](https://arxiv.org/abs/2202.00666) for more information.
+ This version allows for a separate value for each sample and runs inplace when possible.
+ It doesn't validate inputs.
+
+ Args:
+ mass (`float`):
+ Value of typical_p between 0 and 1 inclusive, defaults to 0.9.
+ filter_value (`float`, *optional*, defaults to `-float("Inf")`):
+ All filtered values will be set to this float value.
+ min_tokens_to_keep (`int`, *optional*, defaults to 1):
+ Minimum number of tokens that cannot be filtered.
+ """
+
+ def __init__(
+ self,
+ mass: List[float],
+ dtype: torch.dtype,
+ device: torch.device,
+ filter_value: float = -math.inf,
+ min_tokens_to_keep: int = 1,
+ ):
+ self.mass = mass
+ self.mass_tensor = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1)
+
+ # 1 is a special value that disables typical_p warping for this member of the batch
+ disabled = [x == 1.0 for x in mass]
+
+ if any(disabled):
+ self.disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device)
+ else:
+ self.disabled_mask = None
+
+ self.filter_value = filter_value
+ self.min_tokens_to_keep = min_tokens_to_keep
+
+ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
+ # calculate entropy
+ normalized = torch.nn.functional.log_softmax(scores, dim=-1)
+ p = torch.exp(normalized)
+ ent = -(normalized * p).nansum(-1, keepdim=True)
+
+ # shift and sort
+ shifted_scores = torch.abs((-normalized) - ent)
+ sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
+ sorted_logits = scores.gather(-1, sorted_indices)
+ probs = sorted_logits.softmax(dim=-1)
+ # This is way faster for some reason
+ for i in range(probs.shape[0]):
+ probs[i] = probs[i].cumsum(dim=-1)
+
+ # Remove tokens with cumulative mass above the threshold
+ last_ind = (probs < self.mass_tensor).sum(dim=1)
+ last_ind[last_ind < 0] = 0
+
+ if self.disabled_mask is not None:
+ last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1)
+
+ sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
+ 1, last_ind.view(-1, 1)
+ )
+ if self.min_tokens_to_keep > 1:
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
+ sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
+ indices_to_remove = sorted_indices_to_remove.scatter(
+ 1, sorted_indices, sorted_indices_to_remove
+ )
+
+ warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
+
+ return warped_scores
+
+ def filter(self, indices):
+ self.mass = [self.mass[i] for i in indices]
+ disabled = [x == 1.0 for x in self.mass]
+
+ if not all(disabled):
+ self.mass_tensor = self.mass_tensor[indices]
+
+ if self.disabled_mask is not None:
+ self.disabled_mask = (
+ self.disabled_mask[indices] if any(disabled) else None
+ )
+
+ return self
+ return None
+
+
+class HeterogeneousProcessorWrapper(LogitsProcessor):
+ r"""
+ A wrapper for logit warpers or processors without heterogeneous parameter support.
+ Args:
+ processors (`Dict[int, LogitsProcessor]`):
+ A mapping of sample indices to logit warpers or processors, to be run sequentially.
+ """
+
+ def __init__(
+ self,
+ processors: Dict[int, LogitsProcessor],
+ ):
+ self.processors = processors
+
+ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
+ for i, processor in self.processors.items():
+ scores[i : i + 1] = processor(input_ids[i : i + 1], scores[i : i + 1])
+ return scores
+
+ def filter(self, indices):
+ new_processors = {}
+ for i, idx in enumerate(indices):
+ if idx in self.processors:
+ new_processors[i] = self.processors[idx]
+
+ if new_processors:
+ self.processors = new_processors
+ return self
+ return None
+
+
+class GrammarLogitProcessor(LogitsProcessor):
+ fsm_state: DefaultDict[int, int]
+ fsm: RegexFSM
+
+ def __init__(self, tokenizer, device, grammar, grammar_type):
+ self.device = device
+ self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
+ self.fsm = GrammarLogitProcessor._cached_compile_fsm(
+ grammar_type, grammar, self.tokenizer
+ )
+
+ def __call__(
+ self,
+ logits: torch.Tensor,
+ fsm_grammar_state: int,
+ ):
+ if fsm_grammar_state == -1 or self.fsm is None:
+ return logits
+ allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
+ mask = torch.full_like(logits, -math.inf)
+ mask[:, allowed_tokens] = 0
+ biased_scores = logits + mask
+ return biased_scores
+
+ def advance(self, next_token_id, fsm_grammar_state):
+ return GrammarLogitProcessor._advance(
+ next_token_id, fsm_grammar_state, self.fsm
+ )
+
+ @staticmethod
+ def _advance(next_token_id, fsm_grammar_state, fsm):
+ if fsm_grammar_state == -1:
+ return fsm_grammar_state
+ return fsm.next_state(fsm_grammar_state, next_token_id)
+
+ # TODO: move grammar compilation into the router
+ @staticmethod
+ @lru_cache(maxsize=32, typed=True)
+ def _cached_compile_fsm(grammar_type, schema, tokenizer):
+ start_time = time.time()
+ if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
+ schema = build_regex_from_schema(schema)
+ elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
+ pass # schema is already a regex just here for clarity
+ fsm = RegexFSM(schema, tokenizer)
+ logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
+ return fsm
+
+ @staticmethod
+ @lru_cache(maxsize=32, typed=True)
+ def _cached_adapt_tokenizer(tokenizer):
+ """Adapt tokenizer to work with the FSM.
+
+ The API of Outlines tokenizers is slightly different to that of
+ `transformers`. In addition we need to handle the missing spaces to
+ Llama's tokenizer to be able to compile FSMs for this model.
+
+ """
+ start_time = time.time()
+ tokenizer.vocabulary = tokenizer.get_vocab()
+ tokenizer.special_tokens = set(tokenizer.all_special_tokens)
+
+ def convert_token_to_string(token: str) -> str:
+ from transformers.file_utils import SPIECE_UNDERLINE
+
+ string = tokenizer.convert_tokens_to_string([token])
+
+ # A hack to handle missing spaces to HF's Llama tokenizers
+ if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
+ return " " + string
+
+ return string
+
+ tokenizer.convert_token_to_string = convert_token_to_string
+ logger.debug(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
+ return tokenizer
+
+
+class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
+ def __init__(self, tokenizer, device, grammars, grammar_types):
+ self.device = device
+ self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
+ self.fsms = []
+ for grammar, grammar_type in zip(grammars, grammar_types):
+ if len(grammar) == 0:
+ self.fsms.append(None)
+ continue
+ fsm = GrammarLogitProcessor._cached_compile_fsm(
+ grammar_type, grammar, self.tokenizer
+ )
+ self.fsms.append(fsm)
+
+ def __call__(
+ self,
+ logits: torch.Tensor,
+ fsm_grammar_states: List[int],
+ ):
+ mask = torch.full_like(logits, -math.inf)
+ for i in range(logits.shape[0]):
+ fsm = self.fsms[i]
+ if fsm is None:
+ continue
+ allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
+ mask[i, allowed_tokens] = 0
+ logits[i] += mask[i]
+ return logits
+
+ def advance_batch(self, next_token_ids, fsm_grammar_states):
+ return [
+ GrammarLogitProcessor._advance(
+ next_token_ids[i], fsm_grammar_states[i], self.fsms[i]
+ )
+ for i in range(len(next_token_ids))
+ ]
+
+ def advance_at_index(self, next_token_id, fsm_grammar_state, index):
+ if self.fsms[index] is None:
+ return fsm_grammar_state
+ return GrammarLogitProcessor._advance(
+ next_token_id, fsm_grammar_state, self.fsms[index]
+ )
+
+ def filter(self, indices):
+ new_fsms = []
+ for i in indices:
+ new_fsms.append(self.fsms[i])
+ self.fsms = new_fsms
+ return self
diff --git a/backends/gaudi/server/text_generation_server/utils/merges/strategies.py b/backends/gaudi/server/text_generation_server/utils/merges/strategies.py
new file mode 100644
index 00000000000..cb39cde1f3e
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/merges/strategies.py
@@ -0,0 +1,220 @@
+import copy
+from abc import ABC
+from collections import defaultdict
+from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union
+from text_generation_server.utils.merges.utils import (
+ calculate_majority_sign_mask,
+ disjoint_merge,
+ prune,
+)
+import torch
+
+if TYPE_CHECKING:
+ from text_generation_server.adapters.lora import LoraConfig
+ from text_generation_server.utils.adapter import ModuleMap
+
+
+class AdapterParameters:
+ def __init__(
+ self, adapter_ids, weights, merge_strategy, density, majority_sign_method
+ ):
+ self.adapter_ids = adapter_ids
+ self.weights = weights
+ self.merge_strategy = merge_strategy
+ self.density = density
+ self.majority_sign_method = majority_sign_method
+
+
+def _apply_weights(
+ tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor
+) -> torch.Tensor:
+ if isinstance(tensors, torch.Tensor):
+ t = tensors
+ else:
+ t = torch.stack(tensors, dim=0)
+
+ # element-wise weighting of each task tensor
+ # need to unsqueeze weights to match task tensor dimensions
+ # for multiplication to apply element-wise
+ while len(t.shape) > len(w.shape):
+ w = w.unsqueeze(-1)
+ return t * w
+
+
+class MergeStrategy(ABC):
+ def merge(
+ self, task_tensors: List[torch.Tensor], weights: torch.Tensor
+ ) -> torch.Tensor:
+ raise NotImplementedError()
+
+
+class LinearMerge(MergeStrategy):
+ def __init__(self, **kwargs):
+ pass
+
+ def merge(
+ self, task_tensors: List[torch.Tensor], weights: torch.Tensor
+ ) -> torch.Tensor:
+ weighted_task_tensors = _apply_weights(task_tensors, weights)
+ return weighted_task_tensors.sum(dim=0)
+
+
+class TiesMerge(MergeStrategy):
+ def __init__(self, density: float, majority_sign_method: str = "total", **kwargs):
+ self.density = density
+ self.majority_sign_method = majority_sign_method
+
+ def merge(
+ self, task_tensors: List[torch.Tensor], weights: torch.Tensor
+ ) -> torch.Tensor:
+ # sparsify
+ task_tensors = [
+ prune(tensor, self.density, method="magnitude") for tensor in task_tensors
+ ]
+ task_tensors = torch.stack(task_tensors, dim=0)
+
+ # elect sign before applying weights
+ majority_sign_mask = calculate_majority_sign_mask(
+ task_tensors, method=self.majority_sign_method
+ )
+ weighted_task_tensors = _apply_weights(task_tensors, weights)
+
+ # disjoint merge
+ return disjoint_merge(weighted_task_tensors, majority_sign_mask)
+
+
+class DareLinearMerge(MergeStrategy):
+ def __init__(self, density: float, **kwargs):
+ self.density = density
+
+ def merge(
+ self, task_tensors: List[torch.Tensor], weights: torch.Tensor
+ ) -> torch.Tensor:
+ # sparsify
+ task_tensors = [
+ prune(tensor, self.density, method="random", rescale=True)
+ for tensor in task_tensors
+ ]
+ weighted_task_tensors = _apply_weights(task_tensors, weights)
+ return weighted_task_tensors.sum(dim=0)
+
+
+class DareTiesMerge(MergeStrategy):
+ def __init__(self, density: float, majority_sign_method: str = "total", **kwargs):
+ self.density = density
+ self.majority_sign_method = majority_sign_method
+
+ def merge(
+ self, task_tensors: List[torch.Tensor], weights: torch.Tensor
+ ) -> torch.Tensor:
+ # sparsify
+ task_tensors = [
+ prune(tensor, self.density, method="random", rescale=True)
+ for tensor in task_tensors
+ ]
+ task_tensors = torch.stack(task_tensors, dim=0)
+
+ # elect sign before applying weights
+ majority_sign_mask = calculate_majority_sign_mask(
+ task_tensors, method=self.majority_sign_method
+ )
+ weighted_task_tensors = _apply_weights(task_tensors, weights)
+
+ # disjoint merge
+ mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask)
+ return mixed_task_tensors
+
+
+strategy_registry: Dict[str, Type[MergeStrategy]] = {
+ "linear": LinearMerge,
+ "ties": TiesMerge,
+ "dare_linear": DareLinearMerge,
+ "dare_ties": DareTiesMerge,
+}
+
+
+def merge_adapters(
+ adapters: List[Tuple["ModuleMap", "LoraConfig"]],
+ merge_params: AdapterParameters,
+) -> Tuple["ModuleMap", "LoraConfig"]:
+ # strategy_name = MergeStrategyEnum.Name(merge_params.merge_strategy).lower()
+ strategy_name = "linear"
+
+ weights = merge_params.weights
+ if not weights:
+ weights = torch.ones(len(adapters))
+ else:
+ weights = torch.tensor(weights)
+
+ merge_config = {
+ "density": merge_params.density,
+ # "majority_sign_method": MajoritySignMethodEnum.Name(
+ # merge_params.majority_sign_method
+ # ).lower(),
+ "majority_sign_method": "total",
+ }
+ merge_strategy = strategy_registry[strategy_name](**merge_config)
+
+ module_maps: Dict[str, Dict[str, Dict[str, List[torch.Tensor]]]] = defaultdict(
+ lambda: defaultdict(lambda: defaultdict(list))
+ )
+ lora_configs = []
+ weight_name_to_adapter_idx = defaultdict(list)
+
+ # input is list of (module_map, lora_config) tuples
+ # convert into dict[k][param_name] -> list of tensors
+ for idx, (module_map, lora_config) in enumerate(adapters):
+ for weight_name, data in module_map.items():
+ weight_name_to_adapter_idx[weight_name].append(idx)
+ for k, (param_data, param_name) in data.items():
+ module_maps[weight_name][k][param_name].append(param_data)
+ lora_configs.append(lora_config)
+
+ # validate lora configs are compatible
+ _validate_lora_configs(lora_configs)
+
+ # merge tensors for each module such that we have a single ModuleMap:
+ # dict[k] -> merged tensor
+ merged_module_map: "ModuleMap" = defaultdict(dict)
+ for weight_name, data in module_maps.items():
+ indices = weight_name_to_adapter_idx[weight_name]
+ param_weights = weights[indices]
+ for k, param_data in data.items():
+ for param_name, tensors in param_data.items():
+ merged_tensor = merge_strategy.merge(tensors, param_weights)
+ merged_module_map[weight_name][k] = (merged_tensor, param_name)
+
+ # merge lora configs
+ merged_lora_config = _merge_lora_configs(lora_configs)
+
+ return merged_module_map, merged_lora_config
+
+
+def _validate_lora_configs(lora_configs: List["LoraConfig"]):
+ # check that all configs have the same rank
+ ranks = set(lora_config.r for lora_config in lora_configs)
+ if len(ranks) > 1:
+ raise ValueError(
+ f"unable to merge adapters, lora configs have different ranks: {ranks}"
+ )
+
+ if all(len(lora_config.target_modules) == 0 for lora_config in lora_configs):
+ raise ValueError(
+ "unable to merge adapters, lora configs have no target modules"
+ )
+
+
+def _merge_lora_configs(lora_configs: List["LoraConfig"]) -> "LoraConfig":
+ merged_lora_config = copy.copy(lora_configs[0])
+
+ # merge target modules as a union operation
+ merged_target_modules = sorted(
+ set(
+ module
+ for lora_config in lora_configs
+ for module in lora_config.target_modules
+ )
+ )
+ merged_lora_config.target_modules = merged_target_modules
+
+ return merged_lora_config
diff --git a/backends/gaudi/server/text_generation_server/utils/merges/utils.py b/backends/gaudi/server/text_generation_server/utils/merges/utils.py
new file mode 100644
index 00000000000..d9ad3278a54
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/merges/utils.py
@@ -0,0 +1,108 @@
+# coding=utf-8
+# From: https://github.com/huggingface/peft/pull/1364
+# Copyright 2024-present the HuggingFace Inc. team.
+# Modifications by Predibase, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Literal
+
+import torch
+
+
+def magnitude_based_pruning(tensor: torch.Tensor, density: float) -> torch.Tensor:
+ """
+ Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction
+ `density`.
+
+ Args:
+ tensor (`torch.Tensor`):The tensor to prune.
+ density (`float`):The fraction of values to preserve. Should be in [0,1].
+ """
+ mask = torch.zeros_like(tensor).reshape(-1)
+ k = int(density * tensor.reshape(-1).shape[0])
+ top_k = torch.topk(tensor.abs().reshape(-1), k=k, largest=True)
+ mask[top_k[1]] = 1
+ return tensor * mask.reshape(tensor.shape)
+
+
+def random_pruning(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor:
+ """
+ Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction
+ `density`.
+
+ Args:
+ tensor (`torch.Tensor`):The tensor to prune.
+ density (`float`):The fraction of values to preserve. Should be in [0,1].
+ rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.
+ """
+ mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density))
+ pruned_tensor = tensor * mask
+ if rescale:
+ torch.div(input=pruned_tensor, other=density)
+ return pruned_tensor
+
+
+def prune(
+ tensor: torch.Tensor,
+ density: float,
+ method: Literal["magnitude", "random"],
+ rescale: bool = False,
+) -> torch.Tensor:
+ """
+ Prune the values of task tensors based on the `method`.
+
+ Args:
+ tensor (`torch.Tensor`):The tensor to prune.
+ density (`float`):The fraction of values to preserve. Should be in [0,1].
+ method (`str`):The method to use to prune. Should be one of ["magnitude", "random"].
+ rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.
+ """
+ if density >= 1:
+ return tensor
+ elif density < 0:
+ raise ValueError("Density should be >= 0, got {density}")
+ if method == "magnitude":
+ return magnitude_based_pruning(tensor, density)
+ elif method == "random":
+ return random_pruning(tensor, density, rescale=rescale)
+ else:
+ raise ValueError(f"Unknown method {method}")
+
+
+def calculate_majority_sign_mask(
+ tensor: torch.Tensor, method: Literal["total", "frequency"] = "total"
+):
+ """
+ Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0.
+
+ Args:
+ tensor (`torch.Tensor`):The tensor to get the mask from.
+ method (`str`):The method to use to get the mask. Should be one of ["total", "frequency"].
+ """
+
+ sign = tensor.sign()
+ if method == "total":
+ sign_magnitude = (sign * tensor.abs()).sum(dim=0)
+ elif method == "frequency":
+ sign_magnitude = sign.sum(dim=0)
+ else:
+ raise RuntimeError(f'Unimplemented mask method "{method}"')
+ majority_sign = torch.where(sign_magnitude >= 0, 1, -1)
+ return sign == majority_sign
+
+
+def disjoint_merge(task_tensors, majority_sign_mask):
+ mixed_task_tensors = (task_tensors * majority_sign_mask).sum(dim=0)
+ num_params_preserved = majority_sign_mask.sum(dim=0)
+ return mixed_task_tensors / torch.clamp(num_params_preserved, min=1.0)
diff --git a/backends/gaudi/server/text_generation_server/utils/peft.py b/backends/gaudi/server/text_generation_server/utils/peft.py
new file mode 100644
index 00000000000..d49e73f0096
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/peft.py
@@ -0,0 +1,68 @@
+import os
+from typing import Union
+from loguru import logger
+import torch
+
+from transformers import AutoTokenizer
+from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
+
+
+def download_and_unload_peft(model_id, revision, trust_remote_code):
+ torch_dtype = torch.float16
+
+ logger.info("Trying to load a Peft model. It might take a while without feedback")
+ try:
+ model = AutoPeftModelForCausalLM.from_pretrained(
+ model_id,
+ revision=revision,
+ torch_dtype=torch_dtype,
+ trust_remote_code=trust_remote_code,
+ low_cpu_mem_usage=True,
+ )
+ except Exception:
+ model = AutoPeftModelForSeq2SeqLM.from_pretrained(
+ model_id,
+ revision=revision,
+ torch_dtype=torch_dtype,
+ trust_remote_code=trust_remote_code,
+ low_cpu_mem_usage=True,
+ )
+ logger.info("Peft model detected.")
+ logger.info("Merging the lora weights.")
+
+ base_model_id = model.peft_config["default"].base_model_name_or_path
+
+ model = model.merge_and_unload()
+
+ os.makedirs(model_id, exist_ok=True)
+ cache_dir = model_id
+ logger.info(f"Saving the newly created merged model to {cache_dir}")
+ tokenizer = AutoTokenizer.from_pretrained(
+ base_model_id, trust_remote_code=trust_remote_code
+ )
+ model.save_pretrained(cache_dir, safe_serialization=True)
+ model.config.save_pretrained(cache_dir)
+ tokenizer.save_pretrained(cache_dir)
+
+
+def download_peft(
+ model_id: Union[str, os.PathLike], revision: str, trust_remote_code: bool
+):
+ torch_dtype = torch.float16
+ try:
+ _model = AutoPeftModelForCausalLM.from_pretrained(
+ model_id,
+ revision=revision,
+ torch_dtype=torch_dtype,
+ trust_remote_code=trust_remote_code,
+ low_cpu_mem_usage=True,
+ )
+ except Exception:
+ _model = AutoPeftModelForSeq2SeqLM.from_pretrained(
+ model_id,
+ revision=revision,
+ torch_dtype=torch_dtype,
+ trust_remote_code=trust_remote_code,
+ low_cpu_mem_usage=True,
+ )
+ logger.info("Peft model downloaded.")
diff --git a/backends/gaudi/server/text_generation_server/utils/prefill_chunking.py b/backends/gaudi/server/text_generation_server/utils/prefill_chunking.py
new file mode 100644
index 00000000000..c227d30f512
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/prefill_chunking.py
@@ -0,0 +1,24 @@
+from typing import Optional
+
+SUPPORT_CHUNKING: Optional[bool] = None
+MAX_PREFILL_TOKENS: Optional[int] = None
+
+
+def set_support_chunking(support_chunking: bool):
+ global SUPPORT_CHUNKING
+ SUPPORT_CHUNKING = support_chunking
+
+
+def get_support_chunking() -> bool:
+ global SUPPORT_CHUNKING
+ return SUPPORT_CHUNKING
+
+
+def set_max_prefill_tokens(max_prefill_tokens: int):
+ global MAX_PREFILL_TOKENS
+ MAX_PREFILL_TOKENS = max_prefill_tokens
+
+
+def get_max_prefill_tokens() -> int:
+ global MAX_PREFILL_TOKENS
+ return MAX_PREFILL_TOKENS
diff --git a/backends/gaudi/server/text_generation_server/utils/quantization.py b/backends/gaudi/server/text_generation_server/utils/quantization.py
new file mode 100644
index 00000000000..192963c46cf
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/quantization.py
@@ -0,0 +1,167 @@
+import json
+import os
+from dataclasses import dataclass
+from typing import Optional, List
+
+from huggingface_hub import hf_hub_download
+from text_generation_server.utils.weights import (
+ WeightsLoader,
+)
+
+
+# TODO: Split this config to have a single config type per quant method
+@dataclass
+class _QuantizerConfig:
+ bits: int
+ checkpoint_format: Optional[str]
+ desc_act: bool
+ groupsize: int
+ quant_method: str
+ sym: bool
+ weight_block_size: Optional[List[int]]
+ modules_to_not_convert: List[str]
+
+
+@dataclass
+class _FP8QuantizerConfig:
+ activation_scale_ub: float
+
+
+def _get_config_json(model_id: str, revision: Optional[str], filename: str):
+ if os.path.exists(
+ os.path.join(
+ model_id,
+ )
+ ):
+ filename = os.path.join(model_id, filename)
+ else:
+ filename = hf_hub_download(model_id, filename=filename, revision=revision)
+ with open(filename, "r") as f:
+ return json.load(f)
+
+
+# We should probably do this with Pydantic JSON deserialization,
+# but for now we'll stay close to the old _set_gptq_params.
+def _get_quantizer_config(model_id, revision):
+ bits = 4
+ groupsize = -1
+ quant_method = "gptq"
+ checkpoint_format = None
+ sym = False
+ desc_act = False
+ weight_block_size = None
+ modules_to_not_convert = []
+
+ filename = "config.json"
+ try:
+ data = _get_config_json(model_id, revision, filename)
+ # FP8 config
+ if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
+ return _FP8QuantizerConfig(
+ activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
+ )
+ weight_block_size = data["quantization_config"].get("weight_block_size", None)
+
+ if "zero_point" in data["quantization_config"]:
+ sym = not data["quantization_config"]["zero_point"]
+ quant_method = "awq"
+ elif "sym" in data["quantization_config"]:
+ sym = data["quantization_config"]["sym"]
+
+ bits = data["quantization_config"]["bits"]
+ groupsize = data["quantization_config"]["group_size"]
+ # Order is important here, desc_act is missing on some real models
+ quant_method = data["quantization_config"]["quant_method"]
+ checkpoint_format = data["quantization_config"].get("checkpoint_format")
+ desc_act = data["quantization_config"].get("desc_act", False)
+ modules_to_not_convert = data["quantization_config"].get(
+ "modules_to_not_convert", []
+ )
+ if modules_to_not_convert is None:
+ modules_to_not_convert = []
+ except Exception:
+ filename = "quantize_config.json"
+ try:
+ data = _get_config_json(model_id, revision, filename)
+ bits = data["bits"]
+ groupsize = data["group_size"]
+
+ if "zero_point" in data:
+ sym = not data["zero_point"]
+ quant_method = "awq"
+ elif "sym" in data:
+ sym = data["sym"]
+
+ desc_act = data["desc_act"]
+ if "version" in data and data["version"] == "GEMM":
+ quant_method = "awq"
+ except Exception:
+ filename = "quant_config.json"
+ try:
+ data = _get_config_json(model_id, revision, filename)
+ bits = data["w_bit"]
+ groupsize = data["q_group_size"]
+ desc_act = data["desc_act"]
+ if "version" in data and data["version"] == "GEMM":
+ quant_method = "awq"
+ except Exception:
+ pass
+
+ return _QuantizerConfig(
+ bits=bits,
+ groupsize=groupsize,
+ quant_method=quant_method,
+ checkpoint_format=checkpoint_format,
+ sym=sym,
+ desc_act=desc_act,
+ weight_block_size=weight_block_size,
+ modules_to_not_convert=modules_to_not_convert,
+ )
+
+
+def get_loader(
+ quantize: Optional[str], model_id: str, revision: Optional[str]
+) -> WeightsLoader:
+ if quantize == "compressed-tensors":
+ config = _get_config_json(model_id, revision, "config.json")
+ from text_generation_server.layers.compressed_tensors import (
+ CompressedTensorsLoader,
+ )
+
+ return CompressedTensorsLoader(config)
+ quantizer_config = _get_quantizer_config(model_id, revision)
+ if quantize in {"awq", "gptq"}:
+ from text_generation_server.layers.gptq import GPTQWeightsLoader
+
+ # TODO: improve check once we have one config type per quantize value
+ if not isinstance(quantizer_config, _QuantizerConfig):
+ raise ValueError(
+ f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
+ )
+
+ return GPTQWeightsLoader(
+ bits=quantizer_config.bits,
+ desc_act=quantizer_config.desc_act,
+ groupsize=quantizer_config.groupsize,
+ quant_method=quantizer_config.quant_method,
+ quantize=quantize,
+ sym=quantizer_config.sym,
+ modules_to_not_convert=quantizer_config.modules_to_not_convert,
+ )
+ elif quantize == "fp8" or quantize is None:
+ from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
+
+ # Since the default for the quantize config is _QuantizerConfig,
+ # we need to add this check to not get an attribute error
+ activation_scale_ub = None
+ weight_block_size = quantizer_config.weight_block_size
+ if isinstance(quantizer_config, _FP8QuantizerConfig):
+ activation_scale_ub = quantizer_config.activation_scale_ub
+
+ return HybridFP8UnquantLoader(
+ activation_scale_ub,
+ to_fp8=quantize == "fp8",
+ weight_block_size=weight_block_size,
+ )
+ else:
+ raise ValueError(f"Unknown quantization method: {quantize}")
diff --git a/backends/gaudi/server/text_generation_server/utils/segments.py b/backends/gaudi/server/text_generation_server/utils/segments.py
new file mode 100644
index 00000000000..133049be77a
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/segments.py
@@ -0,0 +1,66 @@
+# Origin: https://github.com/predibase/lorax
+# Path: lorax/server/lorax_server/utils/segments.py
+# License: Apache License Version 2.0, January 2004
+
+from typing import List, Tuple, Union
+
+import torch
+
+
+def find_segments(
+ adapter_indices: Union[torch.Tensor, List[int]],
+) -> Tuple[List[int], List[int]]:
+ segments = [0]
+ segment_indices = []
+
+ if isinstance(adapter_indices, torch.Tensor):
+ # Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first
+ adapter_indices = adapter_indices.cpu().tolist()
+
+ start_index = 0
+ for i in range(1, len(adapter_indices)):
+ if adapter_indices[i] != adapter_indices[i - 1]:
+ segments.append(i)
+ segment_indices.append(adapter_indices[i - 1])
+ start_index = i
+
+ # Handle the last segment
+ if start_index < len(adapter_indices):
+ segments.append(len(adapter_indices))
+ segment_indices.append(adapter_indices[-1])
+
+ return segments, segment_indices
+
+
+class SegmentConcatBuilder:
+ def __init__(self):
+ self.adapter_segment_indices = []
+ self.adapter_segment_tensors = []
+
+ def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]):
+ # Update adapter segments
+ if self.adapter_segment_tensors:
+ # Because we have already processed at least one batch, remove the 0 start index
+ # from this batch denoting the beginning of the segment, then offset all segment
+ # positions by the value of the last segment in the previous batch to account for
+ # the concatenation.
+ adapter_segments = (
+ adapter_segments[1:] + self.adapter_segment_tensors[-1][-1]
+ )
+
+ if (
+ self.adapter_segment_indices
+ and self.adapter_segment_indices[-1] == segment_indices[0]
+ ):
+ # If the last segment in the previous batch is the same as the first segment in this batch,
+ # then we merge them together into a single segment. In effect, this means removing it from
+ # the segment indices of this batch, and extending the segment span by removing the segment
+ # end index from the previous batch.
+ segment_indices = segment_indices[1:]
+ self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1]
+
+ self.adapter_segment_indices.extend(segment_indices)
+ self.adapter_segment_tensors.append(adapter_segments)
+
+ def build(self) -> Tuple[torch.Tensor, List[int]]:
+ return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices
diff --git a/server/text_generation_server/utils/sgmv.py b/backends/gaudi/server/text_generation_server/utils/sgmv.py
similarity index 100%
rename from server/text_generation_server/utils/sgmv.py
rename to backends/gaudi/server/text_generation_server/utils/sgmv.py
diff --git a/backends/gaudi/server/text_generation_server/utils/speculate.py b/backends/gaudi/server/text_generation_server/utils/speculate.py
new file mode 100644
index 00000000000..a1b37a344fd
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/speculate.py
@@ -0,0 +1,11 @@
+SPECULATE = None
+
+
+def get_speculate() -> int:
+ global SPECULATE
+ return SPECULATE
+
+
+def set_speculate(speculate: int):
+ global SPECULATE
+ SPECULATE = speculate
diff --git a/backends/gaudi/server/text_generation_server/utils/tokens.py b/backends/gaudi/server/text_generation_server/utils/tokens.py
new file mode 100644
index 00000000000..9f5ffb87b50
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/tokens.py
@@ -0,0 +1,767 @@
+# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
+
+import re
+from typing import List, Optional, Tuple, Set, Union
+
+import torch
+from text_generation_server.pb import generate_pb2
+from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType
+from text_generation_server.utils.logits_process import (
+ FrequencyPenaltyLogitsProcessor,
+ GrammarLogitProcessor,
+ HeterogeneousProcessorWrapper,
+ HeterogeneousRepetitionPenaltyLogitsProcessor,
+ HeterogeneousFrequencyPenaltyLogitsProcessor,
+ HeterogeneousTemperatureLogitsWarper,
+ HeterogeneousTopKLogitsWarper,
+ HeterogeneousTopPLogitsWarper,
+ HeterogeneousTypicalLogitsWarper,
+ HeterogeneousGrammarLogitProcessor,
+ static_warper,
+)
+from text_generation_server.utils.watermark import WatermarkLogitsProcessor
+from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
+import os
+
+
+class NextTokenChooser:
+ def __init__(
+ self,
+ watermark: bool = False,
+ temperature: float = 1.0,
+ repetition_penalty: float = 1.0,
+ frequency_penalty: float = 0.0,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ typical_p: Optional[float] = None,
+ do_sample: bool = False,
+ seed: int = 0,
+ device: str = "cpu",
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
+ grammar: str = "",
+ grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
+ fsm_grammar_state: int = 0,
+ ):
+ self.watermark_processor = (
+ WatermarkLogitsProcessor(device=device) if watermark else None
+ )
+ self.repetition_processor = (
+ RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
+ if repetition_penalty and repetition_penalty != 1.0
+ else None
+ )
+ self.frequency_processor = (
+ FrequencyPenaltyLogitsProcessor(penalty=frequency_penalty)
+ if frequency_penalty and frequency_penalty != 0.0
+ else None
+ )
+ self.grammar_processor = (
+ GrammarLogitProcessor(tokenizer, device, grammar, grammar_type)
+ if grammar != ""
+ else None
+ )
+ self.tokenizer = tokenizer
+
+ has_warpers = (
+ (temperature is not None and temperature != 1.0)
+ or (top_k is not None and top_k != 0)
+ or (top_p is not None and top_p < 1.0)
+ or (typical_p is not None and typical_p < 1.0)
+ )
+ if has_warpers:
+ self.static_warper = static_warper(
+ temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
+ )
+ else:
+ self.static_warper = None
+
+ sampling = do_sample or has_warpers
+
+ self.choice = Sampling(seed, device) if sampling else Greedy()
+ self.fsm_grammar_state = fsm_grammar_state
+ self.grammar = grammar
+
+ def __call__(self, input_ids, scores):
+ if self.watermark_processor is not None:
+ scores = self.watermark_processor(input_ids, scores)
+ if self.repetition_processor is not None:
+ scores = self.repetition_processor(input_ids, scores)
+ if self.frequency_processor is not None:
+ scores = self.frequency_processor(input_ids, scores)
+ if self.grammar_processor is not None:
+ scores = self.grammar_processor(scores, self.fsm_grammar_state)
+
+ if self.static_warper is None:
+ next_logprob = torch.log_softmax(scores, -1)
+ else:
+ scores, next_logprob = self.static_warper(scores)
+
+ next_id = self.choice(scores[-1]).view(1, 1)
+
+ return next_id, next_logprob
+
+ def advance_grammar(self, next_id: int):
+ if self.grammar_processor is not None:
+ self.fsm_grammar_state = self.grammar_processor.advance(
+ next_id, self.fsm_grammar_state
+ )
+ return self
+
+ @classmethod
+ def from_pb(
+ cls,
+ pb: generate_pb2.NextTokenChooserParameters,
+ device: torch.device,
+ tokenizer: PreTrainedTokenizerBase,
+ ) -> "NextTokenChooser":
+ return NextTokenChooser(
+ watermark=pb.watermark,
+ temperature=pb.temperature,
+ repetition_penalty=pb.repetition_penalty,
+ frequency_penalty=pb.frequency_penalty,
+ top_k=pb.top_k,
+ top_p=pb.top_p,
+ typical_p=pb.typical_p,
+ do_sample=pb.do_sample,
+ seed=pb.seed,
+ device=device,
+ tokenizer=tokenizer,
+ grammar=pb.grammar,
+ grammar_type=pb.grammar_type,
+ )
+
+
+class StopSequenceCriteria:
+ def __init__(self, stop_sequence: str):
+ stop_sequence = re.escape(stop_sequence)
+ self.regex = re.compile(f"{stop_sequence}$")
+
+ def __call__(self, output: str) -> bool:
+ if self.regex.findall(output):
+ return True
+ return False
+
+
+class StoppingCriteria:
+ def __init__(
+ self,
+ eos_token_ids: Optional[Union[Set[int], int]],
+ stop_sequence_criterias: List[StopSequenceCriteria],
+ max_new_tokens: int = 20,
+ ignore_eos_token: bool = False,
+ ):
+ if eos_token_ids is None:
+ eos_token_ids = set()
+ elif isinstance(eos_token_ids, int):
+ eos_token_ids = set([eos_token_ids])
+ elif isinstance(eos_token_ids, set):
+ eos_token_ids = eos_token_ids
+ else:
+ raise RuntimeError(
+ f"eos_token_ids is of invalid type {type(eos_token_ids)}, expected int, None or set[int]"
+ )
+ self.eos_token_ids = eos_token_ids
+ self.stop_sequence_criterias = stop_sequence_criterias
+ self.max_new_tokens = max_new_tokens
+ self.current_tokens = 0
+ self.current_output = ""
+
+ if os.getenv("TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN", "false") == "true":
+ self.ignore_eos_token = True
+ else:
+ self.ignore_eos_token = ignore_eos_token
+
+ def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
+ self.current_tokens += 1
+ if self.current_tokens >= self.max_new_tokens:
+ return True, FinishReason.FINISH_REASON_LENGTH
+
+ if isinstance(last_token, torch.Tensor):
+ last_token = last_token.item()
+
+ if not self.ignore_eos_token and last_token in self.eos_token_ids:
+ return True, FinishReason.FINISH_REASON_EOS_TOKEN
+
+ if self.stop_sequence_criterias:
+ self.current_output += last_output
+ # There is no need to keep an output that is too long
+ if len(self.current_output) > 300:
+ # Slice to -200 to avoid doing it all the time
+ self.current_output = self.current_output[-200:]
+ for stop_sequence_criteria in self.stop_sequence_criterias:
+ if stop_sequence_criteria(self.current_output):
+ return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
+
+ return False, None
+
+ @classmethod
+ def from_pb(
+ cls,
+ pb: generate_pb2.StoppingCriteriaParameters,
+ tokenizer: PreTrainedTokenizerBase,
+ ) -> "StoppingCriteria":
+ stop_sequence_criterias = [
+ StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
+ ]
+ # TODO Hack because eos_token_id cannot be what we want.
+ eos_token_id = getattr(tokenizer, "_eos_token_ids", tokenizer.eos_token_id)
+ return StoppingCriteria(
+ eos_token_id,
+ stop_sequence_criterias,
+ pb.max_new_tokens,
+ pb.ignore_eos_token,
+ )
+
+
+def create_n_gram_speculation(
+ input_ids: torch.Tensor,
+ next_ids: torch.Tensor,
+ accepted_ids: torch.Tensor,
+ speculate: int,
+ verbose: bool,
+):
+ # Very trivial approach, find first match in the string.
+ # This is much less refined than actual n-gram but seems to work
+ # relatively OK in grounded mode and is by far much faster with
+ # much less worst case complexity as everything happens on device.
+ B = accepted_ids.shape[0]
+ device = input_ids.device
+ seeds = next_ids[accepted_ids.cumsum(dim=-1) - 1]
+ indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
+ all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(
+ speculate, device=device
+ )
+ all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
+
+ speculative_ids = input_ids.gather(dim=-1, index=all_indices)
+ return speculative_ids
+
+
+class HeterogeneousNextTokenChooser:
+ def __init__(
+ self,
+ dtype: torch.dtype,
+ device: torch.device,
+ watermark: List[bool],
+ temperature: List[float],
+ repetition_penalty: List[float],
+ frequency_penalty: List[float],
+ top_k: List[int],
+ top_p: List[float],
+ typical_p: List[float],
+ do_sample: List[bool],
+ seeds: List[int],
+ tokenizer: PreTrainedTokenizerBase,
+ grammars: List[str],
+ grammar_types: List[int],
+ fsm_grammar_states: List[int],
+ quantization_enabled: bool,
+ ):
+ warpers = []
+
+ # TODO: enable watermark with FP8 quantization
+ self.watermark_processor = (
+ HeterogeneousProcessorWrapper(
+ {
+ i: WatermarkLogitsProcessor(device=device)
+ for i, do_watermark in enumerate(watermark)
+ if do_watermark
+ }
+ )
+ if any(watermark) and not quantization_enabled
+ else None
+ )
+
+ self.repetition_processor = (
+ HeterogeneousRepetitionPenaltyLogitsProcessor(
+ repetition_penalty, dtype, device
+ )
+ if any([x != 1.0 for x in repetition_penalty])
+ else None
+ )
+
+ self.frequency_processor = (
+ HeterogeneousFrequencyPenaltyLogitsProcessor(
+ frequency_penalty, dtype, device
+ )
+ if any([x != 0.0 for x in frequency_penalty])
+ else None
+ )
+
+ self.grammar_processor = (
+ HeterogeneousGrammarLogitProcessor(
+ tokenizer, device, grammars, grammar_types
+ )
+ if any([grammar != "" for grammar in grammars])
+ else None
+ )
+
+ if any(x != 1.0 for x in temperature):
+ do_sample = [
+ sample or x != 1.0 for x, sample in zip(temperature, do_sample)
+ ]
+ warpers.append(
+ HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
+ )
+
+ if any(x != 0 for x in top_k):
+ do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
+ warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
+
+ if any(x < 1.0 for x in top_p):
+ do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
+ warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
+
+ if any(x < 1.0 for x in typical_p):
+ do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
+ warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
+
+ self.warpers = warpers
+
+ if any(do_sample):
+ self.choice = HeterogeneousSampling(do_sample, seeds, device)
+ else:
+ self.choice = Greedy()
+
+ self.seeds = seeds
+ self.do_sample = do_sample
+ self.dtype = dtype
+ self.device = device
+ self.tokenizer = tokenizer
+ self.fsm_grammar_states = fsm_grammar_states
+ self.grammars = grammars
+ self.grammar_types = grammar_types
+
+ def __call__(
+ self,
+ input_ids: torch.Tensor,
+ scores: torch.Tensor,
+ speculate: int,
+ speculated_ids: Optional[torch.Tensor] = None,
+ speculative_scores: Optional[torch.Tensor] = None,
+ verbose=False,
+ ):
+ if speculated_ids is not None:
+ B = scores.shape[0] // (speculated_ids.shape[1] + 1)
+ S = speculated_ids.shape[1] + 1
+ scores = scores.view(B, S, -1)
+ else:
+ B = scores.shape[0]
+ S = 1
+ scores = scores.view(B, S, -1)
+
+ next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
+
+ for j in range(S):
+ _scores = scores[:, j]
+ if self.watermark_processor is not None:
+ _scores = self.watermark_processor(input_ids, _scores)
+ if self.repetition_processor is not None:
+ _scores = self.repetition_processor(input_ids, _scores)
+ if self.frequency_processor is not None:
+ _scores = self.frequency_processor(input_ids, _scores)
+ if self.grammar_processor is not None:
+ _scores = self.grammar_processor(_scores, self.fsm_grammar_states)
+ for warper in self.warpers:
+ _scores = warper(input_ids, _scores)
+ _next_ids = self.choice(_scores)
+ scores[:, j] = _scores
+ next_ids[:, j] = _next_ids
+ next_ids = next_ids.view(B * S)
+ allscores = scores.view(B * S, -1)
+ alllogprobs = torch.log_softmax(allscores, -1)
+
+ if speculated_ids is not None:
+ accepted_ids = []
+ B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
+ S = speculated_ids.shape[1] + 1
+ indices = []
+ for i in range(B):
+ _next_ids = next_ids[i * S : (i + 1) * S]
+ _speculated_ids = speculated_ids[i]
+ validate_speculative = _next_ids[:-1] == _speculated_ids
+ index = i * S
+ accepted = 1
+ # First is always valid
+ indices.append(index)
+ for valid in validate_speculative.tolist():
+ if valid:
+ index += 1
+ accepted += 1
+ indices.append(index)
+ else:
+ break
+ accepted_ids.append(accepted)
+
+ accepted_ids = torch.tensor(
+ accepted_ids, device=input_ids.device, dtype=input_ids.dtype
+ )
+ next_ids = next_ids[indices]
+ logprobs = alllogprobs[indices]
+ indices = torch.arange(B, device=input_ids.device) * S
+ if speculative_scores is not None:
+ speculative_scores = speculative_scores[indices + accepted_ids - 1]
+ else:
+ accepted_ids = torch.ones_like(next_ids)
+ logprobs = alllogprobs
+
+ next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
+
+ if speculate > 0:
+ if speculative_scores is not None:
+ # Medusa provided some scores
+ speculative_ids = Greedy()(speculative_scores)
+ else:
+ # n-gram
+ speculative_ids = create_n_gram_speculation(
+ input_ids, next_ids, accepted_ids, speculate, verbose
+ )
+ else:
+ speculative_ids = None
+
+ return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
+
+ def advance_grammar(self, next_ids: List[int]):
+ if self.grammar_processor is not None:
+ other_new_states = self.grammar_processor.advance_batch(
+ next_ids, self.fsm_grammar_states
+ )
+ self.fsm_grammar_states = other_new_states
+ return self
+
+ def advance_grammar_single(self, grammar_state_index: int, next_id: int):
+ if self.grammar_processor is not None:
+ self.fsm_grammar_states[grammar_state_index] = (
+ self.grammar_processor.advance_at_index(
+ next_id,
+ self.fsm_grammar_states[grammar_state_index],
+ grammar_state_index,
+ )
+ )
+ return self
+
+ def advance_grammar_single_with_past_state(
+ self, grammar_state_index: int, next_id: torch.Tensor, past_state: int
+ ):
+ if self.grammar_processor is not None:
+ next_id = next_id.item()
+ self.fsm_grammar_states[grammar_state_index] = (
+ self.grammar_processor.advance_at_index(
+ next_id,
+ past_state,
+ grammar_state_index,
+ )
+ )
+ return self
+
+ def filter(self, indices):
+ if self.watermark_processor is not None:
+ self.watermark_processor = self.watermark_processor.filter(indices)
+
+ if self.repetition_processor is not None:
+ self.repetition_processor = self.repetition_processor.filter(indices)
+
+ if self.frequency_processor is not None:
+ self.frequency_processor = self.frequency_processor.filter(indices)
+
+ if self.grammar_processor is not None:
+ self.grammar_processor = self.grammar_processor.filter(indices)
+
+ filtered_warpers = []
+ for warper in self.warpers:
+ filtered_warper = warper.filter(indices)
+ if filtered_warper is not None:
+ filtered_warpers.append(filtered_warper)
+ self.warpers = filtered_warpers
+
+ self.seeds = [self.seeds[i] for i in indices]
+ self.do_sample = [self.do_sample[i] for i in indices]
+
+ new_grammars = []
+ new_fsm_grammar_states = []
+ new_grammar_types = []
+ for i in indices:
+ new_grammars.append(self.grammars[i])
+ new_fsm_grammar_states.append(self.fsm_grammar_states[i])
+ new_grammar_types.append(self.grammar_types[i])
+
+ self.grammars = new_grammars
+ self.fsm_grammar_states = new_fsm_grammar_states
+ self.grammar_types = new_grammar_types
+
+ if any(self.do_sample):
+ self.choice.filter(indices)
+ else:
+ self.choice = Greedy()
+
+ return self
+
+ @classmethod
+ def from_pb(
+ cls,
+ pb: List[generate_pb2.NextTokenChooserParameters],
+ dtype: torch.dtype,
+ device: torch.device,
+ tokenizer: PreTrainedTokenizerBase,
+ fsm_grammar_states: Optional[List[int]] = None,
+ quantization_enabled: bool = False,
+ ) -> "HeterogeneousNextTokenChooser":
+ return HeterogeneousNextTokenChooser(
+ watermark=[pb_.watermark for pb_ in pb],
+ temperature=[pb_.temperature for pb_ in pb],
+ repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
+ frequency_penalty=[pb_.frequency_penalty for pb_ in pb],
+ top_k=[pb_.top_k for pb_ in pb],
+ top_p=[pb_.top_p for pb_ in pb],
+ typical_p=[pb_.typical_p for pb_ in pb],
+ do_sample=[pb_.do_sample for pb_ in pb],
+ seeds=[pb_.seed for pb_ in pb],
+ device=device,
+ dtype=dtype,
+ tokenizer=tokenizer,
+ grammars=[pb_.grammar for pb_ in pb],
+ grammar_types=[pb_.grammar_type for pb_ in pb],
+ fsm_grammar_states=(
+ fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
+ ),
+ quantization_enabled=quantization_enabled,
+ )
+
+
+def pad_next_token_chooser_parameters(
+ parameters: List[generate_pb2.NextTokenChooserParameters],
+ expected_size: int,
+) -> List[generate_pb2.NextTokenChooserParameters]:
+ # disable all logits processors to minimize padding overhead
+ empty_parameters = generate_pb2.NextTokenChooserParameters(
+ temperature=1.0,
+ top_k=0,
+ top_p=1.0,
+ typical_p=1.0,
+ do_sample=False,
+ seed=0,
+ repetition_penalty=1.0,
+ frequency_penalty=0.0,
+ watermark=False,
+ grammar="",
+ grammar_type=0,
+ )
+ parameters.extend([empty_parameters] * (expected_size - len(parameters)))
+ return parameters
+
+
+class Sampling:
+ def __init__(self, seed: int, device: str = "cpu"):
+ if device in ["hpu", torch.device("hpu")]:
+ import habana_frameworks.torch.hpu.random as htrandom
+
+ self.generator = htrandom.default_generators[0].manual_seed(seed)
+ else:
+ self.generator = torch.Generator("cpu")
+ self.generator.manual_seed(seed)
+ self.seed = seed
+
+ def __call__(self, logits):
+ probs = torch.nn.functional.softmax(logits, -1)
+ # Avoid GPU<->CPU sync done by torch multinomial
+ # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
+ q = torch.empty_like(probs).exponential_(1, generator=self.generator)
+ return probs.div_(q).argmax()
+
+
+class Greedy:
+ def __call__(self, logits):
+ return logits.argmax(dim=-1)
+
+
+class HeterogeneousSampling:
+ r"""
+ Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
+ """
+
+ def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):
+ self.seeds = seeds
+
+ self.greedy_indices = []
+ self.sampling_mapping = {}
+ for i, (sample, seed) in enumerate(zip(do_sample, seeds)):
+ if sample:
+ self.sampling_mapping[i] = Sampling(seed, device)
+ else:
+ self.greedy_indices.append(i)
+
+ self.greedy = Greedy()
+
+ def __call__(self, logits):
+ out = torch.zeros(logits.shape[0], dtype=torch.int64, device=logits.device)
+ if self.greedy_indices:
+ # Computing for all indices is faster than slicing
+ torch.argmax(logits, -1, out=out)
+
+ for i, sampling in self.sampling_mapping.items():
+ out[i] = sampling(logits[i])
+ return out
+
+ def filter(self, indices):
+ new_greedy_indices = []
+ new_sampling_mapping = {}
+ for i, idx in enumerate(indices):
+ if idx in self.sampling_mapping:
+ new_sampling_mapping[i] = self.sampling_mapping[idx]
+ else:
+ new_greedy_indices.append(i)
+
+ self.greedy_indices = new_greedy_indices
+ self.sampling_mapping = new_sampling_mapping
+ return self
+
+
+def batch_top_tokens(
+ top_n_tokens: List[int],
+ top_n_tokens_tensor: torch.Tensor,
+ logprobs: torch.Tensor,
+ accepted_ids: torch.Tensor,
+) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
+ """Find the top n most likely tokens for a batch of generations.
+
+ When multiple tokens have equal probabilities and they don't all fit, the
+ remaining tokens are also returned.
+ """
+ max_top_n = max(top_n_tokens)
+ # Early exit when top_n_tokens is not used
+ if max_top_n == 0:
+ return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)
+
+ batch_size = accepted_ids.shape[0]
+ speculate_size = logprobs.shape[0] // batch_size
+ top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)
+ # Ensure top_n doesn't exceed vocab size
+ top_n_tokens = [
+ min(tok, logprobs.size(-1))
+ for tok in top_n_tokens
+ for _ in range(speculate_size)
+ ]
+
+ # Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
+ # Sorted topk is faster than torch.sort() since we only need a small subset
+ sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=-1, sorted=True).values
+
+ nth_highest = torch.gather(
+ sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
+ )
+ nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
+
+ # Find the new "fuzzy" top n values
+ top_n_indices = (logprobs >= nth_highest).nonzero()
+ _, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
+
+ k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max()
+ # Take a new topk for these new max n values
+ top_k = torch.topk(logprobs, k=k, dim=1, sorted=True)
+
+ top_n_ishes = top_n_ishes.tolist()
+ top_indices = top_k.indices.tolist()
+ top_values = top_k.values.tolist()
+
+ batch_top_token_ids = []
+ batch_top_token_logprobs = []
+ accepted_ids_list = accepted_ids.tolist()
+ for i, n_accepted_ids in enumerate(accepted_ids_list):
+ start = speculate_size * i
+ stop = speculate_size * (i + 1)
+ _top_indices = top_indices[start:stop]
+ _top_values = top_values[start:stop]
+ _top_n_ishes = top_n_ishes[start:stop]
+ _top_n_tokens = top_n_tokens[start:stop]
+
+ _top_indices = _top_indices[:n_accepted_ids]
+ _top_values = _top_values[:n_accepted_ids]
+ _top_n_ishes = _top_n_ishes[:n_accepted_ids]
+ _top_n_tokens = _top_n_tokens[:n_accepted_ids]
+
+ row_top_token_ids = []
+ row_top_token_logprobs = []
+
+ for idxs, vals, n, req_n in zip(
+ _top_indices, _top_values, _top_n_ishes, _top_n_tokens
+ ):
+ indices = idxs[:n] if req_n > 0 else []
+ values = vals[:n] if req_n > 0 else []
+
+ row_top_token_ids.append(indices)
+ row_top_token_logprobs.append(values)
+
+ batch_top_token_ids.append(row_top_token_ids)
+ batch_top_token_logprobs.append(row_top_token_logprobs)
+
+ return batch_top_token_ids, batch_top_token_logprobs
+
+
+def make_tokenizer_optional(tokenizer):
+ class _(type(tokenizer)):
+ def __call__(
+ self,
+ text,
+ return_tensors,
+ padding,
+ return_token_type_ids,
+ truncation,
+ max_length,
+ ):
+ assert (
+ return_tensors == "pt"
+ ), "inccorrect input arguments when calling TransparentTokenizer"
+ assert (
+ padding == "max_length" or padding == "longest"
+ ), "inccorrect input arguments when calling TransparentTokenizer"
+ assert (
+ not return_token_type_ids
+ ), "inccorrect input arguments when calling TransparentTokenizer"
+ assert (
+ truncation
+ ), "inccorrect input arguments when calling TransparentTokenizer"
+
+ def str_token_to_int(i):
+ if i == "?":
+ return tokenizer.pad_token_id
+ else:
+ return int(i)
+
+ all_tokens = [
+ [str_token_to_int(i.strip()) for i in inner_text.split(",")]
+ for inner_text in text
+ ]
+ if padding == "longest":
+ max_length = max(len(tokens) for tokens in all_tokens)
+ return {
+ "input_ids": torch.tensor(
+ [
+ [tokenizer.pad_token_id] * (max_length - len(tokens)) + tokens
+ for tokens in all_tokens
+ ]
+ ),
+ "attention_mask": torch.tensor(
+ [
+ [0] * (max_length - len(tokens)) + [1] * len(tokens)
+ for tokens in all_tokens
+ ]
+ ),
+ }
+
+ def decode(
+ self,
+ token_ids,
+ skip_special_tokens: bool = False,
+ clean_up_tokenization_spaces: bool = None,
+ **kwargs,
+ ) -> str:
+ # I don't think this method is used anywhere and should be removed when doing refactoring
+ return ",".join(str(i) for i in to_py_obj(token_ids)) # noqa: F821
+
+ if os.getenv("SKIP_TOKENIZER_IN_TGI", "false").lower() == "true":
+ tokenizer.__class__ = _
+ tokenizer.is_transparent = True
+
+
+def is_tokenizer_transparent(tokenizer):
+ return hasattr(tokenizer, "is_transparent") and tokenizer.is_transparent is True
diff --git a/backends/gaudi/server/text_generation_server/utils/version.py b/backends/gaudi/server/text_generation_server/utils/version.py
new file mode 100644
index 00000000000..74b53dfa5ee
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/version.py
@@ -0,0 +1,37 @@
+from packaging.version import Version
+from packaging import version
+import subprocess
+
+
+def get_driver_version():
+ """
+ Returns the driver version.
+ """
+ # Enable console printing for `hl-smi` check
+ output = subprocess.run(
+ "hl-smi",
+ shell=True,
+ text=True,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ env={"ENABLE_CONSOLE": "true"},
+ )
+ if output.returncode == 0 and output.stdout:
+ return version.parse(
+ output.stdout.split("\n")[2]
+ .replace(" ", "")
+ .split(":")[1][:-1]
+ .split("-")[0]
+ )
+ return None
+
+
+MIN_TGI_GAUDI_SYNAPSE_VERSION = Version("1.19.0")
+
+
+def is_driver_compatible():
+ driver_version = get_driver_version()
+ if driver_version is not None:
+ if driver_version < MIN_TGI_GAUDI_SYNAPSE_VERSION:
+ return False
+ return True
diff --git a/backends/gaudi/server/text_generation_server/utils/watermark.py b/backends/gaudi/server/text_generation_server/utils/watermark.py
new file mode 100644
index 00000000000..5092b076c33
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/watermark.py
@@ -0,0 +1,98 @@
+# coding=utf-8
+# Copyright 2023 Authors of "A Watermark for Large Language Models"
+# available at https://arxiv.org/abs/2301.10226
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+import torch
+from transformers import LogitsProcessor
+from typing import List, Union
+
+GAMMA = float(os.getenv("WATERMARK_GAMMA", 0.5))
+DELTA = float(os.getenv("WATERMARK_DELTA", 2.0))
+
+
+class WatermarkLogitsProcessor(LogitsProcessor):
+ def __init__(
+ self,
+ gamma: float = GAMMA,
+ delta: float = DELTA,
+ hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
+ device: str = "cpu",
+ ):
+ # watermarking parameters
+ self.gamma = gamma
+ self.delta = delta
+ self.rng = torch.Generator(device="cpu")
+ self.hash_key = hash_key
+
+ def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]):
+ if isinstance(input_ids, list):
+ assert (
+ len(input_ids) >= 1
+ ), "requires at least a 1 token prefix sequence to seed rng"
+ prev_token = input_ids[-1]
+ else:
+ assert len(input_ids) == 1
+ input_ids = input_ids[0]
+ assert (
+ input_ids.shape[-1] >= 1
+ ), "requires at least a 1 token prefix sequence to seed rng"
+ prev_token = input_ids[-1].item()
+ self.rng.manual_seed(self.hash_key * prev_token)
+
+ def _get_greenlist_ids(
+ self,
+ input_ids: Union[List[int], torch.LongTensor],
+ max_value: int,
+ device: torch.device,
+ ) -> List[int]:
+ # seed the rng using the previous tokens/prefix
+ self._seed_rng(input_ids)
+
+ greenlist_size = int(max_value * self.gamma)
+ vocab_permutation = torch.randperm(max_value, device=device, generator=self.rng)
+ greenlist_ids = vocab_permutation[:greenlist_size]
+ return greenlist_ids
+
+ @staticmethod
+ def _calc_greenlist_mask(
+ scores: torch.FloatTensor, greenlist_token_ids
+ ) -> torch.BoolTensor:
+ green_tokens_mask = torch.zeros_like(scores)
+ green_tokens_mask[-1, greenlist_token_ids] = 1
+ final_mask = green_tokens_mask.bool()
+ return final_mask
+
+ @staticmethod
+ def _bias_greenlist_logits(
+ scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float
+ ) -> torch.Tensor:
+ scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
+ return scores
+
+ def __call__(
+ self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor
+ ) -> torch.FloatTensor:
+ greenlist_ids = self._get_greenlist_ids(
+ input_ids, scores.shape[-1], scores.device
+ )
+ green_tokens_mask = self._calc_greenlist_mask(
+ scores=scores, greenlist_token_ids=greenlist_ids
+ )
+
+ scores = self._bias_greenlist_logits(
+ scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta
+ )
+ return scores
diff --git a/backends/gaudi/server/text_generation_server/utils/weights.py b/backends/gaudi/server/text_generation_server/utils/weights.py
new file mode 100644
index 00000000000..4edae0d4a48
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/weights.py
@@ -0,0 +1,455 @@
+import torch
+
+from abc import ABC, abstractmethod
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Dict, List, Optional, Union, Type
+from safetensors import safe_open
+from dataclasses import dataclass
+
+
+class WeightsLoader(ABC):
+ """
+ Instances of this type implement higher-level weight loading.
+
+ At a low-level, every weight is stored in the Safetensors format.
+ The interpretation of weights may be different however, for instance
+ could be packed, quantized weights. Loaders are responsible for
+ interpreting the raw tensors, sharding tensors in a manner compatible
+ with the format, etc.
+ """
+
+ @abstractmethod
+ def get_weights(self, weights: "Weights", prefix: str):
+ """
+ Get weights at the given prefix and apply without tensor paralllism.
+ """
+ ...
+
+ @abstractmethod
+ def get_weights_col_packed(
+ self,
+ weights: "Weights",
+ prefix: str,
+ block_sizes: Union[int, List[int]],
+ ):
+ """
+ Get the packed weights at the given prefix with column-splitting for
+ tensor parallelism. This method should be used when multiple different
+ weights are packed into a tensor, for instance, query/key/value
+ weights or a gate/up projection.
+
+ The `block_sizes` determines the proportions of the packed tensors.
+ The columns are split in equally sized blocks when `block_sizes` is an
+ `int`, or in blocks proportional given to the sizes. For instance
+ `[2, 1, 1]` will divide an input with dimensionality `1024` in
+ `[512, 256, 256]`.
+ """
+ ...
+
+ def get_weights_col(self, weights: "Weights", prefix: str):
+ """
+ Get weights at the given prefix and apply column-splitting for tensor
+ paralllism.
+ """
+ return weights.get_multi_weights_col([prefix], 0)
+
+ @abstractmethod
+ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
+ """
+ Get the weights at the given prefixes, column-split them for tensor
+ parallelim, and then concatenate the weights along the given dimension.
+ """
+ ...
+
+ @abstractmethod
+ def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
+ """
+ Get the weights at the given prefixes, column-split them for tensor
+ parallelim, and then concatenate the weights along the given dimension.
+ """
+ ...
+
+ @abstractmethod
+ def get_weights_row(self, weights: "Weights", prefix: str):
+ """
+ Get the weights at the given prefix and apply row-splitting for tensor
+ parallism.
+ """
+ ...
+
+
+class Weight(ABC):
+ """Instances of this type implement unquantized/quantized/to-be
+ quantized weights."""
+
+ @abstractmethod
+ def get_linear(self, bias: torch.Tensor):
+ """Create a linear layer from this weight."""
+ ...
+
+
+@dataclass
+class UnquantizedWeight(Weight):
+ weight: torch.Tensor
+
+ def get_linear(self, bias: torch.Tensor):
+ from text_generation_server.layers.linear import FastLinear
+
+ return FastLinear(self.weight, bias)
+
+
+class DefaultWeightsLoader(WeightsLoader):
+ """Weight loader that loads (unquantized) Torch tensors."""
+
+ def __init__(self, weight_class: Type[UnquantizedWeight]):
+ """Create a loader. Weights will be wrapped using the given `weights_class`,
+ normally this will be `UnquantizedWeight`, but a quantizer-specific class
+ such as `Fp8Weight` can be used to quantize the weights during loading.
+ """
+ self.weight_class = weight_class
+
+ """
+ Loader that uses tensors as-is with the exception of applying sharding
+ and/or concatenation.
+ """
+
+ def get_weights(self, weights: "Weights", prefix: str):
+ return weights.get_tensor(f"{prefix}.weight")
+
+ def get_weights_col_packed(
+ self,
+ weights: "Weights",
+ prefix: str,
+ block_sizes: Union[int, List[int]],
+ ):
+ return self.weight_class(
+ weights.get_packed_sharded(
+ f"{prefix}.weight", dim=0, block_sizes=block_sizes
+ ),
+ )
+
+ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
+ w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
+ return self.weight_class(torch.cat(w, dim=dim))
+
+ def get_weights_row(self, weights: "Weights", prefix: str):
+ return self.weight_class(
+ weights.get_sharded(f"{prefix}.weight", dim=1),
+ )
+
+ def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
+ w = [weights.get_tensor(f"{p}.weight") for p in prefixes]
+ return self.weight_class(torch.cat(w, dim=dim))
+
+
+class Weights:
+ def __init__(
+ self,
+ filenames: List[Path],
+ device,
+ dtype,
+ process_group,
+ weights_loader: WeightsLoader,
+ aliases: Optional[Dict[str, List[str]]] = None,
+ prefix: Optional[str] = None,
+ ):
+ routing = {}
+ for filename in filenames:
+ with safe_open(filename, framework="pytorch") as f:
+ for k in f.keys():
+ if k in routing:
+ raise RuntimeError(
+ f"Key {k} was found in multiple files: {filename} and {routing[k]}"
+ )
+ routing[k] = filename
+ if aliases is None:
+ aliases = {}
+ self.aliases = aliases
+ self.routing = routing
+ self.device = device
+ self.dtype = dtype
+ self.process_group = process_group
+ self.prefix = prefix
+ self.weights_loader = weights_loader
+ self._handles = {}
+
+ def _get_handle(self, filename):
+ if filename not in self._handles:
+ f = safe_open(filename, framework="pytorch")
+ self._handles[filename] = f
+
+ return self._handles[filename]
+
+ def get_filename(self, tensor_name: str) -> (str, str):
+ names = [tensor_name]
+ if self.prefix is not None:
+ prefixed = f"{self.prefix}.{tensor_name}"
+ names.append(prefixed)
+ for name in names:
+ filename = self.routing.get(name, None)
+ if filename is not None:
+ return str(filename), name
+
+ aliases = self.aliases.get(name, [])
+ for alias in aliases:
+ filename = self.routing.get(alias, None)
+ if filename is not None:
+ return str(filename), alias
+ raise RuntimeError(f"weight {tensor_name} does not exist")
+
+ def _get_slice(self, tensor_name: str):
+ filename, tensor_name = self.get_filename(tensor_name)
+ f = self._get_handle(filename)
+ slice_ = f.get_slice(tensor_name)
+ return slice_
+
+ def has_tensor(self, tensor_name: str):
+ try:
+ self.get_filename(tensor_name)
+ except Exception:
+ return False
+ return True
+
+ def get_shape(self, tensor_name: str):
+ return self._get_slice(tensor_name).get_shape()
+
+ def get_tensor(
+ self, tensor_name: str, to_device: bool = True, to_dtype: bool = True
+ ) -> torch.Tensor:
+ filename, tensor_name = self.get_filename(tensor_name)
+ f = self._get_handle(filename)
+ tensor = f.get_tensor(tensor_name)
+ # Special case for gptq which shouldn't convert
+ # u4 which are disguised as int32. Exl2 uses int16
+ # as well. FP8 uses torch.float8_e4m3fn
+ if (
+ tensor.dtype
+ not in [
+ torch.float8_e4m3fn,
+ torch.int8,
+ torch.int16,
+ torch.int32,
+ torch.int64,
+ ]
+ and to_dtype
+ ):
+ tensor = tensor.to(dtype=self.dtype)
+ if to_device:
+ tensor = tensor.to(device=self.device)
+ return tensor
+
+ def get_partial_sharded(
+ self, tensor_name: str, dim: int, to_device=True, to_dtype=True
+ ):
+ filename, tensor_name = self.get_filename(tensor_name)
+ f = self._get_handle(filename)
+ slice_ = f.get_slice(tensor_name)
+ world_size = self.process_group.size()
+ rank = self.process_group.rank()
+
+ size = slice_.get_shape()[dim]
+ block_size = (size + world_size - 1) // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+
+ if dim == 0:
+ tensor = slice_[start:stop]
+ elif dim == 1:
+ tensor = slice_[:, start:stop]
+ else:
+ raise NotImplementedError("Let's make that generic when needed")
+ # Special case for gptq which shouldn't convert
+ # u4 which are disguised as int32. exl2 uses int16.
+ # FP8 uses torch.float8_e4m3fn.
+ if (
+ tensor.dtype
+ not in (torch.float8_e4m3fn, torch.int8, torch.int16, torch.int32)
+ and to_dtype
+ ):
+ tensor = tensor.to(dtype=self.dtype)
+ if to_device:
+ tensor = tensor.to(device=self.device)
+ return tensor
+
+ def get_sharded(self, tensor_name: str, dim: int, to_device=True, to_dtype=True):
+ filename, tensor_name = self.get_filename(tensor_name)
+ f = self._get_handle(filename)
+ slice_ = f.get_slice(tensor_name)
+ world_size = self.process_group.size()
+ size = slice_.get_shape()[dim]
+ assert (
+ size % world_size == 0
+ ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
+ return self.get_partial_sharded(
+ tensor_name, dim, to_device=to_device, to_dtype=to_dtype
+ )
+
+ def get_packed_sharded(
+ self,
+ tensor_name: str,
+ dim: int,
+ block_sizes: Union[int, List[int]],
+ to_dtype=True,
+ ) -> torch.Tensor:
+ """
+ Get a shard from a tensor that packs multiple tensors.
+
+ When a tensor packs multiple tensors (such as QKV or an up
+ projection + gate projection), sharding with `get_sharded` is not
+ safe since it would not split the packed tensors across shards.
+
+ This method shards a tensor, such that the packed tensors are
+ split across shards.
+
+ The columns are split in equally sized blocks when blocks is an `int`, or
+ in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
+ divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
+ convenient for e.g. splitting QKV without knowing the storage details of
+ quantized weights.
+ """
+ slice_ = self._get_slice(tensor_name)
+ total_size = slice_.get_shape()[dim]
+ block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)
+
+ world_size = self.process_group.size()
+ rank = self.process_group.rank()
+
+ tensors_slices = []
+ block_offset = 0
+ for block_size in block_sizes:
+ assert (
+ block_size % world_size == 0
+ ), f"Prepacked tensor cannot be sharded across {world_size} shards"
+ shard_block_size = block_size // world_size
+ start = rank * shard_block_size
+ stop = (rank + 1) * shard_block_size
+ tensors_slices += range(block_offset + start, block_offset + stop)
+ block_offset += block_size
+
+ if dim == 0:
+ tensor = slice_[tensors_slices, ...]
+ elif dim == 1 or dim == -2:
+ tensor = slice_[:, tensors_slices, ...]
+ elif dim == 2 or dim == -1:
+ tensor = slice_[..., tensors_slices]
+ else:
+ raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
+
+ tensor = tensor.to(device=self.device)
+
+ # Avoid casting quantizer dtypes.
+ if (
+ tensor.dtype
+ not in [
+ torch.float8_e4m3fn,
+ torch.int8,
+ torch.int16,
+ torch.int32,
+ torch.int64,
+ ]
+ and to_dtype
+ ):
+ tensor = tensor.to(dtype=self.dtype)
+
+ return tensor
+
+ def get_weights(self, prefix: str):
+ return self.weights_loader.get_weights(self, prefix)
+
+ def get_weights_col_packed_qkv(
+ self,
+ prefix: str,
+ num_heads: int,
+ num_key_value_heads: int,
+ ):
+ return self.get_weights_col_packed(
+ prefix, [num_heads, num_key_value_heads, num_key_value_heads]
+ )
+
+ def get_weights_col_packed_gate_up(self, prefix: str):
+ return self.get_weights_col_packed(prefix, 2)
+
+ def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]):
+ """
+ The columns are split in equally sized blocks when blocks is an `int`, or
+ in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
+ divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
+ convenient for e.g. splitting QKV without knowing the storage details of
+ quantized weights.
+ """
+ return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes)
+
+ def get_weights_col(self, prefix: str):
+ return self.weights_loader.get_weights_col(self, prefix)
+
+ def get_multi_weights_col(self, prefixes: List[str], dim: int):
+ return self.weights_loader.get_multi_weights_col(self, prefixes, dim)
+
+ def get_tensor_shard(self, var, dim):
+ world_size = self.process_group.size()
+ rank = self.process_group.rank()
+ block_size = var.size()[dim] // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+ if dim == 0:
+ tensor = var[start:stop]
+ elif dim == 1:
+ tensor = var[:, start:stop]
+ else:
+ raise NotImplementedError("Let's make that generic when needed")
+ tensor = tensor.to(dtype=self.dtype)
+ tensor = tensor.to(device=self.device)
+ return tensor
+
+ def get_weights_row(self, prefix: str):
+ return self.weights_loader.get_weights_row(self, prefix)
+
+ def get_multi_weights(self, prefixes: List[str], dim: int):
+ return self.weights_loader.get_multi_weights(self, prefixes, dim)
+
+ @contextmanager
+ def use_loader(self, weights_loader: WeightsLoader):
+ """
+ This method is a context manager that can be used to use `Weights` with
+ a different loader for the duration of the context.
+ """
+
+ old_loader = self.weights_loader
+ self.weights_loader = weights_loader
+ try:
+ yield
+ finally:
+ self.weights_loader = old_loader
+
+ @property
+ def loader(self):
+ return self.weights_loader
+
+
+def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
+ """
+ Convert block count or proportions to block sizes.
+
+ This function accepts
+
+ - The number of blocks (int), in which case the block size is
+ total_size//blocks; or
+ - A list of block sizes (List[int]).
+
+ In the latter case, if sum(blocks) < total_size, the ratios between
+ the block sizes will be preserved. For instance, if blocks is
+ [2, 1, 1] and total_size is 1024, the returned block sizes are
+ [512, 256, 256].
+ """
+ if isinstance(blocks, list):
+ total_blocks = sum(blocks)
+ assert (
+ total_size % total_blocks == 0
+ ), f"Cannot split {total_size} in proportional blocks: {blocks}"
+ part_size = total_size // total_blocks
+ return [part_size * block for block in blocks]
+ else:
+ assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
+ single_size = total_size // blocks
+ return [single_size] * blocks
diff --git a/backends/gaudi/tgi-entrypoint.sh b/backends/gaudi/tgi-entrypoint.sh
new file mode 100644
index 00000000000..a5c3f5e1d4b
--- /dev/null
+++ b/backends/gaudi/tgi-entrypoint.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+
+ldconfig 2>/dev/null || echo 'unable to refresh ld cache, not a big deal in most cases'
+
+# Check if --sharded argument is present in the command line arguments
+if [[ "$*" == *"--sharded true"* ]]; then
+ echo 'setting PT_HPU_ENABLE_LAZY_COLLECTIVES=1 for sharding'
+ export PT_HPU_ENABLE_LAZY_COLLECTIVES=1
+fi
+
+text-generation-launcher $@
diff --git a/backends/grpc-metadata/src/lib.rs b/backends/grpc-metadata/src/lib.rs
index 3068a61c3d2..822b03072da 100644
--- a/backends/grpc-metadata/src/lib.rs
+++ b/backends/grpc-metadata/src/lib.rs
@@ -8,7 +8,7 @@ use tracing_opentelemetry::OpenTelemetrySpanExt;
/// Inject context in the metadata of a gRPC request.
struct MetadataInjector<'a>(pub &'a mut tonic::metadata::MetadataMap);
-impl<'a> Injector for MetadataInjector<'a> {
+impl Injector for MetadataInjector<'_> {
/// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs
fn set(&mut self, key: &str, value: String) {
if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) {
diff --git a/backends/llamacpp/Cargo.toml b/backends/llamacpp/Cargo.toml
new file mode 100644
index 00000000000..685a313f120
--- /dev/null
+++ b/backends/llamacpp/Cargo.toml
@@ -0,0 +1,22 @@
+[package]
+name = "text-generation-router-llamacpp"
+version.workspace = true
+edition.workspace = true
+authors.workspace = true
+homepage.workspace = true
+
+[build-dependencies]
+bindgen = "0.71.1"
+pkg-config = "0.3.31"
+
+[dependencies]
+async-trait = "0.1.85"
+clap = "4.5.27"
+hf-hub.workspace = true
+num_cpus = "1.16.0"
+text-generation-router = { path = "../../router" }
+thiserror = "2.0.11"
+tokenizers.workspace = true
+tokio = { version = "1.43.0", features = ["process"] }
+tokio-stream = "0.1.17"
+tracing = "0.1.41"
diff --git a/backends/llamacpp/README.md b/backends/llamacpp/README.md
new file mode 100644
index 00000000000..0971efc5a39
--- /dev/null
+++ b/backends/llamacpp/README.md
@@ -0,0 +1,24 @@
+# Llamacpp backend
+
+If all your dependencies are installed at the system level, running
+cargo build should be sufficient. However, if you want to experiment
+with different versions of llama.cpp, some additional setup is required.
+
+## Install llama.cpp
+
+ LLAMACPP_PREFIX=$(pwd)/llama.cpp.out
+
+ git clone https://github.com/ggerganov/llama.cpp
+ cd llama.cpp
+ cmake -B build \
+ -DCMAKE_INSTALL_PREFIX="$LLAMACPP_PREFIX" \
+ -DLLAMA_BUILD_COMMON=OFF \
+ -DLLAMA_BUILD_TESTS=OFF \
+ -DLLAMA_BUILD_EXAMPLES=OFF \
+ -DLLAMA_BUILD_SERVER=OFF
+ cmake --build build --config Release -j
+ cmake --install build
+
+## Build TGI
+
+ PKG_CONFIG_PATH="$LLAMACPP_PREFIX/lib/pkgconfig" cargo build
diff --git a/backends/llamacpp/build.rs b/backends/llamacpp/build.rs
new file mode 100644
index 00000000000..8f00f3b5bba
--- /dev/null
+++ b/backends/llamacpp/build.rs
@@ -0,0 +1,49 @@
+use bindgen::callbacks::{ItemInfo, ParseCallbacks};
+use std::env;
+use std::path::PathBuf;
+
+#[derive(Debug)]
+struct PrefixStripper;
+
+impl ParseCallbacks for PrefixStripper {
+ fn generated_name_override(&self, item_info: ItemInfo<'_>) -> Option {
+ item_info.name.strip_prefix("llama_").map(str::to_string)
+ }
+}
+
+fn main() {
+ if let Some(cuda_version) = option_env!("CUDA_VERSION") {
+ let mut version: Vec<&str> = cuda_version.split('.').collect();
+ if version.len() > 2 {
+ version.pop();
+ }
+ let cuda_version = format!("cuda-{}", version.join("."));
+ pkg_config::Config::new().probe(&cuda_version).unwrap();
+ }
+ let llama = pkg_config::Config::new().probe("llama").unwrap();
+
+ for path in &llama.link_paths {
+ println!("cargo:rustc-link-arg=-Wl,-rpath,{}", path.display());
+ }
+ if cfg!(target_os = "linux") {
+ println!("cargo:rustc-link-arg=-Wl,--disable-new-dtags");
+ }
+ let bindings = bindgen::Builder::default()
+ .clang_args(
+ llama
+ .include_paths
+ .iter()
+ .map(|p| format!("-I{}", p.display())),
+ )
+ .header_contents("llama_bindings.h", "#include ")
+ .prepend_enum_name(false)
+ .parse_callbacks(Box::new(PrefixStripper))
+ .parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
+ .generate()
+ .expect("Unable to generate bindings");
+
+ let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
+ bindings
+ .write_to_file(out_path.join("llamacpp.rs"))
+ .expect("Couldn't write bindings!");
+}
diff --git a/backends/llamacpp/requirements.txt b/backends/llamacpp/requirements.txt
new file mode 100644
index 00000000000..293cd205583
--- /dev/null
+++ b/backends/llamacpp/requirements.txt
@@ -0,0 +1,4 @@
+transformers==4.49
+huggingface-hub==0.28.1
+hf-transfer==0.1.9
+torch==2.6.0
diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs
new file mode 100644
index 00000000000..3405cfadd90
--- /dev/null
+++ b/backends/llamacpp/src/backend.rs
@@ -0,0 +1,674 @@
+use crate::llamacpp;
+
+use async_trait::async_trait;
+use std::ffi::CString;
+use std::mem::replace;
+use std::str::FromStr;
+use std::sync::{mpsc, Once};
+use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
+use text_generation_router::validation::ValidGenerateRequest;
+use text_generation_router::{FinishReason, Token};
+use thiserror::Error;
+use tokenizers::Tokenizer;
+use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
+use tokio::sync::{oneshot, watch};
+use tokio::task::{spawn, spawn_blocking};
+use tokio::time::{timeout, Duration, Instant};
+use tokio_stream::wrappers::UnboundedReceiverStream;
+use tracing::instrument;
+use tracing::{debug, error, info, trace, warn};
+
+#[derive(Debug, Clone, Copy)]
+pub enum LlamacppSplitMode {
+ GPU(usize),
+ Layer,
+ Row,
+}
+
+impl FromStr for LlamacppSplitMode {
+ type Err = String;
+ fn from_str(s: &str) -> Result {
+ match s.to_lowercase().as_str() {
+ "layer" => Ok(LlamacppSplitMode::Layer),
+ "row" => Ok(LlamacppSplitMode::Row),
+ _ => match s.parse::() {
+ Ok(n) => Ok(LlamacppSplitMode::GPU(n)),
+ Err(_) => Err("Choose a GPU number or `layer` or `row`".to_string()),
+ },
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy, clap::ValueEnum)]
+pub enum LlamacppNuma {
+ Disabled,
+ Distribute,
+ Isolate,
+ Numactl,
+ Mirror,
+}
+
+#[allow(non_camel_case_types)]
+#[derive(Debug, Clone, Copy, clap::ValueEnum)]
+pub enum LlamacppGGMLType {
+ F32,
+ F16,
+ Q4_0,
+ Q4_1,
+ Q5_0,
+ Q5_1,
+ Q8_0,
+ Q8_1,
+ Q2_K,
+ Q3_K,
+ Q4_K,
+ Q5_K,
+ Q6_K,
+ Q8_K,
+ IQ2_XXS,
+ IQ2_XS,
+ IQ3_XXS,
+ IQ1_S,
+ IQ4_NL,
+ IQ3_S,
+ IQ2_S,
+ IQ4_XS,
+ I8,
+ I16,
+ I32,
+ I64,
+ F64,
+ IQ1_M,
+ BF16,
+ TQ1_0,
+ TQ2_0,
+}
+
+// TODO: macro
+impl LlamacppGGMLType {
+ fn to_ggml_type(self) -> llamacpp::ggml_type {
+ match self {
+ LlamacppGGMLType::F32 => llamacpp::GGML_TYPE_F32,
+ LlamacppGGMLType::F16 => llamacpp::GGML_TYPE_F16,
+ LlamacppGGMLType::Q4_0 => llamacpp::GGML_TYPE_Q4_0,
+ LlamacppGGMLType::Q4_1 => llamacpp::GGML_TYPE_Q4_1,
+ LlamacppGGMLType::Q5_0 => llamacpp::GGML_TYPE_Q5_0,
+ LlamacppGGMLType::Q5_1 => llamacpp::GGML_TYPE_Q5_1,
+ LlamacppGGMLType::Q8_0 => llamacpp::GGML_TYPE_Q8_0,
+ LlamacppGGMLType::Q8_1 => llamacpp::GGML_TYPE_Q8_1,
+ LlamacppGGMLType::Q2_K => llamacpp::GGML_TYPE_Q2_K,
+ LlamacppGGMLType::Q3_K => llamacpp::GGML_TYPE_Q3_K,
+ LlamacppGGMLType::Q4_K => llamacpp::GGML_TYPE_Q4_K,
+ LlamacppGGMLType::Q5_K => llamacpp::GGML_TYPE_Q5_K,
+ LlamacppGGMLType::Q6_K => llamacpp::GGML_TYPE_Q6_K,
+ LlamacppGGMLType::Q8_K => llamacpp::GGML_TYPE_Q8_K,
+ LlamacppGGMLType::IQ2_XXS => llamacpp::GGML_TYPE_IQ2_XXS,
+ LlamacppGGMLType::IQ2_XS => llamacpp::GGML_TYPE_IQ2_XS,
+ LlamacppGGMLType::IQ3_XXS => llamacpp::GGML_TYPE_IQ3_XXS,
+ LlamacppGGMLType::IQ1_S => llamacpp::GGML_TYPE_IQ1_S,
+ LlamacppGGMLType::IQ4_NL => llamacpp::GGML_TYPE_IQ4_NL,
+ LlamacppGGMLType::IQ3_S => llamacpp::GGML_TYPE_IQ3_S,
+ LlamacppGGMLType::IQ2_S => llamacpp::GGML_TYPE_IQ2_S,
+ LlamacppGGMLType::IQ4_XS => llamacpp::GGML_TYPE_IQ4_XS,
+ LlamacppGGMLType::I8 => llamacpp::GGML_TYPE_I8,
+ LlamacppGGMLType::I16 => llamacpp::GGML_TYPE_I16,
+ LlamacppGGMLType::I32 => llamacpp::GGML_TYPE_I32,
+ LlamacppGGMLType::I64 => llamacpp::GGML_TYPE_I64,
+ LlamacppGGMLType::F64 => llamacpp::GGML_TYPE_F64,
+ LlamacppGGMLType::IQ1_M => llamacpp::GGML_TYPE_IQ1_M,
+ LlamacppGGMLType::BF16 => llamacpp::GGML_TYPE_BF16,
+ LlamacppGGMLType::TQ1_0 => llamacpp::GGML_TYPE_TQ1_0,
+ LlamacppGGMLType::TQ2_0 => llamacpp::GGML_TYPE_TQ2_0,
+ }
+ }
+}
+
+pub struct LlamacppConfig {
+ pub model_gguf: String,
+ pub max_batch_total_tokens: usize,
+ pub max_physical_batch_total_tokens: usize,
+ pub max_batch_size: usize,
+ pub batch_timeout: Duration,
+ pub n_threads: usize,
+ pub n_threads_batch: usize,
+ pub n_gpu_layers: usize,
+ pub split_mode: LlamacppSplitMode,
+ pub numa: LlamacppNuma,
+ pub defrag_threshold: f32,
+ pub use_mmap: bool,
+ pub use_mlock: bool,
+ pub offload_kqv: bool,
+ pub flash_attention: bool,
+ pub type_k: LlamacppGGMLType,
+ pub type_v: LlamacppGGMLType,
+}
+
+#[derive(Debug)]
+struct LlamacppRequest {
+ input_ids: Vec,
+ top_k: i32,
+ top_p: f32,
+ typical_p: f32,
+ min_keep: usize,
+ temp: f32,
+ seed: u32,
+ penalty_last_n: i32,
+ penalty_repeat: f32,
+ penalty_freq: f32,
+ penalty_present: f32,
+ max_new_tokens: usize,
+ tx: UnboundedSender>,
+ time: Instant,
+}
+
+pub struct LlamacppBackend {
+ tx: UnboundedSender,
+ status: watch::Receiver,
+}
+
+impl LlamacppRequest {
+ fn new(
+ from: &ValidGenerateRequest,
+ tx: UnboundedSender>,
+ ) -> Option {
+ from.input_ids.as_ref().map(|input_ids| LlamacppRequest {
+ input_ids: input_ids.iter().map(|&x| x as i32).collect(),
+ top_k: from.parameters.top_k as _,
+ top_p: from.parameters.top_p as _,
+ typical_p: from.parameters.typical_p as _,
+ min_keep: 0, // disabled
+ temp: from.parameters.temperature as _,
+ seed: from.parameters.seed as _,
+ penalty_last_n: 64, // 0 = disabled, -1 = context size
+ penalty_repeat: from.parameters.repetition_penalty as _,
+ penalty_freq: from.parameters.frequency_penalty as _,
+ penalty_present: 0.0, // disabled
+ max_new_tokens: from.stopping_parameters.max_new_tokens as _,
+ tx,
+ time: Instant::now(),
+ })
+ }
+}
+
+struct Llamacpp {
+ model: *mut llamacpp::llama_model,
+ ctx: *mut llamacpp::llama_context,
+ vocab: *const llamacpp::llama_vocab,
+ logprobs: Vec,
+ batch: llamacpp::llama_batch,
+}
+
+extern "C" fn llamacpp_log_callback(
+ level: llamacpp::ggml_log_level,
+ msg: *const std::os::raw::c_char,
+ _user_data: *mut std::os::raw::c_void,
+) {
+ let cmsg = unsafe { std::ffi::CStr::from_ptr(msg) };
+ let rmsg = cmsg.to_string_lossy().trim_end_matches('\n').to_string();
+
+ match level {
+ llamacpp::GGML_LOG_LEVEL_DEBUG => debug!(target: "llamacpp", "{}", rmsg),
+ llamacpp::GGML_LOG_LEVEL_INFO => info!(target: "llamacpp", "{}", rmsg),
+ llamacpp::GGML_LOG_LEVEL_WARN => warn!(target: "llamacpp", "{}", rmsg),
+ llamacpp::GGML_LOG_LEVEL_ERROR => error!(target: "llamacpp", "{}", rmsg),
+ _ => trace!(target: "llamacpp", "{}", rmsg),
+ }
+}
+
+impl Llamacpp {
+ fn new(conf: LlamacppConfig) -> Result {
+ let gguf = CString::new(conf.model_gguf)?;
+
+ let model = unsafe {
+ let mut params = llamacpp::model_default_params();
+ params.n_gpu_layers = conf.n_gpu_layers as _;
+ params.split_mode = match conf.split_mode {
+ LlamacppSplitMode::GPU(_) => llamacpp::LLAMA_SPLIT_MODE_NONE,
+ LlamacppSplitMode::Layer => llamacpp::LLAMA_SPLIT_MODE_LAYER,
+ LlamacppSplitMode::Row => llamacpp::LLAMA_SPLIT_MODE_ROW,
+ };
+ params.main_gpu = match conf.split_mode {
+ LlamacppSplitMode::GPU(n) => n as _,
+ _ => 0,
+ };
+ params.use_mmap = conf.use_mmap;
+ params.use_mlock = conf.use_mlock;
+ llamacpp::model_load_from_file(gguf.as_ptr(), params)
+ };
+ if model.is_null() {
+ return Err(BackendError::Llamacpp("Failed to load model".to_string()));
+ }
+ let ctx = unsafe {
+ let mut params = llamacpp::context_default_params();
+ params.n_ctx = conf.max_batch_total_tokens as _;
+ params.n_batch = conf.max_batch_total_tokens as _;
+ params.n_ubatch = conf.max_physical_batch_total_tokens as _;
+ params.n_seq_max = conf.max_batch_size as _;
+ params.n_threads = conf.n_threads as _;
+ params.n_threads_batch = conf.n_threads_batch as _;
+ params.defrag_thold = conf.defrag_threshold;
+ params.offload_kqv = conf.offload_kqv;
+ params.flash_attn = conf.flash_attention;
+ params.type_k = conf.type_k.to_ggml_type();
+ params.type_v = conf.type_v.to_ggml_type();
+ params.no_perf = true;
+ llamacpp::init_from_model(model, params)
+ };
+ if ctx.is_null() {
+ return Err(BackendError::Llamacpp("Failed to init context".to_string()));
+ }
+ let vocab = unsafe { llamacpp::model_get_vocab(model) };
+ if vocab.is_null() {
+ return Err(BackendError::Llamacpp("Failed to get vocab".to_string()));
+ }
+ let n_tokens = unsafe { llamacpp::vocab_n_tokens(vocab) };
+ let mut logprobs = Vec::with_capacity(n_tokens as usize);
+
+ for token in 0..n_tokens {
+ logprobs.push(llamacpp::llama_token_data {
+ id: token,
+ logit: 0.0,
+ p: 0.0,
+ });
+ }
+ let batch = unsafe { llamacpp::batch_init(conf.max_batch_total_tokens as _, 0, 1) };
+ Ok(Llamacpp {
+ model,
+ ctx,
+ vocab,
+ logprobs,
+ batch,
+ })
+ }
+
+ fn decode(&mut self) -> i32 {
+ unsafe { llamacpp::decode(self.ctx, self.batch) }
+ }
+
+ fn clear_kv_cache(&mut self, seq_id: llamacpp::llama_seq_id) {
+ unsafe {
+ llamacpp::kv_cache_seq_rm(self.ctx, seq_id, -1, -1);
+ }
+ }
+
+ fn batch_push(
+ &mut self,
+ token: llamacpp::llama_token,
+ pos: llamacpp::llama_pos,
+ seq_id: llamacpp::llama_seq_id,
+ logits: bool,
+ ) -> usize {
+ let n = self.batch.n_tokens as usize;
+ unsafe {
+ *self.batch.token.add(n) = token;
+ *self.batch.pos.add(n) = pos;
+ *self.batch.n_seq_id.add(n) = 1;
+ *(*self.batch.seq_id.add(n)).add(0) = seq_id;
+ *self.batch.logits.add(n) = logits as i8;
+ }
+ self.batch.n_tokens += 1;
+ n
+ }
+}
+
+impl Drop for Llamacpp {
+ fn drop(&mut self) {
+ if !self.ctx.is_null() {
+ unsafe { llamacpp::free(self.ctx) };
+ }
+ if !self.model.is_null() {
+ unsafe { llamacpp::model_free(self.model) };
+ }
+ unsafe { llamacpp::batch_free(self.batch) };
+ }
+}
+
+struct LlamacppSampler {
+ chain: *mut llamacpp::llama_sampler,
+}
+
+impl LlamacppSampler {
+ fn new(req: &LlamacppRequest) -> Option {
+ let chain = unsafe {
+ let params = llamacpp::sampler_chain_default_params();
+ llamacpp::sampler_chain_init(params)
+ };
+ if chain.is_null() {
+ error!("Failed to init sampler");
+ return None;
+ }
+ let (top_k, top_p, typical_p, temp, penalties, dist) = unsafe {
+ (
+ llamacpp::sampler_init_top_k(req.top_k),
+ llamacpp::sampler_init_top_p(req.top_p, req.min_keep),
+ llamacpp::sampler_init_typical(req.typical_p, req.min_keep),
+ llamacpp::sampler_init_temp(req.temp),
+ llamacpp::sampler_init_penalties(
+ req.penalty_last_n,
+ req.penalty_repeat,
+ req.penalty_freq,
+ req.penalty_present,
+ ),
+ llamacpp::sampler_init_dist(req.seed),
+ )
+ };
+ let all = &[
+ ("top_k", top_k),
+ ("top_p", top_p),
+ ("typical_p", typical_p),
+ ("temp", temp),
+ ("penalties", penalties),
+ ("dist", dist),
+ ];
+ let mut failed = false;
+
+ for (k, v) in all {
+ if v.is_null() {
+ error!("Failed to init {k} sampler");
+ failed = true;
+ } else {
+ unsafe { llamacpp::sampler_chain_add(chain, *v) };
+ }
+ }
+ if failed {
+ unsafe { llamacpp::sampler_free(chain) };
+ None
+ } else {
+ Some(LlamacppSampler { chain })
+ }
+ }
+
+ fn sample(&self, llamacpp: &mut Llamacpp, idx: usize) -> (llamacpp::llama_token, f32) {
+ let logits = unsafe { llamacpp::get_logits_ith(llamacpp.ctx, idx as _) };
+ for (token, logprob) in llamacpp.logprobs.iter_mut().enumerate() {
+ *logprob = llamacpp::llama_token_data {
+ id: token as _,
+ logit: unsafe { *logits.add(token) },
+ p: 0.0,
+ };
+ }
+ let mut view = llamacpp::llama_token_data_array {
+ data: llamacpp.logprobs.as_mut_ptr(),
+ size: llamacpp.logprobs.len(),
+ selected: -1,
+ sorted: false,
+ };
+ unsafe {
+ llamacpp::sampler_apply(self.chain, &mut view);
+ let logprob = *view.data.offset(view.selected as _);
+ llamacpp::sampler_accept(self.chain, logprob.id);
+ (logprob.id, logprob.p.ln())
+ }
+ }
+}
+
+impl Drop for LlamacppSampler {
+ fn drop(&mut self) {
+ if !self.chain.is_null() {
+ unsafe { llamacpp::sampler_free(self.chain) };
+ }
+ }
+}
+
+struct LlamacppSeq {
+ id: usize,
+ batch_pos: usize,
+ token: llamacpp::llama_token,
+ pos: llamacpp::llama_pos,
+ sampler: LlamacppSampler,
+ text: String,
+ n_new_tokens: usize,
+ running: bool,
+}
+
+static INIT: Once = Once::new();
+
+impl LlamacppBackend {
+ pub fn new(
+ conf: LlamacppConfig,
+ tokenizer: Tokenizer,
+ ) -> (
+ Self,
+ oneshot::Receiver>,
+ watch::Sender,
+ ) {
+ // Setup llama & export logs, once and for all
+ INIT.call_once(|| unsafe {
+ llamacpp::log_set(Some(llamacpp_log_callback), std::ptr::null_mut());
+ llamacpp::backend_init();
+ llamacpp::numa_init(match conf.numa {
+ LlamacppNuma::Disabled => llamacpp::GGML_NUMA_STRATEGY_DISABLED,
+ LlamacppNuma::Distribute => llamacpp::GGML_NUMA_STRATEGY_DISTRIBUTE,
+ LlamacppNuma::Isolate => llamacpp::GGML_NUMA_STRATEGY_ISOLATE,
+ LlamacppNuma::Numactl => llamacpp::GGML_NUMA_STRATEGY_NUMACTL,
+ LlamacppNuma::Mirror => llamacpp::GGML_NUMA_STRATEGY_MIRROR,
+ });
+ });
+
+ let (status_tx, status_rx) = watch::channel(false);
+ let (shutdown_tx, shutdown_rx) = watch::channel(false);
+ let (ok_tx, ok_rx) = oneshot::channel();
+ let (tx, mut rx) = unbounded_channel::();
+ let (sync_tx, sync_rx) = mpsc::channel();
+
+ spawn(async move {
+ let mut n_tokens = 0;
+ let mut requests = Vec::with_capacity(conf.max_batch_size);
+
+ let flush = |requests: &mut Vec<_>, n_tokens: &mut usize| {
+ if !requests.is_empty() {
+ let _ =
+ sync_tx.send(replace(requests, Vec::with_capacity(conf.max_batch_size)));
+ *n_tokens = 0;
+ }
+ };
+ loop {
+ match timeout(conf.batch_timeout, rx.recv()).await {
+ Ok(Some(request)) => {
+ let n_tokens_to_add = request.input_ids.len();
+
+ if n_tokens + n_tokens_to_add > conf.max_batch_total_tokens {
+ flush(&mut requests, &mut n_tokens);
+ }
+ n_tokens += n_tokens_to_add;
+ requests.push(request);
+
+ if requests.len() == conf.max_batch_size {
+ flush(&mut requests, &mut n_tokens);
+ }
+ }
+ Ok(None) => break, // closed
+ Err(_) => flush(&mut requests, &mut n_tokens), // timeout
+ }
+ }
+ });
+
+ spawn_blocking(move || {
+ let mut llamacpp = match Llamacpp::new(conf) {
+ Ok(v) => {
+ let _ = ok_tx.send(Ok(()));
+ v
+ }
+ Err(e) => {
+ let _ = ok_tx.send(Err(e));
+ return;
+ }
+ };
+ let vocab = tokenizer.get_added_vocabulary();
+
+ // health() returns true
+ let _ = status_tx.send(true);
+
+ while let Ok(requests) = sync_rx.recv() {
+ if *shutdown_rx.borrow() {
+ break;
+ }
+ let start_time = Instant::now();
+ let mut seqs: Vec = Vec::with_capacity(requests.len());
+ llamacpp.batch.n_tokens = 0;
+
+ for (seq_id, request) in requests.iter().enumerate() {
+ debug!("Request: {:?}", request);
+ // TODO remove this
+ let sampler = match LlamacppSampler::new(request) {
+ Some(sampler) => sampler,
+ _ => {
+ let _ = request.tx.send(Err(InferError::IncompleteGeneration));
+ continue;
+ }
+ };
+ let last_pos = request.input_ids.len() - 1;
+
+ for (pos, &token_id) in request.input_ids.iter().enumerate() {
+ llamacpp.batch_push(
+ token_id as llamacpp::llama_token,
+ pos as llamacpp::llama_pos,
+ seq_id as llamacpp::llama_seq_id,
+ pos == last_pos, // check samplers
+ );
+ }
+ seqs.push(LlamacppSeq {
+ id: seq_id,
+ batch_pos: llamacpp.batch.n_tokens as usize - 1,
+ token: llamacpp::LLAMA_TOKEN_NULL,
+ pos: last_pos as llamacpp::llama_pos + 1,
+ sampler,
+ text: String::with_capacity(1024),
+ n_new_tokens: 0,
+ running: true,
+ });
+ }
+ while llamacpp.batch.n_tokens > 0 {
+ if llamacpp.decode() != 0 {
+ warn!("llama_decode failed, clearing kv cache");
+ llamacpp.clear_kv_cache(-1);
+ for seq in seqs.iter_mut() {
+ let _ = requests[seq.id]
+ .tx
+ .send(Err(InferError::IncompleteGeneration));
+ seq.running = false;
+ }
+ break;
+ }
+ for seq in seqs.iter_mut() {
+ if !seq.running {
+ continue;
+ }
+ let (next, logprob) = seq.sampler.sample(&mut llamacpp, seq.batch_pos);
+ seq.n_new_tokens += 1;
+ seq.token = next;
+
+ let piece = match tokenizer.decode(&[next as u32], false) {
+ Ok(piece) => piece,
+ Err(e) => {
+ error!("Failed to decode token: {e}");
+ let _ = requests[seq.id]
+ .tx
+ .send(Err(InferError::IncompleteGeneration));
+ seq.running = false;
+ continue;
+ }
+ };
+ let special = vocab.is_special_token(&piece);
+
+ if !special {
+ seq.text.push_str(&piece);
+ }
+ let token = Token {
+ id: next as _,
+ text: piece,
+ logprob,
+ special,
+ };
+ let finish: Option = {
+ if unsafe { llamacpp::vocab_is_eog(llamacpp.vocab, next) } {
+ Some(FinishReason::EndOfSequenceToken)
+ } else if seq.n_new_tokens == requests[seq.id].max_new_tokens {
+ Some(FinishReason::Length)
+ } else {
+ None
+ }
+ };
+ if let Some(reason) = finish {
+ let _ = requests[seq.id].tx.send(Ok(InferStreamResponse::End {
+ token,
+ top_tokens: vec![],
+ generated_text: GeneratedText {
+ text: seq.text.clone(),
+ generated_tokens: seq.n_new_tokens as _,
+ finish_reason: reason,
+ seed: Some(requests[seq.id].seed as _),
+ },
+ start: start_time,
+ queued: requests[seq.id].time,
+ }));
+ seq.running = false;
+ continue;
+ }
+ let _ = requests[seq.id]
+ .tx
+ .send(Ok(InferStreamResponse::Intermediate {
+ token,
+ top_tokens: vec![],
+ }));
+ }
+ // generate a new batch
+ llamacpp.batch.n_tokens = 0;
+
+ for seq in seqs.iter_mut() {
+ if seq.running {
+ seq.batch_pos =
+ llamacpp.batch_push(seq.token, seq.pos, seq.id as _, true);
+ seq.pos += 1;
+ } else {
+ llamacpp.clear_kv_cache(seq.id as _);
+ }
+ }
+ }
+ }
+ });
+ (
+ Self {
+ tx,
+ status: status_rx,
+ },
+ ok_rx,
+ shutdown_tx,
+ )
+ }
+}
+
+#[async_trait]
+impl Backend for LlamacppBackend {
+ #[instrument(skip_all)]
+ fn schedule(
+ &self,
+ request: ValidGenerateRequest,
+ ) -> Result>, InferError> {
+ debug!(?request);
+ let (tx, rx) = unbounded_channel::>();
+ match LlamacppRequest::new(&request, tx) {
+ Some(v) => match self.tx.send(v) {
+ Err(e) => Err(InferError::GenerationError(e.to_string())),
+ _ => Ok(UnboundedReceiverStream::new(rx)),
+ },
+ _ => Err(InferError::GenerationError("Bad request".to_string())),
+ }
+ }
+
+ async fn health(&self, _: bool) -> bool {
+ *self.status.borrow()
+ }
+
+ fn name(&self) -> &'static str {
+ "llamacpp"
+ }
+}
+
+#[derive(Debug, Error)]
+pub enum BackendError {
+ #[error("CString error: {0}")]
+ CStringError(#[from] std::ffi::NulError),
+ #[error("Llamacpp error: {0}")]
+ Llamacpp(String),
+}
diff --git a/backends/llamacpp/src/llamacpp.rs b/backends/llamacpp/src/llamacpp.rs
new file mode 100644
index 00000000000..fb206df272d
--- /dev/null
+++ b/backends/llamacpp/src/llamacpp.rs
@@ -0,0 +1,5 @@
+#![allow(non_upper_case_globals)]
+#![allow(non_camel_case_types)]
+#![allow(non_snake_case)]
+#![allow(dead_code)]
+include!(concat!(env!("OUT_DIR"), "/llamacpp.rs"));
diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs
new file mode 100644
index 00000000000..9ee61ce6e2e
--- /dev/null
+++ b/backends/llamacpp/src/main.rs
@@ -0,0 +1,351 @@
+mod backend;
+mod llamacpp;
+mod quantize;
+
+use quantize::QuantizeType;
+
+use backend::{
+ BackendError, LlamacppBackend, LlamacppConfig, LlamacppGGMLType, LlamacppNuma,
+ LlamacppSplitMode,
+};
+use clap::Parser;
+use hf_hub::api::tokio::ApiBuilder;
+use hf_hub::{Repo, RepoType};
+use std::path::Path;
+use text_generation_router::{logging, server, usage_stats};
+use thiserror::Error;
+use tokenizers::Tokenizer;
+use tokio::process::Command;
+use tokio::sync::oneshot::error::RecvError;
+use tracing::{error, warn};
+
+/// Backend Configuration
+#[derive(Parser, Debug)]
+#[clap(author, version, about, long_about = None)]
+struct Args {
+ /// Name of the model to load.
+ #[clap(long, env)]
+ model_id: String,
+
+ /// Revision of the model.
+ #[clap(default_value = "main", long, env)]
+ revision: String,
+
+ /// Path to the GGUF model file for inference.
+ #[clap(long, env)]
+ model_gguf: Option,
+
+ /// Number of threads to use for generation.
+ #[clap(long, env)]
+ n_threads: Option,
+
+ /// Number of threads to use for batch processing.
+ #[clap(long, env)]
+ n_threads_batch: Option,
+
+ /// Number of layers to store in VRAM.
+ #[clap(default_value = "0", long, env)]
+ n_gpu_layers: usize,
+
+ /// Split the model across multiple GPUs.
+ #[clap(default_value = "layer", long, env)]
+ split_mode: LlamacppSplitMode,
+
+ /// Defragment the KV cache if holes/size > threshold.
+ #[clap(default_value = "-1.0", long, env)]
+ defrag_threshold: f32,
+
+ /// Enable NUMA optimizations.
+ #[clap(default_value = "disabled", value_enum, long, env)]
+ numa: LlamacppNuma,
+
+ /// Use memory mapping for the model.
+ #[clap(long, env)]
+ disable_mmap: bool,
+
+ /// Use memory locking to prevent swapping.
+ #[clap(long, env)]
+ use_mlock: bool,
+
+ /// Enable offloading of KQV operations to the GPU.
+ #[clap(long, env)]
+ disable_offload_kqv: bool,
+
+ /// Enable flash attention for faster inference. (EXPERIMENTAL)
+ #[clap(long, env)]
+ disable_flash_attention: bool,
+
+ /// Data type used for K cache.
+ #[clap(default_value = "f16", value_enum, long, env)]
+ type_k: LlamacppGGMLType,
+
+ /// Data type used for V cache.
+ #[clap(default_value = "f16", value_enum, long, env)]
+ type_v: LlamacppGGMLType,
+
+ /// Number of tokenizer workers used for payload validation and truncation.
+ #[clap(default_value = "2", long, env)]
+ validation_workers: usize,
+
+ /// Maximum number of concurrent requests.
+ #[clap(long, env)]
+ max_concurrent_requests: Option,
+
+ /// Maximum number of input tokens per request.
+ #[clap(default_value = "1024", long, env)]
+ max_input_tokens: usize,
+
+ /// Maximum number of total tokens (input + output) per request.
+ #[clap(default_value = "2048", long, env)]
+ max_total_tokens: usize,
+
+ /// Maximum number of tokens in a batch.
+ #[clap(long, env)]
+ max_batch_total_tokens: Option,
+
+ /// Maximum number of tokens in a physical batch.
+ #[clap(long, env)]
+ max_physical_batch_total_tokens: Option,
+
+ /// Maximum number of requests per batch.
+ #[clap(long, env)]
+ max_batch_size: Option,
+
+ /// IP address to listen on.
+ #[clap(default_value = "0.0.0.0", long)]
+ hostname: String,
+
+ /// Port to listen on.
+ #[clap(default_value = "3000", long, short, env)]
+ port: u16,
+
+ #[clap(default_value = "9000", long, short, env)]
+ prometheus_port: u16,
+
+ /// Enable JSON output format.
+ #[clap(long, env)]
+ json_output: bool,
+
+ /// OTLP endpoint for telemetry data.
+ #[clap(long, env)]
+ otlp_endpoint: Option,
+
+ /// Service name for OTLP telemetry.
+ #[clap(default_value = "text-generation-inference.router", long, env)]
+ otlp_service_name: String,
+
+ /// Allowed origins for CORS.
+ #[clap(long, env)]
+ cors_allow_origin: Option>,
+
+ /// Path to the tokenizer configuration file.
+ #[clap(long, env)]
+ tokenizer_config_path: Option,
+
+ /// Disable grammar support.
+ #[clap(long, env)]
+ disable_grammar_support: bool,
+
+ /// Maximum number of inputs per request.
+ #[clap(default_value = "4", long, env)]
+ max_client_batch_size: usize,
+
+ /// Level of usage statistics collection.
+ #[clap(default_value = "on", long, env)]
+ usage_stats: usage_stats::UsageStatsLevel,
+
+ /// Maximum payload size in bytes.
+ #[clap(default_value = "2000000", long, env)]
+ payload_limit: usize,
+}
+
+#[tokio::main]
+async fn main() -> Result<(), RouterError> {
+ let args = Args::parse();
+
+ logging::init_logging(args.otlp_endpoint, args.otlp_service_name, args.json_output);
+
+ let n_threads = match args.n_threads {
+ Some(0) | None => num_cpus::get(),
+ Some(threads) => threads,
+ };
+ let n_threads_batch = match args.n_threads_batch {
+ Some(0) | None => n_threads,
+ Some(threads) => threads,
+ };
+ let max_batch_size = match args.max_batch_size {
+ Some(0) | None => n_threads_batch,
+ Some(threads) => threads,
+ };
+ let max_batch_total_tokens = match args.max_batch_total_tokens {
+ None => max_batch_size * args.max_total_tokens,
+ Some(size) => size,
+ };
+ let max_physical_batch_total_tokens = match args.max_physical_batch_total_tokens {
+ None => max_batch_total_tokens,
+ Some(size) => size,
+ };
+ let max_concurrent_requests = match args.max_concurrent_requests {
+ None => max_batch_size * 2,
+ Some(size) => size,
+ };
+ if args.max_input_tokens >= args.max_total_tokens {
+ return Err(RouterError::ArgumentValidation(
+ "`max_input_tokens` must be < `max_total_tokens`".to_string(),
+ ));
+ }
+ if args.max_total_tokens > max_batch_total_tokens {
+ return Err(RouterError::ArgumentValidation(
+ "`max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
+ ));
+ }
+ if max_batch_size * args.max_total_tokens > max_batch_total_tokens {
+ return Err(RouterError::ArgumentValidation(
+ "`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
+ ));
+ }
+
+ let api_builder = || {
+ let mut builder = ApiBuilder::new().with_progress(true);
+
+ if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
+ builder = builder.with_cache_dir(cache_dir.into());
+ }
+ if let Ok(token) = std::env::var("HF_TOKEN") {
+ builder = builder.with_token(token.into());
+ }
+ if let Ok(origin) = std::env::var("HF_HUB_USER_AGENT_ORIGIN") {
+ builder = builder.with_user_agent("origin", origin.as_str());
+ }
+ builder
+ };
+ let api_repo = api_builder().build()?.repo(Repo::with_revision(
+ args.model_id.clone(),
+ RepoType::Model,
+ args.revision.clone(),
+ ));
+
+ let tokenizer_path = api_repo.get("tokenizer.json").await?;
+ let tokenizer = Tokenizer::from_file(&tokenizer_path)?;
+
+ let model_gguf = if let Some(model_gguf) = args.model_gguf {
+ model_gguf
+ } else {
+ let model_gguf = format!("models/{}/model.gguf", args.model_id);
+ let model_gguf_path = Path::new(&model_gguf);
+
+ if !model_gguf_path.exists() {
+ let tmp_gguf = "models/tmp.gguf";
+
+ if let Some(parent) = Path::new(model_gguf_path).parent() {
+ std::fs::create_dir_all(parent)?;
+ }
+ let cache_path = tokenizer_path.parent().unwrap();
+
+ for sibling in api_repo.info().await?.siblings {
+ let _ = api_repo.get(&sibling.rfilename).await?;
+ }
+ let status = Command::new("convert_hf_to_gguf.py")
+ .arg("--outfile")
+ .arg(tmp_gguf)
+ .arg(cache_path)
+ .spawn()?
+ .wait()
+ .await?;
+
+ if !status.success() {
+ let exit_code = status.code().unwrap_or(-1);
+ error!("Failed to generate GGUF, exit code: {}", exit_code);
+ return Err(RouterError::CommandError(exit_code));
+ }
+ quantize::model(tmp_gguf, &model_gguf, QuantizeType::MostlyQ4_0, n_threads)
+ .map_err(RouterError::QuantizeError)?;
+ }
+ model_gguf
+ };
+
+ let (backend, ok, shutdown) = LlamacppBackend::new(
+ LlamacppConfig {
+ model_gguf,
+ n_threads,
+ n_threads_batch,
+ n_gpu_layers: args.n_gpu_layers,
+ split_mode: args.split_mode,
+ defrag_threshold: args.defrag_threshold,
+ numa: args.numa,
+ use_mmap: !args.disable_mmap,
+ use_mlock: args.use_mlock,
+ flash_attention: !args.disable_flash_attention,
+ type_k: args.type_k,
+ type_v: args.type_v,
+ offload_kqv: !args.disable_offload_kqv,
+ max_batch_total_tokens,
+ max_physical_batch_total_tokens,
+ max_batch_size,
+ batch_timeout: tokio::time::Duration::from_millis(5),
+ },
+ tokenizer,
+ );
+ ok.await??;
+
+ if cfg!(debug_assertions) {
+ warn!("Graceful shutdown disabled!");
+ let _ = tokio::task::spawn(async move {
+ let _ = tokio::signal::ctrl_c().await;
+ let _ = shutdown.send(true);
+ });
+ }
+
+ server::run(
+ backend,
+ max_concurrent_requests,
+ 0, // max_best_of
+ 0, // max_stop_sequences
+ 0, // max_top_n_tokens
+ args.max_input_tokens,
+ args.max_total_tokens,
+ args.validation_workers,
+ None, // api_key
+ args.model_id, // tokenizer_name
+ args.tokenizer_config_path,
+ Some(args.revision),
+ false, // trust_remote_code
+ args.hostname,
+ args.port,
+ args.cors_allow_origin,
+ false, // ngrok,
+ None, // ngrok_authtoken,
+ None, // ngrok_edge,
+ args.disable_grammar_support,
+ args.max_client_batch_size,
+ args.usage_stats,
+ args.payload_limit,
+ args.prometheus_port,
+ )
+ .await?;
+ Ok(())
+}
+
+#[derive(Debug, Error)]
+enum RouterError {
+ #[error("Argument validation error: {0}")]
+ ArgumentValidation(String),
+ #[error("Tokenizer error: {0}")]
+ Tokenizer(#[from] tokenizers::Error),
+ #[error("Backend error: {0}")]
+ Backend(#[from] BackendError),
+ #[error("WebServer error: {0}")]
+ WebServer(#[from] server::WebServerError),
+ #[error("Recv error: {0}")]
+ RecvError(#[from] RecvError),
+ #[error("Io error: {0}")]
+ IoError(#[from] std::io::Error),
+ #[error("Var error: {0}")]
+ VarError(#[from] std::env::VarError),
+ #[error("Quantize error: {0}")]
+ QuantizeError(String),
+ #[error("Command error: {0}")]
+ CommandError(i32),
+ #[error("HF hub error: {0}")]
+ HubError(#[from] hf_hub::api::tokio::ApiError),
+}
diff --git a/backends/llamacpp/src/quantize.rs b/backends/llamacpp/src/quantize.rs
new file mode 100644
index 00000000000..31307becf23
--- /dev/null
+++ b/backends/llamacpp/src/quantize.rs
@@ -0,0 +1,35 @@
+use crate::llamacpp;
+
+use std::ffi::CString;
+
+#[repr(u32)]
+#[derive(Debug, Clone, Copy)]
+pub enum QuantizeType {
+ MostlyQ4_0 = 2,
+}
+
+pub fn model(
+ input_path: &str,
+ output_path: &str,
+ ftype: QuantizeType,
+ n_threads: usize,
+) -> Result<(), String> {
+ let c_input_path =
+ CString::new(input_path).map_err(|e| format!("Failed to convert input path: {}", e))?;
+
+ let c_output_path =
+ CString::new(output_path).map_err(|e| format!("Failed to convert output path: {}", e))?;
+
+ let result = unsafe {
+ let mut params = llamacpp::model_quantize_default_params();
+ params.nthread = n_threads as _;
+ params.ftype = ftype as _;
+ params.quantize_output_tensor = true;
+ llamacpp::model_quantize(c_input_path.as_ptr(), c_output_path.as_ptr(), ¶ms)
+ };
+ if result == 0 {
+ Ok(())
+ } else {
+ Err(format!("Quantization failed, error code: {}", result))
+ }
+}
diff --git a/backends/neuron/Cargo.toml b/backends/neuron/Cargo.toml
new file mode 100644
index 00000000000..72f92e69c7e
--- /dev/null
+++ b/backends/neuron/Cargo.toml
@@ -0,0 +1,47 @@
+[workspace]
+members = [
+ "backends/v2",
+ "backends/grpc-metadata",
+ "launcher",
+ "router"
+]
+default-members = [
+ "backends/v2",
+ "backends/grpc-metadata",
+ "launcher",
+ "router"
+]
+resolver = "2"
+
+[workspace.package]
+version = "3.0.0"
+edition = "2021"
+authors = ["Olivier Dehaene"]
+homepage = "https://github.com/huggingface/text-generation-inference"
+
+[workspace.dependencies]
+base64 = "0.22.0"
+tokenizers = { version = "0.20.0", features = ["http"] }
+hf-hub = { version = "0.4.2", features = ["tokio"] }
+metrics = { version = "0.23.0" }
+metrics-exporter-prometheus = { version = "0.15.1", features = [] }
+minijinja = { version = "2.2.0", features = ["json"] }
+minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
+pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
+
+[profile.release]
+incremental = true
+
+[profile.release-binary]
+inherits = "release"
+debug = 1
+incremental = true
+panic = "abort"
+
+[profile.release-opt]
+inherits = "release"
+debug = 0
+incremental = false
+lto = "fat"
+opt-level = 3
+codegen-units = 1
diff --git a/backends/neuron/Makefile b/backends/neuron/Makefile
new file mode 100644
index 00000000000..0667497133f
--- /dev/null
+++ b/backends/neuron/Makefile
@@ -0,0 +1,35 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
+mkfile_dir := $(dir $(mkfile_path))
+root_dir := "${mkfile_dir}/../.."
+
+.PHONY: image install_server test_server test_integration
+
+VERSION := $(shell gawk 'match($$0, /^version = "(.*)"/, a) {print a[1]}' ${root_dir}/Cargo.toml)
+
+image:
+ docker build --rm -f ${root_dir}/Dockerfile.neuron \
+ --ulimit nofile=100000:100000 \
+ --build-arg VERSION=$(VERSION) \
+ -t text-generation-inference:$(VERSION)-neuron ${root_dir}
+ docker tag text-generation-inference:$(VERSION)-neuron text-generation-inference:latest-neuron
+
+install_server:
+ make -C ${mkfile_dir}/server install VERSION:=${VERSION}
+
+test_server: install_server
+ python -m pip install -r ${mkfile_dir}/tests/requirements.txt
+ python -m pytest -sv ${mkfile_dir}/tests/server
diff --git a/backends/neuron/README.md b/backends/neuron/README.md
new file mode 100644
index 00000000000..55722c3bddb
--- /dev/null
+++ b/backends/neuron/README.md
@@ -0,0 +1,25 @@
+# Text-generation-inference - Neuron backend for AWS Trainium and inferentia2
+
+## Description
+
+This is the TGI backend for AWS Neuron Trainium and Inferentia family of chips.
+
+This backend is composed of:
+- the AWS Neuron SDK,
+- the legacy v2 TGI launcher and router,
+- a neuron specific inference server for text-generation.
+
+## Usage
+
+Please refer to the official [documentation](https://huggingface.co/docs/text-generation-inference/backends/neuron).
+
+## Build your own image
+
+The simplest way to build TGI with the neuron backend is to use the provided `Makefile`:
+
+```shell
+$ make -C backends/neuron image
+```
+
+Alternatively, you can build the image directly from the top directory using a command similar to the one defined
+in the `Makefile` under the `image` target.
diff --git a/backends/neuron/server/.gitignore b/backends/neuron/server/.gitignore
new file mode 100644
index 00000000000..378eac25d31
--- /dev/null
+++ b/backends/neuron/server/.gitignore
@@ -0,0 +1 @@
+build
diff --git a/backends/neuron/server/Makefile b/backends/neuron/server/Makefile
new file mode 100644
index 00000000000..efe34bd0df0
--- /dev/null
+++ b/backends/neuron/server/Makefile
@@ -0,0 +1,74 @@
+# Initialize base variables
+SHELL := /bin/bash
+pkg_name := text_generation_server
+BUILDDIR ?= $(CURDIR)/build
+VERSION ?= 0.0.1
+mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
+mkfile_dir := $(dir $(mkfile_path))
+pkg_dir := $(BUILDDIR)/$(pkg_name)
+py_version := $(subst -,.,${VERSION})
+pkg_dist := ${BUILDDIR}/dist/${pkg_name}-$(py_version).tar.gz
+
+clean:
+ rm -rf $(BUILDDIR)/*
+
+${BUILDDIR}:
+ install -d $@
+
+# List static sources to be deployed in the package
+src_dir := $(mkfile_dir)/$(pkg_name)
+sources := $(wildcard $(src_dir)/*.py)
+deployed_sources := $(subst $(src_dir), $(pkg_dir), $(sources))
+
+# Static files are just copied
+
+define COPY
+ cp -f $< $@
+endef
+
+# We use a PHONY target to represent the VERSION
+.PHONY: VERSION
+
+VERSION: ${BUILDDIR}
+ # The trick is to compare the value of the variable with the content of a file in the build directory
+ @if [[ `cat ${BUILDDIR}/VERSION 2>&1` != '$(VERSION)' ]]; then echo -n $(VERSION) >${BUILDDIR}/VERSION; fi
+
+# Depending on the PHONY VERSION target makes sure the pyproject.toml is regenerated if the version changes
+$(BUILDDIR)/pyproject.toml: $(mkfile_dir)/pyproject.toml VERSION
+ mkdir -p $(BUILDDIR)
+ $(COPY)
+ sed -i -e 's/version = "VERSION"/version = \"${VERSION}\"/' $@
+
+$(pkg_dir)/%.py: $(src_dir)/%.py
+ mkdir -p $(pkg_dir)
+ $(COPY)
+
+# Generated files are produced by grpcio tools
+
+# If not provided, get local proto files
+ifndef PROTODIR
+PROTODIR := $(mkfile_dir)/../../../proto
+endif
+
+# Three python files are generated for each protobuf
+protobufs := $(PROTODIR)/generate.proto
+pkg_pb_dir := $(pkg_dir)/pb
+generated_sources_base := $(foreach proto, $(protobufs), $(proto:.proto=_pb2.py))
+generated_sources := $(subst $(PROTODIR), $(pkg_pb_dir), $(generated_sources_base))
+generated_sources += $(subst $(PROTODIR), $(pkg_pb_dir), $(generated_sources_base:.py=.pyi))
+generated_sources += $(subst $(PROTODIR), $(pkg_pb_dir), $(generated_sources_base:.py=_grpc.py))
+
+$(pkg_pb_dir)/%_pb2.py $(pkg_pb_dir)/%_pb2.pyi $(pkg_pb_dir)/%_pb2_grpc.py: $(PROTODIR)/%.proto
+ mkdir -p $(pkg_pb_dir)
+ python -m grpc_tools.protoc -I$(PROTODIR) --python_out=$(pkg_pb_dir) \
+ --grpc_python_out=$(pkg_pb_dir) --mypy_out=$(pkg_pb_dir) $^
+ sed -i -e 's/^\(import.*pb2\)/from . \1/g' $(pkg_pb_dir)/$*_pb2_grpc.py
+
+${pkg_dist}: $(BUILDDIR)/pyproject.toml $(deployed_sources) $(generated_sources)
+ python -m build $(BUILDDIR)
+
+package: ${pkg_dist}
+
+install: ${pkg_dist}
+ python3 -m pip uninstall -y ${pkg_name}
+ python3 -m pip install ${pkg_dist}
diff --git a/backends/neuron/server/build-requirements.txt b/backends/neuron/server/build-requirements.txt
new file mode 100644
index 00000000000..2083bd73f72
--- /dev/null
+++ b/backends/neuron/server/build-requirements.txt
@@ -0,0 +1,3 @@
+build
+grpcio-tools==1.53.0
+mypy-protobuf
diff --git a/backends/neuron/server/pyproject.toml b/backends/neuron/server/pyproject.toml
new file mode 100644
index 00000000000..6bf4e5eee4b
--- /dev/null
+++ b/backends/neuron/server/pyproject.toml
@@ -0,0 +1,26 @@
+[build-system]
+requires = ["setuptools>=78.1"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "text-generation-server"
+version = "VERSION"
+authors = [{name="David Corvoysier", email="david@huggingface.co" }]
+description = "TGI compatible inference server for AWS Neuronx platforms"
+dependencies = [
+ 'protobuf > 3.20.1, < 4',
+ 'grpcio == 1.57.0',
+ 'grpcio-status == 1.48.2',
+ 'grpcio-reflection == 1.48.2',
+ 'grpc-interceptor == 0.15.2',
+ 'typer == 0.6.1',
+ 'safetensors',
+ 'loguru == 0.6.0',
+ 'optimum-neuron[neuronx] >= 0.0.28',
+]
+
+[tool.setuptools]
+packages = ["text_generation_server", "text_generation_server.pb"]
+
+[project.scripts]
+text-generation-server = 'text_generation_server.cli:app'
diff --git a/backends/neuron/server/text_generation_server/cli.py b/backends/neuron/server/text_generation_server/cli.py
new file mode 100644
index 00000000000..4a9c47345f1
--- /dev/null
+++ b/backends/neuron/server/text_generation_server/cli.py
@@ -0,0 +1,115 @@
+import sys
+from typing import Optional
+
+import typer
+from loguru import logger
+
+
+app = typer.Typer()
+
+
+@app.command()
+def serve(
+ model_id: str,
+ revision: Optional[str] = None,
+ sharded: bool = False,
+ trust_remote_code: bool = None,
+ uds_path: str = "/tmp/text-generation-server",
+ logger_level: str = "INFO",
+ json_output: bool = False,
+ otlp_endpoint: Optional[str] = None,
+ otlp_service_name: str = "text-generation-inference.server",
+ max_input_tokens: Optional[int] = None,
+):
+ """This is the main entry-point for the server CLI.
+
+ Args:
+ model_id (`str`):
+ The *model_id* of a model on the HuggingFace hub or the path to a local model.
+ revision (`Optional[str]`, defaults to `None`):
+ The revision of the model on the HuggingFace hub.
+ sharded (`bool`):
+ Whether the model must be sharded or not. Kept for compatibility with the
+ text-generation-launcher, but must be set to False.
+ trust-remote-code (`bool`):
+ Kept for compatibility with text-generation-launcher. Ignored.
+ uds_path (`Union[Path, str]`):
+ The local path on which the server will expose its google RPC services.
+ logger_level (`str`):
+ The server logger level. Defaults to *INFO*.
+ json_output (`bool`):
+ Use JSON format for log serialization.
+ otlp_endpoint (`Optional[str]`, defaults to `None`):
+ The Open Telemetry endpoint to use.
+ otlp_service_name (`Optional[str]`, defaults to `None`):
+ The name to use when pushing data to the Open Telemetry endpoint.
+ max_input_tokens (`Optional[int]`, defaults to `None`):
+ The maximum number of input tokens each request should contain.
+ """
+ if sharded:
+ raise ValueError("Sharding is not supported.")
+ # Remove default handler
+ logger.remove()
+ logger.add(
+ sys.stdout,
+ format="{message}",
+ filter="text_generation_server",
+ level=logger_level,
+ serialize=json_output,
+ backtrace=True,
+ diagnose=False,
+ )
+
+ if trust_remote_code is not None:
+ logger.warning(
+ "'trust_remote_code' argument is not supported and will be ignored."
+ )
+
+ # Import here after the logger is added to log potential import exceptions
+ from .server import serve
+
+ serve(model_id, revision, uds_path)
+
+
+@app.command()
+def download_weights(
+ model_id: str,
+ revision: Optional[str] = None,
+ logger_level: str = "INFO",
+ json_output: bool = False,
+ auto_convert: Optional[bool] = None,
+ extension: Optional[str] = None,
+ trust_remote_code: Optional[bool] = None,
+ merge_lora: Optional[bool] = None,
+):
+ """Download the model weights.
+
+ This command will be called by text-generation-launcher before serving the model.
+ """
+ # Remove default handler
+ logger.remove()
+ logger.add(
+ sys.stdout,
+ format="{message}",
+ filter="text_generation_server",
+ level=logger_level,
+ serialize=json_output,
+ backtrace=True,
+ diagnose=False,
+ )
+
+ if extension is not None:
+ logger.warning("'extension' argument is not supported and will be ignored.")
+ if trust_remote_code is not None:
+ logger.warning(
+ "'trust_remote_code' argument is not supported and will be ignored."
+ )
+ if auto_convert is not None:
+ logger.warning("'auto_convert' argument is not supported and will be ignored.")
+ if merge_lora is not None:
+ logger.warning("'merge_lora' argument is not supported and will be ignored.")
+
+ # Import here after the logger is added to log potential import exceptions
+ from .model import fetch_model
+
+ fetch_model(model_id, revision)
diff --git a/backends/neuron/server/text_generation_server/generator.py b/backends/neuron/server/text_generation_server/generator.py
new file mode 100644
index 00000000000..bd8191bad43
--- /dev/null
+++ b/backends/neuron/server/text_generation_server/generator.py
@@ -0,0 +1,734 @@
+import copy
+import logging
+import time
+from abc import ABC
+from enum import Enum
+from typing import List, Optional, Tuple
+
+import torch
+from loguru import logger
+from transformers import AutoTokenizer, PreTrainedTokenizerBase
+from optimum.neuron.configuration_utils import NeuronConfig
+from transformers.generation import GenerationConfig
+
+from optimum.neuron import NeuronModelForCausalLM
+from optimum.neuron.generation import TokenSelector
+
+from .model import get_export_kwargs_from_env
+from .pb.generate_pb2 import (
+ Batch,
+ CachedBatch,
+ FinishReason,
+ GeneratedText,
+ Generation,
+ InfoResponse,
+ Request,
+ Tokens,
+)
+
+
+# Disable optimum-neuron warnings as it seems to block the server after a while
+optimum_logger = logging.getLogger("optimum.neuron")
+optimum_logger.setLevel("CRITICAL")
+
+
+class Generator(ABC):
+ """An abstract class to represent the workhorse behind TextGenerationService.
+
+ Ideally, it should not rely on protobuf constructs, but in a first step it does.
+ Implementations would typically need a model and a tokenizer to implement the Generator methods.
+ """
+
+ @property
+ def info(self) -> InfoResponse:
+ """This should simply return the expected InfoResponse"""
+ raise NotImplementedError
+
+ def warmup(self, batch: Batch) -> int:
+ """Verify if the hardware can support the target load.
+
+ Args:
+ batch (`Batch`):
+ A batch corresponding to the maximum number of concurrent requests.
+
+ Return:
+ The maximum number of tokens the model supports.
+ """
+ raise NotImplementedError
+
+ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
+ """Prefill is called whenever new requests need to be added.
+
+ When this method returns successfully, a decode method will follow
+ with both the current and newly prefilled batch(es).
+
+ Args:
+ batch (`Batch`):
+ A batch containing the new requests.
+
+ Return:
+ A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
+ """
+ raise NotImplementedError
+
+ def decode(self, batches: List[Batch]) -> Tuple[List[Generation], CachedBatch]:
+ """Decode after a prefill or another decode."""
+ raise NotImplementedError
+
+ def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch:
+ """Remove requests that are not listed from the specified batch"""
+ raise NotImplementedError
+
+ def clear(self):
+ """Remove all requests from the generator"""
+ raise NotImplementedError
+
+ @classmethod
+ def from_pretrained(cls, model_id: str, revision: Optional[str]):
+ """Factory method "a la transformers" """
+ raise NotImplementedError
+
+
+class Slot:
+ """Represents a slot in a static batch"""
+
+ class State(Enum):
+ EMPTY = 0
+ PAUSE = 1
+ READY = 2
+
+ def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase):
+ self._id = id
+ self._tokenizer = tokenizer
+ self.clear()
+
+ def clear(self):
+ """Clear the slot and mark it as available."""
+ self._state = Slot.State.EMPTY
+ self._batch_id = None
+ self._request_id = None
+ self._inputs = ""
+ self._truncate = 0
+ self._generation_config = None
+ self._tokens = []
+ self._mask = torch.tensor([])
+ self._selector = None
+ self._generated_tokens = 0
+ self._next_text_token_start = 0
+ self._next_text_token_end = 0
+ self._generated_text = ""
+ self._next_text = ""
+
+ @property
+ def id(self) -> int:
+ return self._id
+
+ @property
+ def state(self) -> "Slot.State":
+ return self._state
+
+ @property
+ def batch_id(self) -> int:
+ return self._batch_id
+
+ @property
+ def request_id(self) -> int:
+ return self._request_id
+
+ @property
+ def cached_text(self) -> str:
+ return self._inputs + self._generated_text
+
+ @property
+ def generation_config(self) -> GenerationConfig:
+ return self._generation_config
+
+ @property
+ def generated_tokens(self) -> int:
+ return self._generated_tokens
+
+ def assign(
+ self, batch_id: int, request: Request, generation_config: GenerationConfig
+ ):
+ """Assign a request to a slot.
+
+ Args:
+ request (`Request`):
+ The request to be assigned. Contains the inputs and tokens selection parameters.
+ generation_config (`transformers.GenerationConfig`):
+ The base generation config (might be modified by the request generation parameters).
+ """
+ self._state = Slot.State.READY
+ self._batch_id = batch_id
+ self._request_id = request.id
+ self._inputs = request.inputs
+ if request.truncate:
+ self._truncate = request.truncate
+ self._generation_config = copy.deepcopy(generation_config)
+ # Update generation config with request parameters
+ self._generation_config.do_sample = request.parameters.do_sample
+ if self._generation_config.do_sample:
+ if request.parameters.temperature != 0:
+ self._generation_config.temperature = request.parameters.temperature
+ if request.parameters.top_k != 0:
+ self._generation_config.top_k = request.parameters.top_k
+ if request.parameters.top_p != 0:
+ self._generation_config.top_p = request.parameters.top_p
+ if request.parameters.typical_p != 0:
+ self._generation_config.typical_p = request.parameters.typical_p
+ else:
+ # Set the sampling parameters to emulate greedy decoding when using on-device sampling
+ self._generation_config.temperature = 1.0
+ self._generation_config.top_k = 1
+ self._generation_config.top_p = 1.0
+ self._generation_config.typical_p = 1.0
+ if request.parameters.repetition_penalty != 0:
+ self._generation_config.repetition_penalty = (
+ request.parameters.repetition_penalty
+ )
+ self.seed = request.parameters.seed
+ self._generation_config.max_new_tokens = (
+ request.stopping_parameters.max_new_tokens
+ )
+ self._max_new_tokens = self._generation_config.max_new_tokens
+ stop_strings = request.stopping_parameters.stop_sequences
+ if stop_strings:
+ self._generation_config.stop_strings = stop_strings
+
+ def reset(
+ self,
+ input_ids: torch.LongTensor,
+ attention_mask: torch.LongTensor,
+ selector: TokenSelector,
+ ):
+ """Reset the slot for the next generation.
+
+ Args:
+ input_ids: (`torch.LongTensor`):
+ The new input_ids to use to generate the next token.
+ attention_mask: (`torch.LongTensor`):
+ The new attention_mask to use to generate the next token.
+ selector: (`optimum.neuron.generation.TokenSelector`):
+ An object implementing the updated token selection logic.
+ """
+ self._tokens = input_ids.clone()
+ self._next_text_token_start = 0
+ self._next_text_token_end = torch.numel(self._tokens)
+ self._next_text = ""
+ self._mask = attention_mask.clone()
+ self._selector = selector
+
+ def pause(self):
+ """Mark the current slot as paused for generation.
+
+ Note that the KV cache for this slot will still be filled.
+ """
+ self._state = Slot.State.PAUSE
+
+ def resume(self):
+ """Mark the slot as ready for generation."""
+ self._state = Slot.State.READY
+
+ def _decode_next_tokens(
+ self,
+ ) -> str:
+ """Hack to hopefully support generate_stream for the maximum number of tokenizers"""
+ # We need to include the tokens that produced the last text to defeat cleanup algorithms in the decode
+ # which decide to add a space or not depending on the surrounding ids.
+ new_text = self._tokenizer.decode(
+ self._tokens[self._next_text_token_start :], skip_special_tokens=False
+ )
+ if new_text.endswith("�"):
+ # utf-8 char at the end means it's a potential unfinished byte sequence
+ # from byte fallback tokenization.
+ return ""
+
+ # Compare the generated text with the one using only the tokens producing the last one
+ last_text = self._tokenizer.decode(
+ self._tokens[self._next_text_token_start : self._next_text_token_end],
+ skip_special_tokens=False,
+ )
+ if len(new_text) == len(last_text):
+ # Nothing new was actually generated
+ return ""
+ # Return the decoded text and store its token offsets
+ self._next_text_token_start = self._next_text_token_end
+ self._next_text_token_end = torch.numel(self._tokens)
+ return new_text[len(last_text) :]
+
+ def append(self, next_token: int) -> str:
+ """Append a new generated token to this slot
+
+ The new token is added to the list of generated tokens, which impacts
+ directly the generated_text and stopped property.
+
+ The new token is however not added immediately to the slot inputs: it will
+ be added later on when it has effectively been used to produce the next token.
+
+ Args:
+ next_token (`int`):
+ The newly generated token.
+
+ Return:
+ The corresponding decoded text (if any).
+ """
+ self._tokens = torch.cat([self._tokens, torch.LongTensor([next_token])])
+ self._mask = torch.cat([self._mask, torch.LongTensor([1])])
+ self._generated_tokens += 1
+ next_text = self._decode_next_tokens()
+ # Now that a new token has been generated, we can append the previous one to the generated text
+ self._generated_text += self._next_text
+ self._next_text = next_text
+ return next_text
+
+ def select(
+ self, input_ids: torch.LongTensor, logits: torch.Tensor
+ ) -> torch.LongTensor:
+ """Select the next token from the candidate logits.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ The sequence used as a prompt for the generation (not used in all generation modes).
+ logits (`torch.Tensor` of shape `(batch_size, sequence_length)`):
+ The logits corresponding to the generated tokens.
+
+ Return:
+ `torch.LongTensor`: A scalar torch.LongTensor` containing the selected token.
+ """
+ return self._selector.select(input_ids, logits)[0]
+
+ @property
+ def stopped(self) -> bool:
+ # Transformers stopping criteria expects a batch of input ids
+ input_ids = torch.unsqueeze(self._tokens, dim=0)
+ return self._selector.stopping_criteria(input_ids, None)
+
+ @property
+ def generated_text(self) -> str:
+ return self._generated_text + self._next_text
+
+ @property
+ def next_token(self) -> int:
+ return None if len(self._tokens) == 0 else self._tokens[-1]
+
+ @property
+ def attention_mask(self) -> torch.LongTensor:
+ return self._mask
+
+ @property
+ def max_token(self) -> int:
+ return self._generation_config.max_length
+
+ @property
+ def max_new_tokens(self) -> int:
+ # The current value of max_new_tokens: might be different of the target max_new_tokens
+ # if the slot has been paused and resumed.
+ return self._generation_config.max_new_tokens
+
+ @property
+ def truncate(self) -> int:
+ return self._truncate
+
+
+class NeuronGenerator(Generator):
+ """A Generator for Neuron models."""
+
+ def __init__(
+ self,
+ model: NeuronModelForCausalLM,
+ tokenizer: PreTrainedTokenizerBase,
+ ):
+ self.model = model
+ if not isinstance(self.model, NeuronModelForCausalLM):
+ raise ValueError("The model must be a NeuronModelForCausalLM.")
+ if (
+ model.neuron_config.batch_size > 1
+ and not model.neuron_config.continuous_batching
+ ):
+ raise ValueError(
+ "The neuron model must be compiled with continuous_batching=True."
+ )
+ # Specify padding and truncation options for decoder-only architecture
+ tokenizer.pad_token_id = tokenizer.eos_token_id
+ tokenizer.padding_side = "left"
+ tokenizer.truncation_side = "left"
+ self.tokenizer = tokenizer
+ self.special_tokens = self.tokenizer.all_special_ids
+ self.slots = [
+ Slot(i, tokenizer) for i in range(self.model.neuron_config.batch_size)
+ ]
+ self.batch_id = 0
+
+ @property
+ def on_device_sampling(self) -> bool:
+ return getattr(self.model.neuron_config, "on_device_sampling", False)
+
+ @property
+ def info(self) -> InfoResponse:
+ """Returns the expected InfoResponse."""
+ dtype = getattr(self.model.config, "torch_dtype", "float32")
+ return InfoResponse(
+ requires_padding=True,
+ dtype=str(dtype),
+ device_type="xla",
+ )
+
+ def warmup(self, batch: Batch) -> int:
+ """Verify if the hardware can support the target load.
+
+ Args:
+ batch (`Batch`):
+ A batch corresponding to the maximum number of concurrent requests.
+
+ Return:
+ The maximum number of tokens the model supports.
+ """
+ # Just check that the warmup request parameters match the model capacity
+ batch_size = self.model.neuron_config.batch_size
+ if len(batch.requests) > batch_size:
+ raise ValueError(
+ f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model.neuron_config.batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE."
+ )
+ self.prefill(batch)
+ self.clear()
+ return (
+ self.model.neuron_config.batch_size
+ * self.model.neuron_config.sequence_length
+ )
+
+ def max_prefill_length(self) -> int:
+ if hasattr(self.model.neuron_config, "max_context_length"):
+ return self.model.neuron_config.max_context_length
+ return self.model.neuron_config.sequence_length
+
+ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
+ """Prefill new requests.
+
+ Args:
+ batch (`Batch`):
+ A batch containing the new requests.
+
+ Return:
+ A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
+ """
+ slots = {state: [] for state in Slot.State}
+ for slot in self.slots:
+ slots[slot.state].append(slot)
+ active_slots = slots[Slot.State.READY]
+ empty_slots = slots[Slot.State.EMPTY]
+ if len(empty_slots) < len(batch.requests):
+ raise ValueError(
+ f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots."
+ f" Please align max_batch_size with the static batch size: {self.model.neuron_config.batch_size}."
+ )
+ # Assign each request to an empty slot
+ logger.debug(
+ f"Prefilling {len(batch.requests)} new request(s) with {len(empty_slots)} empty slot(s)"
+ )
+ new_slots = []
+ for request in batch.requests:
+ slot = empty_slots.pop()
+ slot.assign(self.batch_id, request, self.model.generation_config)
+ new_slots.append(slot)
+ logger.debug(
+ f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
+ )
+ prefill_slots = new_slots
+ seq_ids = torch.tensor([slot.id for slot in prefill_slots])
+ # Reconstruct the full inputs (without padding) as seen by the model.
+ # This comprises:
+ # - the inputs for new requests,
+ # - only when rebuilding the cache, the inputs and the generated text that has already
+ # been cached (i.e. excluding the last generated token) for unfinished requests.
+ inputs = []
+ max_length = 0
+ for slot in prefill_slots:
+ inputs.append(slot.cached_text)
+ # Apply truncation, making sure we fit into static dimensions
+ if slot.truncate == 0:
+ max_length = self.max_prefill_length()
+ elif (
+ slot.truncate > max_length and slot.truncate < self.max_prefill_length()
+ ):
+ max_length = slot.truncate
+ # Tokenize with padding and truncation
+ padded_inputs = self.tokenizer(
+ inputs,
+ return_tensors="pt",
+ padding=True,
+ truncation=True,
+ max_length=max_length,
+ )
+ input_ids = padded_inputs.input_ids
+ attention_mask = padded_inputs.attention_mask
+ sampling_params = (
+ torch.zeros(input_ids.shape[0], 3) if self.on_device_sampling else None
+ )
+ # Pause previously active slots during generation
+ for slot in active_slots:
+ slot.pause()
+ # Each slot must be reset with the padded inputs and masks
+ for i, slot in enumerate(prefill_slots):
+ if slot.state != slot.state.EMPTY:
+ if slot.truncate > 0 and slot.truncate < input_ids.shape[-1]:
+ # Apply per-request truncation
+ input_ids[i, : -slot.truncate] = self.tokenizer.pad_token_id
+ attention_mask[i, : -slot.truncate] = 0
+ slot_input_ids = input_ids[i : i + 1, :]
+ # Padded input ids are also required to set logits processors and stopping criterias
+ selector = TokenSelector.create(
+ slot_input_ids,
+ slot.generation_config,
+ self.model,
+ self.model.neuron_config.sequence_length,
+ tokenizer=self.tokenizer,
+ seed=slot.seed,
+ )
+ slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
+ slot_attention_mask = attention_mask[i]
+ slot.reset(slot_input_ids, slot_attention_mask, selector)
+ if sampling_params is not None:
+ sampling_params[i, 0] = slot.generation_config.top_k
+ sampling_params[i, 1] = slot.generation_config.top_p
+ sampling_params[i, 2] = slot.generation_config.temperature
+ # Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored,
+ # as they have already been generated and sent back in the last decode.
+ model_inputs = self.model.prepare_inputs_for_prefill(
+ input_ids,
+ attention_mask=attention_mask,
+ seq_ids=seq_ids,
+ sampling_params=sampling_params,
+ )
+ tokens_or_logits = self.model(**model_inputs)[0]
+ generation, next_batch = self._generate_token(
+ prefill_slots, self.batch_id, tokens_or_logits, input_ids
+ )
+ self.batch_id += 1
+ # Reactivate previously active slots for the next decode
+ for i, slot in enumerate(active_slots):
+ slot.resume()
+ logger.debug("Model ready for decoding")
+ if next_batch is not None:
+ logger.debug(
+ f"Next batch is {next_batch.id} with requests: {next_batch.request_ids}"
+ )
+ return generation, next_batch
+
+ def decode(
+ self, batches: List[CachedBatch]
+ ) -> Tuple[List[Generation], CachedBatch]:
+ """Decode the specified prefilled requests.
+
+ Args:
+ batches (`List[CachedBatch]`):
+ A list of previous batches containing the prefilled requests.
+
+ Return:
+ A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
+ """
+ # batches contains a list composed of:
+ # - the batch id returned by the last decode,
+ # - the batch id(s) returned by the last prefill(s)
+ # Batches are always concatenated during prefill, so we can
+ # just carry on with decoding. We adopt the id of the first
+ # batch in the list as our next batch id.
+ next_batch_id = batches[0].id
+ request_ids = []
+ for batch in batches:
+ request_ids += batch.request_ids
+ cleared_request_ids = []
+ for slot in self.slots:
+ if slot.state == slot.State.READY and slot.request_id not in request_ids:
+ cleared_request_ids.append(slot.request_id)
+ slot.clear()
+ if len(cleared_request_ids) > 0:
+ logger.info(
+ f"Clearing slot for requests {cleared_request_ids} as they are not requested."
+ )
+ active_slots = [slot for slot in self.slots if slot.state == slot.State.READY]
+ if len(active_slots) < len(request_ids):
+ raise ValueError(
+ "Unable to decode tokens for non-prefilled batches (probably due to a previous failure)"
+ )
+ decode_slots = active_slots
+ seq_ids = torch.tensor([slot.id for slot in decode_slots])
+ # Reconstruct input_ids and attention_mask from decode slots
+ n_slots = len(decode_slots)
+ input_ids = torch.full(
+ [n_slots, 1], fill_value=self.tokenizer.eos_token_id, dtype=torch.int64
+ )
+ max_length = 0
+ for slot in decode_slots:
+ max_length = max(max_length, slot.attention_mask.size(-1))
+ attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64)
+ sampling_params = torch.zeros(n_slots, 3) if self.on_device_sampling else None
+ for i, slot in enumerate(decode_slots):
+ if slot.state != Slot.State.EMPTY:
+ # input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
+ input_ids[i, 0] = slot.next_token
+ attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask
+ if sampling_params is not None:
+ sampling_params[i, 0] = slot.generation_config.top_k
+ sampling_params[i, 1] = slot.generation_config.top_p
+ sampling_params[i, 2] = slot.generation_config.temperature
+ model_inputs = self.model.prepare_inputs_for_decode(
+ input_ids,
+ attention_mask=attention_mask,
+ seq_ids=seq_ids,
+ sampling_params=sampling_params,
+ )
+ tokens_or_logits = self.model(**model_inputs)[0]
+ return self._generate_token(
+ decode_slots, next_batch_id, tokens_or_logits, input_ids
+ )
+
+ def _generate_token(
+ self,
+ slots: List[Slot],
+ next_batch_id: int,
+ tokens_or_logits: torch.Tensor,
+ input_ids: torch.LongTensor,
+ ) -> Tuple[List[Generation], CachedBatch]:
+ generations = []
+ active_slots = False
+ for i, slot in enumerate(slots):
+ if slot.state != Slot.State.READY:
+ continue
+ request_id = slot.request_id
+ slot_input_ids = input_ids[i : i + 1, :]
+ if self.on_device_sampling:
+ next_token = tokens_or_logits[i]
+ else:
+ next_token_logits = tokens_or_logits[i : i + 1, -1, :]
+ next_token = slot.select(slot_input_ids, next_token_logits)
+ next_token_text = slot.append(next_token)
+ generated_text = None
+ finish_reason = None
+ if next_token == self.tokenizer.eos_token_id:
+ finish_reason = FinishReason.FINISH_REASON_EOS_TOKEN
+ elif slot.stopped:
+ if slot.generated_tokens == slot.max_new_tokens:
+ finish_reason = FinishReason.FINISH_REASON_LENGTH
+ else:
+ finish_reason = FinishReason.FINISH_REASON_STOP_SEQUENCE
+ if finish_reason is not None:
+ # We must include the generated text for each finished sequence in the response
+ generated_text = GeneratedText(
+ text=slot.generated_text,
+ generated_tokens=slot.generated_tokens,
+ finish_reason=finish_reason,
+ )
+ logger.debug(
+ f"Decode complete for request {request_id} with {slot.generated_tokens} tokens"
+ )
+ # mark the slot as available
+ slot.clear()
+ else:
+ active_slots = True
+ generations.append(
+ Generation(
+ request_id=request_id,
+ prefill_tokens=None,
+ tokens=Tokens(
+ ids=[next_token],
+ logprobs=[0],
+ texts=[next_token_text],
+ is_special=[next_token in self.special_tokens],
+ ),
+ generated_text=generated_text,
+ )
+ )
+ batch = None
+ if active_slots:
+ # Whatever initial batch these requests came from, we always return all pending requests in a single batch
+ request_ids = [
+ slot.request_id for slot in self.slots if slot.state == Slot.State.READY
+ ]
+ batch = self._cached_batch(next_batch_id, request_ids)
+ else:
+ logger.debug("No more pending requests")
+ return generations, batch
+
+ def _cached_batch(self, batch_id: int, request_ids: List):
+ size = len(request_ids)
+ max_tokens = size * self.model.neuron_config.sequence_length
+ return CachedBatch(
+ id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens
+ )
+
+ def filter(self, batch_id: int, keep_request_ids: List[int]) -> CachedBatch:
+ """Remove requests that are not listed from the specified batch
+
+ Args:
+ batch_id (`int`):
+ The id of a cached batch.
+ keep_ids(`List[int]`):
+ The list of requests that must be kept.
+
+ Return:
+ A `CachedBatch` containing the pending requests.
+ """
+ keep_slot_ids = [
+ slot.id for slot in self.slots if slot.request_id in keep_request_ids
+ ]
+ self._clear(keep_slot_ids)
+ return self._cached_batch(batch_id, keep_request_ids)
+
+ def clear(self, batch_id: Optional[int] = None):
+ """Remove a subset or all requests from the generator"""
+ keep_ids = []
+ if batch_id is not None:
+ keep_ids = [slot.id for slot in self.slots if slot.batch_id != batch_id]
+ return self._clear(keep_ids)
+
+ def _clear(self, keep_slot_ids: List):
+ for slot in self.slots:
+ if slot.state != Slot.State.EMPTY and slot.id not in keep_slot_ids:
+ logger.debug(f"Removing slot {slot.id} with request {slot.request_id}")
+ slot.clear()
+
+ @classmethod
+ def from_pretrained(cls, model_id: str, revision: str = None):
+ """Instantiate a NeuronGenerator.
+
+ Args:
+ model_id (`str`):
+ A hub model id or the path to a local model. This path must also contain a Tokenizer.
+ revision (`Optional[str]`, defaults to `None`):
+ The revision of the model on the HuggingFace hub.
+
+ Returns:
+ A NeuronGenerator.
+ """
+ try:
+ neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision)
+ except Exception as e:
+ logger.debug(
+ "NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
+ model_id,
+ revision,
+ e,
+ )
+ neuron_config = None
+ start = time.time()
+ if neuron_config is None:
+ export_kwargs = get_export_kwargs_from_env()
+ logger.info(f"Exporting model to neuron with config: {export_kwargs}.")
+ model = NeuronModelForCausalLM.from_pretrained(
+ model_id,
+ revision=revision,
+ low_cpu_mem_usage=True,
+ export=True,
+ **export_kwargs,
+ )
+ else:
+ logger.info(
+ "Loading model on neuron devices (this can take a few minutes)."
+ )
+ model = NeuronModelForCausalLM.from_pretrained(
+ model_id, low_cpu_mem_usage=True, revision=revision
+ )
+ end = time.time()
+ logger.info(f"Model successfully loaded in {end - start:.2f} s.")
+ tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
+ return cls(model, tokenizer)
diff --git a/backends/neuron/server/text_generation_server/interceptor.py b/backends/neuron/server/text_generation_server/interceptor.py
new file mode 100644
index 00000000000..301cafd8739
--- /dev/null
+++ b/backends/neuron/server/text_generation_server/interceptor.py
@@ -0,0 +1,29 @@
+from typing import Any, Callable
+
+import grpc
+from google.rpc import code_pb2, status_pb2
+from grpc_interceptor.server import AsyncServerInterceptor
+from grpc_status import rpc_status
+from loguru import logger
+
+
+class ExceptionInterceptor(AsyncServerInterceptor):
+ async def intercept(
+ self,
+ method: Callable,
+ request_or_iterator: Any,
+ context: grpc.ServicerContext,
+ method_name: str,
+ ) -> Any:
+ try:
+ response = method(request_or_iterator, context)
+ return await response
+ except Exception as err:
+ method_name = method_name.split("/")[-1]
+ logger.exception(f"Method {method_name} encountered an error.")
+
+ await context.abort_with_status(
+ rpc_status.to_status(
+ status_pb2.Status(code=code_pb2.INTERNAL, message=str(err))
+ )
+ )
diff --git a/backends/neuron/server/text_generation_server/model.py b/backends/neuron/server/text_generation_server/model.py
new file mode 100644
index 00000000000..d281b175f93
--- /dev/null
+++ b/backends/neuron/server/text_generation_server/model.py
@@ -0,0 +1,129 @@
+import os
+import shutil
+import time
+from typing import Optional
+
+from huggingface_hub import snapshot_download
+from huggingface_hub.constants import HF_HUB_CACHE
+from loguru import logger
+
+from optimum.neuron.cache import get_hub_cached_entries
+from optimum.neuron.configuration_utils import NeuronConfig
+
+
+from .tgi_env import check_env_and_neuron_config_compatibility
+
+
+def get_export_kwargs_from_env():
+ batch_size = os.environ.get("MAX_BATCH_SIZE", None)
+ if batch_size is not None:
+ batch_size = int(batch_size)
+ sequence_length = os.environ.get("MAX_TOTAL_TOKENS", None)
+ if sequence_length is not None:
+ sequence_length = int(sequence_length)
+ num_cores = os.environ.get("HF_NUM_CORES", None)
+ if num_cores is not None:
+ num_cores = int(num_cores)
+ auto_cast_type = os.environ.get("HF_AUTO_CAST_TYPE", None)
+ return {
+ "batch_size": batch_size,
+ "sequence_length": sequence_length,
+ "num_cores": num_cores,
+ "auto_cast_type": auto_cast_type,
+ }
+
+
+def is_cached(model_id):
+ # Look for cached entries for the specified model
+ in_cache = False
+ entries = get_hub_cached_entries(model_id)
+ # Look for compatible entries
+ for entry in entries:
+ if check_env_and_neuron_config_compatibility(
+ entry, check_compiler_version=True
+ ):
+ in_cache = True
+ break
+ return in_cache
+
+
+def log_cache_size():
+ path = HF_HUB_CACHE
+ if os.path.exists(path):
+ usage = shutil.disk_usage(path)
+ gb = 2**30
+ logger.info(
+ f"Cache disk [{path}]: total = {usage.total / gb:.2f} G, free = {usage.free / gb:.2f} G"
+ )
+ else:
+ raise ValueError(f"The cache directory ({path}) does not exist.")
+
+
+def fetch_model(
+ model_id: str,
+ revision: Optional[str] = None,
+) -> str:
+ """Fetch a neuron model.
+
+ Args:
+ model_id (`str`):
+ The *model_id* of a model on the HuggingFace hub or the path to a local model.
+ revision (`Optional[str]`, defaults to `None`):
+ The revision of the model on the HuggingFace hub.
+
+ Returns:
+ A string corresponding to the model_id or path.
+ """
+ if not os.path.isdir("/sys/class/neuron_device/"):
+ raise SystemError("No neuron cores detected on the host.")
+ if os.path.isdir(model_id) and revision is not None:
+ logger.warning(
+ "Revision {} ignored for local model at {}".format(revision, model_id)
+ )
+ revision = None
+ # Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model)
+ # Note that the model may already be present in the cache.
+ try:
+ neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision)
+ except Exception as e:
+ logger.debug(
+ "NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
+ model_id,
+ revision,
+ e,
+ )
+ neuron_config = None
+ if neuron_config is not None:
+ if os.path.isdir(model_id):
+ return model_id
+ # Prefetch the neuron model from the Hub
+ logger.info(
+ f"Fetching revision [{revision}] for neuron model {model_id} under {HF_HUB_CACHE}"
+ )
+ log_cache_size()
+ return snapshot_download(model_id, revision=revision, ignore_patterns="*.bin")
+ # Model needs to be exported: look for compatible cached entries on the hub
+ if not is_cached(model_id):
+ hub_cache_url = "https://huggingface.co/aws-neuron/optimum-neuron-cache"
+ neuron_export_url = "https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-neuronx-tgi"
+ error_msg = (
+ f"No cached version found for {model_id} with {get_export_kwargs_from_env()}."
+ f"You can start a discussion to request it on {hub_cache_url}"
+ f"Alternatively, you can export your own neuron model as explained in {neuron_export_url}"
+ )
+ raise ValueError(error_msg)
+ logger.warning(
+ f"{model_id} is not a neuron model: it will be exported using cached artifacts."
+ )
+ if os.path.isdir(model_id):
+ return model_id
+ # Prefetch weights, tokenizer and generation config so that they are in cache
+ log_cache_size()
+ start = time.time()
+ snapshot_path = snapshot_download(
+ model_id, revision=revision, ignore_patterns="*.bin"
+ )
+ end = time.time()
+ logger.info(f"Model weights fetched in {end - start:.2f} s.")
+ log_cache_size()
+ return snapshot_path
diff --git a/backends/neuron/server/text_generation_server/server.py b/backends/neuron/server/text_generation_server/server.py
new file mode 100644
index 00000000000..8eb2592d634
--- /dev/null
+++ b/backends/neuron/server/text_generation_server/server.py
@@ -0,0 +1,89 @@
+import asyncio
+from pathlib import Path
+from typing import List
+
+from grpc import aio
+from grpc_reflection.v1alpha import reflection
+from loguru import logger
+
+from .generator import Generator, NeuronGenerator
+from .interceptor import ExceptionInterceptor
+from .pb import generate_pb2, generate_pb2_grpc
+
+
+class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
+ def __init__(self, generator: Generator, server_urls: List[str]):
+ self.generator = generator
+ self.server_urls = server_urls
+
+ async def Info(self, request, context):
+ return self.generator.info
+
+ async def Health(self, request, context):
+ return generate_pb2.HealthResponse()
+
+ async def ServiceDiscovery(self, request, context):
+ return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
+
+ async def ClearCache(self, request, context):
+ if request.HasField("id"):
+ self.generator.clear(request.id)
+ else:
+ self.generator.clear()
+ return generate_pb2.ClearCacheResponse()
+
+ async def FilterBatch(self, request, context):
+ filtered_batch = self.generator.filter(request.batch_id, request.request_ids)
+ return generate_pb2.FilterBatchResponse(batch=filtered_batch)
+
+ async def Warmup(self, request, context):
+ max_tokens = self.generator.warmup(request.batch)
+ return generate_pb2.WarmupResponse(max_supported_total_tokens=max_tokens)
+
+ async def Prefill(self, request, context):
+ generations, batch = self.generator.prefill(request.batch)
+ return generate_pb2.PrefillResponse(generations=generations, batch=batch)
+
+ async def Decode(self, request, context):
+ generations, batch = self.generator.decode(request.batches)
+ return generate_pb2.DecodeResponse(generations=generations, batch=batch)
+
+
+def serve(
+ model_id: str,
+ revision: str,
+ uds_path: Path,
+):
+ async def serve_inner(model_id: str, revision: str):
+ unix_socket_template = "unix://{}-{}"
+ local_url = unix_socket_template.format(uds_path, 0)
+ server_urls = [local_url]
+
+ try:
+ generator = NeuronGenerator.from_pretrained(model_id, revision)
+ except Exception:
+ logger.exception("Error when initializing model")
+ raise
+
+ server = aio.server(interceptors=[ExceptionInterceptor()])
+ generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
+ TextGenerationService(generator, server_urls), server
+ )
+ SERVICE_NAMES = (
+ generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
+ reflection.SERVICE_NAME,
+ )
+ reflection.enable_server_reflection(SERVICE_NAMES, server)
+ server.add_insecure_port(local_url)
+
+ await server.start()
+
+ logger.info("Server started at {}".format(local_url))
+
+ try:
+ await server.wait_for_termination()
+ except KeyboardInterrupt:
+ logger.info("Signal received. Shutting down")
+ await server.stop(0)
+
+ asyncio.run(serve_inner(model_id, revision))
diff --git a/backends/neuron/server/text_generation_server/tgi_env.py b/backends/neuron/server/text_generation_server/tgi_env.py
new file mode 100644
index 00000000000..ee97f180eee
--- /dev/null
+++ b/backends/neuron/server/text_generation_server/tgi_env.py
@@ -0,0 +1,281 @@
+#!/usr/bin/env python
+
+import argparse
+import logging
+import os
+import sys
+from typing import Any, Dict, List, Optional
+
+from optimum.neuron.modeling_decoder import get_available_cores
+from optimum.neuron.cache import get_hub_cached_entries
+from optimum.neuron.configuration_utils import NeuronConfig
+from optimum.neuron.utils.version_utils import get_neuronxcc_version
+from optimum.neuron.utils import map_torch_dtype
+
+
+logger = logging.getLogger(__name__)
+
+tgi_router_env_vars = [
+ "MAX_BATCH_SIZE",
+ "MAX_TOTAL_TOKENS",
+ "MAX_INPUT_TOKENS",
+ "MAX_BATCH_PREFILL_TOKENS",
+]
+tgi_server_env_vars = ["HF_NUM_CORES", "HF_AUTO_CAST_TYPE"]
+
+
+# By the end of this script all env var should be specified properly
+tgi_env_vars = tgi_server_env_vars + tgi_router_env_vars
+
+available_cores = get_available_cores()
+neuronxcc_version = get_neuronxcc_version()
+
+
+def parse_cmdline_and_set_env(argv: List[str] = None) -> argparse.Namespace:
+ parser = argparse.ArgumentParser()
+ if not argv:
+ argv = sys.argv
+ # All these are params passed to tgi and intercepted here
+ parser.add_argument(
+ "--max-input-tokens",
+ type=int,
+ default=os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0)),
+ )
+ parser.add_argument(
+ "--max-total-tokens", type=int, default=os.getenv("MAX_TOTAL_TOKENS", 0)
+ )
+ parser.add_argument(
+ "--max-batch-size", type=int, default=os.getenv("MAX_BATCH_SIZE", 0)
+ )
+ parser.add_argument(
+ "--max-batch-prefill-tokens",
+ type=int,
+ default=os.getenv("MAX_BATCH_PREFILL_TOKENS", 0),
+ )
+ parser.add_argument("--model-id", type=str, default=os.getenv("MODEL_ID"))
+ parser.add_argument("--revision", type=str, default=os.getenv("REVISION"))
+
+ args = parser.parse_known_args(argv)[0]
+
+ if not args.model_id:
+ raise Exception(
+ "No model id provided ! Either specify it using --model-id cmdline or MODEL_ID env var"
+ )
+
+ # Override env with cmdline params
+ os.environ["MODEL_ID"] = args.model_id
+
+ # Set all tgi router and tgi server values to consistent values as early as possible
+ # from the order of the parser defaults, the tgi router value can override the tgi server ones
+ if args.max_total_tokens > 0:
+ os.environ["MAX_TOTAL_TOKENS"] = str(args.max_total_tokens)
+
+ if args.max_input_tokens > 0:
+ os.environ["MAX_INPUT_TOKENS"] = str(args.max_input_tokens)
+
+ if args.max_batch_size > 0:
+ os.environ["MAX_BATCH_SIZE"] = str(args.max_batch_size)
+
+ if args.max_batch_prefill_tokens > 0:
+ os.environ["MAX_BATCH_PREFILL_TOKENS"] = str(args.max_batch_prefill_tokens)
+
+ if args.revision:
+ os.environ["REVISION"] = str(args.revision)
+
+ return args
+
+
+def neuron_config_to_env(neuron_config):
+ if isinstance(neuron_config, NeuronConfig):
+ neuron_config = neuron_config.to_dict()
+ with open(os.environ["ENV_FILEPATH"], "w") as f:
+ f.write("export MAX_BATCH_SIZE={}\n".format(neuron_config["batch_size"]))
+ f.write("export MAX_TOTAL_TOKENS={}\n".format(neuron_config["sequence_length"]))
+ f.write("export HF_NUM_CORES={}\n".format(neuron_config["tp_degree"]))
+ config_key = (
+ "auto_cast_type" if "auto_cast_type" in neuron_config else "torch_dtype"
+ )
+ auto_cast_type = neuron_config[config_key]
+ f.write("export HF_AUTO_CAST_TYPE={}\n".format(auto_cast_type))
+ max_input_tokens = os.getenv("MAX_INPUT_TOKENS")
+ if not max_input_tokens:
+ max_input_tokens = int(neuron_config["sequence_length"]) // 2
+ if max_input_tokens == 0:
+ raise Exception("Model sequence length should be greater than 1")
+ f.write("export MAX_INPUT_TOKENS={}\n".format(max_input_tokens))
+ max_batch_prefill_tokens = os.getenv("MAX_BATCH_PREFILL_TOKENS")
+ if not max_batch_prefill_tokens:
+ max_batch_prefill_tokens = int(neuron_config["batch_size"]) * int(
+ max_input_tokens
+ )
+ f.write("export MAX_BATCH_PREFILL_TOKENS={}\n".format(max_batch_prefill_tokens))
+
+
+def sort_neuron_configs(dictionary):
+ return -dictionary["tp_degree"], -dictionary["batch_size"]
+
+
+def lookup_compatible_cached_model(
+ model_id: str, revision: Optional[str]
+) -> Optional[Dict[str, Any]]:
+ # Reuse the same mechanic as the one in use to configure the tgi server part
+ # The only difference here is that we stay as flexible as possible on the compatibility part
+ entries = get_hub_cached_entries(model_id)
+
+ logger.debug(
+ "Found %d cached entries for model %s, revision %s",
+ len(entries),
+ model_id,
+ revision,
+ )
+
+ all_compatible = []
+ for entry in entries:
+ if check_env_and_neuron_config_compatibility(
+ entry, check_compiler_version=True
+ ):
+ all_compatible.append(entry)
+
+ if not all_compatible:
+ logger.debug(
+ "No compatible cached entry found for model %s, env %s, available cores %s, neuronxcc version %s",
+ model_id,
+ get_env_dict(),
+ available_cores,
+ neuronxcc_version,
+ )
+ return None
+
+ logger.info("%d compatible neuron cached models found", len(all_compatible))
+
+ all_compatible = sorted(all_compatible, key=sort_neuron_configs)
+
+ entry = all_compatible[0]
+
+ return entry
+
+
+def check_env_and_neuron_config_compatibility(
+ neuron_config_dict: Dict[str, Any], check_compiler_version: bool
+) -> bool:
+ logger.debug(
+ "Checking the provided neuron config %s is compatible with the local setup and provided environment",
+ neuron_config_dict,
+ )
+
+ # Local setup compat checks
+ if neuron_config_dict["tp_degree"] > available_cores:
+ logger.debug(
+ "Not enough neuron cores available to run the provided neuron config"
+ )
+ return False
+
+ if (
+ check_compiler_version
+ and neuron_config_dict["neuronxcc_version"] != neuronxcc_version
+ ):
+ logger.debug(
+ "Compiler version conflict, the local one (%s) differs from the one used to compile the model (%s)",
+ neuronxcc_version,
+ neuron_config_dict["neuronxcc_version"],
+ )
+ return False
+
+ batch_size = os.getenv("MAX_BATCH_SIZE", None)
+ if batch_size is not None and neuron_config_dict["batch_size"] < int(batch_size):
+ logger.debug(
+ "The provided MAX_BATCH_SIZE (%s) is higher than the neuron config batch size (%s)",
+ os.getenv("MAX_BATCH_SIZE"),
+ neuron_config_dict["batch_size"],
+ )
+ return False
+ max_total_tokens = os.getenv("MAX_TOTAL_TOKENS", None)
+ if max_total_tokens is not None and neuron_config_dict["sequence_length"] < int(
+ max_total_tokens
+ ):
+ logger.debug(
+ "The provided MAX_TOTAL_TOKENS (%s) is higher than the neuron config sequence length (%s)",
+ max_total_tokens,
+ neuron_config_dict["sequence_length"],
+ )
+ return False
+ num_cores = os.getenv("HF_NUM_CORES", None)
+ if num_cores is not None and neuron_config_dict["tp_degree"] < int(num_cores):
+ logger.debug(
+ "The provided HF_NUM_CORES (%s) is higher than the neuron config tp degree (%s)",
+ num_cores,
+ neuron_config_dict["tp_degree"],
+ )
+ return False
+ auto_cast_type = os.getenv("HF_AUTO_CAST_TYPE", None)
+ if auto_cast_type is not None:
+ config_key = (
+ "auto_cast_type"
+ if "auto_cast_type" in neuron_config_dict
+ else "torch_dtype"
+ )
+ neuron_config_value = map_torch_dtype(str(neuron_config_dict[config_key]))
+ env_value = map_torch_dtype(auto_cast_type)
+ if env_value != neuron_config_value:
+ logger.debug(
+ "The provided auto cast type and the neuron config param differ (%s != %s)",
+ env_value,
+ neuron_config_value,
+ )
+ return False
+ max_input_tokens = int(
+ os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0))
+ )
+ if max_input_tokens > 0:
+ if hasattr(neuron_config_dict, "max_context_length"):
+ sequence_length = neuron_config_dict["max_context_length"]
+ else:
+ sequence_length = neuron_config_dict["sequence_length"]
+ if max_input_tokens >= sequence_length:
+ logger.debug(
+ "Specified max input tokens is not compatible with config sequence length ( %s >= %s)",
+ max_input_tokens,
+ sequence_length,
+ )
+ return False
+
+ return True
+
+
+def get_env_dict() -> Dict[str, str]:
+ d = {}
+ for k in tgi_env_vars:
+ d[k] = os.getenv(k)
+ return d
+
+
+def get_neuron_config_for_model(
+ model_name_or_path: str, revision: Optional[str] = None
+) -> NeuronConfig:
+ try:
+ neuron_config = NeuronConfig.from_pretrained(
+ model_name_or_path, revision=revision
+ )
+ except Exception as e:
+ logger.debug(
+ "NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
+ model_name_or_path,
+ revision,
+ e,
+ )
+ neuron_config = None
+ if neuron_config is not None:
+ compatible = check_env_and_neuron_config_compatibility(
+ neuron_config.to_dict(), check_compiler_version=False
+ )
+ if not compatible:
+ env_dict = get_env_dict()
+ msg = (
+ "Invalid neuron config and env. Config {}, env {}, available cores {}, neuronxcc version {}"
+ ).format(neuron_config, env_dict, available_cores, neuronxcc_version)
+ logger.error(msg)
+ raise Exception(msg)
+ else:
+ neuron_config = lookup_compatible_cached_model(model_name_or_path, revision)
+
+ return neuron_config
diff --git a/backends/neuron/tests/conftest.py b/backends/neuron/tests/conftest.py
new file mode 100644
index 00000000000..1dd20c8c6ed
--- /dev/null
+++ b/backends/neuron/tests/conftest.py
@@ -0,0 +1 @@
+pytest_plugins = ["fixtures.model"]
diff --git a/backends/neuron/tests/fixtures/model.py b/backends/neuron/tests/fixtures/model.py
new file mode 100644
index 00000000000..ad41fd10a5c
--- /dev/null
+++ b/backends/neuron/tests/fixtures/model.py
@@ -0,0 +1,118 @@
+import copy
+import logging
+import subprocess
+import sys
+from tempfile import TemporaryDirectory
+
+import os
+import pytest
+from transformers import AutoTokenizer
+
+
+from optimum.neuron.cache import synchronize_hub_cache
+
+
+logging.basicConfig(
+ level=logging.INFO,
+ format="[%(asctime)s] %(levelname)s [%(filename)s.%(funcName)s:%(lineno)d] %(message)s",
+ stream=sys.stdout,
+)
+logger = logging.getLogger(__file__)
+
+
+OPTIMUM_CACHE_REPO_ID = "optimum-internal-testing/neuron-testing-cache"
+
+
+# All model configurations below will be added to the neuron_model_config fixture
+MODEL_CONFIGURATIONS = {
+ "llama": {
+ "model_id": "unsloth/Llama-3.2-1B-Instruct",
+ "export_kwargs": {
+ "batch_size": 4,
+ "sequence_length": 4096,
+ "num_cores": 2,
+ "auto_cast_type": "bf16",
+ },
+ },
+ "qwen2": {
+ "model_id": "Qwen/Qwen2.5-0.5B",
+ "export_kwargs": {
+ "batch_size": 4,
+ "sequence_length": 4096,
+ "num_cores": 2,
+ "auto_cast_type": "bf16",
+ },
+ },
+ "granite": {
+ "model_id": "ibm-granite/granite-3.1-2b-instruct",
+ "export_kwargs": {
+ "batch_size": 4,
+ "sequence_length": 4096,
+ "num_cores": 2,
+ "auto_cast_type": "bf16",
+ },
+ },
+}
+
+
+def export_model(model_id, export_kwargs, neuron_model_path):
+ export_command = [
+ "optimum-cli",
+ "export",
+ "neuron",
+ "-m",
+ model_id,
+ "--task",
+ "text-generation",
+ ]
+ for kwarg, value in export_kwargs.items():
+ export_command.append(f"--{kwarg}")
+ export_command.append(str(value))
+ export_command.append(neuron_model_path)
+ logger.info(f"Exporting {model_id} with {export_kwargs}")
+ try:
+ subprocess.run(export_command, check=True)
+ except subprocess.CalledProcessError as e:
+ raise ValueError(f"Failed to export model: {e}")
+
+
+@pytest.fixture(scope="session", params=MODEL_CONFIGURATIONS.keys())
+def neuron_model_config(request):
+ """Expose a pre-trained neuron model
+
+ The fixture exports a model locally and returns a dictionary containing:
+ - a configuration name,
+ - the original model id,
+ - the export parameters,
+ - the neuron model local path.
+
+ For each exposed model, the local directory is maintained for the duration of the
+ test session and cleaned up afterwards.
+
+ """
+ config_name = request.param
+ model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param])
+ model_id = model_config["model_id"]
+ export_kwargs = model_config["export_kwargs"]
+ with TemporaryDirectory() as neuron_model_path:
+ export_model(model_id, export_kwargs, neuron_model_path)
+ synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID)
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+ tokenizer.save_pretrained(neuron_model_path)
+ del tokenizer
+ # Add dynamic parameters to the model configuration
+ model_config["neuron_model_path"] = neuron_model_path
+ # Also add model configuration name to allow tests to adapt their expectations
+ model_config["name"] = config_name
+ # Yield instead of returning to keep a reference to the temporary directory.
+ # It will go out of scope and be released only once all tests needing the fixture
+ # have been completed.
+ logger.info(f"{config_name} ready for testing ...")
+ os.environ["CUSTOM_CACHE_REPO"] = OPTIMUM_CACHE_REPO_ID
+ yield model_config
+ logger.info(f"Done with {config_name}")
+
+
+@pytest.fixture(scope="module")
+def neuron_model_path(neuron_model_config):
+ yield neuron_model_config["neuron_model_path"]
diff --git a/backends/neuron/tests/prune_test_models.py b/backends/neuron/tests/prune_test_models.py
new file mode 100644
index 00000000000..448962fb62a
--- /dev/null
+++ b/backends/neuron/tests/prune_test_models.py
@@ -0,0 +1,23 @@
+from argparse import ArgumentParser
+from huggingface_hub import HfApi
+
+
+def main():
+ parser = ArgumentParser()
+ parser.add_argument("--yes", action="http://github.com/store_true", default=False)
+ args = parser.parse_args()
+ api = HfApi()
+ models = api.list_models(search="optimum-internal-testing/neuron-tgi-testing")
+ for model in models:
+ if args.yes:
+ delete = True
+ else:
+ answer = input(f"Do you want to delete {model.id} [y/N] ?")
+ delete = answer == "y"
+ if delete:
+ api.delete_repo(model.id)
+ print(f"Deleted {model.id}.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/backends/neuron/tests/pytest.ini b/backends/neuron/tests/pytest.ini
new file mode 100644
index 00000000000..2f4c80e3075
--- /dev/null
+++ b/backends/neuron/tests/pytest.ini
@@ -0,0 +1,2 @@
+[pytest]
+asyncio_mode = auto
diff --git a/backends/neuron/tests/requirements.txt b/backends/neuron/tests/requirements.txt
new file mode 100644
index 00000000000..ef3c8543e5d
--- /dev/null
+++ b/backends/neuron/tests/requirements.txt
@@ -0,0 +1,19 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+text-generation >= 0.6.0
+pytest >= 7.4.0
+pytest-asyncio >= 0.21.1
+requests < 2.32.0
+docker >= 6.1.3
+Levenshtein
diff --git a/backends/neuron/tests/server/helpers.py b/backends/neuron/tests/server/helpers.py
new file mode 100644
index 00000000000..f0f81d06d1c
--- /dev/null
+++ b/backends/neuron/tests/server/helpers.py
@@ -0,0 +1,173 @@
+from text_generation_server.generator import NeuronGenerator
+from text_generation_server.pb.generate_pb2 import (
+ Batch,
+ NextTokenChooserParameters,
+ Request,
+ StoppingCriteriaParameters,
+)
+
+
+def create_request(
+ id: int,
+ inputs: str,
+ truncate: int = 0,
+ max_new_tokens: int = 20,
+ do_sample: bool = False,
+ top_k: int = 50,
+ top_p: float = 0.9,
+ temperature: float = 1.0,
+ seed: int = 42,
+ repetition_penalty: float = 1.0,
+):
+ parameters = NextTokenChooserParameters(
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ do_sample=do_sample,
+ seed=seed,
+ repetition_penalty=repetition_penalty,
+ )
+ stopping_parameters = StoppingCriteriaParameters(max_new_tokens=max_new_tokens)
+ return Request(
+ id=id,
+ inputs=inputs,
+ truncate=truncate,
+ parameters=parameters,
+ stopping_parameters=stopping_parameters,
+ )
+
+
+def check_prefill(
+ input_text,
+ expected_token_id,
+ expected_token_text,
+ do_sample,
+ batch_size,
+ model_path,
+):
+ """Verify that a prefill for a single request generates the expected output."""
+ generator = NeuronGenerator.from_pretrained(model_path)
+ assert generator.model.batch_size >= batch_size
+ requests = []
+ max_new_tokens = 20
+ for i in range(batch_size):
+ requests.append(
+ create_request(
+ id=0,
+ inputs=input_text,
+ do_sample=do_sample,
+ max_new_tokens=max_new_tokens,
+ )
+ )
+ # Let's be pessimistic when estimating max_tokens
+ batch_size * (len(input_text) + max_new_tokens)
+ max_length = generator.model.max_length
+ batch = Batch(
+ id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
+ )
+ generations, next_batch = generator.prefill(batch)
+ assert next_batch.size == batch_size
+ # Whatever was passed as max_tokens, the server will correct it
+ # because of static batching
+ assert next_batch.max_tokens == batch_size * max_length
+ assert len(generations) == batch_size
+ for g in generations:
+ tokens = g.tokens
+ assert tokens.ids == [expected_token_id]
+ assert tokens.texts == [expected_token_text]
+
+
+def check_decode_single(
+ input_text, max_new_tokens, generated_text, do_sample, model_path
+):
+ """Verify that a decoding for a single request generates the expected output."""
+ generator = NeuronGenerator.from_pretrained(model_path)
+ request = create_request(
+ id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample
+ )
+ max_length = generator.model.max_length
+ batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
+ generations, next_batch = generator.prefill(batch)
+ # We already generated one token: call decode max_new_tokens - 1 times
+ for _ in range(max_new_tokens - 1):
+ assert next_batch.size == 1
+ assert next_batch.max_tokens == max_length
+ assert len(generations) == 1
+ assert len(generations[0].tokens.ids) == 1
+ generations, next_batch = generator.decode([next_batch])
+ assert next_batch is None
+ assert len(generations) == 1
+ output = generations[0].generated_text
+ assert output.generated_tokens == max_new_tokens
+ assert output.finish_reason == 0
+ assert output.text == generated_text
+
+
+def check_decode_multiple(model_path):
+ """Verify that two requests added to the batch at different generation steps
+ generate the same outputs (continuous batching).
+ """
+ generator = NeuronGenerator.from_pretrained(model_path)
+ assert generator.model.batch_size > 1
+ input_text = "Once upon a time"
+ max_new_tokens = 20
+ # Prefill a single request, remembering the generated token
+ tokens = {0: [], 1: []}
+ request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens)
+ max_length = generator.model.max_length
+ batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
+ generations, next_batch = generator.prefill(batch)
+ assert next_batch.size == 1
+ assert len(generations) == 1
+ g = generations[0]
+ tokens[g.request_id].append(g.tokens.ids[0])
+ assert len(tokens[0]) == 1
+ # Decode a few tokens
+ gen_tokens = 4
+ for _ in range(gen_tokens - 1):
+ generations, next_batch = generator.decode([next_batch])
+ assert len(generations) == 1
+ g = generations[0]
+ tokens[g.request_id].append(g.tokens.ids[0])
+ assert len(tokens[0]) == gen_tokens
+ assert next_batch.size == 1
+ # Add a second request
+ request = create_request(id=1, inputs=input_text, max_new_tokens=max_new_tokens)
+ batch = Batch(id=1, requests=[request], size=1, max_tokens=max_length)
+ generations, next_batch_1 = generator.prefill(batch)
+ assert next_batch_1.size == 1
+ # We should have generated only a single token
+ assert len(generations) == 1
+ g = generations[0]
+ tokens[g.request_id].append(g.tokens.ids[0])
+ assert len(tokens[0]) == gen_tokens
+ assert len(tokens[1]) == 1
+ # Decode more tokens until we reach the maximum for the first request
+ batches = [next_batch, next_batch_1]
+ for _ in range(max_new_tokens - gen_tokens):
+ generations, next_batch = generator.decode(batches)
+ for g in generations:
+ tokens[g.request_id].append(g.tokens.ids[0])
+ batches = [next_batch]
+ # Verify we now only have one pending request
+ assert next_batch.size == 1
+ assert len(tokens[0]) == max_new_tokens
+ assert len(tokens[1]) == max_new_tokens - gen_tokens + 1
+ # Verify we have the output for the first request
+ for g in generations:
+ if g.request_id == 0:
+ output = g.generated_text
+ assert output.text != ""
+ assert output.generated_tokens == max_new_tokens
+ generated_text = output.text
+ # Continue decoding until the end of the second request
+ for _ in range(gen_tokens - 1):
+ generations, next_batch = generator.decode([next_batch])
+ assert len(generations) == 1
+ g = generations[0]
+ tokens[g.request_id].append(g.tokens.ids[0])
+ assert next_batch is None
+ output = generations[0].generated_text
+ assert output.generated_tokens == max_new_tokens
+ assert tokens[0] == tokens[1]
+ assert output.text == generated_text
diff --git a/backends/neuron/tests/server/test_cached_model.py b/backends/neuron/tests/server/test_cached_model.py
new file mode 100644
index 00000000000..73622578722
--- /dev/null
+++ b/backends/neuron/tests/server/test_cached_model.py
@@ -0,0 +1,42 @@
+import os
+import pytest
+
+from text_generation_server.generator import NeuronGenerator
+from text_generation_server.model import fetch_model, is_cached
+
+
+@pytest.fixture(scope="module")
+def cached_model_id(neuron_model_config) -> str:
+ """
+ Fixture to provide a cached model ID for testing.
+ This assumes the model is already cached in the local environment.
+ """
+ export_kwargs = neuron_model_config["export_kwargs"]
+ os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"])
+ os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"])
+ os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"]
+ os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"])
+ yield neuron_model_config["model_id"]
+ os.environ.pop("MAX_BATCH_SIZE", None)
+ os.environ.pop("MAX_TOTAL_TOKENS", None)
+ os.environ.pop("HF_AUTO_CAST_TYPE", None)
+ os.environ.pop("HF_NUM_CORES", None)
+
+
+def test_model_is_cached(cached_model_id):
+ assert is_cached(cached_model_id), f"Model {cached_model_id} is not cached"
+
+
+def test_fetch_cached_model(cached_model_id: str):
+ model_path = fetch_model(cached_model_id)
+ assert os.path.exists(
+ model_path
+ ), f"Model {cached_model_id} was not fetched successfully"
+ assert os.path.isdir(model_path), f"Model {cached_model_id} is not a directory"
+
+
+def test_generator_from_cached_model(cached_model_id: str):
+ generator = NeuronGenerator.from_pretrained(model_id=cached_model_id)
+ assert generator is not None, "Generator could not be created from cached model"
+ assert generator.model is not None, "Generator model is not initialized"
+ assert generator.tokenizer is not None, "Generator tokenizer is not initialized"
diff --git a/backends/neuron/tests/server/test_continuous_batching.py b/backends/neuron/tests/server/test_continuous_batching.py
new file mode 100644
index 00000000000..3d9ab50981a
--- /dev/null
+++ b/backends/neuron/tests/server/test_continuous_batching.py
@@ -0,0 +1,74 @@
+from helpers import create_request
+from text_generation_server.generator import NeuronGenerator
+from text_generation_server.pb.generate_pb2 import Batch
+
+
+def test_continuous_batching_two_requests(neuron_model_config):
+ """Verify that two requests added to the batch at different generation steps
+ generate the same outputs (continuous batching).
+ """
+ neuron_model_path = neuron_model_config["neuron_model_path"]
+ generator = NeuronGenerator.from_pretrained(neuron_model_path)
+ assert generator.model.neuron_config.batch_size > 1
+ input_text = "Once upon a time"
+ max_new_tokens = 20
+ # Prefill a single request, remembering the generated token
+ tokens = {0: [], 1: []}
+ request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens)
+ max_length = generator.model.neuron_config.sequence_length
+ batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
+ generations, next_batch = generator.prefill(batch)
+ assert next_batch.size == 1
+ assert len(generations) == 1
+ g = generations[0]
+ tokens[g.request_id].append(g.tokens.ids[0])
+ assert len(tokens[0]) == 1
+ # Decode a few tokens
+ gen_tokens = 4
+ for _ in range(gen_tokens - 1):
+ generations, next_batch = generator.decode([next_batch])
+ assert len(generations) == 1
+ g = generations[0]
+ tokens[g.request_id].append(g.tokens.ids[0])
+ assert len(tokens[0]) == gen_tokens
+ assert next_batch.size == 1
+ # Add a second request
+ request = create_request(id=1, inputs=input_text, max_new_tokens=max_new_tokens)
+ batch = Batch(id=1, requests=[request], size=1, max_tokens=max_length)
+ generations, next_batch_1 = generator.prefill(batch)
+ assert next_batch_1.size == 1
+ # We should have generated only a single token
+ assert len(generations) == 1
+ g = generations[0]
+ tokens[g.request_id].append(g.tokens.ids[0])
+ assert len(tokens[0]) == gen_tokens
+ assert len(tokens[1]) == 1
+ # Decode more tokens until we reach the maximum for the first request
+ batches = [next_batch, next_batch_1]
+ for _ in range(max_new_tokens - gen_tokens):
+ generations, next_batch = generator.decode(batches)
+ for g in generations:
+ tokens[g.request_id].append(g.tokens.ids[0])
+ batches = [next_batch]
+ # Verify we now only have one pending request
+ assert next_batch.size == 1
+ assert len(tokens[0]) == max_new_tokens
+ assert len(tokens[1]) == max_new_tokens - gen_tokens + 1
+ # Verify we have the output for the first request
+ for g in generations:
+ if g.request_id == 0:
+ output = g.generated_text
+ assert output.text != ""
+ assert output.generated_tokens == max_new_tokens
+ generated_text = output.text
+ # Continue decoding until the end of the second request
+ for _ in range(gen_tokens - 1):
+ generations, next_batch = generator.decode([next_batch])
+ assert len(generations) == 1
+ g = generations[0]
+ tokens[g.request_id].append(g.tokens.ids[0])
+ assert next_batch is None
+ output = generations[0].generated_text
+ assert output.generated_tokens == max_new_tokens
+ assert tokens[0] == tokens[1]
+ assert output.text == generated_text
diff --git a/backends/neuron/tests/server/test_decode.py b/backends/neuron/tests/server/test_decode.py
new file mode 100644
index 00000000000..5bb00a84144
--- /dev/null
+++ b/backends/neuron/tests/server/test_decode.py
@@ -0,0 +1,52 @@
+from helpers import create_request
+from text_generation_server.generator import NeuronGenerator
+from text_generation_server.pb.generate_pb2 import Batch
+
+
+def test_decode(neuron_model_config):
+ """Verify that a decoding for a single request generates the expected output."""
+ config_name = neuron_model_config["name"]
+ neuron_model_path = neuron_model_config["neuron_model_path"]
+ generator = NeuronGenerator.from_pretrained(neuron_model_path)
+ for do_sample in [True, False]:
+ mode = "sample" if do_sample else "greedy"
+ print(f"{config_name}[{mode}]")
+ generated_text = _test_decode(config_name, generator, do_sample)
+ if not do_sample:
+ expected_text = {
+ "llama": " The world was holding its breath as the world's top scientists and engineers gathered at the secret underground facility",
+ "qwen2": " I was sitting in my room, staring at the clock, when a knock at the door. I",
+ "granite": "\n\nThis opening line is from George Orwell's dystopian novel, \"1",
+ }[config_name]
+ assert generated_text == expected_text
+ generator.clear()
+
+
+def _test_decode(config_name, generator, do_sample):
+ input_text = (
+ "It was a bright cold day in April, and the clocks were striking thirteen."
+ )
+ max_new_tokens = 20
+ request = create_request(
+ id=0,
+ inputs=input_text,
+ max_new_tokens=max_new_tokens,
+ do_sample=do_sample,
+ temperature=0.9,
+ )
+ max_length = generator.model.neuron_config.sequence_length
+ batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
+ generations, next_batch = generator.prefill(batch)
+ # We already generated one token: call decode max_new_tokens - 1 times
+ for _ in range(max_new_tokens - 1):
+ assert next_batch.size == 1
+ assert next_batch.max_tokens == max_length
+ assert len(generations) == 1
+ assert len(generations[0].tokens.ids) == 1
+ generations, next_batch = generator.decode([next_batch])
+ assert next_batch is None
+ assert len(generations) == 1
+ output = generations[0].generated_text
+ assert output.generated_tokens == max_new_tokens
+ assert output.finish_reason == 0
+ return output.text
diff --git a/backends/neuron/tests/server/test_generator_slot.py b/backends/neuron/tests/server/test_generator_slot.py
new file mode 100644
index 00000000000..0c03e9d1eed
--- /dev/null
+++ b/backends/neuron/tests/server/test_generator_slot.py
@@ -0,0 +1,66 @@
+import pytest
+import torch
+from text_generation_server.generator import Slot
+from text_generation_server.pb.generate_pb2 import Request
+from transformers import AutoTokenizer, GenerationConfig
+
+
+TOKENIZERS = ["NousResearch/Llama-2-7b-hf", "gpt2"]
+
+
+@pytest.fixture(params=TOKENIZERS)
+def tokenizer(request):
+ t = AutoTokenizer.from_pretrained(request.param)
+ t.padding_side = "left"
+ t.pad_token_id = t.eos_token_id
+ return t
+
+
+@pytest.mark.parametrize(
+ "input_text, generated_text",
+ [
+ [
+ "It was a bright cold day in April, and the clocks were striking thirteen.",
+ " Winston Smith, his chin nuzzled into his breast in an effort to escape the vile wind,"
+ " slipped quickly through the glass doors of Victory Mansions, though not quickly enough"
+ " to prevent a swirl of gritty dust from entering along with him.",
+ ],
+ ["This sentence is written in chinese:", "我很感谢你的热情"],
+ ["Some text might contain a lot of emojis like 😃", "😍💪 👉 👀"],
+ ],
+ ids=["spaces", "chinese-utf8", "emojis"],
+)
+def test_decode_streaming(tokenizer, input_text, generated_text):
+ slot = Slot(0, tokenizer)
+ request = Request(id=0, inputs=input_text)
+ slot.assign(0, request, GenerationConfig())
+ assert slot.cached_text == input_text
+
+ inputs = tokenizer(
+ input_text,
+ padding="max_length",
+ max_length=len(input_text) + 1,
+ return_tensors="pt",
+ )
+ input_ids = inputs["input_ids"][0]
+ attention_mask = inputs["attention_mask"][0]
+ generated_tokens = tokenizer(generated_text, add_special_tokens=False)["input_ids"]
+
+ # We need to regenerate the full text as the tokenizer might change it (extra spaces might be added)
+ all_input_ids = torch.cat([input_ids, torch.tensor(generated_tokens)])
+ full_text = tokenizer.decode(all_input_ids, skip_special_tokens=True)
+ regenerated_text = full_text[len(input_text) :]
+
+ # Initialize the slot with the inputs
+ slot.reset(input_ids, attention_mask, selector=None)
+
+ assert slot.generated_tokens == 0
+
+ # Simulate an iterative generation (i.e. don't call select and use known tokens instead)
+ decoded_text = ""
+ for i in range(len(generated_tokens)):
+ text = slot.append(generated_tokens[i])
+ assert slot.generated_tokens == i + 1
+ decoded_text += text
+
+ assert decoded_text == regenerated_text
diff --git a/backends/neuron/tests/server/test_info.py b/backends/neuron/tests/server/test_info.py
new file mode 100644
index 00000000000..5913acec471
--- /dev/null
+++ b/backends/neuron/tests/server/test_info.py
@@ -0,0 +1,10 @@
+from text_generation_server.generator import NeuronGenerator
+
+
+def test_info(neuron_model_path):
+ generator = NeuronGenerator.from_pretrained(neuron_model_path)
+ info = generator.info
+ assert info.requires_padding is True
+ assert info.device_type == "xla"
+ assert info.window_size == 0
+ assert info.speculate == 0
diff --git a/backends/neuron/tests/server/test_prefill.py b/backends/neuron/tests/server/test_prefill.py
new file mode 100644
index 00000000000..1061fbc415e
--- /dev/null
+++ b/backends/neuron/tests/server/test_prefill.py
@@ -0,0 +1,93 @@
+from helpers import create_request
+from text_generation_server.generator import NeuronGenerator
+from text_generation_server.pb.generate_pb2 import Batch
+
+
+def test_prefill(neuron_model_config):
+ """Verify that a prefill for a single request generates the expected output."""
+ config_name = neuron_model_config["name"]
+ neuron_model_path = neuron_model_config["neuron_model_path"]
+ generator = NeuronGenerator.from_pretrained(neuron_model_path)
+ max_batch_size = 4
+ assert generator.model.neuron_config.batch_size >= max_batch_size
+ for num_requests in [1, max_batch_size]:
+ for do_sample in [True, False]:
+ mode = "sample" if do_sample else "greedy"
+ print(f"[{mode}]: {num_requests} requests")
+ _test_prefill(config_name, generator, num_requests, do_sample)
+ generator.clear()
+
+
+def _test_prefill(config_name, generator, batch_size, do_sample):
+ requests = []
+ max_new_tokens = 20
+ input_text = (
+ "It was a bright cold day in April, and the clocks were striking thirteen."
+ )
+ for i in range(batch_size):
+ requests.append(
+ create_request(
+ id=i,
+ inputs=input_text,
+ do_sample=do_sample,
+ max_new_tokens=max_new_tokens,
+ )
+ )
+ # Let's be pessimistic when estimating max_tokens
+ max_length = generator.max_prefill_length()
+ batch = Batch(
+ id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
+ )
+ generations, next_batch = generator.prefill(batch)
+ assert next_batch.size == batch_size
+ # Whatever was passed as max_tokens, the server will correct it
+ # because of static batching
+ assert next_batch.max_tokens == batch_size * max_length
+ assert len(generations) == batch_size
+ expectations = {
+ "llama": [578, " The"],
+ "qwen2": [358, " I"],
+ "granite": [203, "\n"],
+ }[config_name]
+ # Greedy mode should always generate the same output
+ if not do_sample:
+ for g in generations:
+ tokens = g.tokens
+ assert tokens.ids[0] == expectations[0]
+ assert tokens.texts[0] == expectations[1]
+
+
+def test_prefill_truncate(neuron_model_config):
+ config_name = neuron_model_config["name"]
+ neuron_model_path = neuron_model_config["neuron_model_path"]
+ generator = NeuronGenerator.from_pretrained(neuron_model_path)
+ batch_size = generator.model.neuron_config.batch_size
+ # We apply truncation to all requests but the first one
+ truncate = [
+ None,
+ ] + [i * 3 for i in range(1, batch_size)]
+ input_text = (
+ "Two gin-scented tears trickled down the sides of his nose."
+ " But it was all right, everything was all right, the struggle was finished."
+ " He had won the victory over himself. He loved Big Brother."
+ )
+ requests = []
+ for i in range(batch_size):
+ requests.append(create_request(id=i, inputs=input_text, truncate=truncate[i]))
+ max_length = generator.max_prefill_length()
+ batch = Batch(
+ id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
+ )
+ generations, _ = generator.prefill(batch)
+ # Even if the input text is identical for all requests, the first generated token might
+ # be different because of the truncation
+ expectations = {
+ "llama": [" He", "iens", "\x08", " He"],
+ "qwen2": [" He", "<|endoftext|>", " ", " The"],
+ "granite": ["\n", "\n", "\n", "\n"],
+ }[config_name]
+ for i, g in enumerate(generations):
+ tokens = g.tokens
+ assert (
+ tokens.texts[0] == expectations[i]
+ ), f"Request {i} expected [{expectations[i]}], got [{tokens.texts[0]}]"
diff --git a/backends/neuron/tests/test_entry_point.py b/backends/neuron/tests/test_entry_point.py
new file mode 100644
index 00000000000..d4ddf338f2f
--- /dev/null
+++ b/backends/neuron/tests/test_entry_point.py
@@ -0,0 +1,63 @@
+import os
+import pytest
+from tempfile import TemporaryDirectory
+
+from optimum.neuron.models.inference.nxd.backend.config import NxDNeuronConfig
+from optimum.neuron.utils import map_torch_dtype
+
+from text_generation_server.tgi_env import (
+ get_neuron_config_for_model,
+ lookup_compatible_cached_model,
+ neuron_config_to_env,
+)
+
+
+def test_get_neuron_config_for_model(neuron_model_config):
+ neuron_model_path = neuron_model_config["neuron_model_path"]
+ export_kwargs = neuron_model_config["export_kwargs"]
+ os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"])
+ os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"])
+ os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"]
+ os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"])
+ neuron_config = get_neuron_config_for_model(neuron_model_path)
+ assert neuron_config is not None
+ assert neuron_config.batch_size == export_kwargs["batch_size"]
+ assert neuron_config.sequence_length == export_kwargs["sequence_length"]
+ assert neuron_config.tp_degree == export_kwargs["num_cores"]
+ if isinstance(neuron_config, NxDNeuronConfig):
+ assert map_torch_dtype(neuron_config.torch_dtype) == map_torch_dtype(
+ export_kwargs["auto_cast_type"]
+ )
+ else:
+ assert map_torch_dtype(neuron_config.auto_cast_type) == map_torch_dtype(
+ export_kwargs["auto_cast_type"]
+ )
+
+
+@pytest.mark.parametrize("model_id", ["unsloth/Llama-3.2-1B-Instruct"])
+def test_lookup_compatible_cached_model(model_id: str):
+ neuron_config = lookup_compatible_cached_model(model_id, None)
+ assert neuron_config is not None
+
+
+def test_neuron_config_to_env(neuron_model_config) -> None:
+ neuron_model_path = neuron_model_config["neuron_model_path"]
+ neuron_config = get_neuron_config_for_model(neuron_model_path)
+ with TemporaryDirectory() as temp_dir:
+ os.environ["ENV_FILEPATH"] = os.path.join(temp_dir, "env.sh")
+ neuron_config_to_env(neuron_config)
+ with open(os.environ["ENV_FILEPATH"], "r") as env_file:
+ env_content = env_file.read()
+ assert f"export MAX_BATCH_SIZE={neuron_config.batch_size}" in env_content
+ assert (
+ f"export MAX_TOTAL_TOKENS={neuron_config.sequence_length}"
+ in env_content
+ )
+ assert f"export HF_NUM_CORES={neuron_config.tp_degree}" in env_content
+ if hasattr(neuron_config, "torch_dtype"):
+ auto_cast_type = str(map_torch_dtype(neuron_config.torch_dtype)).split(
+ "."
+ )[-1]
+ else:
+ auto_cast_type = neuron_config.auto_cast_type
+ assert f"export HF_AUTO_CAST_TYPE={auto_cast_type}" in env_content
diff --git a/backends/neuron/tgi-entrypoint.sh b/backends/neuron/tgi-entrypoint.sh
new file mode 100755
index 00000000000..7965d1da9be
--- /dev/null
+++ b/backends/neuron/tgi-entrypoint.sh
@@ -0,0 +1,16 @@
+#!/bin/bash
+set -e -o pipefail -u
+
+export ENV_FILEPATH=$(mktemp)
+
+trap "rm -f ${ENV_FILEPATH}" EXIT
+
+touch $ENV_FILEPATH
+
+SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
+
+${SCRIPT_DIR}/tgi_entry_point.py $@
+
+source $ENV_FILEPATH
+
+exec text-generation-launcher $@
diff --git a/backends/neuron/tgi_entry_point.py b/backends/neuron/tgi_entry_point.py
new file mode 100755
index 00000000000..7e81d0e4775
--- /dev/null
+++ b/backends/neuron/tgi_entry_point.py
@@ -0,0 +1,53 @@
+#!/usr/bin/env python
+
+import logging
+import os
+import sys
+
+
+from text_generation_server.tgi_env import (
+ available_cores,
+ get_env_dict,
+ get_neuron_config_for_model,
+ neuron_config_to_env,
+ neuronxcc_version,
+ parse_cmdline_and_set_env,
+ tgi_env_vars,
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+def main():
+ """
+ This script determines proper default TGI env variables for the neuron precompiled models to
+ work properly
+ :return:
+ """
+ args = parse_cmdline_and_set_env()
+
+ for env_var in tgi_env_vars:
+ if not os.getenv(env_var):
+ break
+ else:
+ logger.info(
+ "All env vars %s already set, skipping, user know what they are doing",
+ tgi_env_vars,
+ )
+ sys.exit(0)
+
+ neuron_config = get_neuron_config_for_model(args.model_id, args.revision)
+
+ if not neuron_config:
+ msg = (
+ "No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}"
+ ).format(get_env_dict(), available_cores, neuronxcc_version)
+ logger.error(msg)
+ raise Exception(msg)
+
+ neuron_config_to_env(neuron_config)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/backends/trtllm/CMakeLists.txt b/backends/trtllm/CMakeLists.txt
index 831372cdf99..e54fd1169f3 100644
--- a/backends/trtllm/CMakeLists.txt
+++ b/backends/trtllm/CMakeLists.txt
@@ -1,25 +1,19 @@
cmake_minimum_required(VERSION 3.20)
-if (NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER AND CMAKE_BUILD_TYPE STREQUAL "Debug")
- find_program(CCACHE_EXECUTABLE "ccache")
- if (CCACHE_EXECUTABLE)
- message(STATUS "Using ccache")
- set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_EXECUTABLE}" CACHE PATH "Path to ccache" FORCE)
- endif ()
-endif ()
-
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
cmake_policy(SET CMP0135 NEW)
endif ()
project(tgi-trtllm-backend VERSION 1.0.0)
-set(CMAKE_CXX_STANDARD 20)
+set(CMAKE_CXX_STANDARD 23)
include(FetchContent)
include(ExternalProject)
+include(CheckCXXCompilerFlag)
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
+option(TGI_TRTLLM_BACKEND_BUILD_USE_LLD "Enable lld linker instead of ld" OFF)
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located")
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
@@ -27,13 +21,24 @@ set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE ST
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
+find_package(MPI REQUIRED)
#### External dependencies ####
-include(cmake/fmt.cmake)
include(cmake/json.cmake)
include(cmake/spdlog.cmake)
include(cmake/trtllm.cmake)
+if (CMAKE_BUILD_TYPE STREQUAL "Debug")
+ set(TGI_TRTLLM_BACKEND_DEBUG ON)
+ add_compile_definitions(TGI_TRTLLM_BACKEND_DEBUG=1)
+ add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_TRACE)
+endif ()
+
+if (${TGI_TRTLLM_BACKEND_BUILD_USE_LLD})
+ message(STATUS "Using lld linker")
+ add_link_options("-fuse-ld=lld")
+endif ()
+
# Let's build TRTLLM as part of CMake
add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..")
@@ -41,35 +46,75 @@ add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..")
set_target_properties(executorWorker PROPERTIES SKIP_BUILD_RPATH TRUE)
# TGI TRTLLM Backend definition
-add_library(tgi_trtllm_backend_impl STATIC include/backend.h lib/backend.cpp include/hardware.h)
+add_library(tgi_trtllm_backend_impl STATIC csrc/hardware.hpp csrc/backend.hpp csrc/backend.cpp)
include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
target_include_directories(tgi_trtllm_backend_impl PRIVATE
- $
- $
+ $
+ # $
)
target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
-target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper CUDA::cudart CUDA::nvml)
-target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt)
+target_link_libraries(tgi_trtllm_backend_impl PRIVATE CUDA::cudart CUDA::nvml)
+target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog)
+target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper)
# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back
-install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker)
-install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)
+install(TARGETS tgi_trtllm_backend_impl)
+#install(TARGETS cutlass_src fb_gemm_src fpA_intB_gemm_src gemm_swiglu_sm90_src kernels_src)
+install(TARGETS decoder_attention_0 decoder_attention_1)
+install(TARGETS tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention_src executorWorker)
+install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} TYPE LIB)
+if (NOT ${TGI_TRTLLM_BACKEND_DEBUG})
+ install(FILES ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)
+endif ()
+
#### Unit Tests ####
-if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
+if (${TGI_TRTLLM_BACKEND_BUILD_TESTS} AND CMAKE_BUILD_TYPE MATCHES "Debug")
message(STATUS "Building tests")
+ option(TGI_TRTLLM_BACKEND_ENABLE_ASAN "Enable AddressSanitizer")
+ option(TGI_TRTLLM_BACKEND_ENABLE_UBSAN "Enable UndefinedSanitizer")
+
FetchContent_Declare(
Catch2
- GIT_REPOSITORY https://github.com/catchorg/Catch2
- GIT_TAG v3.6.0
+ URL https://github.com/catchorg/Catch2/archive/refs/tags/v3.7.1.tar.gz
)
FetchContent_MakeAvailable(Catch2)
- # add_executable(tgi_trtllm_backend_tests tests/infer_test.cpp)
- # target_link_libraries(tgi_trtllm_backend_tests PRIVATE tgi_trtllm_backend_impl Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt CUDA::cudart CUDA::nvml)
+ # This attempt to detect if the compiler can emit warning if it can't apply return value optimization from a function
+ check_cxx_compiler_flag("-Wnrvo" COMPILER_SUPPORT_WARNING_ON_NVRO)
+ if (${COMPILER_SUPPORT_WARNING_ON_NVRO})
+ message(STATUS "Enabling non-NVRO detection")
+ target_compile_options(tgi_trtllm_backend_impl PRIVATE -Wnrvo)
+ endif ()
+ target_compile_options(tgi_trtllm_backend_impl PRIVATE -Wall)
+
+ cmake_path(GET TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH PARENT_PATH TRTLLM_NVRTC_WRAPPER_PARENT_LIBRARY_PATH)
+ message(STATUS "Adding linking path: ${TRTLLM_NVRTC_WRAPPER_PARENT_LIBRARY_PATH}")
+
+ add_executable(tgi_trtllm_backend_tests tests/test_hardware.cpp tests/test_backend.cpp)
+
+ # target_compile_options(tgi_trtllm_backend_tests PRIVATE -Werror)
+ target_link_directories(tgi_trtllm_backend_tests PRIVATE "${TRTLLM_NVRTC_WRAPPER_PARENT_LIBRARY_PATH}")
+ target_include_directories(tgi_trtllm_backend_tests PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
+ target_include_directories(tgi_trtllm_backend_tests PUBLIC "csrc/")
+ target_link_libraries(tgi_trtllm_backend_tests PRIVATE ${TRTLLM_LIBS} CUDA::cudart CUDA::nvml)
+ target_link_libraries(tgi_trtllm_backend_tests PUBLIC Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog tgi_trtllm_backend_impl)
+ target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper)
+
+ if (${TGI_TRTLLM_BACKEND_ENABLE_ASAN})
+ message(STATUS "Enabled AddressSanitizer")
+ target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=address)
+ endif ()
+
+ if (${TGI_TRTLLM_BACKEND_ENABLE_UBSAN})
+ message(STATUS "Enabled UndefinedSanitizer")
+ target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=undefined)
+ endif ()
+
+ install(TARGETS tgi_trtllm_backend_tests)
- list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
- include(CTest)
- include(Catch)
+ # list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
+ # include(CTest)
+ # include(Catch)
# catch_discover_tests(tgi_trtllm_backend_tests)
endif ()
diff --git a/backends/trtllm/Cargo.toml b/backends/trtllm/Cargo.toml
index 97ef1a76891..b6c39346aa1 100644
--- a/backends/trtllm/Cargo.toml
+++ b/backends/trtllm/Cargo.toml
@@ -7,20 +7,17 @@ homepage.workspace = true
[dependencies]
async-trait = "0.1"
-async-stream = "0.3"
clap = { version = "4.5", features = ["derive"] }
cxx = "1.0"
-hashbrown = "0.14"
+hashbrown = "0.15"
hf-hub = { workspace = true }
-log = { version = "0.4", features = [] }
text-generation-router = { path = "../../router" }
tokenizers = { workspace = true }
-tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
-tokio-stream = "0.1.15"
+tokio = { version = "1.43.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
+tokio-stream = "0.1.17"
thiserror = "1.0.63"
tracing = "0.1"
-tracing-opentelemetry = "0.25"
-tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
+pyo3 = { workspace = true }
[build-dependencies]
cmake = "0.1"
diff --git a/backends/trtllm/build.rs b/backends/trtllm/build.rs
index 985019260b8..c9918e2c52b 100644
--- a/backends/trtllm/build.rs
+++ b/backends/trtllm/build.rs
@@ -3,24 +3,34 @@ use pkg_config;
use std::env;
use std::env::consts::ARCH;
use std::path::{absolute, PathBuf};
+use std::sync::LazyLock;
-const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
+const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 1] = ["spdlog"];
const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
-const CUDA_REQUIRED_VERSION: &str = "12.6";
+const CUDA_REQUIRED_VERSION: &str = "12.8";
const MPI_REQUIRED_VERSION: &str = "4.1";
const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX");
const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR");
+const IS_GHA_BUILD: LazyLock = LazyLock::new(|| {
+ option_env!("SCCACHE_GHA_ENABLED").map_or(false, |value| match value.to_lowercase().as_str() {
+ "on" => true,
+ "true" => true,
+ "1" => true,
+ _ => false,
+ })
+});
+
// Dependencies
-const BACKEND_DEPS: [&str; 2] = ["tgi_trtllm_backend_impl", "tgi_trtllm_backend"];
+const BACKEND_DEPS: &str = "tgi_trtllm_backend_impl";
const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"];
const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [
("dylib", "tensorrt_llm"),
- ("static", "tensorrt_llm_executor_static"),
("dylib", "tensorrt_llm_nvrtc_wrapper"),
("dylib", "nvinfer_plugin_tensorrt_llm"),
- ("dylib", "decoder_attention"),
+ ("dylib", "decoder_attention_0"),
+ ("dylib", "decoder_attention_1"),
];
macro_rules! probe {
@@ -32,6 +42,48 @@ macro_rules! probe {
};
}
+fn get_compiler_flag(
+ switch: bool,
+ true_case: &'static str,
+ false_case: &'static str,
+) -> &'static str {
+ match switch {
+ true => true_case,
+ false => false_case,
+ }
+}
+
+fn get_library_architecture() -> &'static str {
+ let os = env::var("CARGO_CFG_TARGET_OS").unwrap();
+ let arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap();
+ let env = env::var("CARGO_CFG_TARGET_ENV").unwrap();
+
+ match os.as_str() {
+ "linux" => {
+ if env != "gnu" {
+ panic!("unsupported linux ABI {env}, only 'gnu' is supported")
+ }
+
+ match arch.as_str() {
+ "x86_64" => "x86_64-linux-gnu",
+ "aarch64" => "aarch64-linux-gnu",
+ _ => panic!("unsupported linux architecture {arch}"),
+ }
+ }
+ "windows" => {
+ if env != "msvc" {
+ panic!("unsupported windows ABI {env}, only 'msvc' is supported")
+ }
+
+ match arch.as_str() {
+ "x86_64" => "x86_64-windows-msvc",
+ _ => panic!("unsupported windows architecture {arch}"),
+ }
+ }
+ _ => panic!("unsupported OS {os}"),
+ }
+}
+
fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) {
// Build the backend implementation through CMake
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
@@ -43,7 +95,8 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
install_path = absolute(out_dir).expect("cannot happen").join(install_path);
}
- let _ = cmake::Config::new(".")
+ let mut config = cmake::Config::new(".");
+ config
.uses_cxx11()
.generator("Ninja")
.profile(match is_debug {
@@ -53,9 +106,50 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
.env("OPT_LEVEL", opt_level)
.define("CMAKE_INSTALL_PREFIX", &install_path)
.define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc")
+ .define("CMAKE_LIBRARY_ARCHITECTURE", get_library_architecture())
.define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list)
- .define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path)
- .build();
+ .define(
+ "TGI_TRTLLM_BACKEND_DEBUG",
+ get_compiler_flag(is_debug, "ON", "OFF"),
+ )
+ .define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path);
+
+ if is_debug || *IS_GHA_BUILD {
+ config.define("TGI_TRTLLM_BACKEND_BUILD_TESTS", "ON");
+ }
+
+ if option_env!("USE_LLD_LINKER").is_some() {
+ println!("cargo:warning=Using lld linker");
+ config.define("TGI_TRTLLM_BACKEND_BUILD_USE_LLD", "ON");
+ }
+
+ if (is_debug && option_env!("ENABLE_ASAN").is_some()) || *IS_GHA_BUILD {
+ println!("cargo:warning=Enabling Address Sanitizer");
+ config.define("TGI_TRTLLM_BACKEND_ENABLE_ASAN", "ON");
+ }
+
+ if (is_debug && option_env!("ENABLE_UBSAN").is_some()) || *IS_GHA_BUILD {
+ println!("cargo:warning=Enabling Undefined Sanitizer");
+ config.define("TGI_TRTLLM_BACKEND_ENABLE_UBSAN", "ON");
+ }
+
+ if let Some(nvcc_host_compiler) = option_env!("CMAKE_CUDA_HOST_COMPILER") {
+ config.define("CMAKE_CUDA_HOST_COMPILER", nvcc_host_compiler);
+ }
+
+ if let Some(wrapper) = option_env!("RUSTC_WRAPPER") {
+ println!("cargo:warning=Using caching tool: {wrapper}");
+ config.define("CMAKE_C_COMPILER_LAUNCHER", wrapper);
+ config.define("CMAKE_CXX_COMPILER_LAUNCHER", wrapper);
+ config.define("CMAKE_CUDA_COMPILER_LAUNCHER", wrapper);
+ }
+
+ // Allow to override which Python to use ...
+ if let Some(python3) = option_env!("Python3_EXECUTABLE") {
+ config.define("Python3_EXECUTABLE", python3);
+ }
+
+ config.build();
// Additional transitive CMake dependencies
let deps_folder = out_dir.join("build").join("_deps");
@@ -70,46 +164,43 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
}
// Emit linkage information from the artifacts we just built
- let install_lib_path = install_path.join("lib");
-
- println!(
- r"cargo:warning=Adding link search path: {}",
- install_lib_path.display()
- );
- println!(r"cargo:rustc-link-search={}", install_lib_path.display());
-
+ for path in ["lib", "lib64"] {
+ let install_lib_path = install_path.join(path);
+ println!(
+ r"cargo:warning=Adding link search path: {}",
+ install_lib_path.display()
+ );
+ println!(r"cargo:rustc-link-search={}", install_lib_path.display());
+ }
(PathBuf::from(install_path), deps_folder)
}
fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) {
- let ndebug = match is_debug {
- true => "1",
- false => "0",
- };
-
CFG.include_prefix = "backends/trtllm";
cxx_build::bridge("src/lib.rs")
.static_flag(true)
- .include(deps_folder.join("fmt-src").join("include"))
+ .std("c++23")
.include(deps_folder.join("spdlog-src").join("include"))
.include(deps_folder.join("json-src").join("include"))
.include(deps_folder.join("trtllm-src").join("cpp").join("include"))
.include("/usr/local/cuda/include")
.include("/usr/local/tensorrt/include")
- .file("src/ffi.cpp")
- .std("c++20")
- .define("NDEBUG", ndebug)
+ .include("csrc/")
+ .file("csrc/ffi.hpp")
+ .define(
+ "TGI_TRTLLM_BACKEND_DEBUG",
+ get_compiler_flag(is_debug, "ON", "OFF"),
+ )
.compile("tgi_trtllm_backend");
println!("cargo:rerun-if-changed=CMakeLists.txt");
println!("cargo:rerun-if-changed=cmake/trtllm.cmake");
println!("cargo:rerun-if-changed=cmake/json.cmake");
- println!("cargo:rerun-if-changed=cmake/fmt.cmake");
println!("cargo:rerun-if-changed=cmake/spdlog.cmake");
- println!("cargo:rerun-if-changed=include/backend.h");
- println!("cargo:rerun-if-changed=lib/backend.cpp");
- println!("cargo:rerun-if-changed=include/ffi.h");
- println!("cargo:rerun-if-changed=src/ffi.cpp");
+ println!("cargo:rerun-if-changed=csrc/backend.hpp");
+ println!("cargo:rerun-if-changed=csrc/backend.cpp");
+ println!("cargo:rerun-if-changed=csrc/hardware.hpp");
+ println!("cargo:rerun-if-changed=csrc/ffi.hpp");
}
fn main() {
@@ -118,6 +209,7 @@ fn main() {
let build_profile = env::var("PROFILE").unwrap();
let (is_debug, opt_level) = match build_profile.as_ref() {
"debug" => (true, "0"),
+ "dev" => (true, "0"),
_ => (false, "3"),
};
@@ -154,7 +246,5 @@ fn main() {
});
// Backend
- BACKEND_DEPS.iter().for_each(|name| {
- println!("cargo:rustc-link-lib=static={}", name);
- });
+ println!("cargo:rustc-link-lib=static={}", &BACKEND_DEPS);
}
diff --git a/backends/trtllm/cmake/fmt.cmake b/backends/trtllm/cmake/fmt.cmake
deleted file mode 100644
index afd6ea5f090..00000000000
--- a/backends/trtllm/cmake/fmt.cmake
+++ /dev/null
@@ -1,6 +0,0 @@
-FetchContent_Declare(
- fmt
- DOWNLOAD_EXTRACT_TIMESTAMP
- URL https://github.com/fmtlib/fmt/archive/refs/tags/11.0.2.tar.gz
-)
-FetchContent_MakeAvailable(fmt)
diff --git a/backends/trtllm/cmake/json.cmake b/backends/trtllm/cmake/json.cmake
index 67eff2fe606..d6cdbe3aa90 100644
--- a/backends/trtllm/cmake/json.cmake
+++ b/backends/trtllm/cmake/json.cmake
@@ -1,6 +1,6 @@
fetchcontent_declare(
json
- DOWNLOAD_EXTRACT_TIMESTAMP
- URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz
+# DOWNLOAD_EXTRACT_TIMESTAMP
+ URL https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.tar.gz
)
fetchcontent_makeavailable(json)
diff --git a/backends/trtllm/cmake/spdlog.cmake b/backends/trtllm/cmake/spdlog.cmake
index 7f529a7d29e..e7566cd7302 100644
--- a/backends/trtllm/cmake/spdlog.cmake
+++ b/backends/trtllm/cmake/spdlog.cmake
@@ -1,17 +1,17 @@
set(SPDLOG_USE_FMT ON)
set(SPDLOG_BUILD_SHARED OFF)
-set(SPDLOG_FMT_EXTERNAL ON)
+set(SPDLOG_FMT_EXTERNAL OFF)
# Define the level at which SPDLOG_ compilation level is defined
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
- add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)
+ add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_TRACE)
else ()
- add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO)
+ add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)
endif ()
fetchcontent_declare(
spdlog
- DOWNLOAD_EXTRACT_TIMESTAMP
- URL https://github.com/gabime/spdlog/archive/refs/tags/v1.14.1.tar.gz
+ # DOWNLOAD_EXTRACT_TIMESTAMP
+ URL https://github.com/gabime/spdlog/archive/refs/tags/v1.15.0.tar.gz
)
fetchcontent_makeavailable(spdlog)
diff --git a/backends/trtllm/cmake/trtllm.cmake b/backends/trtllm/cmake/trtllm.cmake
index 5f1b6c19c01..95a99e9bccd 100644
--- a/backends/trtllm/cmake/trtllm.cmake
+++ b/backends/trtllm/cmake/trtllm.cmake
@@ -11,20 +11,25 @@ set(CMAKE_CUDA_ARCHITECTURES ${TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST})
message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}")
+set(ENABLE_UCX OFF)
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
set(FAST_BUILD ON)
- set(NVTX_DISABLE OFF)
+ set(NVTX_DISABLE ON)
+ set(INDEX_RANGE_CHECK ON)
else ()
set(FAST_BUILD OFF)
set(FAST_MATH ON)
- set(NVTX_DISABLE ON)
+ set(NVTX_DISABLE OFF)
+ set(INDEX_RANGE_CHECK OFF)
endif ()
+find_package(Python3 REQUIRED Interpreter)
+
fetchcontent_declare(
trtllm
- GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git
- GIT_TAG 201135e58aa525af7e523d091d4c9584229524bc
- GIT_SHALLOW FALSE
+ GIT_REPOSITORY https://github.com/nvidia/TensorRT-LLM.git
+ GIT_TAG v0.17.0
+ GIT_SHALLOW ON
DOWNLOAD_EXTRACT_TIMESTAMP
)
fetchcontent_makeavailable(trtllm)
diff --git a/backends/trtllm/csrc/backend.cpp b/backends/trtllm/csrc/backend.cpp
new file mode 100644
index 00000000000..2151466be6e
--- /dev/null
+++ b/backends/trtllm/csrc/backend.cpp
@@ -0,0 +1,80 @@
+#include
+
+#include
+
+#include "backend.hpp"
+#include "hardware.hpp"
+
+namespace huggingface::tgi::backends::trtllm {
+ tle::ParallelConfig backend_workspace_t::parallel_config() const {
+ // Single engine (TP = PP = 1) -> using leader mode (no MPI involved)
+ const auto world_size = config_["/pretrained_config/mapping/world_size"_json_pointer].get();
+
+ auto mode = tle::CommunicationMode::kLEADER;
+ std::optional orchestratorConfig = std::nullopt;
+
+ if (world_size > 1) {
+ SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
+ mode = tle::CommunicationMode::kORCHESTRATOR;
+ orchestratorConfig = std::make_optional(true, executor_worker_path_, nullptr,
+ true);
+ } else {
+ SPDLOG_INFO("Detected single engine deployment, using leader mode");
+ }
+
+ return tle::ParallelConfig(tle::CommunicationType::kMPI, mode, std::nullopt, std::nullopt, orchestratorConfig);
+ }
+
+
+ tle::ExecutorConfig backend_workspace_t::executor_config() const {
+ // Retrieve the compute capabilities to enable some options at runtime
+ const auto compute_capabilities = hardware::cuda::compute_capabilities_t();
+
+ // Allocate the config
+ tle::ExecutorConfig executor_config(/* maxBeamWidth = */ 1);
+
+ // Set the parallel config as inferred
+ executor_config.setParallelConfig(parallel_config());
+
+ // Define some configuration variables
+ executor_config.setKvCacheConfig(tle::KvCacheConfig(true));
+ executor_config.setEnableChunkedContext(compute_capabilities.is_at_least_ampere());
+ executor_config.setSchedulerConfig(tle::SchedulerConfig(tle::CapacitySchedulerPolicy::kMAX_UTILIZATION));
+ return executor_config;
+ }
+
+ backend_t::backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path)
+ : workspace(engines_folder, executor_worker_path), executor_(executor_factory_initializer(workspace)) {}
+
+ size_t backend_t::num_tokens_ready() const noexcept {
+ return executor_.getNumResponsesReady();
+ }
+
+ std::expected
+ backend_t::submit(std::span token_ids, const generation_params_t g_params,
+ const sampling_params_t s_params) noexcept {
+ SPDLOG_DEBUG("Submit {:d} tokens for scheduling ({}, {})", token_ids.size(), g_params, s_params);
+ return executor_.enqueueRequest(tle::Request{
+ {token_ids.begin(), token_ids.end()}, // Making actual copy of the tokens
+ static_cast(g_params.max_new_tokens),
+ true,
+ (tle::SamplingConfig) s_params,
+ tle::OutputConfig{ /* returnLogProbs= */ true},
+ std::nullopt,
+ std::nullopt,
+ std::nullopt,
+ std::nullopt,
+ workspace.generation_config().stop_words
+ });
+ }
+
+ std::vector backend_t::pull_tokens() noexcept {
+ SPDLOG_TRACE(FMT_STRING("Pulling out tokens ({:d} available)"), num_tokens_ready());
+ return executor_.awaitResponses();
+ }
+
+ void backend_t::cancel(request_id_t request_id) noexcept {
+ SPDLOG_TRACE(FMT_STRING("Cancelling request: {:d}"), request_id);
+ executor_.cancelRequest(request_id);
+ }
+}
diff --git a/backends/trtllm/csrc/backend.hpp b/backends/trtllm/csrc/backend.hpp
new file mode 100644
index 00000000000..40b44a842b3
--- /dev/null
+++ b/backends/trtllm/csrc/backend.hpp
@@ -0,0 +1,231 @@
+#ifndef TGI_BACKEND_TRTLLM
+#define TGI_BACKEND_TRTLLM
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+
+#include
+
+namespace huggingface::tgi::backends::trtllm {
+ namespace tle = tensorrt_llm::executor;
+ using json = nlohmann::json;
+ using request_id_t = uint64_t;
+ using token_id_t = tle::TokenIdType;
+
+ /**
+ * Represent the parameters used for generation
+ */
+ struct generation_params_t {
+ uint32_t max_new_tokens;
+ };
+
+ /**
+ * Represent the parameters used to sample tokens from the logit distribution
+ */
+ struct sampling_params_t {
+ uint32_t top_k;
+ float_t top_p;
+ float_t repetition_penalty;
+ float_t frequency_penalty;
+ float_t temperature;
+ uint64_t seed;
+
+ constexpr explicit operator tle::SamplingConfig() const {
+ return tle::SamplingConfig{
+ 1,
+ top_k,
+ top_p,
+ std::nullopt,
+ std::nullopt,
+ std::nullopt,
+ seed,
+ temperature,
+ std::nullopt,
+ std::nullopt,
+ repetition_penalty,
+ std::nullopt,
+ frequency_penalty,
+ std::nullopt
+ };
+ }
+ };
+
+ /**
+ * Represent possible values from transformers generation `generation_config.json`.
+ * It usually stores default sampling parameters to use, such as top_p, temperature, etc.
+ */
+ struct generation_config_t {
+ float_t top_p;
+ float_t temperature;
+ std::list> stop_words;
+
+ constexpr explicit generation_config_t(const json &config) :
+ top_p(config.value("top_p", 1.0f)), temperature(config.value("temperature", 1.0f)), stop_words(0) {
+ if (config.contains("/eos_token_id"_json_pointer) && config["/eos_token_id"_json_pointer].is_array()) {
+ const auto &eos_token_id = config["/eos_token_id"_json_pointer];
+ std::for_each(eos_token_id.begin(), eos_token_id.end(), [this](const auto token_id) {
+ stop_words.emplace_back(1, token_id.template get());
+ });
+
+ SPDLOG_DEBUG("Detected {:d} predefined stop_words from generation_config.json", stop_words.size());
+ }
+ }
+ };
+
+ /**
+ * Helper class representing various items which are stored within the TensorRT-LLM engines folder and
+ * can be retrieved at runtime
+ */
+ class backend_workspace_t {
+ private:
+ constexpr static auto as_json = [](const std::filesystem::path &path) -> json {
+ std::ifstream config_f(path);
+ return json::parse(config_f);
+ };
+
+ std::filesystem::path engines_folder_;
+ std::filesystem::path executor_worker_path_;
+ json config_;
+ generation_config_t generation_config_;
+
+ public:
+ backend_workspace_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path) :
+ engines_folder_(engines_folder),
+ executor_worker_path_(executor_worker_path),
+ config_(as_json(engines_folder / "config.json")),
+ generation_config_(as_json(engines_folder / "generation_config.json")) {};
+
+ backend_workspace_t(std::filesystem::path &&engines_folder, std::filesystem::path &&executor_worker_path) :
+ engines_folder_(engines_folder),
+ executor_worker_path_(executor_worker_path),
+ config_(as_json(engines_folder / "config.json")),
+ generation_config_(as_json(engines_folder / "generation_config.json")) {};
+
+ /**
+ * Path to the folder containing the TensorRT-LLM engines
+ * @return local filesystem path to the folder
+ */
+ [[nodiscard]] constexpr std::filesystem::path engines_folder() const { return engines_folder_; }
+
+ /**
+ * Hugging Face transformers' generated `generation_config_t` mapping information stored in the
+ * `generation_config.json` holding default generation parameters.
+ * @return `generation_config_t`
+ */
+ [[nodiscard]] constexpr const generation_config_t &generation_config() const { return generation_config_; }
+
+ /**
+ * Factory method returning new `tensorrt_llm::executor::ParallelConfig` instance used
+ * to initialize `tensorrt_llm::executor::Executor` with multi-instance communication information
+ * @return `tensorrt_llm::executor::ParallelConfig` instance
+ */
+ [[nodiscard]] tle::ParallelConfig parallel_config() const;
+
+ /**
+ * Factory method returning new `tensorrt_llm::executor::ExecutorConfig` instance used
+ * to initialize `tensorrt_llm::executor::Executor`
+ * @return `tensorrt_llm::executor::ExecutorConfig` instance
+ */
+ [[nodiscard]] tle::ExecutorConfig executor_config() const;
+ };
+
+ /**
+ * Error raised by the underlying backend implementation
+ */
+ enum backend_error_t {
+ EXECUTOR_NOT_READY = 3,
+ EXECUTOR_SCHEDULING_FAILED = 4,
+ };
+
+
+ /**
+ * Actual TensorRT-LLM backend implementation interacting with TensorRT-LLM Executor service to
+ * - schedule new request
+ * - pull status of submitted request(s)
+ * - cancel submitted request(s)
+ */
+ class backend_t {
+ private:
+ backend_workspace_t workspace;
+ tle::Executor executor_;
+
+ public:
+ backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path);
+
+ backend_t(std::filesystem::path &&engines_folder, std::filesystem::path &&executor_worker_path)
+ : backend_t(engines_folder, executor_worker_path) {};
+
+ /**
+ * Submit a new request to the executor
+ * @param token_ids
+ * @param generation_params
+ * @param sampling_params
+ * @return Either newly submitted request's id or the error why it failed to submit
+ */
+ [[nodiscard("Discarded executor request_id needs to be assigned")]]
+ std::expected
+ submit(std::span token_ids, generation_params_t generation_params,
+ sampling_params_t sampling_params) noexcept;
+
+ /**
+ * Query the number of tokens available across all in-flight generations
+ * @return
+ */
+ [[nodiscard("Pulling out the number of tokens")]]
+ size_t num_tokens_ready() const noexcept;
+
+ /**
+ * Pull out newly generated tokens from the executor
+ * @return
+ */
+ [[nodiscard("")]]
+ std::vector pull_tokens() noexcept;
+
+ /**
+ * Cancel the specified request on the executor' set
+ * @param request_id Request's Identifier to remove from the in-flight executor
+ */
+ void cancel(request_id_t) noexcept;
+ };
+
+ /**
+ * Create a TensorRT-LLM executor from a workspace
+ */
+ const auto executor_factory_initializer = [](const backend_workspace_t &workspace) -> tle::Executor {
+ return {workspace.engines_folder(), tensorrt_llm::executor::ModelType::kDECODER_ONLY,
+ workspace.executor_config()};
+ };
+}
+
+/**
+ * Helper structures to define formatting strategies for various types in the backend
+ */
+template<>
+struct fmt::formatter : formatter {
+ auto format(huggingface::tgi::backends::trtllm::generation_params_t const &c,
+ format_context &ctx) const -> format_context::iterator {
+ return fmt::format_to(ctx.out(), "generation_params_t{{ max_new_tokens={:d} }}", c.max_new_tokens);
+ }
+};
+
+template<>
+struct fmt::formatter : formatter {
+ auto format(huggingface::tgi::backends::trtllm::sampling_params_t const &c,
+ format_context &ctx) const -> format_context::iterator {
+ return fmt::format_to(
+ ctx.out(),
+ "sampling_params_t{{ top_k={:d}, top_p={:.3f}, repetition_penalty={:.3f}, frequency_penalty={:.3f}, temperature={:.3f}, seed={:d} }}",
+ c.top_k, c.top_p, c.repetition_penalty, c.frequency_penalty, c.temperature, c.seed
+ );
+ }
+};
+
+#endif
diff --git a/backends/trtllm/csrc/ffi.hpp b/backends/trtllm/csrc/ffi.hpp
new file mode 100644
index 00000000000..840614bbcfe
--- /dev/null
+++ b/backends/trtllm/csrc/ffi.hpp
@@ -0,0 +1,191 @@
+#ifndef TGI_BACKEND_TRTLLM_FFI
+#define TGI_BACKEND_TRTLLM_FFI
+
+#include
+#include
+
+#include
+#include
+#include
+
+#include
+
+#include
+#include
+
+namespace rust::behavior {
+ template
+ static void trycatch(Try &&func, Fail &&fail) noexcept try {
+ func();
+ } catch (tensorrt_llm::common::TllmException &e) {
+ fail(e.what());
+ }
+}
+
+namespace huggingface::tgi::backends::trtllm {
+ class tensorrt_llm_backend_t;
+}
+
+#include "backends/trtllm/src/lib.rs.h"
+
+
+namespace huggingface::tgi::backends::trtllm {
+ std::once_flag backend_initialized_flag;
+
+ constexpr finish_reason_t as_finish_reason_t(const tle::FinishReason reason) noexcept {
+ switch (reason) {
+ case tle::FinishReason::kNOT_FINISHED:
+ return finish_reason_t::kNOT_FINISHED;
+ case tle::FinishReason::kSTOP_WORDS:
+ return finish_reason_t::kSTOP_WORDS;
+ case tle::FinishReason::kEND_ID:
+ return finish_reason_t::kEND_ID;
+ case tle::FinishReason::kLENGTH:
+ return finish_reason_t::kLENGTH;
+ default:
+ std::unreachable();
+ }
+ }
+
+ static auto as_generation_step = [](const tle::Response &r) {
+ const auto reqId = r.getRequestId();
+ if (!r.hasError()) [[likely]] {
+ const auto result = r.getResult();
+ const auto logits = result.logProbs.value()[0];
+ return generation_step_t{
+ reqId,
+ static_cast(result.outputTokenIds[0][0]),
+ logits.back(),
+ result.isFinal,
+ as_finish_reason_t(result.finishReasons[0]),
+ false,
+ std::string()
+ };
+ } else {
+ return generation_step_t{
+ reqId,
+ 0,
+ 0.0,
+ true,
+ finish_reason_t::kNOT_FINISHED,
+ true,
+ std::move(r.getErrorMsg())
+ };
+ }
+ };
+
+
+ class tensorrt_llm_backend_t {
+ private:
+ backend_t inner_;
+
+ public:
+ tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path)
+ : inner_(engine_folder, executor_worker_path) {}
+
+ size_t num_tokens_ready() const noexcept { return inner_.num_tokens_ready(); }
+
+ request_id_t submit(
+ rust::Slice tokens,
+ uint32_t max_new_tokens,
+ uint32_t top_k,
+ float_t top_p,
+ float_t temperature,
+ float_t repetition_penalty,
+ float_t frequency_penalty,
+ uint64_t seed
+ ) {
+ // This is enabled only if using add_compile_definitions(SPDLOG_ACTIVE_LEVEL=SPDLOG_LEVEL_TRACE)
+ SPDLOG_TRACE(FMT_STRING("[FFI] Submitting {:d} prompt tokens to the executor"));
+
+ // Submit the request to the executor and get back a potential request_id used to track request status
+ const auto signed_tokens = std::vector(tokens.begin(), tokens.end());
+ const auto maybe_request_id = inner_.submit(
+ signed_tokens,
+ {max_new_tokens},
+ {top_k, top_p, repetition_penalty, frequency_penalty, temperature, seed}
+ );
+
+ // If we do have a value, let's return the request_id
+ if (maybe_request_id.has_value()) [[likely]] {
+ return *maybe_request_id;
+ } else {
+ SPDLOG_WARN("[FFI] Failed to submit request to the executor");
+ return maybe_request_id.error();
+ }
+ }
+
+ std::unique_ptr> pull_tokens() noexcept {
+ if (num_tokens_ready() > 0) [[likely]] {
+ const auto responses = inner_.pull_tokens();
+
+ SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size());
+
+ // Transform tle::Response to generation_step_t
+#ifdef __cpp_lib_ranges_to_container
+ auto steps = responses | std::views::transform(as_generation_step) | std::ranges::to();
+#else
+ auto steps = std::vector();
+ steps.reserve(responses.size());
+ std::transform(responses.begin(), responses.end(), std::back_inserter(steps), as_generation_step);
+#endif
+ return std::make_unique>(steps);
+
+ } else {
+ return std::make_unique>();
+ }
+ }
+
+ void cancel(request_id_t request_id) noexcept {
+ SPDLOG_DEBUG("[FFI] cancelling request {:d}", request_id);
+ inner_.cancel(request_id);
+ }
+ };
+
+ void initialize_logging() {
+#ifndef TGI_TRTLLM_BACKEND_DEBUG
+ if (const auto TRTLLM_LOG_LEVEL_CSTR = std::getenv("TRTLLM_LOG_LEVEL")) {
+ std::string log_level(TRTLLM_LOG_LEVEL_CSTR);
+ std::transform(log_level.begin(), log_level.end(), log_level.begin(), [](unsigned char c) {
+ return std::tolower(c);
+ });
+
+ if (log_level == "debug")
+ spdlog::set_level(spdlog::level::debug);
+ else
+ spdlog::set_level(spdlog::level::info);
+ }
+#else
+ spdlog::set_level(spdlog::level::debug);
+#endif
+ }
+
+ void initialize_tensorrt_llm_backend() {
+ SPDLOG_INFO("Initializing TGI - TensoRT-LLM Backend (v{})", tle::version());
+
+ // Initialize everyone
+ initialize_logging();
+ nvmlInit_v2();
+ initTrtLlmPlugins();
+
+ const auto numGpus = huggingface::tgi::hardware::cuda::get_device_count();
+ if (numGpus.has_value()) {
+ SPDLOG_INFO("[FFI] Detected {:d} Nvidia GPU(s)", *numGpus);
+ } else {
+ SPDLOG_WARN("[FFI] Failed to detected Nvidia GPU(s) on the system");
+ // todo: throw
+ }
+ }
+
+ std::unique_ptr
+ create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) {
+ std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend);
+ return std::make_unique(
+ std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()),
+ std::filesystem::path::format::auto_format),
+ std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()),
+ std::filesystem::path::format::auto_format)
+ );
+ }
+}
+#endif
diff --git a/backends/trtllm/csrc/hardware.hpp b/backends/trtllm/csrc/hardware.hpp
new file mode 100644
index 00000000000..abfb4afd51d
--- /dev/null
+++ b/backends/trtllm/csrc/hardware.hpp
@@ -0,0 +1,81 @@
+#ifndef TGI_HARDWARE_CUDA
+#define TGI_HARDWARE_CUDA
+#include
+#include
+
+#include
+
+namespace huggingface::tgi::hardware::cuda {
+ static constexpr auto VOLTA = std::make_tuple(7u, 0u);
+ static constexpr auto TURING = std::make_tuple(7u, 5u);
+ static constexpr auto AMPERE = std::make_tuple(8u, 0u);
+ static constexpr auto HOPPER = std::make_tuple(9u, 0u);
+ static constexpr auto ADA_LOVELACE = std::make_tuple(8u, 9u);
+
+ /**
+ * Get the number of GPUs on the local machine
+ * @return std::nullopt if no device is available, otherwise >= 1
+ */
+ inline std::optional get_device_count() {
+ uint32_t numGpus = 0;
+ if (nvmlDeviceGetCount_v2(&numGpus) == NVML_SUCCESS) {
+ return numGpus;
+ }
+ return std::nullopt;
+ }
+
+ /**
+ * Store information about the version of the CUDA Compute Capabilities detected on the device
+ */
+ struct compute_capabilities_t {
+ int32_t major;
+ int32_t minor;
+
+ compute_capabilities_t(): compute_capabilities_t(0) {}
+ explicit compute_capabilities_t(size_t device_idx): major(-1), minor(-1) {
+ nvmlDevice_t device;
+ if (nvmlDeviceGetHandleByIndex_v2(device_idx, &device) == NVML_SUCCESS) {
+ nvmlDeviceGetCudaComputeCapability(device, &major, &minor);
+ }
+ };
+ compute_capabilities_t(int32_t major, int32_t minor): major(major), minor(minor) {}
+
+ /**
+ * Evaluate if the underlying capabilities is at least greater or equals to the provided 2-tuple (major, minor)
+ * @param sm Architecture version (major, minor)
+ * @return True if greater or equals to the underlying compute capabilities
+ */
+ [[nodiscard]] constexpr auto is_at_least(std::tuple sm) const -> decltype(auto) { return std::tie(major, minor) >= sm; }
+
+ /**
+ * Check if the capabilities match at least Volta architecture (sm_70)
+ * @return true if at least Volta (>= sm_70), false otherwise
+ */
+ [[nodiscard]] constexpr bool is_at_least_volta() const { return is_at_least(VOLTA); }
+
+ /**
+ * Check if the capabilities match at least Turing architecture (sm_75)
+ * @return true if at least Turing (>= sm_75), false otherwise
+ */
+ [[nodiscard]] constexpr bool is_at_least_turing() const { return is_at_least(TURING); }
+
+ /**
+ * Check if the capabilities match at least Ampere architecture (sm_80)
+ * @return true if at least Ampere (>= sm_80), false otherwise
+ */
+ [[nodiscard]] constexpr bool is_at_least_ampere() const { return is_at_least(AMPERE); }
+
+ /**
+ * Check if the capabilities match at least Ada Lovelace architecture (sm_89)
+ * @return true if at least Ada Lovelace (>= sm_89), false otherwise
+ */
+ [[nodiscard]] constexpr bool is_at_least_ada_lovelace() const { return is_at_least(ADA_LOVELACE); }
+
+ /**
+ * Check if the capabilities match at least Hopper architecture (sm_90)
+ * @return true if at least Hopper (>= sm_90), false otherwise
+ */
+ [[nodiscard]] constexpr bool is_at_least_hopper() const { return is_at_least(HOPPER); }
+ };
+}
+#endif
diff --git a/backends/trtllm/include/backend.h b/backends/trtllm/include/backend.h
deleted file mode 100644
index d23f6288964..00000000000
--- a/backends/trtllm/include/backend.h
+++ /dev/null
@@ -1,144 +0,0 @@
-//
-// Created by Morgan Funtowicz on 6/30/24.
-//
-
-#ifndef TGI_TRTLLM_BACKEND_H
-#define TGI_TRTLLM_BACKEND_H
-
-#include
-#include
-#include
-#include
-#include
-
-#include
-
-#include
-#include
-#include
-
-using json = nlohmann::json;
-namespace tle = tensorrt_llm::executor;
-
-
-#define CAST_SIZETYPE(x) static_cast(x)
-
-namespace huggingface::tgi::backends {
- using RequestId = tle::IdType;
- using TokenId = tle::TokenIdType;
-
- const static auto OUTPUT_CONFIG = tle::OutputConfig(true, false, false, true, false);
- constexpr auto FMT_NOT_ENOUGH_GPUS = FMT_STRING(
- "Not enough GPUs to allocate requested model (detected: {:d}, required: {:d})");
- constexpr auto FMT_EXECUTOR_STATS = FMT_STRING(
- "Submitting inference [{}] to the executor ({:d} already in-flight)");
- constexpr auto FMT_SAMPLING_CONFIG = FMT_STRING(
- "Sampling: topK={:d}, topP={:.1f}, temperature={:.1f}, repetition_penalty={:.1f}, frequency_penalty={:.1f}, seed={:d}");
-
- /**
- * Initialize all the components required by TRTLLM.
- * It is required to call this function before attempting to load any engine
- */
- void InitializeBackend();
-
- /**
- * Initialize logging mechanism
- */
- void InitializeLogging();
-
-
- /**
- *
- * @param config TensorRT-LLM configuration object
- * @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode
- * @return
- */
- tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath);
-
- /**
- *
- * @param worldSize
- * @param workerPath
- * @return
- */
- tle::ParallelConfig GetParallelConfig(size_t worldSize, std::string workerPath) noexcept;
-
- /**
- * Get the sampling configuration from the parameters provided by TGI
- * @param topK
- * @param topP
- * @param temperature
- * @param repetition_penalty
- * @param frequency_penalty
- * @param seed
- * @return
- */
- tle::SamplingConfig GetSamplingConfig(
- uint32_t topK,
- float_t topP,
- float_t temperature,
- float_t repetition_penalty,
- float_t frequency_penalty,
- uint64_t seed
- ) noexcept;
-
- /**
- * Attempt to retrieve the
- * @param generationConfigPath
- * @return
- */
- std::optional>>
- GetStopWordsFromConfig(const std::filesystem::path &generationConfigPath) noexcept;
-
- /**
- *
- */
- class TensorRtLlmBackend {
- private:
- const json config;
- tle::Executor executor;
-
- /** Frequently accessed variables cached here **/
- uint32_t maxNumTokens;
- std::list> stopWords;
-
- public:
- explicit TensorRtLlmBackend(
- const std::filesystem::path &engineFolder,
- const std::filesystem::path &executorWorker
- );
-
- /**
- * Query the executor for the number of token available for pulling
- * @return
- */
- [[nodiscard]] size_t NumResponsesReady() const;
-
- /**
- * Submit a new generation task to the executor
- * @param tokens
- * @param topK
- * @param topP
- * @param temperature
- * @param repetitionPenalty
- * @param frequencyPenalty
- * @param seed
- * @return Request id related to this generation for reference
- */
- [[nodiscard]] RequestId Submit(
- const std::vector &tokens,
- uint32_t maxNewTokens,
- int32_t topK,
- float_t topP,
- float_t temperature,
- float_t repetitionPenalty,
- float_t frequencyPenalty,
- uint64_t seed
- );
-
- [[nodiscard]] std::vector PullNewTokens();
- };
-}
-
-
-#endif //TGI_TRTLLM_BACKEND_H
diff --git a/backends/trtllm/include/ffi.h b/backends/trtllm/include/ffi.h
deleted file mode 100644
index 449bcd4d739..00000000000
--- a/backends/trtllm/include/ffi.h
+++ /dev/null
@@ -1,75 +0,0 @@
-//
-// Created by mfuntowicz on 7/11/24.
-//
-
-#ifndef TGI_TRTLLM_BACKEND_FFI_H
-#define TGI_TRTLLM_BACKEND_FFI_H
-
-#include
-#include
-#include
-#include "backend.h"
-
-namespace huggingface::tgi::backends {
- class TensorRtLlmBackendImpl;
-}
-
-// Template to support returning error from TllmException back to Rust in a Result<>
-#include
-
-namespace rust::behavior {
- template
- static void trycatch(Try &&func, Fail &&fail) noexcept try {
- func();
- } catch (tensorrt_llm::common::TllmException &e) {
- fail(e.what());
- }
-}
-
-#include "backends/trtllm/src/lib.rs.h"
-
-namespace huggingface::tgi::backends {
-
- class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
- public:
- /***
- *
- * @param engineFolder
- * @param executorWorker
- */
- TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker);
-
- /***
- *
- * @param tokens
- * @param maxNewTokens
- * @param topK
- * @param topP
- * @param temperature
- * @param repetition_penalty
- * @param frequency_penalty
- * @param seed
- * @return
- */
- [[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
- uint64_t
- Submit(rust::Slice tokens, uint32_t maxNewTokens,
- int32_t topK, float_t topP, float_t temperature,
- float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);
-
- /***
- *
- * @return
- */
- std::unique_ptr> PullTokens();
- };
-
- /***
- *
- * @param engineFolder
- * @return
- */
- std::unique_ptr CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker);
-}
-
-#endif //TGI_TRTLLM_BACKEND_FFI_H
diff --git a/backends/trtllm/include/hardware.h b/backends/trtllm/include/hardware.h
deleted file mode 100644
index 9633495f4fd..00000000000
--- a/backends/trtllm/include/hardware.h
+++ /dev/null
@@ -1,59 +0,0 @@
-//
-// Created by mfuntowicz on 7/23/24.
-//
-
-#ifndef TGI_TRTLLM_BACKEND_HARDWARE_H
-#define TGI_TRTLLM_BACKEND_HARDWARE_H
-
-#include
-#include
-#include
-#include
-#include
-
-namespace huggingface::hardware::cuda {
-
-#define AMPERE_SM_MAJOR 8
-#define HOPPER_SM_MAJOR 9
-
- /**
- * Store information about the version of the CUDA Compute Capabilities detected on the device
- */
- struct CudaComputeCapabilities {
- int32_t major;
- int32_t minor;
-
- [[nodiscard]] constexpr bool IsPostAmpere() const { return major >= AMPERE_SM_MAJOR; }
-
- [[nodiscard]] constexpr bool IsPostHopper() const { return major >= HOPPER_SM_MAJOR; }
- };
-
- CudaComputeCapabilities GetCudaComputeCapabilities() {
- // Get the compute capabilities of the current hardware
- nvmlDevice_t device;
- CudaComputeCapabilities capabilities{0, 0};
- if (nvmlDeviceGetHandleByIndex_v2(0, &device) == NVML_SUCCESS) {
- SPDLOG_DEBUG("Successfully acquired nvmlDevice_t = 0");
- if (nvmlDeviceGetCudaComputeCapability(device, &capabilities.major, &capabilities.minor) == NVML_SUCCESS) {
- SPDLOG_INFO("Detected sm_{:d}{:d} compute capabilities", capabilities.major, capabilities.minor);
- }
- }
-
- return capabilities;
- }
-
- /**
- * Return the number of GPU detected. If no GPU is detected, return size_t::max()
- * @return
- */
- std::optional GetNumDevices() {
- uint32_t numGpus = 0;
- if (nvmlDeviceGetCount_v2(&numGpus) == NVML_SUCCESS) {
- return std::optional(numGpus);
- } else {
- return std::nullopt;
- }
- }
-}
-
-#endif //TGI_TRTLLM_BACKEND_HARDWARE_H
diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp
deleted file mode 100644
index 4dd41de0072..00000000000
--- a/backends/trtllm/lib/backend.cpp
+++ /dev/null
@@ -1,203 +0,0 @@
-#include
-#include
-
-#include
-#include
-#include
-
-#include "backend.h"
-#include "hardware.h"
-
-
-void huggingface::tgi::backends::InitializeLogging() {
-#ifdef NDEBUG
- if (const auto TRTLLM_LOG_LEVEL_CSTR = std::getenv("TRTLLM_LOG_LEVEL")) {
- std::string log_level(TRTLLM_LOG_LEVEL_CSTR);
- std::transform(log_level.begin(), log_level.end(), log_level.begin(), [](unsigned char c) {
- return std::tolower(c);
- });
-
- if (log_level == "debug")
- spdlog::set_level(spdlog::level::debug);
- else
- spdlog::set_level(spdlog::level::info);
- }
-#else
- spdlog::set_level(spdlog::level::debug);
-#endif
-}
-
-void huggingface::tgi::backends::InitializeBackend() {
- SPDLOG_INFO("Initializing Backend...");
- nvmlInit_v2();
- initTrtLlmPlugins();
-
- InitializeLogging();
-
- SPDLOG_INFO("Backend Executor Version: {}", tle::version());
- const auto numGpus = huggingface::hardware::cuda::GetNumDevices();
- if (numGpus.has_value()) {
- SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value());
- } else {
- SPDLOG_WARN("Failed to detected Nvidia GPU(s) on the system");
- }
-}
-
-[[nodiscard]]
-tle::ParallelConfig
-huggingface::tgi::backends::GetParallelConfig(const size_t worldSize, const std::string workerPath) noexcept {
- auto mode = tle::CommunicationMode::kLEADER;
- std::optional orchestratorConfig = std::nullopt;
-
- if (worldSize > 1) {
- SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
- mode = tle::CommunicationMode::kORCHESTRATOR;
- orchestratorConfig = std::make_optional(true, workerPath, nullptr, true);
- } else {
- SPDLOG_INFO("Detected single engine deployment, using leader mode");
- }
-
- return tle::ParallelConfig(tle::CommunicationType::kMPI, mode, std::nullopt, std::nullopt, orchestratorConfig);
-}
-
-[[nodiscard]]
-tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
- tle::ExecutorConfig execConfig(/* maxBeamWidth = */ 1);
-
- // Retrieve the compute capabilities to enable some options at runtime
- const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities();
-
- // Single engine (TP = PP = 1) -> using leader mode (no MPI involved)
- const auto worldSize = config["/pretrained_config/mapping/world_size"_json_pointer].get();
- execConfig.setParallelConfig(GetParallelConfig(worldSize, workerPath));
-
- // Define some configuration variables
- execConfig.setKvCacheConfig(tle::KvCacheConfig(true));
- execConfig.setEnableChunkedContext(computeCapabilities.IsPostAmpere());
- execConfig.setSchedulerConfig(tle::SchedulerConfig(tle::CapacitySchedulerPolicy::kMAX_UTILIZATION));
- return execConfig;
-}
-
-tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
- const uint32_t topK,
- const float_t topP,
- const float_t temperature,
- const float_t repetition_penalty,
- const float_t frequency_penalty,
- const uint64_t seed) noexcept {
-
- return tle::SamplingConfig(
- 1, // TGI only use a single beam
- topK,
- topP,
- std::nullopt,
- std::nullopt,
- std::nullopt,
- seed,
- temperature,
- temperature,
- std::nullopt,
- repetition_penalty,
- std::nullopt,
- frequency_penalty
- );
-}
-
-std::optional>>
-huggingface::tgi::backends::GetStopWordsFromConfig(
- const std::filesystem::path &generationConfigPath) noexcept {
- if (exists(generationConfigPath)) {
- const auto generationConfig = json::parse(std::ifstream(generationConfigPath));
- if (const auto eosTokenIds = generationConfig["/eos_token_id"_json_pointer]; eosTokenIds.is_array()) {
- SPDLOG_INFO(FMT_STRING("Found {:d} EOS tokens"), eosTokenIds.size());
- std::list> stopWords(eosTokenIds.size());
-
- const auto to_single_token = [](const auto tokenIdObj) -> decltype(stopWords)::value_type {
- return {tokenIdObj.template get()};
- };
-
- std::transform(eosTokenIds.cbegin(), eosTokenIds.cend(), stopWords.begin(), to_single_token);
- return stopWords;
- } else {
- SPDLOG_INFO("Invalid EOS tokens entry found (not an array)");
- }
- } else {
- SPDLOG_INFO("No EOS tokens found, generation_config.json doesn't exist");
- }
-
- return std::nullopt;
-}
-
-huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
- const std::filesystem::path &enginesFolder,
- const std::filesystem::path &executorWorker
-) :
- config(json::parse(std::ifstream(enginesFolder / "config.json"))),
- executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
- GetExecutorConfig(config, executorWorker.string())) {
-
- SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get());
-
- // Ensure we have enough GPUs on the system
- const auto worldSize = config["/pretrained_config/mapping/world_size"_json_pointer].get();
- const auto numGpus = huggingface::hardware::cuda::GetNumDevices().value_or(0);
- if (numGpus < worldSize) {
- SPDLOG_CRITICAL(FMT_NOT_ENOUGH_GPUS, numGpus, worldSize);
- // todo : raise exception to catch on rust side
- }
-
- // Cache variables
- maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get();
-
- // Attempt to discover stopWords from the generation_config.json
- const auto generationConfigPath = enginesFolder / "generation_config.json";
- stopWords = GetStopWordsFromConfig(generationConfigPath).value_or(std::list>());
-}
-
-[[nodiscard("Returned number of requests needs to be consumed")]]
-size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
-#ifdef NDEBUG
- return executor.getNumResponsesReady();
-#else
- const auto numResponses = executor.getNumResponsesReady();
- if (numResponses > 0) SPDLOG_INFO(FMT_STRING("Num responses ready: {:d}"), numResponses);
- return numResponses;
-#endif
-}
-
-[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
-tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
- const std::vector &tokens,
- const uint32_t maxNewTokens,
- const int32_t topK,
- const float_t topP,
- const float_t temperature,
- const float_t repetitionPenalty,
- const float_t frequencyPenalty,
- const uint64_t seed
-) {
- const auto maxNewTokensChecked = std::min(maxNewTokens, static_cast(maxNumTokens - tokens.size()));
-#ifndef NDEBUG
- {
- const auto &iterations = executor.getLatestIterationStats();
- const auto &lastIteration = iterations.front();
-
- SPDLOG_DEBUG(FMT_EXECUTOR_STATS, fmt::join(tokens, ", "), lastIteration.numActiveRequests);
- SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed);
- SPDLOG_DEBUG(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked);
- }
-#endif
-
- const auto sampling = GetSamplingConfig(topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed);
-
- // Build the request
- auto request = tle::Request{tokens, CAST_SIZETYPE(maxNewTokensChecked), true, sampling, OUTPUT_CONFIG};
- request.setStopWords(stopWords);
-
- // Submit to the executor for batching
- return executor.enqueueRequest(request);
-}
-
-std::vector huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {
- return executor.awaitResponses();
-}
diff --git a/backends/trtllm/scripts/install_tensorrt.sh b/backends/trtllm/scripts/install_tensorrt.sh
index 4c2dc26b6bf..e09db6b1296 100755
--- a/backends/trtllm/scripts/install_tensorrt.sh
+++ b/backends/trtllm/scripts/install_tensorrt.sh
@@ -2,13 +2,13 @@
set -ex
-TRT_VER_BASE="10.4.0"
-TRT_VER_FULL="${TRT_VER_BASE}.26"
-CUDA_VER="12.6"
-CUDNN_VER="9.5.0.50-1"
-NCCL_VER="2.22.3-1+cuda12.6"
-CUBLAS_VER="12.6.3.3-1"
-NVRTC_VER="12.6.77-1"
+TRT_VER_BASE="10.8.0"
+TRT_VER_FULL="${TRT_VER_BASE}.43"
+CUDA_VER="12.8"
+CUDNN_VER="9.7.0.66-1"
+NCCL_VER="2.25.1-1+cuda${CUDA_VER}"
+CUBLAS_VER="${CUDA_VER}.3.14-1"
+NVRTC_VER="${CUDA_VER}.61-1"
for i in "$@"; do
case $i in
@@ -73,7 +73,7 @@ install_centos_requirements() {
install_tensorrt() {
#PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
#PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
- TRT_CUDA_VERSION="12.6"
+ TRT_CUDA_VERSION="12.8"
if [ -z "$RELEASE_URL_TRT" ];then
ARCH=${TRT_TARGETARCH}
diff --git a/backends/trtllm/scripts/setup_sccache.py b/backends/trtllm/scripts/setup_sccache.py
new file mode 100644
index 00000000000..65fdee23537
--- /dev/null
+++ b/backends/trtllm/scripts/setup_sccache.py
@@ -0,0 +1,51 @@
+from argparse import ArgumentParser
+
+AWS_S3_CACHING_VARIABLES = {
+ "AWS_ACCESS_KEY_ID": "aws_access_key_id",
+ "AWS_SECRET_ACCESS_KEY": "aws_secret_access_key",
+ "AWS_SESSION_TOKEN": "aws_session_token",
+ "SCCACHE_REGION": "s3_region",
+ "SCCACHE_BUCKET": "s3_bucket_name",
+}
+
+ALL_CACHING_STORAGE_VARIABLES = {"AWS_S3_CACHING_VARIABLES"}
+
+
+def setup_sccache_locally():
+ from os import environ
+
+ print("Setting up Local Caching Layer")
+ for target in ALL_CACHING_STORAGE_VARIABLES:
+ for envvar in globals()[target].keys():
+ if envvar in environ:
+ print(f"Deleted {envvar} from environment variables")
+ del environ[envvar]
+
+
+def setup_sccache_for_s3():
+ from os import environ
+
+ print("Setting up AWS S3 Caching Layer")
+ for envvar in AWS_S3_CACHING_VARIABLES.keys():
+ if envvar not in environ or not environ[envvar] or len(environ[envvar]) == 0:
+ print(f"Missing definition for environment variable {envvar}")
+
+
+if __name__ == "__main__":
+ parser = ArgumentParser("TensorRT-LLM Build Caching Setup")
+
+ parser.add_argument(
+ "--is-gha-build",
+ type=str,
+ default="FALSE",
+ help="Indicate if the build is from Github Actions",
+ )
+
+ # Parse args
+ args = parser.parse_args()
+ args.is_gha_build = args.is_gha_build.lower() in {"on", "true", "1"}
+
+ if args.is_gha_build:
+ setup_sccache_for_s3()
+ else:
+ setup_sccache_locally()
diff --git a/backends/trtllm/src/ffi.cpp b/backends/trtllm/src/ffi.cpp
deleted file mode 100644
index 0a92c050f65..00000000000
--- a/backends/trtllm/src/ffi.cpp
+++ /dev/null
@@ -1,89 +0,0 @@
-//
-// Created by mfuntowicz on 6/30/24.
-//
-#pragma once
-
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-
-#include
-#include "backends/trtllm/include/ffi.h"
-
-
-huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(
- const std::string_view &engineFolder,
- const std::string_view &executorWorker
-) : TensorRtLlmBackend(engineFolder, executorWorker) {}
-
-
-uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
- rust::Slice tokens,
- uint32_t maxNewTokens,
- int32_t topK,
- float_t topP,
- float_t temperature,
- float_t repetition_penalty,
- float_t frequency_penalty,
- uint64_t seed) {
-
- // This will copy all the items from the initial slice
- std::vector tokens_(tokens.begin(), tokens.end());
- return TensorRtLlmBackend::Submit(
- std::move(tokens_), maxNewTokens, topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
-}
-
-std::unique_ptr>
-huggingface::tgi::backends::TensorRtLlmBackendImpl::PullTokens() {
- const auto responses = TensorRtLlmBackend::PullNewTokens();
-
- auto steps = std::make_unique>();
- steps->reserve(responses.size());
-
-#ifndef NDEBUG
- SPDLOG_DEBUG(FMT_STRING("Pulled out {:d} new tokens"), responses->size());
-#endif
-
- // Transform tle::Response to GenerationStep
- std::ranges::transform(responses.begin(), responses.end(), std::back_inserter(*steps), [](const tle::Response &r) {
- const auto reqId = r.getRequestId();
- if (!r.hasError()) {
- const auto result = r.getResult();
- return GenerationStep{
- reqId,
- static_cast(result.outputTokenIds[0][0]),
- result.logProbs.value()[0][0],
- result.isFinal,
- false,
- std::string()
- };
- } else {
- return GenerationStep{
- reqId,
- 0,
- 0.0,
- true,
- true,
- std::move(r.getErrorMsg())
- };
- }
- });
-
- return steps;
-}
-
-std::unique_ptr
-huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
- SPDLOG_INFO("Creating TensorRT-LLM Backend");
- // Unconditionally call this to initialize and discover TRTLLM plugins
- InitializeBackend();
-
- const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end());
- const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end());
- return std::make_unique(std::move(enginePath), std::move(executorPath));
-}
diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs
index edd8caff154..085072561f1 100644
--- a/backends/trtllm/src/lib.rs
+++ b/backends/trtllm/src/lib.rs
@@ -4,24 +4,47 @@ pub mod errors;
mod looper;
mod utils;
-#[cxx::bridge(namespace = "huggingface::tgi::backends")]
+#[cxx::bridge(namespace = "huggingface::tgi::backends::trtllm")]
mod ffi {
+ #[cxx_name = "finish_reason_t"]
+ #[derive(Debug, Clone, Copy)]
+ pub enum FinishReason {
+ /// The request is not finished.
+ #[cxx_name = "kNOT_FINISHED"]
+ NotFinished = 0u8,
+
+ /// The request finished because the end id was generated.
+ #[cxx_name = "kEND_ID"]
+ EndTokenId = 1u8,
+
+ /// The request finished because a stop word was generated.
+ #[cxx_name = "kSTOP_WORDS"]
+ StopWords = 2u8,
+
+ /// The request finished because the maximum number of tokens was reached.
+ #[cxx_name = "kLENGTH"]
+ MaxLength = 3u8,
+ }
+
/// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration
+ #[cxx_name = "generation_step_t"]
#[derive(Debug, Clone)]
pub struct GenerationStep {
request_id: u64,
token_id: u32,
log_prob: f32,
is_final: bool,
+ finish_reason: FinishReason,
has_error: bool,
error_msg: String,
}
unsafe extern "C++" {
- include!("backends/trtllm/src/ffi.cpp");
+ include!("backends/trtllm/csrc/ffi.hpp");
/// Represent an instance of the underlying TensorRT-LLM backend
+ #[cxx_name = "tensorrt_llm_backend_t"]
type TensorRtLlmBackendImpl;
/// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend
@@ -38,21 +61,18 @@ mod ffi {
/// ```
///
/// ```
- #[rust_name = "create_tensorrt_llm_backend"]
- fn CreateTensorRtLlmBackend(
+ fn create_backend_from_engine_folder(
engine_folder: &str,
executor_worker: &str,
) -> Result>;
- #[rust_name = "num_responses_ready"]
- fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
+ fn num_tokens_ready(self: &TensorRtLlmBackendImpl) -> usize;
- #[rust_name = "submit"]
- fn Submit(
+ fn submit(
self: Pin<&mut TensorRtLlmBackendImpl>,
tokens: &[u32],
max_new_tokens: u32,
- top_k: i32,
+ top_k: u32,
top_p: f32,
temperature: f32,
repetition_penalty: f32,
@@ -60,9 +80,24 @@ mod ffi {
seed: u64,
) -> Result;
- #[rust_name = "pull_tokens"]
- fn PullTokens(
+ fn pull_tokens(
self: Pin<&mut TensorRtLlmBackendImpl>,
) -> Result>>;
+
+ fn cancel(self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64);
+ }
+}
+
+use ffi::FinishReason;
+use text_generation_router::FinishReason as InferFinishReason;
+
+impl From for InferFinishReason {
+ fn from(reason: FinishReason) -> Self {
+ match reason {
+ FinishReason::StopWords => InferFinishReason::StopSequence,
+ FinishReason::MaxLength => InferFinishReason::Length,
+ FinishReason::EndTokenId => InferFinishReason::EndOfSequenceToken,
+ _ => panic!("Cannot convert {reason:?} to text_generation_router::FinishReason"),
+ }
}
}
diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs
index e26155c163c..5fed954fff7 100644
--- a/backends/trtllm/src/looper.rs
+++ b/backends/trtllm/src/looper.rs
@@ -1,14 +1,13 @@
-use std::hint;
-use std::ops::Deref;
-use std::path::Path;
-
use async_trait::async_trait;
use cxx::UniquePtr;
use hashbrown::HashMap;
+use std::hint;
+use std::ops::Deref;
+use std::path::Path;
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::TryAcquireError;
-use tokio::task::{spawn_blocking, JoinHandle};
+use tokio::task::spawn_blocking;
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, warn};
@@ -19,10 +18,12 @@ use text_generation_router::validation::ValidationError::{
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
};
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
-use text_generation_router::{FinishReason, Token};
+use text_generation_router::Token;
use crate::errors::TensorRtLlmBackendError;
-use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
+use crate::ffi::{
+ create_backend_from_engine_folder, FinishReason, GenerationStep, TensorRtLlmBackendImpl,
+};
use crate::utils::first_line;
type InferResult = Result;
@@ -30,9 +31,10 @@ type InferResult = Result;
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens
struct GenerationContext {
request: ValidGenerateRequest,
+ streamer: UnboundedSender>,
+ tokens: Vec,
start: Option,
queued: Instant,
- streamer: UnboundedSender>,
}
#[derive(Debug, Copy, Clone)]
@@ -40,6 +42,7 @@ struct DecodedToken {
id: u32,
log_prob: f32,
is_final: bool,
+ finish_reason: FinishReason,
}
impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
@@ -51,6 +54,7 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
id: step.token_id,
log_prob: step.log_prob,
is_final: step.is_final,
+ finish_reason: step.finish_reason,
})
} else {
Err(GenerationError(step.error_msg.clone()))
@@ -58,31 +62,22 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
}
}
-/// Wraps the decoded token with the channel used to stream back to the client the decoded tokens
-struct DecodedTokenContext {
- token: DecodedToken,
- start: Option,
- queued: Instant,
- channel: UnboundedSender>,
-}
-
fn executor_status_looper(
- mut backend: UniquePtr,
max_inflight_requests: usize,
- mut waiting_requests: UnboundedReceiver,
- post_processor_sender: UnboundedSender<(u64, InferResult)>,
+ tokenizer: Tokenizer,
+ mut backend: UniquePtr,
+ mut backlog: UnboundedReceiver,
) {
// Track the tuple (request_id, stream) for each request
let mut in_flights =
HashMap::::with_capacity(max_inflight_requests * 2);
- // TODO: Does it need a spin-loop?
'scheduler: loop {
// Is there any request pending to be scheduled?
- let awaiting_requests = waiting_requests.len();
+ let awaiting_requests = backlog.len();
for _ in 0..awaiting_requests {
// Retrieve all the requests
- if let Some(mut ctx) = waiting_requests.blocking_recv() {
+ if let Some(ctx) = backlog.blocking_recv() {
// Submit all the request to the executor and move the context to the in-flight tracker
let request = &ctx.request;
let generation_params = &request.parameters;
@@ -93,7 +88,7 @@ fn executor_status_looper(
match backend.pin_mut().submit(
&input_ids.unwrap(), // This is checked beforehand in validate()
stopping_params.max_new_tokens,
- generation_params.top_k as i32,
+ generation_params.top_k,
generation_params.top_p,
generation_params.temperature,
generation_params.repetition_penalty,
@@ -103,7 +98,6 @@ fn executor_status_looper(
Ok(request_id) => {
// Insert the context linked to the generated request id in the tracker
debug!("[in-flight] Added {}", request_id);
- ctx.start = Some(Instant::now());
in_flights.insert(request_id, ctx);
}
Err(e) => {
@@ -117,29 +111,43 @@ fn executor_status_looper(
}
}
};
+ } else {
+ break 'scheduler;
}
}
- if backend.num_responses_ready() > 0 {
- match backend.pin_mut().pull_tokens() {
+ if backend.num_tokens_ready() > 0 {
+ let mut backend = backend.pin_mut();
+ match backend.as_mut().pull_tokens() {
Ok(responses) => {
// Iterate through all the decoded token
for step in responses.deref() {
- if let Some(ctx) = in_flights.get(&step.request_id) {
- // Remove from tracked requests
- let parcel =
- DecodedToken::try_from(step).map(|dt| DecodedTokenContext {
- token: dt,
- start: ctx.start,
- queued: ctx.queued,
- channel: ctx.streamer.clone(),
- });
-
- // Submit the work to p:the post_processor
- let posted = post_processor_sender.send((step.request_id, parcel));
-
- if posted.is_err() || step.is_final {
- debug!("Removing {}", step.request_id);
+ if let Some(ctx) = in_flights.get_mut(&step.request_id) {
+ // Update the starting timestamp if not set
+ // This value might not be the actual real starting time of the request
+ // on the executor side - Need to expose more info from the executor to
+ // retrieve this value
+ // TODO : Expose actual real starting time for a request on FFI layer
+ if ctx.start.is_none() {
+ ctx.start = Some(Instant::now());
+ }
+
+ // Try to map the generation step to a DecodedToken
+ let response = match DecodedToken::try_from(step) {
+ Ok(decoded_token) => {
+ post_process_decoded_token(&tokenizer, ctx, decoded_token)
+ }
+ Err(err) => Err(err),
+ };
+
+ // Attempt to send back the response to the client
+ if let Err(_) = ctx.streamer.send(response) {
+ // Client has dropped, remove from tracked requests
+ debug!(
+ "Client dropped - removing request {} from tracked requests",
+ step.request_id
+ );
+ backend.as_mut().cancel(step.request_id);
let _ = in_flights.remove(&step.request_id);
}
} else {
@@ -159,80 +167,51 @@ fn executor_status_looper(
}
}
-fn post_processor_looper(
- tokenizer: Tokenizer,
- max_inflight_requests: usize,
- mut decoded_tokens: UnboundedReceiver<(u64, InferResult)>,
-) {
- let mut states: HashMap> = HashMap::with_capacity(max_inflight_requests * 2);
-
- 'post_processor: loop {
- if decoded_tokens.is_closed() {
- warn!("Post processor IPC is closed, loop will exit now.");
- break 'post_processor;
- }
-
- if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() {
- match decoded {
- Ok(ctx) => {
- states
- .entry(request_id)
- .and_modify(|s| s.push(*&ctx.token.id))
- .or_insert_with(|| {
- let mut state = Vec::with_capacity(MAX_NUM_TOKENS);
- state.push(*&ctx.token.id);
- state
- });
-
- let out = match tokenizer.decode(&[ctx.token.id], false) {
- Ok(text) => {
- let is_special =
- tokenizer.get_added_vocabulary().is_special_token(&text);
- let token = Token {
- id: ctx.token.id,
- text,
- logprob: ctx.token.log_prob,
- special: is_special,
- };
-
- let out = if !ctx.token.is_final {
- InferStreamResponse::Intermediate {
- token,
- top_tokens: vec![],
- }
- } else {
- let tokens = states.remove(&request_id).unwrap();
- let text = tokenizer.decode(&tokens, true);
- let generated_text = GeneratedText {
- text: text.unwrap(),
- generated_tokens: tokens.len() as u32,
- finish_reason: FinishReason::EndOfSequenceToken,
- seed: None,
- };
-
- InferStreamResponse::End {
- token,
- top_tokens: vec![],
- generated_text,
- start: ctx.start.unwrap(),
- queued: ctx.queued,
- }
- };
-
- Ok(out)
- }
- Err(err) => Err(GenerationError(err.to_string())),
- };
-
- if let Err(_) = ctx.channel.send(out) {
- warn!("Failed to send decoded token back to the user")
- }
+fn post_process_decoded_token(
+ tokenizer: &Tokenizer,
+ ctx: &mut GenerationContext,
+ decoded_token: DecodedToken,
+) -> InferResult {
+ match tokenizer.decode(&[decoded_token.id], false) {
+ Ok(text) => {
+ let is_special = tokenizer.get_added_vocabulary().is_special_token(&text);
+ let token = Token {
+ id: decoded_token.id,
+ text,
+ logprob: decoded_token.log_prob,
+ special: is_special,
+ };
+
+ // Append the token to the tracked generated tokens
+ ctx.tokens.push(token.id);
+
+ // Map the correct response depending on the step is final or not
+ let out = if !decoded_token.is_final {
+ InferStreamResponse::Intermediate {
+ token,
+ top_tokens: vec![],
}
- Err(_err) => {
- todo!("what do we do?")
+ } else {
+ let text = tokenizer.decode(&ctx.tokens, true);
+ let generated_text = GeneratedText {
+ text: text.unwrap(),
+ generated_tokens: ctx.tokens.len() as u32,
+ finish_reason: decoded_token.finish_reason.into(),
+ seed: None,
+ };
+
+ InferStreamResponse::End {
+ token,
+ top_tokens: vec![],
+ generated_text,
+ start: ctx.start.unwrap(),
+ queued: ctx.queued,
}
- }
+ };
+
+ Ok(out)
}
+ Err(err) => Err(GenerationError(err.to_string())),
}
}
@@ -277,11 +256,7 @@ fn ensure_paths_exist, PP: AsRef>(
unsafe impl Send for TensorRtLlmBackendImpl {}
-pub struct TensorRtLlmBackendV2 {
- executor_looper: JoinHandle<()>,
- post_processor_looper: JoinHandle<()>,
- executor: UnboundedSender,
-}
+pub struct TensorRtLlmBackendV2(UnboundedSender);
impl TensorRtLlmBackendV2 {
pub fn new + Send, PP: AsRef + Send>(
@@ -295,32 +270,17 @@ impl TensorRtLlmBackendV2 {
// Allocate the IPC layer to communicate with the backend
let (executor_sender, executor_receiver) = unbounded_channel();
- let (post_processor_sender, post_processor_receiver) = unbounded_channel();
// Create the FFI backend
- let backend = create_tensorrt_llm_backend(&engine_folder, &executor_worker_path)
+ let backend = create_backend_from_engine_folder(&engine_folder, &executor_worker_path)
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
// Executor looper is responsible for scheduling and pulling requests state at regular interval
- let executor_looper = spawn_blocking(move || {
- executor_status_looper(
- backend,
- max_inflight_requests,
- executor_receiver,
- post_processor_sender,
- )
+ spawn_blocking(move || {
+ executor_status_looper(max_inflight_requests, tokenizer, backend, executor_receiver)
});
- // Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
- let post_processor_looper = spawn_blocking(move || {
- post_processor_looper::<256>(tokenizer, max_inflight_requests, post_processor_receiver)
- });
-
- Ok(TensorRtLlmBackendV2 {
- executor_looper,
- post_processor_looper,
- executor: executor_sender,
- })
+ Ok(TensorRtLlmBackendV2(executor_sender))
}
fn validate(request: &ValidGenerateRequest) -> InferResult<()> {
@@ -354,20 +314,21 @@ impl TensorRtLlmBackendV2 {
impl Backend for TensorRtLlmBackendV2 {
fn schedule(
&self,
- inner: ValidGenerateRequest,
+ request: ValidGenerateRequest,
) -> Result>, InferError> {
- Self::validate(&inner)?;
+ Self::validate(&request)?;
// Open-up the stream to send tokens
let (streamer, receiver) = unbounded_channel::>();
// Send the context to the executor for scheduling
let queued = Instant::now();
- match self.executor.send(GenerationContext {
- request: inner,
+ match self.0.send(GenerationContext {
+ request,
+ streamer,
+ tokens: Vec::with_capacity(256),
start: None,
queued,
- streamer,
}) {
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
Err(_) => Err(GenerationError(
@@ -377,6 +338,10 @@ impl Backend for TensorRtLlmBackendV2 {
}
async fn health(&self, _: bool) -> bool {
- !self.executor_looper.is_finished() & !self.post_processor_looper.is_finished()
+ true
+ }
+
+ fn name(&self) -> &'static str {
+ "TensorRT-LLM"
}
}
diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs
index 8ab8c533cfb..543f8e6e352 100644
--- a/backends/trtllm/src/main.rs
+++ b/backends/trtllm/src/main.rs
@@ -3,14 +3,15 @@ use std::path::{Path, PathBuf};
use clap::Parser;
use hf_hub::api::tokio::{Api, ApiBuilder};
use hf_hub::{Cache, Repo, RepoType};
-use tokenizers::Tokenizer;
use tracing::info;
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
use text_generation_backends_trtllm::TensorRtLlmBackendV2;
-use text_generation_router::server::get_base_tokenizer;
+use text_generation_router::server::{
+ get_hub_model_info, legacy_tokenizer_handle, py_resolve_tokenizer,
+};
use text_generation_router::usage_stats::UsageStatsLevel;
-use text_generation_router::{server, HubTokenizerConfig};
+use text_generation_router::{server, Tokenizer};
/// App Configuration
#[derive(Parser, Debug)]
@@ -36,6 +37,8 @@ struct Args {
hostname: String,
#[clap(default_value = "3000", long, short, env)]
port: u16,
+ #[clap(default_value = "9000", long, short, env)]
+ prometheus_port: u16,
#[clap(long, env, required = true)]
tokenizer_name: String,
#[clap(long, env)]
@@ -61,16 +64,12 @@ struct Args {
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
executor_worker: PathBuf,
#[clap(default_value = "on", long, env)]
- usage_stats: usage_stats::UsageStatsLevel,
+ usage_stats: UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
}
-async fn get_tokenizer(
- tokenizer_name: &str,
- tokenizer_config_path: Option<&str>,
- revision: Option<&str>,
-) -> Option {
+async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option {
// Parse Huggingface hub token
let authorization_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
@@ -89,6 +88,10 @@ async fn get_tokenizer(
builder = builder.with_cache_dir(cache_dir.into());
}
+ if let Ok(origin) = std::env::var("HF_HUB_USER_AGENT_ORIGIN") {
+ builder = builder.with_user_agent("origin", origin.as_str());
+ }
+
builder
};
@@ -126,18 +129,18 @@ async fn get_tokenizer(
// Load tokenizer and model info
let (
- tokenizer_filename,
- _config_filename,
- tokenizer_config_filename,
+ config_filename,
+ _tokenizer_config_filename,
_preprocessor_config_filename,
_processor_config_filename,
+ _model_info,
) = match api {
Type::None => (
- Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")),
+ None,
),
Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision(
@@ -146,21 +149,23 @@ async fn get_tokenizer(
revision.unwrap_or_else(|| "main").to_string(),
));
- let tokenizer_filename = match api_repo.get("tokenizer.json").await {
- Ok(tokenizer_filename) => Some(tokenizer_filename),
- Err(_) => get_base_tokenizer(&api, &api_repo).await,
- };
let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
+ let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
+ Some(model_info)
+ } else {
+ tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
+ None
+ };
(
- tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
+ model_info,
)
}
Type::Cache(cache) => {
@@ -170,24 +175,42 @@ async fn get_tokenizer(
revision.clone().unwrap_or_else(|| "main").to_string(),
));
(
- repo.get("tokenizer.json"),
repo.get("config.json"),
repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
repo.get("processor_config.json"),
+ None,
)
}
};
- // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
- let tokenizer_config: Option = if let Some(filename) = tokenizer_config_path
- {
- HubTokenizerConfig::from_file(filename)
- } else {
- tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
+ let tokenizer: Tokenizer = {
+ use pyo3::prelude::*;
+ pyo3::Python::with_gil(|py| -> PyResult<()> {
+ py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), false)?;
+ Ok(())
+ })
+ .inspect_err(|err| {
+ tracing::error!("Failed to import python tokenizer {err}");
+ })
+ .or_else(|err| {
+ let out = legacy_tokenizer_handle(config_filename.as_ref());
+ out.ok_or(err)
+ })
+ .expect("We cannot load a tokenizer");
+ let filename = "out/tokenizer.json";
+ if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) {
+ Tokenizer::Rust(tok)
+ } else {
+ Tokenizer::Python {
+ tokenizer_name: tokenizer_name.to_string(),
+ revision: revision.map(|revision| revision.to_string()),
+ trust_remote_code: false,
+ }
+ }
};
- tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok())
+ Some(tokenizer)
}
#[tokio::main]
@@ -206,6 +229,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
max_batch_total_tokens,
hostname,
port,
+ prometheus_port,
tokenizer_name,
tokenizer_config_path,
revision,
@@ -258,50 +282,53 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
}
// Create the backend
- let tokenizer = get_tokenizer(
- &tokenizer_name,
- tokenizer_config_path.as_deref(),
- revision.as_deref(),
- )
- .await
- .expect("Failed to retrieve tokenizer implementation");
-
- info!("Successfully retrieved tokenizer {}", &tokenizer_name);
- let backend = TensorRtLlmBackendV2::new(
- tokenizer,
- model_id,
- executor_worker,
- max_concurrent_requests,
- )?;
+ match get_tokenizer(&tokenizer_name, revision.as_deref())
+ .await
+ .expect("Failed to retrieve tokenizer implementation")
+ {
+ Tokenizer::Python { .. } => Err(TensorRtLlmBackendError::Tokenizer(
+ "Failed to retrieve Rust based tokenizer".to_string(),
+ )),
+ Tokenizer::Rust(tokenizer) => {
+ info!("Successfully retrieved tokenizer {}", &tokenizer_name);
+ let backend = TensorRtLlmBackendV2::new(
+ tokenizer,
+ model_id,
+ executor_worker,
+ max_concurrent_requests,
+ )?;
- info!("Successfully created backend");
+ info!("Successfully created backend");
- // Run server
- server::run(
- backend,
- max_concurrent_requests,
- max_best_of,
- max_stop_sequences,
- max_top_n_tokens,
- max_input_tokens,
- max_total_tokens,
- validation_workers,
- auth_token,
- tokenizer_name,
- tokenizer_config_path,
- revision,
- false,
- hostname,
- port,
- cors_allow_origin,
- false,
- None,
- None,
- true,
- max_client_batch_size,
- usage_stats,
- payload_limit,
- )
- .await?;
- Ok(())
+ // Run server
+ server::run(
+ backend,
+ max_concurrent_requests,
+ max_best_of,
+ max_stop_sequences,
+ max_top_n_tokens,
+ max_input_tokens,
+ max_total_tokens,
+ validation_workers,
+ auth_token,
+ tokenizer_name,
+ tokenizer_config_path,
+ revision,
+ false,
+ hostname,
+ port,
+ cors_allow_origin,
+ false,
+ None,
+ None,
+ true,
+ max_client_batch_size,
+ usage_stats,
+ payload_limit,
+ prometheus_port,
+ )
+ .await?;
+ Ok(())
+ }
+ }
}
diff --git a/backends/trtllm/tests/infer_test.cpp b/backends/trtllm/tests/infer_test.cpp
deleted file mode 100644
index 8520065a759..00000000000
--- a/backends/trtllm/tests/infer_test.cpp
+++ /dev/null
@@ -1,14 +0,0 @@
-//
-// Created by mfuntowicz on 7/2/24.
-//
-#include
-#include
-#include "../include/backend.h"
-
-TEST_CASE("Load TRTLLM Engine on the TGI Backend", "[trtllm][engine][load]") {
- const auto engines = std::filesystem::path("/home/mfuntowicz/.cache/huggingface/assets/trtllm/0.11.0.dev2024062500/meta-llama--Meta-Llama-3-8B-Instruct/4090/engines/");
- const auto executor = std::filesystem::path("/home/mfuntowicz/Workspace/text-generation-inference/backends/trtllm/cmake-build-debug/cmake-build-debug/_deps/trtllm-src/cpp/tensorrt_llm/executor_worker/executorWorker");
-
- spdlog::info("Loading config from: {}", absolute(engines).string());
- huggingface::tgi::backends::TensorRtLlmBackend backend(engines, executor);
-}
diff --git a/backends/trtllm/tests/test_backend.cpp b/backends/trtllm/tests/test_backend.cpp
new file mode 100644
index 00000000000..f44cc03f9b4
--- /dev/null
+++ b/backends/trtllm/tests/test_backend.cpp
@@ -0,0 +1,154 @@
+//
+// Created by mfuntowicz on 12/3/24.
+//
+
+#include
+#include
+#include
+
+#include "backend.hpp"
+
+using namespace huggingface::tgi::backends::trtllm;
+
+TEST_CASE("parse generation_config.json all set", "[generation_config_t]")
+{
+ const json config_j = {{"temperature", 0.6},
+ {"top_p", 0.95},
+ {"eos_token_id", {1, 2, 3}}};
+ const auto generation_config = generation_config_t(config_j);
+
+ REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(0.6, 1e-6));
+ REQUIRE_THAT(generation_config.top_p, Catch::Matchers::WithinAbs(0.95, 1e-6));
+
+ // Stop words
+ REQUIRE_FALSE(generation_config.stop_words.empty());
+ REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());
+
+ for (auto [lhs, rhs]: std::views::zip(generation_config.stop_words, std::list>{{1},
+ {2},
+ {3}})) {
+ // Currently we do not support multi-tokens stop words
+ REQUIRE(lhs.size() == 1);
+ REQUIRE(rhs.size() == 1);
+ REQUIRE_THAT(lhs, Catch::Matchers::UnorderedEquals(rhs));
+ }
+}
+
+TEST_CASE("parse generation_config.json default", "[generation_config_t]")
+{
+ const json config_j = {{"eos_token_id", {1, 2, 3}}};
+ const auto generation_config = generation_config_t(config_j);
+
+ REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6));
+ REQUIRE_THAT(generation_config.top_p, Catch::Matchers::WithinAbs(1.0, 1e-6));
+
+ REQUIRE_FALSE(generation_config.stop_words.empty());
+ REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());
+
+ for (auto [lhs, rhs]: std::views::zip(generation_config.stop_words, std::list>{{1},
+ {2},
+ {3}})) {
+ // Currently we do not support multi-tokens stop words
+ REQUIRE(lhs.size() == 1);
+ REQUIRE(rhs.size() == 1);
+ REQUIRE_THAT(lhs, Catch::Matchers::UnorderedEquals(rhs));
+ }
+}
+
+TEST_CASE("parse generation_config.json empty", "[generation_config_t]")
+{
+ const json config_j = {{"eos_token_id", {}}};
+ const auto generation_config = generation_config_t(config_j);
+
+ REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6));
+ REQUIRE_THAT(generation_config.top_p, Catch::Matchers::WithinAbs(1.0, 1e-6));
+
+ REQUIRE(generation_config.stop_words.empty());
+
+ const json config_j2 = {};
+ const auto generation_config2 = generation_config_t(config_j);
+
+ REQUIRE_THAT(generation_config2.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6));
+ REQUIRE_THAT(generation_config2.top_p, Catch::Matchers::WithinAbs(1.0, 1e-6));
+
+ REQUIRE(generation_config2.stop_words.empty());
+}
+
+TEST_CASE("parallel_config single", "[backend_workspace_t]")
+{
+ // Generate temporary folder
+ const auto tmp_p = std::filesystem::temp_directory_path();
+ const auto config_p = tmp_p / "config.json";
+ const auto generation_config_p = tmp_p / "generation_config.json";
+
+ // Generate content
+ std::ofstream o_config(config_p);
+ o_config << R"({"pretrained_config": {"mapping": {"world_size": 2}}})"_json;
+ o_config.close();
+
+ std::ofstream o_generation_config(generation_config_p);
+ o_generation_config << R"({"eos_token_id": []})"_json;
+ o_generation_config.close();
+
+ const auto workspace = backend_workspace_t(tmp_p.generic_string(), tmp_p.generic_string());
+ const auto parallel = workspace.parallel_config();
+ REQUIRE(parallel.getCommunicationMode() == tle::CommunicationMode::kORCHESTRATOR);
+ REQUIRE(parallel.getCommunicationType() == tle::CommunicationType::kMPI);
+
+ std::filesystem::remove(config_p);
+ std::filesystem::remove(generation_config_p);
+}
+
+TEST_CASE("parallel_config multi", "[backend_workspace_t]")
+{
+ // Generate temporary folder
+ const auto tmp_p = std::filesystem::temp_directory_path();
+ const auto config_p = tmp_p / "config.json";
+ const auto generation_config_p = tmp_p / "generation_config.json";
+
+ // Generate content
+ std::ofstream o_config(config_p);
+ o_config << R"({"pretrained_config": {"mapping": {"world_size": 1}}})"_json;
+ o_config.close();
+
+ std::ofstream o_generation_config(generation_config_p);
+ o_generation_config << R"({"eos_token_id": []})"_json;
+ o_generation_config.close();
+
+ const auto workspace = backend_workspace_t(tmp_p.generic_string(), tmp_p.generic_string());
+ const auto parallel = workspace.parallel_config();
+ REQUIRE(parallel.getCommunicationMode() == tle::CommunicationMode::kLEADER);
+ REQUIRE(parallel.getCommunicationType() == tle::CommunicationType::kMPI);
+
+ std::filesystem::remove(config_p);
+ std::filesystem::remove(generation_config_p);
+}
+
+TEST_CASE("executor_config", "[backend_workspace_t]")
+{
+
+}
+
+TEST_CASE("sampling_params_t to tle::SamplingConfig", "[backend_t]")
+{
+ const sampling_params_t params = {40, 0.95, 0.9, 1.0, 0.6, 2014};
+ const auto config = static_cast(params);
+
+ REQUIRE(config.getTopK().has_value());
+ REQUIRE(config.getTopK().value() == params.top_k);
+
+ REQUIRE(config.getSeed().has_value());
+ REQUIRE(config.getSeed().value() == params.seed);
+
+ REQUIRE(config.getTopP().has_value());
+ REQUIRE_THAT(*config.getTopP(), Catch::Matchers::WithinAbs(params.top_p, 1e-6f));
+
+ REQUIRE(config.getRepetitionPenalty().has_value());
+ REQUIRE_THAT(*config.getRepetitionPenalty(), Catch::Matchers::WithinAbs(params.repetition_penalty, 1e-6f));
+
+ REQUIRE(config.getFrequencyPenalty().has_value());
+ REQUIRE_THAT(*config.getFrequencyPenalty(), Catch::Matchers::WithinAbs(params.frequency_penalty, 1e-6f));
+
+ REQUIRE(config.getTemperature().has_value());
+ REQUIRE_THAT(*config.getTemperature(), Catch::Matchers::WithinAbs(params.temperature, 1e-6f));
+}
diff --git a/backends/trtllm/tests/test_hardware.cpp b/backends/trtllm/tests/test_hardware.cpp
new file mode 100644
index 00000000000..e14f1f357f4
--- /dev/null
+++ b/backends/trtllm/tests/test_hardware.cpp
@@ -0,0 +1,82 @@
+//
+// Created by mfuntowicz on 11/16/24.
+//
+
+#include
+#include "../csrc/hardware.hpp"
+
+using namespace huggingface::tgi::hardware::cuda;
+
+TEST_CASE("is_at_least_") {
+ const static auto VOLTA_CAPABILITIES = compute_capabilities_t(7, 0);
+ REQUIRE(VOLTA_CAPABILITIES.is_at_least_volta());
+ REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least_turing());
+ REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least_ampere());
+ REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least_ada_lovelace());
+ REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least_hopper());
+
+ const static auto TURING_CAPABILITIES = compute_capabilities_t(7, 5);
+ REQUIRE(TURING_CAPABILITIES.is_at_least_volta());
+ REQUIRE(TURING_CAPABILITIES.is_at_least_turing());
+ REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least_ampere());
+ REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least_ada_lovelace());
+ REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least_hopper());
+
+ const static auto AMPERE_CAPABILITIES = compute_capabilities_t(8, 0);
+ REQUIRE(AMPERE_CAPABILITIES.is_at_least_volta());
+ REQUIRE(AMPERE_CAPABILITIES.is_at_least_turing());
+ REQUIRE(AMPERE_CAPABILITIES.is_at_least_ampere());
+ REQUIRE_FALSE(AMPERE_CAPABILITIES.is_at_least_ada_lovelace());
+ REQUIRE_FALSE(AMPERE_CAPABILITIES.is_at_least_hopper());
+
+ const static auto ADA_LOVELACE_CAPABILITIES = compute_capabilities_t(8, 9);
+ REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least_volta());
+ REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least_turing());
+ REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least_ampere());
+ REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least_ada_lovelace());
+ REQUIRE_FALSE(ADA_LOVELACE_CAPABILITIES.is_at_least_hopper());
+
+ const static auto HOPPER_CAPABILITIES = compute_capabilities_t(9, 0);
+ REQUIRE(HOPPER_CAPABILITIES.is_at_least_volta());
+ REQUIRE(HOPPER_CAPABILITIES.is_at_least_turing());
+ REQUIRE(HOPPER_CAPABILITIES.is_at_least_ampere());
+ REQUIRE(HOPPER_CAPABILITIES.is_at_least_ada_lovelace());
+ REQUIRE(HOPPER_CAPABILITIES.is_at_least_hopper());
+}
+
+TEST_CASE("is_at_least") {
+ const static auto VOLTA_CAPABILITIES = compute_capabilities_t(7, 0);
+ REQUIRE(VOLTA_CAPABILITIES.is_at_least(VOLTA));
+ REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least(TURING));
+ REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least(AMPERE));
+ REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least(ADA_LOVELACE));
+ REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least(HOPPER));
+
+ const static auto TURING_CAPABILITIES = compute_capabilities_t(7, 5);
+ REQUIRE(TURING_CAPABILITIES.is_at_least(VOLTA));
+ REQUIRE(TURING_CAPABILITIES.is_at_least(TURING));
+ REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least(AMPERE));
+ REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least(ADA_LOVELACE));
+ REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least(HOPPER));
+
+ const static auto AMPERE_CAPABILITIES = compute_capabilities_t(8, 0);
+ REQUIRE(AMPERE_CAPABILITIES.is_at_least(VOLTA));
+ REQUIRE(AMPERE_CAPABILITIES.is_at_least(TURING));
+ REQUIRE(AMPERE_CAPABILITIES.is_at_least(AMPERE));
+ REQUIRE_FALSE(AMPERE_CAPABILITIES.is_at_least(ADA_LOVELACE));
+ REQUIRE_FALSE(AMPERE_CAPABILITIES.is_at_least(HOPPER));
+
+ const static auto ADA_LOVELACE_CAPABILITIES = compute_capabilities_t(8, 9);
+ REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least(VOLTA));
+ REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least(TURING));
+ REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least(AMPERE));
+ REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least(ADA_LOVELACE));
+ REQUIRE_FALSE(ADA_LOVELACE_CAPABILITIES.is_at_least(HOPPER));
+
+ const static auto HOPPER_CAPABILITIES = compute_capabilities_t (9, 0);
+ REQUIRE(HOPPER_CAPABILITIES.is_at_least(VOLTA));
+ REQUIRE(HOPPER_CAPABILITIES.is_at_least(TURING));
+ REQUIRE(HOPPER_CAPABILITIES.is_at_least(AMPERE));
+ REQUIRE(HOPPER_CAPABILITIES.is_at_least(ADA_LOVELACE));
+ REQUIRE(HOPPER_CAPABILITIES.is_at_least(HOPPER));
+}
diff --git a/backends/v2/Cargo.toml b/backends/v2/Cargo.toml
index 4d32474e77f..0decf41ad0d 100644
--- a/backends/v2/Cargo.toml
+++ b/backends/v2/Cargo.toml
@@ -23,7 +23,7 @@ clap = { version = "4.4.5", features = ["derive", "env"] }
grpc-metadata = { path = "../grpc-metadata" }
futures = "0.3.28"
hf-hub = { workspace = true }
-jsonschema = { version = "0.17.1", features = ["draft202012"] }
+jsonschema = { version = "0.28.0" }
metrics = { workspace = true }
metrics-exporter-prometheus = { workspace = true }
nohash-hasher = "0.2.0"
diff --git a/backends/v2/src/backend.rs b/backends/v2/src/backend.rs
index cfe87f98fa9..adca3d5d25f 100644
--- a/backends/v2/src/backend.rs
+++ b/backends/v2/src/backend.rs
@@ -108,6 +108,10 @@ impl Backend for BackendV2 {
fn start_health(&self) -> bool {
true
}
+
+ fn name(&self) -> &'static str {
+ "tgi-v2"
+ }
}
/// Batching logic
diff --git a/backends/v2/src/main.rs b/backends/v2/src/main.rs
index f537690e4f8..60b5d52bbe2 100644
--- a/backends/v2/src/main.rs
+++ b/backends/v2/src/main.rs
@@ -36,6 +36,8 @@ struct Args {
hostname: String,
#[clap(default_value = "3000", long, short, env)]
port: u16,
+ #[clap(default_value = "9000", long, short, env)]
+ prometheus_port: u16,
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)]
@@ -99,6 +101,7 @@ async fn main() -> Result<(), RouterError> {
max_batch_size,
hostname,
port,
+ prometheus_port,
master_shard_uds_path,
tokenizer_name,
tokenizer_config_path,
@@ -198,6 +201,7 @@ async fn main() -> Result<(), RouterError> {
max_client_batch_size,
usage_stats,
payload_limit,
+ prometheus_port,
)
.await?;
Ok(())
diff --git a/backends/v2/src/queue.rs b/backends/v2/src/queue.rs
index 61a3eebc927..c9a9335dd9d 100644
--- a/backends/v2/src/queue.rs
+++ b/backends/v2/src/queue.rs
@@ -213,8 +213,7 @@ impl State {
}
// Pad prefill_token_budget to be a multiple of block size
- let prefill_token_budget =
- ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
+ let prefill_token_budget = prefill_token_budget.div_ceil(self.block_size) * self.block_size;
// Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
@@ -245,9 +244,8 @@ impl State {
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
} else {
// pad to block size
- prefill_tokens += ((entry.request.input_length + self.block_size - 1)
- / self.block_size)
- * self.block_size;
+ prefill_tokens +=
+ entry.request.input_length.div_ceil(self.block_size) * self.block_size;
}
if self.requires_padding {
@@ -262,8 +260,7 @@ impl State {
};
// pad to block size
- decode_tokens +=
- ((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size;
+ decode_tokens += max_new_tokens.div_ceil(self.block_size) * self.block_size;
}
if prefill_tokens > prefill_token_budget
diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml
index 69dad072fac..588a2716fe1 100644
--- a/backends/v3/Cargo.toml
+++ b/backends/v3/Cargo.toml
@@ -23,7 +23,7 @@ clap = { version = "4.4.5", features = ["derive", "env"] }
grpc-metadata = { path = "../grpc-metadata" }
futures = "0.3.28"
hf-hub = { workspace = true }
-jsonschema = { version = "0.17.1", features = ["draft202012"] }
+jsonschema = { version = "0.28.0" }
metrics = { workspace = true }
metrics-exporter-prometheus = { workspace = true }
nohash-hasher = "0.2.0"
@@ -71,6 +71,7 @@ prost-build = "0.12.1"
[dev-dependencies]
criterion = "0.3"
itertools = "0.13"
+rustc-hash = "2"
[features]
default = ["ngrok"]
diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs
index 736301b3382..98e8d76f09f 100644
--- a/backends/v3/src/backend.rs
+++ b/backends/v3/src/backend.rs
@@ -115,6 +115,10 @@ impl Backend for BackendV3 {
fn start_health(&self) -> bool {
true
}
+
+ fn name(&self) -> &'static str {
+ "tgi-v3"
+ }
}
/// Batching logic
diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs
index 4fea172b65a..c8b2920414e 100644
--- a/backends/v3/src/block_allocator.rs
+++ b/backends/v3/src/block_allocator.rs
@@ -2,7 +2,7 @@ use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
use crate::radix::RadixAllocator;
-
+use text_generation_router::usage_stats::Env;
#[derive(Debug, Clone)]
pub struct BlockAllocation {
pub allocation_id: u64,
@@ -141,6 +141,7 @@ pub struct SimpleAllocator {
free_blocks: Vec,
block_size: u32,
window_size: Option,
+ is_hpu_device: bool,
}
impl SimpleAllocator {
@@ -150,6 +151,7 @@ impl SimpleAllocator {
// Block 0 is reserved for health checks
free_blocks: (1..blocks).collect(),
window_size,
+ is_hpu_device: Env::new().is_hpu_device(),
}
}
}
@@ -160,28 +162,38 @@ impl Allocator for SimpleAllocator {
tokens: u32,
_prefill_tokens: Option>>,
) -> Option {
+ let mut tokens = tokens;
+ if self.is_hpu_device {
+ // need 1 slot for ping-pong optimization
+ tokens += 1;
+ }
// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match self.window_size {
None => (tokens, 1),
Some(window_size) => {
- let repeats = (tokens + window_size - 1) / window_size;
+ let repeats = tokens.div_ceil(window_size);
let tokens = core::cmp::min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
- let required_blocks = (tokens + self.block_size - 1) / self.block_size;
+ let required_blocks = tokens.div_ceil(self.block_size);
(required_blocks, repeats)
};
-
let tokens = tokens as usize;
if required_blocks > self.free_blocks.len() as u32 {
None
} else {
- let blocks = self
+ if self.is_hpu_device {
+ self.free_blocks.sort_by(|a, b| b.cmp(a));
+ }
+ let mut blocks = self
.free_blocks
.split_off(self.free_blocks.len() - required_blocks as usize);
+ if self.is_hpu_device {
+ blocks.sort();
+ }
let mut slots =
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs
index 52e41b55a33..44e63853e04 100644
--- a/backends/v3/src/main.rs
+++ b/backends/v3/src/main.rs
@@ -36,6 +36,8 @@ struct Args {
hostname: String,
#[clap(default_value = "3000", long, short, env)]
port: u16,
+ #[clap(default_value = "9000", long, short, env)]
+ prometheus_port: u16,
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)]
@@ -99,6 +101,7 @@ async fn main() -> Result<(), RouterError> {
max_batch_size,
hostname,
port,
+ prometheus_port,
master_shard_uds_path,
tokenizer_name,
tokenizer_config_path,
@@ -214,6 +217,7 @@ async fn main() -> Result<(), RouterError> {
max_client_batch_size,
usage_stats,
payload_limit,
+ prometheus_port,
)
.await?;
Ok(())
diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs
index dd27806f97a..8cfee3a5016 100644
--- a/backends/v3/src/queue.rs
+++ b/backends/v3/src/queue.rs
@@ -8,6 +8,7 @@ use std::cmp::max;
use std::collections::VecDeque;
use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse;
+use text_generation_router::usage_stats::Env;
use text_generation_router::validation::{
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
ValidStoppingParameters,
@@ -185,6 +186,9 @@ struct State {
/// Paged Attention Block Allocation
block_allocator: Option,
+
+ /// indicate if it's hpu device, the hpu device needs padding to generate first token.
+ is_hpu_device: bool,
}
impl State {
@@ -214,6 +218,7 @@ impl State {
speculate,
support_chunking,
block_allocator,
+ is_hpu_device: Env::new().is_hpu_device(),
}
}
@@ -257,8 +262,7 @@ impl State {
}
// Pad prefill_token_budget to be a multiple of block size
- let prefill_token_budget =
- ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
+ let prefill_token_budget = prefill_token_budget.div_ceil(self.block_size) * self.block_size;
// Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
@@ -312,7 +316,7 @@ impl State {
+ entry.request.stopping_parameters.max_new_tokens
+ self.speculate
- 1;
- tracing::debug!("Allocating {tokens} with {input_ids:?}");
+ // tracing::debug!("Allocating {tokens} with {input_ids:?}");
let block_allocation = match block_allocator.allocate(tokens, input_ids).await {
None => {
@@ -323,7 +327,7 @@ impl State {
break 'entry_loop;
}
Some(mut block_allocation) => {
- tracing::debug!("Allocation: {block_allocation:?}");
+ // tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
if block_allocation.prefix_len == entry.request.input_length {
@@ -369,6 +373,21 @@ impl State {
}
}
+ if self.is_hpu_device {
+ //HPU needs to pad for the prefill
+ max_input_length = max_input_length.max(entry.request.input_length);
+ let actual_prefill_tokens_for_hpu =
+ (batch.len() + 1) as u32 * max_input_length;
+
+ if actual_prefill_tokens_for_hpu > prefill_token_budget {
+ // Entry is over budget
+ // Add it back to the front
+ tracing::debug!("Over budget: prefill_tokens={actual_prefill_tokens_for_hpu} > {prefill_token_budget}");
+ self.entries.push_front((id, entry));
+ break 'entry_loop;
+ }
+ }
+
prefill_tokens += postfix_len;
Some(block_allocation)
diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs
index 8a544891199..aea69693f96 100644
--- a/backends/v3/src/radix.rs
+++ b/backends/v3/src/radix.rs
@@ -103,7 +103,7 @@ impl Allocator for RadixAllocator {
let prefix_len = blocks.len() * self.block_size as usize;
let suffix_len = tokens - prefix_len as u32;
- let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size;
+ let suffix_blocks = suffix_len.div_ceil(self.block_size);
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
@@ -283,7 +283,7 @@ impl RadixTrie {
}
/// Find worker.
- fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec) -> NodeId {
+ fn find_(&mut self, node_id: NodeId, key: &[u32], blocks: &mut Vec) -> NodeId {
let node = &self.nodes[node_id];
if key.len() >= self.block_size {
@@ -295,9 +295,13 @@ impl RadixTrie {
assert_eq!(shared_prefix_len % self.block_size, 0);
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
+ // A node represents the prefix of its children. So, only
+ // recurse when there is a full prefix match.
let key = &key[shared_prefix_len..];
- if !key.is_empty() {
- node_id = self.find_(child_id, key, blocks);
+ if !key.is_empty() && shared_prefix_len == child.key.len() {
+ return self.find_(child_id, key, blocks);
+ } else {
+ return child_id;
}
}
}
@@ -631,6 +635,12 @@ fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize {
mod tests {
use std::sync::Arc;
+ use rand::{
+ distributions::Uniform, prelude::Distribution, rngs::SmallRng, seq::SliceRandom,
+ SeedableRng,
+ };
+ use rustc_hash::FxHashSet;
+
use super::*;
#[test]
@@ -873,4 +883,159 @@ mod tests {
// Clear out the whole trie.
assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);
}
+
+ #[test]
+ fn full_match_returns_correct_node() {
+ let mut trie = RadixTrie::new(1);
+ trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
+ let node_id = trie.find(&[0, 1, 2], &mut vec![]);
+ // At this point, there are only two nodes: the root and the node
+ // with tokens 0, 1, 2. Looking up the exact prefix must return
+ // the non-root node.
+ assert_ne!(node_id, trie.root);
+ }
+
+ #[test]
+ fn partial_match_does_not_recurse() {
+ let mut trie = RadixTrie::new(1);
+ trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
+ trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2, 3, 4, 5])
+ .unwrap();
+ let mut blocks = Vec::new();
+ let node_id = trie.find(&[0, 1, 3, 4, 5], &mut blocks);
+ assert_eq!(blocks, vec![0, 1]);
+ assert_eq!(node_id, trie.find(&[0, 1], &mut blocks))
+ }
+
+ struct AllocationWithInfo {
+ allocation: BlockAllocation,
+ // We are doing a lot of set operations and `FxBuildHasher` is
+ // muc faster for a set of integers.
+ blockset: FxHashSet,
+ non_prefix_blocks: FxHashSet,
+ }
+
+ #[test]
+ fn invariants_hold_on_many_operations_remove_all() {
+ invariants_hold_on_many_insertions(true);
+ }
+
+ #[test]
+ fn invariants_hold_on_many_operations_remove_subset() {
+ invariants_hold_on_many_insertions(false);
+ }
+
+ fn invariants_hold_on_many_insertions(remove_all: bool) {
+ // Small vocabulary sizes lead to violations more quickly due to
+ // prefix sharing, etc.
+ const VOCAB_SIZE: u32 = 2;
+ const DATA_LEN: usize = 1_000;
+
+ const MAX_PREFILL_LEN: usize = 8;
+ const MAX_DECODE_LEN: usize = 8;
+
+ let vocab_range = Uniform::new(0, VOCAB_SIZE);
+ let data_range = Uniform::new(0, DATA_LEN);
+ let prefill_len_range = Uniform::new(0, MAX_PREFILL_LEN);
+ let decode_len_range = Uniform::new(0, MAX_DECODE_LEN);
+
+ let mut rng = SmallRng::seed_from_u64(64);
+ let data = (0..DATA_LEN)
+ .map(|_| vocab_range.sample(&mut rng))
+ .collect::>();
+ let mut allocator = RadixAllocator::new(1, 100, None);
+
+ let mut allocations = Vec::new();
+
+ for i in 0..100_000 {
+ // Allocate until all blocks are used.
+ 'allocation: loop {
+ // Use offset 0 half of the times for prefix sharing.
+ let prefill_offset = data_range.sample(&mut rng);
+ let prefill_len = prefill_len_range.sample(&mut rng);
+ let decode_len = decode_len_range.sample(&mut rng);
+
+ let prefill =
+ data[prefill_offset..data.len().min(prefill_offset + prefill_len)].to_vec();
+
+ let allocation = match allocator
+ .allocate((prefill.len() + decode_len) as u32, Some(Arc::new(prefill)))
+ {
+ Some(allocation) => allocation,
+ None => break 'allocation,
+ };
+ let non_prefix_blocks = allocation.blocks[allocation.prefix_len as usize..]
+ .iter()
+ .copied()
+ .collect::>();
+ let blockset = allocation.blocks.iter().copied().collect::>();
+
+ // No duplicate blocks in an allocation.
+ assert_eq!(
+ allocation.blocks.len(),
+ blockset.len(),
+ "Duplicate blocks in allocation"
+ );
+
+ allocations.push(AllocationWithInfo {
+ allocation,
+ blockset,
+ non_prefix_blocks,
+ });
+ }
+
+ // Check invariants. Skip first iteration, since there is no prefix sharing yet.
+ if i > 1 {
+ check_allocation_invariants(&allocations);
+ }
+
+ // Remove 20% of the allocations, randomly.
+ if remove_all {
+ allocations.into_iter().for_each(|allocation| {
+ allocator.free(
+ allocation.allocation.blocks.clone(),
+ allocation.allocation.allocation_id,
+ )
+ });
+ allocations = Vec::new();
+ } else {
+ allocations.shuffle(&mut rng);
+ let remove_index = (allocations.len() as f64 * 0.8) as usize;
+ for allocation in allocations.drain(remove_index..) {
+ allocator.free(
+ allocation.allocation.blocks.clone(),
+ allocation.allocation.allocation_id,
+ );
+ }
+ }
+ }
+ }
+
+ fn check_allocation_invariants(allocations: &[AllocationWithInfo]) {
+ for i in 0..allocations.len() {
+ let allocation = &allocations[i];
+
+ // 0 is used for health checks, must not be used.
+ assert!(
+ !allocation.blockset.contains(&0),
+ "Block 0 must not be allocated"
+ );
+
+ // No duplicate blocks in an allocation.
+ assert_eq!(
+ allocation.allocation.blocks.len(),
+ allocation.blockset.len(),
+ "Duplicate blocks in allocation"
+ );
+
+ for other_allocation in &allocations[i + 1..] {
+ assert!(
+ other_allocation
+ .non_prefix_blocks
+ .is_disjoint(&allocation.non_prefix_blocks),
+ "Allocations share non-prefix blocks"
+ )
+ }
+ }
+ }
}
diff --git a/benchmark/src/app.rs b/benchmark/src/app.rs
index 7e3aeaf9ecb..bd9ff602557 100644
--- a/benchmark/src/app.rs
+++ b/benchmark/src/app.rs
@@ -434,7 +434,7 @@ impl Data {
}
/// Progress bar
-fn progress_gauge(title: &str, label: String, progress: f64, color: Color) -> Gauge {
+fn progress_gauge(title: &str, label: String, progress: f64, color: Color) -> Gauge<'_> {
Gauge::default()
.block(Block::default().title(title).borders(Borders::ALL))
.gauge_style(Style::default().fg(color))
diff --git a/clients/python/poetry.lock b/clients/python/poetry.lock
index 148d9906558..36e82f2a042 100644
--- a/clients/python/poetry.lock
+++ b/clients/python/poetry.lock
@@ -1,124 +1,131 @@
-# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
+# This file is automatically @generated by Poetry 2.0.0 and should not be changed by hand.
+
+[[package]]
+name = "aiohappyeyeballs"
+version = "2.6.1"
+description = "Happy Eyeballs for asyncio"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "aiohappyeyeballs-2.6.1-py3-none-any.whl", hash = "sha256:f349ba8f4b75cb25c99c5c2d84e997e485204d2902a9597802b0371f09331fb8"},
+ {file = "aiohappyeyeballs-2.6.1.tar.gz", hash = "sha256:c3f9d0113123803ccadfdf3f0faa505bc78e6a72d1cc4806cbd719826e943558"},
+]
[[package]]
name = "aiohttp"
-version = "3.8.5"
+version = "3.11.16"
description = "Async http client/server framework (asyncio)"
optional = false
-python-versions = ">=3.6"
+python-versions = ">=3.9"
+groups = ["main"]
files = [
- {file = "aiohttp-3.8.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a94159871304770da4dd371f4291b20cac04e8c94f11bdea1c3478e557fbe0d8"},
- {file = "aiohttp-3.8.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:13bf85afc99ce6f9ee3567b04501f18f9f8dbbb2ea11ed1a2e079670403a7c84"},
- {file = "aiohttp-3.8.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2ce2ac5708501afc4847221a521f7e4b245abf5178cf5ddae9d5b3856ddb2f3a"},
- {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:96943e5dcc37a6529d18766597c491798b7eb7a61d48878611298afc1fca946c"},
- {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2ad5c3c4590bb3cc28b4382f031f3783f25ec223557124c68754a2231d989e2b"},
- {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0c413c633d0512df4dc7fd2373ec06cc6a815b7b6d6c2f208ada7e9e93a5061d"},
- {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df72ac063b97837a80d80dec8d54c241af059cc9bb42c4de68bd5b61ceb37caa"},
- {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c48c5c0271149cfe467c0ff8eb941279fd6e3f65c9a388c984e0e6cf57538e14"},
- {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:368a42363c4d70ab52c2c6420a57f190ed3dfaca6a1b19afda8165ee16416a82"},
- {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7607ec3ce4993464368505888af5beb446845a014bc676d349efec0e05085905"},
- {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:0d21c684808288a98914e5aaf2a7c6a3179d4df11d249799c32d1808e79503b5"},
- {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:312fcfbacc7880a8da0ae8b6abc6cc7d752e9caa0051a53d217a650b25e9a691"},
- {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ad093e823df03bb3fd37e7dec9d4670c34f9e24aeace76808fc20a507cace825"},
- {file = "aiohttp-3.8.5-cp310-cp310-win32.whl", hash = "sha256:33279701c04351a2914e1100b62b2a7fdb9a25995c4a104259f9a5ead7ed4802"},
- {file = "aiohttp-3.8.5-cp310-cp310-win_amd64.whl", hash = "sha256:6e4a280e4b975a2e7745573e3fc9c9ba0d1194a3738ce1cbaa80626cc9b4f4df"},
- {file = "aiohttp-3.8.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ae871a964e1987a943d83d6709d20ec6103ca1eaf52f7e0d36ee1b5bebb8b9b9"},
- {file = "aiohttp-3.8.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:461908b2578955045efde733719d62f2b649c404189a09a632d245b445c9c975"},
- {file = "aiohttp-3.8.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:72a860c215e26192379f57cae5ab12b168b75db8271f111019509a1196dfc780"},
- {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc14be025665dba6202b6a71cfcdb53210cc498e50068bc088076624471f8bb9"},
- {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8af740fc2711ad85f1a5c034a435782fbd5b5f8314c9a3ef071424a8158d7f6b"},
- {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:841cd8233cbd2111a0ef0a522ce016357c5e3aff8a8ce92bcfa14cef890d698f"},
- {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ed1c46fb119f1b59304b5ec89f834f07124cd23ae5b74288e364477641060ff"},
- {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:84f8ae3e09a34f35c18fa57f015cc394bd1389bce02503fb30c394d04ee6b938"},
- {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:62360cb771707cb70a6fd114b9871d20d7dd2163a0feafe43fd115cfe4fe845e"},
- {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:23fb25a9f0a1ca1f24c0a371523546366bb642397c94ab45ad3aedf2941cec6a"},
- {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b0ba0d15164eae3d878260d4c4df859bbdc6466e9e6689c344a13334f988bb53"},
- {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:5d20003b635fc6ae3f96d7260281dfaf1894fc3aa24d1888a9b2628e97c241e5"},
- {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0175d745d9e85c40dcc51c8f88c74bfbaef9e7afeeeb9d03c37977270303064c"},
- {file = "aiohttp-3.8.5-cp311-cp311-win32.whl", hash = "sha256:2e1b1e51b0774408f091d268648e3d57f7260c1682e7d3a63cb00d22d71bb945"},
- {file = "aiohttp-3.8.5-cp311-cp311-win_amd64.whl", hash = "sha256:043d2299f6dfdc92f0ac5e995dfc56668e1587cea7f9aa9d8a78a1b6554e5755"},
- {file = "aiohttp-3.8.5-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:cae533195e8122584ec87531d6df000ad07737eaa3c81209e85c928854d2195c"},
- {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f21e83f355643c345177a5d1d8079f9f28b5133bcd154193b799d380331d5d3"},
- {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a7a75ef35f2df54ad55dbf4b73fe1da96f370e51b10c91f08b19603c64004acc"},
- {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e2e9839e14dd5308ee773c97115f1e0a1cb1d75cbeeee9f33824fa5144c7634"},
- {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44e65da1de4403d0576473e2344828ef9c4c6244d65cf4b75549bb46d40b8dd"},
- {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:78d847e4cde6ecc19125ccbc9bfac4a7ab37c234dd88fbb3c5c524e8e14da543"},
- {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:c7a815258e5895d8900aec4454f38dca9aed71085f227537208057853f9d13f2"},
- {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:8b929b9bd7cd7c3939f8bcfffa92fae7480bd1aa425279d51a89327d600c704d"},
- {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:5db3a5b833764280ed7618393832e0853e40f3d3e9aa128ac0ba0f8278d08649"},
- {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:a0215ce6041d501f3155dc219712bc41252d0ab76474615b9700d63d4d9292af"},
- {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:fd1ed388ea7fbed22c4968dd64bab0198de60750a25fe8c0c9d4bef5abe13824"},
- {file = "aiohttp-3.8.5-cp36-cp36m-win32.whl", hash = "sha256:6e6783bcc45f397fdebc118d772103d751b54cddf5b60fbcc958382d7dd64f3e"},
- {file = "aiohttp-3.8.5-cp36-cp36m-win_amd64.whl", hash = "sha256:b5411d82cddd212644cf9360879eb5080f0d5f7d809d03262c50dad02f01421a"},
- {file = "aiohttp-3.8.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:01d4c0c874aa4ddfb8098e85d10b5e875a70adc63db91f1ae65a4b04d3344cda"},
- {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5980a746d547a6ba173fd5ee85ce9077e72d118758db05d229044b469d9029a"},
- {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2a482e6da906d5e6e653be079b29bc173a48e381600161c9932d89dfae5942ef"},
- {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80bd372b8d0715c66c974cf57fe363621a02f359f1ec81cba97366948c7fc873"},
- {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1161b345c0a444ebcf46bf0a740ba5dcf50612fd3d0528883fdc0eff578006a"},
- {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cd56db019015b6acfaaf92e1ac40eb8434847d9bf88b4be4efe5bfd260aee692"},
- {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:153c2549f6c004d2754cc60603d4668899c9895b8a89397444a9c4efa282aaf4"},
- {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4a01951fabc4ce26ab791da5f3f24dca6d9a6f24121746eb19756416ff2d881b"},
- {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bfb9162dcf01f615462b995a516ba03e769de0789de1cadc0f916265c257e5d8"},
- {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:7dde0009408969a43b04c16cbbe252c4f5ef4574ac226bc8815cd7342d2028b6"},
- {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:4149d34c32f9638f38f544b3977a4c24052042affa895352d3636fa8bffd030a"},
- {file = "aiohttp-3.8.5-cp37-cp37m-win32.whl", hash = "sha256:68c5a82c8779bdfc6367c967a4a1b2aa52cd3595388bf5961a62158ee8a59e22"},
- {file = "aiohttp-3.8.5-cp37-cp37m-win_amd64.whl", hash = "sha256:2cf57fb50be5f52bda004b8893e63b48530ed9f0d6c96c84620dc92fe3cd9b9d"},
- {file = "aiohttp-3.8.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:eca4bf3734c541dc4f374ad6010a68ff6c6748f00451707f39857f429ca36ced"},
- {file = "aiohttp-3.8.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1274477e4c71ce8cfe6c1ec2f806d57c015ebf84d83373676036e256bc55d690"},
- {file = "aiohttp-3.8.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:28c543e54710d6158fc6f439296c7865b29e0b616629767e685a7185fab4a6b9"},
- {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:910bec0c49637d213f5d9877105d26e0c4a4de2f8b1b29405ff37e9fc0ad52b8"},
- {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5443910d662db951b2e58eb70b0fbe6b6e2ae613477129a5805d0b66c54b6cb7"},
- {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e460be6978fc24e3df83193dc0cc4de46c9909ed92dd47d349a452ef49325b7"},
- {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb1558def481d84f03b45888473fc5a1f35747b5f334ef4e7a571bc0dfcb11f8"},
- {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:34dd0c107799dcbbf7d48b53be761a013c0adf5571bf50c4ecad5643fe9cfcd0"},
- {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:aa1990247f02a54185dc0dff92a6904521172a22664c863a03ff64c42f9b5410"},
- {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:0e584a10f204a617d71d359fe383406305a4b595b333721fa50b867b4a0a1548"},
- {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:a3cf433f127efa43fee6b90ea4c6edf6c4a17109d1d037d1a52abec84d8f2e42"},
- {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:c11f5b099adafb18e65c2c997d57108b5bbeaa9eeee64a84302c0978b1ec948b"},
- {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:84de26ddf621d7ac4c975dbea4c945860e08cccde492269db4e1538a6a6f3c35"},
- {file = "aiohttp-3.8.5-cp38-cp38-win32.whl", hash = "sha256:ab88bafedc57dd0aab55fa728ea10c1911f7e4d8b43e1d838a1739f33712921c"},
- {file = "aiohttp-3.8.5-cp38-cp38-win_amd64.whl", hash = "sha256:5798a9aad1879f626589f3df0f8b79b3608a92e9beab10e5fda02c8a2c60db2e"},
- {file = "aiohttp-3.8.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a6ce61195c6a19c785df04e71a4537e29eaa2c50fe745b732aa937c0c77169f3"},
- {file = "aiohttp-3.8.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:773dd01706d4db536335fcfae6ea2440a70ceb03dd3e7378f3e815b03c97ab51"},
- {file = "aiohttp-3.8.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f83a552443a526ea38d064588613aca983d0ee0038801bc93c0c916428310c28"},
- {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f7372f7341fcc16f57b2caded43e81ddd18df53320b6f9f042acad41f8e049a"},
- {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea353162f249c8097ea63c2169dd1aa55de1e8fecbe63412a9bc50816e87b761"},
- {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5d47ae48db0b2dcf70bc8a3bc72b3de86e2a590fc299fdbbb15af320d2659de"},
- {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d827176898a2b0b09694fbd1088c7a31836d1a505c243811c87ae53a3f6273c1"},
- {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3562b06567c06439d8b447037bb655ef69786c590b1de86c7ab81efe1c9c15d8"},
- {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4e874cbf8caf8959d2adf572a78bba17cb0e9d7e51bb83d86a3697b686a0ab4d"},
- {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6809a00deaf3810e38c628e9a33271892f815b853605a936e2e9e5129762356c"},
- {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:33776e945d89b29251b33a7e7d006ce86447b2cfd66db5e5ded4e5cd0340585c"},
- {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:eaeed7abfb5d64c539e2db173f63631455f1196c37d9d8d873fc316470dfbacd"},
- {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e91d635961bec2d8f19dfeb41a539eb94bd073f075ca6dae6c8dc0ee89ad6f91"},
- {file = "aiohttp-3.8.5-cp39-cp39-win32.whl", hash = "sha256:00ad4b6f185ec67f3e6562e8a1d2b69660be43070bd0ef6fcec5211154c7df67"},
- {file = "aiohttp-3.8.5-cp39-cp39-win_amd64.whl", hash = "sha256:c0a9034379a37ae42dea7ac1e048352d96286626251862e448933c0f59cbd79c"},
- {file = "aiohttp-3.8.5.tar.gz", hash = "sha256:b9552ec52cc147dbf1944ac7ac98af7602e51ea2dcd076ed194ca3c0d1c7d0bc"},
+ {file = "aiohttp-3.11.16-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fb46bb0f24813e6cede6cc07b1961d4b04f331f7112a23b5e21f567da4ee50aa"},
+ {file = "aiohttp-3.11.16-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:54eb3aead72a5c19fad07219acd882c1643a1027fbcdefac9b502c267242f955"},
+ {file = "aiohttp-3.11.16-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:38bea84ee4fe24ebcc8edeb7b54bf20f06fd53ce4d2cc8b74344c5b9620597fd"},
+ {file = "aiohttp-3.11.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0666afbe984f6933fe72cd1f1c3560d8c55880a0bdd728ad774006eb4241ecd"},
+ {file = "aiohttp-3.11.16-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7ba92a2d9ace559a0a14b03d87f47e021e4fa7681dc6970ebbc7b447c7d4b7cd"},
+ {file = "aiohttp-3.11.16-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3ad1d59fd7114e6a08c4814983bb498f391c699f3c78712770077518cae63ff7"},
+ {file = "aiohttp-3.11.16-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98b88a2bf26965f2015a771381624dd4b0839034b70d406dc74fd8be4cc053e3"},
+ {file = "aiohttp-3.11.16-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:576f5ca28d1b3276026f7df3ec841ae460e0fc3aac2a47cbf72eabcfc0f102e1"},
+ {file = "aiohttp-3.11.16-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a2a450bcce4931b295fc0848f384834c3f9b00edfc2150baafb4488c27953de6"},
+ {file = "aiohttp-3.11.16-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:37dcee4906454ae377be5937ab2a66a9a88377b11dd7c072df7a7c142b63c37c"},
+ {file = "aiohttp-3.11.16-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:4d0c970c0d602b1017e2067ff3b7dac41c98fef4f7472ec2ea26fd8a4e8c2149"},
+ {file = "aiohttp-3.11.16-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:004511d3413737700835e949433536a2fe95a7d0297edd911a1e9705c5b5ea43"},
+ {file = "aiohttp-3.11.16-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:c15b2271c44da77ee9d822552201180779e5e942f3a71fb74e026bf6172ff287"},
+ {file = "aiohttp-3.11.16-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ad9509ffb2396483ceacb1eee9134724443ee45b92141105a4645857244aecc8"},
+ {file = "aiohttp-3.11.16-cp310-cp310-win32.whl", hash = "sha256:634d96869be6c4dc232fc503e03e40c42d32cfaa51712aee181e922e61d74814"},
+ {file = "aiohttp-3.11.16-cp310-cp310-win_amd64.whl", hash = "sha256:938f756c2b9374bbcc262a37eea521d8a0e6458162f2a9c26329cc87fdf06534"},
+ {file = "aiohttp-3.11.16-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8cb0688a8d81c63d716e867d59a9ccc389e97ac7037ebef904c2b89334407180"},
+ {file = "aiohttp-3.11.16-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0ad1fb47da60ae1ddfb316f0ff16d1f3b8e844d1a1e154641928ea0583d486ed"},
+ {file = "aiohttp-3.11.16-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:df7db76400bf46ec6a0a73192b14c8295bdb9812053f4fe53f4e789f3ea66bbb"},
+ {file = "aiohttp-3.11.16-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc3a145479a76ad0ed646434d09216d33d08eef0d8c9a11f5ae5cdc37caa3540"},
+ {file = "aiohttp-3.11.16-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d007aa39a52d62373bd23428ba4a2546eed0e7643d7bf2e41ddcefd54519842c"},
+ {file = "aiohttp-3.11.16-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f6ddd90d9fb4b501c97a4458f1c1720e42432c26cb76d28177c5b5ad4e332601"},
+ {file = "aiohttp-3.11.16-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a2f451849e6b39e5c226803dcacfa9c7133e9825dcefd2f4e837a2ec5a3bb98"},
+ {file = "aiohttp-3.11.16-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8df6612df74409080575dca38a5237282865408016e65636a76a2eb9348c2567"},
+ {file = "aiohttp-3.11.16-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:78e6e23b954644737e385befa0deb20233e2dfddf95dd11e9db752bdd2a294d3"},
+ {file = "aiohttp-3.11.16-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:696ef00e8a1f0cec5e30640e64eca75d8e777933d1438f4facc9c0cdf288a810"},
+ {file = "aiohttp-3.11.16-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e3538bc9fe1b902bef51372462e3d7c96fce2b566642512138a480b7adc9d508"},
+ {file = "aiohttp-3.11.16-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:3ab3367bb7f61ad18793fea2ef71f2d181c528c87948638366bf1de26e239183"},
+ {file = "aiohttp-3.11.16-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:56a3443aca82abda0e07be2e1ecb76a050714faf2be84256dae291182ba59049"},
+ {file = "aiohttp-3.11.16-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:61c721764e41af907c9d16b6daa05a458f066015abd35923051be8705108ed17"},
+ {file = "aiohttp-3.11.16-cp311-cp311-win32.whl", hash = "sha256:3e061b09f6fa42997cf627307f220315e313ece74907d35776ec4373ed718b86"},
+ {file = "aiohttp-3.11.16-cp311-cp311-win_amd64.whl", hash = "sha256:745f1ed5e2c687baefc3c5e7b4304e91bf3e2f32834d07baaee243e349624b24"},
+ {file = "aiohttp-3.11.16-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:911a6e91d08bb2c72938bc17f0a2d97864c531536b7832abee6429d5296e5b27"},
+ {file = "aiohttp-3.11.16-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6ac13b71761e49d5f9e4d05d33683bbafef753e876e8e5a7ef26e937dd766713"},
+ {file = "aiohttp-3.11.16-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fd36c119c5d6551bce374fcb5c19269638f8d09862445f85a5a48596fd59f4bb"},
+ {file = "aiohttp-3.11.16-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d489d9778522fbd0f8d6a5c6e48e3514f11be81cb0a5954bdda06f7e1594b321"},
+ {file = "aiohttp-3.11.16-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:69a2cbd61788d26f8f1e626e188044834f37f6ae3f937bd9f08b65fc9d7e514e"},
+ {file = "aiohttp-3.11.16-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd464ba806e27ee24a91362ba3621bfc39dbbb8b79f2e1340201615197370f7c"},
+ {file = "aiohttp-3.11.16-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ce63ae04719513dd2651202352a2beb9f67f55cb8490c40f056cea3c5c355ce"},
+ {file = "aiohttp-3.11.16-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09b00dd520d88eac9d1768439a59ab3d145065c91a8fab97f900d1b5f802895e"},
+ {file = "aiohttp-3.11.16-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:7f6428fee52d2bcf96a8aa7b62095b190ee341ab0e6b1bcf50c615d7966fd45b"},
+ {file = "aiohttp-3.11.16-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:13ceac2c5cdcc3f64b9015710221ddf81c900c5febc505dbd8f810e770011540"},
+ {file = "aiohttp-3.11.16-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:fadbb8f1d4140825069db3fedbbb843290fd5f5bc0a5dbd7eaf81d91bf1b003b"},
+ {file = "aiohttp-3.11.16-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:6a792ce34b999fbe04a7a71a90c74f10c57ae4c51f65461a411faa70e154154e"},
+ {file = "aiohttp-3.11.16-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:f4065145bf69de124accdd17ea5f4dc770da0a6a6e440c53f6e0a8c27b3e635c"},
+ {file = "aiohttp-3.11.16-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fa73e8c2656a3653ae6c307b3f4e878a21f87859a9afab228280ddccd7369d71"},
+ {file = "aiohttp-3.11.16-cp312-cp312-win32.whl", hash = "sha256:f244b8e541f414664889e2c87cac11a07b918cb4b540c36f7ada7bfa76571ea2"},
+ {file = "aiohttp-3.11.16-cp312-cp312-win_amd64.whl", hash = "sha256:23a15727fbfccab973343b6d1b7181bfb0b4aa7ae280f36fd2f90f5476805682"},
+ {file = "aiohttp-3.11.16-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a3814760a1a700f3cfd2f977249f1032301d0a12c92aba74605cfa6ce9f78489"},
+ {file = "aiohttp-3.11.16-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:9b751a6306f330801665ae69270a8a3993654a85569b3469662efaad6cf5cc50"},
+ {file = "aiohttp-3.11.16-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ad497f38a0d6c329cb621774788583ee12321863cd4bd9feee1effd60f2ad133"},
+ {file = "aiohttp-3.11.16-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca37057625693d097543bd88076ceebeb248291df9d6ca8481349efc0b05dcd0"},
+ {file = "aiohttp-3.11.16-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a5abcbba9f4b463a45c8ca8b7720891200658f6f46894f79517e6cd11f3405ca"},
+ {file = "aiohttp-3.11.16-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f420bfe862fb357a6d76f2065447ef6f484bc489292ac91e29bc65d2d7a2c84d"},
+ {file = "aiohttp-3.11.16-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58ede86453a6cf2d6ce40ef0ca15481677a66950e73b0a788917916f7e35a0bb"},
+ {file = "aiohttp-3.11.16-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6fdec0213244c39973674ca2a7f5435bf74369e7d4e104d6c7473c81c9bcc8c4"},
+ {file = "aiohttp-3.11.16-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:72b1b03fb4655c1960403c131740755ec19c5898c82abd3961c364c2afd59fe7"},
+ {file = "aiohttp-3.11.16-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:780df0d837276276226a1ff803f8d0fa5f8996c479aeef52eb040179f3156cbd"},
+ {file = "aiohttp-3.11.16-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ecdb8173e6c7aa09eee342ac62e193e6904923bd232e76b4157ac0bfa670609f"},
+ {file = "aiohttp-3.11.16-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:a6db7458ab89c7d80bc1f4e930cc9df6edee2200127cfa6f6e080cf619eddfbd"},
+ {file = "aiohttp-3.11.16-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:2540ddc83cc724b13d1838026f6a5ad178510953302a49e6d647f6e1de82bc34"},
+ {file = "aiohttp-3.11.16-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:3b4e6db8dc4879015b9955778cfb9881897339c8fab7b3676f8433f849425913"},
+ {file = "aiohttp-3.11.16-cp313-cp313-win32.whl", hash = "sha256:493910ceb2764f792db4dc6e8e4b375dae1b08f72e18e8f10f18b34ca17d0979"},
+ {file = "aiohttp-3.11.16-cp313-cp313-win_amd64.whl", hash = "sha256:42864e70a248f5f6a49fdaf417d9bc62d6e4d8ee9695b24c5916cb4bb666c802"},
+ {file = "aiohttp-3.11.16-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bbcba75fe879ad6fd2e0d6a8d937f34a571f116a0e4db37df8079e738ea95c71"},
+ {file = "aiohttp-3.11.16-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:87a6e922b2b2401e0b0cf6b976b97f11ec7f136bfed445e16384fbf6fd5e8602"},
+ {file = "aiohttp-3.11.16-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ccf10f16ab498d20e28bc2b5c1306e9c1512f2840f7b6a67000a517a4b37d5ee"},
+ {file = "aiohttp-3.11.16-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb3d0cc5cdb926090748ea60172fa8a213cec728bd6c54eae18b96040fcd6227"},
+ {file = "aiohttp-3.11.16-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d07502cc14ecd64f52b2a74ebbc106893d9a9717120057ea9ea1fd6568a747e7"},
+ {file = "aiohttp-3.11.16-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:776c8e959a01e5e8321f1dec77964cb6101020a69d5a94cd3d34db6d555e01f7"},
+ {file = "aiohttp-3.11.16-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0902e887b0e1d50424112f200eb9ae3dfed6c0d0a19fc60f633ae5a57c809656"},
+ {file = "aiohttp-3.11.16-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e87fd812899aa78252866ae03a048e77bd11b80fb4878ce27c23cade239b42b2"},
+ {file = "aiohttp-3.11.16-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0a950c2eb8ff17361abd8c85987fd6076d9f47d040ebffce67dce4993285e973"},
+ {file = "aiohttp-3.11.16-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:c10d85e81d0b9ef87970ecbdbfaeec14a361a7fa947118817fcea8e45335fa46"},
+ {file = "aiohttp-3.11.16-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:7951decace76a9271a1ef181b04aa77d3cc309a02a51d73826039003210bdc86"},
+ {file = "aiohttp-3.11.16-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:14461157d8426bcb40bd94deb0450a6fa16f05129f7da546090cebf8f3123b0f"},
+ {file = "aiohttp-3.11.16-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:9756d9b9d4547e091f99d554fbba0d2a920aab98caa82a8fb3d3d9bee3c9ae85"},
+ {file = "aiohttp-3.11.16-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:87944bd16b7fe6160607f6a17808abd25f17f61ae1e26c47a491b970fb66d8cb"},
+ {file = "aiohttp-3.11.16-cp39-cp39-win32.whl", hash = "sha256:92b7ee222e2b903e0a4b329a9943d432b3767f2d5029dbe4ca59fb75223bbe2e"},
+ {file = "aiohttp-3.11.16-cp39-cp39-win_amd64.whl", hash = "sha256:17ae4664031aadfbcb34fd40ffd90976671fa0c0286e6c4113989f78bebab37a"},
+ {file = "aiohttp-3.11.16.tar.gz", hash = "sha256:16f8a2c9538c14a557b4d309ed4d0a7c60f0253e8ed7b6c9a2859a7582f8b1b8"},
]
[package.dependencies]
+aiohappyeyeballs = ">=2.3.0"
aiosignal = ">=1.1.2"
-async-timeout = ">=4.0.0a3,<5.0"
-asynctest = {version = "0.13.0", markers = "python_version < \"3.8\""}
+async-timeout = {version = ">=4.0,<6.0", markers = "python_version < \"3.11\""}
attrs = ">=17.3.0"
-charset-normalizer = ">=2.0,<4.0"
frozenlist = ">=1.1.1"
multidict = ">=4.5,<7.0"
-typing-extensions = {version = ">=3.7.4", markers = "python_version < \"3.8\""}
-yarl = ">=1.0,<2.0"
+propcache = ">=0.2.0"
+yarl = ">=1.17.0,<2.0"
[package.extras]
-speedups = ["Brotli", "aiodns", "cchardet"]
+speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"]
[[package]]
name = "aiosignal"
-version = "1.3.1"
+version = "1.3.2"
description = "aiosignal: a list of registered asynchronous callbacks"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.9"
+groups = ["main"]
files = [
- {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"},
- {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"},
+ {file = "aiosignal-1.3.2-py2.py3-none-any.whl", hash = "sha256:45cde58e409a301715980c2b01d0c28bdde3770d8290b5eb2173759d9acb31a5"},
+ {file = "aiosignal-1.3.2.tar.gz", hash = "sha256:a8c255c66fafb1e499c9351d0bf32ff2d8a0321595ebac3b93713656d2436f54"},
]
[package.dependencies]
@@ -126,167 +133,161 @@ frozenlist = ">=1.1.0"
[[package]]
name = "annotated-types"
-version = "0.5.0"
+version = "0.7.0"
description = "Reusable constraint types to use with typing.Annotated"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
+groups = ["main"]
files = [
- {file = "annotated_types-0.5.0-py3-none-any.whl", hash = "sha256:58da39888f92c276ad970249761ebea80ba544b77acddaa1a4d6cf78287d45fd"},
- {file = "annotated_types-0.5.0.tar.gz", hash = "sha256:47cdc3490d9ac1506ce92c7aaa76c579dc3509ff11e098fc867e5130ab7be802"},
+ {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"},
+ {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
]
-[package.dependencies]
-typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""}
-
[[package]]
name = "async-timeout"
-version = "4.0.3"
+version = "5.0.1"
description = "Timeout context manager for asyncio programs"
optional = false
-python-versions = ">=3.7"
-files = [
- {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"},
- {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"},
-]
-
-[package.dependencies]
-typing-extensions = {version = ">=3.6.5", markers = "python_version < \"3.8\""}
-
-[[package]]
-name = "asynctest"
-version = "0.13.0"
-description = "Enhance the standard unittest package with features for testing asyncio libraries"
-optional = false
-python-versions = ">=3.5"
-files = [
- {file = "asynctest-0.13.0-py3-none-any.whl", hash = "sha256:5da6118a7e6d6b54d83a8f7197769d046922a44d2a99c21382f0a6e4fadae676"},
- {file = "asynctest-0.13.0.tar.gz", hash = "sha256:c27862842d15d83e6a34eb0b2866c323880eb3a75e4485b079ea11748fd77fac"},
-]
-
-[[package]]
-name = "atomicwrites"
-version = "1.4.1"
-description = "Atomic file writes."
-optional = false
-python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+python-versions = ">=3.8"
+groups = ["main"]
+markers = "python_version < \"3.11\""
files = [
- {file = "atomicwrites-1.4.1.tar.gz", hash = "sha256:81b2c9071a49367a7f770170e5eec8cb66567cfbbc8c73d20ce5ca4a8d71cf11"},
+ {file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"},
+ {file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"},
]
[[package]]
name = "attrs"
-version = "23.1.0"
+version = "25.3.0"
description = "Classes Without Boilerplate"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
+groups = ["main"]
files = [
- {file = "attrs-23.1.0-py3-none-any.whl", hash = "sha256:1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04"},
- {file = "attrs-23.1.0.tar.gz", hash = "sha256:6279836d581513a26f1bf235f9acd333bc9115683f14f7e8fae46c98fc50e015"},
+ {file = "attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3"},
+ {file = "attrs-25.3.0.tar.gz", hash = "sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b"},
]
-[package.dependencies]
-importlib-metadata = {version = "*", markers = "python_version < \"3.8\""}
-
[package.extras]
-cov = ["attrs[tests]", "coverage[toml] (>=5.3)"]
-dev = ["attrs[docs,tests]", "pre-commit"]
-docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"]
-tests = ["attrs[tests-no-zope]", "zope-interface"]
-tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
+benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
+cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
+dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
+docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier"]
+tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
+tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"]
[[package]]
name = "certifi"
-version = "2023.7.22"
+version = "2025.1.31"
description = "Python package for providing Mozilla's CA Bundle."
optional = false
python-versions = ">=3.6"
+groups = ["main"]
files = [
- {file = "certifi-2023.7.22-py3-none-any.whl", hash = "sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9"},
- {file = "certifi-2023.7.22.tar.gz", hash = "sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082"},
+ {file = "certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe"},
+ {file = "certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651"},
]
[[package]]
name = "charset-normalizer"
-version = "3.2.0"
+version = "3.4.1"
description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
optional = false
-python-versions = ">=3.7.0"
+python-versions = ">=3.7"
+groups = ["main"]
files = [
- {file = "charset-normalizer-3.2.0.tar.gz", hash = "sha256:3bb3d25a8e6c0aedd251753a79ae98a093c7e7b471faa3aa9a93a81431987ace"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b87549028f680ca955556e3bd57013ab47474c3124dc069faa0b6545b6c9710"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7c70087bfee18a42b4040bb9ec1ca15a08242cf5867c58726530bdf3945672ed"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a103b3a7069b62f5d4890ae1b8f0597618f628b286b03d4bc9195230b154bfa9"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94aea8eff76ee6d1cdacb07dd2123a68283cb5569e0250feab1240058f53b623"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:db901e2ac34c931d73054d9797383d0f8009991e723dab15109740a63e7f902a"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b0dac0ff919ba34d4df1b6131f59ce95b08b9065233446be7e459f95554c0dc8"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:193cbc708ea3aca45e7221ae58f0fd63f933753a9bfb498a3b474878f12caaad"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09393e1b2a9461950b1c9a45d5fd251dc7c6f228acab64da1c9c0165d9c7765c"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:baacc6aee0b2ef6f3d308e197b5d7a81c0e70b06beae1f1fcacffdbd124fe0e3"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:bf420121d4c8dce6b889f0e8e4ec0ca34b7f40186203f06a946fa0276ba54029"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:c04a46716adde8d927adb9457bbe39cf473e1e2c2f5d0a16ceb837e5d841ad4f"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:aaf63899c94de41fe3cf934601b0f7ccb6b428c6e4eeb80da72c58eab077b19a"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d62e51710986674142526ab9f78663ca2b0726066ae26b78b22e0f5e571238dd"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-win32.whl", hash = "sha256:04e57ab9fbf9607b77f7d057974694b4f6b142da9ed4a199859d9d4d5c63fe96"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:48021783bdf96e3d6de03a6e39a1171ed5bd7e8bb93fc84cc649d11490f87cea"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4957669ef390f0e6719db3613ab3a7631e68424604a7b448f079bee145da6e09"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:46fb8c61d794b78ec7134a715a3e564aafc8f6b5e338417cb19fe9f57a5a9bf2"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f779d3ad205f108d14e99bb3859aa7dd8e9c68874617c72354d7ecaec2a054ac"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f25c229a6ba38a35ae6e25ca1264621cc25d4d38dca2942a7fce0b67a4efe918"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2efb1bd13885392adfda4614c33d3b68dee4921fd0ac1d3988f8cbb7d589e72a"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f30b48dd7fa1474554b0b0f3fdfdd4c13b5c737a3c6284d3cdc424ec0ffff3a"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:246de67b99b6851627d945db38147d1b209a899311b1305dd84916f2b88526c6"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bd9b3b31adcb054116447ea22caa61a285d92e94d710aa5ec97992ff5eb7cf3"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:8c2f5e83493748286002f9369f3e6607c565a6a90425a3a1fef5ae32a36d749d"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:3170c9399da12c9dc66366e9d14da8bf7147e1e9d9ea566067bbce7bb74bd9c2"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7a4826ad2bd6b07ca615c74ab91f32f6c96d08f6fcc3902ceeedaec8cdc3bcd6"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:3b1613dd5aee995ec6d4c69f00378bbd07614702a315a2cf6c1d21461fe17c23"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9e608aafdb55eb9f255034709e20d5a83b6d60c054df0802fa9c9883d0a937aa"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-win32.whl", hash = "sha256:f2a1d0fd4242bd8643ce6f98927cf9c04540af6efa92323e9d3124f57727bfc1"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:681eb3d7e02e3c3655d1b16059fbfb605ac464c834a0c629048a30fad2b27489"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c57921cda3a80d0f2b8aec7e25c8aa14479ea92b5b51b6876d975d925a2ea346"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41b25eaa7d15909cf3ac4c96088c1f266a9a93ec44f87f1d13d4a0e86c81b982"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f058f6963fd82eb143c692cecdc89e075fa0828db2e5b291070485390b2f1c9c"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7647ebdfb9682b7bb97e2a5e7cb6ae735b1c25008a70b906aecca294ee96cf4"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eef9df1eefada2c09a5e7a40991b9fc6ac6ef20b1372abd48d2794a316dc0449"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e03b8895a6990c9ab2cdcd0f2fe44088ca1c65ae592b8f795c3294af00a461c3"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:ee4006268ed33370957f55bf2e6f4d263eaf4dc3cfc473d1d90baff6ed36ce4a"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c4983bf937209c57240cff65906b18bb35e64ae872da6a0db937d7b4af845dd7"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:3bb7fda7260735efe66d5107fb7e6af6a7c04c7fce9b2514e04b7a74b06bf5dd"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:72814c01533f51d68702802d74f77ea026b5ec52793c791e2da806a3844a46c3"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:70c610f6cbe4b9fce272c407dd9d07e33e6bf7b4aa1b7ffb6f6ded8e634e3592"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-win32.whl", hash = "sha256:a401b4598e5d3f4a9a811f3daf42ee2291790c7f9d74b18d75d6e21dda98a1a1"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:c0b21078a4b56965e2b12f247467b234734491897e99c1d51cee628da9786959"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:95eb302ff792e12aba9a8b8f8474ab229a83c103d74a750ec0bd1c1eea32e669"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1a100c6d595a7f316f1b6f01d20815d916e75ff98c27a01ae817439ea7726329"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6339d047dab2780cc6220f46306628e04d9750f02f983ddb37439ca47ced7149"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4b749b9cc6ee664a3300bb3a273c1ca8068c46be705b6c31cf5d276f8628a94"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a38856a971c602f98472050165cea2cdc97709240373041b69030be15047691f"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f87f746ee241d30d6ed93969de31e5ffd09a2961a051e60ae6bddde9ec3583aa"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89f1b185a01fe560bc8ae5f619e924407efca2191b56ce749ec84982fc59a32a"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e1c8a2f4c69e08e89632defbfabec2feb8a8d99edc9f89ce33c4b9e36ab63037"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2f4ac36d8e2b4cc1aa71df3dd84ff8efbe3bfb97ac41242fbcfc053c67434f46"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a386ebe437176aab38c041de1260cd3ea459c6ce5263594399880bbc398225b2"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:ccd16eb18a849fd8dcb23e23380e2f0a354e8daa0c984b8a732d9cfaba3a776d"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:e6a5bf2cba5ae1bb80b154ed68a3cfa2fa00fde979a7f50d6598d3e17d9ac20c"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:45de3f87179c1823e6d9e32156fb14c1927fcc9aba21433f088fdfb555b77c10"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-win32.whl", hash = "sha256:1000fba1057b92a65daec275aec30586c3de2401ccdcd41f8a5c1e2c87078706"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:8b2c760cfc7042b27ebdb4a43a4453bd829a5742503599144d54a032c5dc7e9e"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:855eafa5d5a2034b4621c74925d89c5efef61418570e5ef9b37717d9c796419c"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:203f0c8871d5a7987be20c72442488a0b8cfd0f43b7973771640fc593f56321f"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e857a2232ba53ae940d3456f7533ce6ca98b81917d47adc3c7fd55dad8fab858"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e86d77b090dbddbe78867a0275cb4df08ea195e660f1f7f13435a4649e954e5"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4fb39a81950ec280984b3a44f5bd12819953dc5fa3a7e6fa7a80db5ee853952"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2dee8e57f052ef5353cf608e0b4c871aee320dd1b87d351c28764fc0ca55f9f4"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8700f06d0ce6f128de3ccdbc1acaea1ee264d2caa9ca05daaf492fde7c2a7200"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1920d4ff15ce893210c1f0c0e9d19bfbecb7983c76b33f046c13a8ffbd570252"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c1c76a1743432b4b60ab3358c937a3fe1341c828ae6194108a94c69028247f22"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f7560358a6811e52e9c4d142d497f1a6e10103d3a6881f18d04dbce3729c0e2c"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:c8063cf17b19661471ecbdb3df1c84f24ad2e389e326ccaf89e3fb2484d8dd7e"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:cd6dbe0238f7743d0efe563ab46294f54f9bc8f4b9bcf57c3c666cc5bc9d1299"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:1249cbbf3d3b04902ff081ffbb33ce3377fa6e4c7356f759f3cd076cc138d020"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-win32.whl", hash = "sha256:6c409c0deba34f147f77efaa67b8e4bb83d2f11c8806405f76397ae5b8c0d1c9"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:7095f6fbfaa55defb6b733cfeb14efaae7a29f0b59d8cf213be4e7ca0b857b80"},
- {file = "charset_normalizer-3.2.0-py3-none-any.whl", hash = "sha256:8e098148dd37b4ce3baca71fb394c81dc5d9c7728c95df695d2dca218edf40e6"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e218488cd232553829be0664c2292d3af2eeeb94b32bea483cf79ac6a694e037"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80ed5e856eb7f30115aaf94e4a08114ccc8813e6ed1b5efa74f9f82e8509858f"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b010a7a4fd316c3c484d482922d13044979e78d1861f0e0650423144c616a46a"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4532bff1b8421fd0a320463030c7520f56a79c9024a4e88f01c537316019005a"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d973f03c0cb71c5ed99037b870f2be986c3c05e63622c017ea9816881d2dd247"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3a3bd0dcd373514dcec91c411ddb9632c0d7d92aed7093b8c3bbb6d69ca74408"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:d9c3cdf5390dcd29aa8056d13e8e99526cda0305acc038b96b30352aff5ff2bb"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:2bdfe3ac2e1bbe5b59a1a63721eb3b95fc9b6817ae4a46debbb4e11f6232428d"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:eab677309cdb30d047996b36d34caeda1dc91149e4fdca0b1a039b3f79d9a807"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-win32.whl", hash = "sha256:c0429126cf75e16c4f0ad00ee0eae4242dc652290f940152ca8c75c3a4b6ee8f"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:9f0b8b1c6d84c8034a44893aba5e767bf9c7a211e313a9605d9c617d7083829f"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8bfa33f4f2672964266e940dd22a195989ba31669bd84629f05fab3ef4e2d125"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28bf57629c75e810b6ae989f03c0828d64d6b26a5e205535585f96093e405ed1"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f08ff5e948271dc7e18a35641d2f11a4cd8dfd5634f55228b691e62b37125eb3"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:234ac59ea147c59ee4da87a0c0f098e9c8d169f4dc2a159ef720f1a61bbe27cd"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd4ec41f914fa74ad1b8304bbc634b3de73d2a0889bd32076342a573e0779e00"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eea6ee1db730b3483adf394ea72f808b6e18cf3cb6454b4d86e04fa8c4327a12"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c96836c97b1238e9c9e3fe90844c947d5afbf4f4c92762679acfe19927d81d77"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4d86f7aff21ee58f26dcf5ae81a9addbd914115cdebcbb2217e4f0ed8982e146"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:09b5e6733cbd160dcc09589227187e242a30a49ca5cefa5a7edd3f9d19ed53fd"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:5777ee0881f9499ed0f71cc82cf873d9a0ca8af166dfa0af8ec4e675b7df48e6"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:237bdbe6159cff53b4f24f397d43c6336c6b0b42affbe857970cefbb620911c8"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-win32.whl", hash = "sha256:8417cb1f36cc0bc7eaba8ccb0e04d55f0ee52df06df3ad55259b9a323555fc8b"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:d7f50a1f8c450f3925cb367d011448c39239bb3eb4117c36a6d354794de4ce76"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:73d94b58ec7fecbc7366247d3b0b10a21681004153238750bb67bd9012414545"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dad3e487649f498dd991eeb901125411559b22e8d7ab25d3aeb1af367df5efd7"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c30197aa96e8eed02200a83fba2657b4c3acd0f0aa4bdc9f6c1af8e8962e0757"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2369eea1ee4a7610a860d88f268eb39b95cb588acd7235e02fd5a5601773d4fa"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc2722592d8998c870fa4e290c2eec2c1569b87fe58618e67d38b4665dfa680d"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffc9202a29ab3920fa812879e95a9e78b2465fd10be7fcbd042899695d75e616"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:804a4d582ba6e5b747c625bf1255e6b1507465494a40a2130978bda7b932c90b"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0f55e69f030f7163dffe9fd0752b32f070566451afe180f99dbeeb81f511ad8d"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c4c3e6da02df6fa1410a7680bd3f63d4f710232d3139089536310d027950696a"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:5df196eb874dae23dcfb968c83d4f8fdccb333330fe1fc278ac5ceeb101003a9"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e358e64305fe12299a08e08978f51fc21fac060dcfcddd95453eabe5b93ed0e1"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-win32.whl", hash = "sha256:9b23ca7ef998bc739bf6ffc077c2116917eabcc901f88da1b9856b210ef63f35"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:6ff8a4a60c227ad87030d76e99cd1698345d4491638dfa6673027c48b3cd395f"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:aabfa34badd18f1da5ec1bc2715cadc8dca465868a4e73a0173466b688f29dda"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22e14b5d70560b8dd51ec22863f370d1e595ac3d024cb8ad7d308b4cd95f8313"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8436c508b408b82d87dc5f62496973a1805cd46727c34440b0d29d8a2f50a6c9"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2d074908e1aecee37a7635990b2c6d504cd4766c7bc9fc86d63f9c09af3fa11b"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:955f8851919303c92343d2f66165294848d57e9bba6cf6e3625485a70a038d11"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:44ecbf16649486d4aebafeaa7ec4c9fed8b88101f4dd612dcaf65d5e815f837f"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0924e81d3d5e70f8126529951dac65c1010cdf117bb75eb02dd12339b57749dd"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2967f74ad52c3b98de4c3b32e1a44e32975e008a9cd2a8cc8966d6a5218c5cb2"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:c75cb2a3e389853835e84a2d8fb2b81a10645b503eca9bcb98df6b5a43eb8886"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:09b26ae6b1abf0d27570633b2b078a2a20419c99d66fb2823173d73f188ce601"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fa88b843d6e211393a37219e6a1c1df99d35e8fd90446f1118f4216e307e48cd"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-win32.whl", hash = "sha256:eb8178fe3dba6450a3e024e95ac49ed3400e506fd4e9e5c32d30adda88cbd407"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-win_amd64.whl", hash = "sha256:b1ac5992a838106edb89654e0aebfc24f5848ae2547d22c2c3f66454daa11971"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f30bf9fd9be89ecb2360c7d94a711f00c09b976258846efe40db3d05828e8089"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:97f68b8d6831127e4787ad15e6757232e14e12060bec17091b85eb1486b91d8d"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7974a0b5ecd505609e3b19742b60cee7aa2aa2fb3151bc917e6e2646d7667dcf"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc54db6c8593ef7d4b2a331b58653356cf04f67c960f584edb7c3d8c97e8f39e"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:311f30128d7d333eebd7896965bfcfbd0065f1716ec92bd5638d7748eb6f936a"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:7d053096f67cd1241601111b698f5cad775f97ab25d81567d3f59219b5f1adbd"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:807f52c1f798eef6cf26beb819eeb8819b1622ddfeef9d0977a8502d4db6d534"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:dccbe65bd2f7f7ec22c4ff99ed56faa1e9f785482b9bbd7c717e26fd723a1d1e"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_s390x.whl", hash = "sha256:2fb9bd477fdea8684f78791a6de97a953c51831ee2981f8e4f583ff3b9d9687e"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:01732659ba9b5b873fc117534143e4feefecf3b2078b0a6a2e925271bb6f4cfa"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-win32.whl", hash = "sha256:7a4f97a081603d2050bfaffdefa5b02a9ec823f8348a572e39032caa8404a487"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:7b1bef6280950ee6c177b326508f86cad7ad4dff12454483b51d8b7d673a2c5d"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ecddf25bee22fe4fe3737a399d0d177d72bc22be6913acfab364b40bce1ba83c"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c60ca7339acd497a55b0ea5d506b2a2612afb2826560416f6894e8b5770d4a9"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b7b2d86dd06bfc2ade3312a83a5c364c7ec2e3498f8734282c6c3d4b07b346b8"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dd78cfcda14a1ef52584dbb008f7ac81c1328c0f58184bf9a84c49c605002da6"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e27f48bcd0957c6d4cb9d6fa6b61d192d0b13d5ef563e5f2ae35feafc0d179c"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:01ad647cdd609225c5350561d084b42ddf732f4eeefe6e678765636791e78b9a"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:619a609aa74ae43d90ed2e89bdd784765de0a25ca761b93e196d938b8fd1dbbd"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:89149166622f4db9b4b6a449256291dc87a99ee53151c74cbd82a53c8c2f6ccd"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:7709f51f5f7c853f0fb938bcd3bc59cdfdc5203635ffd18bf354f6967ea0f824"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:345b0426edd4e18138d6528aed636de7a9ed169b4aaf9d61a8c19e39d26838ca"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0907f11d019260cdc3f94fbdb23ff9125f6b5d1039b76003b5b0ac9d6a6c9d5b"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-win32.whl", hash = "sha256:ea0d8d539afa5eb2728aa1932a988a9a7af94f18582ffae4bc10b3fbdad0626e"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:329ce159e82018d646c7ac45b01a430369d526569ec08516081727a20e9e4af4"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b97e690a2118911e39b4042088092771b4ae3fc3aa86518f84b8cf6888dbdb41"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78baa6d91634dfb69ec52a463534bc0df05dbd546209b79a3880a34487f4b84f"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1a2bc9f351a75ef49d664206d51f8e5ede9da246602dc2d2726837620ea034b2"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:75832c08354f595c760a804588b9357d34ec00ba1c940c15e31e96d902093770"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0af291f4fe114be0280cdd29d533696a77b5b49cfde5467176ecab32353395c4"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0167ddc8ab6508fe81860a57dd472b2ef4060e8d378f0cc555707126830f2537"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2a75d49014d118e4198bcee5ee0a6f25856b29b12dbf7cd012791f8a6cc5c496"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:363e2f92b0f0174b2f8238240a1a30142e3db7b957a5dd5689b0e75fb717cc78"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ab36c8eb7e454e34e60eb55ca5d241a5d18b2c6244f6827a30e451c42410b5f7"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:4c0907b1928a36d5a998d72d64d8eaa7244989f7aaaf947500d3a800c83a3fd6"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:04432ad9479fa40ec0f387795ddad4437a2b50417c69fa275e212933519ff294"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-win32.whl", hash = "sha256:3bed14e9c89dcb10e8f3a29f9ccac4955aebe93c71ae803af79265c9ca5644c5"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:49402233c892a461407c512a19435d1ce275543138294f7ef013f0b63d5d3765"},
+ {file = "charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85"},
+ {file = "charset_normalizer-3.4.1.tar.gz", hash = "sha256:44251f18cd68a75b56585dd00dae26183e102cd5e0f9f1466e6df5da2ed64ea3"},
]
[[package]]
@@ -295,78 +296,84 @@ version = "0.4.6"
description = "Cross-platform colored terminal text."
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
+groups = ["main", "dev"]
files = [
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
]
+markers = {main = "platform_system == \"Windows\"", dev = "sys_platform == \"win32\""}
[[package]]
name = "coverage"
-version = "7.2.7"
+version = "7.8.0"
description = "Code coverage measurement for Python"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.9"
+groups = ["dev"]
files = [
- {file = "coverage-7.2.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d39b5b4f2a66ccae8b7263ac3c8170994b65266797fb96cbbfd3fb5b23921db8"},
- {file = "coverage-7.2.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6d040ef7c9859bb11dfeb056ff5b3872436e3b5e401817d87a31e1750b9ae2fb"},
- {file = "coverage-7.2.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba90a9563ba44a72fda2e85302c3abc71c5589cea608ca16c22b9804262aaeb6"},
- {file = "coverage-7.2.7-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7d9405291c6928619403db1d10bd07888888ec1abcbd9748fdaa971d7d661b2"},
- {file = "coverage-7.2.7-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31563e97dae5598556600466ad9beea39fb04e0229e61c12eaa206e0aa202063"},
- {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ebba1cd308ef115925421d3e6a586e655ca5a77b5bf41e02eb0e4562a111f2d1"},
- {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:cb017fd1b2603ef59e374ba2063f593abe0fc45f2ad9abdde5b4d83bd922a353"},
- {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d62a5c7dad11015c66fbb9d881bc4caa5b12f16292f857842d9d1871595f4495"},
- {file = "coverage-7.2.7-cp310-cp310-win32.whl", hash = "sha256:ee57190f24fba796e36bb6d3aa8a8783c643d8fa9760c89f7a98ab5455fbf818"},
- {file = "coverage-7.2.7-cp310-cp310-win_amd64.whl", hash = "sha256:f75f7168ab25dd93110c8a8117a22450c19976afbc44234cbf71481094c1b850"},
- {file = "coverage-7.2.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:06a9a2be0b5b576c3f18f1a241f0473575c4a26021b52b2a85263a00f034d51f"},
- {file = "coverage-7.2.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5baa06420f837184130752b7c5ea0808762083bf3487b5038d68b012e5937dbe"},
- {file = "coverage-7.2.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdec9e8cbf13a5bf63290fc6013d216a4c7232efb51548594ca3631a7f13c3a3"},
- {file = "coverage-7.2.7-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:52edc1a60c0d34afa421c9c37078817b2e67a392cab17d97283b64c5833f427f"},
- {file = "coverage-7.2.7-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63426706118b7f5cf6bb6c895dc215d8a418d5952544042c8a2d9fe87fcf09cb"},
- {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:afb17f84d56068a7c29f5fa37bfd38d5aba69e3304af08ee94da8ed5b0865833"},
- {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:48c19d2159d433ccc99e729ceae7d5293fbffa0bdb94952d3579983d1c8c9d97"},
- {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0e1f928eaf5469c11e886fe0885ad2bf1ec606434e79842a879277895a50942a"},
- {file = "coverage-7.2.7-cp311-cp311-win32.whl", hash = "sha256:33d6d3ea29d5b3a1a632b3c4e4f4ecae24ef170b0b9ee493883f2df10039959a"},
- {file = "coverage-7.2.7-cp311-cp311-win_amd64.whl", hash = "sha256:5b7540161790b2f28143191f5f8ec02fb132660ff175b7747b95dcb77ac26562"},
- {file = "coverage-7.2.7-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f2f67fe12b22cd130d34d0ef79206061bfb5eda52feb6ce0dba0644e20a03cf4"},
- {file = "coverage-7.2.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a342242fe22407f3c17f4b499276a02b01e80f861f1682ad1d95b04018e0c0d4"},
- {file = "coverage-7.2.7-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:171717c7cb6b453aebac9a2ef603699da237f341b38eebfee9be75d27dc38e01"},
- {file = "coverage-7.2.7-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49969a9f7ffa086d973d91cec8d2e31080436ef0fb4a359cae927e742abfaaa6"},
- {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b46517c02ccd08092f4fa99f24c3b83d8f92f739b4657b0f146246a0ca6a831d"},
- {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:a3d33a6b3eae87ceaefa91ffdc130b5e8536182cd6dfdbfc1aa56b46ff8c86de"},
- {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:976b9c42fb2a43ebf304fa7d4a310e5f16cc99992f33eced91ef6f908bd8f33d"},
- {file = "coverage-7.2.7-cp312-cp312-win32.whl", hash = "sha256:8de8bb0e5ad103888d65abef8bca41ab93721647590a3f740100cd65c3b00511"},
- {file = "coverage-7.2.7-cp312-cp312-win_amd64.whl", hash = "sha256:9e31cb64d7de6b6f09702bb27c02d1904b3aebfca610c12772452c4e6c21a0d3"},
- {file = "coverage-7.2.7-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:58c2ccc2f00ecb51253cbe5d8d7122a34590fac9646a960d1430d5b15321d95f"},
- {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d22656368f0e6189e24722214ed8d66b8022db19d182927b9a248a2a8a2f67eb"},
- {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a895fcc7b15c3fc72beb43cdcbdf0ddb7d2ebc959edac9cef390b0d14f39f8a9"},
- {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84606b74eb7de6ff581a7915e2dab7a28a0517fbe1c9239eb227e1354064dcd"},
- {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:0a5f9e1dbd7fbe30196578ca36f3fba75376fb99888c395c5880b355e2875f8a"},
- {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:419bfd2caae268623dd469eff96d510a920c90928b60f2073d79f8fe2bbc5959"},
- {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:2aee274c46590717f38ae5e4650988d1af340fe06167546cc32fe2f58ed05b02"},
- {file = "coverage-7.2.7-cp37-cp37m-win32.whl", hash = "sha256:61b9a528fb348373c433e8966535074b802c7a5d7f23c4f421e6c6e2f1697a6f"},
- {file = "coverage-7.2.7-cp37-cp37m-win_amd64.whl", hash = "sha256:b1c546aca0ca4d028901d825015dc8e4d56aac4b541877690eb76490f1dc8ed0"},
- {file = "coverage-7.2.7-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:54b896376ab563bd38453cecb813c295cf347cf5906e8b41d340b0321a5433e5"},
- {file = "coverage-7.2.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3d376df58cc111dc8e21e3b6e24606b5bb5dee6024f46a5abca99124b2229ef5"},
- {file = "coverage-7.2.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e330fc79bd7207e46c7d7fd2bb4af2963f5f635703925543a70b99574b0fea9"},
- {file = "coverage-7.2.7-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e9d683426464e4a252bf70c3498756055016f99ddaec3774bf368e76bbe02b6"},
- {file = "coverage-7.2.7-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d13c64ee2d33eccf7437961b6ea7ad8673e2be040b4f7fd4fd4d4d28d9ccb1e"},
- {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b7aa5f8a41217360e600da646004f878250a0d6738bcdc11a0a39928d7dc2050"},
- {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fa03bce9bfbeeef9f3b160a8bed39a221d82308b4152b27d82d8daa7041fee5"},
- {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:245167dd26180ab4c91d5e1496a30be4cd721a5cf2abf52974f965f10f11419f"},
- {file = "coverage-7.2.7-cp38-cp38-win32.whl", hash = "sha256:d2c2db7fd82e9b72937969bceac4d6ca89660db0a0967614ce2481e81a0b771e"},
- {file = "coverage-7.2.7-cp38-cp38-win_amd64.whl", hash = "sha256:2e07b54284e381531c87f785f613b833569c14ecacdcb85d56b25c4622c16c3c"},
- {file = "coverage-7.2.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:537891ae8ce59ef63d0123f7ac9e2ae0fc8b72c7ccbe5296fec45fd68967b6c9"},
- {file = "coverage-7.2.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:06fb182e69f33f6cd1d39a6c597294cff3143554b64b9825d1dc69d18cc2fff2"},
- {file = "coverage-7.2.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:201e7389591af40950a6480bd9edfa8ed04346ff80002cec1a66cac4549c1ad7"},
- {file = "coverage-7.2.7-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f6951407391b639504e3b3be51b7ba5f3528adbf1a8ac3302b687ecababf929e"},
- {file = "coverage-7.2.7-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f48351d66575f535669306aa7d6d6f71bc43372473b54a832222803eb956fd1"},
- {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b29019c76039dc3c0fd815c41392a044ce555d9bcdd38b0fb60fb4cd8e475ba9"},
- {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:81c13a1fc7468c40f13420732805a4c38a105d89848b7c10af65a90beff25250"},
- {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:975d70ab7e3c80a3fe86001d8751f6778905ec723f5b110aed1e450da9d4b7f2"},
- {file = "coverage-7.2.7-cp39-cp39-win32.whl", hash = "sha256:7ee7d9d4822c8acc74a5e26c50604dff824710bc8de424904c0982e25c39c6cb"},
- {file = "coverage-7.2.7-cp39-cp39-win_amd64.whl", hash = "sha256:eb393e5ebc85245347950143969b241d08b52b88a3dc39479822e073a1a8eb27"},
- {file = "coverage-7.2.7-pp37.pp38.pp39-none-any.whl", hash = "sha256:b7b4c971f05e6ae490fef852c218b0e79d4e52f79ef0c8475566584a8fb3e01d"},
- {file = "coverage-7.2.7.tar.gz", hash = "sha256:924d94291ca674905fe9481f12294eb11f2d3d3fd1adb20314ba89e94f44ed59"},
+ {file = "coverage-7.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2931f66991175369859b5fd58529cd4b73582461877ecfd859b6549869287ffe"},
+ {file = "coverage-7.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:52a523153c568d2c0ef8826f6cc23031dc86cffb8c6aeab92c4ff776e7951b28"},
+ {file = "coverage-7.8.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c8a5c139aae4c35cbd7cadca1df02ea8cf28a911534fc1b0456acb0b14234f3"},
+ {file = "coverage-7.8.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5a26c0c795c3e0b63ec7da6efded5f0bc856d7c0b24b2ac84b4d1d7bc578d676"},
+ {file = "coverage-7.8.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:821f7bcbaa84318287115d54becb1915eece6918136c6f91045bb84e2f88739d"},
+ {file = "coverage-7.8.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a321c61477ff8ee705b8a5fed370b5710c56b3a52d17b983d9215861e37b642a"},
+ {file = "coverage-7.8.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:ed2144b8a78f9d94d9515963ed273d620e07846acd5d4b0a642d4849e8d91a0c"},
+ {file = "coverage-7.8.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:042e7841a26498fff7a37d6fda770d17519982f5b7d8bf5278d140b67b61095f"},
+ {file = "coverage-7.8.0-cp310-cp310-win32.whl", hash = "sha256:f9983d01d7705b2d1f7a95e10bbe4091fabc03a46881a256c2787637b087003f"},
+ {file = "coverage-7.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:5a570cd9bd20b85d1a0d7b009aaf6c110b52b5755c17be6962f8ccd65d1dbd23"},
+ {file = "coverage-7.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e7ac22a0bb2c7c49f441f7a6d46c9c80d96e56f5a8bc6972529ed43c8b694e27"},
+ {file = "coverage-7.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bf13d564d310c156d1c8e53877baf2993fb3073b2fc9f69790ca6a732eb4bfea"},
+ {file = "coverage-7.8.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5761c70c017c1b0d21b0815a920ffb94a670c8d5d409d9b38857874c21f70d7"},
+ {file = "coverage-7.8.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5ff52d790c7e1628241ffbcaeb33e07d14b007b6eb00a19320c7b8a7024c040"},
+ {file = "coverage-7.8.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d39fc4817fd67b3915256af5dda75fd4ee10621a3d484524487e33416c6f3543"},
+ {file = "coverage-7.8.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b44674870709017e4b4036e3d0d6c17f06a0e6d4436422e0ad29b882c40697d2"},
+ {file = "coverage-7.8.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:8f99eb72bf27cbb167b636eb1726f590c00e1ad375002230607a844d9e9a2318"},
+ {file = "coverage-7.8.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b571bf5341ba8c6bc02e0baeaf3b061ab993bf372d982ae509807e7f112554e9"},
+ {file = "coverage-7.8.0-cp311-cp311-win32.whl", hash = "sha256:e75a2ad7b647fd8046d58c3132d7eaf31b12d8a53c0e4b21fa9c4d23d6ee6d3c"},
+ {file = "coverage-7.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:3043ba1c88b2139126fc72cb48574b90e2e0546d4c78b5299317f61b7f718b78"},
+ {file = "coverage-7.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:bbb5cc845a0292e0c520656d19d7ce40e18d0e19b22cb3e0409135a575bf79fc"},
+ {file = "coverage-7.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4dfd9a93db9e78666d178d4f08a5408aa3f2474ad4d0e0378ed5f2ef71640cb6"},
+ {file = "coverage-7.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f017a61399f13aa6d1039f75cd467be388d157cd81f1a119b9d9a68ba6f2830d"},
+ {file = "coverage-7.8.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0915742f4c82208ebf47a2b154a5334155ed9ef9fe6190674b8a46c2fb89cb05"},
+ {file = "coverage-7.8.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a40fcf208e021eb14b0fac6bdb045c0e0cab53105f93ba0d03fd934c956143a"},
+ {file = "coverage-7.8.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a1f406a8e0995d654b2ad87c62caf6befa767885301f3b8f6f73e6f3c31ec3a6"},
+ {file = "coverage-7.8.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:77af0f6447a582fdc7de5e06fa3757a3ef87769fbb0fdbdeba78c23049140a47"},
+ {file = "coverage-7.8.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f2d32f95922927186c6dbc8bc60df0d186b6edb828d299ab10898ef3f40052fe"},
+ {file = "coverage-7.8.0-cp312-cp312-win32.whl", hash = "sha256:769773614e676f9d8e8a0980dd7740f09a6ea386d0f383db6821df07d0f08545"},
+ {file = "coverage-7.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:e5d2b9be5b0693cf21eb4ce0ec8d211efb43966f6657807f6859aab3814f946b"},
+ {file = "coverage-7.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5ac46d0c2dd5820ce93943a501ac5f6548ea81594777ca585bf002aa8854cacd"},
+ {file = "coverage-7.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:771eb7587a0563ca5bb6f622b9ed7f9d07bd08900f7589b4febff05f469bea00"},
+ {file = "coverage-7.8.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42421e04069fb2cbcbca5a696c4050b84a43b05392679d4068acbe65449b5c64"},
+ {file = "coverage-7.8.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:554fec1199d93ab30adaa751db68acec2b41c5602ac944bb19187cb9a41a8067"},
+ {file = "coverage-7.8.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5aaeb00761f985007b38cf463b1d160a14a22c34eb3f6a39d9ad6fc27cb73008"},
+ {file = "coverage-7.8.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:581a40c7b94921fffd6457ffe532259813fc68eb2bdda60fa8cc343414ce3733"},
+ {file = "coverage-7.8.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:f319bae0321bc838e205bf9e5bc28f0a3165f30c203b610f17ab5552cff90323"},
+ {file = "coverage-7.8.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04bfec25a8ef1c5f41f5e7e5c842f6b615599ca8ba8391ec33a9290d9d2db3a3"},
+ {file = "coverage-7.8.0-cp313-cp313-win32.whl", hash = "sha256:dd19608788b50eed889e13a5d71d832edc34fc9dfce606f66e8f9f917eef910d"},
+ {file = "coverage-7.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:a9abbccd778d98e9c7e85038e35e91e67f5b520776781d9a1e2ee9d400869487"},
+ {file = "coverage-7.8.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:18c5ae6d061ad5b3e7eef4363fb27a0576012a7447af48be6c75b88494c6cf25"},
+ {file = "coverage-7.8.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:95aa6ae391a22bbbce1b77ddac846c98c5473de0372ba5c463480043a07bff42"},
+ {file = "coverage-7.8.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e013b07ba1c748dacc2a80e69a46286ff145935f260eb8c72df7185bf048f502"},
+ {file = "coverage-7.8.0-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d766a4f0e5aa1ba056ec3496243150698dc0481902e2b8559314368717be82b1"},
+ {file = "coverage-7.8.0-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad80e6b4a0c3cb6f10f29ae4c60e991f424e6b14219d46f1e7d442b938ee68a4"},
+ {file = "coverage-7.8.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b87eb6fc9e1bb8f98892a2458781348fa37e6925f35bb6ceb9d4afd54ba36c73"},
+ {file = "coverage-7.8.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:d1ba00ae33be84066cfbe7361d4e04dec78445b2b88bdb734d0d1cbab916025a"},
+ {file = "coverage-7.8.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f3c38e4e5ccbdc9198aecc766cedbb134b2d89bf64533973678dfcf07effd883"},
+ {file = "coverage-7.8.0-cp313-cp313t-win32.whl", hash = "sha256:379fe315e206b14e21db5240f89dc0774bdd3e25c3c58c2c733c99eca96f1ada"},
+ {file = "coverage-7.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:2e4b6b87bb0c846a9315e3ab4be2d52fac905100565f4b92f02c445c8799e257"},
+ {file = "coverage-7.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fa260de59dfb143af06dcf30c2be0b200bed2a73737a8a59248fcb9fa601ef0f"},
+ {file = "coverage-7.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:96121edfa4c2dfdda409877ea8608dd01de816a4dc4a0523356067b305e4e17a"},
+ {file = "coverage-7.8.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b8af63b9afa1031c0ef05b217faa598f3069148eeee6bb24b79da9012423b82"},
+ {file = "coverage-7.8.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:89b1f4af0d4afe495cd4787a68e00f30f1d15939f550e869de90a86efa7e0814"},
+ {file = "coverage-7.8.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94ec0be97723ae72d63d3aa41961a0b9a6f5a53ff599813c324548d18e3b9e8c"},
+ {file = "coverage-7.8.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:8a1d96e780bdb2d0cbb297325711701f7c0b6f89199a57f2049e90064c29f6bd"},
+ {file = "coverage-7.8.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:f1d8a2a57b47142b10374902777e798784abf400a004b14f1b0b9eaf1e528ba4"},
+ {file = "coverage-7.8.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:cf60dd2696b457b710dd40bf17ad269d5f5457b96442f7f85722bdb16fa6c899"},
+ {file = "coverage-7.8.0-cp39-cp39-win32.whl", hash = "sha256:be945402e03de47ba1872cd5236395e0f4ad635526185a930735f66710e1bd3f"},
+ {file = "coverage-7.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:90e7fbc6216ecaffa5a880cdc9c77b7418c1dcb166166b78dbc630d07f278cc3"},
+ {file = "coverage-7.8.0-pp39.pp310.pp311-none-any.whl", hash = "sha256:b8194fb8e50d556d5849753de991d390c5a1edeeba50f68e3a9253fbd8bf8ccd"},
+ {file = "coverage-7.8.0-py3-none-any.whl", hash = "sha256:dbf364b4c5e7bae9250528167dfe40219b62e2d573c854d74be213e1e52069f7"},
+ {file = "coverage-7.8.0.tar.gz", hash = "sha256:7a3d62b3b03b4b6fd41a085f3574874cf946cb4604d2b4d3e8dca8cd570ca501"},
]
[package.dependencies]
@@ -375,113 +382,151 @@ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.1
[package.extras]
toml = ["tomli"]
+[[package]]
+name = "exceptiongroup"
+version = "1.2.2"
+description = "Backport of PEP 654 (exception groups)"
+optional = false
+python-versions = ">=3.7"
+groups = ["dev"]
+markers = "python_version < \"3.11\""
+files = [
+ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"},
+ {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
+]
+
+[package.extras]
+test = ["pytest (>=6)"]
+
[[package]]
name = "filelock"
-version = "3.12.2"
+version = "3.18.0"
description = "A platform independent file lock."
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.9"
+groups = ["main"]
files = [
- {file = "filelock-3.12.2-py3-none-any.whl", hash = "sha256:cbb791cdea2a72f23da6ac5b5269ab0a0d161e9ef0100e653b69049a7706d1ec"},
- {file = "filelock-3.12.2.tar.gz", hash = "sha256:002740518d8aa59a26b0c76e10fb8c6e15eae825d34b6fdf670333fd7b938d81"},
+ {file = "filelock-3.18.0-py3-none-any.whl", hash = "sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de"},
+ {file = "filelock-3.18.0.tar.gz", hash = "sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2"},
]
[package.extras]
-docs = ["furo (>=2023.5.20)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"]
-testing = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "diff-cover (>=7.5)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"]
+docs = ["furo (>=2024.8.6)", "sphinx (>=8.1.3)", "sphinx-autodoc-typehints (>=3)"]
+testing = ["covdefaults (>=2.3)", "coverage (>=7.6.10)", "diff-cover (>=9.2.1)", "pytest (>=8.3.4)", "pytest-asyncio (>=0.25.2)", "pytest-cov (>=6)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.28.1)"]
+typing = ["typing-extensions (>=4.12.2)"]
[[package]]
name = "frozenlist"
-version = "1.3.3"
+version = "1.5.0"
description = "A list-like structure which implements collections.abc.MutableSequence"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
+groups = ["main"]
files = [
- {file = "frozenlist-1.3.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ff8bf625fe85e119553b5383ba0fb6aa3d0ec2ae980295aaefa552374926b3f4"},
- {file = "frozenlist-1.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dfbac4c2dfcc082fcf8d942d1e49b6aa0766c19d3358bd86e2000bf0fa4a9cf0"},
- {file = "frozenlist-1.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b1c63e8d377d039ac769cd0926558bb7068a1f7abb0f003e3717ee003ad85530"},
- {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7fdfc24dcfce5b48109867c13b4cb15e4660e7bd7661741a391f821f23dfdca7"},
- {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2c926450857408e42f0bbc295e84395722ce74bae69a3b2aa2a65fe22cb14b99"},
- {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1841e200fdafc3d51f974d9d377c079a0694a8f06de2e67b48150328d66d5483"},
- {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f470c92737afa7d4c3aacc001e335062d582053d4dbe73cda126f2d7031068dd"},
- {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:783263a4eaad7c49983fe4b2e7b53fa9770c136c270d2d4bbb6d2192bf4d9caf"},
- {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:924620eef691990dfb56dc4709f280f40baee568c794b5c1885800c3ecc69816"},
- {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ae4dc05c465a08a866b7a1baf360747078b362e6a6dbeb0c57f234db0ef88ae0"},
- {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:bed331fe18f58d844d39ceb398b77d6ac0b010d571cba8267c2e7165806b00ce"},
- {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:02c9ac843e3390826a265e331105efeab489ffaf4dd86384595ee8ce6d35ae7f"},
- {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9545a33965d0d377b0bc823dcabf26980e77f1b6a7caa368a365a9497fb09420"},
- {file = "frozenlist-1.3.3-cp310-cp310-win32.whl", hash = "sha256:d5cd3ab21acbdb414bb6c31958d7b06b85eeb40f66463c264a9b343a4e238642"},
- {file = "frozenlist-1.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:b756072364347cb6aa5b60f9bc18e94b2f79632de3b0190253ad770c5df17db1"},
- {file = "frozenlist-1.3.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:b4395e2f8d83fbe0c627b2b696acce67868793d7d9750e90e39592b3626691b7"},
- {file = "frozenlist-1.3.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:14143ae966a6229350021384870458e4777d1eae4c28d1a7aa47f24d030e6678"},
- {file = "frozenlist-1.3.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5d8860749e813a6f65bad8285a0520607c9500caa23fea6ee407e63debcdbef6"},
- {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23d16d9f477bb55b6154654e0e74557040575d9d19fe78a161bd33d7d76808e8"},
- {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eb82dbba47a8318e75f679690190c10a5e1f447fbf9df41cbc4c3afd726d88cb"},
- {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9309869032abb23d196cb4e4db574232abe8b8be1339026f489eeb34a4acfd91"},
- {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a97b4fe50b5890d36300820abd305694cb865ddb7885049587a5678215782a6b"},
- {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c188512b43542b1e91cadc3c6c915a82a5eb95929134faf7fd109f14f9892ce4"},
- {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:303e04d422e9b911a09ad499b0368dc551e8c3cd15293c99160c7f1f07b59a48"},
- {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:0771aed7f596c7d73444c847a1c16288937ef988dc04fb9f7be4b2aa91db609d"},
- {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:66080ec69883597e4d026f2f71a231a1ee9887835902dbe6b6467d5a89216cf6"},
- {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:41fe21dc74ad3a779c3d73a2786bdf622ea81234bdd4faf90b8b03cad0c2c0b4"},
- {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f20380df709d91525e4bee04746ba612a4df0972c1b8f8e1e8af997e678c7b81"},
- {file = "frozenlist-1.3.3-cp311-cp311-win32.whl", hash = "sha256:f30f1928162e189091cf4d9da2eac617bfe78ef907a761614ff577ef4edfb3c8"},
- {file = "frozenlist-1.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:a6394d7dadd3cfe3f4b3b186e54d5d8504d44f2d58dcc89d693698e8b7132b32"},
- {file = "frozenlist-1.3.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8df3de3a9ab8325f94f646609a66cbeeede263910c5c0de0101079ad541af332"},
- {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0693c609e9742c66ba4870bcee1ad5ff35462d5ffec18710b4ac89337ff16e27"},
- {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cd4210baef299717db0a600d7a3cac81d46ef0e007f88c9335db79f8979c0d3d"},
- {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:394c9c242113bfb4b9aa36e2b80a05ffa163a30691c7b5a29eba82e937895d5e"},
- {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6327eb8e419f7d9c38f333cde41b9ae348bec26d840927332f17e887a8dcb70d"},
- {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e24900aa13212e75e5b366cb9065e78bbf3893d4baab6052d1aca10d46d944c"},
- {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3843f84a6c465a36559161e6c59dce2f2ac10943040c2fd021cfb70d58c4ad56"},
- {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:84610c1502b2461255b4c9b7d5e9c48052601a8957cd0aea6ec7a7a1e1fb9420"},
- {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:c21b9aa40e08e4f63a2f92ff3748e6b6c84d717d033c7b3438dd3123ee18f70e"},
- {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:efce6ae830831ab6a22b9b4091d411698145cb9b8fc869e1397ccf4b4b6455cb"},
- {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:40de71985e9042ca00b7953c4f41eabc3dc514a2d1ff534027f091bc74416401"},
- {file = "frozenlist-1.3.3-cp37-cp37m-win32.whl", hash = "sha256:180c00c66bde6146a860cbb81b54ee0df350d2daf13ca85b275123bbf85de18a"},
- {file = "frozenlist-1.3.3-cp37-cp37m-win_amd64.whl", hash = "sha256:9bbbcedd75acdfecf2159663b87f1bb5cfc80e7cd99f7ddd9d66eb98b14a8411"},
- {file = "frozenlist-1.3.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:034a5c08d36649591be1cbb10e09da9f531034acfe29275fc5454a3b101ce41a"},
- {file = "frozenlist-1.3.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ba64dc2b3b7b158c6660d49cdb1d872d1d0bf4e42043ad8d5006099479a194e5"},
- {file = "frozenlist-1.3.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:47df36a9fe24054b950bbc2db630d508cca3aa27ed0566c0baf661225e52c18e"},
- {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:008a054b75d77c995ea26629ab3a0c0d7281341f2fa7e1e85fa6153ae29ae99c"},
- {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:841ea19b43d438a80b4de62ac6ab21cfe6827bb8a9dc62b896acc88eaf9cecba"},
- {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e235688f42b36be2b6b06fc37ac2126a73b75fb8d6bc66dd632aa35286238703"},
- {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca713d4af15bae6e5d79b15c10c8522859a9a89d3b361a50b817c98c2fb402a2"},
- {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ac5995f2b408017b0be26d4a1d7c61bce106ff3d9e3324374d66b5964325448"},
- {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a4ae8135b11652b08a8baf07631d3ebfe65a4c87909dbef5fa0cdde440444ee4"},
- {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:4ea42116ceb6bb16dbb7d526e242cb6747b08b7710d9782aa3d6732bd8d27649"},
- {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:810860bb4bdce7557bc0febb84bbd88198b9dbc2022d8eebe5b3590b2ad6c842"},
- {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:ee78feb9d293c323b59a6f2dd441b63339a30edf35abcb51187d2fc26e696d13"},
- {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:0af2e7c87d35b38732e810befb9d797a99279cbb85374d42ea61c1e9d23094b3"},
- {file = "frozenlist-1.3.3-cp38-cp38-win32.whl", hash = "sha256:899c5e1928eec13fd6f6d8dc51be23f0d09c5281e40d9cf4273d188d9feeaf9b"},
- {file = "frozenlist-1.3.3-cp38-cp38-win_amd64.whl", hash = "sha256:7f44e24fa70f6fbc74aeec3e971f60a14dde85da364aa87f15d1be94ae75aeef"},
- {file = "frozenlist-1.3.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:2b07ae0c1edaa0a36339ec6cce700f51b14a3fc6545fdd32930d2c83917332cf"},
- {file = "frozenlist-1.3.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ebb86518203e12e96af765ee89034a1dbb0c3c65052d1b0c19bbbd6af8a145e1"},
- {file = "frozenlist-1.3.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5cf820485f1b4c91e0417ea0afd41ce5cf5965011b3c22c400f6d144296ccbc0"},
- {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c11e43016b9024240212d2a65043b70ed8dfd3b52678a1271972702d990ac6d"},
- {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8fa3c6e3305aa1146b59a09b32b2e04074945ffcfb2f0931836d103a2c38f936"},
- {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:352bd4c8c72d508778cf05ab491f6ef36149f4d0cb3c56b1b4302852255d05d5"},
- {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:65a5e4d3aa679610ac6e3569e865425b23b372277f89b5ef06cf2cdaf1ebf22b"},
- {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1e2c1185858d7e10ff045c496bbf90ae752c28b365fef2c09cf0fa309291669"},
- {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f163d2fd041c630fed01bc48d28c3ed4a3b003c00acd396900e11ee5316b56bb"},
- {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:05cdb16d09a0832eedf770cb7bd1fe57d8cf4eaf5aced29c4e41e3f20b30a784"},
- {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:8bae29d60768bfa8fb92244b74502b18fae55a80eac13c88eb0b496d4268fd2d"},
- {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:eedab4c310c0299961ac285591acd53dc6723a1ebd90a57207c71f6e0c2153ab"},
- {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3bbdf44855ed8f0fbcd102ef05ec3012d6a4fd7c7562403f76ce6a52aeffb2b1"},
- {file = "frozenlist-1.3.3-cp39-cp39-win32.whl", hash = "sha256:efa568b885bca461f7c7b9e032655c0c143d305bf01c30caf6db2854a4532b38"},
- {file = "frozenlist-1.3.3-cp39-cp39-win_amd64.whl", hash = "sha256:cfe33efc9cb900a4c46f91a5ceba26d6df370ffddd9ca386eb1d4f0ad97b9ea9"},
- {file = "frozenlist-1.3.3.tar.gz", hash = "sha256:58bcc55721e8a90b88332d6cd441261ebb22342e238296bb330968952fbb3a6a"},
+ {file = "frozenlist-1.5.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5b6a66c18b5b9dd261ca98dffcb826a525334b2f29e7caa54e182255c5f6a65a"},
+ {file = "frozenlist-1.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d1b3eb7b05ea246510b43a7e53ed1653e55c2121019a97e60cad7efb881a97bb"},
+ {file = "frozenlist-1.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:15538c0cbf0e4fa11d1e3a71f823524b0c46299aed6e10ebb4c2089abd8c3bec"},
+ {file = "frozenlist-1.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e79225373c317ff1e35f210dd5f1344ff31066ba8067c307ab60254cd3a78ad5"},
+ {file = "frozenlist-1.5.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9272fa73ca71266702c4c3e2d4a28553ea03418e591e377a03b8e3659d94fa76"},
+ {file = "frozenlist-1.5.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:498524025a5b8ba81695761d78c8dd7382ac0b052f34e66939c42df860b8ff17"},
+ {file = "frozenlist-1.5.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:92b5278ed9d50fe610185ecd23c55d8b307d75ca18e94c0e7de328089ac5dcba"},
+ {file = "frozenlist-1.5.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f3c8c1dacd037df16e85227bac13cca58c30da836c6f936ba1df0c05d046d8d"},
+ {file = "frozenlist-1.5.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f2ac49a9bedb996086057b75bf93538240538c6d9b38e57c82d51f75a73409d2"},
+ {file = "frozenlist-1.5.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e66cc454f97053b79c2ab09c17fbe3c825ea6b4de20baf1be28919460dd7877f"},
+ {file = "frozenlist-1.5.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:5a3ba5f9a0dfed20337d3e966dc359784c9f96503674c2faf015f7fe8e96798c"},
+ {file = "frozenlist-1.5.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:6321899477db90bdeb9299ac3627a6a53c7399c8cd58d25da094007402b039ab"},
+ {file = "frozenlist-1.5.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:76e4753701248476e6286f2ef492af900ea67d9706a0155335a40ea21bf3b2f5"},
+ {file = "frozenlist-1.5.0-cp310-cp310-win32.whl", hash = "sha256:977701c081c0241d0955c9586ffdd9ce44f7a7795df39b9151cd9a6fd0ce4cfb"},
+ {file = "frozenlist-1.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:189f03b53e64144f90990d29a27ec4f7997d91ed3d01b51fa39d2dbe77540fd4"},
+ {file = "frozenlist-1.5.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:fd74520371c3c4175142d02a976aee0b4cb4a7cc912a60586ffd8d5929979b30"},
+ {file = "frozenlist-1.5.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2f3f7a0fbc219fb4455264cae4d9f01ad41ae6ee8524500f381de64ffaa077d5"},
+ {file = "frozenlist-1.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f47c9c9028f55a04ac254346e92977bf0f166c483c74b4232bee19a6697e4778"},
+ {file = "frozenlist-1.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0996c66760924da6e88922756d99b47512a71cfd45215f3570bf1e0b694c206a"},
+ {file = "frozenlist-1.5.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a2fe128eb4edeabe11896cb6af88fca5346059f6c8d807e3b910069f39157869"},
+ {file = "frozenlist-1.5.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1a8ea951bbb6cacd492e3948b8da8c502a3f814f5d20935aae74b5df2b19cf3d"},
+ {file = "frozenlist-1.5.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de537c11e4aa01d37db0d403b57bd6f0546e71a82347a97c6a9f0dcc532b3a45"},
+ {file = "frozenlist-1.5.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c2623347b933fcb9095841f1cc5d4ff0b278addd743e0e966cb3d460278840d"},
+ {file = "frozenlist-1.5.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:cee6798eaf8b1416ef6909b06f7dc04b60755206bddc599f52232606e18179d3"},
+ {file = "frozenlist-1.5.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:f5f9da7f5dbc00a604fe74aa02ae7c98bcede8a3b8b9666f9f86fc13993bc71a"},
+ {file = "frozenlist-1.5.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:90646abbc7a5d5c7c19461d2e3eeb76eb0b204919e6ece342feb6032c9325ae9"},
+ {file = "frozenlist-1.5.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:bdac3c7d9b705d253b2ce370fde941836a5f8b3c5c2b8fd70940a3ea3af7f4f2"},
+ {file = "frozenlist-1.5.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:03d33c2ddbc1816237a67f66336616416e2bbb6beb306e5f890f2eb22b959cdf"},
+ {file = "frozenlist-1.5.0-cp311-cp311-win32.whl", hash = "sha256:237f6b23ee0f44066219dae14c70ae38a63f0440ce6750f868ee08775073f942"},
+ {file = "frozenlist-1.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:0cc974cc93d32c42e7b0f6cf242a6bd941c57c61b618e78b6c0a96cb72788c1d"},
+ {file = "frozenlist-1.5.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:31115ba75889723431aa9a4e77d5f398f5cf976eea3bdf61749731f62d4a4a21"},
+ {file = "frozenlist-1.5.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7437601c4d89d070eac8323f121fcf25f88674627505334654fd027b091db09d"},
+ {file = "frozenlist-1.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7948140d9f8ece1745be806f2bfdf390127cf1a763b925c4a805c603df5e697e"},
+ {file = "frozenlist-1.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:feeb64bc9bcc6b45c6311c9e9b99406660a9c05ca8a5b30d14a78555088b0b3a"},
+ {file = "frozenlist-1.5.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:683173d371daad49cffb8309779e886e59c2f369430ad28fe715f66d08d4ab1a"},
+ {file = "frozenlist-1.5.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7d57d8f702221405a9d9b40f9da8ac2e4a1a8b5285aac6100f3393675f0a85ee"},
+ {file = "frozenlist-1.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:30c72000fbcc35b129cb09956836c7d7abf78ab5416595e4857d1cae8d6251a6"},
+ {file = "frozenlist-1.5.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:000a77d6034fbad9b6bb880f7ec073027908f1b40254b5d6f26210d2dab1240e"},
+ {file = "frozenlist-1.5.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5d7f5a50342475962eb18b740f3beecc685a15b52c91f7d975257e13e029eca9"},
+ {file = "frozenlist-1.5.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:87f724d055eb4785d9be84e9ebf0f24e392ddfad00b3fe036e43f489fafc9039"},
+ {file = "frozenlist-1.5.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:6e9080bb2fb195a046e5177f10d9d82b8a204c0736a97a153c2466127de87784"},
+ {file = "frozenlist-1.5.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:9b93d7aaa36c966fa42efcaf716e6b3900438632a626fb09c049f6a2f09fc631"},
+ {file = "frozenlist-1.5.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:52ef692a4bc60a6dd57f507429636c2af8b6046db8b31b18dac02cbc8f507f7f"},
+ {file = "frozenlist-1.5.0-cp312-cp312-win32.whl", hash = "sha256:29d94c256679247b33a3dc96cce0f93cbc69c23bf75ff715919332fdbb6a32b8"},
+ {file = "frozenlist-1.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:8969190d709e7c48ea386db202d708eb94bdb29207a1f269bab1196ce0dcca1f"},
+ {file = "frozenlist-1.5.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:7a1a048f9215c90973402e26c01d1cff8a209e1f1b53f72b95c13db61b00f953"},
+ {file = "frozenlist-1.5.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:dd47a5181ce5fcb463b5d9e17ecfdb02b678cca31280639255ce9d0e5aa67af0"},
+ {file = "frozenlist-1.5.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1431d60b36d15cda188ea222033eec8e0eab488f39a272461f2e6d9e1a8e63c2"},
+ {file = "frozenlist-1.5.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6482a5851f5d72767fbd0e507e80737f9c8646ae7fd303def99bfe813f76cf7f"},
+ {file = "frozenlist-1.5.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:44c49271a937625619e862baacbd037a7ef86dd1ee215afc298a417ff3270608"},
+ {file = "frozenlist-1.5.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:12f78f98c2f1c2429d42e6a485f433722b0061d5c0b0139efa64f396efb5886b"},
+ {file = "frozenlist-1.5.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce3aa154c452d2467487765e3adc730a8c153af77ad84096bc19ce19a2400840"},
+ {file = "frozenlist-1.5.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b7dc0c4338e6b8b091e8faf0db3168a37101943e687f373dce00959583f7439"},
+ {file = "frozenlist-1.5.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:45e0896250900b5aa25180f9aec243e84e92ac84bd4a74d9ad4138ef3f5c97de"},
+ {file = "frozenlist-1.5.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:561eb1c9579d495fddb6da8959fd2a1fca2c6d060d4113f5844b433fc02f2641"},
+ {file = "frozenlist-1.5.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:df6e2f325bfee1f49f81aaac97d2aa757c7646534a06f8f577ce184afe2f0a9e"},
+ {file = "frozenlist-1.5.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:140228863501b44b809fb39ec56b5d4071f4d0aa6d216c19cbb08b8c5a7eadb9"},
+ {file = "frozenlist-1.5.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7707a25d6a77f5d27ea7dc7d1fc608aa0a478193823f88511ef5e6b8a48f9d03"},
+ {file = "frozenlist-1.5.0-cp313-cp313-win32.whl", hash = "sha256:31a9ac2b38ab9b5a8933b693db4939764ad3f299fcaa931a3e605bc3460e693c"},
+ {file = "frozenlist-1.5.0-cp313-cp313-win_amd64.whl", hash = "sha256:11aabdd62b8b9c4b84081a3c246506d1cddd2dd93ff0ad53ede5defec7886b28"},
+ {file = "frozenlist-1.5.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:dd94994fc91a6177bfaafd7d9fd951bc8689b0a98168aa26b5f543868548d3ca"},
+ {file = "frozenlist-1.5.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2d0da8bbec082bf6bf18345b180958775363588678f64998c2b7609e34719b10"},
+ {file = "frozenlist-1.5.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:73f2e31ea8dd7df61a359b731716018c2be196e5bb3b74ddba107f694fbd7604"},
+ {file = "frozenlist-1.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:828afae9f17e6de596825cf4228ff28fbdf6065974e5ac1410cecc22f699d2b3"},
+ {file = "frozenlist-1.5.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f1577515d35ed5649d52ab4319db757bb881ce3b2b796d7283e6634d99ace307"},
+ {file = "frozenlist-1.5.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2150cc6305a2c2ab33299453e2968611dacb970d2283a14955923062c8d00b10"},
+ {file = "frozenlist-1.5.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a72b7a6e3cd2725eff67cd64c8f13335ee18fc3c7befc05aed043d24c7b9ccb9"},
+ {file = "frozenlist-1.5.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c16d2fa63e0800723139137d667e1056bee1a1cf7965153d2d104b62855e9b99"},
+ {file = "frozenlist-1.5.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:17dcc32fc7bda7ce5875435003220a457bcfa34ab7924a49a1c19f55b6ee185c"},
+ {file = "frozenlist-1.5.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:97160e245ea33d8609cd2b8fd997c850b56db147a304a262abc2b3be021a9171"},
+ {file = "frozenlist-1.5.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:f1e6540b7fa044eee0bb5111ada694cf3dc15f2b0347ca125ee9ca984d5e9e6e"},
+ {file = "frozenlist-1.5.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:91d6c171862df0a6c61479d9724f22efb6109111017c87567cfeb7b5d1449fdf"},
+ {file = "frozenlist-1.5.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c1fac3e2ace2eb1052e9f7c7db480818371134410e1f5c55d65e8f3ac6d1407e"},
+ {file = "frozenlist-1.5.0-cp38-cp38-win32.whl", hash = "sha256:b97f7b575ab4a8af9b7bc1d2ef7f29d3afee2226bd03ca3875c16451ad5a7723"},
+ {file = "frozenlist-1.5.0-cp38-cp38-win_amd64.whl", hash = "sha256:374ca2dabdccad8e2a76d40b1d037f5bd16824933bf7bcea3e59c891fd4a0923"},
+ {file = "frozenlist-1.5.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:9bbcdfaf4af7ce002694a4e10a0159d5a8d20056a12b05b45cea944a4953f972"},
+ {file = "frozenlist-1.5.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1893f948bf6681733aaccf36c5232c231e3b5166d607c5fa77773611df6dc336"},
+ {file = "frozenlist-1.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2b5e23253bb709ef57a8e95e6ae48daa9ac5f265637529e4ce6b003a37b2621f"},
+ {file = "frozenlist-1.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f253985bb515ecd89629db13cb58d702035ecd8cfbca7d7a7e29a0e6d39af5f"},
+ {file = "frozenlist-1.5.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:04a5c6babd5e8fb7d3c871dc8b321166b80e41b637c31a995ed844a6139942b6"},
+ {file = "frozenlist-1.5.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a9fe0f1c29ba24ba6ff6abf688cb0b7cf1efab6b6aa6adc55441773c252f7411"},
+ {file = "frozenlist-1.5.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:226d72559fa19babe2ccd920273e767c96a49b9d3d38badd7c91a0fdeda8ea08"},
+ {file = "frozenlist-1.5.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15b731db116ab3aedec558573c1a5eec78822b32292fe4f2f0345b7f697745c2"},
+ {file = "frozenlist-1.5.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:366d8f93e3edfe5a918c874702f78faac300209a4d5bf38352b2c1bdc07a766d"},
+ {file = "frozenlist-1.5.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:1b96af8c582b94d381a1c1f51ffaedeb77c821c690ea5f01da3d70a487dd0a9b"},
+ {file = "frozenlist-1.5.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:c03eff4a41bd4e38415cbed054bbaff4a075b093e2394b6915dca34a40d1e38b"},
+ {file = "frozenlist-1.5.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:50cf5e7ee9b98f22bdecbabf3800ae78ddcc26e4a435515fc72d97903e8488e0"},
+ {file = "frozenlist-1.5.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1e76bfbc72353269c44e0bc2cfe171900fbf7f722ad74c9a7b638052afe6a00c"},
+ {file = "frozenlist-1.5.0-cp39-cp39-win32.whl", hash = "sha256:666534d15ba8f0fda3f53969117383d5dc021266b3c1a42c9ec4855e4b58b9d3"},
+ {file = "frozenlist-1.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:5c28f4b5dbef8a0d8aad0d4de24d1e9e981728628afaf4ea0792f5d0939372f0"},
+ {file = "frozenlist-1.5.0-py3-none-any.whl", hash = "sha256:d994863bba198a4a518b467bb971c56e1db3f180a25c6cf7bb1949c267f748c3"},
+ {file = "frozenlist-1.5.0.tar.gz", hash = "sha256:81d5af29e61b9c8348e876d442253723928dce6433e0e76cd925cd83f1b4b817"},
]
[[package]]
name = "fsspec"
-version = "2023.1.0"
+version = "2025.3.2"
description = "File-system specification"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.9"
+groups = ["main"]
files = [
- {file = "fsspec-2023.1.0-py3-none-any.whl", hash = "sha256:b833e2e541e9e8cde0ab549414187871243177feb3d344f9d27b25a93f5d8139"},
- {file = "fsspec-2023.1.0.tar.gz", hash = "sha256:fbae7f20ff801eb5f7d0bedf81f25c787c0dfac5e982d98fa3884a9cde2b5411"},
+ {file = "fsspec-2025.3.2-py3-none-any.whl", hash = "sha256:2daf8dc3d1dfa65b6aa37748d112773a7a08416f6c70d96b264c96476ecaf711"},
+ {file = "fsspec-2025.3.2.tar.gz", hash = "sha256:e52c77ef398680bbd6a98c0e628fbc469491282981209907bbc8aea76a04fdc6"},
]
[package.extras]
@@ -489,8 +534,10 @@ abfs = ["adlfs"]
adl = ["adlfs"]
arrow = ["pyarrow (>=1)"]
dask = ["dask", "distributed"]
+dev = ["pre-commit", "ruff"]
+doc = ["numpydoc", "sphinx", "sphinx-design", "sphinx-rtd-theme", "yarl"]
dropbox = ["dropbox", "dropboxdrivefs", "requests"]
-entrypoints = ["importlib-metadata"]
+full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"]
fuse = ["fusepy"]
gcs = ["gcsfs"]
git = ["pygit2"]
@@ -498,30 +545,33 @@ github = ["requests"]
gs = ["gcsfs"]
gui = ["panel"]
hdfs = ["pyarrow (>=1)"]
-http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "requests"]
+http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)"]
libarchive = ["libarchive-c"]
oci = ["ocifs"]
s3 = ["s3fs"]
sftp = ["paramiko"]
smb = ["smbprotocol"]
ssh = ["paramiko"]
+test = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "numpy", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "requests"]
+test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask[dataframe,test]", "moto[server] (>4,<5)", "pytest-timeout", "xarray"]
+test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"]
tqdm = ["tqdm"]
[[package]]
name = "huggingface-hub"
-version = "0.16.4"
+version = "0.30.2"
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
optional = false
-python-versions = ">=3.7.0"
+python-versions = ">=3.8.0"
+groups = ["main"]
files = [
- {file = "huggingface_hub-0.16.4-py3-none-any.whl", hash = "sha256:0d3df29932f334fead024afc7cb4cc5149d955238b8b5e42dcf9740d6995a349"},
- {file = "huggingface_hub-0.16.4.tar.gz", hash = "sha256:608c7d4f3d368b326d1747f91523dbd1f692871e8e2e7a4750314a2dd8b63e14"},
+ {file = "huggingface_hub-0.30.2-py3-none-any.whl", hash = "sha256:68ff05969927058cfa41df4f2155d4bb48f5f54f719dd0390103eefa9b191e28"},
+ {file = "huggingface_hub-0.30.2.tar.gz", hash = "sha256:9a7897c5b6fd9dad3168a794a8998d6378210f5b9688d0dfc180b1a228dc2466"},
]
[package.dependencies]
filelock = "*"
-fsspec = "*"
-importlib-metadata = {version = "*", markers = "python_version < \"3.8\""}
+fsspec = ">=2023.5.0"
packaging = ">=20.9"
pyyaml = ">=5.1"
requests = "*"
@@ -529,314 +579,429 @@ tqdm = ">=4.42.1"
typing-extensions = ">=3.7.4.3"
[package.extras]
-all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"]
+all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
cli = ["InquirerPy (==0.3.4)"]
-dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"]
+dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
-inference = ["aiohttp", "pydantic"]
-quality = ["black (>=23.1,<24.0)", "mypy (==0.982)", "ruff (>=0.0.241)"]
+hf-transfer = ["hf-transfer (>=0.1.4)"]
+hf-xet = ["hf-xet (>=0.1.4)"]
+inference = ["aiohttp"]
+quality = ["libcst (==1.4.0)", "mypy (==1.5.1)", "ruff (>=0.9.0)"]
tensorflow = ["graphviz", "pydot", "tensorflow"]
-testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
-torch = ["torch"]
-typing = ["pydantic", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
+tensorflow-testing = ["keras (<3.0)", "tensorflow"]
+testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
+torch = ["safetensors[torch]", "torch"]
+typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"]
[[package]]
name = "idna"
-version = "3.4"
+version = "3.10"
description = "Internationalized Domain Names in Applications (IDNA)"
optional = false
-python-versions = ">=3.5"
-files = [
- {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"},
- {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"},
-]
-
-[[package]]
-name = "importlib-metadata"
-version = "6.7.0"
-description = "Read metadata from Python packages"
-optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.6"
+groups = ["main"]
files = [
- {file = "importlib_metadata-6.7.0-py3-none-any.whl", hash = "sha256:cb52082e659e97afc5dac71e79de97d8681de3aa07ff18578330904a9d18e5b5"},
- {file = "importlib_metadata-6.7.0.tar.gz", hash = "sha256:1aaf550d4f73e5d6783e7acb77aec43d49da8017410afae93822cc9cca98c4d4"},
+ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"},
+ {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"},
]
-[package.dependencies]
-typing-extensions = {version = ">=3.6.4", markers = "python_version < \"3.8\""}
-zipp = ">=0.5"
-
[package.extras]
-docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
-perf = ["ipython"]
-testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"]
+all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"]
[[package]]
name = "iniconfig"
-version = "2.0.0"
+version = "2.1.0"
description = "brain-dead simple config-ini parsing"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
+groups = ["dev"]
files = [
- {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
- {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
+ {file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"},
+ {file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"},
]
[[package]]
name = "multidict"
-version = "6.0.4"
+version = "6.4.3"
description = "multidict implementation"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.9"
+groups = ["main"]
files = [
- {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b1a97283e0c85772d613878028fec909f003993e1007eafa715b24b377cb9b8"},
- {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eeb6dcc05e911516ae3d1f207d4b0520d07f54484c49dfc294d6e7d63b734171"},
- {file = "multidict-6.0.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d6d635d5209b82a3492508cf5b365f3446afb65ae7ebd755e70e18f287b0adf7"},
- {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c048099e4c9e9d615545e2001d3d8a4380bd403e1a0578734e0d31703d1b0c0b"},
- {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea20853c6dbbb53ed34cb4d080382169b6f4554d394015f1bef35e881bf83547"},
- {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:16d232d4e5396c2efbbf4f6d4df89bfa905eb0d4dc5b3549d872ab898451f569"},
- {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36c63aaa167f6c6b04ef2c85704e93af16c11d20de1d133e39de6a0e84582a93"},
- {file = "multidict-6.0.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:64bdf1086b6043bf519869678f5f2757f473dee970d7abf6da91ec00acb9cb98"},
- {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:43644e38f42e3af682690876cff722d301ac585c5b9e1eacc013b7a3f7b696a0"},
- {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7582a1d1030e15422262de9f58711774e02fa80df0d1578995c76214f6954988"},
- {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ddff9c4e225a63a5afab9dd15590432c22e8057e1a9a13d28ed128ecf047bbdc"},
- {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ee2a1ece51b9b9e7752e742cfb661d2a29e7bcdba2d27e66e28a99f1890e4fa0"},
- {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a2e4369eb3d47d2034032a26c7a80fcb21a2cb22e1173d761a162f11e562caa5"},
- {file = "multidict-6.0.4-cp310-cp310-win32.whl", hash = "sha256:574b7eae1ab267e5f8285f0fe881f17efe4b98c39a40858247720935b893bba8"},
- {file = "multidict-6.0.4-cp310-cp310-win_amd64.whl", hash = "sha256:4dcbb0906e38440fa3e325df2359ac6cb043df8e58c965bb45f4e406ecb162cc"},
- {file = "multidict-6.0.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0dfad7a5a1e39c53ed00d2dd0c2e36aed4650936dc18fd9a1826a5ae1cad6f03"},
- {file = "multidict-6.0.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:64da238a09d6039e3bd39bb3aee9c21a5e34f28bfa5aa22518581f910ff94af3"},
- {file = "multidict-6.0.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ff959bee35038c4624250473988b24f846cbeb2c6639de3602c073f10410ceba"},
- {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:01a3a55bd90018c9c080fbb0b9f4891db37d148a0a18722b42f94694f8b6d4c9"},
- {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5cb09abb18c1ea940fb99360ea0396f34d46566f157122c92dfa069d3e0e982"},
- {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:666daae833559deb2d609afa4490b85830ab0dfca811a98b70a205621a6109fe"},
- {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11bdf3f5e1518b24530b8241529d2050014c884cf18b6fc69c0c2b30ca248710"},
- {file = "multidict-6.0.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d18748f2d30f94f498e852c67d61261c643b349b9d2a581131725595c45ec6c"},
- {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:458f37be2d9e4c95e2d8866a851663cbc76e865b78395090786f6cd9b3bbf4f4"},
- {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:b1a2eeedcead3a41694130495593a559a668f382eee0727352b9a41e1c45759a"},
- {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7d6ae9d593ef8641544d6263c7fa6408cc90370c8cb2bbb65f8d43e5b0351d9c"},
- {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:5979b5632c3e3534e42ca6ff856bb24b2e3071b37861c2c727ce220d80eee9ed"},
- {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:dcfe792765fab89c365123c81046ad4103fcabbc4f56d1c1997e6715e8015461"},
- {file = "multidict-6.0.4-cp311-cp311-win32.whl", hash = "sha256:3601a3cece3819534b11d4efc1eb76047488fddd0c85a3948099d5da4d504636"},
- {file = "multidict-6.0.4-cp311-cp311-win_amd64.whl", hash = "sha256:81a4f0b34bd92df3da93315c6a59034df95866014ac08535fc819f043bfd51f0"},
- {file = "multidict-6.0.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:67040058f37a2a51ed8ea8f6b0e6ee5bd78ca67f169ce6122f3e2ec80dfe9b78"},
- {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:853888594621e6604c978ce2a0444a1e6e70c8d253ab65ba11657659dcc9100f"},
- {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:39ff62e7d0f26c248b15e364517a72932a611a9b75f35b45be078d81bdb86603"},
- {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af048912e045a2dc732847d33821a9d84ba553f5c5f028adbd364dd4765092ac"},
- {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1e8b901e607795ec06c9e42530788c45ac21ef3aaa11dbd0c69de543bfb79a9"},
- {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62501642008a8b9871ddfccbf83e4222cf8ac0d5aeedf73da36153ef2ec222d2"},
- {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:99b76c052e9f1bc0721f7541e5e8c05db3941eb9ebe7b8553c625ef88d6eefde"},
- {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:509eac6cf09c794aa27bcacfd4d62c885cce62bef7b2c3e8b2e49d365b5003fe"},
- {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:21a12c4eb6ddc9952c415f24eef97e3e55ba3af61f67c7bc388dcdec1404a067"},
- {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:5cad9430ab3e2e4fa4a2ef4450f548768400a2ac635841bc2a56a2052cdbeb87"},
- {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ab55edc2e84460694295f401215f4a58597f8f7c9466faec545093045476327d"},
- {file = "multidict-6.0.4-cp37-cp37m-win32.whl", hash = "sha256:5a4dcf02b908c3b8b17a45fb0f15b695bf117a67b76b7ad18b73cf8e92608775"},
- {file = "multidict-6.0.4-cp37-cp37m-win_amd64.whl", hash = "sha256:6ed5f161328b7df384d71b07317f4d8656434e34591f20552c7bcef27b0ab88e"},
- {file = "multidict-6.0.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5fc1b16f586f049820c5c5b17bb4ee7583092fa0d1c4e28b5239181ff9532e0c"},
- {file = "multidict-6.0.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1502e24330eb681bdaa3eb70d6358e818e8e8f908a22a1851dfd4e15bc2f8161"},
- {file = "multidict-6.0.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b692f419760c0e65d060959df05f2a531945af31fda0c8a3b3195d4efd06de11"},
- {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45e1ecb0379bfaab5eef059f50115b54571acfbe422a14f668fc8c27ba410e7e"},
- {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ddd3915998d93fbcd2566ddf9cf62cdb35c9e093075f862935573d265cf8f65d"},
- {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:59d43b61c59d82f2effb39a93c48b845efe23a3852d201ed2d24ba830d0b4cf2"},
- {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc8e1d0c705233c5dd0c5e6460fbad7827d5d36f310a0fadfd45cc3029762258"},
- {file = "multidict-6.0.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6aa0418fcc838522256761b3415822626f866758ee0bc6632c9486b179d0b52"},
- {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6748717bb10339c4760c1e63da040f5f29f5ed6e59d76daee30305894069a660"},
- {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:4d1a3d7ef5e96b1c9e92f973e43aa5e5b96c659c9bc3124acbbd81b0b9c8a951"},
- {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4372381634485bec7e46718edc71528024fcdc6f835baefe517b34a33c731d60"},
- {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:fc35cb4676846ef752816d5be2193a1e8367b4c1397b74a565a9d0389c433a1d"},
- {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:4b9d9e4e2b37daddb5c23ea33a3417901fa7c7b3dee2d855f63ee67a0b21e5b1"},
- {file = "multidict-6.0.4-cp38-cp38-win32.whl", hash = "sha256:e41b7e2b59679edfa309e8db64fdf22399eec4b0b24694e1b2104fb789207779"},
- {file = "multidict-6.0.4-cp38-cp38-win_amd64.whl", hash = "sha256:d6c254ba6e45d8e72739281ebc46ea5eb5f101234f3ce171f0e9f5cc86991480"},
- {file = "multidict-6.0.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:16ab77bbeb596e14212e7bab8429f24c1579234a3a462105cda4a66904998664"},
- {file = "multidict-6.0.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc779e9e6f7fda81b3f9aa58e3a6091d49ad528b11ed19f6621408806204ad35"},
- {file = "multidict-6.0.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ceef517eca3e03c1cceb22030a3e39cb399ac86bff4e426d4fc6ae49052cc60"},
- {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:281af09f488903fde97923c7744bb001a9b23b039a909460d0f14edc7bf59706"},
- {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:52f2dffc8acaba9a2f27174c41c9e57f60b907bb9f096b36b1a1f3be71c6284d"},
- {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b41156839806aecb3641f3208c0dafd3ac7775b9c4c422d82ee2a45c34ba81ca"},
- {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5e3fc56f88cc98ef8139255cf8cd63eb2c586531e43310ff859d6bb3a6b51f1"},
- {file = "multidict-6.0.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8316a77808c501004802f9beebde51c9f857054a0c871bd6da8280e718444449"},
- {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f70b98cd94886b49d91170ef23ec5c0e8ebb6f242d734ed7ed677b24d50c82cf"},
- {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bf6774e60d67a9efe02b3616fee22441d86fab4c6d335f9d2051d19d90a40063"},
- {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:e69924bfcdda39b722ef4d9aa762b2dd38e4632b3641b1d9a57ca9cd18f2f83a"},
- {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:6b181d8c23da913d4ff585afd1155a0e1194c0b50c54fcfe286f70cdaf2b7176"},
- {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:52509b5be062d9eafc8170e53026fbc54cf3b32759a23d07fd935fb04fc22d95"},
- {file = "multidict-6.0.4-cp39-cp39-win32.whl", hash = "sha256:27c523fbfbdfd19c6867af7346332b62b586eed663887392cff78d614f9ec313"},
- {file = "multidict-6.0.4-cp39-cp39-win_amd64.whl", hash = "sha256:33029f5734336aa0d4c0384525da0387ef89148dc7191aae00ca5fb23d7aafc2"},
- {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"},
+ {file = "multidict-6.4.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:32a998bd8a64ca48616eac5a8c1cc4fa38fb244a3facf2eeb14abe186e0f6cc5"},
+ {file = "multidict-6.4.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a54ec568f1fc7f3c313c2f3b16e5db346bf3660e1309746e7fccbbfded856188"},
+ {file = "multidict-6.4.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a7be07e5df178430621c716a63151165684d3e9958f2bbfcb644246162007ab7"},
+ {file = "multidict-6.4.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b128dbf1c939674a50dd0b28f12c244d90e5015e751a4f339a96c54f7275e291"},
+ {file = "multidict-6.4.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:b9cb19dfd83d35b6ff24a4022376ea6e45a2beba8ef3f0836b8a4b288b6ad685"},
+ {file = "multidict-6.4.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3cf62f8e447ea2c1395afa289b332e49e13d07435369b6f4e41f887db65b40bf"},
+ {file = "multidict-6.4.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:909f7d43ff8f13d1adccb6a397094adc369d4da794407f8dd592c51cf0eae4b1"},
+ {file = "multidict-6.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0bb8f8302fbc7122033df959e25777b0b7659b1fd6bcb9cb6bed76b5de67afef"},
+ {file = "multidict-6.4.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:224b79471b4f21169ea25ebc37ed6f058040c578e50ade532e2066562597b8a9"},
+ {file = "multidict-6.4.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a7bd27f7ab3204f16967a6f899b3e8e9eb3362c0ab91f2ee659e0345445e0078"},
+ {file = "multidict-6.4.3-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:99592bd3162e9c664671fd14e578a33bfdba487ea64bcb41d281286d3c870ad7"},
+ {file = "multidict-6.4.3-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:a62d78a1c9072949018cdb05d3c533924ef8ac9bcb06cbf96f6d14772c5cd451"},
+ {file = "multidict-6.4.3-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:3ccdde001578347e877ca4f629450973c510e88e8865d5aefbcb89b852ccc666"},
+ {file = "multidict-6.4.3-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:eccb67b0e78aa2e38a04c5ecc13bab325a43e5159a181a9d1a6723db913cbb3c"},
+ {file = "multidict-6.4.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8b6fcf6054fc4114a27aa865f8840ef3d675f9316e81868e0ad5866184a6cba5"},
+ {file = "multidict-6.4.3-cp310-cp310-win32.whl", hash = "sha256:f92c7f62d59373cd93bc9969d2da9b4b21f78283b1379ba012f7ee8127b3152e"},
+ {file = "multidict-6.4.3-cp310-cp310-win_amd64.whl", hash = "sha256:b57e28dbc031d13916b946719f213c494a517b442d7b48b29443e79610acd887"},
+ {file = "multidict-6.4.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f6f19170197cc29baccd33ccc5b5d6a331058796485857cf34f7635aa25fb0cd"},
+ {file = "multidict-6.4.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f2882bf27037eb687e49591690e5d491e677272964f9ec7bc2abbe09108bdfb8"},
+ {file = "multidict-6.4.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fbf226ac85f7d6b6b9ba77db4ec0704fde88463dc17717aec78ec3c8546c70ad"},
+ {file = "multidict-6.4.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e329114f82ad4b9dd291bef614ea8971ec119ecd0f54795109976de75c9a852"},
+ {file = "multidict-6.4.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:1f4e0334d7a555c63f5c8952c57ab6f1c7b4f8c7f3442df689fc9f03df315c08"},
+ {file = "multidict-6.4.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:740915eb776617b57142ce0bb13b7596933496e2f798d3d15a20614adf30d229"},
+ {file = "multidict-6.4.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:255dac25134d2b141c944b59a0d2f7211ca12a6d4779f7586a98b4b03ea80508"},
+ {file = "multidict-6.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4e8535bd4d741039b5aad4285ecd9b902ef9e224711f0b6afda6e38d7ac02c7"},
+ {file = "multidict-6.4.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:30c433a33be000dd968f5750722eaa0991037be0be4a9d453eba121774985bc8"},
+ {file = "multidict-6.4.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:4eb33b0bdc50acd538f45041f5f19945a1f32b909b76d7b117c0c25d8063df56"},
+ {file = "multidict-6.4.3-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:75482f43465edefd8a5d72724887ccdcd0c83778ded8f0cb1e0594bf71736cc0"},
+ {file = "multidict-6.4.3-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:ce5b3082e86aee80b3925ab4928198450d8e5b6466e11501fe03ad2191c6d777"},
+ {file = "multidict-6.4.3-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e413152e3212c4d39f82cf83c6f91be44bec9ddea950ce17af87fbf4e32ca6b2"},
+ {file = "multidict-6.4.3-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:8aac2eeff69b71f229a405c0a4b61b54bade8e10163bc7b44fcd257949620618"},
+ {file = "multidict-6.4.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ab583ac203af1d09034be41458feeab7863c0635c650a16f15771e1386abf2d7"},
+ {file = "multidict-6.4.3-cp311-cp311-win32.whl", hash = "sha256:1b2019317726f41e81154df636a897de1bfe9228c3724a433894e44cd2512378"},
+ {file = "multidict-6.4.3-cp311-cp311-win_amd64.whl", hash = "sha256:43173924fa93c7486402217fab99b60baf78d33806af299c56133a3755f69589"},
+ {file = "multidict-6.4.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:1f1c2f58f08b36f8475f3ec6f5aeb95270921d418bf18f90dffd6be5c7b0e676"},
+ {file = "multidict-6.4.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:26ae9ad364fc61b936fb7bf4c9d8bd53f3a5b4417142cd0be5c509d6f767e2f1"},
+ {file = "multidict-6.4.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:659318c6c8a85f6ecfc06b4e57529e5a78dfdd697260cc81f683492ad7e9435a"},
+ {file = "multidict-6.4.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e1eb72c741fd24d5a28242ce72bb61bc91f8451877131fa3fe930edb195f7054"},
+ {file = "multidict-6.4.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3cd06d88cb7398252284ee75c8db8e680aa0d321451132d0dba12bc995f0adcc"},
+ {file = "multidict-6.4.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4543d8dc6470a82fde92b035a92529317191ce993533c3c0c68f56811164ed07"},
+ {file = "multidict-6.4.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:30a3ebdc068c27e9d6081fca0e2c33fdf132ecea703a72ea216b81a66860adde"},
+ {file = "multidict-6.4.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b038f10e23f277153f86f95c777ba1958bcd5993194fda26a1d06fae98b2f00c"},
+ {file = "multidict-6.4.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c605a2b2dc14282b580454b9b5d14ebe0668381a3a26d0ac39daa0ca115eb2ae"},
+ {file = "multidict-6.4.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8bd2b875f4ca2bb527fe23e318ddd509b7df163407b0fb717df229041c6df5d3"},
+ {file = "multidict-6.4.3-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:c2e98c840c9c8e65c0e04b40c6c5066c8632678cd50c8721fdbcd2e09f21a507"},
+ {file = "multidict-6.4.3-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:66eb80dd0ab36dbd559635e62fba3083a48a252633164857a1d1684f14326427"},
+ {file = "multidict-6.4.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c23831bdee0a2a3cf21be057b5e5326292f60472fb6c6f86392bbf0de70ba731"},
+ {file = "multidict-6.4.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:1535cec6443bfd80d028052e9d17ba6ff8a5a3534c51d285ba56c18af97e9713"},
+ {file = "multidict-6.4.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3b73e7227681f85d19dec46e5b881827cd354aabe46049e1a61d2f9aaa4e285a"},
+ {file = "multidict-6.4.3-cp312-cp312-win32.whl", hash = "sha256:8eac0c49df91b88bf91f818e0a24c1c46f3622978e2c27035bfdca98e0e18124"},
+ {file = "multidict-6.4.3-cp312-cp312-win_amd64.whl", hash = "sha256:11990b5c757d956cd1db7cb140be50a63216af32cd6506329c2c59d732d802db"},
+ {file = "multidict-6.4.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:7a76534263d03ae0cfa721fea40fd2b5b9d17a6f85e98025931d41dc49504474"},
+ {file = "multidict-6.4.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:805031c2f599eee62ac579843555ed1ce389ae00c7e9f74c2a1b45e0564a88dd"},
+ {file = "multidict-6.4.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c56c179839d5dcf51d565132185409d1d5dd8e614ba501eb79023a6cab25576b"},
+ {file = "multidict-6.4.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c64f4ddb3886dd8ab71b68a7431ad4aa01a8fa5be5b11543b29674f29ca0ba3"},
+ {file = "multidict-6.4.3-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3002a856367c0b41cad6784f5b8d3ab008eda194ed7864aaa58f65312e2abcac"},
+ {file = "multidict-6.4.3-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3d75e621e7d887d539d6e1d789f0c64271c250276c333480a9e1de089611f790"},
+ {file = "multidict-6.4.3-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:995015cf4a3c0d72cbf453b10a999b92c5629eaf3a0c3e1efb4b5c1f602253bb"},
+ {file = "multidict-6.4.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2b0fabae7939d09d7d16a711468c385272fa1b9b7fb0d37e51143585d8e72e0"},
+ {file = "multidict-6.4.3-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:61ed4d82f8a1e67eb9eb04f8587970d78fe7cddb4e4d6230b77eda23d27938f9"},
+ {file = "multidict-6.4.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:062428944a8dc69df9fdc5d5fc6279421e5f9c75a9ee3f586f274ba7b05ab3c8"},
+ {file = "multidict-6.4.3-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:b90e27b4674e6c405ad6c64e515a505c6d113b832df52fdacb6b1ffd1fa9a1d1"},
+ {file = "multidict-6.4.3-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7d50d4abf6729921e9613d98344b74241572b751c6b37feed75fb0c37bd5a817"},
+ {file = "multidict-6.4.3-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:43fe10524fb0a0514be3954be53258e61d87341008ce4914f8e8b92bee6f875d"},
+ {file = "multidict-6.4.3-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:236966ca6c472ea4e2d3f02f6673ebfd36ba3f23159c323f5a496869bc8e47c9"},
+ {file = "multidict-6.4.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:422a5ec315018e606473ba1f5431e064cf8b2a7468019233dcf8082fabad64c8"},
+ {file = "multidict-6.4.3-cp313-cp313-win32.whl", hash = "sha256:f901a5aace8e8c25d78960dcc24c870c8d356660d3b49b93a78bf38eb682aac3"},
+ {file = "multidict-6.4.3-cp313-cp313-win_amd64.whl", hash = "sha256:1c152c49e42277bc9a2f7b78bd5fa10b13e88d1b0328221e7aef89d5c60a99a5"},
+ {file = "multidict-6.4.3-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:be8751869e28b9c0d368d94f5afcb4234db66fe8496144547b4b6d6a0645cfc6"},
+ {file = "multidict-6.4.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0d4b31f8a68dccbcd2c0ea04f0e014f1defc6b78f0eb8b35f2265e8716a6df0c"},
+ {file = "multidict-6.4.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:032efeab3049e37eef2ff91271884303becc9e54d740b492a93b7e7266e23756"},
+ {file = "multidict-6.4.3-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9e78006af1a7c8a8007e4f56629d7252668344442f66982368ac06522445e375"},
+ {file = "multidict-6.4.3-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:daeac9dd30cda8703c417e4fddccd7c4dc0c73421a0b54a7da2713be125846be"},
+ {file = "multidict-6.4.3-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1f6f90700881438953eae443a9c6f8a509808bc3b185246992c4233ccee37fea"},
+ {file = "multidict-6.4.3-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f84627997008390dd15762128dcf73c3365f4ec0106739cde6c20a07ed198ec8"},
+ {file = "multidict-6.4.3-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3307b48cd156153b117c0ea54890a3bdbf858a5b296ddd40dc3852e5f16e9b02"},
+ {file = "multidict-6.4.3-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ead46b0fa1dcf5af503a46e9f1c2e80b5d95c6011526352fa5f42ea201526124"},
+ {file = "multidict-6.4.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:1748cb2743bedc339d63eb1bca314061568793acd603a6e37b09a326334c9f44"},
+ {file = "multidict-6.4.3-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:acc9fa606f76fc111b4569348cc23a771cb52c61516dcc6bcef46d612edb483b"},
+ {file = "multidict-6.4.3-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:31469d5832b5885adeb70982e531ce86f8c992334edd2f2254a10fa3182ac504"},
+ {file = "multidict-6.4.3-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:ba46b51b6e51b4ef7bfb84b82f5db0dc5e300fb222a8a13b8cd4111898a869cf"},
+ {file = "multidict-6.4.3-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:389cfefb599edf3fcfd5f64c0410da686f90f5f5e2c4d84e14f6797a5a337af4"},
+ {file = "multidict-6.4.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:64bc2bbc5fba7b9db5c2c8d750824f41c6994e3882e6d73c903c2afa78d091e4"},
+ {file = "multidict-6.4.3-cp313-cp313t-win32.whl", hash = "sha256:0ecdc12ea44bab2807d6b4a7e5eef25109ab1c82a8240d86d3c1fc9f3b72efd5"},
+ {file = "multidict-6.4.3-cp313-cp313t-win_amd64.whl", hash = "sha256:7146a8742ea71b5d7d955bffcef58a9e6e04efba704b52a460134fefd10a8208"},
+ {file = "multidict-6.4.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5427a2679e95a642b7f8b0f761e660c845c8e6fe3141cddd6b62005bd133fc21"},
+ {file = "multidict-6.4.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:24a8caa26521b9ad09732972927d7b45b66453e6ebd91a3c6a46d811eeb7349b"},
+ {file = "multidict-6.4.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6b5a272bc7c36a2cd1b56ddc6bff02e9ce499f9f14ee4a45c45434ef083f2459"},
+ {file = "multidict-6.4.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edf74dc5e212b8c75165b435c43eb0d5e81b6b300a938a4eb82827119115e840"},
+ {file = "multidict-6.4.3-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:9f35de41aec4b323c71f54b0ca461ebf694fb48bec62f65221f52e0017955b39"},
+ {file = "multidict-6.4.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae93e0ff43b6f6892999af64097b18561691ffd835e21a8348a441e256592e1f"},
+ {file = "multidict-6.4.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5e3929269e9d7eff905d6971d8b8c85e7dbc72c18fb99c8eae6fe0a152f2e343"},
+ {file = "multidict-6.4.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb6214fe1750adc2a1b801a199d64b5a67671bf76ebf24c730b157846d0e90d2"},
+ {file = "multidict-6.4.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6d79cf5c0c6284e90f72123f4a3e4add52d6c6ebb4a9054e88df15b8d08444c6"},
+ {file = "multidict-6.4.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2427370f4a255262928cd14533a70d9738dfacadb7563bc3b7f704cc2360fc4e"},
+ {file = "multidict-6.4.3-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:fbd8d737867912b6c5f99f56782b8cb81f978a97b4437a1c476de90a3e41c9a1"},
+ {file = "multidict-6.4.3-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:0ee1bf613c448997f73fc4efb4ecebebb1c02268028dd4f11f011f02300cf1e8"},
+ {file = "multidict-6.4.3-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:578568c4ba5f2b8abd956baf8b23790dbfdc953e87d5b110bce343b4a54fc9e7"},
+ {file = "multidict-6.4.3-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:a059ad6b80de5b84b9fa02a39400319e62edd39d210b4e4f8c4f1243bdac4752"},
+ {file = "multidict-6.4.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:dd53893675b729a965088aaadd6a1f326a72b83742b056c1065bdd2e2a42b4df"},
+ {file = "multidict-6.4.3-cp39-cp39-win32.whl", hash = "sha256:abcfed2c4c139f25c2355e180bcc077a7cae91eefbb8b3927bb3f836c9586f1f"},
+ {file = "multidict-6.4.3-cp39-cp39-win_amd64.whl", hash = "sha256:b1b389ae17296dd739015d5ddb222ee99fd66adeae910de21ac950e00979d897"},
+ {file = "multidict-6.4.3-py3-none-any.whl", hash = "sha256:59fe01ee8e2a1e8ceb3f6dbb216b09c8d9f4ef1c22c4fc825d045a147fa2ebc9"},
+ {file = "multidict-6.4.3.tar.gz", hash = "sha256:3ada0b058c9f213c5f95ba301f922d402ac234f1111a7d8fd70f1b99f3c281ec"},
]
+[package.dependencies]
+typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""}
+
[[package]]
name = "packaging"
-version = "23.1"
+version = "24.2"
description = "Core utilities for Python packages"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
+groups = ["main", "dev"]
files = [
- {file = "packaging-23.1-py3-none-any.whl", hash = "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61"},
- {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"},
+ {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"},
+ {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"},
]
[[package]]
name = "pluggy"
-version = "1.2.0"
+version = "1.5.0"
description = "plugin and hook calling mechanisms for python"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
+groups = ["dev"]
files = [
- {file = "pluggy-1.2.0-py3-none-any.whl", hash = "sha256:c2fd55a7d7a3863cba1a013e4e2414658b1d07b6bc57b3919e0c63c9abb99849"},
- {file = "pluggy-1.2.0.tar.gz", hash = "sha256:d12f0c4b579b15f5e054301bb226ee85eeeba08ffec228092f8defbaa3a4c4b3"},
+ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"},
+ {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"},
]
-[package.dependencies]
-importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""}
-
[package.extras]
dev = ["pre-commit", "tox"]
testing = ["pytest", "pytest-benchmark"]
[[package]]
-name = "py"
-version = "1.11.0"
-description = "library with cross-python path, ini-parsing, io, code, log facilities"
+name = "propcache"
+version = "0.3.1"
+description = "Accelerated property cache"
optional = false
-python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
+python-versions = ">=3.9"
+groups = ["main"]
files = [
- {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"},
- {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"},
+ {file = "propcache-0.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f27785888d2fdd918bc36de8b8739f2d6c791399552333721b58193f68ea3e98"},
+ {file = "propcache-0.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4e89cde74154c7b5957f87a355bb9c8ec929c167b59c83d90654ea36aeb6180"},
+ {file = "propcache-0.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:730178f476ef03d3d4d255f0c9fa186cb1d13fd33ffe89d39f2cda4da90ceb71"},
+ {file = "propcache-0.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:967a8eec513dbe08330f10137eacb427b2ca52118769e82ebcfcab0fba92a649"},
+ {file = "propcache-0.3.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5b9145c35cc87313b5fd480144f8078716007656093d23059e8993d3a8fa730f"},
+ {file = "propcache-0.3.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9e64e948ab41411958670f1093c0a57acfdc3bee5cf5b935671bbd5313bcf229"},
+ {file = "propcache-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:319fa8765bfd6a265e5fa661547556da381e53274bc05094fc9ea50da51bfd46"},
+ {file = "propcache-0.3.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c66d8ccbc902ad548312b96ed8d5d266d0d2c6d006fd0f66323e9d8f2dd49be7"},
+ {file = "propcache-0.3.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2d219b0dbabe75e15e581fc1ae796109b07c8ba7d25b9ae8d650da582bed01b0"},
+ {file = "propcache-0.3.1-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:cd6a55f65241c551eb53f8cf4d2f4af33512c39da5d9777694e9d9c60872f519"},
+ {file = "propcache-0.3.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:9979643ffc69b799d50d3a7b72b5164a2e97e117009d7af6dfdd2ab906cb72cd"},
+ {file = "propcache-0.3.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:4cf9e93a81979f1424f1a3d155213dc928f1069d697e4353edb8a5eba67c6259"},
+ {file = "propcache-0.3.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:2fce1df66915909ff6c824bbb5eb403d2d15f98f1518e583074671a30fe0c21e"},
+ {file = "propcache-0.3.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:4d0dfdd9a2ebc77b869a0b04423591ea8823f791293b527dc1bb896c1d6f1136"},
+ {file = "propcache-0.3.1-cp310-cp310-win32.whl", hash = "sha256:1f6cc0ad7b4560e5637eb2c994e97b4fa41ba8226069c9277eb5ea7101845b42"},
+ {file = "propcache-0.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:47ef24aa6511e388e9894ec16f0fbf3313a53ee68402bc428744a367ec55b833"},
+ {file = "propcache-0.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:7f30241577d2fef2602113b70ef7231bf4c69a97e04693bde08ddab913ba0ce5"},
+ {file = "propcache-0.3.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:43593c6772aa12abc3af7784bff4a41ffa921608dd38b77cf1dfd7f5c4e71371"},
+ {file = "propcache-0.3.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a75801768bbe65499495660b777e018cbe90c7980f07f8aa57d6be79ea6f71da"},
+ {file = "propcache-0.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6f1324db48f001c2ca26a25fa25af60711e09b9aaf4b28488602776f4f9a744"},
+ {file = "propcache-0.3.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5cdb0f3e1eb6dfc9965d19734d8f9c481b294b5274337a8cb5cb01b462dcb7e0"},
+ {file = "propcache-0.3.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1eb34d90aac9bfbced9a58b266f8946cb5935869ff01b164573a7634d39fbcb5"},
+ {file = "propcache-0.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f35c7070eeec2cdaac6fd3fe245226ed2a6292d3ee8c938e5bb645b434c5f256"},
+ {file = "propcache-0.3.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b23c11c2c9e6d4e7300c92e022046ad09b91fd00e36e83c44483df4afa990073"},
+ {file = "propcache-0.3.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:3e19ea4ea0bf46179f8a3652ac1426e6dcbaf577ce4b4f65be581e237340420d"},
+ {file = "propcache-0.3.1-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:bd39c92e4c8f6cbf5f08257d6360123af72af9f4da75a690bef50da77362d25f"},
+ {file = "propcache-0.3.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b0313e8b923b3814d1c4a524c93dfecea5f39fa95601f6a9b1ac96cd66f89ea0"},
+ {file = "propcache-0.3.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e861ad82892408487be144906a368ddbe2dc6297074ade2d892341b35c59844a"},
+ {file = "propcache-0.3.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:61014615c1274df8da5991a1e5da85a3ccb00c2d4701ac6f3383afd3ca47ab0a"},
+ {file = "propcache-0.3.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:71ebe3fe42656a2328ab08933d420df5f3ab121772eef78f2dc63624157f0ed9"},
+ {file = "propcache-0.3.1-cp311-cp311-win32.whl", hash = "sha256:58aa11f4ca8b60113d4b8e32d37e7e78bd8af4d1a5b5cb4979ed856a45e62005"},
+ {file = "propcache-0.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:9532ea0b26a401264b1365146c440a6d78269ed41f83f23818d4b79497aeabe7"},
+ {file = "propcache-0.3.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:f78eb8422acc93d7b69964012ad7048764bb45a54ba7a39bb9e146c72ea29723"},
+ {file = "propcache-0.3.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:89498dd49c2f9a026ee057965cdf8192e5ae070ce7d7a7bd4b66a8e257d0c976"},
+ {file = "propcache-0.3.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:09400e98545c998d57d10035ff623266927cb784d13dd2b31fd33b8a5316b85b"},
+ {file = "propcache-0.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aa8efd8c5adc5a2c9d3b952815ff8f7710cefdcaf5f2c36d26aff51aeca2f12f"},
+ {file = "propcache-0.3.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c2fe5c910f6007e716a06d269608d307b4f36e7babee5f36533722660e8c4a70"},
+ {file = "propcache-0.3.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a0ab8cf8cdd2194f8ff979a43ab43049b1df0b37aa64ab7eca04ac14429baeb7"},
+ {file = "propcache-0.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:563f9d8c03ad645597b8d010ef4e9eab359faeb11a0a2ac9f7b4bc8c28ebef25"},
+ {file = "propcache-0.3.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb6e0faf8cb6b4beea5d6ed7b5a578254c6d7df54c36ccd3d8b3eb00d6770277"},
+ {file = "propcache-0.3.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1c5c7ab7f2bb3f573d1cb921993006ba2d39e8621019dffb1c5bc94cdbae81e8"},
+ {file = "propcache-0.3.1-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:050b571b2e96ec942898f8eb46ea4bfbb19bd5502424747e83badc2d4a99a44e"},
+ {file = "propcache-0.3.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:e1c4d24b804b3a87e9350f79e2371a705a188d292fd310e663483af6ee6718ee"},
+ {file = "propcache-0.3.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:e4fe2a6d5ce975c117a6bb1e8ccda772d1e7029c1cca1acd209f91d30fa72815"},
+ {file = "propcache-0.3.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:feccd282de1f6322f56f6845bf1207a537227812f0a9bf5571df52bb418d79d5"},
+ {file = "propcache-0.3.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ec314cde7314d2dd0510c6787326bbffcbdc317ecee6b7401ce218b3099075a7"},
+ {file = "propcache-0.3.1-cp312-cp312-win32.whl", hash = "sha256:7d2d5a0028d920738372630870e7d9644ce437142197f8c827194fca404bf03b"},
+ {file = "propcache-0.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:88c423efef9d7a59dae0614eaed718449c09a5ac79a5f224a8b9664d603f04a3"},
+ {file = "propcache-0.3.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:f1528ec4374617a7a753f90f20e2f551121bb558fcb35926f99e3c42367164b8"},
+ {file = "propcache-0.3.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:dc1915ec523b3b494933b5424980831b636fe483d7d543f7afb7b3bf00f0c10f"},
+ {file = "propcache-0.3.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a110205022d077da24e60b3df8bcee73971be9575dec5573dd17ae5d81751111"},
+ {file = "propcache-0.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d249609e547c04d190e820d0d4c8ca03ed4582bcf8e4e160a6969ddfb57b62e5"},
+ {file = "propcache-0.3.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5ced33d827625d0a589e831126ccb4f5c29dfdf6766cac441d23995a65825dcb"},
+ {file = "propcache-0.3.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4114c4ada8f3181af20808bedb250da6bae56660e4b8dfd9cd95d4549c0962f7"},
+ {file = "propcache-0.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:975af16f406ce48f1333ec5e912fe11064605d5c5b3f6746969077cc3adeb120"},
+ {file = "propcache-0.3.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a34aa3a1abc50740be6ac0ab9d594e274f59960d3ad253cd318af76b996dd654"},
+ {file = "propcache-0.3.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9cec3239c85ed15bfaded997773fdad9fb5662b0a7cbc854a43f291eb183179e"},
+ {file = "propcache-0.3.1-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:05543250deac8e61084234d5fc54f8ebd254e8f2b39a16b1dce48904f45b744b"},
+ {file = "propcache-0.3.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:5cb5918253912e088edbf023788de539219718d3b10aef334476b62d2b53de53"},
+ {file = "propcache-0.3.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f3bbecd2f34d0e6d3c543fdb3b15d6b60dd69970c2b4c822379e5ec8f6f621d5"},
+ {file = "propcache-0.3.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:aca63103895c7d960a5b9b044a83f544b233c95e0dcff114389d64d762017af7"},
+ {file = "propcache-0.3.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5a0a9898fdb99bf11786265468571e628ba60af80dc3f6eb89a3545540c6b0ef"},
+ {file = "propcache-0.3.1-cp313-cp313-win32.whl", hash = "sha256:3a02a28095b5e63128bcae98eb59025924f121f048a62393db682f049bf4ac24"},
+ {file = "propcache-0.3.1-cp313-cp313-win_amd64.whl", hash = "sha256:813fbb8b6aea2fc9659815e585e548fe706d6f663fa73dff59a1677d4595a037"},
+ {file = "propcache-0.3.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:a444192f20f5ce8a5e52761a031b90f5ea6288b1eef42ad4c7e64fef33540b8f"},
+ {file = "propcache-0.3.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0fbe94666e62ebe36cd652f5fc012abfbc2342de99b523f8267a678e4dfdee3c"},
+ {file = "propcache-0.3.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:f011f104db880f4e2166bcdcf7f58250f7a465bc6b068dc84c824a3d4a5c94dc"},
+ {file = "propcache-0.3.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e584b6d388aeb0001d6d5c2bd86b26304adde6d9bb9bfa9c4889805021b96de"},
+ {file = "propcache-0.3.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8a17583515a04358b034e241f952f1715243482fc2c2945fd99a1b03a0bd77d6"},
+ {file = "propcache-0.3.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5aed8d8308215089c0734a2af4f2e95eeb360660184ad3912686c181e500b2e7"},
+ {file = "propcache-0.3.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d8e309ff9a0503ef70dc9a0ebd3e69cf7b3894c9ae2ae81fc10943c37762458"},
+ {file = "propcache-0.3.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b655032b202028a582d27aeedc2e813299f82cb232f969f87a4fde491a233f11"},
+ {file = "propcache-0.3.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:9f64d91b751df77931336b5ff7bafbe8845c5770b06630e27acd5dbb71e1931c"},
+ {file = "propcache-0.3.1-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:19a06db789a4bd896ee91ebc50d059e23b3639c25d58eb35be3ca1cbe967c3bf"},
+ {file = "propcache-0.3.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:bef100c88d8692864651b5f98e871fb090bd65c8a41a1cb0ff2322db39c96c27"},
+ {file = "propcache-0.3.1-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:87380fb1f3089d2a0b8b00f006ed12bd41bd858fabfa7330c954c70f50ed8757"},
+ {file = "propcache-0.3.1-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:e474fc718e73ba5ec5180358aa07f6aded0ff5f2abe700e3115c37d75c947e18"},
+ {file = "propcache-0.3.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:17d1c688a443355234f3c031349da69444be052613483f3e4158eef751abcd8a"},
+ {file = "propcache-0.3.1-cp313-cp313t-win32.whl", hash = "sha256:359e81a949a7619802eb601d66d37072b79b79c2505e6d3fd8b945538411400d"},
+ {file = "propcache-0.3.1-cp313-cp313t-win_amd64.whl", hash = "sha256:e7fb9a84c9abbf2b2683fa3e7b0d7da4d8ecf139a1c635732a8bda29c5214b0e"},
+ {file = "propcache-0.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:ed5f6d2edbf349bd8d630e81f474d33d6ae5d07760c44d33cd808e2f5c8f4ae6"},
+ {file = "propcache-0.3.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:668ddddc9f3075af019f784456267eb504cb77c2c4bd46cc8402d723b4d200bf"},
+ {file = "propcache-0.3.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0c86e7ceea56376216eba345aa1fc6a8a6b27ac236181f840d1d7e6a1ea9ba5c"},
+ {file = "propcache-0.3.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:83be47aa4e35b87c106fc0c84c0fc069d3f9b9b06d3c494cd404ec6747544894"},
+ {file = "propcache-0.3.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:27c6ac6aa9fc7bc662f594ef380707494cb42c22786a558d95fcdedb9aa5d035"},
+ {file = "propcache-0.3.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64a956dff37080b352c1c40b2966b09defb014347043e740d420ca1eb7c9b908"},
+ {file = "propcache-0.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:82de5da8c8893056603ac2d6a89eb8b4df49abf1a7c19d536984c8dd63f481d5"},
+ {file = "propcache-0.3.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0c3c3a203c375b08fd06a20da3cf7aac293b834b6f4f4db71190e8422750cca5"},
+ {file = "propcache-0.3.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:b303b194c2e6f171cfddf8b8ba30baefccf03d36a4d9cab7fd0bb68ba476a3d7"},
+ {file = "propcache-0.3.1-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:916cd229b0150129d645ec51614d38129ee74c03293a9f3f17537be0029a9641"},
+ {file = "propcache-0.3.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:a461959ead5b38e2581998700b26346b78cd98540b5524796c175722f18b0294"},
+ {file = "propcache-0.3.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:069e7212890b0bcf9b2be0a03afb0c2d5161d91e1bf51569a64f629acc7defbf"},
+ {file = "propcache-0.3.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:ef2e4e91fb3945769e14ce82ed53007195e616a63aa43b40fb7ebaaf907c8d4c"},
+ {file = "propcache-0.3.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:8638f99dca15b9dff328fb6273e09f03d1c50d9b6512f3b65a4154588a7595fe"},
+ {file = "propcache-0.3.1-cp39-cp39-win32.whl", hash = "sha256:6f173bbfe976105aaa890b712d1759de339d8a7cef2fc0a1714cc1a1e1c47f64"},
+ {file = "propcache-0.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:603f1fe4144420374f1a69b907494c3acbc867a581c2d49d4175b0de7cc64566"},
+ {file = "propcache-0.3.1-py3-none-any.whl", hash = "sha256:9a8ecf38de50a7f518c21568c80f985e776397b902f1ce0b01f799aba1608b40"},
+ {file = "propcache-0.3.1.tar.gz", hash = "sha256:40d980c33765359098837527e18eddefc9a24cea5b45e078a7f3bb5b032c6ecf"},
]
[[package]]
name = "pydantic"
-version = "2.5.3"
+version = "2.11.3"
description = "Data validation using Python type hints"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.9"
+groups = ["main"]
files = [
- {file = "pydantic-2.5.3-py3-none-any.whl", hash = "sha256:d0caf5954bee831b6bfe7e338c32b9e30c85dfe080c843680783ac2b631673b4"},
- {file = "pydantic-2.5.3.tar.gz", hash = "sha256:b3ef57c62535b0941697cce638c08900d87fcb67e29cfa99e8a68f747f393f7a"},
+ {file = "pydantic-2.11.3-py3-none-any.whl", hash = "sha256:a082753436a07f9ba1289c6ffa01cd93db3548776088aa917cc43b63f68fa60f"},
+ {file = "pydantic-2.11.3.tar.gz", hash = "sha256:7471657138c16adad9322fe3070c0116dd6c3ad8d649300e3cbdfe91f4db4ec3"},
]
[package.dependencies]
-annotated-types = ">=0.4.0"
-importlib-metadata = {version = "*", markers = "python_version == \"3.7\""}
-pydantic-core = "2.14.6"
-typing-extensions = ">=4.6.1"
+annotated-types = ">=0.6.0"
+pydantic-core = "2.33.1"
+typing-extensions = ">=4.12.2"
+typing-inspection = ">=0.4.0"
[package.extras]
email = ["email-validator (>=2.0.0)"]
+timezone = ["tzdata"]
[[package]]
name = "pydantic-core"
-version = "2.14.6"
-description = ""
+version = "2.33.1"
+description = "Core functionality for Pydantic validation and serialization"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.9"
+groups = ["main"]
files = [
- {file = "pydantic_core-2.14.6-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:72f9a942d739f09cd42fffe5dc759928217649f070056f03c70df14f5770acf9"},
- {file = "pydantic_core-2.14.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6a31d98c0d69776c2576dda4b77b8e0c69ad08e8b539c25c7d0ca0dc19a50d6c"},
- {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5aa90562bc079c6c290f0512b21768967f9968e4cfea84ea4ff5af5d917016e4"},
- {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:370ffecb5316ed23b667d99ce4debe53ea664b99cc37bfa2af47bc769056d534"},
- {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f85f3843bdb1fe80e8c206fe6eed7a1caeae897e496542cee499c374a85c6e08"},
- {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9862bf828112e19685b76ca499b379338fd4c5c269d897e218b2ae8fcb80139d"},
- {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:036137b5ad0cb0004c75b579445a1efccd072387a36c7f217bb8efd1afbe5245"},
- {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:92879bce89f91f4b2416eba4429c7b5ca22c45ef4a499c39f0c5c69257522c7c"},
- {file = "pydantic_core-2.14.6-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0c08de15d50fa190d577e8591f0329a643eeaed696d7771760295998aca6bc66"},
- {file = "pydantic_core-2.14.6-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:36099c69f6b14fc2c49d7996cbf4f87ec4f0e66d1c74aa05228583225a07b590"},
- {file = "pydantic_core-2.14.6-cp310-none-win32.whl", hash = "sha256:7be719e4d2ae6c314f72844ba9d69e38dff342bc360379f7c8537c48e23034b7"},
- {file = "pydantic_core-2.14.6-cp310-none-win_amd64.whl", hash = "sha256:36fa402dcdc8ea7f1b0ddcf0df4254cc6b2e08f8cd80e7010d4c4ae6e86b2a87"},
- {file = "pydantic_core-2.14.6-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:dea7fcd62915fb150cdc373212141a30037e11b761fbced340e9db3379b892d4"},
- {file = "pydantic_core-2.14.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ffff855100bc066ff2cd3aa4a60bc9534661816b110f0243e59503ec2df38421"},
- {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b027c86c66b8627eb90e57aee1f526df77dc6d8b354ec498be9a757d513b92b"},
- {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:00b1087dabcee0b0ffd104f9f53d7d3eaddfaa314cdd6726143af6bc713aa27e"},
- {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:75ec284328b60a4e91010c1acade0c30584f28a1f345bc8f72fe8b9e46ec6a96"},
- {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7e1f4744eea1501404b20b0ac059ff7e3f96a97d3e3f48ce27a139e053bb370b"},
- {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2602177668f89b38b9f84b7b3435d0a72511ddef45dc14446811759b82235a1"},
- {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6c8edaea3089bf908dd27da8f5d9e395c5b4dc092dbcce9b65e7156099b4b937"},
- {file = "pydantic_core-2.14.6-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:478e9e7b360dfec451daafe286998d4a1eeaecf6d69c427b834ae771cad4b622"},
- {file = "pydantic_core-2.14.6-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b6ca36c12a5120bad343eef193cc0122928c5c7466121da7c20f41160ba00ba2"},
- {file = "pydantic_core-2.14.6-cp311-none-win32.whl", hash = "sha256:2b8719037e570639e6b665a4050add43134d80b687288ba3ade18b22bbb29dd2"},
- {file = "pydantic_core-2.14.6-cp311-none-win_amd64.whl", hash = "sha256:78ee52ecc088c61cce32b2d30a826f929e1708f7b9247dc3b921aec367dc1b23"},
- {file = "pydantic_core-2.14.6-cp311-none-win_arm64.whl", hash = "sha256:a19b794f8fe6569472ff77602437ec4430f9b2b9ec7a1105cfd2232f9ba355e6"},
- {file = "pydantic_core-2.14.6-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:667aa2eac9cd0700af1ddb38b7b1ef246d8cf94c85637cbb03d7757ca4c3fdec"},
- {file = "pydantic_core-2.14.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:cdee837710ef6b56ebd20245b83799fce40b265b3b406e51e8ccc5b85b9099b7"},
- {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c5bcf3414367e29f83fd66f7de64509a8fd2368b1edf4351e862910727d3e51"},
- {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:26a92ae76f75d1915806b77cf459811e772d8f71fd1e4339c99750f0e7f6324f"},
- {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a983cca5ed1dd9a35e9e42ebf9f278d344603bfcb174ff99a5815f953925140a"},
- {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cb92f9061657287eded380d7dc455bbf115430b3aa4741bdc662d02977e7d0af"},
- {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4ace1e220b078c8e48e82c081e35002038657e4b37d403ce940fa679e57113b"},
- {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ef633add81832f4b56d3b4c9408b43d530dfca29e68fb1b797dcb861a2c734cd"},
- {file = "pydantic_core-2.14.6-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7e90d6cc4aad2cc1f5e16ed56e46cebf4877c62403a311af20459c15da76fd91"},
- {file = "pydantic_core-2.14.6-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e8a5ac97ea521d7bde7621d86c30e86b798cdecd985723c4ed737a2aa9e77d0c"},
- {file = "pydantic_core-2.14.6-cp312-none-win32.whl", hash = "sha256:f27207e8ca3e5e021e2402ba942e5b4c629718e665c81b8b306f3c8b1ddbb786"},
- {file = "pydantic_core-2.14.6-cp312-none-win_amd64.whl", hash = "sha256:b3e5fe4538001bb82e2295b8d2a39356a84694c97cb73a566dc36328b9f83b40"},
- {file = "pydantic_core-2.14.6-cp312-none-win_arm64.whl", hash = "sha256:64634ccf9d671c6be242a664a33c4acf12882670b09b3f163cd00a24cffbd74e"},
- {file = "pydantic_core-2.14.6-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:24368e31be2c88bd69340fbfe741b405302993242ccb476c5c3ff48aeee1afe0"},
- {file = "pydantic_core-2.14.6-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:e33b0834f1cf779aa839975f9d8755a7c2420510c0fa1e9fa0497de77cd35d2c"},
- {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6af4b3f52cc65f8a0bc8b1cd9676f8c21ef3e9132f21fed250f6958bd7223bed"},
- {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d15687d7d7f40333bd8266f3814c591c2e2cd263fa2116e314f60d82086e353a"},
- {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:095b707bb287bfd534044166ab767bec70a9bba3175dcdc3371782175c14e43c"},
- {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:94fc0e6621e07d1e91c44e016cc0b189b48db053061cc22d6298a611de8071bb"},
- {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ce830e480f6774608dedfd4a90c42aac4a7af0a711f1b52f807130c2e434c06"},
- {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a306cdd2ad3a7d795d8e617a58c3a2ed0f76c8496fb7621b6cd514eb1532cae8"},
- {file = "pydantic_core-2.14.6-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:2f5fa187bde8524b1e37ba894db13aadd64faa884657473b03a019f625cee9a8"},
- {file = "pydantic_core-2.14.6-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:438027a975cc213a47c5d70672e0d29776082155cfae540c4e225716586be75e"},
- {file = "pydantic_core-2.14.6-cp37-none-win32.whl", hash = "sha256:f96ae96a060a8072ceff4cfde89d261837b4294a4f28b84a28765470d502ccc6"},
- {file = "pydantic_core-2.14.6-cp37-none-win_amd64.whl", hash = "sha256:e646c0e282e960345314f42f2cea5e0b5f56938c093541ea6dbf11aec2862391"},
- {file = "pydantic_core-2.14.6-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:db453f2da3f59a348f514cfbfeb042393b68720787bbef2b4c6068ea362c8149"},
- {file = "pydantic_core-2.14.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3860c62057acd95cc84044e758e47b18dcd8871a328ebc8ccdefd18b0d26a21b"},
- {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36026d8f99c58d7044413e1b819a67ca0e0b8ebe0f25e775e6c3d1fabb3c38fb"},
- {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8ed1af8692bd8d2a29d702f1a2e6065416d76897d726e45a1775b1444f5928a7"},
- {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:314ccc4264ce7d854941231cf71b592e30d8d368a71e50197c905874feacc8a8"},
- {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:982487f8931067a32e72d40ab6b47b1628a9c5d344be7f1a4e668fb462d2da42"},
- {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2dbe357bc4ddda078f79d2a36fc1dd0494a7f2fad83a0a684465b6f24b46fe80"},
- {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2f6ffc6701a0eb28648c845f4945a194dc7ab3c651f535b81793251e1185ac3d"},
- {file = "pydantic_core-2.14.6-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:7f5025db12fc6de7bc1104d826d5aee1d172f9ba6ca936bf6474c2148ac336c1"},
- {file = "pydantic_core-2.14.6-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:dab03ed811ed1c71d700ed08bde8431cf429bbe59e423394f0f4055f1ca0ea60"},
- {file = "pydantic_core-2.14.6-cp38-none-win32.whl", hash = "sha256:dfcbebdb3c4b6f739a91769aea5ed615023f3c88cb70df812849aef634c25fbe"},
- {file = "pydantic_core-2.14.6-cp38-none-win_amd64.whl", hash = "sha256:99b14dbea2fdb563d8b5a57c9badfcd72083f6006caf8e126b491519c7d64ca8"},
- {file = "pydantic_core-2.14.6-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:4ce8299b481bcb68e5c82002b96e411796b844d72b3e92a3fbedfe8e19813eab"},
- {file = "pydantic_core-2.14.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b9a9d92f10772d2a181b5ca339dee066ab7d1c9a34ae2421b2a52556e719756f"},
- {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd9e98b408384989ea4ab60206b8e100d8687da18b5c813c11e92fd8212a98e0"},
- {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4f86f1f318e56f5cbb282fe61eb84767aee743ebe32c7c0834690ebea50c0a6b"},
- {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86ce5fcfc3accf3a07a729779d0b86c5d0309a4764c897d86c11089be61da160"},
- {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3dcf1978be02153c6a31692d4fbcc2a3f1db9da36039ead23173bc256ee3b91b"},
- {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eedf97be7bc3dbc8addcef4142f4b4164066df0c6f36397ae4aaed3eb187d8ab"},
- {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d5f916acf8afbcab6bacbb376ba7dc61f845367901ecd5e328fc4d4aef2fcab0"},
- {file = "pydantic_core-2.14.6-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:8a14c192c1d724c3acbfb3f10a958c55a2638391319ce8078cb36c02283959b9"},
- {file = "pydantic_core-2.14.6-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0348b1dc6b76041516e8a854ff95b21c55f5a411c3297d2ca52f5528e49d8411"},
- {file = "pydantic_core-2.14.6-cp39-none-win32.whl", hash = "sha256:de2a0645a923ba57c5527497daf8ec5df69c6eadf869e9cd46e86349146e5975"},
- {file = "pydantic_core-2.14.6-cp39-none-win_amd64.whl", hash = "sha256:aca48506a9c20f68ee61c87f2008f81f8ee99f8d7f0104bff3c47e2d148f89d9"},
- {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:d5c28525c19f5bb1e09511669bb57353d22b94cf8b65f3a8d141c389a55dec95"},
- {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:78d0768ee59baa3de0f4adac9e3748b4b1fffc52143caebddfd5ea2961595277"},
- {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b93785eadaef932e4fe9c6e12ba67beb1b3f1e5495631419c784ab87e975670"},
- {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a874f21f87c485310944b2b2734cd6d318765bcbb7515eead33af9641816506e"},
- {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b89f4477d915ea43b4ceea6756f63f0288941b6443a2b28c69004fe07fde0d0d"},
- {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:172de779e2a153d36ee690dbc49c6db568d7b33b18dc56b69a7514aecbcf380d"},
- {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:dfcebb950aa7e667ec226a442722134539e77c575f6cfaa423f24371bb8d2e94"},
- {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:55a23dcd98c858c0db44fc5c04fc7ed81c4b4d33c653a7c45ddaebf6563a2f66"},
- {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:4241204e4b36ab5ae466ecec5c4c16527a054c69f99bba20f6f75232a6a534e2"},
- {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e574de99d735b3fc8364cba9912c2bec2da78775eba95cbb225ef7dda6acea24"},
- {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1302a54f87b5cd8528e4d6d1bf2133b6aa7c6122ff8e9dc5220fbc1e07bffebd"},
- {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f8e81e4b55930e5ffab4a68db1af431629cf2e4066dbdbfef65348b8ab804ea8"},
- {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:c99462ffc538717b3e60151dfaf91125f637e801f5ab008f81c402f1dff0cd0f"},
- {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:e4cf2d5829f6963a5483ec01578ee76d329eb5caf330ecd05b3edd697e7d768a"},
- {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:cf10b7d58ae4a1f07fccbf4a0a956d705356fea05fb4c70608bb6fa81d103cda"},
- {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:399ac0891c284fa8eb998bcfa323f2234858f5d2efca3950ae58c8f88830f145"},
- {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c6a5c79b28003543db3ba67d1df336f253a87d3112dac3a51b94f7d48e4c0e1"},
- {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:599c87d79cab2a6a2a9df4aefe0455e61e7d2aeede2f8577c1b7c0aec643ee8e"},
- {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:43e166ad47ba900f2542a80d83f9fc65fe99eb63ceec4debec160ae729824052"},
- {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:3a0b5db001b98e1c649dd55afa928e75aa4087e587b9524a4992316fa23c9fba"},
- {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:747265448cb57a9f37572a488a57d873fd96bf51e5bb7edb52cfb37124516da4"},
- {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:7ebe3416785f65c28f4f9441e916bfc8a54179c8dea73c23023f7086fa601c5d"},
- {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:86c963186ca5e50d5c8287b1d1c9d3f8f024cbe343d048c5bd282aec2d8641f2"},
- {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:e0641b506486f0b4cd1500a2a65740243e8670a2549bb02bc4556a83af84ae03"},
- {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:71d72ca5eaaa8d38c8df16b7deb1a2da4f650c41b58bb142f3fb75d5ad4a611f"},
- {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:27e524624eace5c59af499cd97dc18bb201dc6a7a2da24bfc66ef151c69a5f2a"},
- {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a3dde6cac75e0b0902778978d3b1646ca9f438654395a362cb21d9ad34b24acf"},
- {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:00646784f6cd993b1e1c0e7b0fdcbccc375d539db95555477771c27555e3c556"},
- {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:23598acb8ccaa3d1d875ef3b35cb6376535095e9405d91a3d57a8c7db5d29341"},
- {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7f41533d7e3cf9520065f610b41ac1c76bc2161415955fbcead4981b22c7611e"},
- {file = "pydantic_core-2.14.6.tar.gz", hash = "sha256:1fd0c1d395372843fba13a51c28e3bb9d59bd7aebfeb17358ffaaa1e4dbbe948"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3077cfdb6125cc8dab61b155fdd714663e401f0e6883f9632118ec12cf42df26"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8ffab8b2908d152e74862d276cf5017c81a2f3719f14e8e3e8d6b83fda863927"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5183e4f6a2d468787243ebcd70cf4098c247e60d73fb7d68d5bc1e1beaa0c4db"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:398a38d323f37714023be1e0285765f0a27243a8b1506b7b7de87b647b517e48"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:87d3776f0001b43acebfa86f8c64019c043b55cc5a6a2e313d728b5c95b46969"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c566dd9c5f63d22226409553531f89de0cac55397f2ab8d97d6f06cfce6d947e"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0d5f3acc81452c56895e90643a625302bd6be351e7010664151cc55b7b97f89"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d3a07fadec2a13274a8d861d3d37c61e97a816beae717efccaa4b36dfcaadcde"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:f99aeda58dce827f76963ee87a0ebe75e648c72ff9ba1174a253f6744f518f65"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:902dbc832141aa0ec374f4310f1e4e7febeebc3256f00dc359a9ac3f264a45dc"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fe44d56aa0b00d66640aa84a3cbe80b7a3ccdc6f0b1ca71090696a6d4777c091"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-win32.whl", hash = "sha256:ed3eb16d51257c763539bde21e011092f127a2202692afaeaccb50db55a31383"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-win_amd64.whl", hash = "sha256:694ad99a7f6718c1a498dc170ca430687a39894a60327f548e02a9c7ee4b6504"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6e966fc3caaf9f1d96b349b0341c70c8d6573bf1bac7261f7b0ba88f96c56c24"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bfd0adeee563d59c598ceabddf2c92eec77abcb3f4a391b19aa7366170bd9e30"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91815221101ad3c6b507804178a7bb5cb7b2ead9ecd600041669c8d805ebd595"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9fea9c1869bb4742d174a57b4700c6dadea951df8b06de40c2fedb4f02931c2e"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d20eb4861329bb2484c021b9d9a977566ab16d84000a57e28061151c62b349a"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fb935c5591573ae3201640579f30128ccc10739b45663f93c06796854405505"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c964fd24e6166420d18fb53996d8c9fd6eac9bf5ae3ec3d03015be4414ce497f"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:681d65e9011f7392db5aa002b7423cc442d6a673c635668c227c6c8d0e5a4f77"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e100c52f7355a48413e2999bfb4e139d2977a904495441b374f3d4fb4a170961"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:048831bd363490be79acdd3232f74a0e9951b11b2b4cc058aeb72b22fdc3abe1"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:bdc84017d28459c00db6f918a7272a5190bec3090058334e43a76afb279eac7c"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-win32.whl", hash = "sha256:32cd11c5914d1179df70406427097c7dcde19fddf1418c787540f4b730289896"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-win_amd64.whl", hash = "sha256:2ea62419ba8c397e7da28a9170a16219d310d2cf4970dbc65c32faf20d828c83"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-win_arm64.whl", hash = "sha256:fc903512177361e868bc1f5b80ac8c8a6e05fcdd574a5fb5ffeac5a9982b9e89"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:1293d7febb995e9d3ec3ea09caf1a26214eec45b0f29f6074abb004723fc1de8"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:99b56acd433386c8f20be5c4000786d1e7ca0523c8eefc995d14d79c7a081498"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35a5ec3fa8c2fe6c53e1b2ccc2454398f95d5393ab398478f53e1afbbeb4d939"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b172f7b9d2f3abc0efd12e3386f7e48b576ef309544ac3a63e5e9cdd2e24585d"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9097b9f17f91eea659b9ec58148c0747ec354a42f7389b9d50701610d86f812e"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cc77ec5b7e2118b152b0d886c7514a4653bcb58c6b1d760134a9fab915f777b3"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5e3d15245b08fa4a84cefc6c9222e6f37c98111c8679fbd94aa145f9a0ae23d"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ef99779001d7ac2e2461d8ab55d3373fe7315caefdbecd8ced75304ae5a6fc6b"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:fc6bf8869e193855e8d91d91f6bf59699a5cdfaa47a404e278e776dd7f168b39"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:b1caa0bc2741b043db7823843e1bde8aaa58a55a58fda06083b0569f8b45693a"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ec259f62538e8bf364903a7d0d0239447059f9434b284f5536e8402b7dd198db"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-win32.whl", hash = "sha256:e14f369c98a7c15772b9da98987f58e2b509a93235582838bd0d1d8c08b68fda"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-win_amd64.whl", hash = "sha256:1c607801d85e2e123357b3893f82c97a42856192997b95b4d8325deb1cd0c5f4"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-win_arm64.whl", hash = "sha256:8d13f0276806ee722e70a1c93da19748594f19ac4299c7e41237fc791d1861ea"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:70af6a21237b53d1fe7b9325b20e65cbf2f0a848cf77bed492b029139701e66a"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:282b3fe1bbbe5ae35224a0dbd05aed9ccabccd241e8e6b60370484234b456266"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b315e596282bbb5822d0c7ee9d255595bd7506d1cb20c2911a4da0b970187d3"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1dfae24cf9921875ca0ca6a8ecb4bb2f13c855794ed0d468d6abbec6e6dcd44a"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6dd8ecfde08d8bfadaea669e83c63939af76f4cf5538a72597016edfa3fad516"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2f593494876eae852dc98c43c6f260f45abdbfeec9e4324e31a481d948214764"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:948b73114f47fd7016088e5186d13faf5e1b2fe83f5e320e371f035557fd264d"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e11f3864eb516af21b01e25fac915a82e9ddad3bb0fb9e95a246067398b435a4"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:549150be302428b56fdad0c23c2741dcdb5572413776826c965619a25d9c6bde"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:495bc156026efafd9ef2d82372bd38afce78ddd82bf28ef5276c469e57c0c83e"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ec79de2a8680b1a67a07490bddf9636d5c2fab609ba8c57597e855fa5fa4dacd"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-win32.whl", hash = "sha256:ee12a7be1742f81b8a65b36c6921022301d466b82d80315d215c4c691724986f"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-win_amd64.whl", hash = "sha256:ede9b407e39949d2afc46385ce6bd6e11588660c26f80576c11c958e6647bc40"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-win_arm64.whl", hash = "sha256:aa687a23d4b7871a00e03ca96a09cad0f28f443690d300500603bd0adba4b523"},
+ {file = "pydantic_core-2.33.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:401d7b76e1000d0dd5538e6381d28febdcacb097c8d340dde7d7fc6e13e9f95d"},
+ {file = "pydantic_core-2.33.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7aeb055a42d734c0255c9e489ac67e75397d59c6fbe60d155851e9782f276a9c"},
+ {file = "pydantic_core-2.33.1-cp313-cp313t-win_amd64.whl", hash = "sha256:338ea9b73e6e109f15ab439e62cb3b78aa752c7fd9536794112e14bee02c8d18"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:5ab77f45d33d264de66e1884fca158bc920cb5e27fd0764a72f72f5756ae8bdb"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e7aaba1b4b03aaea7bb59e1b5856d734be011d3e6d98f5bcaa98cb30f375f2ad"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7fb66263e9ba8fea2aa85e1e5578980d127fb37d7f2e292773e7bc3a38fb0c7b"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3f2648b9262607a7fb41d782cc263b48032ff7a03a835581abbf7a3bec62bcf5"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:723c5630c4259400818b4ad096735a829074601805d07f8cafc366d95786d331"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d100e3ae783d2167782391e0c1c7a20a31f55f8015f3293647544df3f9c67824"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:177d50460bc976a0369920b6c744d927b0ecb8606fb56858ff542560251b19e5"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a3edde68d1a1f9af1273b2fe798997b33f90308fb6d44d8550c89fc6a3647cf6"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a62c3c3ef6a7e2c45f7853b10b5bc4ddefd6ee3cd31024754a1a5842da7d598d"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-musllinux_1_1_armv7l.whl", hash = "sha256:c91dbb0ab683fa0cd64a6e81907c8ff41d6497c346890e26b23de7ee55353f96"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9f466e8bf0a62dc43e068c12166281c2eca72121dd2adc1040f3aa1e21ef8599"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-win32.whl", hash = "sha256:ab0277cedb698749caada82e5d099dc9fed3f906a30d4c382d1a21725777a1e5"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-win_amd64.whl", hash = "sha256:5773da0ee2d17136b1f1c6fbde543398d452a6ad2a7b54ea1033e2daa739b8d2"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5c834f54f8f4640fd7e4b193f80eb25a0602bba9e19b3cd2fc7ffe8199f5ae02"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:049e0de24cf23766f12cc5cc71d8abc07d4a9deb9061b334b62093dedc7cb068"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a28239037b3d6f16916a4c831a5a0eadf856bdd6d2e92c10a0da3a59eadcf3e"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d3da303ab5f378a268fa7d45f37d7d85c3ec19769f28d2cc0c61826a8de21fe"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:25626fb37b3c543818c14821afe0fd3830bc327a43953bc88db924b68c5723f1"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:3ab2d36e20fbfcce8f02d73c33a8a7362980cff717926bbae030b93ae46b56c7"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:2f9284e11c751b003fd4215ad92d325d92c9cb19ee6729ebd87e3250072cdcde"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:048c01eee07d37cbd066fc512b9d8b5ea88ceeb4e629ab94b3e56965ad655add"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:5ccd429694cf26af7997595d627dd2637e7932214486f55b8a357edaac9dae8c"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3a371dc00282c4b84246509a5ddc808e61b9864aa1eae9ecc92bb1268b82db4a"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:f59295ecc75a1788af8ba92f2e8c6eeaa5a94c22fc4d151e8d9638814f85c8fc"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08530b8ac922003033f399128505f513e30ca770527cc8bbacf75a84fcc2c74b"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bae370459da6a5466978c0eacf90690cb57ec9d533f8e63e564ef3822bfa04fe"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e3de2777e3b9f4d603112f78006f4ae0acb936e95f06da6cb1a45fbad6bdb4b5"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:3a64e81e8cba118e108d7126362ea30e021291b7805d47e4896e52c791be2761"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:52928d8c1b6bda03cc6d811e8923dffc87a2d3c8b3bfd2ce16471c7147a24850"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:1b30d92c9412beb5ac6b10a3eb7ef92ccb14e3f2a8d7732e2d739f58b3aa7544"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:f995719707e0e29f0f41a8aa3bcea6e761a36c9136104d3189eafb83f5cec5e5"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:7edbc454a29fc6aeae1e1eecba4f07b63b8d76e76a748532233c4c167b4cb9ea"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:ad05b683963f69a1d5d2c2bdab1274a31221ca737dbbceaa32bcb67359453cdd"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df6a94bf9452c6da9b5d76ed229a5683d0306ccb91cca8e1eea883189780d568"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7965c13b3967909a09ecc91f21d09cfc4576bf78140b988904e94f130f188396"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3f1fdb790440a34f6ecf7679e1863b825cb5ffde858a9197f851168ed08371e5"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:5277aec8d879f8d05168fdd17ae811dd313b8ff894aeeaf7cd34ad28b4d77e33"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:8ab581d3530611897d863d1a649fb0644b860286b4718db919bfd51ece41f10b"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:0483847fa9ad5e3412265c1bd72aad35235512d9ce9d27d81a56d935ef489672"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:de9e06abe3cc5ec6a2d5f75bc99b0bdca4f5c719a5b34026f8c57efbdecd2ee3"},
+ {file = "pydantic_core-2.33.1.tar.gz", hash = "sha256:bcc9c6fdb0ced789245b02b7d6603e17d1563064ddcfc36f046b61c0c05dd9df"},
]
[package.dependencies]
@@ -844,134 +1009,139 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
[[package]]
name = "pytest"
-version = "6.2.5"
+version = "8.3.5"
description = "pytest: simple powerful testing with Python"
optional = false
-python-versions = ">=3.6"
+python-versions = ">=3.8"
+groups = ["dev"]
files = [
- {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"},
- {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"},
+ {file = "pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820"},
+ {file = "pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845"},
]
[package.dependencies]
-atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""}
-attrs = ">=19.2.0"
colorama = {version = "*", markers = "sys_platform == \"win32\""}
-importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""}
+exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
iniconfig = "*"
packaging = "*"
-pluggy = ">=0.12,<2.0"
-py = ">=1.8.2"
-toml = "*"
+pluggy = ">=1.5,<2"
+tomli = {version = ">=1", markers = "python_version < \"3.11\""}
[package.extras]
-testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"]
+dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]]
name = "pytest-asyncio"
-version = "0.17.2"
+version = "0.26.0"
description = "Pytest support for asyncio"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.9"
+groups = ["dev"]
files = [
- {file = "pytest-asyncio-0.17.2.tar.gz", hash = "sha256:6d895b02432c028e6957d25fc936494e78c6305736e785d9fee408b1efbc7ff4"},
- {file = "pytest_asyncio-0.17.2-py3-none-any.whl", hash = "sha256:e0fe5dbea40516b661ef1bcfe0bd9461c2847c4ef4bb40012324f2454fb7d56d"},
+ {file = "pytest_asyncio-0.26.0-py3-none-any.whl", hash = "sha256:7b51ed894f4fbea1340262bdae5135797ebbe21d8638978e35d31c6d19f72fb0"},
+ {file = "pytest_asyncio-0.26.0.tar.gz", hash = "sha256:c4df2a697648241ff39e7f0e4a73050b03f123f760673956cf0d72a4990e312f"},
]
[package.dependencies]
-pytest = ">=6.1.0"
-typing-extensions = {version = ">=4.0", markers = "python_version < \"3.8\""}
+pytest = ">=8.2,<9"
+typing-extensions = {version = ">=4.12", markers = "python_version < \"3.10\""}
[package.extras]
-testing = ["coverage (==6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (==0.931)"]
+docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"]
+testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
[[package]]
name = "pytest-cov"
-version = "3.0.0"
+version = "6.1.1"
description = "Pytest plugin for measuring coverage."
optional = false
-python-versions = ">=3.6"
+python-versions = ">=3.9"
+groups = ["dev"]
files = [
- {file = "pytest-cov-3.0.0.tar.gz", hash = "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470"},
- {file = "pytest_cov-3.0.0-py3-none-any.whl", hash = "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6"},
+ {file = "pytest_cov-6.1.1-py3-none-any.whl", hash = "sha256:bddf29ed2d0ab6f4df17b4c55b0a657287db8684af9c42ea546b21b1041b3dde"},
+ {file = "pytest_cov-6.1.1.tar.gz", hash = "sha256:46935f7aaefba760e716c2ebfbe1c216240b9592966e7da99ea8292d4d3e2a0a"},
]
[package.dependencies]
-coverage = {version = ">=5.2.1", extras = ["toml"]}
+coverage = {version = ">=7.5", extras = ["toml"]}
pytest = ">=4.6"
[package.extras]
-testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"]
+testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
[[package]]
name = "pyyaml"
-version = "6.0.1"
+version = "6.0.2"
description = "YAML parser and emitter for Python"
optional = false
-python-versions = ">=3.6"
+python-versions = ">=3.8"
+groups = ["main"]
files = [
- {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"},
- {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"},
- {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
- {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
- {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
- {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
- {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
- {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
- {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
- {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"},
- {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
- {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
- {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
- {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
- {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
- {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
- {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
- {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
- {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
- {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
- {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
- {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
- {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
- {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"},
- {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"},
- {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"},
- {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"},
- {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"},
- {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"},
- {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"},
- {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"},
- {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"},
- {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"},
- {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"},
- {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"},
- {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"},
- {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"},
- {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"},
- {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"},
- {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"},
- {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"},
- {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"},
- {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
- {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"},
- {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
- {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
- {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
- {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"},
- {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
- {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
- {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
+ {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"},
+ {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"},
+ {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237"},
+ {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b"},
+ {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed"},
+ {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180"},
+ {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68"},
+ {file = "PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99"},
+ {file = "PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e"},
+ {file = "PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774"},
+ {file = "PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee"},
+ {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c"},
+ {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317"},
+ {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85"},
+ {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4"},
+ {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e"},
+ {file = "PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5"},
+ {file = "PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44"},
+ {file = "PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab"},
+ {file = "PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725"},
+ {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5"},
+ {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425"},
+ {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476"},
+ {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48"},
+ {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b"},
+ {file = "PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4"},
+ {file = "PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8"},
+ {file = "PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba"},
+ {file = "PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1"},
+ {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133"},
+ {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484"},
+ {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5"},
+ {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc"},
+ {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652"},
+ {file = "PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183"},
+ {file = "PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563"},
+ {file = "PyYAML-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a"},
+ {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5"},
+ {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d"},
+ {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083"},
+ {file = "PyYAML-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706"},
+ {file = "PyYAML-6.0.2-cp38-cp38-win32.whl", hash = "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a"},
+ {file = "PyYAML-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff"},
+ {file = "PyYAML-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d"},
+ {file = "PyYAML-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f"},
+ {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290"},
+ {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12"},
+ {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19"},
+ {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e"},
+ {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725"},
+ {file = "PyYAML-6.0.2-cp39-cp39-win32.whl", hash = "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631"},
+ {file = "PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8"},
+ {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"},
]
[[package]]
name = "requests"
-version = "2.31.0"
+version = "2.32.3"
description = "Python HTTP for Humans."
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
+groups = ["main"]
files = [
- {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"},
- {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"},
+ {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"},
+ {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"},
]
[package.dependencies]
@@ -984,180 +1154,220 @@ urllib3 = ">=1.21.1,<3"
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
-[[package]]
-name = "toml"
-version = "0.10.2"
-description = "Python Library for Tom's Obvious, Minimal Language"
-optional = false
-python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
-files = [
- {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
- {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
-]
-
[[package]]
name = "tomli"
-version = "2.0.1"
+version = "2.2.1"
description = "A lil' TOML parser"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
+groups = ["dev"]
+markers = "python_full_version <= \"3.11.0a6\""
files = [
- {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
- {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
+ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
+ {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
+ {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a"},
+ {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee"},
+ {file = "tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e"},
+ {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4"},
+ {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106"},
+ {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8"},
+ {file = "tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff"},
+ {file = "tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b"},
+ {file = "tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea"},
+ {file = "tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8"},
+ {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192"},
+ {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222"},
+ {file = "tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77"},
+ {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6"},
+ {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd"},
+ {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e"},
+ {file = "tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98"},
+ {file = "tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4"},
+ {file = "tomli-2.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f4039b9cbc3048b2416cc57ab3bda989a6fcf9b36cf8937f01a6e731b64f80d7"},
+ {file = "tomli-2.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:286f0ca2ffeeb5b9bd4fcc8d6c330534323ec51b2f52da063b11c502da16f30c"},
+ {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a92ef1a44547e894e2a17d24e7557a5e85a9e1d0048b0b5e7541f76c5032cb13"},
+ {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9316dc65bed1684c9a98ee68759ceaed29d229e985297003e494aa825ebb0281"},
+ {file = "tomli-2.2.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e85e99945e688e32d5a35c1ff38ed0b3f41f43fad8df0bdf79f72b2ba7bc5272"},
+ {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ac065718db92ca818f8d6141b5f66369833d4a80a9d74435a268c52bdfa73140"},
+ {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d920f33822747519673ee656a4b6ac33e382eca9d331c87770faa3eef562aeb2"},
+ {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a198f10c4d1b1375d7687bc25294306e551bf1abfa4eace6650070a5c1ae2744"},
+ {file = "tomli-2.2.1-cp313-cp313-win32.whl", hash = "sha256:d3f5614314d758649ab2ab3a62d4f2004c825922f9e370b29416484086b264ec"},
+ {file = "tomli-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69"},
+ {file = "tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc"},
+ {file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"},
]
[[package]]
name = "tqdm"
-version = "4.66.1"
+version = "4.67.1"
description = "Fast, Extensible Progress Meter"
optional = false
python-versions = ">=3.7"
+groups = ["main"]
files = [
- {file = "tqdm-4.66.1-py3-none-any.whl", hash = "sha256:d302b3c5b53d47bce91fea46679d9c3c6508cf6332229aa1e7d8653723793386"},
- {file = "tqdm-4.66.1.tar.gz", hash = "sha256:d88e651f9db8d8551a62556d3cff9e3034274ca5d66e93197cf2490e2dcb69c7"},
+ {file = "tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2"},
+ {file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"},
]
[package.dependencies]
colorama = {version = "*", markers = "platform_system == \"Windows\""}
[package.extras]
-dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"]
+dev = ["nbval", "pytest (>=6)", "pytest-asyncio (>=0.24)", "pytest-cov", "pytest-timeout"]
+discord = ["requests"]
notebook = ["ipywidgets (>=6)"]
slack = ["slack-sdk"]
telegram = ["requests"]
[[package]]
name = "typing-extensions"
-version = "4.7.1"
-description = "Backported and Experimental Type Hints for Python 3.7+"
+version = "4.13.2"
+description = "Backported and Experimental Type Hints for Python 3.8+"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
+groups = ["main", "dev"]
files = [
- {file = "typing_extensions-4.7.1-py3-none-any.whl", hash = "sha256:440d5dd3af93b060174bf433bccd69b0babc3b15b1a8dca43789fd7f61514b36"},
- {file = "typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2"},
+ {file = "typing_extensions-4.13.2-py3-none-any.whl", hash = "sha256:a439e7c04b49fec3e5d3e2beaa21755cadbbdc391694e28ccdd36ca4a1408f8c"},
+ {file = "typing_extensions-4.13.2.tar.gz", hash = "sha256:e6c81219bd689f51865d9e372991c540bda33a0379d5573cddb9a3a23f7caaef"},
]
+markers = {dev = "python_version < \"3.10\""}
+
+[[package]]
+name = "typing-inspection"
+version = "0.4.0"
+description = "Runtime typing introspection tools"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "typing_inspection-0.4.0-py3-none-any.whl", hash = "sha256:50e72559fcd2a6367a19f7a7e610e6afcb9fac940c650290eed893d61386832f"},
+ {file = "typing_inspection-0.4.0.tar.gz", hash = "sha256:9765c87de36671694a67904bf2c96e395be9c6439bb6c87b5142569dcdd65122"},
+]
+
+[package.dependencies]
+typing-extensions = ">=4.12.0"
[[package]]
name = "urllib3"
-version = "2.0.5"
+version = "2.4.0"
description = "HTTP library with thread-safe connection pooling, file post, and more."
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.9"
+groups = ["main"]
files = [
- {file = "urllib3-2.0.5-py3-none-any.whl", hash = "sha256:ef16afa8ba34a1f989db38e1dbbe0c302e4289a47856990d0682e374563ce35e"},
- {file = "urllib3-2.0.5.tar.gz", hash = "sha256:13abf37382ea2ce6fb744d4dad67838eec857c9f4f57009891805e0b5e123594"},
+ {file = "urllib3-2.4.0-py3-none-any.whl", hash = "sha256:4e16665048960a0900c702d4a66415956a584919c03361cac9f1df5c5dd7e813"},
+ {file = "urllib3-2.4.0.tar.gz", hash = "sha256:414bc6535b787febd7567804cc015fee39daab8ad86268f1310a9250697de466"},
]
[package.extras]
brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
-secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"]
+h2 = ["h2 (>=4,<5)"]
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "yarl"
-version = "1.9.2"
+version = "1.19.0"
description = "Yet another URL library"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.9"
+groups = ["main"]
files = [
- {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8c2ad583743d16ddbdf6bb14b5cd76bf43b0d0006e918809d5d4ddf7bde8dd82"},
- {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:82aa6264b36c50acfb2424ad5ca537a2060ab6de158a5bd2a72a032cc75b9eb8"},
- {file = "yarl-1.9.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c0c77533b5ed4bcc38e943178ccae29b9bcf48ffd1063f5821192f23a1bd27b9"},
- {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee4afac41415d52d53a9833ebae7e32b344be72835bbb589018c9e938045a560"},
- {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9bf345c3a4f5ba7f766430f97f9cc1320786f19584acc7086491f45524a551ac"},
- {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a96c19c52ff442a808c105901d0bdfd2e28575b3d5f82e2f5fd67e20dc5f4ea"},
- {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:891c0e3ec5ec881541f6c5113d8df0315ce5440e244a716b95f2525b7b9f3608"},
- {file = "yarl-1.9.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c3a53ba34a636a256d767c086ceb111358876e1fb6b50dfc4d3f4951d40133d5"},
- {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:566185e8ebc0898b11f8026447eacd02e46226716229cea8db37496c8cdd26e0"},
- {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:2b0738fb871812722a0ac2154be1f049c6223b9f6f22eec352996b69775b36d4"},
- {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:32f1d071b3f362c80f1a7d322bfd7b2d11e33d2adf395cc1dd4df36c9c243095"},
- {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:e9fdc7ac0d42bc3ea78818557fab03af6181e076a2944f43c38684b4b6bed8e3"},
- {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:56ff08ab5df8429901ebdc5d15941b59f6253393cb5da07b4170beefcf1b2528"},
- {file = "yarl-1.9.2-cp310-cp310-win32.whl", hash = "sha256:8ea48e0a2f931064469bdabca50c2f578b565fc446f302a79ba6cc0ee7f384d3"},
- {file = "yarl-1.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:50f33040f3836e912ed16d212f6cc1efb3231a8a60526a407aeb66c1c1956dde"},
- {file = "yarl-1.9.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:646d663eb2232d7909e6601f1a9107e66f9791f290a1b3dc7057818fe44fc2b6"},
- {file = "yarl-1.9.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:aff634b15beff8902d1f918012fc2a42e0dbae6f469fce134c8a0dc51ca423bb"},
- {file = "yarl-1.9.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a83503934c6273806aed765035716216cc9ab4e0364f7f066227e1aaea90b8d0"},
- {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b25322201585c69abc7b0e89e72790469f7dad90d26754717f3310bfe30331c2"},
- {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:22a94666751778629f1ec4280b08eb11815783c63f52092a5953faf73be24191"},
- {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ec53a0ea2a80c5cd1ab397925f94bff59222aa3cf9c6da938ce05c9ec20428d"},
- {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:159d81f22d7a43e6eabc36d7194cb53f2f15f498dbbfa8edc8a3239350f59fe7"},
- {file = "yarl-1.9.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:832b7e711027c114d79dffb92576acd1bd2decc467dec60e1cac96912602d0e6"},
- {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:95d2ecefbcf4e744ea952d073c6922e72ee650ffc79028eb1e320e732898d7e8"},
- {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d4e2c6d555e77b37288eaf45b8f60f0737c9efa3452c6c44626a5455aeb250b9"},
- {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:783185c75c12a017cc345015ea359cc801c3b29a2966c2655cd12b233bf5a2be"},
- {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:b8cc1863402472f16c600e3e93d542b7e7542a540f95c30afd472e8e549fc3f7"},
- {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:822b30a0f22e588b32d3120f6d41e4ed021806418b4c9f0bc3048b8c8cb3f92a"},
- {file = "yarl-1.9.2-cp311-cp311-win32.whl", hash = "sha256:a60347f234c2212a9f0361955007fcf4033a75bf600a33c88a0a8e91af77c0e8"},
- {file = "yarl-1.9.2-cp311-cp311-win_amd64.whl", hash = "sha256:be6b3fdec5c62f2a67cb3f8c6dbf56bbf3f61c0f046f84645cd1ca73532ea051"},
- {file = "yarl-1.9.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:38a3928ae37558bc1b559f67410df446d1fbfa87318b124bf5032c31e3447b74"},
- {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac9bb4c5ce3975aeac288cfcb5061ce60e0d14d92209e780c93954076c7c4367"},
- {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3da8a678ca8b96c8606bbb8bfacd99a12ad5dd288bc6f7979baddd62f71c63ef"},
- {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:13414591ff516e04fcdee8dc051c13fd3db13b673c7a4cb1350e6b2ad9639ad3"},
- {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf74d08542c3a9ea97bb8f343d4fcbd4d8f91bba5ec9d5d7f792dbe727f88938"},
- {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e7221580dc1db478464cfeef9b03b95c5852cc22894e418562997df0d074ccc"},
- {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:494053246b119b041960ddcd20fd76224149cfea8ed8777b687358727911dd33"},
- {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:52a25809fcbecfc63ac9ba0c0fb586f90837f5425edfd1ec9f3372b119585e45"},
- {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:e65610c5792870d45d7b68c677681376fcf9cc1c289f23e8e8b39c1485384185"},
- {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:1b1bba902cba32cdec51fca038fd53f8beee88b77efc373968d1ed021024cc04"},
- {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:662e6016409828ee910f5d9602a2729a8a57d74b163c89a837de3fea050c7582"},
- {file = "yarl-1.9.2-cp37-cp37m-win32.whl", hash = "sha256:f364d3480bffd3aa566e886587eaca7c8c04d74f6e8933f3f2c996b7f09bee1b"},
- {file = "yarl-1.9.2-cp37-cp37m-win_amd64.whl", hash = "sha256:6a5883464143ab3ae9ba68daae8e7c5c95b969462bbe42e2464d60e7e2698368"},
- {file = "yarl-1.9.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5610f80cf43b6202e2c33ba3ec2ee0a2884f8f423c8f4f62906731d876ef4fac"},
- {file = "yarl-1.9.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b9a4e67ad7b646cd6f0938c7ebfd60e481b7410f574c560e455e938d2da8e0f4"},
- {file = "yarl-1.9.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:83fcc480d7549ccebe9415d96d9263e2d4226798c37ebd18c930fce43dfb9574"},
- {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5fcd436ea16fee7d4207c045b1e340020e58a2597301cfbcfdbe5abd2356c2fb"},
- {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84e0b1599334b1e1478db01b756e55937d4614f8654311eb26012091be109d59"},
- {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3458a24e4ea3fd8930e934c129b676c27452e4ebda80fbe47b56d8c6c7a63a9e"},
- {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:838162460b3a08987546e881a2bfa573960bb559dfa739e7800ceeec92e64417"},
- {file = "yarl-1.9.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f4e2d08f07a3d7d3e12549052eb5ad3eab1c349c53ac51c209a0e5991bbada78"},
- {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:de119f56f3c5f0e2fb4dee508531a32b069a5f2c6e827b272d1e0ff5ac040333"},
- {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:149ddea5abf329752ea5051b61bd6c1d979e13fbf122d3a1f9f0c8be6cb6f63c"},
- {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:674ca19cbee4a82c9f54e0d1eee28116e63bc6fd1e96c43031d11cbab8b2afd5"},
- {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:9b3152f2f5677b997ae6c804b73da05a39daa6a9e85a512e0e6823d81cdad7cc"},
- {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5415d5a4b080dc9612b1b63cba008db84e908b95848369aa1da3686ae27b6d2b"},
- {file = "yarl-1.9.2-cp38-cp38-win32.whl", hash = "sha256:f7a3d8146575e08c29ed1cd287068e6d02f1c7bdff8970db96683b9591b86ee7"},
- {file = "yarl-1.9.2-cp38-cp38-win_amd64.whl", hash = "sha256:63c48f6cef34e6319a74c727376e95626f84ea091f92c0250a98e53e62c77c72"},
- {file = "yarl-1.9.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:75df5ef94c3fdc393c6b19d80e6ef1ecc9ae2f4263c09cacb178d871c02a5ba9"},
- {file = "yarl-1.9.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c027a6e96ef77d401d8d5a5c8d6bc478e8042f1e448272e8d9752cb0aff8b5c8"},
- {file = "yarl-1.9.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f3b078dbe227f79be488ffcfc7a9edb3409d018e0952cf13f15fd6512847f3f7"},
- {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59723a029760079b7d991a401386390c4be5bfec1e7dd83e25a6a0881859e716"},
- {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b03917871bf859a81ccb180c9a2e6c1e04d2f6a51d953e6a5cdd70c93d4e5a2a"},
- {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c1012fa63eb6c032f3ce5d2171c267992ae0c00b9e164efe4d73db818465fac3"},
- {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a74dcbfe780e62f4b5a062714576f16c2f3493a0394e555ab141bf0d746bb955"},
- {file = "yarl-1.9.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8c56986609b057b4839968ba901944af91b8e92f1725d1a2d77cbac6972b9ed1"},
- {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2c315df3293cd521033533d242d15eab26583360b58f7ee5d9565f15fee1bef4"},
- {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:b7232f8dfbd225d57340e441d8caf8652a6acd06b389ea2d3222b8bc89cbfca6"},
- {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:53338749febd28935d55b41bf0bcc79d634881195a39f6b2f767870b72514caf"},
- {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:066c163aec9d3d073dc9ffe5dd3ad05069bcb03fcaab8d221290ba99f9f69ee3"},
- {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8288d7cd28f8119b07dd49b7230d6b4562f9b61ee9a4ab02221060d21136be80"},
- {file = "yarl-1.9.2-cp39-cp39-win32.whl", hash = "sha256:b124e2a6d223b65ba8768d5706d103280914d61f5cae3afbc50fc3dfcc016623"},
- {file = "yarl-1.9.2-cp39-cp39-win_amd64.whl", hash = "sha256:61016e7d582bc46a5378ffdd02cd0314fb8ba52f40f9cf4d9a5e7dbef88dee18"},
- {file = "yarl-1.9.2.tar.gz", hash = "sha256:04ab9d4b9f587c06d801c2abfe9317b77cdf996c65a90d5e84ecc45010823571"},
+ {file = "yarl-1.19.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0bae32f8ebd35c04d6528cedb4a26b8bf25339d3616b04613b97347f919b76d3"},
+ {file = "yarl-1.19.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8015a076daf77823e7ebdcba474156587391dab4e70c732822960368c01251e6"},
+ {file = "yarl-1.19.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9973ac95327f5d699eb620286c39365990b240031672b5c436a4cd00539596c5"},
+ {file = "yarl-1.19.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd4b5fbd7b9dde785cfeb486b8cca211a0b138d4f3a7da27db89a25b3c482e5c"},
+ {file = "yarl-1.19.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:75460740005de5a912b19f657848aef419387426a40f581b1dc9fac0eb9addb5"},
+ {file = "yarl-1.19.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:57abd66ca913f2cfbb51eb3dbbbac3648f1f6983f614a4446e0802e241441d2a"},
+ {file = "yarl-1.19.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:46ade37911b7c99ce28a959147cb28bffbd14cea9e7dd91021e06a8d2359a5aa"},
+ {file = "yarl-1.19.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8346ec72ada749a6b5d82bff7be72578eab056ad7ec38c04f668a685abde6af0"},
+ {file = "yarl-1.19.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7e4cb14a6ee5b6649ccf1c6d648b4da9220e8277d4d4380593c03cc08d8fe937"},
+ {file = "yarl-1.19.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:66fc1c2926a73a2fb46e4b92e3a6c03904d9bc3a0b65e01cb7d2b84146a8bd3b"},
+ {file = "yarl-1.19.0-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:5a70201dd1e0a4304849b6445a9891d7210604c27e67da59091d5412bc19e51c"},
+ {file = "yarl-1.19.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e4807aab1bdeab6ae6f296be46337a260ae4b1f3a8c2fcd373e236b4b2b46efd"},
+ {file = "yarl-1.19.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:ae584afe81a1de4c1bb06672481050f0d001cad13163e3c019477409f638f9b7"},
+ {file = "yarl-1.19.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:30eaf4459df6e91f21b2999d1ee18f891bcd51e3cbe1de301b4858c84385895b"},
+ {file = "yarl-1.19.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:0e617d45d03c8dec0dfce6f51f3e1b8a31aa81aaf4a4d1442fdb232bcf0c6d8c"},
+ {file = "yarl-1.19.0-cp310-cp310-win32.whl", hash = "sha256:32ba32d0fa23893fd8ea8d05bdb05de6eb19d7f2106787024fd969f4ba5466cb"},
+ {file = "yarl-1.19.0-cp310-cp310-win_amd64.whl", hash = "sha256:545575ecfcd465891b51546c2bcafdde0acd2c62c2097d8d71902050b20e4922"},
+ {file = "yarl-1.19.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:163ff326680de5f6d4966954cf9e3fe1bf980f5fee2255e46e89b8cf0f3418b5"},
+ {file = "yarl-1.19.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a626c4d9cca298d1be8625cff4b17004a9066330ac82d132bbda64a4c17c18d3"},
+ {file = "yarl-1.19.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:961c3e401ea7f13d02b8bb7cb0c709152a632a6e14cdc8119e9c6ee5596cd45d"},
+ {file = "yarl-1.19.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a39d7b807ab58e633ed760f80195cbd145b58ba265436af35f9080f1810dfe64"},
+ {file = "yarl-1.19.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:c4228978fb59c6b10f60124ba8e311c26151e176df364e996f3f8ff8b93971b5"},
+ {file = "yarl-1.19.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ba536b17ecf3c74a94239ec1137a3ad3caea8c0e4deb8c8d2ffe847d870a8c5"},
+ {file = "yarl-1.19.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a251e00e445d2e9df7b827c9843c0b87f58a3254aaa3f162fb610747491fe00f"},
+ {file = "yarl-1.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9b92431d8b4d4ca5ccbfdbac95b05a3a6cd70cd73aa62f32f9627acfde7549c"},
+ {file = "yarl-1.19.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ec2f56edaf476f70b5831bbd59700b53d9dd011b1f77cd4846b5ab5c5eafdb3f"},
+ {file = "yarl-1.19.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:acf9b92c4245ac8b59bc7ec66a38d3dcb8d1f97fac934672529562bb824ecadb"},
+ {file = "yarl-1.19.0-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:57711f1465c06fee8825b95c0b83e82991e6d9425f9a042c3c19070a70ac92bf"},
+ {file = "yarl-1.19.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:528e86f5b1de0ad8dd758ddef4e0ed24f5d946d4a1cef80ffb2d4fca4e10f122"},
+ {file = "yarl-1.19.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:3b77173663e075d9e5a57e09d711e9da2f3266be729ecca0b8ae78190990d260"},
+ {file = "yarl-1.19.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:d8717924cf0a825b62b1a96fc7d28aab7f55a81bf5338b8ef41d7a76ab9223e9"},
+ {file = "yarl-1.19.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0df9f0221a78d858793f40cbea3915c29f969c11366646a92ca47e080a14f881"},
+ {file = "yarl-1.19.0-cp311-cp311-win32.whl", hash = "sha256:8b3ade62678ee2c7c10dcd6be19045135e9badad53108f7d2ed14896ee396045"},
+ {file = "yarl-1.19.0-cp311-cp311-win_amd64.whl", hash = "sha256:0626ee31edb23ac36bdffe607231de2cca055ad3a5e2dc5da587ef8bc6a321bc"},
+ {file = "yarl-1.19.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:7b687c334da3ff8eab848c9620c47a253d005e78335e9ce0d6868ed7e8fd170b"},
+ {file = "yarl-1.19.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b0fe766febcf523a2930b819c87bb92407ae1368662c1bc267234e79b20ff894"},
+ {file = "yarl-1.19.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:742ceffd3c7beeb2b20d47cdb92c513eef83c9ef88c46829f88d5b06be6734ee"},
+ {file = "yarl-1.19.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2af682a1e97437382ee0791eacbf540318bd487a942e068e7e0a6c571fadbbd3"},
+ {file = "yarl-1.19.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:63702f1a098d0eaaea755e9c9d63172be1acb9e2d4aeb28b187092bcc9ca2d17"},
+ {file = "yarl-1.19.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3560dcba3c71ae7382975dc1e912ee76e50b4cd7c34b454ed620d55464f11876"},
+ {file = "yarl-1.19.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:68972df6a0cc47c8abaf77525a76ee5c5f6ea9bbdb79b9565b3234ded3c5e675"},
+ {file = "yarl-1.19.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5684e7ff93ea74e47542232bd132f608df4d449f8968fde6b05aaf9e08a140f9"},
+ {file = "yarl-1.19.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8182ad422bfacdebd4759ce3adc6055c0c79d4740aea1104e05652a81cd868c6"},
+ {file = "yarl-1.19.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:aee5b90a5a9b71ac57400a7bdd0feaa27c51e8f961decc8d412e720a004a1791"},
+ {file = "yarl-1.19.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:8c0b2371858d5a814b08542d5d548adb03ff2d7ab32f23160e54e92250961a72"},
+ {file = "yarl-1.19.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:cd430c2b7df4ae92498da09e9b12cad5bdbb140d22d138f9e507de1aa3edfea3"},
+ {file = "yarl-1.19.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:a93208282c0ccdf73065fd76c6c129bd428dba5ff65d338ae7d2ab27169861a0"},
+ {file = "yarl-1.19.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:b8179280cdeb4c36eb18d6534a328f9d40da60d2b96ac4a295c5f93e2799e9d9"},
+ {file = "yarl-1.19.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:eda3c2b42dc0c389b7cfda2c4df81c12eeb552019e0de28bde8f913fc3d1fcf3"},
+ {file = "yarl-1.19.0-cp312-cp312-win32.whl", hash = "sha256:57f3fed859af367b9ca316ecc05ce79ce327d6466342734305aa5cc380e4d8be"},
+ {file = "yarl-1.19.0-cp312-cp312-win_amd64.whl", hash = "sha256:5507c1f7dd3d41251b67eecba331c8b2157cfd324849879bebf74676ce76aff7"},
+ {file = "yarl-1.19.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:59281b9ed27bc410e0793833bcbe7fc149739d56ffa071d1e0fe70536a4f7b61"},
+ {file = "yarl-1.19.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:d27a6482ad5e05e8bafd47bf42866f8a1c0c3345abcb48d4511b3c29ecc197dc"},
+ {file = "yarl-1.19.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7a8e19fd5a6fdf19a91f2409665c7a089ffe7b9b5394ab33c0eec04cbecdd01f"},
+ {file = "yarl-1.19.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cda34ab19099c3a1685ad48fe45172536610c312b993310b5f1ca3eb83453b36"},
+ {file = "yarl-1.19.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7908a25d33f94852b479910f9cae6cdb9e2a509894e8d5f416c8342c0253c397"},
+ {file = "yarl-1.19.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e66c14d162bac94973e767b24de5d7e6c5153f7305a64ff4fcba701210bcd638"},
+ {file = "yarl-1.19.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c03607bf932aa4cfae371e2dc9ca8b76faf031f106dac6a6ff1458418140c165"},
+ {file = "yarl-1.19.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9931343d1c1f4e77421687b6b94bbebd8a15a64ab8279adf6fbb047eff47e536"},
+ {file = "yarl-1.19.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:262087a8a0d73e1d169d45c2baf968126f93c97cf403e1af23a7d5455d52721f"},
+ {file = "yarl-1.19.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:70f384921c24e703d249a6ccdabeb57dd6312b568b504c69e428a8dd3e8e68ca"},
+ {file = "yarl-1.19.0-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:756b9ea5292a2c180d1fe782a377bc4159b3cfefaca7e41b5b0a00328ef62fa9"},
+ {file = "yarl-1.19.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:cbeb9c145d534c240a63b6ecc8a8dd451faeb67b3dc61d729ec197bb93e29497"},
+ {file = "yarl-1.19.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:087ae8f8319848c18e0d114d0f56131a9c017f29200ab1413b0137ad7c83e2ae"},
+ {file = "yarl-1.19.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:362f5480ba527b6c26ff58cff1f229afe8b7fdd54ee5ffac2ab827c1a75fc71c"},
+ {file = "yarl-1.19.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f408d4b4315e814e5c3668094e33d885f13c7809cbe831cbdc5b1bb8c7a448f4"},
+ {file = "yarl-1.19.0-cp313-cp313-win32.whl", hash = "sha256:24e4c367ad69988a2283dd45ea88172561ca24b2326b9781e164eb46eea68345"},
+ {file = "yarl-1.19.0-cp313-cp313-win_amd64.whl", hash = "sha256:0110f91c57ab43d1538dfa92d61c45e33b84df9257bd08fcfcda90cce931cbc9"},
+ {file = "yarl-1.19.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:85ac908cd5a97bbd3048cca9f1bf37b932ea26c3885099444f34b0bf5d5e9fa6"},
+ {file = "yarl-1.19.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6ba0931b559f1345df48a78521c31cfe356585670e8be22af84a33a39f7b9221"},
+ {file = "yarl-1.19.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5bc503e1c1fee1b86bcb58db67c032957a52cae39fe8ddd95441f414ffbab83e"},
+ {file = "yarl-1.19.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d995122dcaf180fd4830a9aa425abddab7c0246107c21ecca2fa085611fa7ce9"},
+ {file = "yarl-1.19.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:217f69e60a14da4eed454a030ea8283f8fbd01a7d6d81e57efb865856822489b"},
+ {file = "yarl-1.19.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aad67c8f13a4b79990082f72ef09c078a77de2b39899aabf3960a48069704973"},
+ {file = "yarl-1.19.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dff065a1a8ed051d7e641369ba1ad030d5a707afac54cf4ede7069b959898835"},
+ {file = "yarl-1.19.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ada882e26b16ee651ab6544ce956f2f4beaed38261238f67c2a96db748e17741"},
+ {file = "yarl-1.19.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:67a56b1acc7093451ea2de0687aa3bd4e58d6b4ef6cbeeaad137b45203deaade"},
+ {file = "yarl-1.19.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:e97d2f0a06b39e231e59ebab0e6eec45c7683b339e8262299ac952707bdf7688"},
+ {file = "yarl-1.19.0-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:a5288adb7c59d0f54e4ad58d86fb06d4b26e08a59ed06d00a1aac978c0e32884"},
+ {file = "yarl-1.19.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:1efbf4d03e6eddf5da27752e0b67a8e70599053436e9344d0969532baa99df53"},
+ {file = "yarl-1.19.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:f228f42f29cc87db67020f7d71624102b2c837686e55317b16e1d3ef2747a993"},
+ {file = "yarl-1.19.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:c515f7dd60ca724e4c62b34aeaa603188964abed2eb66bb8e220f7f104d5a187"},
+ {file = "yarl-1.19.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:4815ec6d3d68a96557fa71bd36661b45ac773fb50e5cfa31a7e843edb098f060"},
+ {file = "yarl-1.19.0-cp39-cp39-win32.whl", hash = "sha256:9fac2dd1c5ecb921359d9546bc23a6dcc18c6acd50c6d96f118188d68010f497"},
+ {file = "yarl-1.19.0-cp39-cp39-win_amd64.whl", hash = "sha256:5864f539ce86b935053bfa18205fa08ce38e9a40ea4d51b19ce923345f0ed5db"},
+ {file = "yarl-1.19.0-py3-none-any.whl", hash = "sha256:a727101eb27f66727576630d02985d8a065d09cd0b5fcbe38a5793f71b2a97ef"},
+ {file = "yarl-1.19.0.tar.gz", hash = "sha256:01e02bb80ae0dbed44273c304095295106e1d9470460e773268a27d11e594892"},
]
[package.dependencies]
idna = ">=2.0"
multidict = ">=4.0"
-typing-extensions = {version = ">=3.7.4", markers = "python_version < \"3.8\""}
-
-[[package]]
-name = "zipp"
-version = "3.15.0"
-description = "Backport of pathlib-compatible object wrapper for zip files"
-optional = false
-python-versions = ">=3.7"
-files = [
- {file = "zipp-3.15.0-py3-none-any.whl", hash = "sha256:48904fc76a60e542af151aded95726c1a5c34ed43ab4134b597665c86d7ad556"},
- {file = "zipp-3.15.0.tar.gz", hash = "sha256:112929ad649da941c23de50f356a2b5570c954b65150642bccdd66bf194d224b"},
-]
-
-[package.extras]
-docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
-testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
+propcache = ">=0.2.1"
[metadata]
-lock-version = "2.0"
-python-versions = "^3.7"
-content-hash = "b7fab8703967f2616ea59a98a437cd30f97f0c8d2a06e399d688814a2a2c64f8"
+lock-version = "2.1"
+python-versions = "^3.9"
+content-hash = "f136e898d37b7c7db1ccceb1822ade280d3542ca19cdd9dcf583cb9aefef11c6"
diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml
index 47ef9d71784..1448d76181b 100644
--- a/clients/python/pyproject.toml
+++ b/clients/python/pyproject.toml
@@ -11,15 +11,15 @@ repository = "https://github.com/huggingface/text-generation-inference"
[tool.poetry.dependencies]
-python = "^3.7"
+python = "^3.9"
pydantic = "> 2, < 3"
-aiohttp = "^3.8"
+aiohttp = "^3.11"
huggingface-hub = ">= 0.12, < 1.0"
-[tool.poetry.dev-dependencies]
-pytest = "^6.2.5"
-pytest-asyncio = "^0.17.2"
-pytest-cov = "^3.0.0"
+[tool.poetry.group.dev.dependencies]
+pytest = "^8"
+pytest-asyncio = "^0.26"
+pytest-cov = "^6.0.0"
[tool.pytest.ini_options]
asyncio_mode = "auto"
diff --git a/clients/python/tests/conftest.py b/clients/python/tests/conftest.py
index 17bb73b5d4b..f3db6e68a47 100644
--- a/clients/python/tests/conftest.py
+++ b/clients/python/tests/conftest.py
@@ -21,7 +21,7 @@ def fake_model():
@pytest.fixture
def unsupported_model():
- return "gpt2"
+ return "google-bert/bert-base-uncased"
@pytest.fixture
diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py
index 8aed865b7b9..0c702c63687 100644
--- a/clients/python/tests/test_client.py
+++ b/clients/python/tests/test_client.py
@@ -2,7 +2,7 @@
from text_generation import Client, AsyncClient
from text_generation.errors import NotFoundError, ValidationError
-from text_generation.types import FinishReason, InputToken
+from text_generation.types import FinishReason
def test_generate(llama_7b_url, hf_headers):
@@ -13,8 +13,8 @@ def test_generate(llama_7b_url, hf_headers):
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
- assert len(response.details.prefill) == 2
- assert response.details.prefill[0] == InputToken(id=1, text="", logprob=None)
+ assert len(response.details.prefill) == 0
+ # assert response.details.prefill[0] == InputToken(id=1, text="", logprob=None)
assert len(response.details.tokens) == 1
assert response.details.tokens[0].id == 29918
assert response.details.tokens[0].text == "_"
@@ -83,11 +83,11 @@ async def test_generate_async(llama_7b_url, hf_headers):
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
- assert len(response.details.prefill) == 2
- assert response.details.prefill[0] == InputToken(id=1, text="", logprob=None)
- assert response.details.prefill[1] == InputToken(
- id=1243, text="test", logprob=-10.96875
- )
+ assert len(response.details.prefill) == 0
+ # assert response.details.prefill[0] == InputToken(id=1, text="", logprob=None)
+ # assert response.details.prefill[1] == InputToken(
+ # id=1243, text="test", logprob=-10.96875
+ # )
assert len(response.details.tokens) == 1
assert response.details.tokens[0].id == 29918
assert response.details.tokens[0].text == "_"
diff --git a/clients/python/tests/test_inference_api.py b/clients/python/tests/test_inference_api.py
index 59297c26e94..5a2584059f3 100644
--- a/clients/python/tests/test_inference_api.py
+++ b/clients/python/tests/test_inference_api.py
@@ -1,42 +1,42 @@
-import pytest
-
-from text_generation import (
- InferenceAPIClient,
- InferenceAPIAsyncClient,
- Client,
- AsyncClient,
-)
-from text_generation.errors import NotSupportedError, NotFoundError
-from text_generation.inference_api import check_model_support, deployed_models
-
-
-def test_check_model_support(flan_t5_xxl, unsupported_model, fake_model):
- assert check_model_support(flan_t5_xxl)
- assert not check_model_support(unsupported_model)
-
- with pytest.raises(NotFoundError):
- check_model_support(fake_model)
-
-
-def test_deployed_models():
- deployed_models()
-
-
-def test_client(flan_t5_xxl):
- client = InferenceAPIClient(flan_t5_xxl)
- assert isinstance(client, Client)
-
-
-def test_client_unsupported_model(unsupported_model):
- with pytest.raises(NotSupportedError):
- InferenceAPIClient(unsupported_model)
-
-
-def test_async_client(flan_t5_xxl):
- client = InferenceAPIAsyncClient(flan_t5_xxl)
- assert isinstance(client, AsyncClient)
-
-
-def test_async_client_unsupported_model(unsupported_model):
- with pytest.raises(NotSupportedError):
- InferenceAPIAsyncClient(unsupported_model)
+# import pytest
+#
+# from text_generation import (
+# InferenceAPIClient,
+# InferenceAPIAsyncClient,
+# Client,
+# AsyncClient,
+# )
+# from text_generation.errors import NotSupportedError, NotFoundError
+# from text_generation.inference_api import check_model_support, deployed_models
+#
+#
+# def test_check_model_support(flan_t5_xxl, unsupported_model, fake_model):
+# assert check_model_support(flan_t5_xxl)
+# assert not check_model_support(unsupported_model)
+#
+# with pytest.raises(NotFoundError):
+# check_model_support(fake_model)
+#
+#
+# def test_deployed_models():
+# deployed_models()
+#
+#
+# def test_client(flan_t5_xxl):
+# client = InferenceAPIClient(flan_t5_xxl)
+# assert isinstance(client, Client)
+#
+#
+# def test_client_unsupported_model(unsupported_model):
+# with pytest.raises(NotSupportedError):
+# InferenceAPIClient(unsupported_model)
+#
+#
+# def test_async_client(flan_t5_xxl):
+# client = InferenceAPIAsyncClient(flan_t5_xxl)
+# assert isinstance(client, AsyncClient)
+#
+#
+# def test_async_client_unsupported_model(unsupported_model):
+# with pytest.raises(NotSupportedError):
+# InferenceAPIAsyncClient(unsupported_model)
diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py
index 45301b63256..0b60d93aa0a 100644
--- a/clients/python/text_generation/client.py
+++ b/clients/python/text_generation/client.py
@@ -867,7 +867,7 @@ async def generate(
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
- async with session.post(self.base_url, json=request.dict()) as resp:
+ async with session.post(self.base_url, json=request.model_dump()) as resp:
payload = await resp.json()
if resp.status != 200:
diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py
index 1085075e438..6f51c153eab 100644
--- a/clients/python/text_generation/types.py
+++ b/clients/python/text_generation/types.py
@@ -67,7 +67,7 @@ class ChoiceDeltaToolCall(BaseModel):
class ChoiceDelta(BaseModel):
role: str
content: Optional[str] = None
- tool_calls: Optional[ChoiceDeltaToolCall] = None
+ tool_calls: Optional[List[ChoiceDeltaToolCall]] = None
class Choice(BaseModel):
diff --git a/docs/openapi.json b/docs/openapi.json
index 1caf67525f0..89172386388 100644
--- a/docs/openapi.json
+++ b/docs/openapi.json
@@ -10,7 +10,7 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
- "version": "3.0.1-dev0"
+ "version": "3.3.6-dev0"
},
"paths": {
"/": {
@@ -57,7 +57,8 @@
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "error": "Input validation error"
+ "error": "Input validation error",
+ "error_type": "validation"
}
}
}
@@ -70,7 +71,8 @@
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "error": "Request failed during generation"
+ "error": "Request failed during generation",
+ "error_type": "generation"
}
}
}
@@ -83,7 +85,8 @@
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "error": "Model is overloaded"
+ "error": "Model is overloaded",
+ "error_type": "overloaded"
}
}
}
@@ -96,7 +99,8 @@
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "error": "Incomplete generation"
+ "error": "Incomplete generation",
+ "error_type": "incomplete_generation"
}
}
}
@@ -181,7 +185,8 @@
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "error": "Input validation error"
+ "error": "Input validation error",
+ "error_type": "validation"
}
}
}
@@ -194,7 +199,8 @@
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "error": "Request failed during generation"
+ "error": "Request failed during generation",
+ "error_type": "generation"
}
}
}
@@ -207,7 +213,8 @@
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "error": "Model is overloaded"
+ "error": "Model is overloaded",
+ "error_type": "overloaded"
}
}
}
@@ -220,7 +227,8 @@
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "error": "Incomplete generation"
+ "error": "Incomplete generation",
+ "error_type": "incomplete_generation"
}
}
}
@@ -264,7 +272,8 @@
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "error": "Input validation error"
+ "error": "Input validation error",
+ "error_type": "validation"
}
}
}
@@ -277,7 +286,8 @@
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "error": "Request failed during generation"
+ "error": "Request failed during generation",
+ "error_type": "generation"
}
}
}
@@ -290,7 +300,8 @@
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "error": "Model is overloaded"
+ "error": "Model is overloaded",
+ "error_type": "overloaded"
}
}
}
@@ -303,7 +314,8 @@
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "error": "Incomplete generation"
+ "error": "Incomplete generation",
+ "error_type": "incomplete_generation"
}
}
}
@@ -558,7 +570,8 @@
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "error": "Input validation error"
+ "error": "Input validation error",
+ "error_type": "validation"
}
}
}
@@ -571,7 +584,8 @@
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "error": "Request failed during generation"
+ "error": "Request failed during generation",
+ "error_type": "generation"
}
}
}
@@ -584,7 +598,8 @@
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "error": "Model is overloaded"
+ "error": "Model is overloaded",
+ "error_type": "overloaded"
}
}
}
@@ -597,7 +612,8 @@
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "error": "Incomplete generation"
+ "error": "Incomplete generation",
+ "error_type": "incomplete_generation"
}
}
}
@@ -646,7 +662,8 @@
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "error": "Input validation error"
+ "error": "Input validation error",
+ "error_type": "validation"
}
}
}
@@ -659,7 +676,8 @@
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "error": "Request failed during generation"
+ "error": "Request failed during generation",
+ "error_type": "generation"
}
}
}
@@ -672,7 +690,8 @@
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "error": "Model is overloaded"
+ "error": "Model is overloaded",
+ "error_type": "overloaded"
}
}
}
@@ -685,7 +704,8 @@
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "error": "Incomplete generation"
+ "error": "Incomplete generation",
+ "error_type": "incomplete_generation"
}
}
}
@@ -1771,6 +1791,24 @@
"type": "string"
}
}
+ },
+ {
+ "type": "object",
+ "required": [
+ "type",
+ "value"
+ ],
+ "properties": {
+ "type": {
+ "type": "string",
+ "enum": [
+ "json_schema"
+ ]
+ },
+ "value": {
+ "$ref": "#/components/schemas/JsonSchemaConfig"
+ }
+ }
}
],
"discriminator": {
@@ -1864,27 +1902,75 @@
}
}
},
- "Message": {
+ "JsonSchemaConfig": {
"type": "object",
"required": [
- "role",
- "content"
+ "schema"
],
"properties": {
- "content": {
- "$ref": "#/components/schemas/MessageContent"
- },
"name": {
"type": "string",
- "example": "\"David\"",
+ "description": "Optional name identifier for the schema",
"nullable": true
},
- "role": {
- "type": "string",
- "example": "user"
+ "schema": {
+ "description": "The actual JSON schema definition"
}
}
},
+ "Message": {
+ "allOf": [
+ {
+ "$ref": "#/components/schemas/MessageBody"
+ },
+ {
+ "type": "object",
+ "required": [
+ "role"
+ ],
+ "properties": {
+ "name": {
+ "type": "string",
+ "example": "\"David\"",
+ "nullable": true
+ },
+ "role": {
+ "type": "string",
+ "example": "user"
+ }
+ }
+ }
+ ]
+ },
+ "MessageBody": {
+ "oneOf": [
+ {
+ "type": "object",
+ "required": [
+ "content"
+ ],
+ "properties": {
+ "content": {
+ "$ref": "#/components/schemas/MessageContent"
+ }
+ }
+ },
+ {
+ "type": "object",
+ "required": [
+ "tool_calls"
+ ],
+ "properties": {
+ "tool_calls": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/ToolCall"
+ }
+ }
+ }
+ }
+ ]
+ },
"MessageChunk": {
"oneOf": [
{
@@ -2116,9 +2202,6 @@
},
"StreamOptions": {
"type": "object",
- "required": [
- "include_usage"
- ],
"properties": {
"include_usage": {
"type": "boolean",
@@ -2179,6 +2262,10 @@
"role": {
"type": "string",
"example": "user"
+ },
+ "tool_call_id": {
+ "type": "string",
+ "nullable": true
}
}
},
@@ -2266,7 +2353,10 @@
"example": "assistant"
},
"tool_calls": {
- "$ref": "#/components/schemas/DeltaToolCall"
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/DeltaToolCall"
+ }
}
}
},
diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index e31a3788485..4c6f0151d9b 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -12,11 +12,15 @@
- local: installation_gaudi
title: Using TGI with Intel Gaudi
- local: installation_inferentia
- title: Using TGI with AWS Inferentia
+ title: Using TGI with AWS Trainium and Inferentia
+ - local: installation_tpu
+ title: Using TGI with Google TPUs
- local: installation_intel
title: Using TGI with Intel GPUs
- local: installation
title: Installation from source
+ - local: multi_backend_support
+ title: Multi-backend support
- local: architecture
title: Internal Architecture
@@ -45,6 +49,16 @@
- local: basic_tutorials/train_medusa
title: Train Medusa
title: Tutorials
+- sections:
+ - local: backends/neuron
+ title: Neuron
+ - local: backends/gaudi
+ title: Gaudi
+ - local: backends/trtllm
+ title: TensorRT-LLM
+ - local: backends/llamacpp
+ title: Llamacpp
+ title: Backends
- sections:
- local: reference/launcher
title: All TGI CLI options
diff --git a/docs/source/architecture.md b/docs/source/architecture.md
index 6660630d061..b475bb6dc7e 100644
--- a/docs/source/architecture.md
+++ b/docs/source/architecture.md
@@ -9,8 +9,10 @@ A high-level architecture diagram can be seen here:
This diagram shows well there are these separate components:
- **The router**, also named `webserver`, that receives the client requests, buffers them, creates some batches, and prepares gRPC calls to a model server.
-- **The model server**, responsible of receiving the gRPC requests and to process the inference on the model. If the model is sharded across multiple accelerators (e.g.: multiple GPUs), the model server shards might be synchronized via NCCL or equivalent.
- **The launcher** is a helper that will be able to launch one or several model servers (if model is sharded), and it launches the router with the compatible arguments.
+- **The model server**, responsible for receiving the gRPC requests and to process the inference on the model. If the model is sharded across multiple accelerators (e.g.: multiple GPUs), the model server shards might be synchronized via NCCL or equivalent.
+
+Note that for other backends (eg. TRTLLM) the model server and launcher are specific to the backend.
The router and the model server can be two different machines, they do not need to be deployed together.
@@ -105,7 +107,7 @@ Several variants of the model server exist that are actively supported by Huggin
- A [version optimized for AMD with ROCm](https://huggingface.co/docs/text-generation-inference/installation_amd) is hosted in the main TGI repository. Some model features differ.
- A [version optimized for Intel GPUs](https://huggingface.co/docs/text-generation-inference/installation_intel) is hosted in the main TGI repository. Some model features differ.
- The [version for Intel Gaudi](https://huggingface.co/docs/text-generation-inference/installation_gaudi) is maintained on a forked repository, often resynchronized with the main [TGI repository](https://github.com/huggingface/tgi-gaudi).
-- A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained as part of [Optimum Neuron](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference).
+- A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained in the main TGI repository. Some model features differ.
- A version for Google TPUs is maintained as part of [Optimum TPU](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference).
Not all variants provide the same features, as hardware and middleware capabilities do not provide the same optimizations.
diff --git a/docs/source/backends/gaudi.mdx b/docs/source/backends/gaudi.mdx
new file mode 100644
index 00000000000..07d34a8204e
--- /dev/null
+++ b/docs/source/backends/gaudi.mdx
@@ -0,0 +1,264 @@
+# Gaudi Backend for Text Generation Inference
+
+## Overview
+Text Generation Inference (TGI) has been optimized to run on Gaudi hardware via the Gaudi backend for TGI.
+
+## Supported Hardware
+- **Gaudi1**: Available on [AWS EC2 DL1 instances](https://aws.amazon.com/ec2/instance-types/dl1/)
+- **Gaudi2**: Available on [Intel Cloud](https://console.cloud.intel.com/docs/reference/ai_instances.html)
+- **Gaudi3**: Available on [Intel Cloud](https://console.cloud.intel.com/docs/reference/ai_instances.html)
+
+## Tutorial: Getting Started with TGI on Gaudi
+
+### Basic Usage
+The easiest way to run TGI on Gaudi is to use the official Docker image:
+
+```bash
+model=meta-llama/Meta-Llama-3.1-8B-Instruct
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+hf_token=YOUR_HF_ACCESS_TOKEN
+
+docker run --runtime=habana --cap-add=sys_nice --ipc=host \
+ -p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
+ ghcr.io/huggingface/text-generation-inference:3.3.5-gaudi \
+ --model-id $model
+```
+
+Once you see the `connected` log, the server is ready to accept requests:
+> 2024-05-22T19:31:48.302239Z INFO text_generation_router: router/src/main.rs:378: Connected
+
+You can find your `YOUR_HF_ACCESS_TOKEN` at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens). This is necessary to access gated models like llama3.1.
+
+### Making Your First Request
+You can send a request from a separate terminal:
+
+```bash
+curl 127.0.0.1:8080/generate \
+ -X POST \
+ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":32}}' \
+ -H 'Content-Type: application/json'
+```
+
+## How-to Guides
+
+You can view the full list of supported models in the [Supported Models](https://huggingface.co/docs/text-generation-inference/backends/gaudi#supported-models) section.
+
+For example, to run Llama3.1-8B, you can use the following command:
+
+```bash
+model=meta-llama/Meta-Llama-3.1-8B-Instruct
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+hf_token=YOUR_ACCESS_TOKEN
+
+docker run --runtime=habana --cap-add=sys_nice --ipc=host \
+ -p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
+ ghcr.io/huggingface/text-generation-inference:3.3.5-gaudi \
+ --model-id $model
+
+```
+
+For the full list of service parameters, refer to the [launcher-arguments page](https://huggingface.co/docs/text-generation-inference/reference/launcher).
+
+The validated docker commands can be found in the [examples/docker_commands folder](https://github.com/huggingface/text-generation-inference/tree/main/backends/gaudi/examples/docker_commands).
+
+> Note: `--runtime=habana --cap-add=sys_nice --ipc=host ` is required to enable docker to use the Gaudi hardware (more details [here](https://docs.habana.ai/en/latest/Installation_Guide/Additional_Installation/Docker_Installation.html)).
+
+### How to Enable Multi-Card Inference (Sharding)
+
+TGI-Gaudi supports sharding for multi-card inference, allowing you to distribute the load across multiple Gaudi cards. This is recommended to run large models and to speed up inference.
+
+For example, on a machine with 8 Gaudi cards, you can run:
+
+```bash
+docker run --runtime=habana --ipc=host --cap-add=sys_nice \
+ -p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
+ tgi-gaudi \
+ --model-id $model --sharded true --num-shard 8
+```
+
+
+We recommend always using sharding when running on a multi-card machine.
+
+
+### How to Use Different Precision Formats
+
+#### BF16 Precision (Default)
+By default, all models run with BF16 precision on Gaudi hardware.
+
+#### FP8 Precision
+TGI-Gaudi supports FP8 precision inference, which can significantly reduce memory usage and improve performance for large models. We support model like W8A8 FP compressed-tensors parameters such as [RedHatAI/Mixtral-8x7B-Instruct-v0.1-FP8](https://huggingface.co/RedHatAI/Mixtral-8x7B-Instruct-v0.1-FP8) and AutoFP8 generated model[RedHatAI/Meta-Llama-3-8B-Instruct-FP8](https://huggingface.co/RedHatAI/Meta-Llama-3-8B-Instruct-FP8) .
+TGI-Gaudi supports FP8 precision inference with [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html).
+
+
+### How to Run Vision-Language Models (VLMs)
+
+Gaudi supports VLM inference.
+
+Example for Llava-v1.6-Mistral-7B on 1 card:
+
+Start the TGI server via the following command:
+```bash
+model=llava-hf/llava-v1.6-mistral-7b-hf
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+
+docker run -p 8080:80 \
+ --runtime=habana \
+ --cap-add=sys_nice \
+ --ipc=host \
+ -v $volume:/data \
+ ghcr.io/huggingface/text-generation-inference:3.3.5-gaudi \
+ --model-id $model \
+ --max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
+ --max-total-tokens 8192 --max-batch-size 4
+```
+
+You can then send a request to the server via the following command:
+```bash
+curl -N 127.0.0.1:8080/generate \
+ -X POST \
+ -d '{"inputs":"What is this a picture of?\n\n","parameters":{"max_new_tokens":32}}' \
+ -H 'Content-Type: application/json'
+```
+
+> Note: In Llava-v1.6-Mistral-7B, an image usually accounts for 2000 input tokens. For example, an image of size 512x512 is represented by 2800 tokens. Thus, `max-input-tokens` must be larger than the number of tokens associated with the image. Otherwise the image may be truncated. The value of `max-batch-prefill-tokens` is 16384, which is calculated as follows: `prefill_batch_size` = `max-batch-prefill-tokens` / `max-input-tokens`.
+
+### How to Benchmark Performance
+
+We recommend using the [inference-benchmarker tool](https://github.com/huggingface/inference-benchmarker) to benchmark performance on Gaudi hardware.
+
+This benchmark tool simulates user requests and measures the performance of the model on realistic scenarios.
+
+To run it on the same machine, you can do the following:
+```bash
+MODEL=meta-llama/Llama-3.1-8B-Instruct
+HF_TOKEN=
+# run a benchmark to evaluate the performance of the model for chat use case
+# we mount results to the current directory
+docker run \
+ --rm \
+ -it \
+ --net host \
+ -v $(pwd):/opt/inference-benchmarker/results \
+ -e "HF_TOKEN=$HF_TOKEN" \
+ ghcr.io/huggingface/inference-benchmarker:latest \
+ inference-benchmarker \
+ --tokenizer-name "$MODEL" \
+ --url http://localhost:8080 \
+ --profile chat
+```
+
+Please refer to the [inference-benchmarker README](https://github.com/huggingface/inference-benchmarker) for more details.
+
+## Explanation: Understanding TGI on Gaudi
+
+### The Warmup Process
+
+Intel Gaudi accelerators perform best when operating on models with fixed tensor shapes. [Intel Gaudi Graph Compiler](https://docs.habana.ai/en/latest/Gaudi_Overview/Intel_Gaudi_Software_Suite.html#graph-compiler-and-runtime)
+generates optimized binary code that implements the given model topology on Gaudi. In its default configuration, the produced binary code may be highly dependent on input and output tensor shapes, requiring graph recompilation
+when encountering tensors with different shapes within the same topology. While these binaries efficiently utilize Gaudi, the compilation process itself can introduce noticeable overhead in end-to-end execution.
+In dynamic inference serving scenarios, minimizing the number of graph compilations and reducing the risk of graph compilation occurring during server runtime is important.
+
+To ensure optimal performance, warmup is performed at the beginning of each server run. This process creates queries with various input shapes based on provided parameters and runs basic TGI operations (prefill, decode).
+
+Note: Model warmup can take several minutes, especially for FP8 inference. For faster subsequent runs, refer to [Disk Caching Eviction Policy](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#disk-caching-eviction-policy).
+
+### Understanding Parameter Tuning
+
+#### Sequence Length Parameters
+- `--max-input-tokens` is the maximum possible input prompt length. Default value is `4095`.
+- `--max-total-tokens` is the maximum possible total length of the sequence (input and output). Default value is `4096`.
+
+#### Batch Size Parameters
+- For prefill operation, please set `--max-batch-prefill-tokens` as `bs * max-input-tokens`, where `bs` is your expected maximum prefill batch size.
+- For decode operation, please set `--max-batch-size` as `bs`, where `bs` is your expected maximum decode batch size.
+- Please note that batch size will be always padded to the nearest shapes that has been warmed up. This is done to avoid out of memory issues and to ensure that the graphs are reused efficiently.
+
+
+## Reference
+
+This section contains reference information about the Gaudi backend.
+
+### Supported Models
+
+Text Generation Inference enables serving optimized models on Gaudi hardware. The following sections list which models (VLMs & LLMs) are supported on Gaudi.
+
+**Large Language Models (LLMs)**
+- [deepseek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)
+- [deepseek-v2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
+- [Llama2](https://huggingface.co/collections/meta-llama/llama-2-family-661da1f90a9d678b6f55773b)
+- [Llama3](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
+- [CodeLlama](https://huggingface.co/codellama/CodeLlama-13b-hf)
+- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
+- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3)
+- [Qwen 2](https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f)
+- [Qwen 3](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f)
+- [Qwen 3 Moe](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f)
+- [Phi-1.5](https://huggingface.co/microsoft/phi-1_5)
+- [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
+- [PhiMoe](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct)
+- [Gemma](https://huggingface.co/google/gemma-7b-it)
+- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
+- [Gemma3 Text](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d)
+- [Granite](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct)
+- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
+- [dbrx](https://huggingface.co/databricks/dbrx-instruct)
+- [Starcoder2](https://huggingface.co/bigcode/starcoder2-3b)
+- [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct)
+- [GPT-2](https://huggingface.co/openai-community/gpt2)
+- [gpt-j-6b](https://huggingface.co/EleutherAI/gpt-j-6b)
+- [gpt-bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder)
+- [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat)
+
+
+**Vision-Language Models (VLMs)**
+- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf)
+- [Mllama (Multimodal Llama from Meta)](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)
+- [idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b)
+- [idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3)
+- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
+- [Llama4](https://huggingface.co/collections/meta-llama/llama-4-67f0c30d9fe03840bc9d0164)
+- [Gemma3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d)
+- [Qwen 2.5 VL](https://huggingface.co/collections/Qwen/qwen25-vl-6795ffac22b334a837c0f9a5)
+- [Qwen 2 VL](https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d)
+
+If you have an issue with a model, please open an issue on the [Gaudi backend repository](https://github.com/huggingface/text-generation-inference/issues).
+
+### Environment Variables
+
+The following table contains the environment variables that can be used to configure the Gaudi backend:
+
+| Name | Value(s) | Default | Description | Usage |
+|-----------------------------| :--------- | :--------------- | :------------------------------------------------------------------------------------------------------------------------------- | :--------------------------- |
+| LIMIT_HPU_GRAPH | True/False | True | Skip HPU graph usage for prefill to save memory, set to `True` for large sequence/decoding lengths(e.g. 300/212) | add -e in docker run command |
+| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command |
+| VLLM_SKIP_WARMUP | True/False | False | Skip graph warmup during server initialization which is not recommended, but could be used for debug. | add -e in docker run command |
+
+
+## Contributing
+
+Contributions to the TGI-Gaudi project are welcome. Please refer to the [contributing guide](https://github.com/huggingface/text-generation-inference/blob/main/CONTRIBUTING.md).
+
+**Guidelines for contributing to Gaudi on TGI:** All changes should be made within the `backends/gaudi` folder. In general, you should avoid modifying the router, launcher, or benchmark to accommodate Gaudi hardware, as all Gaudi-specific logic should be contained within the `backends/gaudi` folder.
+
+### Building the Docker Image from Source
+
+To build the Docker image from source:
+
+```bash
+make -C backends/gaudi image
+```
+
+This builds the image and saves it as `tgi-gaudi`. You can then run TGI-Gaudi with this image:
+
+```bash
+model=meta-llama/Meta-Llama-3.1-8B-Instruct
+volume=$PWD/data
+hf_token=YOUR_ACCESS_TOKEN
+
+docker run --runtime=habana --ipc=host --cap-add=sys_nice \
+ -p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
+ tgi-gaudi \
+ --model-id $model
+```
+
+For more details, see the [README of the Gaudi backend](https://github.com/huggingface/text-generation-inference/blob/main/backends/gaudi/README.md) and the [Makefile of the Gaudi backend](https://github.com/huggingface/text-generation-inference/blob/main/backends/gaudi/Makefile).
diff --git a/docs/source/backends/llamacpp.md b/docs/source/backends/llamacpp.md
new file mode 100644
index 00000000000..8014dd5f660
--- /dev/null
+++ b/docs/source/backends/llamacpp.md
@@ -0,0 +1,144 @@
+# Llamacpp Backend
+
+The llamacpp backend facilitates the deployment of large language models
+(LLMs) by integrating [llama.cpp][llama.cpp], an advanced inference engine
+optimized for both CPU and GPU computation. This backend is a component
+of Hugging Face’s **Text Generation Inference (TGI)** suite,
+specifically designed to streamline the deployment of LLMs in production
+environments.
+
+## Key Capabilities
+
+- Full compatibility with GGUF format and all quantization formats
+ (GGUF-related constraints may be mitigated dynamically by on-the-fly
+ generation in future updates)
+- Optimized inference on CPU and GPU architectures
+- Containerized deployment, eliminating dependency complexity
+- Seamless interoperability with the Hugging Face ecosystem
+
+## Model Compatibility
+
+This backend leverages models formatted in **GGUF**, providing an
+optimized balance between computational efficiency and model accuracy.
+You will find the best models on [Hugging Face][GGUF].
+
+## Build Docker image
+
+For optimal performance, the Docker image is compiled with native CPU
+instructions by default. As a result, it is strongly recommended to run
+the container on the same host architecture used during the build
+process. Efforts are ongoing to improve portability across different
+systems while preserving high computational efficiency.
+
+To build the Docker image, use the following command:
+
+```bash
+docker build \
+ -t tgi-llamacpp \
+ https://github.com/huggingface/text-generation-inference.git \
+ -f Dockerfile_llamacpp
+```
+
+### Build parameters
+
+| Parameter (with --build-arg) | Description |
+| ----------------------------------------- | -------------------------------- |
+| `llamacpp_version=bXXXX` | Specific version of llama.cpp |
+| `llamacpp_cuda=ON` | Enables CUDA acceleration |
+| `llamacpp_native=OFF` | Disable automatic CPU detection |
+| `llamacpp_cpu_arm_arch=ARCH[+FEATURE]...` | Specific ARM CPU and features |
+| `cuda_arch=ARCH` | Defines target CUDA architecture |
+
+For example, to target Graviton4 when building on another ARM
+architecture:
+
+```bash
+docker build \
+ -t tgi-llamacpp \
+ --build-arg llamacpp_native=OFF \
+ --build-arg llamacpp_cpu_arm_arch=armv9-a+i8mm \
+ https://github.com/huggingface/text-generation-inference.git \
+ -f Dockerfile_llamacpp
+```
+
+## Run Docker image
+
+### CPU-based inference
+
+```bash
+docker run \
+ -p 3000:3000 \
+ -e "HF_TOKEN=$HF_TOKEN" \
+ -v "$HOME/models:/app/models" \
+ tgi-llamacpp \
+ --model-id "Qwen/Qwen2.5-3B-Instruct"
+```
+
+### GPU-Accelerated inference
+
+```bash
+docker run \
+ --gpus all \
+ -p 3000:3000 \
+ -e "HF_TOKEN=$HF_TOKEN" \
+ -v "$HOME/models:/app/models" \
+ tgi-llamacpp \
+ --n-gpu-layers 99 \
+ --model-id "Qwen/Qwen2.5-3B-Instruct"
+```
+
+## Using a custom GGUF
+
+GGUF files are optional as they will be automatically generated at
+startup if not already present in the `models` directory. However, if
+the default GGUF generation is not suitable for your use case, you can
+provide your own GGUF file with `--model-gguf`, for example:
+
+```bash
+docker run \
+ -p 3000:3000 \
+ -e "HF_TOKEN=$HF_TOKEN" \
+ -v "$HOME/models:/app/models" \
+ tgi-llamacpp \
+ --model-id "Qwen/Qwen2.5-3B-Instruct" \
+ --model-gguf "models/qwen2.5-3b-instruct-q4_0.gguf"
+```
+
+Note that `--model-id` is still required.
+
+## Advanced parameters
+
+A full listing of configurable parameters is available in the `--help`:
+
+```bash
+docker run tgi-llamacpp --help
+
+```
+
+The table below summarizes key options:
+
+| Parameter | Description |
+|-------------------------------------|------------------------------------------------------------------------|
+| `--n-threads` | Number of threads to use for generation |
+| `--n-threads-batch` | Number of threads to use for batch processing |
+| `--n-gpu-layers` | Number of layers to store in VRAM |
+| `--split-mode` | Split the model across multiple GPUs |
+| `--defrag-threshold` | Defragment the KV cache if holes/size > threshold |
+| `--numa` | Enable NUMA optimizations |
+| `--disable-mmap` | Disable memory mapping for the model |
+| `--use-mlock` | Use memory locking to prevent swapping |
+| `--disable-offload-kqv` | Disable offloading of KQV operations to the GPU |
+| `--disable-flash-attention` | Disable flash attention |
+| `--type-k` | Data type used for K cache |
+| `--type-v` | Data type used for V cache |
+| `--validation-workers` | Number of tokenizer workers used for payload validation and truncation |
+| `--max-concurrent-requests` | Maximum number of concurrent requests |
+| `--max-input-tokens` | Maximum number of input tokens per request |
+| `--max-total-tokens` | Maximum number of total tokens (input + output) per request |
+| `--max-batch-total-tokens` | Maximum number of tokens in a batch |
+| `--max-physical-batch-total-tokens` | Maximum number of tokens in a physical batch |
+| `--max-batch-size` | Maximum number of requests per batch |
+
+---
+[llama.cpp]: https://github.com/ggerganov/llama.cpp
+[GGUF]: https://huggingface.co/models?library=gguf&sort=trending
diff --git a/docs/source/backends/neuron.md b/docs/source/backends/neuron.md
new file mode 100644
index 00000000000..17d720db56a
--- /dev/null
+++ b/docs/source/backends/neuron.md
@@ -0,0 +1,179 @@
+# Neuron backend for AWS Trainium and Inferentia
+
+The Neuron backend allows the deployment of TGI on AWS Trainium and Inferentia family of chips.
+
+The following hardware targets are supported:
+- Trainium 1,
+- Inferentia 2.
+
+## Features
+
+The basic TGI features are supported:
+
+- continuous batching,
+- token streaming,
+- greedy search and multinomial sampling using [transformers](https://huggingface.co/docs/transformers/generation_strategies#customize-text-generation).
+
+
+## Deploy the service from the Hugging Face hub
+
+The simplest way to deploy the NeuronX TGI service for a specific model is to follow the
+deployment instructions in the model card:
+
+- click on the "Deploy" button on the right,
+- select your deployment service ("Inference Endpoints" and "SageMaker" are supported),
+- select "AWS Trainum & Inferentia",
+- follow the instructions.
+
+
+## Deploy the service on a dedicated host
+
+The service is launched simply by running the text-generation-inference container with two sets of parameters:
+
+```
+docker run ghcr.io/huggingface/text-generation-inference:3.3.5-neuron
+```
+
+- system parameters are used to map ports, volumes and devices between the host and the service,
+- service parameters are forwarded to the `text-generation-launcher`.
+
+When deploying a service, you will need a pre-compiled Neuron model. The Neuron TGI backend supports two main modes of operation:
+
+- you can either deploy the service on a model that has already been exported to Neuron,
+- or alternatively you can take advantage of the Neuron Model Cache to export your own model.
+
+### Common system parameters
+
+Whenever you launch a TGI service, we highly recommend you to mount a shared volume mounted as `/data` in the container: this is where
+the models will be cached to speed up further instantiations of the service.
+
+Note also that enough neuron devices should be made visible to the container, knowing that each neuron device has two cores (so when deploying on two cores you need to expose at least one device).
+The recommended way to expose a device in a production environment is to use explicitly the `--device` option (e.g `--device /dev/neuron0`) repeated as many time as there are devices to be exposed.
+
+Note: alternatively, for a quick local test it is also possible to launch the service in `privileged` mode to get access to all neuron devices.
+
+Finally, you might want to export the `HF_TOKEN` if you want to access gated repositories.
+
+Here is an example of a service instantiation exposing only the first device:
+
+```
+docker run -p 8080:80 \
+ -v $(pwd)/data:/data \
+ --device=/dev/neuron0 \
+ -e HF_TOKEN=${HF_TOKEN} \
+ ghcr.io/huggingface/text-generation-inference:-neuron \
+
+```
+
+### Using a standard model from the 🤗 [HuggingFace Hub](https://huggingface.co/aws-neuron) (recommended)
+
+We maintain a Neuron Model Cache of the most popular architecture and deployment parameters under [aws-neuron/optimum-neuron-cache](https://huggingface.co/aws-neuron/optimum-neuron-cache).
+
+If you just want to try the service quickly using a model without exporting it to Neuron first, it is thus still possible, pending some conditions:
+- you must specify the export parameters when launching the service (or use default parameters),
+- the model configuration must be cached.
+
+The snippet below shows how you can deploy a service from a hub standard model:
+
+```
+export HF_TOKEN=
+docker run -p 8080:80 \
+ -v $(pwd)/data:/data \
+ --device=/dev/neuron0 \
+ --device=/dev/neuron1 \
+ --device=/dev/neuron2 \
+ --device=/dev/neuron3 \
+ -e HF_TOKEN=${HF_TOKEN} \
+ -e HF_AUTO_CAST_TYPE="fp16" \
+ -e HF_NUM_CORES=8 \
+ ghcr.io/huggingface/text-generation-inference:-neuron \
+ --model-id meta-llama/Meta-Llama-3-8B \
+ --max-batch-size 1 \
+ --max-input-length 3164 \
+ --max-total-tokens 4096
+```
+
+### Using a model exported to a local path
+
+Alternatively, you can first [export the model to neuron format](https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-text-generation-inference) locally.
+
+You can then deploy the service inside the shared volume:
+
+```
+docker run -p 8080:80 \
+ -v $(pwd)/data:/data \
+ --device=/dev/neuron0 \
+ --device=/dev/neuron1 \
+ ghcr.io/huggingface/text-generation-inference:-neuron \
+ --model-id /data/
+```
+
+Note: You don't need to specify any service parameters, as they will all be deduced from the model export configuration. You must however expose enough devices to match the number of cores specified during the export phase.
+
+
+### Using a neuron model from the 🤗 [HuggingFace Hub](https://huggingface.co/)
+
+The easiest way to share a neuron model inside your organization is to push it on the Hugging Face hub, so that it can be deployed directly without requiring an export.
+
+The snippet below shows how you can deploy a service from a hub neuron model:
+
+```
+docker run -p 8080:80 \
+ -v $(pwd)/data:/data \
+ --device=/dev/neuron0 \
+ --device=/dev/neuron1 \
+ -e HF_TOKEN=${HF_TOKEN} \
+ ghcr.io/huggingface/text-generation-inference:-neuron \
+ --model-id /
+```
+
+### Choosing service parameters
+
+Use the following command to list the available service parameters:
+
+```
+docker run ghcr.io/huggingface/text-generation-inference:-neuron --help
+```
+
+The configuration of an inference endpoint is always a compromise between throughput and latency: serving more requests in parallel will allow a higher throughput, but it will increase the latency.
+
+The neuron models have static input dimensions `[batch_size, max_length]`.
+
+This adds several restrictions to the following parameters:
+
+- `--max-batch-size` must be set to `batch size`,
+- `--max-input-length` must be lower than `max_length`,
+- `--max-total-tokens` must be set to `max_length` (it is per-request).
+
+Although not strictly necessary, but important for efficient prefilling:
+
+- `--max-batch-prefill-tokens` should be set to `batch_size` * `max-input-length`.
+
+### Choosing the correct batch size
+
+As seen in the previous paragraph, neuron model static batch size has a direct influence on the endpoint latency and throughput.
+
+Please refer to [text-generation-inference](https://github.com/huggingface/text-generation-inference) for optimization hints.
+
+Note that the main constraint is to be able to fit the model for the specified `batch_size` within the total device memory available
+on your instance (16GB per neuron core, with 2 cores per device).
+
+## Query the service
+
+You can query the model using either the `/generate` or `/generate_stream` routes:
+
+```
+curl 127.0.0.1:8080/generate \
+ -X POST \
+ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
+ -H 'Content-Type: application/json'
+```
+
+```
+curl 127.0.0.1:8080/generate_stream \
+ -X POST \
+ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
+ -H 'Content-Type: application/json'
+```
+
+Note: replace 127.0.0.1:8080 with your actual IP address and port.
diff --git a/docs/source/backends/trtllm.md b/docs/source/backends/trtllm.md
new file mode 100644
index 00000000000..92edbf0903d
--- /dev/null
+++ b/docs/source/backends/trtllm.md
@@ -0,0 +1,182 @@
+# TensorRT-LLM backend
+
+The NVIDIA TensorRT-LLM (TRTLLM) backend is a high-performance backend for LLMs
+that uses NVIDIA's TensorRT library for inference acceleration.
+It makes use of specific optimizations for NVIDIA GPUs, such as custom kernels.
+
+To use the TRTLLM backend **you need to compile** `engines` for the models you want to use.
+Each `engine` must be compiled for a given set of:
+- GPU architecture that you will use for inference (e.g. A100, L40, etc.)
+- Maximum batch size
+- Maximum input length
+- Maximum output length
+- Maximum beams width
+
+## Supported models
+
+Check the [support matrix](https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html) to see which models are
+supported.
+
+## Compiling engines
+
+You can use [Optimum-NVIDIA](https://github.com/huggingface/optimum-nvidia) to compile engines for the models you
+want to use.
+
+```bash
+MODEL_NAME="meta-llama/Llama-3.1-8B-Instruct"
+DESTINATION="/tmp/engines/$MODEL_NAME"
+HF_TOKEN="hf_xxx"
+# Compile the engine using Optimum-NVIDIA
+# This will create a compiled engine in the /tmp/engines/meta-llama/Llama-3.1-8B-Instruct
+# directory for 1 GPU
+docker run \
+ --rm \
+ -it \
+ --gpus=1 \
+ --shm-size=1g \
+ -v "$DESTINATION":/engine \
+ -e HF_TOKEN=$HF_TOKEN \
+ -e HF_HUB_ENABLE_HF_TRANSFER=1 \
+ huggingface/optimum-nvidia:v0.1.0b9-py310 \
+ bash -c "optimum-cli export trtllm \
+ --tp=1 \
+ --pp=1 \
+ --max-batch-size=64 \
+ --max-input-length 4096 \
+ --max-output-length 8192 \
+ --max-beams-width=1 \
+ --destination /tmp/engine \
+ $MODEL_NAME && cp -rL /tmp/engine/* /engine/"
+```
+
+Your compiled engine will be saved in the `/tmp/engines/$MODEL_NAME` directory, in a subfolder named after the GPU used to compile the model.
+
+## Using the TRTLLM backend
+
+Run TGI-TRTLLM Docker image with the compiled engine:
+
+```bash
+MODEL_NAME="meta-llama/Llama-3.1-8B-Instruct"
+DESTINATION="/tmp/engines/$MODEL_NAME"
+HF_TOKEN="hf_xxx"
+docker run \
+ --gpus 1 \
+ --shm-size=1g \
+ -it \
+ --rm \
+ -p 3000:3000 \
+ -e MODEL=$MODEL_NAME \
+ -e PORT=3000 \
+ -e HF_TOKEN=$HF_TOKEN \
+ -v "$DESTINATION"//engines:/data \
+ ghcr.io/huggingface/text-generation-inference:latest-trtllm \
+ --model-id /data/ \
+ --tokenizer-name $MODEL_NAME
+```
+
+## Development
+
+To develop TRTLLM backend, you can use [dev containers](https://containers.dev/) with the following `.devcontainer.json` file:
+```json
+{
+ "name": "CUDA",
+ "build": {
+ "dockerfile": "Dockerfile_trtllm",
+ "context": ".."
+ },
+ "remoteEnv": {
+ "PATH": "${containerEnv:PATH}:/usr/local/cuda/bin",
+ "LD_LIBRARY_PATH": "$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64",
+ "XLA_FLAGS": "--xla_gpu_cuda_data_dir=/usr/local/cuda"
+ },
+ "customizations" : {
+ "jetbrains" : {
+ "backend" : "CLion"
+ }
+ }
+}
+```
+
+and `Dockerfile_trtllm`:
+
+```Dockerfile
+ARG cuda_arch_list="75-real;80-real;86-real;89-real;90-real"
+ARG build_type=release
+ARG ompi_version=4.1.7
+
+# CUDA dependent dependencies resolver stage
+FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu24.04 AS cuda-builder
+
+RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
+ build-essential \
+ cmake \
+ curl \
+ gcc-14 \
+ g++-14 \
+ git \
+ git-lfs \
+ lld \
+ libssl-dev \
+ libucx-dev \
+ libasan8 \
+ libubsan1 \
+ ninja-build \
+ pkg-config \
+ pipx \
+ python3 \
+ python3-dev \
+ python3-setuptools \
+ tar \
+ wget --no-install-recommends && \
+ pipx ensurepath
+
+ENV TGI_INSTALL_PREFIX=/usr/local/tgi
+ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
+
+# Install OpenMPI
+FROM cuda-builder AS mpi-builder
+WORKDIR /opt/src/mpi
+
+ARG ompi_version
+ENV OMPI_VERSION=${ompi_version}
+ENV OMPI_TARBALL_FILENAME=openmpi-${OMPI_VERSION}.tar.bz2
+ADD --checksum=sha256:54a33cb7ad81ff0976f15a6cc8003c3922f0f3d8ceed14e1813ef3603f22cd34 \
+ https://download.open-mpi.org/release/open-mpi/v4.1/${OMPI_TARBALL_FILENAME} .
+
+RUN tar --strip-components=1 -xf ${OMPI_TARBALL_FILENAME} &&\
+ ./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \
+ make -j all && \
+ make install && \
+ rm -rf ${OMPI_TARBALL_FILENAME}/..
+
+# Install TensorRT
+FROM cuda-builder AS trt-builder
+COPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh
+RUN chmod +x /opt/install_tensorrt.sh && \
+ /opt/install_tensorrt.sh
+
+# Build Backend
+FROM cuda-builder AS tgi-builder
+WORKDIR /usr/src/text-generation-inference
+
+# Scoped global args reuse
+ARG cuda_arch_list
+ARG build_type
+ARG sccache_gha_enabled
+ARG actions_results_url
+ARG actions_runtime_token
+
+# Install Rust
+ENV PATH="/root/.cargo/bin:$PATH"
+RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \
+ chmod -R a+w /root/.rustup && \
+ chmod -R a+w /root/.cargo && \
+ cargo install sccache --version ">=0.10.0" --locked
+
+ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
+ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig"
+ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt"
+
+ENV USE_LLD_LINKER=ON
+ENV CUDA_ARCH_LIST=${cuda_arch_list}
+```
diff --git a/docs/source/basic_tutorials/consuming_tgi.md b/docs/source/basic_tutorials/consuming_tgi.md
index b07e7219d7c..56237f9ad6f 100644
--- a/docs/source/basic_tutorials/consuming_tgi.md
+++ b/docs/source/basic_tutorials/consuming_tgi.md
@@ -152,14 +152,10 @@ def inference(message, history):
gr.ChatInterface(
inference,
- chatbot=gr.Chatbot(height=300),
- textbox=gr.Textbox(placeholder="Chat with me!", container=False, scale=7),
+ type="messages",
description="This is the demo for Gradio UI consuming TGI endpoint.",
title="Gradio 🤝 TGI",
examples=["Are tomatoes vegetables?"],
- retry_btn="Retry",
- undo_btn="Undo",
- clear_btn="Clear",
).queue().launch()
```
diff --git a/docs/source/basic_tutorials/gated_model_access.md b/docs/source/basic_tutorials/gated_model_access.md
index 60b347b13ad..d42bac7a689 100644
--- a/docs/source/basic_tutorials/gated_model_access.md
+++ b/docs/source/basic_tutorials/gated_model_access.md
@@ -19,6 +19,6 @@ docker run --gpus all \
--shm-size 1g \
-e HF_TOKEN=$token \
-p 8080:80 \
- -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.0 \
+ -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.5 \
--model-id $model
```
diff --git a/docs/source/basic_tutorials/monitoring.md b/docs/source/basic_tutorials/monitoring.md
index 509b0aff1e8..aa60e7af980 100644
--- a/docs/source/basic_tutorials/monitoring.md
+++ b/docs/source/basic_tutorials/monitoring.md
@@ -38,7 +38,7 @@ In this guide, Prometheus monitoring data will be consumed on a local computer.
* Use ssh [local port forwarding](https://www.ssh.com/academy/ssh/tunneling-example)
* Use ngrok port tunneling
-For simplicity, we will use [Ngrok](https://ngrok.com/docs/) in this guide to tunnel Prometheus port from the TGI server to the outside word.
+For simplicity, we will use [Ngrok](https://ngrok.com/docs/) in this guide to tunnel Prometheus port from the TGI server to the outside world.
For that, you should follow the steps at https://dashboard.ngrok.com/get-started/setup/linux, and once Ngrok is installed, use:
```bash
diff --git a/docs/source/basic_tutorials/using_guidance.md b/docs/source/basic_tutorials/using_guidance.md
index 2d55c9528c1..d25a6bd02f6 100644
--- a/docs/source/basic_tutorials/using_guidance.md
+++ b/docs/source/basic_tutorials/using_guidance.md
@@ -138,10 +138,10 @@ client = InferenceClient("http://localhost:3000/")
user_input = "I saw a puppy a cat and a raccoon during my bike ride in the park"
resp = client.text_generation(
- f"convert to JSON: 'f{user_input}'. please use the following schema: {Animals.schema()}",
+ f"convert to JSON: 'f{user_input}'. please use the following schema: {Animals.model_json_schema()}",
max_new_tokens=100,
seed=42,
- grammar={"type": "json", "value": Animals.schema()},
+ grammar={"type": "json", "value": Animals.model_json_schema()},
)
print(resp)
@@ -187,8 +187,6 @@ In addition to the grammar parameter, we've also introduced a set of tools and f
Tools are a set of user defined functions that can be used in tandem with the chat functionality to enhance the LLM's capabilities. Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API.
-Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API.
-
```json
curl localhost:3000/v1/chat/completions \
-X POST \
diff --git a/docs/source/basic_tutorials/visual_language_models.md b/docs/source/basic_tutorials/visual_language_models.md
index f152a2f0b2e..f3c8c836aef 100644
--- a/docs/source/basic_tutorials/visual_language_models.md
+++ b/docs/source/basic_tutorials/visual_language_models.md
@@ -22,7 +22,7 @@ To infer with vision language models through Python, you can use the [`huggingfa
```python
from huggingface_hub import InferenceClient
-client = InferenceClient("http://127.0.0.1:3000/")
+client = InferenceClient(base_url="http://127.0.0.1:3000/")
image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
prompt = f"What is this a picture of?\n\n"
for token in client.text_generation(prompt, max_new_tokens=16, stream=True):
@@ -37,7 +37,7 @@ import base64
import requests
import io
-client = InferenceClient("http://127.0.0.1:3000/")
+client = InferenceClient(base_url="http://127.0.0.1:3000/")
# read image from local file
image_path = "rabbit.png"
@@ -58,7 +58,7 @@ or via the `chat_completion` endpoint:
```python
from huggingface_hub import InferenceClient
-client = InferenceClient("http://127.0.0.1:3000/")
+client = InferenceClient(base_url="http://127.0.0.1:3000/")
chat = client.chat_completion(
messages=[
@@ -137,19 +137,19 @@ First, we need to install the `@huggingface/inference` library.
npm install @huggingface/inference
```
-If you're using the free Inference API, you can use [Huggingface.js](https://huggingface.co/docs/huggingface.js/inference/README)'s `HfInference`. If you're using inference endpoints, you can use `HfInferenceEndpoint` class to easily interact with the Inference API.
+Whether you use Inference Providers (our serverless API), or Inference Endpoints, you can call `InferenceClient`.
-We can create a `HfInferenceEndpoint` providing our endpoint URL and We can create a `HfInferenceEndpoint` providing our endpoint URL and [Hugging Face access token](https://huggingface.co/settings/tokens).
+We can create a `InferenceClient` providing our endpoint URL and [Hugging Face access token](https://huggingface.co/settings/tokens).
```js
-import { HfInferenceEndpoint } from "@huggingface/inference";
+import { InferenceClient } from "@huggingface/inference";
-const hf = new HfInferenceEndpoint("http://127.0.0.1:3000/", "HF_TOKEN");
+const client = new InferenceClient('hf_YOUR_TOKEN', { endpointUrl: 'https://your_endpoint.endpoints.huggingface.cloud/' });
const prompt =
"What is this a picture of?\n\n";
-const stream = hf.textGenerationStream({
+const stream = client.textGenerationStream({
inputs: prompt,
parameters: { max_new_tokens: 16, seed: 42 },
});
diff --git a/docs/source/conceptual/quantization.md b/docs/source/conceptual/quantization.md
index 953e36837aa..ad6483e296a 100644
--- a/docs/source/conceptual/quantization.md
+++ b/docs/source/conceptual/quantization.md
@@ -19,7 +19,7 @@ bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models.
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
```bash
-docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.0 --model-id $model --quantize bitsandbytes
+docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.5 --model-id $model --quantize bitsandbytes
```
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
@@ -27,7 +27,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
```bash
-docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.0 --model-id $model --quantize bitsandbytes-nf4
+docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.5 --model-id $model --quantize bitsandbytes-nf4
```
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
@@ -48,7 +48,7 @@ $$({\hat{W}_{l}}^{*} = argmin_{\hat{W_{l}}} ||W_{l}X-\hat{W}_{l}X||^{2}_{2})$$
TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇
```bash
-docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.0 --model-id $model --quantize gptq
+docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.5 --model-id $model --quantize gptq
```
Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI.
diff --git a/docs/source/conceptual/speculation.md b/docs/source/conceptual/speculation.md
index 45618ae3feb..74e010c8aea 100644
--- a/docs/source/conceptual/speculation.md
+++ b/docs/source/conceptual/speculation.md
@@ -27,7 +27,7 @@ You can check a few existing fine-tunes for popular models:
- [text-generation-inference/Mistral-7B-Instruct-v0.2-medusa](https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa)
-In order to create your own medusa heads for your own finetune, you should check own the original medusa repo. [../basic_tutorials/train_medusa.md](../basic_tutorials/train_medusa.md)
+In order to create your own medusa heads for your own finetune, you should check own the original medusa repo. Read for more in [Train Medusa](../basic_tutorials/train_medusa#training).
In order to use medusa models in TGI, simply point to a medusa enabled model, and everything will load automatically.
diff --git a/docs/source/conceptual/streaming.md b/docs/source/conceptual/streaming.md
index b8154ba4355..2a2f6490fcc 100644
--- a/docs/source/conceptual/streaming.md
+++ b/docs/source/conceptual/streaming.md
@@ -27,14 +27,14 @@ For example, a system can generate 100 tokens per second. If the system generate
@@ -125,24 +125,26 @@ curl localhost:8080/v1/chat/completions \
### Streaming with JavaScript
First, we need to install the `@huggingface/inference` library.
-`npm install @huggingface/inference`
-If you're using the free Inference API, you can use `HfInference`. If you're using inference endpoints, you can use `HfInferenceEndpoint`.
+```bash
+npm install @huggingface/inference
+```
+
+Whether you use Inference Providers (our serverless API), or Inference Endpoints, you can call `InferenceClient`.
-We can create a `HfInferenceEndpoint` providing our endpoint URL and credential.
```js
-import { HfInferenceEndpoint } from '@huggingface/inference'
+import { InferenceClient } from '@huggingface/inference';
-const hf = new HfInferenceEndpoint('https://your_endpoint.endpoints.huggingface.cloud/', 'hf_YOUR_TOKEN')
+const client = new InferenceClient('hf_YOUR_TOKEN', { endpointUrl: 'https://your_endpoint.endpoints.huggingface.cloud/' });
// prompt
-const prompt = 'What can you do in Nuremberg, Germany? Give me 3 Tips'
+const prompt = 'What can you do in Nuremberg, Germany? Give me 3 Tips';
-const stream = hf.textGenerationStream({ inputs: prompt })
+const stream = client.textGenerationStream({ inputs: prompt });
for await (const r of stream) {
// yield the generated token
- process.stdout.write(r.token.text)
+ process.stdout.write(r.token.text);
}
```
diff --git a/docs/source/installation_amd.md b/docs/source/installation_amd.md
index 3c9c0eecf05..df4abb3b43c 100644
--- a/docs/source/installation_amd.md
+++ b/docs/source/installation_amd.md
@@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--device=/dev/kfd --device=/dev/dri --group-add video \
--ipc=host --shm-size 256g --net host -v $volume:/data \
- ghcr.io/huggingface/text-generation-inference:3.0.0-rocm \
+ ghcr.io/huggingface/text-generation-inference:3.3.5-rocm \
--model-id $model
```
@@ -39,6 +39,6 @@ The custom kernel supports bf16 and fp16 data types, block size of 16, head size
## Unsupported features
-The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
+The following features are currently not supported in the ROCm version of TGI, and the support may be extended in the future:
* Loading [AWQ](https://huggingface.co/docs/transformers/quantization#awq) checkpoints.
* Kernel for sliding window attention (Mistral)
diff --git a/docs/source/installation_gaudi.md b/docs/source/installation_gaudi.md
index 1ddf2b47225..51aa667dcad 100644
--- a/docs/source/installation_gaudi.md
+++ b/docs/source/installation_gaudi.md
@@ -1,3 +1,3 @@
# Using TGI with Intel Gaudi
-Check out this [repository](https://github.com/huggingface/tgi-gaudi) to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index).
+You can use TGI on Intel Gaudi using the [TGI gaudi backend](https://huggingface.co/docs/text-generation-inference/backends/gaudi).
diff --git a/docs/source/installation_inferentia.md b/docs/source/installation_inferentia.md
index 0394e6ded37..bfd0f657754 100644
--- a/docs/source/installation_inferentia.md
+++ b/docs/source/installation_inferentia.md
@@ -1,3 +1,3 @@
# Using TGI with Inferentia
-Check out this [guide](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference) on how to serve models with TGI on Inferentia2.
+You can use TGI on AWS Trainium and Inferentia platforms using the [TGI neuron backend](https://huggingface.co/docs/text-generation-inference/backends/neuron).
diff --git a/docs/source/installation_intel.md b/docs/source/installation_intel.md
index a5578e769c0..60b0bcc056f 100644
--- a/docs/source/installation_intel.md
+++ b/docs/source/installation_intel.md
@@ -12,7 +12,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \
- ghcr.io/huggingface/text-generation-inference:3.0.0-intel-xpu \
+ ghcr.io/huggingface/text-generation-inference:3.3.5-intel-xpu \
--model-id $model --cuda-graphs 0
```
@@ -29,7 +29,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \
- ghcr.io/huggingface/text-generation-inference:3.0.0-intel-cpu \
+ ghcr.io/huggingface/text-generation-inference:3.3.5-intel-cpu \
--model-id $model --cuda-graphs 0
```
diff --git a/docs/source/installation_nvidia.md b/docs/source/installation_nvidia.md
index d954481eda5..37cb841c8b3 100644
--- a/docs/source/installation_nvidia.md
+++ b/docs/source/installation_nvidia.md
@@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
- ghcr.io/huggingface/text-generation-inference:3.0.0 \
+ ghcr.io/huggingface/text-generation-inference:3.3.5 \
--model-id $model
```
diff --git a/docs/source/installation_tpu.md b/docs/source/installation_tpu.md
new file mode 100644
index 00000000000..559e83aa7dd
--- /dev/null
+++ b/docs/source/installation_tpu.md
@@ -0,0 +1,3 @@
+# Using TGI with Google TPUs
+
+Check out this [guide](https://huggingface.co/docs/optimum-tpu) on how to serve models with TGI on TPUs.
diff --git a/docs/source/multi_backend_support.md b/docs/source/multi_backend_support.md
new file mode 100644
index 00000000000..997503a4f19
--- /dev/null
+++ b/docs/source/multi_backend_support.md
@@ -0,0 +1,16 @@
+# Multi-backend support
+
+TGI (Text Generation Inference) offers flexibility by supporting multiple backends for serving large language models (LLMs).
+With multi-backend support, you can choose the backend that best suits your needs,
+whether you prioritize performance, ease of use, or compatibility with specific hardware. API interaction with
+TGI remains consistent across backends, allowing you to switch between them seamlessly.
+
+**Supported backends:**
+* **TGI CUDA backend**: This high-performance backend is optimized for NVIDIA GPUs and serves as the default option
+ within TGI. Developed in-house, it boasts numerous optimizations and is used in production by various projects, including those by Hugging Face.
+* **[TGI TRTLLM backend](./backends/trtllm)**: This backend leverages NVIDIA's TensorRT library to accelerate LLM inference.
+ It utilizes specialized optimizations and custom kernels for enhanced performance.
+ However, it requires a model-specific compilation step for each GPU architecture.
+* **[TGI Llamacpp backend](./backends/llamacpp)**: This backend facilitates the deployment of large language models
+ (LLMs) by integrating [llama.cpp][llama.cpp], an advanced inference engine optimized for both CPU and GPU computation.
+* **[TGI Neuron backend](./backends/neuron)**: This backend leverages the [AWS Neuron SDK](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/) to allow the deployment of large language models (LLMs) on [AWS Trainium and Inferentia chips](https://aws.amazon.com/ai/machine-learning/trainium/).
diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md
index d1f3efa4f40..bd8495c52e2 100644
--- a/docs/source/quicktour.md
+++ b/docs/source/quicktour.md
@@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
- ghcr.io/huggingface/text-generation-inference:3.0.0 \
+ ghcr.io/huggingface/text-generation-inference:3.3.5 \
--model-id $model
```
@@ -96,7 +96,7 @@ curl 127.0.0.1:8080/generate \
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
```bash
-docker run ghcr.io/huggingface/text-generation-inference:3.0.0 --help
+docker run ghcr.io/huggingface/text-generation-inference:3.3.5 --help
```
diff --git a/docs/source/reference/api_reference.md b/docs/source/reference/api_reference.md
index 42a777039f3..7d21eca70a8 100644
--- a/docs/source/reference/api_reference.md
+++ b/docs/source/reference/api_reference.md
@@ -163,7 +163,7 @@ hub = {
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
- image_uri=get_huggingface_llm_image_uri("huggingface",version="3.0.0"),
+ image_uri=get_huggingface_llm_image_uri("huggingface",version="3.3.5"),
env=hub,
role=role,
)
diff --git a/docs/source/reference/launcher.md b/docs/source/reference/launcher.md
index 159b22e7352..5b7321b73a3 100644
--- a/docs/source/reference/launcher.md
+++ b/docs/source/reference/launcher.md
@@ -58,8 +58,6 @@ Options:
Quantization method to use for the model. It is not necessary to specify this option for pre-quantized models, since the quantization method is read from the model configuration.
Marlin kernels will be used automatically for GPTQ/AWQ models.
-
- [env: QUANTIZE=]
Possible values:
- awq: 4 bit quantization. Requires a specific AWQ quantized model:
. Should replace GPTQ models wherever possible because of the better latency
@@ -72,6 +70,8 @@ Options:
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
- fp8: [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above This dtype has native ops should be the fastest if available. This is currently not the fastest because of local unpacking + padding to satisfy matrix multiplication limitations
+
+ [env: QUANTIZE=]
```
## SPECULATE
@@ -198,7 +198,7 @@ Options:
For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100` or a single query of `1000` tokens.
- Overall this number should be the largest possible amount that fits the remaining memory (after the model is loaded). Since the actual memory overhead depends on other parameters like if you're using quantization, flash attention or the model implementation, text-generation-inference cannot infer this number automatically.
+ Overall this number should be the largest possible amount that fits the remaining memory (after the model is loaded). Since the actual memory overhead depends on other parameters like if you're using quantization, flash attention or the model implementation, text-generation-inference infers this number automatically if not provided ensuring that the value is as large as possible.
[env: MAX_BATCH_TOTAL_TOKENS=]
@@ -251,6 +251,15 @@ Options:
[env: PORT=]
[default: 3000]
+```
+## PROMETHEUS_PORT
+```shell
+ -p, --prometheus-port
+ The Prometheus port to listen on
+
+ [env: PROMETHEUS_PORT=]
+ [default: 9000]
+
```
## SHARD_UDS_PATH
```shell
@@ -447,14 +456,14 @@ Options:
```shell
--usage-stats
Control if anonymous usage stats are collected. Options are "on", "off" and "no-stack" Defaul is on
-
- [env: USAGE_STATS=]
- [default: on]
Possible values:
- on: Default option, usage statistics are collected anonymously
- off: Disables all collection of usage statistics
- no-stack: Doesn't send the error stack trace or error type, but allows sending a crash event
+
+ [env: USAGE_STATS=]
+ [default: on]
```
## PAYLOAD_LIMIT
@@ -477,6 +486,15 @@ Options:
[env: ENABLE_PREFILL_LOGPROBS=]
+```
+## GRACEFUL_TERMINATION_TIMEOUT
+```shell
+ -g, --graceful-termination-timeout
+ Change timeout of graceful termination of the TGI server
+
+ [env: GRACEFUL_TERMINATION_TIMEOUT=]
+ [default: 90]
+
```
## HELP
```shell
diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md
index 0f39ff28e42..5d855a5b905 100644
--- a/docs/source/supported_models.md
+++ b/docs/source/supported_models.md
@@ -4,14 +4,19 @@
Text Generation Inference enables serving optimized models. The following sections list which models (VLMs & LLMs) are supported.
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
+- [Deepseek V3](https://huggingface.co/deepseek-ai/DeepSeek-V3)
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
+- [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal)
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
+- [Llama4](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
- [Granite](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct)
- [Gemma](https://huggingface.co/google/gemma-7b)
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
+- [Gemma3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d)
+- [Gemma3 Text](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d)
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct)
- [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj)
@@ -25,6 +30,7 @@ Text Generation Inference enables serving optimized models. The following sectio
- [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1)
- [Qwen 2](https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f)
- [Qwen 2 VL](https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d)
+- [Qwen 2.5 VL](https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e)
- [Opt](https://huggingface.co/facebook/opt-6.7b)
- [T5](https://huggingface.co/google/flan-t5-xxl)
- [Galactica](https://huggingface.co/facebook/galactica-120b)
diff --git a/docs/source/usage_statistics.md b/docs/source/usage_statistics.md
index d3878b53ccb..dbf36c77921 100644
--- a/docs/source/usage_statistics.md
+++ b/docs/source/usage_statistics.md
@@ -3,7 +3,7 @@
Text Generation Inference collects anonymous usage statistics to help us improve the service. The collected data is used to improve TGI and to understand what causes failures. The data is collected transparently and any sensitive information is omitted.
-Data is sent twice, once on server startup and once when server stops. Also, usage statistics are only enabled when TGI is running in docker to avoid collecting data then TGI runs directly on the host machine.
+Usage statistics are collected only when TGI is running in a Docker container. This prevents data collection when TGI is run directly on the host machine. The collected data includes startup and shutdown events, as well as a heartbeat signal sent every 15 minutes.
## What data is collected
diff --git a/flake.lock b/flake.lock
index ec87d569231..e57990c89a1 100644
--- a/flake.lock
+++ b/flake.lock
@@ -102,17 +102,17 @@
"flake-parts": "flake-parts_3",
"nix-test-runner": "nix-test-runner_3",
"nixpkgs": [
- "tgi-nix",
+ "hf-nix",
"nixpkgs"
],
"pre-commit-hooks": "pre-commit-hooks_3"
},
"locked": {
- "lastModified": 1732039290,
- "narHash": "sha256-LQKY7bShf2H9kJouxa9ZspfdrulnZF9o4kLTqGqCDYM=",
+ "lastModified": 1739473963,
+ "narHash": "sha256-ItAhpjNUzEWd/cgZVyW/jvoGbCec4TK29e1Mnmn1oJE=",
"owner": "nix-community",
"repo": "crate2nix",
- "rev": "9ff208ce7f5a482272b1bcefbe363c772d7ff914",
+ "rev": "be31feae9a82c225c0fd1bdf978565dc452a483a",
"type": "github"
},
"original": {
@@ -305,11 +305,11 @@
},
"flake-compat_4": {
"locked": {
- "lastModified": 1696426674,
- "narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
+ "lastModified": 1733328505,
+ "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
"owner": "edolstra",
"repo": "flake-compat",
- "rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
"type": "github"
},
"original": {
@@ -579,6 +579,26 @@
"type": "github"
}
},
+ "hf-nix": {
+ "inputs": {
+ "flake-compat": "flake-compat_4",
+ "flake-utils": "flake-utils_7",
+ "nixpkgs": "nixpkgs_6"
+ },
+ "locked": {
+ "lastModified": 1747919133,
+ "narHash": "sha256-VvF1naQOvv7yulQ5/cDiaxkNxlh1Y84QMZnderv1szk=",
+ "owner": "huggingface",
+ "repo": "hf-nix",
+ "rev": "9c71e026d6c7c8588ef85a5f7c77f57d598e038c",
+ "type": "github"
+ },
+ "original": {
+ "owner": "huggingface",
+ "repo": "hf-nix",
+ "type": "github"
+ }
+ },
"nix-filter": {
"locked": {
"lastModified": 1731533336,
@@ -718,16 +738,16 @@
},
"nixpkgs_6": {
"locked": {
- "lastModified": 1732034459,
- "narHash": "sha256-Zais/zMRuJdlALidkUgEuasXOd37ZZLqkPkF9bIYSrY=",
+ "lastModified": 1747820358,
+ "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
"owner": "danieldk",
"repo": "nixpkgs",
- "rev": "40280e7bf9743cdf563494db4ece2a43aa674fa8",
+ "rev": "d3c1681180717528068082103bf323147de6ab0b",
"type": "github"
},
"original": {
"owner": "danieldk",
- "ref": "outlines-v0.1.4-tgi",
+ "ref": "cudatoolkit-12.9-kernel-builder",
"repo": "nixpkgs",
"type": "github"
}
@@ -836,28 +856,28 @@
"inputs": {
"crate2nix": "crate2nix",
"flake-utils": "flake-utils_6",
+ "hf-nix": "hf-nix",
"nix-filter": "nix-filter",
"nixpkgs": [
- "tgi-nix",
+ "hf-nix",
"nixpkgs"
],
- "rust-overlay": "rust-overlay",
- "tgi-nix": "tgi-nix"
+ "rust-overlay": "rust-overlay"
}
},
"rust-overlay": {
"inputs": {
"nixpkgs": [
- "tgi-nix",
+ "hf-nix",
"nixpkgs"
]
},
"locked": {
- "lastModified": 1732242723,
- "narHash": "sha256-NWI8csIK0ujFlFuEXKnoc+7hWoCiEtINK9r48LUUMeU=",
+ "lastModified": 1743993291,
+ "narHash": "sha256-u8GHvduU1gCtoFXvTS/wGjH1ouv5S/GRGq6MAT+sG/k=",
"owner": "oxalica",
"repo": "rust-overlay",
- "rev": "a229311fcb45b88a95fdfa5cecd8349c809a272a",
+ "rev": "0cb3c8979c65dc6a5812dfe67499a8c7b8b4325b",
"type": "github"
},
"original": {
@@ -970,26 +990,6 @@
"repo": "default",
"type": "github"
}
- },
- "tgi-nix": {
- "inputs": {
- "flake-compat": "flake-compat_4",
- "flake-utils": "flake-utils_7",
- "nixpkgs": "nixpkgs_6"
- },
- "locked": {
- "lastModified": 1732218602,
- "narHash": "sha256-BElslL34KjOJCFMPkNtilOz6S/7iY7Vd72FNbRRWKDY=",
- "owner": "huggingface",
- "repo": "text-generation-inference-nix",
- "rev": "f79638ac4e420e661321261744e745a3a747e182",
- "type": "github"
- },
- "original": {
- "owner": "huggingface",
- "repo": "text-generation-inference-nix",
- "type": "github"
- }
}
},
"root": "root",
diff --git a/flake.nix b/flake.nix
index 83cedfa620f..b5b13cad8bf 100644
--- a/flake.nix
+++ b/flake.nix
@@ -2,15 +2,15 @@
inputs = {
crate2nix = {
url = "github:nix-community/crate2nix";
- inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
+ inputs.nixpkgs.follows = "hf-nix/nixpkgs";
};
nix-filter.url = "github:numtide/nix-filter";
- tgi-nix.url = "github:huggingface/text-generation-inference-nix";
- nixpkgs.follows = "tgi-nix/nixpkgs";
+ hf-nix.url = "github:huggingface/hf-nix";
+ nixpkgs.follows = "hf-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
rust-overlay = {
url = "github:oxalica/rust-overlay";
- inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
+ inputs.nixpkgs.follows = "hf-nix/nixpkgs";
};
};
outputs =
@@ -21,7 +21,7 @@
nixpkgs,
flake-utils,
rust-overlay,
- tgi-nix,
+ hf-nix,
}:
flake-utils.lib.eachDefaultSystem (
system:
@@ -33,10 +33,10 @@
};
pkgs = import nixpkgs {
inherit system;
- inherit (tgi-nix.lib) config;
+ inherit (hf-nix.lib) config;
overlays = [
rust-overlay.overlays.default
- tgi-nix.overlays.default
+ hf-nix.overlays.default
(import nix/overlay.nix)
];
};
@@ -44,9 +44,24 @@
benchmark = cargoNix.workspaceMembers.text-generation-benchmark.build.override {
inherit crateOverrides;
};
- launcher = cargoNix.workspaceMembers.text-generation-launcher.build.override {
- inherit crateOverrides;
- };
+ launcher =
+ let
+ launcherUnwrapped = cargoNix.workspaceMembers.text-generation-launcher.build.override {
+ inherit crateOverrides;
+ };
+ packagePath =
+ with pkgs.python3.pkgs;
+ makePythonPath [
+ torch
+ ];
+ in
+ pkgs.writeShellApplication {
+ name = "text-generation-launcher";
+ text = ''
+ PYTHONPATH="${packagePath}" ${launcherUnwrapped}/bin/text-generation-launcher "$@"
+ '';
+ };
+
router =
let
routerUnwrapped = cargoNix.workspaceMembers.text-generation-router-v3.build.override {
@@ -161,11 +176,15 @@
'';
};
- dockerImage = pkgs.callPackage nix/docker.nix {
+ # Use plain nixpkgs without overlays for dockerTools. dockerTools
+ # uses a Python package for computing the layers from the transitive
+ # closure. However, this needs a lot of rebuilds due to our overlay.
+
+ dockerImage = nixpkgs.legacyPackages.${system}.callPackage nix/docker.nix {
text-generation-inference = default;
};
- dockerImageStreamed = pkgs.callPackage nix/docker.nix {
+ dockerImageStreamed = nixpkgs.legacyPackages.${system}.callPackage nix/docker.nix {
text-generation-inference = default;
stream = true;
};
diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py
index c9c477665a3..9cc334168aa 100644
--- a/integration-tests/conftest.py
+++ b/integration-tests/conftest.py
@@ -1,4 +1,18 @@
+pytest_plugins = [
+ "fixtures.neuron.service",
+ "fixtures.neuron.export_models",
+ "fixtures.gaudi.service",
+]
# ruff: noqa: E402
+from _pytest.fixtures import SubRequest
+from huggingface_hub.inference._generated.types.chat_completion import (
+ ChatCompletionStreamOutput,
+ ChatCompletionOutput,
+)
+from openai.types.chat.chat_completion_chunk import (
+ ChatCompletionChunk as OAIChatCompletionChunk,
+)
+from openai.types.completion import Completion as OAICompletion
import requests
@@ -10,13 +24,13 @@ def request(self, *args, **kwargs):
requests.sessions.Session = SessionTimeoutFix
+import warnings
import asyncio
import contextlib
import json
import math
import os
import random
-import shutil
import subprocess
import sys
import tempfile
@@ -30,7 +44,6 @@ def request(self, *args, **kwargs):
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from docker.errors import NotFound
from syrupy.extensions.json import JSONSnapshotExtension
-
from text_generation import AsyncClient
from text_generation.types import (
BestOfSequence,
@@ -56,45 +69,124 @@ def pytest_addoption(parser):
parser.addoption(
"--release", action="http://github.com/store_true", default=False, help="run release tests"
)
+ parser.addoption(
+ "--neuron", action="http://github.com/store_true", default=False, help="run neuron tests"
+ )
+ parser.addoption(
+ "--gaudi", action="http://github.com/store_true", default=False, help="run gaudi tests"
+ )
+ parser.addoption(
+ "--gaudi-all-models",
+ action="http://github.com/store_true",
+ default=False,
+ help="Run tests for all models instead of just the default subset",
+ )
def pytest_configure(config):
config.addinivalue_line("markers", "release: mark test as a release-only test")
+ config.addinivalue_line("markers", "neuron: mark test as a neuron test")
def pytest_collection_modifyitems(config, items):
- if config.getoption("--release"):
- # --release given in cli: do not skip release tests
- return
- skip_release = pytest.mark.skip(reason="need --release option to run")
+ selectors = []
+ if not config.getoption("--release"):
+ # --release not given in cli: skip release tests
+ def skip_release(item):
+ if "release" in item.keywords:
+ item.add_marker(pytest.mark.skip(reason="need --release option to run"))
+
+ selectors.append(skip_release)
+
+ if config.getoption("--gaudi"):
+
+ def skip_not_gaudi(item):
+ if "gaudi" not in item.keywords:
+ item.add_marker(pytest.mark.skip(reason="requires --gaudi to run"))
+
+ selectors.append(skip_not_gaudi)
+ else:
+
+ def skip_gaudi(item):
+ if "gaudi" in item.keywords:
+ item.add_marker(pytest.mark.skip(reason="requires --gaudi to run"))
+
+ selectors.append(skip_gaudi)
+
+ if config.getoption("--neuron"):
+
+ def skip_not_neuron(item):
+ if "neuron" not in item.keywords:
+ item.add_marker(
+ pytest.mark.skip(reason="incompatible with --neuron option")
+ )
+
+ selectors.append(skip_not_neuron)
+ else:
+
+ def skip_neuron(item):
+ if "neuron" in item.keywords:
+ item.add_marker(pytest.mark.skip(reason="requires --neuron to run"))
+
+ selectors.append(skip_neuron)
+
for item in items:
- if "release" in item.keywords:
- item.add_marker(skip_release)
+ for selector in selectors:
+ selector(item)
+
+
+@pytest.fixture(autouse=True, scope="module")
+def container_log(request: SubRequest):
+ error_log = request.getfixturevalue("error_log")
+ assert error_log is not None
+ yield
+ if request.session.testsfailed:
+ error_log.seek(0)
+ print(error_log.read(), file=sys.stderr)
+ else:
+ error_log.truncate(0)
+ error_log.seek(0)
class ResponseComparator(JSONSnapshotExtension):
rtol = 0.2
ignore_logprob = False
- def serialize(
+ def _serialize(
self,
data,
- *,
- include=None,
- exclude=None,
- matcher=None,
):
if (
isinstance(data, Response)
or isinstance(data, ChatComplete)
or isinstance(data, ChatCompletionChunk)
or isinstance(data, ChatCompletionComplete)
+ or isinstance(data, Completion)
+ or isinstance(data, OAIChatCompletionChunk)
+ or isinstance(data, OAICompletion)
):
data = data.model_dump()
+ elif isinstance(data, ChatCompletionStreamOutput) or isinstance(
+ data, ChatCompletionOutput
+ ):
+ data = dict(data)
+ elif isinstance(data, List):
+ data = [self._serialize(d) for d in data]
+ elif isinstance(data, dict):
+ return data
+ else:
+ raise RuntimeError(f"Unexpected data {type(data)} : {data}")
+ return data
- if isinstance(data, List):
- data = [d.model_dump() for d in data]
-
+ def serialize(
+ self,
+ data,
+ *,
+ include=None,
+ exclude=None,
+ matcher=None,
+ ):
+ data = self._serialize(data)
data = self._filter(
data=data,
depth=0,
@@ -103,7 +195,8 @@ def serialize(
include=include,
matcher=matcher,
)
- return json.dumps(data, indent=2, ensure_ascii=False, sort_keys=False) + "\n"
+ data = json.dumps(data, indent=2, ensure_ascii=False, sort_keys=False) + "\n"
+ return data
def matches(
self,
@@ -119,7 +212,7 @@ def _convert_data(data):
if isinstance(data, Dict):
if "choices" in data:
data["choices"] = list(
- sorted(data["choices"], key=lambda x: x["index"])
+ sorted(data["choices"], key=lambda x: int(x["index"]))
)
choices = data["choices"]
if isinstance(choices, List) and len(choices) >= 1:
@@ -132,7 +225,7 @@ def _convert_data(data):
return Response(**data)
if isinstance(data, List):
return [_convert_data(d) for d in data]
- raise NotImplementedError
+ raise NotImplementedError(f"Data: {data}")
def eq_token(token: Token, other: Token) -> bool:
return (
@@ -230,7 +323,25 @@ def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
def eq_chat_complete_chunk(
response: ChatCompletionChunk, other: ChatCompletionChunk
) -> bool:
- return response.choices[0].delta.content == other.choices[0].delta.content
+ if response.choices:
+ if response.choices[0].delta.content is not None:
+ return (
+ response.choices[0].delta.content
+ == other.choices[0].delta.content
+ )
+ elif response.choices[0].delta.tool_calls is not None:
+ return (
+ response.choices[0].delta.tool_calls
+ == other.choices[0].delta.tool_calls
+ )
+ else:
+ raise RuntimeError(
+ f"Invalid empty chat chunk {response} vs {other}"
+ )
+ elif response.usage is not None:
+ return response.usage == other.usage
+ else:
+ raise RuntimeError(f"Invalid empty chat {response} vs {other}")
def eq_response(response: Response, other: Response) -> bool:
return response.generated_text == other.generated_text and eq_details(
@@ -245,6 +356,9 @@ def eq_response(response: Response, other: Response) -> bool:
if not isinstance(snapshot_data, List):
snapshot_data = [snapshot_data]
+ if len(serialized_data) == 0:
+ return len(snapshot_data) == len(serialized_data)
+
if isinstance(serialized_data[0], Completion):
return len(snapshot_data) == len(serialized_data) and all(
[eq_completion(r, o) for r, o in zip(serialized_data, snapshot_data)]
@@ -278,8 +392,10 @@ class IgnoreLogProbResponseComparator(ResponseComparator):
class LauncherHandle:
- def __init__(self, port: int):
- self.client = AsyncClient(f"http://localhost:{port}", timeout=30)
+ def __init__(self, port: int, error_log):
+ with warnings.catch_warnings(action="http://github.com/ignore"):
+ self.client = AsyncClient(f"http://localhost:{port}", timeout=30)
+ self.error_log = error_log
def _inner_health(self):
raise NotImplementedError
@@ -288,6 +404,8 @@ async def health(self, timeout: int = 60):
assert timeout > 0
for _ in range(timeout):
if not self._inner_health():
+ self.error_log.seek(0)
+ print(self.error_log.read(), file=sys.stderr)
raise RuntimeError("Launcher crashed")
try:
@@ -295,12 +413,14 @@ async def health(self, timeout: int = 60):
return
except (ClientConnectorError, ClientOSError, ServerDisconnectedError):
time.sleep(1)
+ self.error_log.seek(0)
+ print(self.error_log.read(), file=sys.stderr)
raise RuntimeError("Health check failed")
class ContainerLauncherHandle(LauncherHandle):
- def __init__(self, docker_client, container_name, port: int):
- super(ContainerLauncherHandle, self).__init__(port)
+ def __init__(self, docker_client, container_name, port: int, error_log):
+ super().__init__(port, error_log)
self.docker_client = docker_client
self.container_name = container_name
@@ -310,8 +430,8 @@ def _inner_health(self) -> bool:
class ProcessLauncherHandle(LauncherHandle):
- def __init__(self, process, port: int):
- super(ProcessLauncherHandle, self).__init__(port)
+ def __init__(self, process, port: int, error_log):
+ super().__init__(port, error_log)
self.process = process
def _inner_health(self) -> bool:
@@ -333,15 +453,14 @@ def ignore_logprob_response_snapshot(snapshot):
return snapshot.use_extension(IgnoreLogProbResponseComparator)
-@pytest.fixture(scope="module")
-def event_loop():
- loop = asyncio.get_event_loop()
- yield loop
- loop.close()
+@pytest.fixture(scope="session")
+def error_log():
+ with tempfile.TemporaryFile("w+") as tmp:
+ yield tmp
-@pytest.fixture(scope="module")
-def launcher(event_loop):
+@pytest.fixture(scope="session")
+async def launcher(error_log):
@contextlib.contextmanager
def local_launcher(
model_id: str,
@@ -354,6 +473,7 @@ def local_launcher(
kv_cache_dtype: Optional[str] = None,
revision: Optional[str] = None,
max_input_length: Optional[int] = None,
+ max_input_tokens: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None,
lora_adapters: Optional[List[str]] = None,
@@ -402,6 +522,9 @@ def local_launcher(
if max_input_length:
args.append("--max-input-length")
args.append(str(max_input_length))
+ if max_input_tokens:
+ args.append("--max-input-tokens")
+ args.append(str(max_input_tokens))
if max_batch_prefill_tokens:
args.append("--max-batch-prefill-tokens")
args.append(str(max_batch_prefill_tokens))
@@ -425,22 +548,19 @@ def local_launcher(
if attention is not None:
env["ATTENTION"] = attention
- with tempfile.TemporaryFile("w+") as tmp:
+ # with tempfile.TemporaryFile("w+") as tmp:
# We'll output stdout/stderr to a temporary file. Using a pipe
# cause the process to block until stdout is read.
- with subprocess.Popen(
- args,
- stdout=tmp,
- stderr=subprocess.STDOUT,
- env=env,
- ) as process:
- yield ProcessLauncherHandle(process, port)
+ with subprocess.Popen(
+ args,
+ stdout=error_log,
+ stderr=subprocess.STDOUT,
+ env=env,
+ ) as process:
+ yield ProcessLauncherHandle(process, port, error_log=error_log)
- process.terminate()
- process.wait(60)
-
- tmp.seek(0)
- shutil.copyfileobj(tmp, sys.stderr)
+ process.terminate()
+ process.wait(60)
if not use_flash_attention:
del env["USE_FLASH_ATTENTION"]
@@ -558,6 +678,7 @@ def docker_launcher(
docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
]
+ client.api.timeout = 1000
container = client.containers.run(
DOCKER_IMAGE,
command=args,
@@ -569,12 +690,25 @@ def docker_launcher(
devices=devices,
volumes=volumes,
ports={"80/tcp": port},
- healthcheck={"timeout": int(60 * 1e9), "retries": 2}, # 60s
+ healthcheck={"timeout": int(180 * 1e9), "retries": 2}, # 60s
shm_size="1G",
)
+ def pipe():
+ for log in container.logs(stream=True):
+ log = log.decode("utf-8")
+ error_log.write(log)
+
+ # Start looping to pipe the logs
+ import threading
+
+ t = threading.Thread(target=pipe, args=())
+ t.start()
+
try:
- yield ContainerLauncherHandle(client, container.name, port)
+ yield ContainerLauncherHandle(
+ client, container.name, port, error_log=error_log
+ )
if not use_flash_attention:
del env["USE_FLASH_ATTENTION"]
@@ -585,9 +719,6 @@ def docker_launcher(
except NotFound:
pass
- container_output = container.logs().decode("utf-8")
- print(container_output, file=sys.stderr)
-
finally:
try:
container.remove()
diff --git a/integration-tests/fixtures/gaudi/service.py b/integration-tests/fixtures/gaudi/service.py
new file mode 100644
index 00000000000..f4f43691cd9
--- /dev/null
+++ b/integration-tests/fixtures/gaudi/service.py
@@ -0,0 +1,311 @@
+import asyncio
+import contextlib
+import os
+import shlex
+import subprocess
+import sys
+import threading
+import time
+from tempfile import TemporaryDirectory
+from typing import List
+import socket
+
+import docker
+import pytest
+from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
+from docker.errors import NotFound
+import logging
+from huggingface_hub import AsyncInferenceClient, TextGenerationOutput
+import huggingface_hub
+
+logging.basicConfig(
+ level=logging.INFO,
+ format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}",
+ stream=sys.stdout,
+)
+logger = logging.getLogger(__file__)
+
+# Use the latest image from the local docker build
+DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", "tgi-gaudi")
+DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", None)
+HF_TOKEN = huggingface_hub.get_token()
+
+assert (
+ HF_TOKEN is not None
+), "HF_TOKEN is not set, please set it as some models are gated and thus the test will fail without it"
+
+if DOCKER_VOLUME is None:
+ logger.warning(
+ "DOCKER_VOLUME is not set, this will lead to the tests redownloading the models on each run, consider setting it to speed up testing"
+ )
+
+LOG_LEVEL = os.getenv("LOG_LEVEL", "info")
+
+BASE_ENV = {
+ "HF_HUB_ENABLE_HF_TRANSFER": "1",
+ "LOG_LEVEL": LOG_LEVEL,
+ "HF_TOKEN": os.getenv("HF_TOKEN", None),
+}
+
+
+HABANA_RUN_ARGS = {
+ "runtime": "habana",
+ "ipc_mode": "host",
+ "cap_add": ["sys_nice"],
+}
+
+
+def stream_container_logs(container, test_name):
+ """Stream container logs in a separate thread."""
+ try:
+ for log in container.logs(stream=True, follow=True):
+ print(
+ f"[TGI Server Logs - {test_name}] {log.decode('utf-8')}",
+ end="",
+ file=sys.stderr,
+ flush=True,
+ )
+ except Exception as e:
+ logger.error(f"Error streaming container logs: {str(e)}")
+
+
+class TestClient(AsyncInferenceClient):
+ def __init__(self, service_name: str, base_url: str):
+ super().__init__(model=base_url)
+ self.service_name = service_name
+
+
+class LauncherHandle:
+ def __init__(self, service_name: str, port: int):
+ self.client = TestClient(service_name, f"http://localhost:{port}")
+
+ def _inner_health(self):
+ raise NotImplementedError
+
+ async def health(self, timeout: int = 60):
+ assert timeout > 0
+ start_time = time.time()
+ logger.info(f"Starting health check with timeout of {timeout}s")
+
+ for attempt in range(timeout):
+ if not self._inner_health():
+ logger.error("Launcher crashed during health check")
+ raise RuntimeError("Launcher crashed")
+
+ try:
+ await self.client.text_generation("test", max_new_tokens=1)
+ elapsed = time.time() - start_time
+ logger.info(f"Health check passed after {elapsed:.1f}s")
+ return
+ except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e:
+ if attempt == timeout - 1:
+ logger.error(f"Health check failed after {timeout}s: {str(e)}")
+ raise RuntimeError(f"Health check failed: {str(e)}")
+ if attempt % 10 == 0 and attempt != 0: # Only log every 10th attempt
+ logger.debug(
+ f"Connection attempt {attempt}/{timeout} failed: {str(e)}"
+ )
+ time.sleep(1)
+ except Exception as e:
+ logger.error(f"Unexpected error during health check: {str(e)}")
+ # Get full traceback for debugging
+ import traceback
+
+ logger.error(f"Full traceback:\n{traceback.format_exc()}")
+ raise
+
+
+class ContainerLauncherHandle(LauncherHandle):
+ def __init__(self, docker_client, container_name, port: int):
+ service_name = container_name # Use container name as service name
+ super(ContainerLauncherHandle, self).__init__(service_name, port)
+ self.docker_client = docker_client
+ self.container_name = container_name
+
+ def _inner_health(self) -> bool:
+ try:
+ container = self.docker_client.containers.get(self.container_name)
+ status = container.status
+ if status not in ["running", "created"]:
+ logger.warning(f"Container status is {status}")
+ # Get container logs for debugging
+ logs = container.logs().decode("utf-8")
+ logger.debug(f"Container logs:\n{logs}")
+ return status in ["running", "created"]
+ except Exception as e:
+ logger.error(f"Error checking container health: {str(e)}")
+ return False
+
+
+class ProcessLauncherHandle(LauncherHandle):
+ def __init__(self, process, port: int):
+ service_name = "process" # Use generic name for process launcher
+ super(ProcessLauncherHandle, self).__init__(service_name, port)
+ self.process = process
+
+ def _inner_health(self) -> bool:
+ return self.process.poll() is None
+
+
+@pytest.fixture(scope="module")
+def data_volume():
+ tmpdir = TemporaryDirectory()
+ yield tmpdir.name
+ try:
+ # Cleanup the temporary directory using sudo as it contains root files created by the container
+ subprocess.run(shlex.split(f"sudo rm -rf {tmpdir.name}"), check=True)
+ except subprocess.CalledProcessError as e:
+ logger.error(f"Error cleaning up temporary directory: {str(e)}")
+
+
+@pytest.fixture(scope="module")
+def gaudi_launcher():
+ @contextlib.contextmanager
+ def docker_launcher(
+ model_id: str,
+ test_name: str,
+ tgi_args: List[str] = None,
+ env_config: dict = None,
+ ):
+ logger.info(
+ f"Starting docker launcher for model {model_id} and test {test_name}"
+ )
+
+ # Get a random available port
+ def get_free_port():
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(("", 0))
+ s.listen(1)
+ port = s.getsockname()[1]
+ return port
+
+ port = get_free_port()
+ logger.debug(f"Using port {port}")
+
+ client = docker.from_env()
+
+ container_name = f"tgi-gaudi-test-{test_name.replace('/', '-')}"
+
+ try:
+ container = client.containers.get(container_name)
+ logger.info(
+ f"Stopping existing container {container_name} for test {test_name}"
+ )
+ container.stop()
+ container.wait()
+ container.remove()
+ logger.info(f"Removed existing container {container_name}")
+ except NotFound:
+ pass
+ except Exception as e:
+ logger.error(f"Error handling existing container: {str(e)}")
+
+ if tgi_args is None:
+ tgi_args = []
+ else:
+ tgi_args = tgi_args.copy()
+
+ env = BASE_ENV.copy()
+
+ # Add model_id to env
+ env["MODEL_ID"] = model_id
+
+ # Add env config that is defined in the fixture parameter
+ if env_config is not None:
+ env.update(env_config.copy())
+
+ volumes = []
+ if DOCKER_VOLUME:
+ volumes = [f"{DOCKER_VOLUME}:/data"]
+ logger.debug(f"Using volume {volumes}")
+
+ try:
+ logger.debug(f"Using command {tgi_args}")
+ logger.info(f"Creating container with name {container_name}")
+
+ logger.debug(f"Using environment {env}")
+ logger.debug(f"Using volumes {volumes}")
+ logger.debug(f"HABANA_RUN_ARGS {HABANA_RUN_ARGS}")
+
+ # Log equivalent docker run command for debugging, this is not actually executed
+ container = client.containers.run(
+ DOCKER_IMAGE,
+ command=tgi_args,
+ name=container_name,
+ environment=env,
+ detach=True,
+ volumes=volumes,
+ ports={"80/tcp": port},
+ **HABANA_RUN_ARGS,
+ )
+
+ logger.info(f"Container {container_name} started successfully")
+
+ # Start log streaming in a background thread
+ log_thread = threading.Thread(
+ target=stream_container_logs,
+ args=(container, test_name),
+ daemon=True, # This ensures the thread will be killed when the main program exits
+ )
+ log_thread.start()
+
+ # Add a small delay to allow container to initialize
+ time.sleep(2)
+
+ # Check container status after creation
+ status = container.status
+ logger.debug(f"Initial container status: {status}")
+ if status not in ["running", "created"]:
+ logs = container.logs().decode("utf-8")
+ logger.error(f"Container failed to start properly. Logs:\n{logs}")
+
+ yield ContainerLauncherHandle(client, container.name, port)
+
+ except Exception as e:
+ logger.error(f"Error starting container: {str(e)}")
+ # Get full traceback for debugging
+ import traceback
+
+ logger.error(f"Full traceback:\n{traceback.format_exc()}")
+ raise
+ finally:
+ try:
+ container = client.containers.get(container_name)
+ logger.info(f"Stopping container {container_name}")
+ container.stop()
+ container.wait()
+
+ container_output = container.logs().decode("utf-8")
+ print(container_output, file=sys.stderr)
+
+ container.remove()
+ logger.info(f"Container {container_name} removed successfully")
+ except NotFound:
+ pass
+ except Exception as e:
+ logger.warning(f"Error cleaning up container: {str(e)}")
+
+ return docker_launcher
+
+
+@pytest.fixture(scope="module")
+def gaudi_generate_load():
+ async def generate_load_inner(
+ client: AsyncInferenceClient, prompt: str, max_new_tokens: int, n: int
+ ) -> List[TextGenerationOutput]:
+ try:
+ futures = [
+ client.text_generation(
+ prompt,
+ max_new_tokens=max_new_tokens,
+ details=True,
+ decoder_input_details=True,
+ )
+ for _ in range(n)
+ ]
+ return await asyncio.gather(*futures)
+ except Exception as e:
+ logger.error(f"Error generating load: {str(e)}")
+ raise
+
+ return generate_load_inner
diff --git a/integration-tests/fixtures/neuron/export_models.py b/integration-tests/fixtures/neuron/export_models.py
new file mode 100644
index 00000000000..beee2ba7127
--- /dev/null
+++ b/integration-tests/fixtures/neuron/export_models.py
@@ -0,0 +1,274 @@
+import copy
+import logging
+import sys
+from tempfile import TemporaryDirectory
+
+import huggingface_hub
+import pytest
+import docker
+import hashlib
+import os
+import tempfile
+
+from docker.errors import NotFound
+
+
+TEST_ORGANIZATION = "optimum-internal-testing"
+TEST_CACHE_REPO_ID = f"{TEST_ORGANIZATION}/neuron-testing-cache"
+HF_TOKEN = huggingface_hub.get_token()
+
+
+logging.basicConfig(
+ level=logging.INFO,
+ format="[%(asctime)s] %(levelname)s [%(filename)s.%(funcName)s:%(lineno)d] %(message)s",
+ stream=sys.stdout,
+)
+logger = logging.getLogger(__file__)
+
+
+# All model configurations below will be added to the neuron_model_config fixture
+MODEL_CONFIGURATIONS = {
+ "llama": {
+ "model_id": "unsloth/Llama-3.2-1B-Instruct",
+ "export_kwargs": {
+ "batch_size": 4,
+ "sequence_length": 2048,
+ "num_cores": 2,
+ "auto_cast_type": "fp16",
+ },
+ },
+ "qwen2": {
+ "model_id": "Qwen/Qwen2.5-0.5B",
+ "export_kwargs": {
+ "batch_size": 4,
+ "sequence_length": 4096,
+ "num_cores": 2,
+ "auto_cast_type": "fp16",
+ },
+ },
+ "qwen3": {
+ "model_id": "Qwen/Qwen3-1.7B",
+ "export_kwargs": {
+ "batch_size": 4,
+ "sequence_length": 4096,
+ "num_cores": 2,
+ "auto_cast_type": "bf16",
+ },
+ },
+ "granite": {
+ "model_id": "ibm-granite/granite-3.1-2b-instruct",
+ "export_kwargs": {
+ "batch_size": 4,
+ "sequence_length": 4096,
+ "num_cores": 2,
+ "auto_cast_type": "bf16",
+ },
+ },
+ "phi3": {
+ "model_id": "microsoft/Phi-3-mini-4k-instruct",
+ "export_kwargs": {
+ "batch_size": 4,
+ "sequence_length": 4096,
+ "num_cores": 2,
+ "auto_cast_type": "bf16",
+ },
+ },
+}
+
+
+def get_neuron_backend_hash():
+ import subprocess
+
+ res = subprocess.run(
+ ["git", "rev-parse", "--show-toplevel"], capture_output=True, text=True
+ )
+ root_dir = res.stdout.split("\n")[0]
+
+ def get_sha(path):
+ res = subprocess.run(
+ ["git", "ls-tree", "HEAD", f"{root_dir}/{path}"],
+ capture_output=True,
+ text=True,
+ )
+ # Output of the command is in the form '040000 tree|blob \t\n'
+ sha = res.stdout.split("\t")[0].split(" ")[-1]
+ return sha.encode()
+
+ # We hash both the neuron backends directory and Dockerfile and create a smaller hash out of that
+ m = hashlib.sha256()
+ m.update(get_sha("backends/neuron"))
+ m.update(get_sha("Dockerfile.neuron"))
+ return m.hexdigest()[:10]
+
+
+def get_neuron_model_name(config_name: str):
+ return f"neuron-tgi-testing-{config_name}-{get_neuron_backend_hash()}"
+
+
+def get_tgi_docker_image():
+ docker_image = os.getenv("DOCKER_IMAGE", None)
+ if docker_image is None:
+ client = docker.from_env()
+ images = client.images.list(filters={"reference": "text-generation-inference"})
+ if not images:
+ raise ValueError(
+ "No text-generation-inference image found on this host to run tests."
+ )
+ docker_image = images[0].tags[0]
+ return docker_image
+
+
+def maybe_export_model(config_name, model_config):
+ """Export a neuron model for the specified test configuration.
+
+ If the neuron model has not already been compiled and pushed to the hub, it is
+ exported by a custom image built on the fly from the base TGI image.
+ This makes sure the exported model and image are aligned and avoids introducing
+ neuron specific imports in the test suite.
+
+ Args:
+ config_name (`str`):
+ Used to identify test configurations
+ model_config (`str`):
+ The model configuration for export (includes the original model id)
+ """
+ neuron_model_name = get_neuron_model_name(config_name)
+ neuron_model_id = f"{TEST_ORGANIZATION}/{neuron_model_name}"
+ hub = huggingface_hub.HfApi()
+ if hub.repo_exists(neuron_model_id):
+ logger.info(
+ f"Skipping model export for config {config_name} as {neuron_model_id} already exists"
+ )
+ return neuron_model_id
+
+ client = docker.from_env()
+
+ env = {"LOG_LEVEL": "info", "CUSTOM_CACHE_REPO": TEST_CACHE_REPO_ID}
+ if HF_TOKEN is not None:
+ env["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
+ env["HF_TOKEN"] = HF_TOKEN
+
+ # Create a sub-image to export the model to workaround docker dind issues preventing
+ # to share a volume from the container running tests
+ model_id = model_config["model_id"]
+ export_kwargs = model_config["export_kwargs"]
+ base_image = get_tgi_docker_image()
+ export_image = f"neuron-tgi-tests-{config_name}-export-img"
+ logger.info(f"Building temporary image {export_image} from {base_image}")
+ with tempfile.TemporaryDirectory() as context_dir:
+ # Create entrypoint
+ model_path = "/data/neuron_model"
+ export_command = (
+ f"optimum-cli export neuron -m {model_id} --task text-generation"
+ )
+ for kwarg, value in export_kwargs.items():
+ export_command += f" --{kwarg} {str(value)}"
+ export_command += f" {model_path}"
+ entrypoint_content = f"""#!/bin/sh
+ {export_command}
+ huggingface-cli repo create --organization {TEST_ORGANIZATION} {neuron_model_name}
+ huggingface-cli upload {TEST_ORGANIZATION}/{neuron_model_name} {model_path} --exclude *.bin *.safetensors
+ optimum-cli neuron cache synchronize --repo_id {TEST_CACHE_REPO_ID}
+ """
+ with open(os.path.join(context_dir, "entrypoint.sh"), "wb") as f:
+ f.write(entrypoint_content.encode("utf-8"))
+ f.flush()
+ # Create Dockerfile
+ docker_content = f"""
+ FROM {base_image}
+ COPY entrypoint.sh /export-entrypoint.sh
+ RUN chmod +x /export-entrypoint.sh
+ ENTRYPOINT ["/export-entrypoint.sh"]
+ """
+ with open(os.path.join(context_dir, "Dockerfile"), "wb") as f:
+ f.write(docker_content.encode("utf-8"))
+ f.flush()
+ image, logs = client.images.build(
+ path=context_dir, dockerfile=f.name, tag=export_image
+ )
+ logger.info("Successfully built image %s", image.id)
+ logger.debug("Build logs %s", logs)
+
+ try:
+ client.containers.run(
+ export_image,
+ environment=env,
+ auto_remove=True,
+ detach=False,
+ devices=["/dev/neuron0"],
+ shm_size="1G",
+ )
+ logger.info(f"Successfully exported model for config {config_name}")
+ except Exception as e:
+ logger.exception(f"An exception occurred while running container: {e}.")
+ pass
+ finally:
+ # Cleanup the export image
+ logger.info("Cleaning image %s", image.id)
+ try:
+ image.remove(force=True)
+ except NotFound:
+ pass
+ except Exception as e:
+ logger.error("Error while removing image %s, skipping", image.id)
+ logger.exception(e)
+ return neuron_model_id
+
+
+def maybe_export_models():
+ for config_name, model_config in MODEL_CONFIGURATIONS.items():
+ maybe_export_model(config_name, model_config)
+
+
+@pytest.fixture(scope="session", params=MODEL_CONFIGURATIONS.keys())
+def neuron_model_config(request):
+ """Expose a pre-trained neuron model
+
+ The fixture first makes sure the following model artifacts are present on the hub:
+ - exported neuron model under optimum-internal-testing/neuron-testing-