diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml deleted file mode 100644 index 124e6a33ee0..00000000000 --- a/.github/workflows/build.yaml +++ /dev/null @@ -1,246 +0,0 @@ -name: Build and push docker image to internal registry - -on: - workflow_dispatch: - push: - branches: - - 'main' - tags: - - 'v*' - pull_request: - paths: - - ".github/workflows/build.yaml" - - "integration-tests/**" - - "server/**" - - "proto/**" - - "router/**" - - "launcher/**" - - "Cargo.lock" - - "rust-toolchain.toml" - - "Dockerfile" - branches: - - 'main' - -jobs: - start-runner: - name: Start self-hosted EC2 runner - runs-on: ubuntu-latest - env: - AWS_REGION: us-east-1 - EC2_AMI_ID: ami-03cfed9ea28f4b002 - EC2_INSTANCE_TYPE: g5.12xlarge - EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc - EC2_SECURITY_GROUP: sg-030175c435ac141d6 - outputs: - label: ${{ steps.start-ec2-runner.outputs.label }} - ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }} - steps: - - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v1 - with: - aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} - aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: ${{ env.AWS_REGION }} - - name: Start EC2 runner - id: start-ec2-runner - uses: philschmid/philschmid-ec2-github-runner@main - with: - mode: start - github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} - ec2-image-id: ${{ env.EC2_AMI_ID }} - ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }} - subnet-id: ${{ env.EC2_SUBNET_ID }} - security-group-id: ${{ env.EC2_SECURITY_GROUP }} - aws-resource-tags: > # optional, requires additional permissions - [ - {"Key": "Name", "Value": "ec2-tgi-github-runner"}, - {"Key": "GitHubRepository", "Value": "${{ github.repository }}"} - ] - - build-and-push-image: - concurrency: - group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} - cancel-in-progress: true - needs: start-runner # required to start the main job when the runner is ready - runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner - permissions: - contents: write - packages: write - # This is used to complete the identity challenge - # with sigstore/fulcio when running outside of PRs. - id-token: write - security-events: write - steps: - - name: Checkout repository - uses: actions/checkout@v3 - - name: Initialize Docker Buildx - uses: docker/setup-buildx-action@v2.0.0 - with: - install: true - - name: Inject slug/short variables - uses: rlespinasse/github-slug-action@v4.4.1 - - name: Install cosign - if: github.event_name != 'pull_request' - uses: sigstore/cosign-installer@f3c664df7af409cb4873aa5068053ba9d61a57b6 #v2.6.0 - with: - cosign-release: 'v1.13.1' - - name: Tailscale - uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966 - with: - authkey: ${{ secrets.TAILSCALE_AUTHKEY }} - - name: Login to GitHub Container Registry - if: github.event_name != 'pull_request' - uses: docker/login-action@v2 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - name: Login to internal Container Registry - uses: docker/login-action@v2.1.0 - with: - username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }} - password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }} - registry: registry.internal.huggingface.tech - - name: Login to Azure Container Registry - if: github.event_name != 'pull_request' - uses: docker/login-action@v2.1.0 - with: - username: ${{ secrets.AZURE_DOCKER_USERNAME }} - password: ${{ secrets.AZURE_DOCKER_PASSWORD }} - registry: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io - # If pull request - - name: Extract metadata (tags, labels) for Docker - if: ${{ github.event_name == 'pull_request' }} - id: meta-pr - uses: docker/metadata-action@v4.3.0 - with: - images: | - registry.internal.huggingface.tech/api-inference/community/text-generation-inference - tags: | - type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }} - # If main, release or tag - - name: Extract metadata (tags, labels) for Docker - if: ${{ github.event_name != 'pull_request' }} - id: meta - uses: docker/metadata-action@v4.3.0 - with: - flavor: | - latest=auto - images: | - registry.internal.huggingface.tech/api-inference/community/text-generation-inference - ghcr.io/huggingface/text-generation-inference - db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference - tags: | - type=semver,pattern={{version}} - type=semver,pattern={{major}}.{{minor}} - type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }} - type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }} - - name: Build and push Docker image - id: build-and-push - uses: docker/build-push-action@v4 - with: - context: . - file: Dockerfile - push: true - platforms: 'linux/amd64' - build-args: | - GIT_SHA=${{ env.GITHUB_SHA }} - DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} - tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }} - labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} - cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min - cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min - # Sign the resulting Docker image digest except on PRs. - # This will only write to the public Rekor transparency log when the Docker - # repository is public to avoid leaking data. - - name: Sign the published Docker image - if: ${{ github.event_name != 'pull_request' }} - env: - COSIGN_EXPERIMENTAL: "true" - # This step uses the identity token to provision an ephemeral certificate - # against the sigstore community Fulcio instance. - run: echo "${{ steps.meta.outputs.tags }}" | xargs -I {} cosign sign {}@${{ steps.build-and-push.outputs.digest }} - - name: Run Trivy in GitHub SBOM mode and submit results to Dependency Graph - uses: aquasecurity/trivy-action@master - if: ${{ github.event_name != 'pull_request' }} - with: - image-ref: 'ghcr.io/huggingface/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}' - format: 'github' - output: 'dependency-results.sbom.json' - github-pat: ${{ secrets.GITHUB_TOKEN }} - scanners: 'vuln' - - name: Run Trivy vulnerability scanner - uses: aquasecurity/trivy-action@master - if: ${{ github.event_name != 'pull_request' }} - with: - image-ref: 'ghcr.io/huggingface/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}' - format: 'sarif' - output: 'trivy-results.sarif' - severity: 'CRITICAL' - scanners: 'vuln' - - name: Upload Trivy scan results to GitHub Security tab - uses: github/codeql-action/upload-sarif@v2 - if: ${{ github.event_name != 'pull_request' }} - with: - sarif_file: 'trivy-results.sarif' - - integration-tests: - concurrency: - group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} - cancel-in-progress: true - needs: - - start-runner - - build-and-push-image # Wait for the docker image to be built - runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner - env: - DOCKER_VOLUME: /cache - steps: - - uses: actions/checkout@v2 - - name: Inject slug/short variables - uses: rlespinasse/github-slug-action@v4.4.1 - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: 3.9 - - name: Tailscale - uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966 - with: - authkey: ${{ secrets.TAILSCALE_AUTHKEY }} - - name: Prepare disks - run: | - sudo mkfs -t ext4 /dev/nvme1n1 - sudo mkdir ${{ env.DOCKER_VOLUME }} - sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }} - - name: Install - run: | - make install-integration-tests - - name: Run tests - run: | - export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }} - export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} - pytest -s -vv integration-tests - - stop-runner: - name: Stop self-hosted EC2 runner - needs: - - start-runner - - build-and-push-image - - integration-tests - runs-on: ubuntu-latest - env: - AWS_REGION: us-east-1 - if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs - steps: - - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v1 - with: - aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} - aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: ${{ env.AWS_REGION }} - - name: Stop EC2 runner - uses: philschmid/philschmid-ec2-github-runner@main - with: - mode: stop - github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} - label: ${{ needs.start-runner.outputs.label }} - ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }} diff --git a/.github/workflows/client-tests.yaml b/.github/workflows/client-tests.yaml deleted file mode 100644 index 1fa0b39d7db..00000000000 --- a/.github/workflows/client-tests.yaml +++ /dev/null @@ -1,25 +0,0 @@ -name: Python Client Tests - -on: - pull_request: - paths: - - ".github/workflows/client-tests.yaml" - - "clients/python/**" - -jobs: - run_tests: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v1 - with: - python-version: 3.9 - - name: Install - run: | - cd clients/python && pip install . - - name: Run tests - run: | - pip install pytest pytest-asyncio - make python-client-tests diff --git a/.github/workflows/free_disk_space.sh b/.github/workflows/free_disk_space.sh new file mode 100644 index 00000000000..416884b6a4b --- /dev/null +++ b/.github/workflows/free_disk_space.sh @@ -0,0 +1,49 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# +# The Azure provided machines typically have the following disk allocation: +# Total space: 85GB +# Allocated: 67 GB +# Free: 17 GB +# This script frees up 28 GB of disk space by deleting unneeded packages and +# large directories. +# The Flink end to end tests download and generate more than 17 GB of files, +# causing unpredictable behavior and build failures. +# +echo "==============================================================================" +echo "Freeing up disk space on CI system" +echo "==============================================================================" + +echo "Listing 100 largest packages" +dpkg-query -Wf '${Installed-Size}\t${Package}\n' | sort -n | tail -n 100 +df -h +echo "Removing large packages" +sudo apt-get remove -y '^ghc-8.*' +sudo apt-get remove -y '^dotnet-.*' +sudo apt-get remove -y '^llvm-.*' +sudo apt-get remove -y 'php.*' +sudo apt-get remove -y azure-cli google-cloud-sdk hhvm google-chrome-stable firefox microsoft-edge-stable powershell mono-devel +sudo apt-get remove -y '^gcc-.*' +sudo apt-get remove -y '^g++-.*' +sudo apt-get autoremove -y +sudo apt-get clean +df -h +echo "Removing large directories" +# deleting 15GB +rm -rf /usr/share/dotnet/ +df -h \ No newline at end of file diff --git a/.github/workflows/load_test.yaml b/.github/workflows/load_test.yaml deleted file mode 100644 index fd22e395780..00000000000 --- a/.github/workflows/load_test.yaml +++ /dev/null @@ -1,108 +0,0 @@ -name: Nightly load test - -on: - schedule: - - cron: '0 0 * * 1-5' - - pull_request: - paths: - - ".github/workflows/load_test.yaml" - branches: - - 'main' - -jobs: - start-runner: - name: Start self-hosted EC2 runner - runs-on: ubuntu-latest - env: - AWS_REGION: eu-central-1 - EC2_AMI_ID: ami-0ab09c07cfd194259 - EC2_INSTANCE_TYPE: g5.12xlarge - EC2_SUBNET_ID: subnet-988fd9f2,subnet-6f56db13,subnet-6a039326 - EC2_SECURITY_GROUP: sg-072f92ae3082936c6 - outputs: - label: ${{ steps.start-ec2-runner.outputs.label }} - ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }} - steps: - - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v1 - with: - aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} - aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: ${{ env.AWS_REGION }} - - name: Start EC2 runner - id: start-ec2-runner - uses: philschmid/philschmid-ec2-github-runner@main - with: - mode: start - github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} - ec2-image-id: ${{ env.EC2_AMI_ID }} - ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }} - subnet-id: ${{ env.EC2_SUBNET_ID }} - security-group-id: ${{ env.EC2_SECURITY_GROUP }} - aws-resource-tags: > # optional, requires additional permissions - [ - {"Key": "Name", "Value": "ec2-tgi-github-runner"}, - {"Key": "GitHubRepository", "Value": "${{ github.repository }}"} - ] - - load-tests: - concurrency: - group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} - cancel-in-progress: true - needs: start-runner # required to start the main job when the runner is ready - runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner - env: - DOCKER_VOLUME: /cache - steps: - - name: Checkout repository - uses: actions/checkout@v3 - - - name: Prepare disks - run: | - sudo mkfs -t ext4 /dev/nvme1n1 - sudo mkdir ${{ env.DOCKER_VOLUME }} - sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }} - - - name: Install k6 - run: | - curl https://github.com/grafana/k6/releases/download/v0.44.0/k6-v0.44.0-linux-amd64.tar.gz -L | tar xvz --strip-components 1 - - - name: Start starcoder - run: | - docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v ${{ env.DOCKER_VOLUME }}:/data -e HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768 - sleep 10 - wget --timeout 10 --retry-on-http-error --waitretry=1 --tries=240 http://localhost:3000/health - - - name: Run k6 - run: | - ./k6 run load_tests/starcoder_load.js - - - name: Stop starcoder - if: ${{ always() }} - run: | - docker stop tgi-starcoder || true - - stop-runner: - name: Stop self-hosted EC2 runner - needs: - - start-runner - - load-tests - runs-on: ubuntu-latest - env: - AWS_REGION: eu-central-1 - if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs - steps: - - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v1 - with: - aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} - aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: ${{ env.AWS_REGION }} - - name: Stop EC2 runner - uses: philschmid/philschmid-ec2-github-runner@main - with: - mode: stop - github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} - label: ${{ needs.start-runner.outputs.label }} - ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }} diff --git a/.github/workflows/setup_conda.yml b/.github/workflows/setup_conda.yml new file mode 100644 index 00000000000..816b9fd6db1 --- /dev/null +++ b/.github/workflows/setup_conda.yml @@ -0,0 +1,27 @@ +name: Test Conda Setup + +on: [push, pull_request] + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Conda + uses: conda-incubator/setup-miniconda@v2 + with: + auto-activate-base: true + activate-environment: "" + auto-update-conda: true + + - name: Free up disk space + run: | + bash .github/workflows/free_disk_space.sh + + - name: Run Conda Server Setup + shell: bash -l {0} + run: | + bash ./setup_scripts/conda_server.sh --light-mode --no-tests \ No newline at end of file diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml deleted file mode 100644 index 7e5ba52cb5e..00000000000 --- a/.github/workflows/tests.yaml +++ /dev/null @@ -1,82 +0,0 @@ -name: Server Tests - -on: - pull_request: - paths: - - ".github/workflows/tests.yaml" - - "server/**" - - "proto/**" - - "router/**" - - "launcher/**" - - "Cargo.lock" - - "rust-toolchain.toml" - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - run_tests: - runs-on: ubuntu-latest - - env: - SCCACHE_GHA_ENABLED: "on" - RUSTC_WRAPPER: /usr/local/bin/sccache - SCCACHE: 0.3.3 - - steps: - - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v1 - with: - python-version: 3.9 - - name: Install Rust - uses: actions-rs/toolchain@v1 - with: - toolchain: 1.65.0 - override: true - components: rustfmt, clippy - - name: Install Protoc - uses: arduino/setup-protoc@v1 - - name: Install sccache - run: | - curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache - chmod +x /usr/local/bin/sccache - - name: configure sccache - uses: actions/github-script@v6 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - core.exportVariable('SCCACHE_GHA_CACHE_TO', 'sccache-${{runner.os}}-${{github.ref_name}}'); - core.exportVariable('SCCACHE_GHA_CACHE_FROM', 'sccache-${{runner.os}}-main,sccache-${{runner.os}}-'); - - name: cargo registry cache - uses: actions/cache@v3 - with: - key: cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-${{ github.sha }} - restore-keys: | - cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}- - cargo-${{ runner.os }}- - path: | - ~/.cargo/registry - ~/.cargo/git - - name: Install - run: | - make install - - name: Run server tests - run: | - pip install pytest - export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} - pytest -s -vv server/tests - - name: Run Rust fmt - run: | - cargo fmt --check - - name: Run Rust clippy - run: | - cargo clippy - - name: Run Rust tests - run: | - cargo test - - name: sccache stats - run: | - /usr/local/bin/sccache --show-stats diff --git a/.gitignore b/.gitignore index 20c9baee226..de17588f0a1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .idea target router/tokenizer.json -*__pycache__* +.openssl +*__pycache__* \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000000..ae1eeb3c438 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "chat-ui"] + path = chat-ui + url = git@github.com:CoderPat/chat-ui.git diff --git a/Cargo.lock b/Cargo.lock index 8984ea6ad4b..c2866b114fe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -57,6 +57,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anstream" version = "0.3.2" @@ -398,6 +413,21 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chrono" +version = "0.4.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec837a71355b28f6556dbd569b37b3f363091c0bd4b2e735674521b4c5fd9bc5" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", + "num-traits", + "time 0.1.45", + "wasm-bindgen", + "winapi", +] + [[package]] name = "cipher" version = "0.4.4" @@ -991,7 +1021,7 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "wasm-bindgen", ] @@ -1057,6 +1087,31 @@ version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" +[[package]] +name = "headers" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3e372db8e5c0d213e0cd0b9be18be2aca3d44cf2fe30a9d46a65581cd454584" +dependencies = [ + "base64 0.13.1", + "bitflags 1.3.2", + "bytes", + "headers-core", + "http", + "httpdate", + "mime", + "sha1", +] + +[[package]] +name = "headers-core" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429" +dependencies = [ + "http", +] + [[package]] name = "heck" version = "0.4.1" @@ -1178,6 +1233,29 @@ dependencies = [ "tokio-native-tls", ] +[[package]] +name = "iana-time-zone" +version = "0.1.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fad5b825842d2b38bd206f3e81d6957625fd7f0a361e345c30e01a0ae2dd613" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "ident_case" version = "1.0.1" @@ -1546,7 +1624,7 @@ checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" dependencies = [ "libc", "log", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.48.0", ] @@ -1571,6 +1649,24 @@ dependencies = [ "syn 2.0.25", ] +[[package]] +name = "multer" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01acbdc23469fd8fe07ab135923371d5f5a422fbf9c522158677c8eb15bc51c2" +dependencies = [ + "bytes", + "encoding_rs", + "futures-util", + "http", + "httparse", + "log", + "memchr", + "mime", + "spin 0.9.8", + "version_check", +] + [[package]] name = "multimap" version = "0.8.3" @@ -1905,7 +2001,7 @@ dependencies = [ "once_cell", "pin-project-lite", "thiserror", - "urlencoding", + "urlencoding 2.1.2", ] [[package]] @@ -2195,7 +2291,7 @@ dependencies = [ "mach2", "once_cell", "raw-cpuid", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "web-sys", "winapi", ] @@ -2547,6 +2643,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.1.0" @@ -2893,7 +2995,7 @@ dependencies = [ [[package]] name = "text-generation-benchmark" -version = "0.9.4" +version = "0.9.5" dependencies = [ "average", "clap", @@ -2911,9 +3013,25 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "text-generation-central" +version = "0.9.5" +dependencies = [ + "bytes", + "chrono", + "clap", + "reqwest", + "serde", + "serde_json", + "tokio", + "urlencoding 1.3.3", + "warp", + "whoami", +] + [[package]] name = "text-generation-client" -version = "0.9.4" +version = "0.9.5" dependencies = [ "futures", "grpc-metadata", @@ -2929,7 +3047,7 @@ dependencies = [ [[package]] name = "text-generation-launcher" -version = "0.9.4" +version = "0.9.5" dependencies = [ "clap", "ctrlc", @@ -2940,12 +3058,14 @@ dependencies = [ "serde_json", "tracing", "tracing-subscriber", + "urlencoding 1.3.3", "vergen", + "whoami", ] [[package]] name = "text-generation-router" -version = "0.9.4" +version = "0.9.5" dependencies = [ "async-stream", "axum", @@ -3006,6 +3126,17 @@ dependencies = [ "once_cell", ] +[[package]] +name = "time" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b797afad3f312d1c66a56d11d0316f916356d11bd158fbc6ca6389ff6bf805a" +dependencies = [ + "libc", + "wasi 0.10.0+wasi-snapshot-preview1", + "winapi", +] + [[package]] name = "time" version = "0.3.23" @@ -3159,6 +3290,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54319c93411147bced34cb5609a80e0a8e44c5999c93903a81cd866630ec0bfd" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.8" @@ -3436,6 +3579,25 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +[[package]] +name = "tungstenite" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30ee6ab729cd4cf0fd55218530c4522ed30b7b6081752839b68fcec8d0960788" +dependencies = [ + "base64 0.13.1", + "byteorder", + "bytes", + "http", + "httparse", + "log", + "rand", + "sha1", + "thiserror", + "url", + "utf-8", +] + [[package]] name = "typenum" version = "1.16.0" @@ -3516,12 +3678,24 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "urlencoding" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a1f0175e03a0973cf4afd476bef05c26e228520400eb1fd473ad417b1c00ffb" + [[package]] name = "urlencoding" version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8db7427f936968176eaa7cdf81b7f98b980b18495ec28f1b5791ac3bfe3eea9" +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8parse" version = "0.2.1" @@ -3591,7 +3765,7 @@ dependencies = [ "rustc_version", "rustversion", "sysinfo", - "time", + "time 0.3.23", ] [[package]] @@ -3619,6 +3793,43 @@ dependencies = [ "try-lock", ] +[[package]] +name = "warp" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba431ef570df1287f7f8b07e376491ad54f84d26ac473489427231e1718e1f69" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "headers", + "http", + "hyper", + "log", + "mime", + "mime_guess", + "multer", + "percent-encoding", + "pin-project", + "rustls-pemfile", + "scoped-tls", + "serde", + "serde_json", + "serde_urlencoded", + "tokio", + "tokio-stream", + "tokio-tungstenite", + "tokio-util", + "tower-service", + "tracing", +] + +[[package]] +name = "wasi" +version = "0.10.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -3722,6 +3933,16 @@ dependencies = [ "once_cell", ] +[[package]] +name = "whoami" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22fc3756b8a9133049b26c7f61ab35416c130e8c09b660f5b3958b446f52cc50" +dependencies = [ + "wasm-bindgen", + "web-sys", +] + [[package]] name = "winapi" version = "0.3.9" @@ -3753,6 +3974,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f" +dependencies = [ + "windows-targets 0.48.1", +] + [[package]] name = "windows-sys" version = "0.45.0" @@ -3919,7 +4149,7 @@ dependencies = [ "hmac", "pbkdf2", "sha1", - "time", + "time 0.3.23", "zstd", ] diff --git a/Cargo.toml b/Cargo.toml index 3bfe9831b9e..886bb06c188 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,14 +4,15 @@ members = [ "router", "router/client", "router/grpc-metadata", + "central", "launcher" ] [workspace.package] -version = "0.9.4" +version = "0.9.5" edition = "2021" -authors = ["Olivier Dehaene"] -homepage = "https://github.com/huggingface/text-generation-inference" +authors = ["Patrick Fernandes"] +homepage = "https://github.com/coderpat/text-generation-inference" [profile.release] debug = 1 diff --git a/Makefile b/Makefile index 7f534c7ccd7..458508d66f7 100644 --- a/Makefile +++ b/Makefile @@ -14,10 +14,13 @@ install-router: install-launcher: cd launcher && cargo install --path . +install-central: + cd central && cargo install --path . + install-benchmark: cd benchmark && cargo install --path . -install: install-server install-router install-launcher install-custom-kernels +install: install-server install-router install-launcher install-central install-custom-kernels server-dev: cd server && make run-dev @@ -42,6 +45,27 @@ python-client-tests: python-tests: python-server-tests python-client-tests +run-llama2-benchmark: + text-generation-launcher --model-id lmsys/vicuna-7b-v1.5 + +run-llama2-vicuna-7b: + text-generation-launcher --model-id lmsys/vicuna-7b-v1.5 --port 8080 + +run-llama2-vicuna-7b-quantize: + text-generation-launcher --model-id lmsys/vicuna-7b-v1.5 --port 8080 --quantize bitsandbytes + +run-llama2-vicuna-13b: + text-generation-launcher --model-id lmsys/vicuna-13b-v1.5 --port 8080 + +run-llama2-vicuna-13b-quantize: + text-generation-launcher --model-id lmsys/vicuna-13b-v1.5 --port 8080 --quantize bitsandbytes + +run-llama2-vicuna-33b-quantize: + text-generation-launcher --model-id lmsys/vicuna-33b-v1.3 --port 8080 --quantize bitsandbytes + +run-llama2-70b-instruct-quantize: + text-generation-launcher --model-id upstage/Llama-2-70b-instruct-v2 --port 8080 --quantize bitsandbytes + run-falcon-7b-instruct: text-generation-launcher --model-id tiiuae/falcon-7b-instruct --port 8080 diff --git a/README.md b/README.md index 2bbb6583788..9ef56cb6be9 100644 --- a/README.md +++ b/README.md @@ -1,22 +1,124 @@
-![image](https://github.com/huggingface/text-generation-inference/assets/3841370/38ba1531-ea0d-4851-b31a-a6d4ddc944b0) +# LTI's **Text Generation Inference** Fork -# Text Generation Inference - - - GitHub Repo stars - - - License - - - Swagger API documentation -
-A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co) -to power LLMs api-inference widgets. +A Rust, Python and gRPC server for text generation inference. + +Forked from [HuggingFace](https://huggingface.co)'s [Text Generation Inference](https://github.com/huggingface/text-generation-inference/) project (prior to its re-licensing), it's commercial-friendly and licensed under the Apache 2.0. + +## *A note on this fork* + +This fork was created mainly due to two reasons: +1. Primarily, it allows us faster iteration and more flexibility, which is essential for our research uses. It also allows more control over development and documentation, crucial for our in-house uses at CMU. +2. While we understand the reasons behind the re-licensing, we don't want our (research) contributions to be locked behind a restrictive license. This fork will not sync with the upstream repository, and will be updated independently. + +*For contributors*: If HuggingFace's upstream has a feature that you want to use, please open an issue first and discuss porting the functionality independently. +Do not just copy the code over, as it will be rejected. + +## *For LTI/cluster users* + +### Getting started + +If you are new to using this library, and as it has being used in your cluster, we recommend by starting with a *client-only* installation, and using models launched by other users. + +To start, the `TGI_CENTRAL_ADDRESS` needs to be set, so that the client can know which servers to connect to. For example, in the LTI cluster, run + +```shell +echo "export TGI_CENTRAL_ADDRESS=babel-3-36:8765" >> ~/.bashrc # if using a single machine, use `0.0.0.0:8765` instead +source ~/.bashrc +``` + +To use the python client, install it with + +```shell +cd clients/python +pip install . +``` + +You can then query the API to list the models available in your cluster, and use models for inference. + +```python +from text_generation import Client + +# get current models and pick the first one +models = Client.list_from_central() +model_name, model_addr = models[0]["name"], models[0]["address"] +print(f"Using model {model_name} at {model_addr}") + +client = Client("http://" + model_addr) +print(client.generate("What is Deep Learning?", max_new_tokens=20).generated_text) +``` + +#### Updating the environment + +In general, you don't have to recreate the environment every time you want to update the library. +To just update the library, run in the base directory (in a previously created environment) + +```shell +export DIR=`pwd` +OPENSSL_DIR=${DIR}/.openssl \ +OPENSSL_LIB_DIR=${DIR}/.openssl/lib \ +OPENSSL_INCLUDE_DIR=${DIR}/.openssl/include \ +BUILD_EXTENSIONS=false \ + make install +``` + +### Running your own servers + +If you are an LTI student using one of its cluster (or generally belong to an academic cluster that doesn't have docker installed), you can side-steps problems with installing system dependencies by using the [(mini)conda](https://docs.conda.io/en/latest/miniconda.html) package manager. + +Then, ***from your base environment***, run the install script: + +```shell +bash setup_scripts/conda_server.sh +``` + +*Note*: This **takes a really long time**, up to 1.5-3 hour, sit back and realx while you wait for it. + +*Note*: if you are running in a cluster with `module` installed, make sure you deactivate all modules before running the script. + +This will create a conda environment with all the dependencies needed to run the model servers. + +You should then be able to launch models with the `text-generation-launcher` command, or by using one of the predefined MAKE rules +```shell +conda activate tgi-env +make run-llama2-vicuna-7b +``` + +### Setting up a Central server + +If you are setting this library for use in your group/cluster for the first time, you will need (or at least benefit) from setting up a central server. +See the instructions [in the package folder](./central/README.md). + +Remember to set the `TGI_CENTRAL_ADDRESS` environment variable (ideally for all the users in your cluster) to the address of the central server. + +### Chat-UI + +It is also possible to a simple web [chat-ui](./clients/chat-ui) to interact with models running in your server/cluster. +This is a simple fork of [HuggingFace's Chat UI](https://github.com/huggingface/chat-ui) that communicates with the central controller to get the list of models available in the cluster, and then connects to the corresponding servers to generate text. + +For example, in Babel, you can access a running Chat-UI web-server with *port forwarding* by running + +```shell +ssh babel -L 8888:babel-3-36:4173 +``` + +and going to `localhost:8888` in your browser. + +

+ + +

+ +Check the [README](./chat-ui/README.md) for more details. + +*Content below is from the original README.* + +--- + +![image](https://github.com/huggingface/text-generation-inference/assets/3841370/38ba1531-ea0d-4851-b31a-a6d4ddc944b0) ## Table of contents @@ -129,6 +231,7 @@ for response in client.generate_stream("What is Deep Learning?", max_new_tokens= print(text) ``` + ### API documentation You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route. diff --git a/assets/chatui01.png b/assets/chatui01.png new file mode 100755 index 00000000000..6b24e7171b6 Binary files /dev/null and b/assets/chatui01.png differ diff --git a/assets/chatui02.png b/assets/chatui02.png new file mode 100755 index 00000000000..775f8c52b94 Binary files /dev/null and b/assets/chatui02.png differ diff --git a/benchmark/dump_fast_tokenizer.py b/benchmark/dump_fast_tokenizer.py new file mode 100644 index 00000000000..c2799b1faa6 --- /dev/null +++ b/benchmark/dump_fast_tokenizer.py @@ -0,0 +1,19 @@ +import os +import json +import argparse +from transformers import AutoTokenizer + +def dump_fast_tokenizer(tokenizer_name, output_path): + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer.save_pretrained(output_path) + +def main(): + parser = argparse.ArgumentParser(description="Dump fast tokenizer json file") + parser.add_argument("--tokenizer-name", required=True, help="Name of the Hugging Face tokenizer") + parser.add_argument("--output", required=True, help="Output path for the fast tokenizer json file") + args = parser.parse_args() + + dump_fast_tokenizer(args.tokenizer_name, args.output) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index b57c652b9b4..85f236dcf94 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -218,6 +218,6 @@ fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String { encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left); // Decode tokenizer - .decode(Vec::from(encoding.get_ids()), false) + .decode(&Vec::from(encoding.get_ids()), false) .unwrap() } diff --git a/central/Cargo.toml b/central/Cargo.toml new file mode 100644 index 00000000000..0b84813b8a1 --- /dev/null +++ b/central/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "text-generation-central" +description = "Text Generation Central controller tool" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[[bin]] +name = "text-generation-central" +path = "src/main.rs" + +[dependencies] +clap = { version = "4.1.4", features = ["derive", "env"] } +warp = "0.3" +serde = {version = "1.0.142", features = ["derive"]} +serde_json = "1.0.93" +reqwest = { version = "0.11.14", features = [] } +tokio = { version = "1.25.0", features = ["full"] } +chrono = "0.4" +urlencoding = "1.1.1" +bytes = "1.1" +whoami = "1.4" + diff --git a/central/README.md b/central/README.md new file mode 100644 index 00000000000..3a27d73a1dc --- /dev/null +++ b/central/README.md @@ -0,0 +1,24 @@ +
+ +# Text Generation Central controller tool +
+ +A lightweight tool for tracking which models are running, who is running them, in what ip and port, etc. + +## Install + +From the root of the project run: + +```shell +make install-central +``` + +## Run + +To run the central controller tool, run: + +```shell +text-generation-central --port $PORT +``` + +TODO: Docs on environment variables \ No newline at end of file diff --git a/central/src/main.rs b/central/src/main.rs new file mode 100644 index 00000000000..cc51b453f89 --- /dev/null +++ b/central/src/main.rs @@ -0,0 +1,230 @@ +use clap::Parser; +use warp::{Filter, Reply}; +use serde::{Serialize,Deserialize}; +use std::{sync::Arc, collections::HashMap}; +use tokio::sync::Mutex; + +/// App Configuration +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + // The address the central uses to listen to server requests + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + + // The port the central uses to listen to server requests + #[clap(default_value = "8086", long, env)] + port: u16, + + // The interval (in seconds) between central pings to models + #[clap(default_value = "60", long, env)] + ping_interval: u64, + + // The maximum number of failed pings before a model is dropped + #[clap(default_value = "3", long, env)] + max_failed_pings: u32, + + // By default is None, if set pings a server on launch and if alive registers it + #[clap(default_value = None, long, env)] + initial_ping: Option, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct ModelRecord { + pub name: String, + pub address: String, + pub owner: String, + pub is_quantized: bool, +} + +#[derive(Deserialize, Clone, Debug)] +struct ModelInfo { + docker_label: Option, + max_batch_total_tokens: u32, + max_best_of: u32, + max_concurrent_requests: u32, + max_input_length: u32, + max_stop_sequences: u32, + max_total_tokens: u32, + max_waiting_tokens: u32, + model_device_type: String, + model_dtype: String, + model_id: String, + model_pipeline_tag: String, + model_sha: String, + sha: String, + validation_workers: u32, + version: String, + waiting_served_ratio: f32, +} + +type Models = Arc>>; + + +// define function to print model info +fn print_model_record(record: &ModelRecord) { + if record.is_quantized { + println!("\t{} (quant) - {} by {}", record.name, record.address, record.owner); + } else { + println!("\t{} - {} by {}", record.name, record.address, record.owner); + } +} + +#[tokio::main] + +async fn main() -> Result<(), Box> { + let args = Args::parse(); + let hostname = args.hostname; + let port = args.port; + let ping_interval = args.ping_interval; + let initial_ping = args.initial_ping; + // get current user from env + let user = whoami::username(); + + let models: Models = Arc::new(Mutex::new(HashMap::new())); + + fn with_models(models: Models) -> impl Filter + Clone { + warp::any().map(move || models.clone()) + } + + async fn handle_model_notice(encoded_id: String, record: ModelRecord, models: Models) -> Result { + println!("Received model notice for {}", encoded_id); + let model_id = urlencoding::decode(&encoded_id).unwrap(); + models.lock().await.insert(model_id, record); + Ok(warp::reply::with_status( + "Model registered successfully", + warp::http::StatusCode::OK, + )) + } + + async fn handle_list_models(models: Models) -> Result { + let models = models.lock().await; + // print for debug + let mut models_list: Vec = vec![]; + for (_, record) in models.iter() { + models_list.push(record.clone()); + } + Ok(warp::reply::with_status( + warp::reply::json(&models_list).into_response(), + warp::http::StatusCode::OK, + )) + } + + let model_notice_route = warp::path("model_up") + .and(warp::path::param::()) + .and(warp::post()) + .and(warp::body::json()) + .and(with_models(models.clone())) + .and_then(handle_model_notice); + + let list_models_route = warp::path("list_models") + .and(warp::get()) + .and(with_models(models.clone())) + .and_then(handle_list_models); + + let catch_all = warp::any() + .map(||{ + println!("Warning: Received a request on an unhandled route"); + warp::reply::with_status( + "Unhandled route!", + warp::http::StatusCode::NOT_FOUND, + ) + }); + + let routes = model_notice_route + .or(list_models_route) + .or(catch_all); + + let listener = warp::serve(routes).run((hostname.parse::().unwrap(), port)); + let monitor = async { + // ping server if provided + if let Some(model_addr) = initial_ping { + // split server into ip and port variables strings + let model_ip = model_addr.split(":").collect::>()[0]; + let model_port = model_addr.split(":").collect::>()[1]; + + let url = format!("http://{}:{}/info", model_ip, model_port); + let response = reqwest::get(&url).await; + match response { + Ok(response) => { + if response.status().is_success() { + let body = response.text().await?; + let info: ModelInfo = serde_json::from_str(&body)?; + let address = format!("{}:{}", model_ip, model_port); + models.lock().await.insert( + info.model_id.clone(), + // TODO: this is not the correct values + // we should get these from the model + ModelRecord { + name: info.model_id.clone(), + address: address, + owner: user.to_string(), + is_quantized: false, + }); + } else { + println!("Model not alive"); + } + }, + Err(e) => { + println!("Model not alive"); + } + }; + } + + // every Ns, for every model, ping in /health, and if not alive remove from models () + loop { + let mut models = models.lock().await; + let mut keys_removal: Vec = vec![]; + + for (model, record) in models.iter() { + let url = format!("http://{}/health", record.address); + let response = reqwest::get(&url).await; + match response { + Ok(response) => { + if !response.status().is_success() { + keys_removal.push(model.to_string()); + } + }, + Err(e) => { + keys_removal.push(model.to_string()); + } + } + }; + + let mut dropped_models: HashMap = HashMap::new(); + for key in keys_removal { + if let Some(record) = models.remove(&key) { + dropped_models.insert(key, record); + } + } + + // print current time + println!("------------------"); + println!("Current time: {}", chrono::Local::now().format("%Y-%m-%d %H:%M:%S")); + // print models that stayed, one in each line + println!("Current Models:"); + for (model, record) in models.iter() { + print_model_record(record); + } + // print dropped models + println!("Dropped Models:"); + for (model, record) in dropped_models.iter() { + print_model_record(record); + } + + std::mem::drop(models); + tokio::time::sleep(std::time::Duration::from_secs(ping_interval)).await; + } + + Ok(()) as Result<(), Box> + }; + + // wrap listener to go into try join + let listener = async { + listener.await; + Ok(()) + }; + tokio::try_join!(listener, monitor); + Ok(()) +} + diff --git a/chat-ui b/chat-ui new file mode 160000 index 00000000000..f65ca708ec2 --- /dev/null +++ b/chat-ui @@ -0,0 +1 @@ +Subproject commit f65ca708ec2d018fee108a6edf5e48f545d4032f diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index bf045d47735..486b4d7775b 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -1,5 +1,6 @@ import json import requests +import os from aiohttp import ClientSession, ClientTimeout from pydantic import ValidationError @@ -35,6 +36,36 @@ class Client: ``` """ + @classmethod + def list_from_central( + cls, + central_url: str = None, + ): + """ + Get the list of available models from the central model hub + + Args: + central_url (`str`): + Text Generation Central URL + + Returns: + List[Dict[str, str]]: List of available models + """ + if central_url is None: + # check if environment variable is set + if os.environ.get("TGI_CENTRAL_ADDRESS") is None: + raise ValueError( + "No Central url provided and TGI_CENTRAL_ADDRESS environment variable is not set" + ) + central_url = f"http://{os.environ.get(/'TGI_CENTRAL_ADDRESS')}" + + # query from /models endpoint + resp = requests.get(f"{central_url}/list_models") + payload = resp.json() + if resp.status_code != 200: + raise parse_error(resp.status_code, payload) + return payload + def __init__( self, base_url: str, @@ -75,6 +106,7 @@ def generate( typical_p: Optional[float] = None, watermark: bool = False, decoder_input_details: bool = False, + top_tokens: Optional[int] = None, ) -> Response: """ Given a prompt, generate the following text @@ -113,6 +145,8 @@ def generate( Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) decoder_input_details (`bool`): Return the decoder input token logprobs and ids + top_tokens (`Optional[int]`): + Return the top `top_tokens` tokens with the highest logprobs at each step Returns: Response: generated response @@ -134,6 +168,7 @@ def generate( typical_p=typical_p, watermark=watermark, decoder_input_details=decoder_input_details, + top_tokens=top_tokens, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -148,6 +183,25 @@ def generate( if resp.status_code != 200: raise parse_error(resp.status_code, payload) return Response(**payload[0]) + + def score( + self: str, + target: str, + ): + """ Utility function to score a target string (i.e. compute its logprob). + + Mostly wraps the generate function, asking for 1 new token and returning + the logprob of the prompt. + """ + # Use generate to get the score + resp = self.generate( + prompt=target, + do_sample=False, + max_new_tokens=1, + decoder_input_details=True, + ) + # extract prefill details and cut off first + return resp.details.prefill[1:] def generate_stream( self, @@ -164,6 +218,7 @@ def generate_stream( truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + top_tokens: Optional[int] = None, ) -> Iterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens @@ -198,6 +253,8 @@ def generate_stream( See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + top_tokens (`Optional[int]`): + Return the top `top_tokens` tokens with the highest logprobs at each step Returns: Iterator[StreamResponse]: stream of generated tokens @@ -219,6 +276,7 @@ def generate_stream( truncate=truncate, typical_p=typical_p, watermark=watermark, + top_tokens=top_tokens, ) request = Request(inputs=prompt, stream=True, parameters=parameters) @@ -317,6 +375,7 @@ async def generate( typical_p: Optional[float] = None, watermark: bool = False, decoder_input_details: bool = False, + top_tokens: Optional[int] = None, ) -> Response: """ Given a prompt, generate the following text asynchronously @@ -355,6 +414,8 @@ async def generate( Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) decoder_input_details (`bool`): Return the decoder input token logprobs and ids + top_tokens (`Optional[int]`): + Return the top `top_tokens` tokens with the highest logprobs at each step Returns: Response: generated response @@ -404,6 +465,8 @@ async def generate_stream( truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + top_tokens: Optional[int] = None, + ) -> AsyncIterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens asynchronously @@ -438,6 +501,8 @@ async def generate_stream( See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + top_tokens (`Optional[int]`): + Return the top `top_tokens` tokens with the highest logprobs at each step Returns: AsyncIterator[StreamResponse]: stream of generated tokens @@ -459,6 +524,7 @@ async def generate_stream( truncate=truncate, typical_p=typical_p, watermark=watermark, + top_tokens=top_tokens, ) request = Request(inputs=prompt, stream=True, parameters=parameters) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 548f0b639ce..d15bc82378e 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -33,6 +33,8 @@ class Parameters(BaseModel): typical_p: Optional[float] # Generate best_of sequences and return the one if the highest token logprobs best_of: Optional[int] + # Return the `top_tokens` most likely tokens at each step + top_tokens: Optional[int] # Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) watermark: bool = False # Get generation details @@ -100,6 +102,13 @@ def valid_typical_p(cls, v): if v is not None and (v <= 0 or v >= 1.0): raise ValidationError("`typical_p` must be > 0.0 and < 1.0") return v + + @validator("top_tokens") + def valid_top_tokens(cls, v): + if v is not None and v <= 0: + raise ValidationError("`top_tokens` must be strictly positive") + return v + class Request(BaseModel): @@ -193,6 +202,8 @@ class Details(BaseModel): prefill: List[InputToken] # Generated tokens tokens: List[Token] + # Most likely tokens at each step + top_tokens: Optional[List[List[Token]]] # Additional sequences when using the `best_of` parameter best_of_sequences: Optional[List[BestOfSequence]] @@ -219,6 +230,8 @@ class StreamDetails(BaseModel): class StreamResponse(BaseModel): # Generated token token: Token + # Most likely tokens at each step + top_tokens: Optional[List[Token]] # Complete generated text # Only available when the generation is finished generated_text: Optional[str] diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml index 3e7f86d4e6f..aebf8e7e5d9 100644 --- a/launcher/Cargo.toml +++ b/launcher/Cargo.toml @@ -10,14 +10,17 @@ homepage.workspace = true clap = { version = "4.1.4", features = ["derive", "env"] } ctrlc = { version = "3.2.5", features = ["termination"] } nix = "0.26.2" +reqwest = { version = "0.11.14", features = ["blocking", "json"] } serde = { version = "1.0.152", features = ["derive"] } serde_json = "1.0.93" tracing = "0.1.37" tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] } +urlencoding = "1.1.1" +whoami = "1.4.0" +hostname = "0.3" [dev-dependencies] float_eq = "1.0.1" -reqwest = { version = "0.11.14", features = ["blocking", "json"] } [build-dependencies] vergen = { version = "8.0.0", features = ["build", "cargo", "git", "gitcl", "rustc", "si"] } diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 2ad788a405b..ad9379bc9f3 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -16,6 +16,8 @@ use std::thread::sleep; use std::time::{Duration, Instant}; use std::{fs, io}; use tracing_subscriber::EnvFilter; +use serde::Serialize; +use serde_json::json; mod env_runtime; @@ -25,6 +27,15 @@ enum Quantization { Gptq, } +#[derive(Serialize)] +struct ModelRecord { + name: String, + address: String, + owner: String, + is_quantized: bool, +} + + impl std::fmt::Display for Quantization { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { // To keep in track with `server`. @@ -122,6 +133,12 @@ struct Args { #[clap(default_value = "2", long, env)] max_best_of: usize, + // This is the maximum allowed value for clients to set `top_tokens`. + // It is used to return the `top_tokens` most likely tokens at each generation + // rather than just the top one. + #[clap(default_value = "10", long, env)] + max_top_tokens: u32, + /// This is the maximum allowed value for clients to set `stop_sequences`. /// Stop sequences are used to allow the model to stop on more than just /// the EOS token, and enable more complex "prompting" where users can preprompt @@ -276,6 +293,10 @@ struct Args { #[clap(long, env)] ngrok_edge: Option, + /// Central address, used to register the server in the central registry + #[clap(long, env)] + central_address: Option, + /// Display a lot of information about your runtime environment #[clap(long, short, action)] env: bool, @@ -852,6 +873,8 @@ fn spawn_webserver( args.max_concurrent_requests.to_string(), "--max-best-of".to_string(), args.max_best_of.to_string(), + "--max-top-tokens".to_string(), + args.max_top_tokens.to_string(), "--max-stop-sequences".to_string(), args.max_stop_sequences.to_string(), "--max-input-length".to_string(), @@ -1083,11 +1106,6 @@ fn main() -> Result<(), LauncherError> { // Download and convert model weights download_convert_model(&args, running.clone())?; - if !running.load(Ordering::SeqCst) { - // Launcher was asked to stop - return Ok(()); - } - // Shared shutdown bool let shutdown = Arc::new(AtomicBool::new(false)); // Shared shutdown channel @@ -1097,6 +1115,34 @@ fn main() -> Result<(), LauncherError> { // Shared channel to track shard status let (status_sender, status_receiver) = mpsc::channel(); + // clone args to avoid borrowing issues + // if central_address is None, check if enviroment variable is set + let central_address = match args.central_address.clone() { + Some(central_address) => Some(central_address), + None => match env::var("TGI_CENTRAL_ADDRESS") { + Ok(central_address) => Some(central_address), + Err(_) => None, + }, + }; + let encoded_id = urlencoding::encode(&args.model_id); + let ip = args.hostname.to_string(); + let hostname = match ip.parse::() { + Ok(ip) => ip.ip().to_string(), + Err(_) => { + tracing::warn!("invalid hostname passed! will use system's hostname..."); + // try to resolve hostname.into_string + whoami::hostname() + } + }; + println!("final hostname: {}", hostname); + let model_record = ModelRecord { + name: args.model_id.clone(), + // build address string with hostnmae and port + address: format!("{}:{}", hostname, args.port), + owner: whoami::username(), + is_quantized: args.quantize.is_some() + }; + spawn_shards( num_shard, &args, @@ -1123,6 +1169,37 @@ fn main() -> Result<(), LauncherError> { // Default exit code let mut exit_code = Ok(()); + // Ping central server to register model, using request + if let Some(central_address) = central_address { + println!("Attempting to register in Central at {}", central_address); + let url = format!("http://{}/model_up/%7B%7D", central_address, encoded_id.to_string()); + let client = reqwest::blocking::Client::new(); + let res = client + .post(&url) + .json(&model_record) + .send(); + + match res { + Ok(response) => { + if response.status().is_success() { + println!("Successfully registered on central server"); + } else { + println!("Failed to register on central server"); + // response + println!("Response: {:?}", response); + } + }, + Err(e) => println!("Error occurred while initiating connection with central server: {}", e) + } + } else { + println!("No central server address provided. Skipping registration"); + } + + if !running.load(Ordering::SeqCst) { + // Launcher was asked to stop + return Ok(()); + } + while running.load(Ordering::SeqCst) { if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() { tracing::error!("Shard {rank} crashed"); diff --git a/notebooks/test_client.ipynb b/notebooks/test_client.ipynb new file mode 100644 index 00000000000..f828d91b091 --- /dev/null +++ b/notebooks/test_client.ipynb @@ -0,0 +1,273 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/mnt/data_2/patrick/conda/envs/tgi-env/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import text_generation as tg" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# set environment variable\n", + "import os\n", + "os.environ['TGI_CENTRAL_ADDRESS'] = 'localhost:8765'" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[{'name': '/mnt/data_2/patrick/croissantllm-models/small4_equals/', 'address': 'frightened-frank-flowers-fin-01:3000', 'owner': 'patrick', 'is_quantized': False}]\n" + ] + } + ], + "source": [ + "servers = tg.Client.list_from_central()\n", + "print(servers)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "server_addr = servers[0]['address']" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "client = tg.Client(f\"http://{server_addr}/")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "among the best in the country in their field of study. They are also among the best in the\n" + ] + } + ], + "source": [ + "print(client.generate(\"CMU's PhD students are\", max_new_tokens=20).generated_text)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " among the best in the country in their field of study. They are also among the best in the\n" + ] + } + ], + "source": [ + "text = \"\"\n", + "for response in client.generate_stream(\"CMU's PhD students are\", max_new_tokens=20):\n", + " if not response.token.special:\n", + " text += response.token.text\n", + "print(text)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Getting Top K tokens at each step" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "among the best in\n", + "[Token(id=5684, text='among', logprob=-2.5429688, special=False), Token(id=4645, text='working', logprob=-3.4179688, special=False), Token(id=1135, text='the', logprob=-3.6757812, special=False)]\n", + "[Token(id=1135, text='the', logprob=-0.40478516, special=False), Token(id=3108, text='those', logprob=-2.7402344, special=False), Token(id=488, text='', logprob=-2.8417969, special=False)]\n", + "[Token(id=3284, text='best', logprob=-1.5273438, special=False), Token(id=2481, text='most', logprob=-1.5664062, special=False), Token(id=3263, text='top', logprob=-2.2148438, special=False)]\n", + "[Token(id=1147, text='in', logprob=-0.45898438, special=False), Token(id=1171, text='and', logprob=-3.7089844, special=False), Token(id=5208, text='students', logprob=-3.9511719, special=False)]\n" + ] + } + ], + "source": [ + "resp = client.generate(\"CMU's PhD students are\", max_new_tokens=4, top_tokens=3)\n", + "print(resp.generated_text)\n", + "for top_tokens in resp.details.top_tokens:\n", + " print(top_tokens)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Benchmarking" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "# create 4 random sentences\n", + "SAMPLES = [\n", + " \"The quick brown fox jumps over the lazy dog.\",\n", + " \"The five boxing wizards jump quickly.\",\n", + " \"All questions asked by five watch experts amazed the judge.\",\n", + " \"Jack quietly moved up front and seized the big ball of wax.\",\n", + "]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Sync Client" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "The quick brown fox jumps over the lazy dog.\n", + "The quick brown fox j\n", + "\n", + "The first step in the process is to create a list of potential candidates. This list should include\n", + "\n", + "The first time I heard the term “fake news” was in the context of the \n", + "He was a master of disguise, and he had a knack for getting into places he\n", + "CPU times: user 36.8 ms, sys: 3.42 ms, total: 40.2 ms\n", + "Wall time: 1.95 s\n" + ] + } + ], + "source": [ + "%%time\n", + "for sample in SAMPLES:\n", + " print(client.generate(sample, max_new_tokens=20).generated_text)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Async Client" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "import asyncio\n", + "import nest_asyncio\n", + "nest_asyncio.apply()\n", + "\n", + "async_client = tg.AsyncClient(f\"http://{server_addr}/")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "The quick brown fox jumps over the lazy dog.\n", + "The quick brown fox j\n", + "\n", + "The first step in the process is to create a list of potential candidates. This list should include\n", + "\n", + "The first time I heard the term “fake news” was in the context of the \n", + "He was a master of disguise, and he had a knack for getting into places he\n", + "CPU times: user 105 ms, sys: 5.03 ms, total: 110 ms\n", + "Wall time: 620 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "async def batch_generate():\n", + " return await asyncio.gather(*[async_client.generate(sample, max_new_tokens=20) for sample in SAMPLES])\n", + "\n", + "results = asyncio.run(batch_generate())\n", + "for r in results:\n", + " print(r.generated_text)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tgi-env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/proto/generate.proto b/proto/generate.proto index 57d79bcaf27..157598cf304 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -91,6 +91,8 @@ message Request { StoppingCriteriaParameters stopping_parameters = 5; /// Return prefill logprobs bool prefill_logprobs = 6; + /// Return most likely n tokens + uint32 top_tokens = 7; } message Batch { @@ -141,6 +143,17 @@ message PrefillTokens { repeated string texts = 3; } +message TopTokens { + /// Top Token IDs + repeated uint32 ids = 1; + /// Top Logprobs + repeated float logprobs = 2; + /// Top Token Texts + repeated string texts = 3; + /// If the tokens are special + repeated bool is_special = 6; +} + message Generation { /// Request ID uint64 request_id = 1; @@ -156,6 +169,8 @@ message Generation { bool token_is_special = 6; /// Complete generated text optional GeneratedText generated_text = 7; + // Top tokens + TopTokens top_tokens = 8; } message FilterBatchRequest { diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 7753f307c0a..10212adb6f1 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -131,6 +131,7 @@ impl Client { ignore_eos_token: false, }), prefill_logprobs: true, + top_tokens: 20, }); n_tokens += max_input_length; } diff --git a/router/src/health.rs b/router/src/health.rs index a3cacdcd016..972dfe486fa 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -50,6 +50,7 @@ impl Health { stop_sequences: vec![], ignore_eos_token: false, }), + top_tokens: 0, }; let batch = Batch { id: BATCH_ID, diff --git a/router/src/infer.rs b/router/src/infer.rs index 188ddc6420c..079b28e1a2c 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -138,12 +138,15 @@ impl Infer { &self, request: GenerateRequest, ) -> Result { + let use_top_tokens = request.parameters.top_tokens.is_some_and(|x| x > 0); + // Create stream and keep semaphore permit as long as generate lives let (_permit, mut stream) = self.generate_stream(request).await?; // Return values let mut result_prefill = Vec::new(); let mut result_tokens = Vec::new(); + let mut result_top_tokens = Vec::new(); let mut result_generated_text = None; let mut result_start = None; let mut result_queued = None; @@ -164,7 +167,10 @@ impl Infer { .collect(); } // Push last token - InferStreamResponse::Token(token) => result_tokens.push(token), + InferStreamResponse::Intermediate {token, top_tokens} => { + result_tokens.push(token); + result_top_tokens.push(top_tokens); + } // Final message // Set return values InferStreamResponse::End { @@ -172,8 +178,10 @@ impl Infer { generated_text, start, queued, + top_tokens } => { result_tokens.push(token); + result_top_tokens.push(top_tokens); result_generated_text = Some(generated_text); result_start = Some(start); result_queued = Some(queued) @@ -185,12 +193,16 @@ impl Infer { if let (Some(generated_text), Some(queued), Some(start)) = (result_generated_text, result_queued, result_start) { + let top_tokens = if use_top_tokens + {result_top_tokens} else + {Vec::new()}; Ok(InferResponse { prefill: result_prefill, tokens: result_tokens, generated_text, queued, start, + top_tokens, }) } else { let err = InferError::IncompleteGeneration; @@ -520,6 +532,24 @@ fn send_responses( special: generation.token_is_special, }; + let mut top_tokens = Vec::new(); + if let Some(top_tokens_) = generation.top_tokens { + top_tokens.extend( + top_tokens_ + .ids + .into_iter() + .zip(top_tokens_.logprobs.into_iter()) + .zip(top_tokens_.texts.into_iter()) + .zip(top_tokens_.is_special.into_iter()) + .map(|(((id, logprob), text), special)| Token { + id, + text, + logprob, + special, + }) + ) + } + if let Some(generated_text) = generation.generated_text { // Generation has ended stopped = true; @@ -527,6 +557,7 @@ fn send_responses( entry.response_tx.send_timeout( Ok(InferStreamResponse::End { token, + top_tokens, generated_text, queued: entry.queue_time, start: entry.batch_time.unwrap(), @@ -536,7 +567,7 @@ fn send_responses( } else { // Send message entry.response_tx.send_timeout( - Ok(InferStreamResponse::Token(token)), + Ok(InferStreamResponse::Intermediate { token, top_tokens }), Duration::from_millis(10), )?; } @@ -566,10 +597,14 @@ pub(crate) enum InferStreamResponse { // Optional first message Prefill(PrefillTokens), // Intermediate messages - Token(Token), + Intermediate { + token: Token, + top_tokens: Vec, + }, // Last message End { token: Token, + top_tokens: Vec, generated_text: GeneratedText, start: Instant, queued: Instant, @@ -583,6 +618,7 @@ pub(crate) struct InferResponse { pub(crate) generated_text: GeneratedText, pub(crate) queued: Instant, pub(crate) start: Instant, + pub(crate) top_tokens: Vec>, } #[derive(Debug, Error)] diff --git a/router/src/lib.rs b/router/src/lib.rs index 7dff7a114ec..a948cf444b1 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -67,6 +67,9 @@ pub(crate) struct GenerateParameters { #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)] pub best_of: Option, #[serde(default)] + #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)] + pub top_tokens: Option, + #[serde(default)] #[schema( exclusive_minimum = 0.0, nullable = true, @@ -144,6 +147,7 @@ fn default_max_new_tokens() -> u32 { fn default_parameters() -> GenerateParameters { GenerateParameters { best_of: None, + top_tokens: None, temperature: None, repetition_penalty: None, top_k: None, @@ -235,6 +239,8 @@ pub(crate) struct BestOfSequence { pub seed: Option, pub prefill: Vec, pub tokens: Vec, + #[serde(skip_serializing_if = "Vec::is_empty")] + pub top_tokens: Vec>, } #[derive(Serialize, ToSchema)] @@ -249,6 +255,8 @@ pub(crate) struct Details { pub tokens: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub best_of_sequences: Option>, + #[serde(skip_serializing_if = "Vec::is_empty")] + pub top_tokens: Vec>, } #[derive(Serialize, ToSchema)] @@ -272,6 +280,8 @@ pub(crate) struct StreamDetails { #[derive(Serialize, ToSchema)] pub(crate) struct StreamResponse { pub token: Token, + #[serde(skip_serializing_if = "Vec::is_empty")] + pub top_tokens: Vec, #[schema(nullable = true, default = "null", example = "test")] pub generated_text: Option, #[schema(nullable = true, default = "null")] diff --git a/router/src/main.rs b/router/src/main.rs index 484643cb252..d8994bac71e 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -27,6 +27,8 @@ struct Args { max_concurrent_requests: usize, #[clap(default_value = "2", long, env)] max_best_of: usize, + #[clap(default_value = "10", long, env)] + max_top_tokens: u32, #[clap(default_value = "4", long, env)] max_stop_sequences: usize, #[clap(default_value = "1024", long, env)] @@ -74,6 +76,7 @@ fn main() -> Result<(), RouterError> { let Args { max_concurrent_requests, max_best_of, + max_top_tokens, max_stop_sequences, max_input_length, max_total_tokens, @@ -258,6 +261,7 @@ fn main() -> Result<(), RouterError> { compat_return_full_text, max_concurrent_requests, max_best_of, + max_top_tokens, max_stop_sequences, max_input_length, max_total_tokens, diff --git a/router/src/queue.rs b/router/src/queue.rs index 2d8d6d1c1c2..aab3530da47 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -235,6 +235,7 @@ impl State { truncate: entry.request.truncate, parameters: Some(entry.request.parameters.clone()), stopping_parameters: Some(entry.request.stopping_parameters.clone()), + top_tokens: entry.request.top_tokens, }); // Set batch_time entry.batch_time = Some(Instant::now()); @@ -323,6 +324,7 @@ mod tests { repetition_penalty: 0.0, watermark: false, }, + top_tokens: 0, stopping_parameters: StoppingCriteriaParameters { ignore_eos_token: false, max_new_tokens: 1, diff --git a/router/src/server.rs b/router/src/server.rs index 9af94951b2a..b3ba94edf73 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -193,6 +193,7 @@ async fn generate( generated_tokens: response.generated_text.generated_tokens, prefill: response.prefill, tokens: response.tokens, + top_tokens: response.top_tokens, seed: response.generated_text.seed, } }) @@ -206,6 +207,7 @@ async fn generate( tokens: response.tokens, seed: response.generated_text.seed, best_of_sequences, + top_tokens: response.top_tokens, }) } false => None, @@ -387,12 +389,16 @@ async fn generate_stream( // Prefill is ignored InferStreamResponse::Prefill(_) => {} // Yield event for every new token - InferStreamResponse::Token(token) => { + InferStreamResponse::Intermediate{ + token, + top_tokens, + } => { tracing::debug!(parent: &span, "Token: {:?}", token); // StreamResponse let stream_token = StreamResponse { token, + top_tokens: top_tokens, generated_text: None, details: None, }; @@ -402,6 +408,7 @@ async fn generate_stream( // Yield event for last token and compute timings InferStreamResponse::End { token, + top_tokens, generated_text, start, queued, @@ -453,6 +460,7 @@ async fn generate_stream( let stream_token = StreamResponse { token, + top_tokens, generated_text: Some(output_text), details }; @@ -510,6 +518,7 @@ pub async fn run( compat_return_full_text: bool, max_concurrent_requests: usize, max_best_of: usize, + max_top_tokens: u32, max_stop_sequences: usize, max_input_length: usize, max_total_tokens: usize, @@ -572,6 +581,7 @@ pub async fn run( validation_workers, tokenizer, max_best_of, + max_top_tokens, max_stop_sequences, max_input_length, max_total_tokens, diff --git a/router/src/validation.rs b/router/src/validation.rs index be835bf0a07..ca345c0bbec 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -14,6 +14,7 @@ use tracing::{instrument, Span}; pub struct Validation { /// Validation parameters max_best_of: usize, + max_top_tokens: u32, max_stop_sequences: usize, max_input_length: usize, max_total_tokens: usize, @@ -26,6 +27,7 @@ impl Validation { workers: usize, tokenizer: Option, max_best_of: usize, + max_top_tokens: u32, max_stop_sequences: usize, max_input_length: usize, max_total_tokens: usize, @@ -53,6 +55,7 @@ impl Validation { Self { max_best_of, sender, + max_top_tokens, max_stop_sequences, max_input_length, max_total_tokens, @@ -130,6 +133,7 @@ impl Validation { ) -> Result { let GenerateParameters { best_of, + top_tokens, temperature, repetition_penalty, top_k, @@ -218,6 +222,15 @@ impl Validation { } }; + let top_tokens = top_tokens + .map(|value| { + if value > self.max_top_tokens { + return Err(ValidationError::TopTokens(self.max_top_tokens, value)); + } + Ok(value) + }) + .unwrap_or(Ok(0))?; + // Check if inputs is empty if request.inputs.is_empty() { return Err(EmptyInput); @@ -263,6 +276,7 @@ impl Validation { truncate: truncate.unwrap_or(self.max_input_length) as u32, parameters, stopping_parameters, + top_tokens: top_tokens, }) } @@ -311,7 +325,7 @@ fn prepare_input( // truncate encoding and decode new inputs encoding.truncate(truncate, 0, TruncationDirection::Left); let inputs = tokenizer - .decode(Vec::from(encoding.get_ids()), false) + .decode(&Vec::from(encoding.get_ids()), false) .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; (inputs, encoding.len()) } @@ -336,6 +350,7 @@ pub(crate) struct ValidGenerateRequest { pub decoder_input_details: bool, pub parameters: NextTokenChooserParameters, pub stopping_parameters: StoppingCriteriaParameters, + pub top_tokens: u32, } #[derive(Error, Debug)] @@ -344,6 +359,10 @@ pub enum ValidationError { BestOf(usize, usize), #[error("`best_of` != 1 is not allowed for this endpoint")] BestOfDisabled, + #[error("`top_tokens` must be >= 0 and <= {0}. Given: {1}")] + TopTokens(u32, u32), + #[error("`top_tokens` != 0 is not allowed for this endpoint")] + TopTokensDisabled, #[error("you must use sampling when `best_of` is > 1")] BestOfSampling, #[error("`seed` must not be set when `best_of` > 1")] @@ -390,14 +409,16 @@ mod tests { async fn test_validation_max_new_tokens() { let tokenizer = None; let max_best_of = 2; - let max_stop_sequence = 3; - let max_input_length = 4; - let max_total_tokens = 5; + let max_top_tokens = 3; + let max_stop_sequence = 4; + let max_input_length = 5; + let max_total_tokens = 6; let workers = 1; let validation = Validation::new( workers, tokenizer, max_best_of, + max_top_tokens, max_stop_sequence, max_input_length, max_total_tokens, @@ -417,9 +438,10 @@ mod tests { async fn test_validation_input_length() { let tokenizer = Some(get_tokenizer().await); let max_best_of = 2; - let max_stop_sequence = 3; - let max_input_length = 4; - let max_total_tokens = 5; + let max_tokens = 3; + let max_stop_sequence = 4; + let max_input_length = 5; + let max_total_tokens = 6; let workers = 1; let validation = Validation::new( workers, @@ -435,7 +457,7 @@ mod tests { .validate_input("Hello".to_string(), None, max_new_tokens) .await { - Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (), + Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (), _ => panic!("Unexpected not max new tokens"), } } diff --git a/server/Makefile b/server/Makefile index a4ce6d8b7a2..ea6bd0052c4 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,5 +1,4 @@ include Makefile-flash-att -include Makefile-flash-att-v2 include Makefile-vllm unit-tests: diff --git a/server/Makefile-flash-att b/server/Makefile-flash-att index bc1d37ef5e2..d938480e094 100644 --- a/server/Makefile-flash-att +++ b/server/Makefile-flash-att @@ -1,16 +1,18 @@ -flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec +flash_attention_commit := v2.0.4 flash-attention: # Clone flash attention pip install packaging - git clone https://github.com/HazyResearch/flash-attention.git + git clone https://github.com/Dao-AILab/flash-attention.git + cd flash-attention && git fetch && git checkout $(flash_attention_commit) -build-flash-attention: flash-attention - cd flash-attention && git fetch && git checkout $(flash_att_commit) - cd flash-attention && python setup.py build - cd flash-attention/csrc/rotary && python setup.py build - cd flash-attention/csrc/layer_norm && python setup.py build +install-flash-attention: flash-attention + pip uninstall flash-attention -y || true + cd flash-attention && pip install . + cd flash-attention/csrc/layer_norm && pip install . + cd flash-attention/csrc/rotary && pip install . -install-flash-attention: build-flash-attention - pip uninstall flash_attn rotary_emb dropout_layer_norm -y || true - cd flash-attention && python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install \ No newline at end of file + +test-flash-attention: flash-attention + pip install pytest + cd flash-attention && pytest -q -s tests/test_flash_attn.py \ No newline at end of file diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 deleted file mode 100644 index a7d633563d8..00000000000 --- a/server/Makefile-flash-att-v2 +++ /dev/null @@ -1,13 +0,0 @@ -flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc - -flash-attention-v2: - # Clone flash attention - pip install packaging - git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2 - -build-flash-attention-v2: flash-attention-v2 - cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit) - cd flash-attention-v2 && python setup.py build - -install-flash-attention-v2: build-flash-attention-v2 - cd flash-attention-v2 && python setup.py install \ No newline at end of file diff --git a/server/Makefile-vllm b/server/Makefile-vllm index 9100fff4e3c..1621c481571 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,13 +1,14 @@ -vllm_commit := 084ca75d4271f8f67be731bc58e0d41d8e0afd3a +vllm_commit := "3d40c83" vllm: # Clone vllm - git clone https://github.com/OlivierDehaene/vllm.git - -build-vllm: vllm + git clone https://github.com/vllm-project/vllm.git cd vllm && git fetch && git checkout $(vllm_commit) - cd vllm && python setup.py build -install-vllm: build-vllm +install-vllm: vllm pip uninstall vllm -y || true - cd vllm && python setup.py install \ No newline at end of file + cd vllm && pip install . + +test-vllm: vllm + pip install pytest + cd vllm && pytest -q -s tests/kernels/test_attention.py \ No newline at end of file diff --git a/server/pyproject.toml b/server/pyproject.toml index 3ee3351c6e0..28d1ea3bc12 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -15,21 +15,22 @@ grpcio-status = "^1.51.1" grpcio-reflection = "^1.51.1" grpc-interceptor = "^0.15.0" typer = "^0.6.1" -accelerate = { version = "^0.19.0", optional = true } -bitsandbytes = { version = "^0.38.1", optional = true } -safetensors = "0.3.1" +accelerate = { version = "^0.20.0", optional = true } +bitsandbytes = { version = "^0.41.1", optional = true } +safetensors = "^0.4.0" loguru = "^0.6.0" opentelemetry-api = "^1.15.0" opentelemetry-exporter-otlp = "^1.15.0" opentelemetry-instrumentation-grpc = "^0.36b0" hf-transfer = "^0.1.2" sentencepiece = "^0.1.97" -tokenizers = "0.13.3" -huggingface-hub = "^0.14.1" -transformers = "4.29.2" +tokenizers = "^0.13.3" +huggingface-hub = "^0.16.4" +transformers = "^4.32.2" einops = "^0.6.1" texttable = { version = "^1.6.7", optional = true } datasets = { version = "^2.14.0", optional = true } +scipy = "^1.11.1" [tool.poetry.extras] accelerate = ["accelerate"] diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index e74c03311c6..06e0862c6a7 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -161,20 +161,24 @@ def download_weights( for p in local_pt_files ] try: - from transformers import AutoConfig import transformers + import json + from huggingface_hub import hf_hub_download - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - ) - architecture = config.architectures[0] + logger.info(f"is_local_model: {is_local_model}") + if is_local_model: + config_filename = os.path.join(model_id, "config.json") + else: + config_filename = hf_hub_download(model_id, revision=revision, filename="config.json") + + with open(config_filename, "r") as f: + config = json.load(f) + architecture = config["architectures"][0] class_ = getattr(transformers, architecture) # Name for this varible depends on transformers version. discard_names = getattr(class_, "_tied_weights_keys", []) - discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", [])) except Exception as e: discard_names = [] diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index cbdf480837c..7e3ea652418 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -12,9 +12,11 @@ PrefillTokens, Generation, GeneratedText, + TopTokens, ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from text_generation_server.utils.tokens import batch_top_tokens tracer = trace.get_tracer(__name__) @@ -42,6 +44,8 @@ class CausalLMBatch(Batch): # Generation helpers next_token_choosers: List[NextTokenChooser] stopping_criterias: List[StoppingCriteria] + top_tokens: List[int] + top_tokens_tensor: torch.Tensor # Metadata used for padding max_input_length: int @@ -72,6 +76,7 @@ def from_pb( inputs = [] next_token_choosers = [] stopping_criterias = [] + top_tokens = [] prefix_offsets = [] read_offsets = [] requests_idx_mapping = {} @@ -88,6 +93,7 @@ def from_pb( r.stopping_parameters, tokenizer ) stopping_criterias.append(stopping_criteria) + top_tokens.append(r.top_tokens) max_truncation = max(max_truncation, r.truncate) max_decode_tokens += stopping_criteria.max_new_tokens padding_right_offset = max( @@ -123,6 +129,9 @@ def from_pb( all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) max_tokens = len(inputs) * (max_input_length + max_decode_tokens) + top_tokens_tensor = torch.tensor( + top_tokens, device=device, dtype=torch.int64 + ) return cls( batch_id=pb.id, @@ -138,6 +147,8 @@ def from_pb( read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, + top_tokens=top_tokens, + top_tokens_tensor=top_tokens_tensor, max_input_length=max_input_length.item(), padding_right_offset=padding_right_offset, max_tokens=max_tokens, @@ -163,6 +174,7 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: next_token_choosers = [] stopping_criterias = [] + top_tokens = [] total_remaining_decode_tokens = 0 new_padding_right_offset = 0 @@ -184,6 +196,8 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: next_token_choosers.append(self.next_token_choosers[idx]) stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) + top_tokens.append(self.top_tokens[idx]) + remaining_decode_tokens = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) @@ -223,6 +237,7 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: layer[1] = past_values[keep_indices, :, -past_kv_length:, :] del past_values + top_tokens_tensor = self.top_tokens_tensor[keep_indices] max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens self.requests = requests @@ -235,6 +250,8 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: self.read_offsets = read_offsets self.next_token_choosers = next_token_choosers self.stopping_criterias = stopping_criterias + self.top_tokens = top_tokens + self.top_tokens_tensor = top_tokens_tensor self.max_input_length = max_input_length self.padding_right_offset = new_padding_right_offset self.max_tokens = max_tokens @@ -262,6 +279,7 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": all_input_ids = [] next_token_choosers = [] stopping_criterias = [] + top_tokens = [] max_tokens = 0 # Batch tensors @@ -281,6 +299,7 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": all_input_ids.extend(batch.all_input_ids) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) + top_tokens.extend(batch.top_tokens) if i == 0: requests_idx_mapping = batch.requests_idx_mapping @@ -310,6 +329,12 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": (total_batch_size, max_input_length + padding_right_offset), ) + if top_tokens_tensor is None: + top_tokens_tensor = batches[0].top_tokens_tensor.new_zeros( + total_batch_size, + ) + top_tokens_tensor[start_index:end_index] = batch.top_tokens_tensor + # We need to slice the attention mask to remove padding from previous steps # and to remove unused allocated space left_offset = max_input_length - batch.max_input_length @@ -438,6 +463,8 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, + top_tokens=top_tokens, + top_tokens_tensor=top_tokens_tensor, max_input_length=max_input_length, padding_right_offset=padding_right_offset, keys_head_dim_last=batches[0].keys_head_dim_last, @@ -545,6 +572,12 @@ def generate_token( batch.past_key_values, ) + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_tokens, + batch.top_tokens_tensor, + torch.softmax(logits[:, -1], -1), + ) + # Results generations: List[Generation] = [] stopped = True @@ -559,6 +592,9 @@ def generate_token( batch.next_token_choosers, batch.stopping_criterias, batch.all_input_ids, + batch.top_tokens, + batch_top_token_ids, + batch_top_token_logprobs, ) # For each member of the batch @@ -571,6 +607,9 @@ def generate_token( next_token_chooser, stopping_criteria, all_input_ids, + top_tokens, + top_token_ids, + top_token_logprobs, ) in enumerate(iterator): # Select next token next_token_id, logprobs = next_token_chooser( @@ -637,6 +676,24 @@ def generate_token( else: prefill_tokens = None + if top_tokens > 0: + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids for token_id in top_token_ids + ] + top_tokens_obj = TopTokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + else: + top_tokens_obj = None + generation = Generation( request.id, prefill_tokens, @@ -645,6 +702,7 @@ def generate_token( next_token_text, next_token_id_squeezed.item() in self.all_special_ids, generated_text, + top_tokens_obj, ) generations.append(generation) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index b6285856fc3..d528178259e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -30,8 +30,8 @@ import dropout_layer_norm # vllm imports -import vllm_cache_ops -import vllm_attention_ops +import vllm.cache_ops +import vllm.attention_ops from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( @@ -185,10 +185,12 @@ def __init__( self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - self.rotary_emb = PositionRotaryEmbedding.load( - prefix=f"{prefix}.rotary_emb", weights=weights + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=10000.0, + device=weights.device ) - self.softmax_scale = self.head_size**-0.5 if self.num_heads % weights.process_group.size() != 0: @@ -247,7 +249,7 @@ def forward( self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - vllm_cache_ops.reshape_and_cache( + vllm.cache_ops.reshape_and_cache( kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -270,7 +272,7 @@ def forward( else: # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] block_size = kv_cache[1].shape[3] - vllm_attention_ops.single_query_cached_kv_attention( + vllm.attention_ops.paged_attention_v1( attn_output, query, kv_cache[0], @@ -281,6 +283,7 @@ def forward( input_lengths, block_size, max_s, + None ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -383,6 +386,7 @@ def forward( class FlashLlamaModel(torch.nn.Module): def __init__(self, config, weights): super().__init__() + self.config = config process_group = weights.process_group self.tp_rank = process_group.rank() @@ -485,4 +489,4 @@ def forward( if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) - return logits + return logits \ No newline at end of file diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index e7c8ced4ca7..6d524d308f0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -28,8 +28,8 @@ from typing import Optional, List, Tuple # vllm imports -import vllm_cache_ops -import vllm_attention_ops +import vllm.cache_ops +import vllm.attention_ops from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( @@ -141,7 +141,7 @@ def forward( self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(qkv[:, 1], cos, sin) - vllm_cache_ops.reshape_and_cache( + vllm.cache_ops.reshape_and_cache( qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots ) @@ -164,7 +164,7 @@ def forward( else: # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] block_size = kv_cache[1].shape[3] - vllm_attention_ops.single_query_cached_kv_attention( + vllm.attention_ops.paged_attention_v1( attn_output, qkv[:, 0], kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 3570b283e59..fc731a95cb5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -7,8 +7,8 @@ from typing import Optional, List, Tuple # vllm imports -import vllm_cache_ops -import vllm_attention_ops +import vllm.cache_ops +import vllm.attention_ops from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( @@ -191,7 +191,7 @@ def forward( self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - vllm_cache_ops.reshape_and_cache( + vllm.cache_ops.reshape_and_cache( kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -214,7 +214,7 @@ def forward( else: # kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size] block_size = kv_cache[1].shape[3] - vllm_attention_ops.single_query_cached_kv_attention( + vllm.attention_ops.paged_attention_v1( attn_output, query, kv_cache[0], @@ -307,7 +307,7 @@ def forward( self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin) - vllm_cache_ops.reshape_and_cache( + vllm.cache_ops.reshape_and_cache( kv[:, :, 0].contiguous(), kv[:, :, 1].contiguous(), kv_cache[0], @@ -334,7 +334,7 @@ def forward( else: # kv_cache[1] => [num_blocks, num_groups, head_size, block_size] block_size = kv_cache[1].shape[3] - vllm_attention_ops.single_query_cached_kv_attention( + vllm.attention_ops.paged_attention_v1( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 2dd0a5ee4e0..4b587752d7e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -6,8 +6,8 @@ from typing import Optional, List, Tuple # vllm imports -import vllm_cache_ops -import vllm_attention_ops +import vllm.cache_ops +import vllm.attention_ops from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( @@ -258,7 +258,7 @@ def forward( query = query.view(-1, self.num_heads, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size) - vllm_cache_ops.reshape_and_cache( + vllm.cache_ops.reshape_and_cache( key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -281,7 +281,7 @@ def forward( else: # kv_cache[1] => [num_blocks, 1, head_size, block_size] block_size = kv_cache[1].shape[3] - vllm_attention_ops.single_query_cached_kv_attention( + vllm.attention_ops.paged_attention_v1( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 7de51358c8d..c8fe1061460 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -16,10 +16,12 @@ PrefillTokens, Generation, GeneratedText, + TopTokens ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION +from text_generation_server.utils.tokens import batch_top_tokens tracer = trace.get_tracer(__name__) @@ -165,6 +167,8 @@ class FlashCausalLMBatch(Batch): # Generation helpers next_token_chooser: HeterogeneousNextTokenChooser stopping_criterias: List[StoppingCriteria] + top_tokens: List[int] + top_tokens_tensor: torch.Tensor # Number of blocks in this batch blocks: int @@ -217,6 +221,7 @@ def from_pb( next_token_chooser_parameters = [] stopping_criterias = [] + top_tokens = [] # Cumulative length cumulative_length = 0 @@ -259,6 +264,7 @@ def from_pb( ) max_new_tokens = stopping_criteria.max_new_tokens stopping_criterias.append(stopping_criteria) + top_tokens.append(r.top_tokens) # Paged attention # Remove one as the first token des not have a past @@ -353,6 +359,10 @@ def from_pb( prefill_next_token_indices, dtype=torch.int64, device=device ) + top_tokens_tensor = torch.tensor( + top_tokens, device=device, dtype=torch.int64 + ) + return cls( batch_id=pb.id, requests=pb.requests, @@ -378,6 +388,8 @@ def from_pb( all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, + top_tokens=top_tokens, + top_tokens_tensor=top_tokens_tensor, blocks=blocks, max_blocks=max_blocks, ) @@ -417,6 +429,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": read_offsets = [] stopping_criterias = [] + top_tokens = [] blocks = 0 max_blocks = 0 @@ -443,6 +456,8 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) + top_tokens.append(self.top_tokens[idx]) + remaining_tokens = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) @@ -487,6 +502,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": input_lengths_tensor = self.input_lengths_tensor[indices] slots = self.slots[slot_filtering_indices] next_token_chooser = self.next_token_chooser.filter(indices) + top_tokens_tensor = self.top_tokens_tensor[indices] start_slots = torch.tensor(start_slots, dtype=torch.int64) @@ -518,6 +534,8 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, + top_tokens=top_tokens, + top_tokens_tensor=top_tokens_tensor, blocks=blocks, max_blocks=max_blocks, ) @@ -567,6 +585,10 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch (total_batch_size, max_length) ) + top_tokens_tensor = batches[0].top_tokens_tensor.new_zeros( + total_batch_size, + ) + start_slots = [] block_tables = [] all_input_ids = [] @@ -577,6 +599,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch next_token_chooser_parameters = [] stopping_criterias = [] + top_tokens = [] # Cumulative length cumulative_batch_size = 0 @@ -601,6 +624,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch input_ids[start_index:end_index] = batch.input_ids position_ids[start_index:end_index] = batch.position_ids slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots + top_tokens_tensor[start_index:end_index] = batch.top_tokens_tensor input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor slots[slots_start_index:slots_end_index] = batch.slots @@ -623,6 +647,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) stopping_criterias.extend(batch.stopping_criterias) + top_tokens.extend(batch.top_tokens) # Update cumulative_batch_size += len(batch) @@ -666,6 +691,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, + top_tokens=top_tokens, + top_tokens_tensor=top_tokens_tensor, blocks=blocks, max_blocks=max_blocks, ) @@ -831,10 +858,14 @@ def generate_token( else: next_token_logits = out - next_input_ids, next_token_logprobs = batch.next_token_chooser( + next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser( batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits ) + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_tokens, batch.top_tokens_tensor, logprobs + ) + if prefill: if len(batch) > 1 and prefill_logprobs: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs @@ -931,8 +962,11 @@ def generate_token( batch.all_input_ids, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, + batch.top_tokens, next_token_ids, next_token_logprobs, + batch_top_token_ids, + batch_top_token_logprobs, ) # For each member of the batch @@ -945,8 +979,11 @@ def generate_token( all_input_ids, do_sample, seed, + top_tokens, next_token_id, next_token_logprob, + top_token_ids, + top_token_logprobs, ) in enumerate(iterator): # Append next token to all tokens all_input_ids.append(next_token_id) @@ -1005,6 +1042,24 @@ def generate_token( else: prefill_tokens = None + if top_tokens > 0: + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids for token_id in top_token_ids + ] + top_tokens_obj = TopTokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + else: + top_tokens_obj = None + generation = Generation( request.id, prefill_tokens, @@ -1013,6 +1068,7 @@ def generate_token( next_token_text, next_token_id in self.all_special_ids, generated_text, + top_tokens_obj, ) generations.append(generation) diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 96fb0c266cb..f725a627c6b 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -2,7 +2,8 @@ import torch.distributed from opentelemetry import trace -from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast +from transformers import AutoTokenizer +from transformers.models.llama import LlamaTokenizer from typing import Optional from text_generation_server.models import FlashCausalLM @@ -44,7 +45,8 @@ def __init__( trust_remote_code=trust_remote_code, ) except Exception: - tokenizer = LlamaTokenizerFast.from_pretrained( + # use AutoTokenizer as fallback in case it's a costum tokenizer + tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 9e5c21d1542..5422236d4ed 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -11,9 +11,11 @@ Batch, Generation, PrefillTokens, + TopTokens ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from text_generation_server.utils.tokens import batch_top_tokens tracer = trace.get_tracer(__name__) @@ -51,6 +53,8 @@ class Seq2SeqLMBatch(Batch): # Metadata used for padding max_input_length: int + top_tokens: List[int] + top_tokens_tensor: torch.Tensor max_decoder_input_length: int padding_right_offset: int @@ -78,6 +82,7 @@ def from_pb( inputs = [] next_token_choosers = [] stopping_criterias = [] + top_tokens = [] decoder_input_lengths = [] prefix_offsets = [] @@ -97,6 +102,7 @@ def from_pb( r.stopping_parameters, tokenizer ) stopping_criterias.append(stopping_criteria) + top_tokens.append(r.top_tokens) max_truncation = max(max_truncation, r.truncate) max_decode_tokens += stopping_criteria.max_new_tokens padding_right_offset = max( @@ -127,6 +133,9 @@ def from_pb( read_offsets.append(1) all_decoder_input_ids = decoder_input_ids.view(-1).split(1) + top_tokens_tensor = torch.tensor( + top_tokens, device=device, dtype=torch.int64 + ) max_tokens = len(inputs) * (max_input_length + max_decode_tokens) return cls( @@ -146,6 +155,8 @@ def from_pb( read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, + top_tokens=top_tokens, + top_tokens_tensor=top_tokens_tensor, max_input_length=max_input_length.item(), max_decoder_input_length=1, padding_right_offset=padding_right_offset, @@ -173,6 +184,7 @@ def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]: next_token_choosers = [] stopping_criterias = [] + top_tokens = [] max_input_length = 0 max_decoder_input_length = 0 @@ -204,6 +216,7 @@ def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]: next_token_choosers.append(self.next_token_choosers[idx]) stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) + top_tokens.append(self.top_tokens[idx]) remaining_decode_tokens = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) @@ -239,6 +252,7 @@ def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]: layer[2] = layer[2][keep_indices, :, -max_input_length:] layer[3] = layer[3][keep_indices, :, -max_input_length:] + top_tokens_tensor = self.top_tokens_tensor[keep_indices] max_tokens = ( len(request_ids) * (max_input_length + max_decoder_input_length) + remaining_decode_tokens @@ -254,6 +268,8 @@ def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]: self.read_offsets = read_offsets self.next_token_choosers = next_token_choosers self.stopping_criterias = stopping_criterias + self.top_tokens = top_tokens + self.top_tokens_tensor = top_tokens_tensor self.max_input_length = max_input_length self.max_decoder_input_length = max_decoder_input_length self.padding_right_offset = padding_right_offset @@ -289,6 +305,7 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": read_offsets = [] next_token_choosers = [] stopping_criterias = [] + top_tokens = [] max_tokens = 0 # Batch tensors @@ -312,6 +329,7 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": read_offsets.extend(batch.read_offsets) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) + top_tokens.extend(batch.top_tokens) if i == 0: requests_idx_mapping = batch.requests_idx_mapping @@ -384,6 +402,12 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": ), ) + if top_tokens_tensor is None: + top_tokens_tensor = batches[0].top_tokens_tensor.new_zeros( + total_batch_size, + ) + top_tokens_tensor[start_index:end_index] = batch.top_tokens_tensor + # Copy to correct indices encoder_last_hidden_state[ start_index:end_index, -batch.max_input_length :, : @@ -488,6 +512,8 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, + top_tokens=top_tokens, + top_tokens_tensor=top_tokens_tensor, max_input_length=max_input_length, max_decoder_input_length=max_decoder_input_length, padding_right_offset=padding_right_offset, @@ -613,6 +639,12 @@ def generate_token( batch.past_key_values, ) + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_tokens, + batch.top_tokens_tensor, + torch.softmax(logits[:, -1], -1), + ) + # Finished requests generations: List[Generation] = [] stopped = True @@ -628,6 +660,9 @@ def generate_token( batch.next_token_choosers, batch.stopping_criterias, batch.all_decoder_input_ids, + batch.top_tokens, + batch_top_token_ids, + batch_top_token_logprobs, ) # For each member of the batch @@ -641,6 +676,9 @@ def generate_token( next_token_chooser, stopping_criteria, all_decoder_input_ids, + top_tokens, + top_token_ids, + top_token_logprobs, ) in enumerate(iterator): # Select next token next_token_id, logprobs = next_token_chooser( @@ -698,6 +736,24 @@ def generate_token( else: prefill_tokens = None + if top_tokens > 0: + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids for token_id in top_token_ids + ] + top_tokens_obj = TopTokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + else: + top_tokens_obj = None + generation = Generation( request.id, prefill_tokens, @@ -706,6 +762,7 @@ def generate_token( next_token_text, next_token_id_squeezed.item() in self.all_special_ids, generated_text, + top_tokens_obj, ) generations.append(generation) diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 28ca8147eb9..98ad95c2859 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -71,6 +71,24 @@ def __len__(self): return len(self.token_ids) +@dataclass +class TopTokens: + token_ids: List[int] + logprobs: List[float] + texts: List[str] + is_special: List[bool] + + def to_pb(self) -> generate_pb2.TopTokens: + return generate_pb2.TopTokens( + ids=self.token_ids, + logprobs=self.logprobs, + texts=self.texts, + is_special=self.is_special, + ) + + def __len__(self): + return len(self.token_ids) + @dataclass class Generation: request_id: int @@ -80,6 +98,8 @@ class Generation: token_text: str token_is_special: bool generated_text: Optional[GeneratedText] + # Optional for now, since it's not yet supported for every model. + top_tokens: Optional[TopTokens] def to_pb(self) -> generate_pb2.Generation: return generate_pb2.Generation( @@ -94,4 +114,5 @@ def to_pb(self) -> generate_pb2.Generation: generated_text=self.generated_text.to_pb() if self.generated_text is not None else None, + top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None, ) diff --git a/server/text_generation_server/utils/convert.py b/server/text_generation_server/utils/convert.py index 8d414ecac91..0b62f520836 100644 --- a/server/text_generation_server/utils/convert.py +++ b/server/text_generation_server/utils/convert.py @@ -29,9 +29,15 @@ def _remove_duplicate_names( [name for name in shared if _is_complete(state_dict[name])] ) if not complete_names: - raise RuntimeError( - f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue." - ) + if len(shared) == 1: + # Force contiguous + name = list(shared)[0] + state_dict[name] = state_dict[name].clone() + complete_names = {name} + else: + raise RuntimeError( + f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue." + ) keep_name = sorted(list(complete_names))[0] diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index c472d1fceab..6c8e1ca2cae 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -15,39 +15,22 @@ is_sm90 = major == 9 and minor == 0 HAS_FLASH_ATTN = False -HAS_FLASH_ATTN_V2 = False try: - try: - import flash_attn_2_cuda - except ImportError: - raise ImportError( - "Flash Attention V2 is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention v2 with `cd server && make install install-flash-attention-v2`" - ) - if not (is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported for " - "Flash Attention V2" - ) - HAS_FLASH_ATTN_V2 = True -except ImportError as e: - try: - import flash_attn_cuda - except ImportError: - raise ImportError( - "Flash Attention is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" - ) from e - - if not (is_sm75 or is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported" - ) from e - logger.warning(f"Unable to use Flash Attention V2: {e}") + import flash_attn_2_cuda as flash_attn_cuda + import flash_attn HAS_FLASH_ATTN = True - +except ImportError as e: + raise ImportError( + f"Flash Attention V2 is not installed.\n" + f"Error message: {e}\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention v2 with `cd server && make install install-flash-attention-v2`" + ) +if not (is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported for " + "Flash Attention V2" + ) def attention( q, @@ -58,52 +41,8 @@ def attention( max_s, softmax_scale, ): - if HAS_FLASH_ATTN_V2: - return flash_attn_2_cuda.varlen_fwd( - q, - k, - v, - out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, - 0.0, - softmax_scale, - False, - True, - False, - None, - ) - if HAS_FLASH_ATTN: - # Flash attention v1 requires q, k and v to have the same number of heads - if k.shape[1] != q.shape[1]: - # MQA expand - if k.shape[1] == 1: - k = k.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = k.shape - k = ( - k.unsqueeze(2) - .expand(-1, -1, q.shape[1] // k.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - if v.shape[1] != q.shape[1]: - # MQA expand - if v.shape[1] == 1: - v = v.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = v.shape - v = ( - v.unsqueeze(2) - .expand(-1, -1, q.shape[1] // v.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - - return flash_attn_cuda.fwd( + return flash_attn_cuda.varlen_fwd( q, k, v, @@ -117,7 +56,6 @@ def attention( False, True, False, - 0, None, ) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 7a45808ec92..bae70f755ee 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -375,35 +375,68 @@ def forward(self, hidden_states, residual=None): try: from flash_attn.layers.rotary import RotaryEmbedding import rotary_emb + + def _create_inv_freq(dim, base, device): + inv_freq = 1.0 / ( + base + ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) + ) + return inv_freq + + def _get_rope_config(config): + if os.getenv("ROPE_SCALING", None) is not None: + rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])} + return rope_scaling + return getattr(config, "rope_scaling", None) + class PositionRotaryEmbedding(nn.Module): - def __init__(self, inv_freq): + def __init__(self, inv_freq, scaling_factor): super().__init__() - self.inv_freq = inv_freq self._seq_len_cached = 0 self._cos_cached = None self._sin_cached = None self._cos_k_cached = None self._sin_k_cached = None + self.scaling_factor = scaling_factor + self.dynamic_args = None @classmethod - def static(cls, dim, base, device): - inv_freq = 1.0 / ( - base - ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) - ) - return cls(inv_freq) + def static(cls, config, dim, base, device): + inv_freq = _create_inv_freq(dim, base, device) + scaling_factor = None + rope_scaling = _get_rope_config(config) + if rope_scaling is not None: + scaling_factor = rope_scaling["factor"] + if rope_scaling["type"] == "linear": + pass + elif rope_scaling["type"] == "dynamic": + return DynamicPositionRotaryEmbedding(dim=dim, max_position_embeddings=config.max_position_embeddings, base=base, device=inv_freq.device, scaling_factor=scaling_factor) + else: + raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid") + return cls(inv_freq, scaling_factor) @classmethod - def load(cls, prefix, weights): + def load(cls, config, prefix, weights): # XXX: Always load this in float32 ! dtype = weights.dtype weights.dtype = torch.float32 inv_freq = weights.get_tensor(f"{prefix}.inv_freq") weights.dtype = dtype - return cls(inv_freq) + scaling_factor = None + rope_scaling = _get_rope_config(config) + if rope_scaling is not None: + scaling_factor = rope_scaling["factor"] + if rope_scaling["type"] == "linear": + pass + elif rope_scaling["type"] == "dynamic": + return DynamicPositionRotaryEmbedding(dim=2*inv_freq.shape[0], max_position_embeddings=config.max_position_embeddings, base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor) + else: + raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid") + return cls(inv_freq, scaling_factor) + def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) @@ -441,5 +474,35 @@ def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False) return x + class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): + def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): + inv_freq = _create_inv_freq(dim, base, device) + super().__init__(inv_freq, scaling_factor) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + if seqlen > self.max_position_embeddings: + newbase = self.base * ((self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2)) + self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device) + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + if self.scaling_factor is not None: + t /= self.scaling_factor + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + except ImportError: pass diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index b83af59150f..5b8400f2136 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -229,11 +229,11 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor): scores = warper(input_ids, scores) next_ids = self.choice(scores) - next_logprobs = torch.gather( - torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1) - ).view(-1) + logprobs = torch.log_softmax(scores, -1) + next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) + return next_ids, next_logprobs, logprobs + - return next_ids, next_logprobs def filter(self, indices): if self.watermark_processor is not None: @@ -339,3 +339,59 @@ def filter(self, indices): self.greedy_indices = new_greedy_indices self.sampling_mapping = new_sampling_mapping return self + + +def batch_top_tokens( + top_tokens: list[int], top_tokens_tensor: torch.Tensor, logprobs: torch.Tensor +) -> Tuple[List[List[int]], List[List[float]]]: + """Find the top n most likely tokens for a batch of generations. + When multiple tokens have equal probabilities and they don't all fit, the + remaining tokens are also returned. + + Basically copied from HF's original repo to save some time + + Args: + top_tokens: List specifying the number of top tokens to retrieve for each item in the batch. + top_tokens_tensor: Torch tensor equivalent of top_tokens for use in tensor operations. + logprobs: Torch tensor of log probabilities, shape (batch_size, vocab_size). + + Returns: + A tuple containing two lists: + 1. The indices of the top tokens for each logprob tensor in the batch. + 2. The values of the top tokens for each logprob tensor in the batch. + """ + max_top_n = max(top_tokens) + # Early exit when top_tokens is not used + if max_top_n == 0: + return [[]] * len(top_tokens), [[]] * len(top_tokens) + + # Ensure top_n doesn't exceed vocab size + top_tokens = [min(tok, logprobs.size(-1)) for tok in top_tokens] + + # From https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2 + # Sorted topk is faster than torch.sort() since we only need a small subset + sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values + nth_highest = torch.gather( + sorted_top_k, 1, (top_tokens_tensor - 1).clip(min=0).unsqueeze(1) + ) + nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min + + # Find the new "fuzzy" top n values + top_n_indices = (logprobs >= nth_highest).nonzero() + _, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True) + + top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True) + top_n_ishes = top_n_ishes.tolist() + top_indices = top_k.indices.tolist() + top_values = top_k.values.tolist() + + return ( + [ + idxs[:n] if req_n > 0 else [] + for idxs, n, req_n in zip(top_indices, top_n_ishes, top_tokens) + ], + [ + vals[:n] if req_n > 0 else [] + for vals, n, req_n in zip(top_values, top_n_ishes, top_tokens) + ], + ) \ No newline at end of file diff --git a/server/vllm_testscript.py b/server/vllm_testscript.py new file mode 100644 index 00000000000..eedbb75aca9 --- /dev/null +++ b/server/vllm_testscript.py @@ -0,0 +1,22 @@ +# Tests if VLLM works correctly +import vllm +import time + +prompts = [ + 'Hello, my name is', + 'CMU\'s PhD students are', +] +sampling_params = vllm.SamplingParams(temperature=0.8, top_p=0.95) + +llm = vllm.LLM(model="lmsys/vicuna-7b-v1.5") + +# time the generation +start = time.time() +outputs = llm.generate(prompts, sampling_params) +end = time.time() +for output in outputs: + prompt = output.prompt + generated = output.outputs[0].text + print(f'Prompt: {prompt!r}, Generated: {generated!r}') +print() +print(f'Time taken: {end - start:.2f}s') \ No newline at end of file diff --git a/setup_scripts/conda_client.sh b/setup_scripts/conda_client.sh new file mode 100644 index 00000000000..8bc3bee87d1 --- /dev/null +++ b/setup_scripts/conda_client.sh @@ -0,0 +1,34 @@ +ENV_NAME=tgi-env-client +# get the directory of this script, and go one up to get the root directory +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +DIR="$(dirname "$DIR")" +N_THREADS=8 +INSTALL_CHATUI=true + +set -eo pipefail + +# check if CONDA_HOME is set and create environment +if [ -z "$CONDA_HOME" ] +then + echo "Please set CONDA_HOME to the location of your conda installation" + exit 1 +fi + +source ${CONDA_HOME}/etc/profile.d/conda.sh +conda create -y -n ${ENV_NAME} python=3.9 +conda activate ${ENV_NAME} +conda install -y -c conda-forge mamba + +# install client +cd ${DIR}/clients/python +pip install . + +echo $PATH +echo $LD_LIBRARY_PATH + +if [ "$INSTALL_CHATUI" = true ] ; then + # install chat-ui + cd ${DIR}/chat-ui + mamba install -y -c conda-forge mongodb pymongo "nodejs>=18" + npm install +fi \ No newline at end of file diff --git a/setup_scripts/conda_server.sh b/setup_scripts/conda_server.sh new file mode 100644 index 00000000000..2da4e053e16 --- /dev/null +++ b/setup_scripts/conda_server.sh @@ -0,0 +1,208 @@ +#!/bin/zsh +# Script for setting up a conda environment with for launching servers +# It sidesteps system-wide installations by relying on conda for most packages +# and by building openssl from source +# TODO: only got it to work with a static build of OpenSSL, which is not ideal +# get the directory of this script, and go one up to get the root directory +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +DIR="$(dirname "$DIR")" + +# currently can only build in TIR without extensions +# seems un-important, as it only affects BLOOM/NEOX +ENV_NAME=tgi-env +BUILD_EXTENSIONS=false +BUILD_VLLM=true +BUILD_FLASHATTN=true +TEST_EXTRA=true +BENCHMARK=false +SERVER_WAIT=180 +N_THREADS=4 + +# Parse command line arguments +while (( "$#" )); do + case "$1" in + --env-name) + ENV_NAME=$2 + shift 2 + ;; + --build-extensions) + BUILD_EXTENSIONS=true + shift 1 + ;; + --light-mode) + BUILD_VLLM=false + BUILD_FLASHATTN=false + shift 1 + ;; + --no-tests) + TEST_EXTRA=false + shift 1 + ;; + --benchmark) + BENCHMARK=true + shift 1 + ;; + --server-wait) + SERVER_WAIT=$2 + shift 2 + ;; + --n-threads) + N_THREADS=$2 + shift 2 + ;; + --) # end argument parsing + shift + break + ;; + -*|--*=) # unsupported flags + echo "Error: Unsupported flag $1" >&2 + exit 1 + ;; + *) # preserve positional arguments + PARAMS="$PARAMS $1" + shift + ;; + esac +done +# set positional arguments in their proper place +eval set -- "$PARAMS" + +set -eo pipefail + +# check if CONDA_PREFIX is set and create environment +if [ -z "$CONDA_PREFIX" ] +then + echo "(Mini)conda does not seem to be installed, please install it first or set CONDA_PREFIX appropriately" + exit 1 +fi +export CONDA_HOME=$CONDA_PREFIX +source ${CONDA_HOME}/etc/profile.d/conda.sh + +# python can't handle this dependency madness, switch to C++ +conda install -y -c conda-forge mamba +# we need to add the base path to get mamba to work inside the new environment +export PATH=${CONDA_HOME}/bin:$PATH + +echo "Creating conda environment ${ENV_NAME}..." +mamba create -y -n ${ENV_NAME} python=3.9 +conda activate ${ENV_NAME} + +# # Install dependencies +mamba install -y -c conda-forge coreutils "gxx<12.0" +mamba install -y -c conda-forge curl git tar +mamba install -y -c conda-forge "rust>=1.65.0" +mamba install -y -c conda-forge openssh +mamba install -y -c "nvidia/label/cuda-11.8.0" cuda-toolkit +# pin pytorch due to some cuda-issue in pytorch==2.1.0 / something with vllm +mamba install -y -c pytorch -c nvidia pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 + +# bring in the conda environment variables forward +# (not sure if needed, but added during debugging and kept for now) +export LD_LIBRARY_PATH=${CONDA_HOME}/envs/${ENV_NAME}/lib:$LD_LIBRARY_PATH +export PATH=${CONDA_HOME}/envs/${ENV_NAME}/bin:$PATH +export CUDA_HOME=${CONDA_HOME}/envs/${ENV_NAME} + +# add cargo bin +export PATH=~/.cargo/bin:$PATH + +# add protoc +export PROTOC_ZIP=protoc-21.12-linux-x86_64.zip +mkdir -p /tmp/protoc +mkdir -p ~/local/bin +mkdir -p ~/local/include +cd /tmp/protoc +curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP +unzip -o $PROTOC_ZIP -d ~/local/ bin/protoc +unzip -o $PROTOC_ZIP -d ~/local/ 'include/*' +cd $DIR +rm -rf /tmp/protoc +export LD_LIBRARY_PATH=~/local/lib:$LD_LIBRARY_PATH +export PATH=~/local/bin:$PATH + +# download and build openssl +mkdir -p /tmp/openssl +cd /tmp/openssl +wget https://www.openssl.org/source/openssl-1.1.1l.tar.gz -O openssl.tar.gz +tar -xzf openssl.tar.gz +cd openssl-1.1.1l +./config --prefix=${DIR}/.openssl --openssldir=${DIR}/.openssl +make -j $N_THREADS +make install +cd $DIR +rm -rf /tmp/openssl +export LD_LIBRARY_PATH=${DIR}/.openssl/lib:$LD_LIBRARY_PATH +export PATH=${DIR}/.openssl/bin:$PATH + +# install ninja for faster compilation of CUDA kernels and setup workdir +pip install ninja +# export MAX_JOBS to limit ninjas parallelism +export MAX_JOBS=$N_THREADS +cd ${DIR}/server +mkdir -p workdir +rm -rf workdir/* + +# install vllm +if [ "$BUILD_VLLM" = true ] ; then + cp Makefile-vllm workdir/Makefile + cd workdir && sleep 1 + make install-vllm + + cd ${DIR}/server + if [ "$TEST_EXTRA" = true ] ; then + python3 vllm_testscript.py + fi + rm -rf workdir/* +fi + +# install base package +cd ${DIR} +OPENSSL_DIR=${DIR}/.openssl \ +OPENSSL_LIB_DIR=${DIR}/.openssl/lib \ +OPENSSL_INCLUDE_DIR=${DIR}/.openssl/include \ +BUILD_EXTENSIONS=$BUILD_EXTENSIONS \ + make install + + +# install flash attention +if [ "$BUILD_FLASHATTN" = true ] ; then + cd ${DIR}/server + cp Makefile-flash-att workdir/Makefile + cd workdir && sleep 1 + make install-flash-attention + if [ "$TEST_EXTRA" = true ] ; then + make test-flash-attention + fi + cd ${DIR}/server +fi + +rm -rf workdir + +# override protobuf +pip install 'protobuf<3.21' + +# # install python client +cd ${DIR}/clients/python +pip install . + +# run a example server +if [ "$BENCHMARK" = true ] ; then + cd ${DIR} + # trap signal to avoid orphan server process + trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT + # launch server as background process, checking for errors + make run-llama2-benchmark & + # sleep to make sure server has time to boot + sleep $SERVER_WAIT + + OPENSSL_DIR=${DIR}/.openssl \ + OPENSSL_LIB_DIR=${DIR}/.openssl/lib \ + OPENSSL_INCLUDE_DIR=${DIR}/.openssl/include \ + make install-benchmark + python benchmark/dump_fast_tokenizer.py --tokenizer-name=lmsys/vicuna-7b-v1.5 --output=/tmp/vicuna-7b-v1.5/ + text-generation-benchmark --tokenizer-name=/tmp/vicuna-7b-v1.5 +fi + +# set default conda environment variables +conda env config vars set LD_LIBRARY_PATH=${LD_LIBRARY_PATH} +conda env config vars set PATH=${PATH} +conda env config vars set CUDA_HOME=${CUDA_HOME}