diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml
deleted file mode 100644
index 124e6a33ee0..00000000000
--- a/.github/workflows/build.yaml
+++ /dev/null
@@ -1,246 +0,0 @@
-name: Build and push docker image to internal registry
-
-on:
- workflow_dispatch:
- push:
- branches:
- - 'main'
- tags:
- - 'v*'
- pull_request:
- paths:
- - ".github/workflows/build.yaml"
- - "integration-tests/**"
- - "server/**"
- - "proto/**"
- - "router/**"
- - "launcher/**"
- - "Cargo.lock"
- - "rust-toolchain.toml"
- - "Dockerfile"
- branches:
- - 'main'
-
-jobs:
- start-runner:
- name: Start self-hosted EC2 runner
- runs-on: ubuntu-latest
- env:
- AWS_REGION: us-east-1
- EC2_AMI_ID: ami-03cfed9ea28f4b002
- EC2_INSTANCE_TYPE: g5.12xlarge
- EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
- EC2_SECURITY_GROUP: sg-030175c435ac141d6
- outputs:
- label: ${{ steps.start-ec2-runner.outputs.label }}
- ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
- steps:
- - name: Configure AWS credentials
- uses: aws-actions/configure-aws-credentials@v1
- with:
- aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
- aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
- aws-region: ${{ env.AWS_REGION }}
- - name: Start EC2 runner
- id: start-ec2-runner
- uses: philschmid/philschmid-ec2-github-runner@main
- with:
- mode: start
- github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
- ec2-image-id: ${{ env.EC2_AMI_ID }}
- ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
- subnet-id: ${{ env.EC2_SUBNET_ID }}
- security-group-id: ${{ env.EC2_SECURITY_GROUP }}
- aws-resource-tags: > # optional, requires additional permissions
- [
- {"Key": "Name", "Value": "ec2-tgi-github-runner"},
- {"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
- ]
-
- build-and-push-image:
- concurrency:
- group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
- cancel-in-progress: true
- needs: start-runner # required to start the main job when the runner is ready
- runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
- permissions:
- contents: write
- packages: write
- # This is used to complete the identity challenge
- # with sigstore/fulcio when running outside of PRs.
- id-token: write
- security-events: write
- steps:
- - name: Checkout repository
- uses: actions/checkout@v3
- - name: Initialize Docker Buildx
- uses: docker/setup-buildx-action@v2.0.0
- with:
- install: true
- - name: Inject slug/short variables
- uses: rlespinasse/github-slug-action@v4.4.1
- - name: Install cosign
- if: github.event_name != 'pull_request'
- uses: sigstore/cosign-installer@f3c664df7af409cb4873aa5068053ba9d61a57b6 #v2.6.0
- with:
- cosign-release: 'v1.13.1'
- - name: Tailscale
- uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
- with:
- authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
- - name: Login to GitHub Container Registry
- if: github.event_name != 'pull_request'
- uses: docker/login-action@v2
- with:
- registry: ghcr.io
- username: ${{ github.actor }}
- password: ${{ secrets.GITHUB_TOKEN }}
- - name: Login to internal Container Registry
- uses: docker/login-action@v2.1.0
- with:
- username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }}
- password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }}
- registry: registry.internal.huggingface.tech
- - name: Login to Azure Container Registry
- if: github.event_name != 'pull_request'
- uses: docker/login-action@v2.1.0
- with:
- username: ${{ secrets.AZURE_DOCKER_USERNAME }}
- password: ${{ secrets.AZURE_DOCKER_PASSWORD }}
- registry: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io
- # If pull request
- - name: Extract metadata (tags, labels) for Docker
- if: ${{ github.event_name == 'pull_request' }}
- id: meta-pr
- uses: docker/metadata-action@v4.3.0
- with:
- images: |
- registry.internal.huggingface.tech/api-inference/community/text-generation-inference
- tags: |
- type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}
- # If main, release or tag
- - name: Extract metadata (tags, labels) for Docker
- if: ${{ github.event_name != 'pull_request' }}
- id: meta
- uses: docker/metadata-action@v4.3.0
- with:
- flavor: |
- latest=auto
- images: |
- registry.internal.huggingface.tech/api-inference/community/text-generation-inference
- ghcr.io/huggingface/text-generation-inference
- db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
- tags: |
- type=semver,pattern={{version}}
- type=semver,pattern={{major}}.{{minor}}
- type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
- type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}
- - name: Build and push Docker image
- id: build-and-push
- uses: docker/build-push-action@v4
- with:
- context: .
- file: Dockerfile
- push: true
- platforms: 'linux/amd64'
- build-args: |
- GIT_SHA=${{ env.GITHUB_SHA }}
- DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}
- tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
- labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
- cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min
- cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min
- # Sign the resulting Docker image digest except on PRs.
- # This will only write to the public Rekor transparency log when the Docker
- # repository is public to avoid leaking data.
- - name: Sign the published Docker image
- if: ${{ github.event_name != 'pull_request' }}
- env:
- COSIGN_EXPERIMENTAL: "true"
- # This step uses the identity token to provision an ephemeral certificate
- # against the sigstore community Fulcio instance.
- run: echo "${{ steps.meta.outputs.tags }}" | xargs -I {} cosign sign {}@${{ steps.build-and-push.outputs.digest }}
- - name: Run Trivy in GitHub SBOM mode and submit results to Dependency Graph
- uses: aquasecurity/trivy-action@master
- if: ${{ github.event_name != 'pull_request' }}
- with:
- image-ref: 'ghcr.io/huggingface/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}'
- format: 'github'
- output: 'dependency-results.sbom.json'
- github-pat: ${{ secrets.GITHUB_TOKEN }}
- scanners: 'vuln'
- - name: Run Trivy vulnerability scanner
- uses: aquasecurity/trivy-action@master
- if: ${{ github.event_name != 'pull_request' }}
- with:
- image-ref: 'ghcr.io/huggingface/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}'
- format: 'sarif'
- output: 'trivy-results.sarif'
- severity: 'CRITICAL'
- scanners: 'vuln'
- - name: Upload Trivy scan results to GitHub Security tab
- uses: github/codeql-action/upload-sarif@v2
- if: ${{ github.event_name != 'pull_request' }}
- with:
- sarif_file: 'trivy-results.sarif'
-
- integration-tests:
- concurrency:
- group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
- cancel-in-progress: true
- needs:
- - start-runner
- - build-and-push-image # Wait for the docker image to be built
- runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
- env:
- DOCKER_VOLUME: /cache
- steps:
- - uses: actions/checkout@v2
- - name: Inject slug/short variables
- uses: rlespinasse/github-slug-action@v4.4.1
- - name: Set up Python
- uses: actions/setup-python@v4
- with:
- python-version: 3.9
- - name: Tailscale
- uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
- with:
- authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
- - name: Prepare disks
- run: |
- sudo mkfs -t ext4 /dev/nvme1n1
- sudo mkdir ${{ env.DOCKER_VOLUME }}
- sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
- - name: Install
- run: |
- make install-integration-tests
- - name: Run tests
- run: |
- export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
- export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
- pytest -s -vv integration-tests
-
- stop-runner:
- name: Stop self-hosted EC2 runner
- needs:
- - start-runner
- - build-and-push-image
- - integration-tests
- runs-on: ubuntu-latest
- env:
- AWS_REGION: us-east-1
- if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
- steps:
- - name: Configure AWS credentials
- uses: aws-actions/configure-aws-credentials@v1
- with:
- aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
- aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
- aws-region: ${{ env.AWS_REGION }}
- - name: Stop EC2 runner
- uses: philschmid/philschmid-ec2-github-runner@main
- with:
- mode: stop
- github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
- label: ${{ needs.start-runner.outputs.label }}
- ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
diff --git a/.github/workflows/client-tests.yaml b/.github/workflows/client-tests.yaml
deleted file mode 100644
index 1fa0b39d7db..00000000000
--- a/.github/workflows/client-tests.yaml
+++ /dev/null
@@ -1,25 +0,0 @@
-name: Python Client Tests
-
-on:
- pull_request:
- paths:
- - ".github/workflows/client-tests.yaml"
- - "clients/python/**"
-
-jobs:
- run_tests:
- runs-on: ubuntu-latest
-
- steps:
- - uses: actions/checkout@v2
- - name: Set up Python
- uses: actions/setup-python@v1
- with:
- python-version: 3.9
- - name: Install
- run: |
- cd clients/python && pip install .
- - name: Run tests
- run: |
- pip install pytest pytest-asyncio
- make python-client-tests
diff --git a/.github/workflows/free_disk_space.sh b/.github/workflows/free_disk_space.sh
new file mode 100644
index 00000000000..416884b6a4b
--- /dev/null
+++ b/.github/workflows/free_disk_space.sh
@@ -0,0 +1,49 @@
+#!/usr/bin/env bash
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+#
+# The Azure provided machines typically have the following disk allocation:
+# Total space: 85GB
+# Allocated: 67 GB
+# Free: 17 GB
+# This script frees up 28 GB of disk space by deleting unneeded packages and
+# large directories.
+# The Flink end to end tests download and generate more than 17 GB of files,
+# causing unpredictable behavior and build failures.
+#
+echo "=============================================================================="
+echo "Freeing up disk space on CI system"
+echo "=============================================================================="
+
+echo "Listing 100 largest packages"
+dpkg-query -Wf '${Installed-Size}\t${Package}\n' | sort -n | tail -n 100
+df -h
+echo "Removing large packages"
+sudo apt-get remove -y '^ghc-8.*'
+sudo apt-get remove -y '^dotnet-.*'
+sudo apt-get remove -y '^llvm-.*'
+sudo apt-get remove -y 'php.*'
+sudo apt-get remove -y azure-cli google-cloud-sdk hhvm google-chrome-stable firefox microsoft-edge-stable powershell mono-devel
+sudo apt-get remove -y '^gcc-.*'
+sudo apt-get remove -y '^g++-.*'
+sudo apt-get autoremove -y
+sudo apt-get clean
+df -h
+echo "Removing large directories"
+# deleting 15GB
+rm -rf /usr/share/dotnet/
+df -h
\ No newline at end of file
diff --git a/.github/workflows/load_test.yaml b/.github/workflows/load_test.yaml
deleted file mode 100644
index fd22e395780..00000000000
--- a/.github/workflows/load_test.yaml
+++ /dev/null
@@ -1,108 +0,0 @@
-name: Nightly load test
-
-on:
- schedule:
- - cron: '0 0 * * 1-5'
-
- pull_request:
- paths:
- - ".github/workflows/load_test.yaml"
- branches:
- - 'main'
-
-jobs:
- start-runner:
- name: Start self-hosted EC2 runner
- runs-on: ubuntu-latest
- env:
- AWS_REGION: eu-central-1
- EC2_AMI_ID: ami-0ab09c07cfd194259
- EC2_INSTANCE_TYPE: g5.12xlarge
- EC2_SUBNET_ID: subnet-988fd9f2,subnet-6f56db13,subnet-6a039326
- EC2_SECURITY_GROUP: sg-072f92ae3082936c6
- outputs:
- label: ${{ steps.start-ec2-runner.outputs.label }}
- ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
- steps:
- - name: Configure AWS credentials
- uses: aws-actions/configure-aws-credentials@v1
- with:
- aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
- aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
- aws-region: ${{ env.AWS_REGION }}
- - name: Start EC2 runner
- id: start-ec2-runner
- uses: philschmid/philschmid-ec2-github-runner@main
- with:
- mode: start
- github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
- ec2-image-id: ${{ env.EC2_AMI_ID }}
- ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
- subnet-id: ${{ env.EC2_SUBNET_ID }}
- security-group-id: ${{ env.EC2_SECURITY_GROUP }}
- aws-resource-tags: > # optional, requires additional permissions
- [
- {"Key": "Name", "Value": "ec2-tgi-github-runner"},
- {"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
- ]
-
- load-tests:
- concurrency:
- group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
- cancel-in-progress: true
- needs: start-runner # required to start the main job when the runner is ready
- runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
- env:
- DOCKER_VOLUME: /cache
- steps:
- - name: Checkout repository
- uses: actions/checkout@v3
-
- - name: Prepare disks
- run: |
- sudo mkfs -t ext4 /dev/nvme1n1
- sudo mkdir ${{ env.DOCKER_VOLUME }}
- sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
-
- - name: Install k6
- run: |
- curl https://github.com/grafana/k6/releases/download/v0.44.0/k6-v0.44.0-linux-amd64.tar.gz -L | tar xvz --strip-components 1
-
- - name: Start starcoder
- run: |
- docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v ${{ env.DOCKER_VOLUME }}:/data -e HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768
- sleep 10
- wget --timeout 10 --retry-on-http-error --waitretry=1 --tries=240 http://localhost:3000/health
-
- - name: Run k6
- run: |
- ./k6 run load_tests/starcoder_load.js
-
- - name: Stop starcoder
- if: ${{ always() }}
- run: |
- docker stop tgi-starcoder || true
-
- stop-runner:
- name: Stop self-hosted EC2 runner
- needs:
- - start-runner
- - load-tests
- runs-on: ubuntu-latest
- env:
- AWS_REGION: eu-central-1
- if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
- steps:
- - name: Configure AWS credentials
- uses: aws-actions/configure-aws-credentials@v1
- with:
- aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
- aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
- aws-region: ${{ env.AWS_REGION }}
- - name: Stop EC2 runner
- uses: philschmid/philschmid-ec2-github-runner@main
- with:
- mode: stop
- github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
- label: ${{ needs.start-runner.outputs.label }}
- ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
diff --git a/.github/workflows/setup_conda.yml b/.github/workflows/setup_conda.yml
new file mode 100644
index 00000000000..816b9fd6db1
--- /dev/null
+++ b/.github/workflows/setup_conda.yml
@@ -0,0 +1,27 @@
+name: Test Conda Setup
+
+on: [push, pull_request]
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v2
+
+ - name: Set up Conda
+ uses: conda-incubator/setup-miniconda@v2
+ with:
+ auto-activate-base: true
+ activate-environment: ""
+ auto-update-conda: true
+
+ - name: Free up disk space
+ run: |
+ bash .github/workflows/free_disk_space.sh
+
+ - name: Run Conda Server Setup
+ shell: bash -l {0}
+ run: |
+ bash ./setup_scripts/conda_server.sh --light-mode --no-tests
\ No newline at end of file
diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml
deleted file mode 100644
index 7e5ba52cb5e..00000000000
--- a/.github/workflows/tests.yaml
+++ /dev/null
@@ -1,82 +0,0 @@
-name: Server Tests
-
-on:
- pull_request:
- paths:
- - ".github/workflows/tests.yaml"
- - "server/**"
- - "proto/**"
- - "router/**"
- - "launcher/**"
- - "Cargo.lock"
- - "rust-toolchain.toml"
-
-concurrency:
- group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
- cancel-in-progress: true
-
-jobs:
- run_tests:
- runs-on: ubuntu-latest
-
- env:
- SCCACHE_GHA_ENABLED: "on"
- RUSTC_WRAPPER: /usr/local/bin/sccache
- SCCACHE: 0.3.3
-
- steps:
- - uses: actions/checkout@v2
- - name: Set up Python
- uses: actions/setup-python@v1
- with:
- python-version: 3.9
- - name: Install Rust
- uses: actions-rs/toolchain@v1
- with:
- toolchain: 1.65.0
- override: true
- components: rustfmt, clippy
- - name: Install Protoc
- uses: arduino/setup-protoc@v1
- - name: Install sccache
- run: |
- curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache
- chmod +x /usr/local/bin/sccache
- - name: configure sccache
- uses: actions/github-script@v6
- with:
- script: |
- core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || '');
- core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
- core.exportVariable('SCCACHE_GHA_CACHE_TO', 'sccache-${{runner.os}}-${{github.ref_name}}');
- core.exportVariable('SCCACHE_GHA_CACHE_FROM', 'sccache-${{runner.os}}-main,sccache-${{runner.os}}-');
- - name: cargo registry cache
- uses: actions/cache@v3
- with:
- key: cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-${{ github.sha }}
- restore-keys: |
- cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-
- cargo-${{ runner.os }}-
- path: |
- ~/.cargo/registry
- ~/.cargo/git
- - name: Install
- run: |
- make install
- - name: Run server tests
- run: |
- pip install pytest
- export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
- pytest -s -vv server/tests
- - name: Run Rust fmt
- run: |
- cargo fmt --check
- - name: Run Rust clippy
- run: |
- cargo clippy
- - name: Run Rust tests
- run: |
- cargo test
- - name: sccache stats
- run: |
- /usr/local/bin/sccache --show-stats
diff --git a/.gitignore b/.gitignore
index 20c9baee226..de17588f0a1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,5 @@
.idea
target
router/tokenizer.json
-*__pycache__*
+.openssl
+*__pycache__*
\ No newline at end of file
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 00000000000..ae1eeb3c438
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "chat-ui"]
+ path = chat-ui
+ url = git@github.com:CoderPat/chat-ui.git
diff --git a/Cargo.lock b/Cargo.lock
index 8984ea6ad4b..c2866b114fe 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -57,6 +57,21 @@ dependencies = [
"memchr",
]
+[[package]]
+name = "android-tzdata"
+version = "0.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
+
+[[package]]
+name = "android_system_properties"
+version = "0.1.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
+dependencies = [
+ "libc",
+]
+
[[package]]
name = "anstream"
version = "0.3.2"
@@ -398,6 +413,21 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
+[[package]]
+name = "chrono"
+version = "0.4.26"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ec837a71355b28f6556dbd569b37b3f363091c0bd4b2e735674521b4c5fd9bc5"
+dependencies = [
+ "android-tzdata",
+ "iana-time-zone",
+ "js-sys",
+ "num-traits",
+ "time 0.1.45",
+ "wasm-bindgen",
+ "winapi",
+]
+
[[package]]
name = "cipher"
version = "0.4.4"
@@ -991,7 +1021,7 @@ dependencies = [
"cfg-if",
"js-sys",
"libc",
- "wasi",
+ "wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
@@ -1057,6 +1087,31 @@ version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a"
+[[package]]
+name = "headers"
+version = "0.3.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f3e372db8e5c0d213e0cd0b9be18be2aca3d44cf2fe30a9d46a65581cd454584"
+dependencies = [
+ "base64 0.13.1",
+ "bitflags 1.3.2",
+ "bytes",
+ "headers-core",
+ "http",
+ "httpdate",
+ "mime",
+ "sha1",
+]
+
+[[package]]
+name = "headers-core"
+version = "0.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429"
+dependencies = [
+ "http",
+]
+
[[package]]
name = "heck"
version = "0.4.1"
@@ -1178,6 +1233,29 @@ dependencies = [
"tokio-native-tls",
]
+[[package]]
+name = "iana-time-zone"
+version = "0.1.57"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2fad5b825842d2b38bd206f3e81d6957625fd7f0a361e345c30e01a0ae2dd613"
+dependencies = [
+ "android_system_properties",
+ "core-foundation-sys",
+ "iana-time-zone-haiku",
+ "js-sys",
+ "wasm-bindgen",
+ "windows",
+]
+
+[[package]]
+name = "iana-time-zone-haiku"
+version = "0.1.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
+dependencies = [
+ "cc",
+]
+
[[package]]
name = "ident_case"
version = "1.0.1"
@@ -1546,7 +1624,7 @@ checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2"
dependencies = [
"libc",
"log",
- "wasi",
+ "wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys 0.48.0",
]
@@ -1571,6 +1649,24 @@ dependencies = [
"syn 2.0.25",
]
+[[package]]
+name = "multer"
+version = "2.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "01acbdc23469fd8fe07ab135923371d5f5a422fbf9c522158677c8eb15bc51c2"
+dependencies = [
+ "bytes",
+ "encoding_rs",
+ "futures-util",
+ "http",
+ "httparse",
+ "log",
+ "memchr",
+ "mime",
+ "spin 0.9.8",
+ "version_check",
+]
+
[[package]]
name = "multimap"
version = "0.8.3"
@@ -1905,7 +2001,7 @@ dependencies = [
"once_cell",
"pin-project-lite",
"thiserror",
- "urlencoding",
+ "urlencoding 2.1.2",
]
[[package]]
@@ -2195,7 +2291,7 @@ dependencies = [
"mach2",
"once_cell",
"raw-cpuid",
- "wasi",
+ "wasi 0.11.0+wasi-snapshot-preview1",
"web-sys",
"winapi",
]
@@ -2547,6 +2643,12 @@ dependencies = [
"windows-sys 0.48.0",
]
+[[package]]
+name = "scoped-tls"
+version = "1.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294"
+
[[package]]
name = "scopeguard"
version = "1.1.0"
@@ -2893,7 +2995,7 @@ dependencies = [
[[package]]
name = "text-generation-benchmark"
-version = "0.9.4"
+version = "0.9.5"
dependencies = [
"average",
"clap",
@@ -2911,9 +3013,25 @@ dependencies = [
"tracing-subscriber",
]
+[[package]]
+name = "text-generation-central"
+version = "0.9.5"
+dependencies = [
+ "bytes",
+ "chrono",
+ "clap",
+ "reqwest",
+ "serde",
+ "serde_json",
+ "tokio",
+ "urlencoding 1.3.3",
+ "warp",
+ "whoami",
+]
+
[[package]]
name = "text-generation-client"
-version = "0.9.4"
+version = "0.9.5"
dependencies = [
"futures",
"grpc-metadata",
@@ -2929,7 +3047,7 @@ dependencies = [
[[package]]
name = "text-generation-launcher"
-version = "0.9.4"
+version = "0.9.5"
dependencies = [
"clap",
"ctrlc",
@@ -2940,12 +3058,14 @@ dependencies = [
"serde_json",
"tracing",
"tracing-subscriber",
+ "urlencoding 1.3.3",
"vergen",
+ "whoami",
]
[[package]]
name = "text-generation-router"
-version = "0.9.4"
+version = "0.9.5"
dependencies = [
"async-stream",
"axum",
@@ -3006,6 +3126,17 @@ dependencies = [
"once_cell",
]
+[[package]]
+name = "time"
+version = "0.1.45"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1b797afad3f312d1c66a56d11d0316f916356d11bd158fbc6ca6389ff6bf805a"
+dependencies = [
+ "libc",
+ "wasi 0.10.0+wasi-snapshot-preview1",
+ "winapi",
+]
+
[[package]]
name = "time"
version = "0.3.23"
@@ -3159,6 +3290,18 @@ dependencies = [
"tokio",
]
+[[package]]
+name = "tokio-tungstenite"
+version = "0.18.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "54319c93411147bced34cb5609a80e0a8e44c5999c93903a81cd866630ec0bfd"
+dependencies = [
+ "futures-util",
+ "log",
+ "tokio",
+ "tungstenite",
+]
+
[[package]]
name = "tokio-util"
version = "0.7.8"
@@ -3436,6 +3579,25 @@ version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed"
+[[package]]
+name = "tungstenite"
+version = "0.18.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "30ee6ab729cd4cf0fd55218530c4522ed30b7b6081752839b68fcec8d0960788"
+dependencies = [
+ "base64 0.13.1",
+ "byteorder",
+ "bytes",
+ "http",
+ "httparse",
+ "log",
+ "rand",
+ "sha1",
+ "thiserror",
+ "url",
+ "utf-8",
+]
+
[[package]]
name = "typenum"
version = "1.16.0"
@@ -3516,12 +3678,24 @@ dependencies = [
"percent-encoding",
]
+[[package]]
+name = "urlencoding"
+version = "1.3.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5a1f0175e03a0973cf4afd476bef05c26e228520400eb1fd473ad417b1c00ffb"
+
[[package]]
name = "urlencoding"
version = "2.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8db7427f936968176eaa7cdf81b7f98b980b18495ec28f1b5791ac3bfe3eea9"
+[[package]]
+name = "utf-8"
+version = "0.7.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
+
[[package]]
name = "utf8parse"
version = "0.2.1"
@@ -3591,7 +3765,7 @@ dependencies = [
"rustc_version",
"rustversion",
"sysinfo",
- "time",
+ "time 0.3.23",
]
[[package]]
@@ -3619,6 +3793,43 @@ dependencies = [
"try-lock",
]
+[[package]]
+name = "warp"
+version = "0.3.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ba431ef570df1287f7f8b07e376491ad54f84d26ac473489427231e1718e1f69"
+dependencies = [
+ "bytes",
+ "futures-channel",
+ "futures-util",
+ "headers",
+ "http",
+ "hyper",
+ "log",
+ "mime",
+ "mime_guess",
+ "multer",
+ "percent-encoding",
+ "pin-project",
+ "rustls-pemfile",
+ "scoped-tls",
+ "serde",
+ "serde_json",
+ "serde_urlencoded",
+ "tokio",
+ "tokio-stream",
+ "tokio-tungstenite",
+ "tokio-util",
+ "tower-service",
+ "tracing",
+]
+
+[[package]]
+name = "wasi"
+version = "0.10.0+wasi-snapshot-preview1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f"
+
[[package]]
name = "wasi"
version = "0.11.0+wasi-snapshot-preview1"
@@ -3722,6 +3933,16 @@ dependencies = [
"once_cell",
]
+[[package]]
+name = "whoami"
+version = "1.4.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "22fc3756b8a9133049b26c7f61ab35416c130e8c09b660f5b3958b446f52cc50"
+dependencies = [
+ "wasm-bindgen",
+ "web-sys",
+]
+
[[package]]
name = "winapi"
version = "0.3.9"
@@ -3753,6 +3974,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
+[[package]]
+name = "windows"
+version = "0.48.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f"
+dependencies = [
+ "windows-targets 0.48.1",
+]
+
[[package]]
name = "windows-sys"
version = "0.45.0"
@@ -3919,7 +4149,7 @@ dependencies = [
"hmac",
"pbkdf2",
"sha1",
- "time",
+ "time 0.3.23",
"zstd",
]
diff --git a/Cargo.toml b/Cargo.toml
index 3bfe9831b9e..886bb06c188 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -4,14 +4,15 @@ members = [
"router",
"router/client",
"router/grpc-metadata",
+ "central",
"launcher"
]
[workspace.package]
-version = "0.9.4"
+version = "0.9.5"
edition = "2021"
-authors = ["Olivier Dehaene"]
-homepage = "https://github.com/huggingface/text-generation-inference"
+authors = ["Patrick Fernandes"]
+homepage = "https://github.com/coderpat/text-generation-inference"
[profile.release]
debug = 1
diff --git a/Makefile b/Makefile
index 7f534c7ccd7..458508d66f7 100644
--- a/Makefile
+++ b/Makefile
@@ -14,10 +14,13 @@ install-router:
install-launcher:
cd launcher && cargo install --path .
+install-central:
+ cd central && cargo install --path .
+
install-benchmark:
cd benchmark && cargo install --path .
-install: install-server install-router install-launcher install-custom-kernels
+install: install-server install-router install-launcher install-central install-custom-kernels
server-dev:
cd server && make run-dev
@@ -42,6 +45,27 @@ python-client-tests:
python-tests: python-server-tests python-client-tests
+run-llama2-benchmark:
+ text-generation-launcher --model-id lmsys/vicuna-7b-v1.5
+
+run-llama2-vicuna-7b:
+ text-generation-launcher --model-id lmsys/vicuna-7b-v1.5 --port 8080
+
+run-llama2-vicuna-7b-quantize:
+ text-generation-launcher --model-id lmsys/vicuna-7b-v1.5 --port 8080 --quantize bitsandbytes
+
+run-llama2-vicuna-13b:
+ text-generation-launcher --model-id lmsys/vicuna-13b-v1.5 --port 8080
+
+run-llama2-vicuna-13b-quantize:
+ text-generation-launcher --model-id lmsys/vicuna-13b-v1.5 --port 8080 --quantize bitsandbytes
+
+run-llama2-vicuna-33b-quantize:
+ text-generation-launcher --model-id lmsys/vicuna-33b-v1.3 --port 8080 --quantize bitsandbytes
+
+run-llama2-70b-instruct-quantize:
+ text-generation-launcher --model-id upstage/Llama-2-70b-instruct-v2 --port 8080 --quantize bitsandbytes
+
run-falcon-7b-instruct:
text-generation-launcher --model-id tiiuae/falcon-7b-instruct --port 8080
diff --git a/README.md b/README.md
index 2bbb6583788..9ef56cb6be9 100644
--- a/README.md
+++ b/README.md
@@ -1,22 +1,124 @@
-
+# LTI's **Text Generation Inference** Fork
-# Text Generation Inference
-
-
-
-
-
-
-
-
-
-
-A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co)
-to power LLMs api-inference widgets.
+A Rust, Python and gRPC server for text generation inference.
+
+Forked from [HuggingFace](https://huggingface.co)'s [Text Generation Inference](https://github.com/huggingface/text-generation-inference/) project (prior to its re-licensing), it's commercial-friendly and licensed under the Apache 2.0.
+
+## *A note on this fork*
+
+This fork was created mainly due to two reasons:
+1. Primarily, it allows us faster iteration and more flexibility, which is essential for our research uses. It also allows more control over development and documentation, crucial for our in-house uses at CMU.
+2. While we understand the reasons behind the re-licensing, we don't want our (research) contributions to be locked behind a restrictive license. This fork will not sync with the upstream repository, and will be updated independently.
+
+*For contributors*: If HuggingFace's upstream has a feature that you want to use, please open an issue first and discuss porting the functionality independently.
+Do not just copy the code over, as it will be rejected.
+
+## *For LTI/cluster users*
+
+### Getting started
+
+If you are new to using this library, and as it has being used in your cluster, we recommend by starting with a *client-only* installation, and using models launched by other users.
+
+To start, the `TGI_CENTRAL_ADDRESS` needs to be set, so that the client can know which servers to connect to. For example, in the LTI cluster, run
+
+```shell
+echo "export TGI_CENTRAL_ADDRESS=babel-3-36:8765" >> ~/.bashrc # if using a single machine, use `0.0.0.0:8765` instead
+source ~/.bashrc
+```
+
+To use the python client, install it with
+
+```shell
+cd clients/python
+pip install .
+```
+
+You can then query the API to list the models available in your cluster, and use models for inference.
+
+```python
+from text_generation import Client
+
+# get current models and pick the first one
+models = Client.list_from_central()
+model_name, model_addr = models[0]["name"], models[0]["address"]
+print(f"Using model {model_name} at {model_addr}")
+
+client = Client("http://" + model_addr)
+print(client.generate("What is Deep Learning?", max_new_tokens=20).generated_text)
+```
+
+#### Updating the environment
+
+In general, you don't have to recreate the environment every time you want to update the library.
+To just update the library, run in the base directory (in a previously created environment)
+
+```shell
+export DIR=`pwd`
+OPENSSL_DIR=${DIR}/.openssl \
+OPENSSL_LIB_DIR=${DIR}/.openssl/lib \
+OPENSSL_INCLUDE_DIR=${DIR}/.openssl/include \
+BUILD_EXTENSIONS=false \
+ make install
+```
+
+### Running your own servers
+
+If you are an LTI student using one of its cluster (or generally belong to an academic cluster that doesn't have docker installed), you can side-steps problems with installing system dependencies by using the [(mini)conda](https://docs.conda.io/en/latest/miniconda.html) package manager.
+
+Then, ***from your base environment***, run the install script:
+
+```shell
+bash setup_scripts/conda_server.sh
+```
+
+*Note*: This **takes a really long time**, up to 1.5-3 hour, sit back and realx while you wait for it.
+
+*Note*: if you are running in a cluster with `module` installed, make sure you deactivate all modules before running the script.
+
+This will create a conda environment with all the dependencies needed to run the model servers.
+
+You should then be able to launch models with the `text-generation-launcher` command, or by using one of the predefined MAKE rules
+```shell
+conda activate tgi-env
+make run-llama2-vicuna-7b
+```
+
+### Setting up a Central server
+
+If you are setting this library for use in your group/cluster for the first time, you will need (or at least benefit) from setting up a central server.
+See the instructions [in the package folder](./central/README.md).
+
+Remember to set the `TGI_CENTRAL_ADDRESS` environment variable (ideally for all the users in your cluster) to the address of the central server.
+
+### Chat-UI
+
+It is also possible to a simple web [chat-ui](./clients/chat-ui) to interact with models running in your server/cluster.
+This is a simple fork of [HuggingFace's Chat UI](https://github.com/huggingface/chat-ui) that communicates with the central controller to get the list of models available in the cluster, and then connects to the corresponding servers to generate text.
+
+For example, in Babel, you can access a running Chat-UI web-server with *port forwarding* by running
+
+```shell
+ssh babel -L 8888:babel-3-36:4173
+```
+
+and going to `localhost:8888` in your browser.
+
+
+
+
+
+
+Check the [README](./chat-ui/README.md) for more details.
+
+*Content below is from the original README.*
+
+---
+
+
## Table of contents
@@ -129,6 +231,7 @@ for response in client.generate_stream("What is Deep Learning?", max_new_tokens=
print(text)
```
+
### API documentation
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route.
diff --git a/assets/chatui01.png b/assets/chatui01.png
new file mode 100755
index 00000000000..6b24e7171b6
Binary files /dev/null and b/assets/chatui01.png differ
diff --git a/assets/chatui02.png b/assets/chatui02.png
new file mode 100755
index 00000000000..775f8c52b94
Binary files /dev/null and b/assets/chatui02.png differ
diff --git a/benchmark/dump_fast_tokenizer.py b/benchmark/dump_fast_tokenizer.py
new file mode 100644
index 00000000000..c2799b1faa6
--- /dev/null
+++ b/benchmark/dump_fast_tokenizer.py
@@ -0,0 +1,19 @@
+import os
+import json
+import argparse
+from transformers import AutoTokenizer
+
+def dump_fast_tokenizer(tokenizer_name, output_path):
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
+ tokenizer.save_pretrained(output_path)
+
+def main():
+ parser = argparse.ArgumentParser(description="Dump fast tokenizer json file")
+ parser.add_argument("--tokenizer-name", required=True, help="Name of the Hugging Face tokenizer")
+ parser.add_argument("--output", required=True, help="Output path for the fast tokenizer json file")
+ args = parser.parse_args()
+
+ dump_fast_tokenizer(args.tokenizer_name, args.output)
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs
index b57c652b9b4..85f236dcf94 100644
--- a/benchmark/src/generation.rs
+++ b/benchmark/src/generation.rs
@@ -218,6 +218,6 @@ fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String {
encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left);
// Decode
tokenizer
- .decode(Vec::from(encoding.get_ids()), false)
+ .decode(&Vec::from(encoding.get_ids()), false)
.unwrap()
}
diff --git a/central/Cargo.toml b/central/Cargo.toml
new file mode 100644
index 00000000000..0b84813b8a1
--- /dev/null
+++ b/central/Cargo.toml
@@ -0,0 +1,24 @@
+[package]
+name = "text-generation-central"
+description = "Text Generation Central controller tool"
+version.workspace = true
+edition.workspace = true
+authors.workspace = true
+homepage.workspace = true
+
+[[bin]]
+name = "text-generation-central"
+path = "src/main.rs"
+
+[dependencies]
+clap = { version = "4.1.4", features = ["derive", "env"] }
+warp = "0.3"
+serde = {version = "1.0.142", features = ["derive"]}
+serde_json = "1.0.93"
+reqwest = { version = "0.11.14", features = [] }
+tokio = { version = "1.25.0", features = ["full"] }
+chrono = "0.4"
+urlencoding = "1.1.1"
+bytes = "1.1"
+whoami = "1.4"
+
diff --git a/central/README.md b/central/README.md
new file mode 100644
index 00000000000..3a27d73a1dc
--- /dev/null
+++ b/central/README.md
@@ -0,0 +1,24 @@
+
+
+# Text Generation Central controller tool
+
+
+A lightweight tool for tracking which models are running, who is running them, in what ip and port, etc.
+
+## Install
+
+From the root of the project run:
+
+```shell
+make install-central
+```
+
+## Run
+
+To run the central controller tool, run:
+
+```shell
+text-generation-central --port $PORT
+```
+
+TODO: Docs on environment variables
\ No newline at end of file
diff --git a/central/src/main.rs b/central/src/main.rs
new file mode 100644
index 00000000000..cc51b453f89
--- /dev/null
+++ b/central/src/main.rs
@@ -0,0 +1,230 @@
+use clap::Parser;
+use warp::{Filter, Reply};
+use serde::{Serialize,Deserialize};
+use std::{sync::Arc, collections::HashMap};
+use tokio::sync::Mutex;
+
+/// App Configuration
+#[derive(Parser, Debug)]
+#[clap(author, version, about, long_about = None)]
+struct Args {
+ // The address the central uses to listen to server requests
+ #[clap(default_value = "0.0.0.0", long, env)]
+ hostname: String,
+
+ // The port the central uses to listen to server requests
+ #[clap(default_value = "8086", long, env)]
+ port: u16,
+
+ // The interval (in seconds) between central pings to models
+ #[clap(default_value = "60", long, env)]
+ ping_interval: u64,
+
+ // The maximum number of failed pings before a model is dropped
+ #[clap(default_value = "3", long, env)]
+ max_failed_pings: u32,
+
+ // By default is None, if set pings a server on launch and if alive registers it
+ #[clap(default_value = None, long, env)]
+ initial_ping: Option,
+}
+
+#[derive(Serialize, Deserialize, Clone, Debug)]
+pub struct ModelRecord {
+ pub name: String,
+ pub address: String,
+ pub owner: String,
+ pub is_quantized: bool,
+}
+
+#[derive(Deserialize, Clone, Debug)]
+struct ModelInfo {
+ docker_label: Option,
+ max_batch_total_tokens: u32,
+ max_best_of: u32,
+ max_concurrent_requests: u32,
+ max_input_length: u32,
+ max_stop_sequences: u32,
+ max_total_tokens: u32,
+ max_waiting_tokens: u32,
+ model_device_type: String,
+ model_dtype: String,
+ model_id: String,
+ model_pipeline_tag: String,
+ model_sha: String,
+ sha: String,
+ validation_workers: u32,
+ version: String,
+ waiting_served_ratio: f32,
+}
+
+type Models = Arc>>;
+
+
+// define function to print model info
+fn print_model_record(record: &ModelRecord) {
+ if record.is_quantized {
+ println!("\t{} (quant) - {} by {}", record.name, record.address, record.owner);
+ } else {
+ println!("\t{} - {} by {}", record.name, record.address, record.owner);
+ }
+}
+
+#[tokio::main]
+
+async fn main() -> Result<(), Box> {
+ let args = Args::parse();
+ let hostname = args.hostname;
+ let port = args.port;
+ let ping_interval = args.ping_interval;
+ let initial_ping = args.initial_ping;
+ // get current user from env
+ let user = whoami::username();
+
+ let models: Models = Arc::new(Mutex::new(HashMap::new()));
+
+ fn with_models(models: Models) -> impl Filter + Clone {
+ warp::any().map(move || models.clone())
+ }
+
+ async fn handle_model_notice(encoded_id: String, record: ModelRecord, models: Models) -> Result {
+ println!("Received model notice for {}", encoded_id);
+ let model_id = urlencoding::decode(&encoded_id).unwrap();
+ models.lock().await.insert(model_id, record);
+ Ok(warp::reply::with_status(
+ "Model registered successfully",
+ warp::http::StatusCode::OK,
+ ))
+ }
+
+ async fn handle_list_models(models: Models) -> Result {
+ let models = models.lock().await;
+ // print for debug
+ let mut models_list: Vec = vec![];
+ for (_, record) in models.iter() {
+ models_list.push(record.clone());
+ }
+ Ok(warp::reply::with_status(
+ warp::reply::json(&models_list).into_response(),
+ warp::http::StatusCode::OK,
+ ))
+ }
+
+ let model_notice_route = warp::path("model_up")
+ .and(warp::path::param::())
+ .and(warp::post())
+ .and(warp::body::json())
+ .and(with_models(models.clone()))
+ .and_then(handle_model_notice);
+
+ let list_models_route = warp::path("list_models")
+ .and(warp::get())
+ .and(with_models(models.clone()))
+ .and_then(handle_list_models);
+
+ let catch_all = warp::any()
+ .map(||{
+ println!("Warning: Received a request on an unhandled route");
+ warp::reply::with_status(
+ "Unhandled route!",
+ warp::http::StatusCode::NOT_FOUND,
+ )
+ });
+
+ let routes = model_notice_route
+ .or(list_models_route)
+ .or(catch_all);
+
+ let listener = warp::serve(routes).run((hostname.parse::().unwrap(), port));
+ let monitor = async {
+ // ping server if provided
+ if let Some(model_addr) = initial_ping {
+ // split server into ip and port variables strings
+ let model_ip = model_addr.split(":").collect::>()[0];
+ let model_port = model_addr.split(":").collect::>()[1];
+
+ let url = format!("http://{}:{}/info", model_ip, model_port);
+ let response = reqwest::get(&url).await;
+ match response {
+ Ok(response) => {
+ if response.status().is_success() {
+ let body = response.text().await?;
+ let info: ModelInfo = serde_json::from_str(&body)?;
+ let address = format!("{}:{}", model_ip, model_port);
+ models.lock().await.insert(
+ info.model_id.clone(),
+ // TODO: this is not the correct values
+ // we should get these from the model
+ ModelRecord {
+ name: info.model_id.clone(),
+ address: address,
+ owner: user.to_string(),
+ is_quantized: false,
+ });
+ } else {
+ println!("Model not alive");
+ }
+ },
+ Err(e) => {
+ println!("Model not alive");
+ }
+ };
+ }
+
+ // every Ns, for every model, ping in /health, and if not alive remove from models ()
+ loop {
+ let mut models = models.lock().await;
+ let mut keys_removal: Vec = vec![];
+
+ for (model, record) in models.iter() {
+ let url = format!("http://{}/health", record.address);
+ let response = reqwest::get(&url).await;
+ match response {
+ Ok(response) => {
+ if !response.status().is_success() {
+ keys_removal.push(model.to_string());
+ }
+ },
+ Err(e) => {
+ keys_removal.push(model.to_string());
+ }
+ }
+ };
+
+ let mut dropped_models: HashMap = HashMap::new();
+ for key in keys_removal {
+ if let Some(record) = models.remove(&key) {
+ dropped_models.insert(key, record);
+ }
+ }
+
+ // print current time
+ println!("------------------");
+ println!("Current time: {}", chrono::Local::now().format("%Y-%m-%d %H:%M:%S"));
+ // print models that stayed, one in each line
+ println!("Current Models:");
+ for (model, record) in models.iter() {
+ print_model_record(record);
+ }
+ // print dropped models
+ println!("Dropped Models:");
+ for (model, record) in dropped_models.iter() {
+ print_model_record(record);
+ }
+
+ std::mem::drop(models);
+ tokio::time::sleep(std::time::Duration::from_secs(ping_interval)).await;
+ }
+
+ Ok(()) as Result<(), Box>
+ };
+
+ // wrap listener to go into try join
+ let listener = async {
+ listener.await;
+ Ok(())
+ };
+ tokio::try_join!(listener, monitor);
+ Ok(())
+}
+
diff --git a/chat-ui b/chat-ui
new file mode 160000
index 00000000000..f65ca708ec2
--- /dev/null
+++ b/chat-ui
@@ -0,0 +1 @@
+Subproject commit f65ca708ec2d018fee108a6edf5e48f545d4032f
diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py
index bf045d47735..486b4d7775b 100644
--- a/clients/python/text_generation/client.py
+++ b/clients/python/text_generation/client.py
@@ -1,5 +1,6 @@
import json
import requests
+import os
from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError
@@ -35,6 +36,36 @@ class Client:
```
"""
+ @classmethod
+ def list_from_central(
+ cls,
+ central_url: str = None,
+ ):
+ """
+ Get the list of available models from the central model hub
+
+ Args:
+ central_url (`str`):
+ Text Generation Central URL
+
+ Returns:
+ List[Dict[str, str]]: List of available models
+ """
+ if central_url is None:
+ # check if environment variable is set
+ if os.environ.get("TGI_CENTRAL_ADDRESS") is None:
+ raise ValueError(
+ "No Central url provided and TGI_CENTRAL_ADDRESS environment variable is not set"
+ )
+ central_url = f"http://{os.environ.get(/'TGI_CENTRAL_ADDRESS')}"
+
+ # query from /models endpoint
+ resp = requests.get(f"{central_url}/list_models")
+ payload = resp.json()
+ if resp.status_code != 200:
+ raise parse_error(resp.status_code, payload)
+ return payload
+
def __init__(
self,
base_url: str,
@@ -75,6 +106,7 @@ def generate(
typical_p: Optional[float] = None,
watermark: bool = False,
decoder_input_details: bool = False,
+ top_tokens: Optional[int] = None,
) -> Response:
"""
Given a prompt, generate the following text
@@ -113,6 +145,8 @@ def generate(
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`):
Return the decoder input token logprobs and ids
+ top_tokens (`Optional[int]`):
+ Return the top `top_tokens` tokens with the highest logprobs at each step
Returns:
Response: generated response
@@ -134,6 +168,7 @@ def generate(
typical_p=typical_p,
watermark=watermark,
decoder_input_details=decoder_input_details,
+ top_tokens=top_tokens,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
@@ -148,6 +183,25 @@ def generate(
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
return Response(**payload[0])
+
+ def score(
+ self: str,
+ target: str,
+ ):
+ """ Utility function to score a target string (i.e. compute its logprob).
+
+ Mostly wraps the generate function, asking for 1 new token and returning
+ the logprob of the prompt.
+ """
+ # Use generate to get the score
+ resp = self.generate(
+ prompt=target,
+ do_sample=False,
+ max_new_tokens=1,
+ decoder_input_details=True,
+ )
+ # extract prefill details and cut off first
+ return resp.details.prefill[1:]
def generate_stream(
self,
@@ -164,6 +218,7 @@ def generate_stream(
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
+ top_tokens: Optional[int] = None,
) -> Iterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens
@@ -198,6 +253,8 @@ def generate_stream(
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
+ top_tokens (`Optional[int]`):
+ Return the top `top_tokens` tokens with the highest logprobs at each step
Returns:
Iterator[StreamResponse]: stream of generated tokens
@@ -219,6 +276,7 @@ def generate_stream(
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
+ top_tokens=top_tokens,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
@@ -317,6 +375,7 @@ async def generate(
typical_p: Optional[float] = None,
watermark: bool = False,
decoder_input_details: bool = False,
+ top_tokens: Optional[int] = None,
) -> Response:
"""
Given a prompt, generate the following text asynchronously
@@ -355,6 +414,8 @@ async def generate(
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`):
Return the decoder input token logprobs and ids
+ top_tokens (`Optional[int]`):
+ Return the top `top_tokens` tokens with the highest logprobs at each step
Returns:
Response: generated response
@@ -404,6 +465,8 @@ async def generate_stream(
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
+ top_tokens: Optional[int] = None,
+
) -> AsyncIterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens asynchronously
@@ -438,6 +501,8 @@ async def generate_stream(
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
+ top_tokens (`Optional[int]`):
+ Return the top `top_tokens` tokens with the highest logprobs at each step
Returns:
AsyncIterator[StreamResponse]: stream of generated tokens
@@ -459,6 +524,7 @@ async def generate_stream(
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
+ top_tokens=top_tokens,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py
index 548f0b639ce..d15bc82378e 100644
--- a/clients/python/text_generation/types.py
+++ b/clients/python/text_generation/types.py
@@ -33,6 +33,8 @@ class Parameters(BaseModel):
typical_p: Optional[float]
# Generate best_of sequences and return the one if the highest token logprobs
best_of: Optional[int]
+ # Return the `top_tokens` most likely tokens at each step
+ top_tokens: Optional[int]
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
watermark: bool = False
# Get generation details
@@ -100,6 +102,13 @@ def valid_typical_p(cls, v):
if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
return v
+
+ @validator("top_tokens")
+ def valid_top_tokens(cls, v):
+ if v is not None and v <= 0:
+ raise ValidationError("`top_tokens` must be strictly positive")
+ return v
+
class Request(BaseModel):
@@ -193,6 +202,8 @@ class Details(BaseModel):
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]
+ # Most likely tokens at each step
+ top_tokens: Optional[List[List[Token]]]
# Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]]
@@ -219,6 +230,8 @@ class StreamDetails(BaseModel):
class StreamResponse(BaseModel):
# Generated token
token: Token
+ # Most likely tokens at each step
+ top_tokens: Optional[List[Token]]
# Complete generated text
# Only available when the generation is finished
generated_text: Optional[str]
diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml
index 3e7f86d4e6f..aebf8e7e5d9 100644
--- a/launcher/Cargo.toml
+++ b/launcher/Cargo.toml
@@ -10,14 +10,17 @@ homepage.workspace = true
clap = { version = "4.1.4", features = ["derive", "env"] }
ctrlc = { version = "3.2.5", features = ["termination"] }
nix = "0.26.2"
+reqwest = { version = "0.11.14", features = ["blocking", "json"] }
serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.93"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
+urlencoding = "1.1.1"
+whoami = "1.4.0"
+hostname = "0.3"
[dev-dependencies]
float_eq = "1.0.1"
-reqwest = { version = "0.11.14", features = ["blocking", "json"] }
[build-dependencies]
vergen = { version = "8.0.0", features = ["build", "cargo", "git", "gitcl", "rustc", "si"] }
diff --git a/launcher/src/main.rs b/launcher/src/main.rs
index 2ad788a405b..ad9379bc9f3 100644
--- a/launcher/src/main.rs
+++ b/launcher/src/main.rs
@@ -16,6 +16,8 @@ use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{fs, io};
use tracing_subscriber::EnvFilter;
+use serde::Serialize;
+use serde_json::json;
mod env_runtime;
@@ -25,6 +27,15 @@ enum Quantization {
Gptq,
}
+#[derive(Serialize)]
+struct ModelRecord {
+ name: String,
+ address: String,
+ owner: String,
+ is_quantized: bool,
+}
+
+
impl std::fmt::Display for Quantization {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// To keep in track with `server`.
@@ -122,6 +133,12 @@ struct Args {
#[clap(default_value = "2", long, env)]
max_best_of: usize,
+ // This is the maximum allowed value for clients to set `top_tokens`.
+ // It is used to return the `top_tokens` most likely tokens at each generation
+ // rather than just the top one.
+ #[clap(default_value = "10", long, env)]
+ max_top_tokens: u32,
+
/// This is the maximum allowed value for clients to set `stop_sequences`.
/// Stop sequences are used to allow the model to stop on more than just
/// the EOS token, and enable more complex "prompting" where users can preprompt
@@ -276,6 +293,10 @@ struct Args {
#[clap(long, env)]
ngrok_edge: Option,
+ /// Central address, used to register the server in the central registry
+ #[clap(long, env)]
+ central_address: Option,
+
/// Display a lot of information about your runtime environment
#[clap(long, short, action)]
env: bool,
@@ -852,6 +873,8 @@ fn spawn_webserver(
args.max_concurrent_requests.to_string(),
"--max-best-of".to_string(),
args.max_best_of.to_string(),
+ "--max-top-tokens".to_string(),
+ args.max_top_tokens.to_string(),
"--max-stop-sequences".to_string(),
args.max_stop_sequences.to_string(),
"--max-input-length".to_string(),
@@ -1083,11 +1106,6 @@ fn main() -> Result<(), LauncherError> {
// Download and convert model weights
download_convert_model(&args, running.clone())?;
- if !running.load(Ordering::SeqCst) {
- // Launcher was asked to stop
- return Ok(());
- }
-
// Shared shutdown bool
let shutdown = Arc::new(AtomicBool::new(false));
// Shared shutdown channel
@@ -1097,6 +1115,34 @@ fn main() -> Result<(), LauncherError> {
// Shared channel to track shard status
let (status_sender, status_receiver) = mpsc::channel();
+ // clone args to avoid borrowing issues
+ // if central_address is None, check if enviroment variable is set
+ let central_address = match args.central_address.clone() {
+ Some(central_address) => Some(central_address),
+ None => match env::var("TGI_CENTRAL_ADDRESS") {
+ Ok(central_address) => Some(central_address),
+ Err(_) => None,
+ },
+ };
+ let encoded_id = urlencoding::encode(&args.model_id);
+ let ip = args.hostname.to_string();
+ let hostname = match ip.parse::() {
+ Ok(ip) => ip.ip().to_string(),
+ Err(_) => {
+ tracing::warn!("invalid hostname passed! will use system's hostname...");
+ // try to resolve hostname.into_string
+ whoami::hostname()
+ }
+ };
+ println!("final hostname: {}", hostname);
+ let model_record = ModelRecord {
+ name: args.model_id.clone(),
+ // build address string with hostnmae and port
+ address: format!("{}:{}", hostname, args.port),
+ owner: whoami::username(),
+ is_quantized: args.quantize.is_some()
+ };
+
spawn_shards(
num_shard,
&args,
@@ -1123,6 +1169,37 @@ fn main() -> Result<(), LauncherError> {
// Default exit code
let mut exit_code = Ok(());
+ // Ping central server to register model, using request
+ if let Some(central_address) = central_address {
+ println!("Attempting to register in Central at {}", central_address);
+ let url = format!("http://{}/model_up/%7B%7D", central_address, encoded_id.to_string());
+ let client = reqwest::blocking::Client::new();
+ let res = client
+ .post(&url)
+ .json(&model_record)
+ .send();
+
+ match res {
+ Ok(response) => {
+ if response.status().is_success() {
+ println!("Successfully registered on central server");
+ } else {
+ println!("Failed to register on central server");
+ // response
+ println!("Response: {:?}", response);
+ }
+ },
+ Err(e) => println!("Error occurred while initiating connection with central server: {}", e)
+ }
+ } else {
+ println!("No central server address provided. Skipping registration");
+ }
+
+ if !running.load(Ordering::SeqCst) {
+ // Launcher was asked to stop
+ return Ok(());
+ }
+
while running.load(Ordering::SeqCst) {
if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
tracing::error!("Shard {rank} crashed");
diff --git a/notebooks/test_client.ipynb b/notebooks/test_client.ipynb
new file mode 100644
index 00000000000..f828d91b091
--- /dev/null
+++ b/notebooks/test_client.ipynb
@@ -0,0 +1,273 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/mnt/data_2/patrick/conda/envs/tgi-env/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n"
+ ]
+ }
+ ],
+ "source": [
+ "import text_generation as tg"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# set environment variable\n",
+ "import os\n",
+ "os.environ['TGI_CENTRAL_ADDRESS'] = 'localhost:8765'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[{'name': '/mnt/data_2/patrick/croissantllm-models/small4_equals/', 'address': 'frightened-frank-flowers-fin-01:3000', 'owner': 'patrick', 'is_quantized': False}]\n"
+ ]
+ }
+ ],
+ "source": [
+ "servers = tg.Client.list_from_central()\n",
+ "print(servers)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "server_addr = servers[0]['address']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "client = tg.Client(f\"http://{server_addr}/")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "among the best in the country in their field of study. They are also among the best in the\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(client.generate(\"CMU's PhD students are\", max_new_tokens=20).generated_text)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " among the best in the country in their field of study. They are also among the best in the\n"
+ ]
+ }
+ ],
+ "source": [
+ "text = \"\"\n",
+ "for response in client.generate_stream(\"CMU's PhD students are\", max_new_tokens=20):\n",
+ " if not response.token.special:\n",
+ " text += response.token.text\n",
+ "print(text)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Getting Top K tokens at each step"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "among the best in\n",
+ "[Token(id=5684, text='among', logprob=-2.5429688, special=False), Token(id=4645, text='working', logprob=-3.4179688, special=False), Token(id=1135, text='the', logprob=-3.6757812, special=False)]\n",
+ "[Token(id=1135, text='the', logprob=-0.40478516, special=False), Token(id=3108, text='those', logprob=-2.7402344, special=False), Token(id=488, text='', logprob=-2.8417969, special=False)]\n",
+ "[Token(id=3284, text='best', logprob=-1.5273438, special=False), Token(id=2481, text='most', logprob=-1.5664062, special=False), Token(id=3263, text='top', logprob=-2.2148438, special=False)]\n",
+ "[Token(id=1147, text='in', logprob=-0.45898438, special=False), Token(id=1171, text='and', logprob=-3.7089844, special=False), Token(id=5208, text='students', logprob=-3.9511719, special=False)]\n"
+ ]
+ }
+ ],
+ "source": [
+ "resp = client.generate(\"CMU's PhD students are\", max_new_tokens=4, top_tokens=3)\n",
+ "print(resp.generated_text)\n",
+ "for top_tokens in resp.details.top_tokens:\n",
+ " print(top_tokens)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Benchmarking"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# create 4 random sentences\n",
+ "SAMPLES = [\n",
+ " \"The quick brown fox jumps over the lazy dog.\",\n",
+ " \"The five boxing wizards jump quickly.\",\n",
+ " \"All questions asked by five watch experts amazed the judge.\",\n",
+ " \"Jack quietly moved up front and seized the big ball of wax.\",\n",
+ "]"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Sync Client"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "The quick brown fox jumps over the lazy dog.\n",
+ "The quick brown fox j\n",
+ "\n",
+ "The first step in the process is to create a list of potential candidates. This list should include\n",
+ "\n",
+ "The first time I heard the term “fake news” was in the context of the \n",
+ "He was a master of disguise, and he had a knack for getting into places he\n",
+ "CPU times: user 36.8 ms, sys: 3.42 ms, total: 40.2 ms\n",
+ "Wall time: 1.95 s\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "for sample in SAMPLES:\n",
+ " print(client.generate(sample, max_new_tokens=20).generated_text)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Async Client"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import asyncio\n",
+ "import nest_asyncio\n",
+ "nest_asyncio.apply()\n",
+ "\n",
+ "async_client = tg.AsyncClient(f\"http://{server_addr}/")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "The quick brown fox jumps over the lazy dog.\n",
+ "The quick brown fox j\n",
+ "\n",
+ "The first step in the process is to create a list of potential candidates. This list should include\n",
+ "\n",
+ "The first time I heard the term “fake news” was in the context of the \n",
+ "He was a master of disguise, and he had a knack for getting into places he\n",
+ "CPU times: user 105 ms, sys: 5.03 ms, total: 110 ms\n",
+ "Wall time: 620 ms\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "async def batch_generate():\n",
+ " return await asyncio.gather(*[async_client.generate(sample, max_new_tokens=20) for sample in SAMPLES])\n",
+ "\n",
+ "results = asyncio.run(batch_generate())\n",
+ "for r in results:\n",
+ " print(r.generated_text)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "tgi-env",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.18"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/proto/generate.proto b/proto/generate.proto
index 57d79bcaf27..157598cf304 100644
--- a/proto/generate.proto
+++ b/proto/generate.proto
@@ -91,6 +91,8 @@ message Request {
StoppingCriteriaParameters stopping_parameters = 5;
/// Return prefill logprobs
bool prefill_logprobs = 6;
+ /// Return most likely n tokens
+ uint32 top_tokens = 7;
}
message Batch {
@@ -141,6 +143,17 @@ message PrefillTokens {
repeated string texts = 3;
}
+message TopTokens {
+ /// Top Token IDs
+ repeated uint32 ids = 1;
+ /// Top Logprobs
+ repeated float logprobs = 2;
+ /// Top Token Texts
+ repeated string texts = 3;
+ /// If the tokens are special
+ repeated bool is_special = 6;
+}
+
message Generation {
/// Request ID
uint64 request_id = 1;
@@ -156,6 +169,8 @@ message Generation {
bool token_is_special = 6;
/// Complete generated text
optional GeneratedText generated_text = 7;
+ // Top tokens
+ TopTokens top_tokens = 8;
}
message FilterBatchRequest {
diff --git a/router/client/src/client.rs b/router/client/src/client.rs
index 7753f307c0a..10212adb6f1 100644
--- a/router/client/src/client.rs
+++ b/router/client/src/client.rs
@@ -131,6 +131,7 @@ impl Client {
ignore_eos_token: false,
}),
prefill_logprobs: true,
+ top_tokens: 20,
});
n_tokens += max_input_length;
}
diff --git a/router/src/health.rs b/router/src/health.rs
index a3cacdcd016..972dfe486fa 100644
--- a/router/src/health.rs
+++ b/router/src/health.rs
@@ -50,6 +50,7 @@ impl Health {
stop_sequences: vec![],
ignore_eos_token: false,
}),
+ top_tokens: 0,
};
let batch = Batch {
id: BATCH_ID,
diff --git a/router/src/infer.rs b/router/src/infer.rs
index 188ddc6420c..079b28e1a2c 100644
--- a/router/src/infer.rs
+++ b/router/src/infer.rs
@@ -138,12 +138,15 @@ impl Infer {
&self,
request: GenerateRequest,
) -> Result {
+ let use_top_tokens = request.parameters.top_tokens.is_some_and(|x| x > 0);
+
// Create stream and keep semaphore permit as long as generate lives
let (_permit, mut stream) = self.generate_stream(request).await?;
// Return values
let mut result_prefill = Vec::new();
let mut result_tokens = Vec::new();
+ let mut result_top_tokens = Vec::new();
let mut result_generated_text = None;
let mut result_start = None;
let mut result_queued = None;
@@ -164,7 +167,10 @@ impl Infer {
.collect();
}
// Push last token
- InferStreamResponse::Token(token) => result_tokens.push(token),
+ InferStreamResponse::Intermediate {token, top_tokens} => {
+ result_tokens.push(token);
+ result_top_tokens.push(top_tokens);
+ }
// Final message
// Set return values
InferStreamResponse::End {
@@ -172,8 +178,10 @@ impl Infer {
generated_text,
start,
queued,
+ top_tokens
} => {
result_tokens.push(token);
+ result_top_tokens.push(top_tokens);
result_generated_text = Some(generated_text);
result_start = Some(start);
result_queued = Some(queued)
@@ -185,12 +193,16 @@ impl Infer {
if let (Some(generated_text), Some(queued), Some(start)) =
(result_generated_text, result_queued, result_start)
{
+ let top_tokens = if use_top_tokens
+ {result_top_tokens} else
+ {Vec::new()};
Ok(InferResponse {
prefill: result_prefill,
tokens: result_tokens,
generated_text,
queued,
start,
+ top_tokens,
})
} else {
let err = InferError::IncompleteGeneration;
@@ -520,6 +532,24 @@ fn send_responses(
special: generation.token_is_special,
};
+ let mut top_tokens = Vec::new();
+ if let Some(top_tokens_) = generation.top_tokens {
+ top_tokens.extend(
+ top_tokens_
+ .ids
+ .into_iter()
+ .zip(top_tokens_.logprobs.into_iter())
+ .zip(top_tokens_.texts.into_iter())
+ .zip(top_tokens_.is_special.into_iter())
+ .map(|(((id, logprob), text), special)| Token {
+ id,
+ text,
+ logprob,
+ special,
+ })
+ )
+ }
+
if let Some(generated_text) = generation.generated_text {
// Generation has ended
stopped = true;
@@ -527,6 +557,7 @@ fn send_responses(
entry.response_tx.send_timeout(
Ok(InferStreamResponse::End {
token,
+ top_tokens,
generated_text,
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
@@ -536,7 +567,7 @@ fn send_responses(
} else {
// Send message
entry.response_tx.send_timeout(
- Ok(InferStreamResponse::Token(token)),
+ Ok(InferStreamResponse::Intermediate { token, top_tokens }),
Duration::from_millis(10),
)?;
}
@@ -566,10 +597,14 @@ pub(crate) enum InferStreamResponse {
// Optional first message
Prefill(PrefillTokens),
// Intermediate messages
- Token(Token),
+ Intermediate {
+ token: Token,
+ top_tokens: Vec,
+ },
// Last message
End {
token: Token,
+ top_tokens: Vec,
generated_text: GeneratedText,
start: Instant,
queued: Instant,
@@ -583,6 +618,7 @@ pub(crate) struct InferResponse {
pub(crate) generated_text: GeneratedText,
pub(crate) queued: Instant,
pub(crate) start: Instant,
+ pub(crate) top_tokens: Vec>,
}
#[derive(Debug, Error)]
diff --git a/router/src/lib.rs b/router/src/lib.rs
index 7dff7a114ec..a948cf444b1 100644
--- a/router/src/lib.rs
+++ b/router/src/lib.rs
@@ -67,6 +67,9 @@ pub(crate) struct GenerateParameters {
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
pub best_of: Option,
#[serde(default)]
+ #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
+ pub top_tokens: Option,
+ #[serde(default)]
#[schema(
exclusive_minimum = 0.0,
nullable = true,
@@ -144,6 +147,7 @@ fn default_max_new_tokens() -> u32 {
fn default_parameters() -> GenerateParameters {
GenerateParameters {
best_of: None,
+ top_tokens: None,
temperature: None,
repetition_penalty: None,
top_k: None,
@@ -235,6 +239,8 @@ pub(crate) struct BestOfSequence {
pub seed: Option,
pub prefill: Vec,
pub tokens: Vec,
+ #[serde(skip_serializing_if = "Vec::is_empty")]
+ pub top_tokens: Vec>,
}
#[derive(Serialize, ToSchema)]
@@ -249,6 +255,8 @@ pub(crate) struct Details {
pub tokens: Vec,
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of_sequences: Option>,
+ #[serde(skip_serializing_if = "Vec::is_empty")]
+ pub top_tokens: Vec>,
}
#[derive(Serialize, ToSchema)]
@@ -272,6 +280,8 @@ pub(crate) struct StreamDetails {
#[derive(Serialize, ToSchema)]
pub(crate) struct StreamResponse {
pub token: Token,
+ #[serde(skip_serializing_if = "Vec::is_empty")]
+ pub top_tokens: Vec,
#[schema(nullable = true, default = "null", example = "test")]
pub generated_text: Option,
#[schema(nullable = true, default = "null")]
diff --git a/router/src/main.rs b/router/src/main.rs
index 484643cb252..d8994bac71e 100644
--- a/router/src/main.rs
+++ b/router/src/main.rs
@@ -27,6 +27,8 @@ struct Args {
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
+ #[clap(default_value = "10", long, env)]
+ max_top_tokens: u32,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "1024", long, env)]
@@ -74,6 +76,7 @@ fn main() -> Result<(), RouterError> {
let Args {
max_concurrent_requests,
max_best_of,
+ max_top_tokens,
max_stop_sequences,
max_input_length,
max_total_tokens,
@@ -258,6 +261,7 @@ fn main() -> Result<(), RouterError> {
compat_return_full_text,
max_concurrent_requests,
max_best_of,
+ max_top_tokens,
max_stop_sequences,
max_input_length,
max_total_tokens,
diff --git a/router/src/queue.rs b/router/src/queue.rs
index 2d8d6d1c1c2..aab3530da47 100644
--- a/router/src/queue.rs
+++ b/router/src/queue.rs
@@ -235,6 +235,7 @@ impl State {
truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
+ top_tokens: entry.request.top_tokens,
});
// Set batch_time
entry.batch_time = Some(Instant::now());
@@ -323,6 +324,7 @@ mod tests {
repetition_penalty: 0.0,
watermark: false,
},
+ top_tokens: 0,
stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false,
max_new_tokens: 1,
diff --git a/router/src/server.rs b/router/src/server.rs
index 9af94951b2a..b3ba94edf73 100644
--- a/router/src/server.rs
+++ b/router/src/server.rs
@@ -193,6 +193,7 @@ async fn generate(
generated_tokens: response.generated_text.generated_tokens,
prefill: response.prefill,
tokens: response.tokens,
+ top_tokens: response.top_tokens,
seed: response.generated_text.seed,
}
})
@@ -206,6 +207,7 @@ async fn generate(
tokens: response.tokens,
seed: response.generated_text.seed,
best_of_sequences,
+ top_tokens: response.top_tokens,
})
}
false => None,
@@ -387,12 +389,16 @@ async fn generate_stream(
// Prefill is ignored
InferStreamResponse::Prefill(_) => {}
// Yield event for every new token
- InferStreamResponse::Token(token) => {
+ InferStreamResponse::Intermediate{
+ token,
+ top_tokens,
+ } => {
tracing::debug!(parent: &span, "Token: {:?}", token);
// StreamResponse
let stream_token = StreamResponse {
token,
+ top_tokens: top_tokens,
generated_text: None,
details: None,
};
@@ -402,6 +408,7 @@ async fn generate_stream(
// Yield event for last token and compute timings
InferStreamResponse::End {
token,
+ top_tokens,
generated_text,
start,
queued,
@@ -453,6 +460,7 @@ async fn generate_stream(
let stream_token = StreamResponse {
token,
+ top_tokens,
generated_text: Some(output_text),
details
};
@@ -510,6 +518,7 @@ pub async fn run(
compat_return_full_text: bool,
max_concurrent_requests: usize,
max_best_of: usize,
+ max_top_tokens: u32,
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
@@ -572,6 +581,7 @@ pub async fn run(
validation_workers,
tokenizer,
max_best_of,
+ max_top_tokens,
max_stop_sequences,
max_input_length,
max_total_tokens,
diff --git a/router/src/validation.rs b/router/src/validation.rs
index be835bf0a07..ca345c0bbec 100644
--- a/router/src/validation.rs
+++ b/router/src/validation.rs
@@ -14,6 +14,7 @@ use tracing::{instrument, Span};
pub struct Validation {
/// Validation parameters
max_best_of: usize,
+ max_top_tokens: u32,
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
@@ -26,6 +27,7 @@ impl Validation {
workers: usize,
tokenizer: Option,
max_best_of: usize,
+ max_top_tokens: u32,
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
@@ -53,6 +55,7 @@ impl Validation {
Self {
max_best_of,
sender,
+ max_top_tokens,
max_stop_sequences,
max_input_length,
max_total_tokens,
@@ -130,6 +133,7 @@ impl Validation {
) -> Result {
let GenerateParameters {
best_of,
+ top_tokens,
temperature,
repetition_penalty,
top_k,
@@ -218,6 +222,15 @@ impl Validation {
}
};
+ let top_tokens = top_tokens
+ .map(|value| {
+ if value > self.max_top_tokens {
+ return Err(ValidationError::TopTokens(self.max_top_tokens, value));
+ }
+ Ok(value)
+ })
+ .unwrap_or(Ok(0))?;
+
// Check if inputs is empty
if request.inputs.is_empty() {
return Err(EmptyInput);
@@ -263,6 +276,7 @@ impl Validation {
truncate: truncate.unwrap_or(self.max_input_length) as u32,
parameters,
stopping_parameters,
+ top_tokens: top_tokens,
})
}
@@ -311,7 +325,7 @@ fn prepare_input(
// truncate encoding and decode new inputs
encoding.truncate(truncate, 0, TruncationDirection::Left);
let inputs = tokenizer
- .decode(Vec::from(encoding.get_ids()), false)
+ .decode(&Vec::from(encoding.get_ids()), false)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
(inputs, encoding.len())
}
@@ -336,6 +350,7 @@ pub(crate) struct ValidGenerateRequest {
pub decoder_input_details: bool,
pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters,
+ pub top_tokens: u32,
}
#[derive(Error, Debug)]
@@ -344,6 +359,10 @@ pub enum ValidationError {
BestOf(usize, usize),
#[error("`best_of` != 1 is not allowed for this endpoint")]
BestOfDisabled,
+ #[error("`top_tokens` must be >= 0 and <= {0}. Given: {1}")]
+ TopTokens(u32, u32),
+ #[error("`top_tokens` != 0 is not allowed for this endpoint")]
+ TopTokensDisabled,
#[error("you must use sampling when `best_of` is > 1")]
BestOfSampling,
#[error("`seed` must not be set when `best_of` > 1")]
@@ -390,14 +409,16 @@ mod tests {
async fn test_validation_max_new_tokens() {
let tokenizer = None;
let max_best_of = 2;
- let max_stop_sequence = 3;
- let max_input_length = 4;
- let max_total_tokens = 5;
+ let max_top_tokens = 3;
+ let max_stop_sequence = 4;
+ let max_input_length = 5;
+ let max_total_tokens = 6;
let workers = 1;
let validation = Validation::new(
workers,
tokenizer,
max_best_of,
+ max_top_tokens,
max_stop_sequence,
max_input_length,
max_total_tokens,
@@ -417,9 +438,10 @@ mod tests {
async fn test_validation_input_length() {
let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2;
- let max_stop_sequence = 3;
- let max_input_length = 4;
- let max_total_tokens = 5;
+ let max_tokens = 3;
+ let max_stop_sequence = 4;
+ let max_input_length = 5;
+ let max_total_tokens = 6;
let workers = 1;
let validation = Validation::new(
workers,
@@ -435,7 +457,7 @@ mod tests {
.validate_input("Hello".to_string(), None, max_new_tokens)
.await
{
- Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (),
+ Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
_ => panic!("Unexpected not max new tokens"),
}
}
diff --git a/server/Makefile b/server/Makefile
index a4ce6d8b7a2..ea6bd0052c4 100644
--- a/server/Makefile
+++ b/server/Makefile
@@ -1,5 +1,4 @@
include Makefile-flash-att
-include Makefile-flash-att-v2
include Makefile-vllm
unit-tests:
diff --git a/server/Makefile-flash-att b/server/Makefile-flash-att
index bc1d37ef5e2..d938480e094 100644
--- a/server/Makefile-flash-att
+++ b/server/Makefile-flash-att
@@ -1,16 +1,18 @@
-flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec
+flash_attention_commit := v2.0.4
flash-attention:
# Clone flash attention
pip install packaging
- git clone https://github.com/HazyResearch/flash-attention.git
+ git clone https://github.com/Dao-AILab/flash-attention.git
+ cd flash-attention && git fetch && git checkout $(flash_attention_commit)
-build-flash-attention: flash-attention
- cd flash-attention && git fetch && git checkout $(flash_att_commit)
- cd flash-attention && python setup.py build
- cd flash-attention/csrc/rotary && python setup.py build
- cd flash-attention/csrc/layer_norm && python setup.py build
+install-flash-attention: flash-attention
+ pip uninstall flash-attention -y || true
+ cd flash-attention && pip install .
+ cd flash-attention/csrc/layer_norm && pip install .
+ cd flash-attention/csrc/rotary && pip install .
-install-flash-attention: build-flash-attention
- pip uninstall flash_attn rotary_emb dropout_layer_norm -y || true
- cd flash-attention && python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install
\ No newline at end of file
+
+test-flash-attention: flash-attention
+ pip install pytest
+ cd flash-attention && pytest -q -s tests/test_flash_attn.py
\ No newline at end of file
diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2
deleted file mode 100644
index a7d633563d8..00000000000
--- a/server/Makefile-flash-att-v2
+++ /dev/null
@@ -1,13 +0,0 @@
-flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc
-
-flash-attention-v2:
- # Clone flash attention
- pip install packaging
- git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2
-
-build-flash-attention-v2: flash-attention-v2
- cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit)
- cd flash-attention-v2 && python setup.py build
-
-install-flash-attention-v2: build-flash-attention-v2
- cd flash-attention-v2 && python setup.py install
\ No newline at end of file
diff --git a/server/Makefile-vllm b/server/Makefile-vllm
index 9100fff4e3c..1621c481571 100644
--- a/server/Makefile-vllm
+++ b/server/Makefile-vllm
@@ -1,13 +1,14 @@
-vllm_commit := 084ca75d4271f8f67be731bc58e0d41d8e0afd3a
+vllm_commit := "3d40c83"
vllm:
# Clone vllm
- git clone https://github.com/OlivierDehaene/vllm.git
-
-build-vllm: vllm
+ git clone https://github.com/vllm-project/vllm.git
cd vllm && git fetch && git checkout $(vllm_commit)
- cd vllm && python setup.py build
-install-vllm: build-vllm
+install-vllm: vllm
pip uninstall vllm -y || true
- cd vllm && python setup.py install
\ No newline at end of file
+ cd vllm && pip install .
+
+test-vllm: vllm
+ pip install pytest
+ cd vllm && pytest -q -s tests/kernels/test_attention.py
\ No newline at end of file
diff --git a/server/pyproject.toml b/server/pyproject.toml
index 3ee3351c6e0..28d1ea3bc12 100644
--- a/server/pyproject.toml
+++ b/server/pyproject.toml
@@ -15,21 +15,22 @@ grpcio-status = "^1.51.1"
grpcio-reflection = "^1.51.1"
grpc-interceptor = "^0.15.0"
typer = "^0.6.1"
-accelerate = { version = "^0.19.0", optional = true }
-bitsandbytes = { version = "^0.38.1", optional = true }
-safetensors = "0.3.1"
+accelerate = { version = "^0.20.0", optional = true }
+bitsandbytes = { version = "^0.41.1", optional = true }
+safetensors = "^0.4.0"
loguru = "^0.6.0"
opentelemetry-api = "^1.15.0"
opentelemetry-exporter-otlp = "^1.15.0"
opentelemetry-instrumentation-grpc = "^0.36b0"
hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97"
-tokenizers = "0.13.3"
-huggingface-hub = "^0.14.1"
-transformers = "4.29.2"
+tokenizers = "^0.13.3"
+huggingface-hub = "^0.16.4"
+transformers = "^4.32.2"
einops = "^0.6.1"
texttable = { version = "^1.6.7", optional = true }
datasets = { version = "^2.14.0", optional = true }
+scipy = "^1.11.1"
[tool.poetry.extras]
accelerate = ["accelerate"]
diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py
index e74c03311c6..06e0862c6a7 100644
--- a/server/text_generation_server/cli.py
+++ b/server/text_generation_server/cli.py
@@ -161,20 +161,24 @@ def download_weights(
for p in local_pt_files
]
try:
- from transformers import AutoConfig
import transformers
+ import json
+ from huggingface_hub import hf_hub_download
- config = AutoConfig.from_pretrained(
- model_id,
- revision=revision,
- )
- architecture = config.architectures[0]
+ logger.info(f"is_local_model: {is_local_model}")
+ if is_local_model:
+ config_filename = os.path.join(model_id, "config.json")
+ else:
+ config_filename = hf_hub_download(model_id, revision=revision, filename="config.json")
+
+ with open(config_filename, "r") as f:
+ config = json.load(f)
+ architecture = config["architectures"][0]
class_ = getattr(transformers, architecture)
# Name for this varible depends on transformers version.
discard_names = getattr(class_, "_tied_weights_keys", [])
- discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", []))
except Exception as e:
discard_names = []
diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py
index cbdf480837c..7e3ea652418 100644
--- a/server/text_generation_server/models/causal_lm.py
+++ b/server/text_generation_server/models/causal_lm.py
@@ -12,9 +12,11 @@
PrefillTokens,
Generation,
GeneratedText,
+ TopTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
+from text_generation_server.utils.tokens import batch_top_tokens
tracer = trace.get_tracer(__name__)
@@ -42,6 +44,8 @@ class CausalLMBatch(Batch):
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
+ top_tokens: List[int]
+ top_tokens_tensor: torch.Tensor
# Metadata used for padding
max_input_length: int
@@ -72,6 +76,7 @@ def from_pb(
inputs = []
next_token_choosers = []
stopping_criterias = []
+ top_tokens = []
prefix_offsets = []
read_offsets = []
requests_idx_mapping = {}
@@ -88,6 +93,7 @@ def from_pb(
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
+ top_tokens.append(r.top_tokens)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
@@ -123,6 +129,9 @@ def from_pb(
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
+ top_tokens_tensor = torch.tensor(
+ top_tokens, device=device, dtype=torch.int64
+ )
return cls(
batch_id=pb.id,
@@ -138,6 +147,8 @@ def from_pb(
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
+ top_tokens=top_tokens,
+ top_tokens_tensor=top_tokens_tensor,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
@@ -163,6 +174,7 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
next_token_choosers = []
stopping_criterias = []
+ top_tokens = []
total_remaining_decode_tokens = 0
new_padding_right_offset = 0
@@ -184,6 +196,8 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
+ top_tokens.append(self.top_tokens[idx])
+
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
@@ -223,6 +237,7 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
del past_values
+ top_tokens_tensor = self.top_tokens_tensor[keep_indices]
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
self.requests = requests
@@ -235,6 +250,8 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
+ self.top_tokens = top_tokens
+ self.top_tokens_tensor = top_tokens_tensor
self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
@@ -262,6 +279,7 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
+ top_tokens = []
max_tokens = 0
# Batch tensors
@@ -281,6 +299,7 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
+ top_tokens.extend(batch.top_tokens)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
@@ -310,6 +329,12 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
(total_batch_size, max_input_length + padding_right_offset),
)
+ if top_tokens_tensor is None:
+ top_tokens_tensor = batches[0].top_tokens_tensor.new_zeros(
+ total_batch_size,
+ )
+ top_tokens_tensor[start_index:end_index] = batch.top_tokens_tensor
+
# We need to slice the attention mask to remove padding from previous steps
# and to remove unused allocated space
left_offset = max_input_length - batch.max_input_length
@@ -438,6 +463,8 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
+ top_tokens=top_tokens,
+ top_tokens_tensor=top_tokens_tensor,
max_input_length=max_input_length,
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
@@ -545,6 +572,12 @@ def generate_token(
batch.past_key_values,
)
+ batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
+ batch.top_tokens,
+ batch.top_tokens_tensor,
+ torch.softmax(logits[:, -1], -1),
+ )
+
# Results
generations: List[Generation] = []
stopped = True
@@ -559,6 +592,9 @@ def generate_token(
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
+ batch.top_tokens,
+ batch_top_token_ids,
+ batch_top_token_logprobs,
)
# For each member of the batch
@@ -571,6 +607,9 @@ def generate_token(
next_token_chooser,
stopping_criteria,
all_input_ids,
+ top_tokens,
+ top_token_ids,
+ top_token_logprobs,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
@@ -637,6 +676,24 @@ def generate_token(
else:
prefill_tokens = None
+ if top_tokens > 0:
+ toptoken_texts = self.tokenizer.batch_decode(
+ top_token_ids,
+ clean_up_tokenization_spaces=False,
+ skip_special_tokens=False,
+ )
+ special_toptokens = [
+ token_id in self.all_special_ids for token_id in top_token_ids
+ ]
+ top_tokens_obj = TopTokens(
+ top_token_ids,
+ top_token_logprobs,
+ toptoken_texts,
+ special_toptokens,
+ )
+ else:
+ top_tokens_obj = None
+
generation = Generation(
request.id,
prefill_tokens,
@@ -645,6 +702,7 @@ def generate_token(
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
+ top_tokens_obj,
)
generations.append(generation)
diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
index b6285856fc3..d528178259e 100644
--- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
+++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
@@ -30,8 +30,8 @@
import dropout_layer_norm
# vllm imports
-import vllm_cache_ops
-import vllm_attention_ops
+import vllm.cache_ops
+import vllm.attention_ops
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
@@ -185,10 +185,12 @@ def __init__(
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
- self.rotary_emb = PositionRotaryEmbedding.load(
- prefix=f"{prefix}.rotary_emb", weights=weights
+ self.rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=self.head_size,
+ base=10000.0,
+ device=weights.device
)
-
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
@@ -247,7 +249,7 @@ def forward(
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
- vllm_cache_ops.reshape_and_cache(
+ vllm.cache_ops.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
@@ -270,7 +272,7 @@ def forward(
else:
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
block_size = kv_cache[1].shape[3]
- vllm_attention_ops.single_query_cached_kv_attention(
+ vllm.attention_ops.paged_attention_v1(
attn_output,
query,
kv_cache[0],
@@ -281,6 +283,7 @@ def forward(
input_lengths,
block_size,
max_s,
+ None
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@@ -383,6 +386,7 @@ def forward(
class FlashLlamaModel(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
+ self.config = config
process_group = weights.process_group
self.tp_rank = process_group.rank()
@@ -485,4 +489,4 @@ def forward(
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
- return logits
+ return logits
\ No newline at end of file
diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
index e7c8ced4ca7..6d524d308f0 100644
--- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
+++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
@@ -28,8 +28,8 @@
from typing import Optional, List, Tuple
# vllm imports
-import vllm_cache_ops
-import vllm_attention_ops
+import vllm.cache_ops
+import vllm.attention_ops
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
@@ -141,7 +141,7 @@ def forward(
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
- vllm_cache_ops.reshape_and_cache(
+ vllm.cache_ops.reshape_and_cache(
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
)
@@ -164,7 +164,7 @@ def forward(
else:
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
block_size = kv_cache[1].shape[3]
- vllm_attention_ops.single_query_cached_kv_attention(
+ vllm.attention_ops.paged_attention_v1(
attn_output,
qkv[:, 0],
kv_cache[0],
diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
index 3570b283e59..fc731a95cb5 100644
--- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
+++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
@@ -7,8 +7,8 @@
from typing import Optional, List, Tuple
# vllm imports
-import vllm_cache_ops
-import vllm_attention_ops
+import vllm.cache_ops
+import vllm.attention_ops
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
@@ -191,7 +191,7 @@ def forward(
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
- vllm_cache_ops.reshape_and_cache(
+ vllm.cache_ops.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
@@ -214,7 +214,7 @@ def forward(
else:
# kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size]
block_size = kv_cache[1].shape[3]
- vllm_attention_ops.single_query_cached_kv_attention(
+ vllm.attention_ops.paged_attention_v1(
attn_output,
query,
kv_cache[0],
@@ -307,7 +307,7 @@ def forward(
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
- vllm_cache_ops.reshape_and_cache(
+ vllm.cache_ops.reshape_and_cache(
kv[:, :, 0].contiguous(),
kv[:, :, 1].contiguous(),
kv_cache[0],
@@ -334,7 +334,7 @@ def forward(
else:
# kv_cache[1] => [num_blocks, num_groups, head_size, block_size]
block_size = kv_cache[1].shape[3]
- vllm_attention_ops.single_query_cached_kv_attention(
+ vllm.attention_ops.paged_attention_v1(
attn_output,
query,
kv_cache[0],
diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
index 2dd0a5ee4e0..4b587752d7e 100644
--- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
+++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
@@ -6,8 +6,8 @@
from typing import Optional, List, Tuple
# vllm imports
-import vllm_cache_ops
-import vllm_attention_ops
+import vllm.cache_ops
+import vllm.attention_ops
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
@@ -258,7 +258,7 @@ def forward(
query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size)
- vllm_cache_ops.reshape_and_cache(
+ vllm.cache_ops.reshape_and_cache(
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
)
@@ -281,7 +281,7 @@ def forward(
else:
# kv_cache[1] => [num_blocks, 1, head_size, block_size]
block_size = kv_cache[1].shape[3]
- vllm_attention_ops.single_query_cached_kv_attention(
+ vllm.attention_ops.paged_attention_v1(
attn_output,
query,
kv_cache[0],
diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py
index 7de51358c8d..c8fe1061460 100644
--- a/server/text_generation_server/models/flash_causal_lm.py
+++ b/server/text_generation_server/models/flash_causal_lm.py
@@ -16,10 +16,12 @@
PrefillTokens,
Generation,
GeneratedText,
+ TopTokens
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION
+from text_generation_server.utils.tokens import batch_top_tokens
tracer = trace.get_tracer(__name__)
@@ -165,6 +167,8 @@ class FlashCausalLMBatch(Batch):
# Generation helpers
next_token_chooser: HeterogeneousNextTokenChooser
stopping_criterias: List[StoppingCriteria]
+ top_tokens: List[int]
+ top_tokens_tensor: torch.Tensor
# Number of blocks in this batch
blocks: int
@@ -217,6 +221,7 @@ def from_pb(
next_token_chooser_parameters = []
stopping_criterias = []
+ top_tokens = []
# Cumulative length
cumulative_length = 0
@@ -259,6 +264,7 @@ def from_pb(
)
max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria)
+ top_tokens.append(r.top_tokens)
# Paged attention
# Remove one as the first token des not have a past
@@ -353,6 +359,10 @@ def from_pb(
prefill_next_token_indices, dtype=torch.int64, device=device
)
+ top_tokens_tensor = torch.tensor(
+ top_tokens, device=device, dtype=torch.int64
+ )
+
return cls(
batch_id=pb.id,
requests=pb.requests,
@@ -378,6 +388,8 @@ def from_pb(
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
+ top_tokens=top_tokens,
+ top_tokens_tensor=top_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
)
@@ -417,6 +429,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
read_offsets = []
stopping_criterias = []
+ top_tokens = []
blocks = 0
max_blocks = 0
@@ -443,6 +456,8 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
+ top_tokens.append(self.top_tokens[idx])
+
remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
@@ -487,6 +502,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices)
+ top_tokens_tensor = self.top_tokens_tensor[indices]
start_slots = torch.tensor(start_slots, dtype=torch.int64)
@@ -518,6 +534,8 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
+ top_tokens=top_tokens,
+ top_tokens_tensor=top_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
)
@@ -567,6 +585,10 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
(total_batch_size, max_length)
)
+ top_tokens_tensor = batches[0].top_tokens_tensor.new_zeros(
+ total_batch_size,
+ )
+
start_slots = []
block_tables = []
all_input_ids = []
@@ -577,6 +599,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
next_token_chooser_parameters = []
stopping_criterias = []
+ top_tokens = []
# Cumulative length
cumulative_batch_size = 0
@@ -601,6 +624,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
+ top_tokens_tensor[start_index:end_index] = batch.top_tokens_tensor
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
slots[slots_start_index:slots_end_index] = batch.slots
@@ -623,6 +647,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
stopping_criterias.extend(batch.stopping_criterias)
+ top_tokens.extend(batch.top_tokens)
# Update
cumulative_batch_size += len(batch)
@@ -666,6 +691,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
+ top_tokens=top_tokens,
+ top_tokens_tensor=top_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
)
@@ -831,10 +858,14 @@ def generate_token(
else:
next_token_logits = out
- next_input_ids, next_token_logprobs = batch.next_token_chooser(
+ next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
)
+ batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
+ batch.top_tokens, batch.top_tokens_tensor, logprobs
+ )
+
if prefill:
if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
@@ -931,8 +962,11 @@ def generate_token(
batch.all_input_ids,
batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds,
+ batch.top_tokens,
next_token_ids,
next_token_logprobs,
+ batch_top_token_ids,
+ batch_top_token_logprobs,
)
# For each member of the batch
@@ -945,8 +979,11 @@ def generate_token(
all_input_ids,
do_sample,
seed,
+ top_tokens,
next_token_id,
next_token_logprob,
+ top_token_ids,
+ top_token_logprobs,
) in enumerate(iterator):
# Append next token to all tokens
all_input_ids.append(next_token_id)
@@ -1005,6 +1042,24 @@ def generate_token(
else:
prefill_tokens = None
+ if top_tokens > 0:
+ toptoken_texts = self.tokenizer.batch_decode(
+ top_token_ids,
+ clean_up_tokenization_spaces=False,
+ skip_special_tokens=False,
+ )
+ special_toptokens = [
+ token_id in self.all_special_ids for token_id in top_token_ids
+ ]
+ top_tokens_obj = TopTokens(
+ top_token_ids,
+ top_token_logprobs,
+ toptoken_texts,
+ special_toptokens,
+ )
+ else:
+ top_tokens_obj = None
+
generation = Generation(
request.id,
prefill_tokens,
@@ -1013,6 +1068,7 @@ def generate_token(
next_token_text,
next_token_id in self.all_special_ids,
generated_text,
+ top_tokens_obj,
)
generations.append(generation)
diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py
index 96fb0c266cb..f725a627c6b 100644
--- a/server/text_generation_server/models/flash_llama.py
+++ b/server/text_generation_server/models/flash_llama.py
@@ -2,7 +2,8 @@
import torch.distributed
from opentelemetry import trace
-from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast
+from transformers import AutoTokenizer
+from transformers.models.llama import LlamaTokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
@@ -44,7 +45,8 @@ def __init__(
trust_remote_code=trust_remote_code,
)
except Exception:
- tokenizer = LlamaTokenizerFast.from_pretrained(
+ # use AutoTokenizer as fallback in case it's a costum tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py
index 9e5c21d1542..5422236d4ed 100644
--- a/server/text_generation_server/models/seq2seq_lm.py
+++ b/server/text_generation_server/models/seq2seq_lm.py
@@ -11,9 +11,11 @@
Batch,
Generation,
PrefillTokens,
+ TopTokens
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
+from text_generation_server.utils.tokens import batch_top_tokens
tracer = trace.get_tracer(__name__)
@@ -51,6 +53,8 @@ class Seq2SeqLMBatch(Batch):
# Metadata used for padding
max_input_length: int
+ top_tokens: List[int]
+ top_tokens_tensor: torch.Tensor
max_decoder_input_length: int
padding_right_offset: int
@@ -78,6 +82,7 @@ def from_pb(
inputs = []
next_token_choosers = []
stopping_criterias = []
+ top_tokens = []
decoder_input_lengths = []
prefix_offsets = []
@@ -97,6 +102,7 @@ def from_pb(
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
+ top_tokens.append(r.top_tokens)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
@@ -127,6 +133,9 @@ def from_pb(
read_offsets.append(1)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
+ top_tokens_tensor = torch.tensor(
+ top_tokens, device=device, dtype=torch.int64
+ )
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
return cls(
@@ -146,6 +155,8 @@ def from_pb(
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
+ top_tokens=top_tokens,
+ top_tokens_tensor=top_tokens_tensor,
max_input_length=max_input_length.item(),
max_decoder_input_length=1,
padding_right_offset=padding_right_offset,
@@ -173,6 +184,7 @@ def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]:
next_token_choosers = []
stopping_criterias = []
+ top_tokens = []
max_input_length = 0
max_decoder_input_length = 0
@@ -204,6 +216,7 @@ def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]:
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
+ top_tokens.append(self.top_tokens[idx])
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
@@ -239,6 +252,7 @@ def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]:
layer[2] = layer[2][keep_indices, :, -max_input_length:]
layer[3] = layer[3][keep_indices, :, -max_input_length:]
+ top_tokens_tensor = self.top_tokens_tensor[keep_indices]
max_tokens = (
len(request_ids) * (max_input_length + max_decoder_input_length)
+ remaining_decode_tokens
@@ -254,6 +268,8 @@ def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]:
self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
+ self.top_tokens = top_tokens
+ self.top_tokens_tensor = top_tokens_tensor
self.max_input_length = max_input_length
self.max_decoder_input_length = max_decoder_input_length
self.padding_right_offset = padding_right_offset
@@ -289,6 +305,7 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
read_offsets = []
next_token_choosers = []
stopping_criterias = []
+ top_tokens = []
max_tokens = 0
# Batch tensors
@@ -312,6 +329,7 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
read_offsets.extend(batch.read_offsets)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
+ top_tokens.extend(batch.top_tokens)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
@@ -384,6 +402,12 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
),
)
+ if top_tokens_tensor is None:
+ top_tokens_tensor = batches[0].top_tokens_tensor.new_zeros(
+ total_batch_size,
+ )
+ top_tokens_tensor[start_index:end_index] = batch.top_tokens_tensor
+
# Copy to correct indices
encoder_last_hidden_state[
start_index:end_index, -batch.max_input_length :, :
@@ -488,6 +512,8 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
+ top_tokens=top_tokens,
+ top_tokens_tensor=top_tokens_tensor,
max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length,
padding_right_offset=padding_right_offset,
@@ -613,6 +639,12 @@ def generate_token(
batch.past_key_values,
)
+ batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
+ batch.top_tokens,
+ batch.top_tokens_tensor,
+ torch.softmax(logits[:, -1], -1),
+ )
+
# Finished requests
generations: List[Generation] = []
stopped = True
@@ -628,6 +660,9 @@ def generate_token(
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_decoder_input_ids,
+ batch.top_tokens,
+ batch_top_token_ids,
+ batch_top_token_logprobs,
)
# For each member of the batch
@@ -641,6 +676,9 @@ def generate_token(
next_token_chooser,
stopping_criteria,
all_decoder_input_ids,
+ top_tokens,
+ top_token_ids,
+ top_token_logprobs,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
@@ -698,6 +736,24 @@ def generate_token(
else:
prefill_tokens = None
+ if top_tokens > 0:
+ toptoken_texts = self.tokenizer.batch_decode(
+ top_token_ids,
+ clean_up_tokenization_spaces=False,
+ skip_special_tokens=False,
+ )
+ special_toptokens = [
+ token_id in self.all_special_ids for token_id in top_token_ids
+ ]
+ top_tokens_obj = TopTokens(
+ top_token_ids,
+ top_token_logprobs,
+ toptoken_texts,
+ special_toptokens,
+ )
+ else:
+ top_tokens_obj = None
+
generation = Generation(
request.id,
prefill_tokens,
@@ -706,6 +762,7 @@ def generate_token(
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
+ top_tokens_obj,
)
generations.append(generation)
diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py
index 28ca8147eb9..98ad95c2859 100644
--- a/server/text_generation_server/models/types.py
+++ b/server/text_generation_server/models/types.py
@@ -71,6 +71,24 @@ def __len__(self):
return len(self.token_ids)
+@dataclass
+class TopTokens:
+ token_ids: List[int]
+ logprobs: List[float]
+ texts: List[str]
+ is_special: List[bool]
+
+ def to_pb(self) -> generate_pb2.TopTokens:
+ return generate_pb2.TopTokens(
+ ids=self.token_ids,
+ logprobs=self.logprobs,
+ texts=self.texts,
+ is_special=self.is_special,
+ )
+
+ def __len__(self):
+ return len(self.token_ids)
+
@dataclass
class Generation:
request_id: int
@@ -80,6 +98,8 @@ class Generation:
token_text: str
token_is_special: bool
generated_text: Optional[GeneratedText]
+ # Optional for now, since it's not yet supported for every model.
+ top_tokens: Optional[TopTokens]
def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation(
@@ -94,4 +114,5 @@ def to_pb(self) -> generate_pb2.Generation:
generated_text=self.generated_text.to_pb()
if self.generated_text is not None
else None,
+ top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None,
)
diff --git a/server/text_generation_server/utils/convert.py b/server/text_generation_server/utils/convert.py
index 8d414ecac91..0b62f520836 100644
--- a/server/text_generation_server/utils/convert.py
+++ b/server/text_generation_server/utils/convert.py
@@ -29,9 +29,15 @@ def _remove_duplicate_names(
[name for name in shared if _is_complete(state_dict[name])]
)
if not complete_names:
- raise RuntimeError(
- f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
- )
+ if len(shared) == 1:
+ # Force contiguous
+ name = list(shared)[0]
+ state_dict[name] = state_dict[name].clone()
+ complete_names = {name}
+ else:
+ raise RuntimeError(
+ f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
+ )
keep_name = sorted(list(complete_names))[0]
diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py
index c472d1fceab..6c8e1ca2cae 100644
--- a/server/text_generation_server/utils/flash_attn.py
+++ b/server/text_generation_server/utils/flash_attn.py
@@ -15,39 +15,22 @@
is_sm90 = major == 9 and minor == 0
HAS_FLASH_ATTN = False
-HAS_FLASH_ATTN_V2 = False
try:
- try:
- import flash_attn_2_cuda
- except ImportError:
- raise ImportError(
- "Flash Attention V2 is not installed.\n"
- "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
- "or install flash attention v2 with `cd server && make install install-flash-attention-v2`"
- )
- if not (is_sm8x or is_sm90):
- raise ImportError(
- f"GPU with CUDA capability {major} {minor} is not supported for "
- "Flash Attention V2"
- )
- HAS_FLASH_ATTN_V2 = True
-except ImportError as e:
- try:
- import flash_attn_cuda
- except ImportError:
- raise ImportError(
- "Flash Attention is not installed.\n"
- "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
- "or install flash attention with `cd server && make install install-flash-attention`"
- ) from e
-
- if not (is_sm75 or is_sm8x or is_sm90):
- raise ImportError(
- f"GPU with CUDA capability {major} {minor} is not supported"
- ) from e
- logger.warning(f"Unable to use Flash Attention V2: {e}")
+ import flash_attn_2_cuda as flash_attn_cuda
+ import flash_attn
HAS_FLASH_ATTN = True
-
+except ImportError as e:
+ raise ImportError(
+ f"Flash Attention V2 is not installed.\n"
+ f"Error message: {e}\n"
+ "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
+ "or install flash attention v2 with `cd server && make install install-flash-attention-v2`"
+ )
+if not (is_sm8x or is_sm90):
+ raise ImportError(
+ f"GPU with CUDA capability {major} {minor} is not supported for "
+ "Flash Attention V2"
+ )
def attention(
q,
@@ -58,52 +41,8 @@ def attention(
max_s,
softmax_scale,
):
- if HAS_FLASH_ATTN_V2:
- return flash_attn_2_cuda.varlen_fwd(
- q,
- k,
- v,
- out,
- cu_seqlens,
- cu_seqlens,
- max_s,
- max_s,
- 0.0,
- softmax_scale,
- False,
- True,
- False,
- None,
- )
-
if HAS_FLASH_ATTN:
- # Flash attention v1 requires q, k and v to have the same number of heads
- if k.shape[1] != q.shape[1]:
- # MQA expand
- if k.shape[1] == 1:
- k = k.expand(-1, q.shape[1], -1)
- # Grouped attention reshape
- else:
- original_shape = k.shape
- k = (
- k.unsqueeze(2)
- .expand(-1, -1, q.shape[1] // k.shape[1], -1)
- .reshape(original_shape[0], -1, original_shape[2])
- )
- if v.shape[1] != q.shape[1]:
- # MQA expand
- if v.shape[1] == 1:
- v = v.expand(-1, q.shape[1], -1)
- # Grouped attention reshape
- else:
- original_shape = v.shape
- v = (
- v.unsqueeze(2)
- .expand(-1, -1, q.shape[1] // v.shape[1], -1)
- .reshape(original_shape[0], -1, original_shape[2])
- )
-
- return flash_attn_cuda.fwd(
+ return flash_attn_cuda.varlen_fwd(
q,
k,
v,
@@ -117,7 +56,6 @@ def attention(
False,
True,
False,
- 0,
None,
)
diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py
index 7a45808ec92..bae70f755ee 100644
--- a/server/text_generation_server/utils/layers.py
+++ b/server/text_generation_server/utils/layers.py
@@ -375,35 +375,68 @@ def forward(self, hidden_states, residual=None):
try:
from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb
+
+ def _create_inv_freq(dim, base, device):
+ inv_freq = 1.0 / (
+ base
+ ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
+ )
+ return inv_freq
+
+ def _get_rope_config(config):
+ if os.getenv("ROPE_SCALING", None) is not None:
+ rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])}
+ return rope_scaling
+ return getattr(config, "rope_scaling", None)
+
class PositionRotaryEmbedding(nn.Module):
- def __init__(self, inv_freq):
+ def __init__(self, inv_freq, scaling_factor):
super().__init__()
-
self.inv_freq = inv_freq
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
+ self.scaling_factor = scaling_factor
+ self.dynamic_args = None
@classmethod
- def static(cls, dim, base, device):
- inv_freq = 1.0 / (
- base
- ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
- )
- return cls(inv_freq)
+ def static(cls, config, dim, base, device):
+ inv_freq = _create_inv_freq(dim, base, device)
+ scaling_factor = None
+ rope_scaling = _get_rope_config(config)
+ if rope_scaling is not None:
+ scaling_factor = rope_scaling["factor"]
+ if rope_scaling["type"] == "linear":
+ pass
+ elif rope_scaling["type"] == "dynamic":
+ return DynamicPositionRotaryEmbedding(dim=dim, max_position_embeddings=config.max_position_embeddings, base=base, device=inv_freq.device, scaling_factor=scaling_factor)
+ else:
+ raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
+ return cls(inv_freq, scaling_factor)
@classmethod
- def load(cls, prefix, weights):
+ def load(cls, config, prefix, weights):
# XXX: Always load this in float32 !
dtype = weights.dtype
weights.dtype = torch.float32
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
weights.dtype = dtype
- return cls(inv_freq)
+ scaling_factor = None
+ rope_scaling = _get_rope_config(config)
+ if rope_scaling is not None:
+ scaling_factor = rope_scaling["factor"]
+ if rope_scaling["type"] == "linear":
+ pass
+ elif rope_scaling["type"] == "dynamic":
+ return DynamicPositionRotaryEmbedding(dim=2*inv_freq.shape[0], max_position_embeddings=config.max_position_embeddings, base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor)
+ else:
+ raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
+ return cls(inv_freq, scaling_factor)
+
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
@@ -441,5 +474,35 @@ def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
return x
+ class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
+ def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
+ inv_freq = _create_inv_freq(dim, base, device)
+ super().__init__(inv_freq, scaling_factor)
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+
+ def _update_cos_sin_cache(self, dtype, device, seqlen):
+ # Reset the tables if the sequence length has changed,
+ # or if we're on a new device (possibly due to tracing for instance)
+ if (
+ seqlen > self._seq_len_cached
+ or self._cos_cached.device != device
+ or self._cos_cached.dtype != dtype
+ ):
+ if seqlen > self.max_position_embeddings:
+ newbase = self.base * ((self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
+ self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device)
+ self._seq_len_cached = seqlen
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
+ if self.scaling_factor is not None:
+ t /= self.scaling_factor
+ # Don't do einsum, it converts fp32 to fp16
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+
+ freqs = torch.outer(t, self.inv_freq.to(device=t.device))
+ self._cos_cached = torch.cos(freqs).to(dtype)
+ self._sin_cached = torch.sin(freqs).to(dtype)
+
except ImportError:
pass
diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py
index b83af59150f..5b8400f2136 100644
--- a/server/text_generation_server/utils/tokens.py
+++ b/server/text_generation_server/utils/tokens.py
@@ -229,11 +229,11 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
scores = warper(input_ids, scores)
next_ids = self.choice(scores)
- next_logprobs = torch.gather(
- torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1)
- ).view(-1)
+ logprobs = torch.log_softmax(scores, -1)
+ next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
+ return next_ids, next_logprobs, logprobs
+
- return next_ids, next_logprobs
def filter(self, indices):
if self.watermark_processor is not None:
@@ -339,3 +339,59 @@ def filter(self, indices):
self.greedy_indices = new_greedy_indices
self.sampling_mapping = new_sampling_mapping
return self
+
+
+def batch_top_tokens(
+ top_tokens: list[int], top_tokens_tensor: torch.Tensor, logprobs: torch.Tensor
+) -> Tuple[List[List[int]], List[List[float]]]:
+ """Find the top n most likely tokens for a batch of generations.
+ When multiple tokens have equal probabilities and they don't all fit, the
+ remaining tokens are also returned.
+
+ Basically copied from HF's original repo to save some time
+
+ Args:
+ top_tokens: List specifying the number of top tokens to retrieve for each item in the batch.
+ top_tokens_tensor: Torch tensor equivalent of top_tokens for use in tensor operations.
+ logprobs: Torch tensor of log probabilities, shape (batch_size, vocab_size).
+
+ Returns:
+ A tuple containing two lists:
+ 1. The indices of the top tokens for each logprob tensor in the batch.
+ 2. The values of the top tokens for each logprob tensor in the batch.
+ """
+ max_top_n = max(top_tokens)
+ # Early exit when top_tokens is not used
+ if max_top_n == 0:
+ return [[]] * len(top_tokens), [[]] * len(top_tokens)
+
+ # Ensure top_n doesn't exceed vocab size
+ top_tokens = [min(tok, logprobs.size(-1)) for tok in top_tokens]
+
+ # From https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
+ # Sorted topk is faster than torch.sort() since we only need a small subset
+ sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values
+ nth_highest = torch.gather(
+ sorted_top_k, 1, (top_tokens_tensor - 1).clip(min=0).unsqueeze(1)
+ )
+ nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
+
+ # Find the new "fuzzy" top n values
+ top_n_indices = (logprobs >= nth_highest).nonzero()
+ _, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
+
+ top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True)
+ top_n_ishes = top_n_ishes.tolist()
+ top_indices = top_k.indices.tolist()
+ top_values = top_k.values.tolist()
+
+ return (
+ [
+ idxs[:n] if req_n > 0 else []
+ for idxs, n, req_n in zip(top_indices, top_n_ishes, top_tokens)
+ ],
+ [
+ vals[:n] if req_n > 0 else []
+ for vals, n, req_n in zip(top_values, top_n_ishes, top_tokens)
+ ],
+ )
\ No newline at end of file
diff --git a/server/vllm_testscript.py b/server/vllm_testscript.py
new file mode 100644
index 00000000000..eedbb75aca9
--- /dev/null
+++ b/server/vllm_testscript.py
@@ -0,0 +1,22 @@
+# Tests if VLLM works correctly
+import vllm
+import time
+
+prompts = [
+ 'Hello, my name is',
+ 'CMU\'s PhD students are',
+]
+sampling_params = vllm.SamplingParams(temperature=0.8, top_p=0.95)
+
+llm = vllm.LLM(model="lmsys/vicuna-7b-v1.5")
+
+# time the generation
+start = time.time()
+outputs = llm.generate(prompts, sampling_params)
+end = time.time()
+for output in outputs:
+ prompt = output.prompt
+ generated = output.outputs[0].text
+ print(f'Prompt: {prompt!r}, Generated: {generated!r}')
+print()
+print(f'Time taken: {end - start:.2f}s')
\ No newline at end of file
diff --git a/setup_scripts/conda_client.sh b/setup_scripts/conda_client.sh
new file mode 100644
index 00000000000..8bc3bee87d1
--- /dev/null
+++ b/setup_scripts/conda_client.sh
@@ -0,0 +1,34 @@
+ENV_NAME=tgi-env-client
+# get the directory of this script, and go one up to get the root directory
+DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
+DIR="$(dirname "$DIR")"
+N_THREADS=8
+INSTALL_CHATUI=true
+
+set -eo pipefail
+
+# check if CONDA_HOME is set and create environment
+if [ -z "$CONDA_HOME" ]
+then
+ echo "Please set CONDA_HOME to the location of your conda installation"
+ exit 1
+fi
+
+source ${CONDA_HOME}/etc/profile.d/conda.sh
+conda create -y -n ${ENV_NAME} python=3.9
+conda activate ${ENV_NAME}
+conda install -y -c conda-forge mamba
+
+# install client
+cd ${DIR}/clients/python
+pip install .
+
+echo $PATH
+echo $LD_LIBRARY_PATH
+
+if [ "$INSTALL_CHATUI" = true ] ; then
+ # install chat-ui
+ cd ${DIR}/chat-ui
+ mamba install -y -c conda-forge mongodb pymongo "nodejs>=18"
+ npm install
+fi
\ No newline at end of file
diff --git a/setup_scripts/conda_server.sh b/setup_scripts/conda_server.sh
new file mode 100644
index 00000000000..2da4e053e16
--- /dev/null
+++ b/setup_scripts/conda_server.sh
@@ -0,0 +1,208 @@
+#!/bin/zsh
+# Script for setting up a conda environment with for launching servers
+# It sidesteps system-wide installations by relying on conda for most packages
+# and by building openssl from source
+# TODO: only got it to work with a static build of OpenSSL, which is not ideal
+# get the directory of this script, and go one up to get the root directory
+DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
+DIR="$(dirname "$DIR")"
+
+# currently can only build in TIR without extensions
+# seems un-important, as it only affects BLOOM/NEOX
+ENV_NAME=tgi-env
+BUILD_EXTENSIONS=false
+BUILD_VLLM=true
+BUILD_FLASHATTN=true
+TEST_EXTRA=true
+BENCHMARK=false
+SERVER_WAIT=180
+N_THREADS=4
+
+# Parse command line arguments
+while (( "$#" )); do
+ case "$1" in
+ --env-name)
+ ENV_NAME=$2
+ shift 2
+ ;;
+ --build-extensions)
+ BUILD_EXTENSIONS=true
+ shift 1
+ ;;
+ --light-mode)
+ BUILD_VLLM=false
+ BUILD_FLASHATTN=false
+ shift 1
+ ;;
+ --no-tests)
+ TEST_EXTRA=false
+ shift 1
+ ;;
+ --benchmark)
+ BENCHMARK=true
+ shift 1
+ ;;
+ --server-wait)
+ SERVER_WAIT=$2
+ shift 2
+ ;;
+ --n-threads)
+ N_THREADS=$2
+ shift 2
+ ;;
+ --) # end argument parsing
+ shift
+ break
+ ;;
+ -*|--*=) # unsupported flags
+ echo "Error: Unsupported flag $1" >&2
+ exit 1
+ ;;
+ *) # preserve positional arguments
+ PARAMS="$PARAMS $1"
+ shift
+ ;;
+ esac
+done
+# set positional arguments in their proper place
+eval set -- "$PARAMS"
+
+set -eo pipefail
+
+# check if CONDA_PREFIX is set and create environment
+if [ -z "$CONDA_PREFIX" ]
+then
+ echo "(Mini)conda does not seem to be installed, please install it first or set CONDA_PREFIX appropriately"
+ exit 1
+fi
+export CONDA_HOME=$CONDA_PREFIX
+source ${CONDA_HOME}/etc/profile.d/conda.sh
+
+# python can't handle this dependency madness, switch to C++
+conda install -y -c conda-forge mamba
+# we need to add the base path to get mamba to work inside the new environment
+export PATH=${CONDA_HOME}/bin:$PATH
+
+echo "Creating conda environment ${ENV_NAME}..."
+mamba create -y -n ${ENV_NAME} python=3.9
+conda activate ${ENV_NAME}
+
+# # Install dependencies
+mamba install -y -c conda-forge coreutils "gxx<12.0"
+mamba install -y -c conda-forge curl git tar
+mamba install -y -c conda-forge "rust>=1.65.0"
+mamba install -y -c conda-forge openssh
+mamba install -y -c "nvidia/label/cuda-11.8.0" cuda-toolkit
+# pin pytorch due to some cuda-issue in pytorch==2.1.0 / something with vllm
+mamba install -y -c pytorch -c nvidia pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8
+
+# bring in the conda environment variables forward
+# (not sure if needed, but added during debugging and kept for now)
+export LD_LIBRARY_PATH=${CONDA_HOME}/envs/${ENV_NAME}/lib:$LD_LIBRARY_PATH
+export PATH=${CONDA_HOME}/envs/${ENV_NAME}/bin:$PATH
+export CUDA_HOME=${CONDA_HOME}/envs/${ENV_NAME}
+
+# add cargo bin
+export PATH=~/.cargo/bin:$PATH
+
+# add protoc
+export PROTOC_ZIP=protoc-21.12-linux-x86_64.zip
+mkdir -p /tmp/protoc
+mkdir -p ~/local/bin
+mkdir -p ~/local/include
+cd /tmp/protoc
+curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP
+unzip -o $PROTOC_ZIP -d ~/local/ bin/protoc
+unzip -o $PROTOC_ZIP -d ~/local/ 'include/*'
+cd $DIR
+rm -rf /tmp/protoc
+export LD_LIBRARY_PATH=~/local/lib:$LD_LIBRARY_PATH
+export PATH=~/local/bin:$PATH
+
+# download and build openssl
+mkdir -p /tmp/openssl
+cd /tmp/openssl
+wget https://www.openssl.org/source/openssl-1.1.1l.tar.gz -O openssl.tar.gz
+tar -xzf openssl.tar.gz
+cd openssl-1.1.1l
+./config --prefix=${DIR}/.openssl --openssldir=${DIR}/.openssl
+make -j $N_THREADS
+make install
+cd $DIR
+rm -rf /tmp/openssl
+export LD_LIBRARY_PATH=${DIR}/.openssl/lib:$LD_LIBRARY_PATH
+export PATH=${DIR}/.openssl/bin:$PATH
+
+# install ninja for faster compilation of CUDA kernels and setup workdir
+pip install ninja
+# export MAX_JOBS to limit ninjas parallelism
+export MAX_JOBS=$N_THREADS
+cd ${DIR}/server
+mkdir -p workdir
+rm -rf workdir/*
+
+# install vllm
+if [ "$BUILD_VLLM" = true ] ; then
+ cp Makefile-vllm workdir/Makefile
+ cd workdir && sleep 1
+ make install-vllm
+
+ cd ${DIR}/server
+ if [ "$TEST_EXTRA" = true ] ; then
+ python3 vllm_testscript.py
+ fi
+ rm -rf workdir/*
+fi
+
+# install base package
+cd ${DIR}
+OPENSSL_DIR=${DIR}/.openssl \
+OPENSSL_LIB_DIR=${DIR}/.openssl/lib \
+OPENSSL_INCLUDE_DIR=${DIR}/.openssl/include \
+BUILD_EXTENSIONS=$BUILD_EXTENSIONS \
+ make install
+
+
+# install flash attention
+if [ "$BUILD_FLASHATTN" = true ] ; then
+ cd ${DIR}/server
+ cp Makefile-flash-att workdir/Makefile
+ cd workdir && sleep 1
+ make install-flash-attention
+ if [ "$TEST_EXTRA" = true ] ; then
+ make test-flash-attention
+ fi
+ cd ${DIR}/server
+fi
+
+rm -rf workdir
+
+# override protobuf
+pip install 'protobuf<3.21'
+
+# # install python client
+cd ${DIR}/clients/python
+pip install .
+
+# run a example server
+if [ "$BENCHMARK" = true ] ; then
+ cd ${DIR}
+ # trap signal to avoid orphan server process
+ trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT
+ # launch server as background process, checking for errors
+ make run-llama2-benchmark &
+ # sleep to make sure server has time to boot
+ sleep $SERVER_WAIT
+
+ OPENSSL_DIR=${DIR}/.openssl \
+ OPENSSL_LIB_DIR=${DIR}/.openssl/lib \
+ OPENSSL_INCLUDE_DIR=${DIR}/.openssl/include \
+ make install-benchmark
+ python benchmark/dump_fast_tokenizer.py --tokenizer-name=lmsys/vicuna-7b-v1.5 --output=/tmp/vicuna-7b-v1.5/
+ text-generation-benchmark --tokenizer-name=/tmp/vicuna-7b-v1.5
+fi
+
+# set default conda environment variables
+conda env config vars set LD_LIBRARY_PATH=${LD_LIBRARY_PATH}
+conda env config vars set PATH=${PATH}
+conda env config vars set CUDA_HOME=${CUDA_HOME}